55 auto [
mem, zero, comb, inputs] = map_reduce_ax->args<4>();
56 auto [n, S, T, m, NI, TI, SI] = map_reduce_ax->callee()->as<
App>()->args<7>();
57 world.DLOG(
"map_reduce_ax {} : {}", map_reduce_ax, map_reduce_ax->type());
58 world.DLOG(
"meta variables:");
59 world.DLOG(
" n = {}", n);
60 world.DLOG(
" S = {}", S);
61 world.DLOG(
" T = {}", T);
62 world.DLOG(
" m = {}", m);
63 world.DLOG(
" NI = {} : {}", NI, NI->type());
64 world.DLOG(
" TI = {} : {}", TI, TI->type());
65 world.DLOG(
" SI = {} : {}", SI, SI->type());
66 world.DLOG(
"arguments:");
68 world.DLOG(
" zero = {}", zero);
69 world.DLOG(
" comb = {} : {}", comb, comb->type());
70 world.DLOG(
" inputs = {} : {}", inputs, inputs->type());
86 absl::flat_hash_map<u64, Ref> dims;
87 absl::flat_hash_map<u64, Ref> raw_iterator;
88 absl::flat_hash_map<u64, Ref> iterator;
96 auto n_lit = n->isa<
Lit>();
97 auto m_lit = m->isa<
Lit>();
98 if (!n_lit || !m_lit) {
99 world.DLOG(
"n or m is not a literal");
103 auto n_nat = n_lit->
get<
u64>();
104 auto m_nat = m_lit->
get<
u64>();
107 world.DLOG(
"out dims (n) = {}", n_nat);
108 for (
u64 i = 0; i < n_nat; ++i) {
109 auto dim = S->proj(n_nat, i);
110 world.DLOG(
"dim {} = {}", i, dim);
112 output_dims.push_back(dim);
116 world.DLOG(
"matrix count (m) = {}", m_nat);
118 for (
u64 i = 0; i < m_nat; ++i) {
119 auto ni = NI->proj(m_nat, i);
122 world.DLOG(
"matrix {} has non-constant dimension count", i);
125 u64 ni_nat = *ni_lit;
126 world.DLOG(
" dims({i}) = {}", i, ni_nat);
127 auto SI_i = SI->proj(m_nat, i);
129 for (
u64 j = 0; j < ni_nat; ++j) {
130 auto dim = SI_i->proj(ni_nat, j);
131 world.DLOG(
" dim {} {} = {}", i, j, dim);
133 input_dims_i.push_back(dim);
135 input_dims.push_back(input_dims_i);
136 n_input.push_back(ni_nat);
140 for (
u64 i = 0; i < m_nat; ++i) {
141 world.DLOG(
"investigate {} / {}", i, m_nat);
142 auto [indices, mat] = inputs->proj(m_nat, i)->projs<2>();
143 world.DLOG(
" indices {} = {}", i, indices);
144 world.DLOG(
" matrix {} = {}", i, mat);
145 for (
u64 j = 0; j < n_input[i]; ++j) {
147 auto idx = indices->proj(n_input[i], j);
150 world.DLOG(
" index {} {} is not a literal", i, j);
153 u64 idx_nat = *idx_lit;
154 auto dim = input_dims[i][j];
155 world.DLOG(
" index {} = {}", j, idx);
156 world.DLOG(
" dim {} = {}", idx, dim);
157 if (!dims.contains(idx_nat)) {
159 world.DLOG(
" {} ↦ {}", idx_nat, dim);
162 auto prev_dim = dims[idx_nat];
163 world.DLOG(
" prev dim {} = {}", idx_nat, prev_dim);
165 if (
auto dim_lit = dim->isa<
Lit>()) {
166 if (
auto prev_dim_lit = prev_dim->isa<
Lit>())
167 assert(dim_lit->get<
u64>() == prev_dim_lit->get<
u64>() &&
"dimensions must be equal");
175 for (
auto [idx, dim] : dims) {
176 world.ILOG(
"dim {} = {}", idx, dim);
178 out_indices.push_back(idx);
180 in_indices.push_back(idx);
183 std::sort(out_indices.begin(), out_indices.end());
184 std::sort(in_indices.begin(), in_indices.end());
189 auto fun =
world.
mut_fun(mem_type, map_reduce_ax->type())->
set(
"mapRed");
193 world.DLOG(
"ds_fun {} : {}", ds_fun, ds_fun->type());
195 world.DLOG(
"call {} : {}", call, call->type());
223 auto current_mem =
mem;
228 auto cont = fun->var(1);
229 auto current_mut = fun;
232 DefVec acc = {current_mem, init_mat};
234 for (
auto idx : out_indices) {
235 auto for_name =
world.
sym(
"forIn_"s + std::to_string(idx));
236 auto dim_nat_def = dims[idx];
239 auto [body, for_call] =
counting_for(dim, acc, cont, for_name);
240 auto [iter, new_acc, yield] = body->vars<3>();
242 raw_iterator[idx] = iter;
244 auto [new_mem, new_mat] = new_acc->projs<2>();
245 acc = {new_mem, new_mat};
246 current_mut->set(
true, for_call);
252 world.DLOG(
"acc at inner: {;}", acc);
255 auto element_acc = zero;
256 element_acc->set(
"acc");
257 current_mem = acc[0];
258 auto wb_matrix = acc[1];
260 world.DLOG(
"wb_matrix {} : {}", wb_matrix, wb_matrix->type());
264 world.DLOG(
"write_back {} : {}", write_back, write_back->type());
265 auto [wb_mem, element_final] = write_back->
vars<2>();
267 auto output_iterators =
DefVec((
size_t)n_nat, [&](
u64 i) {
268 auto idx = out_indices[i];
269 if (idx != i)
world.ELOG(
"output indices must be consecutive 0..n-1 but {} != {}", idx, i);
270 assert(idx == i &&
"output indices must be consecutive 0..n-1");
271 auto iter_idx_def = iterator[idx];
274 auto output_it_tuple =
world.
tuple(output_iterators);
275 world.DLOG(
"output tuple: {} : {}", output_it_tuple, output_it_tuple->type());
277 auto [wb_mem2, written_matrix] =
world
279 {wb_mem, wb_matrix, output_it_tuple, element_final})
282 write_back->app(
true, cont, {wb_mem2, written_matrix});
285 acc = {current_mem, element_acc};
289 for (
auto idx : in_indices) {
290 auto for_name =
world.
sym(
"forIn_"s + std::to_string(idx));
291 auto dim_nat_def = dims[idx];
294 auto [body, for_call] =
counting_for(dim, acc, cont, for_name);
295 auto [iter, new_acc, yield] = body->vars<3>();
297 raw_iterator[idx] = iter;
299 auto [new_mem, new_element] = new_acc->projs<2>();
300 acc = {new_mem, new_element};
301 current_mut->set(
true, for_call);
308 current_mem = acc[0];
309 element_acc = acc[1];
312 DefVec input_elements((
size_t)m_nat);
313 for (
u64 i = 0; i < m_nat; i++) {
315 auto input_i = inputs->proj(m_nat, i);
316 auto [input_idx_tup, input_matrix] = input_i->projs<2>();
318 world.DLOG(
"input matrix {} is {} : {}", i, input_matrix, input_matrix->type());
320 auto indices = input_idx_tup->projs(n_input[i]);
321 auto input_iterators =
DefVec(n_input[i], [&](
u64 j) {
322 auto idx = indices[j];
323 auto idx_lit = idx->as<
Lit>()->get<u64>();
324 world.DLOG(
" idx {} {} = {}", i, j, idx_lit);
325 return iterator[idx_lit];
327 auto input_it_tuple =
world.
tuple(input_iterators);
329 auto read_entry =
op_read(current_mem, input_matrix, input_it_tuple);
330 world.DLOG(
"read_entry {} : {}", read_entry, read_entry->type());
331 auto [new_mem, element_i] = read_entry->projs<2>();
332 current_mem = new_mem;
333 input_elements[i] = element_i;
336 world.DLOG(
" read elements {,}", input_elements);
337 world.DLOG(
" fun {} : {}", fun, fun->type());