53 auto [
mem, zero, comb, inputs] = map_reduce_ax->args<4>();
54 auto [n, S, T, m, NI, TI, SI] = map_reduce_ax->callee()->as<
App>()->args<7>();
55 world().DLOG(
"map_reduce_ax {} : {}", map_reduce_ax, map_reduce_ax->type());
56 world().DLOG(
"meta variables:");
57 world().DLOG(
" n = {}", n);
58 world().DLOG(
" S = {}", S);
59 world().DLOG(
" T = {}", T);
60 world().DLOG(
" m = {}", m);
61 world().DLOG(
" NI = {} : {}", NI, NI->type());
62 world().DLOG(
" TI = {} : {}", TI, TI->type());
63 world().DLOG(
" SI = {} : {}", SI, SI->type());
64 world().DLOG(
"arguments:");
66 world().DLOG(
" zero = {}", zero);
67 world().DLOG(
" comb = {} : {}", comb, comb->type());
68 world().DLOG(
" inputs = {} : {}", inputs, inputs->type());
84 absl::flat_hash_map<u64, Ref> dims;
85 absl::flat_hash_map<u64, Ref> raw_iterator;
86 absl::flat_hash_map<u64, Ref> iterator;
94 auto n_lit = n->isa<
Lit>();
95 auto m_lit = m->isa<
Lit>();
96 if (!n_lit || !m_lit) {
97 world().DLOG(
"n or m is not a literal");
101 auto n_nat = n_lit->
get<
u64>();
102 auto m_nat = m_lit->
get<
u64>();
105 world().DLOG(
"out dims (n) = {}", n_nat);
106 for (
u64 i = 0; i < n_nat; ++i) {
107 auto dim = S->proj(n_nat, i);
108 world().DLOG(
"dim {} = {}", i, dim);
110 output_dims.push_back(dim);
114 world().DLOG(
"matrix count (m) = {}", m_nat);
116 for (
u64 i = 0; i < m_nat; ++i) {
117 auto ni = NI->proj(m_nat, i);
120 world().DLOG(
"matrix {} has non-constant dimension count", i);
123 u64 ni_nat = *ni_lit;
124 world().DLOG(
" dims({i}) = {}", i, ni_nat);
125 auto SI_i = SI->proj(m_nat, i);
127 for (
u64 j = 0; j < ni_nat; ++j) {
128 auto dim = SI_i->proj(ni_nat, j);
129 world().DLOG(
" dim {} {} = {}", i, j, dim);
131 input_dims_i.push_back(dim);
133 input_dims.push_back(input_dims_i);
134 n_input.push_back(ni_nat);
138 for (
u64 i = 0; i < m_nat; ++i) {
139 world().DLOG(
"investigate {} / {}", i, m_nat);
140 auto [indices, mat] = inputs->proj(m_nat, i)->projs<2>();
141 world().DLOG(
" indices {} = {}", i, indices);
142 world().DLOG(
" matrix {} = {}", i, mat);
143 for (
u64 j = 0; j < n_input[i]; ++j) {
145 auto idx = indices->proj(n_input[i], j);
148 world().DLOG(
" index {} {} is not a literal", i, j);
151 u64 idx_nat = *idx_lit;
152 auto dim = input_dims[i][j];
153 world().DLOG(
" index {} = {}", j, idx);
154 world().DLOG(
" dim {} = {}", idx, dim);
155 if (!dims.contains(idx_nat)) {
157 world().DLOG(
" {} ↦ {}", idx_nat, dim);
160 auto prev_dim = dims[idx_nat];
161 world().DLOG(
" prev dim {} = {}", idx_nat, prev_dim);
163 if (
auto dim_lit = dim->isa<
Lit>()) {
164 if (
auto prev_dim_lit = prev_dim->isa<
Lit>())
165 assert(dim_lit->get<
u64>() == prev_dim_lit->get<
u64>() &&
"dimensions must be equal");
173 for (
auto [idx, dim] : dims) {
174 world().ILOG(
"dim {} = {}", idx, dim);
176 out_indices.push_back(idx);
178 in_indices.push_back(idx);
181 std::sort(out_indices.begin(), out_indices.end());
182 std::sort(in_indices.begin(), in_indices.end());
187 auto fun =
world().
mut_fun(mem_type, map_reduce_ax->type())->
set(
"mapRed");
191 world().DLOG(
"ds_fun {} : {}", ds_fun, ds_fun->type());
193 world().DLOG(
"call {} : {}", call, call->type());
221 auto current_mem =
mem;
222 auto [mem2, init_mat] =
world().
app(
world().annex<matrix::init>(), {n, S, T, current_mem})->projs<2>();
226 auto cont = fun->
var(1);
227 auto current_mut = fun;
230 DefVec acc = {current_mem, init_mat};
232 for (
auto idx : out_indices) {
233 auto for_name =
world().
sym(
"forIn_"s + std::to_string(idx));
234 auto dim_nat_def = dims[idx];
237 auto [body, for_call] =
counting_for(dim, acc, cont, for_name);
238 auto [iter, new_acc, yield] = body->vars<3>();
240 raw_iterator[idx] = iter;
242 auto [new_mem, new_mat] = new_acc->
projs<2>();
243 acc = {new_mem, new_mat};
244 current_mut->set(
true, for_call);
250 world().DLOG(
"acc at inner: {;}", acc);
253 auto element_acc = zero;
254 element_acc->set(
"acc");
255 current_mem = acc[0];
256 auto wb_matrix = acc[1];
258 world().DLOG(
"wb_matrix {} : {}", wb_matrix, wb_matrix->type());
262 world().DLOG(
"write_back {} : {}", write_back, write_back->type());
263 auto [wb_mem, element_final] = write_back->
vars<2>();
265 auto output_iterators =
DefVec((
size_t)n_nat, [&](
u64 i) {
266 auto idx = out_indices[i];
267 if (idx != i)
world().ELOG(
"output indices must be consecutive 0..n-1 but {} != {}", idx, i);
268 assert(idx == i &&
"output indices must be consecutive 0..n-1");
269 auto iter_idx_def = iterator[idx];
272 auto output_it_tuple =
world().
tuple(output_iterators);
273 world().DLOG(
"output tuple: {} : {}", output_it_tuple, output_it_tuple->type());
275 auto [wb_mem2, written_matrix] =
world()
276 .
app(
world().app(
world().annex<matrix::insert>(), {n, S, T}),
277 {wb_mem, wb_matrix, output_it_tuple, element_final})
280 write_back->app(
true, cont, {wb_mem2, written_matrix});
283 acc = {current_mem, element_acc};
287 for (
auto idx : in_indices) {
288 auto for_name =
world().
sym(
"forIn_"s + std::to_string(idx));
289 auto dim_nat_def = dims[idx];
292 auto [body, for_call] =
counting_for(dim, acc, cont, for_name);
293 auto [iter, new_acc, yield] = body->vars<3>();
295 raw_iterator[idx] = iter;
297 auto [new_mem, new_element] = new_acc->
projs<2>();
298 acc = {new_mem, new_element};
299 current_mut->set(
true, for_call);
306 current_mem = acc[0];
307 element_acc = acc[1];
310 DefVec input_elements((
size_t)m_nat);
311 for (
u64 i = 0; i < m_nat; i++) {
313 auto input_i = inputs->proj(m_nat, i);
314 auto [input_idx_tup, input_matrix] = input_i->projs<2>();
316 world().DLOG(
"input matrix {} is {} : {}", i, input_matrix, input_matrix->type());
318 auto indices = input_idx_tup->projs(n_input[i]);
319 auto input_iterators =
DefVec(n_input[i], [&](
u64 j) {
320 auto idx = indices[j];
321 auto idx_lit = idx->as<
Lit>()->get<u64>();
322 world().DLOG(
" idx {} {} = {}", i, j, idx_lit);
323 return iterator[idx_lit];
325 auto input_it_tuple =
world().
tuple(input_iterators);
327 auto read_entry =
op_read(current_mem, input_matrix, input_it_tuple);
328 world().DLOG(
"read_entry {} : {}", read_entry, read_entry->type());
329 auto [new_mem, element_i] = read_entry->projs<2>();
330 current_mem = new_mem;
331 input_elements[i] = element_i;
334 world().DLOG(
" read elements {,}", input_elements);
335 world().DLOG(
" fun {} : {}", fun, fun->type());
338 current_mut->
app(
true, comb, {
world().
tuple({current_mem, element_acc,
world().
tuple(input_elements)}), cont});