54 auto [
mem, zero, comb, inputs] = map_reduce_ax->args<4>();
55 auto [n, S, T, m, NI, TI, SI] = map_reduce_ax->callee()->as<
App>()->args<7>();
56 world().DLOG(
"map_reduce_ax {} : {}", map_reduce_ax, map_reduce_ax->type());
57 world().DLOG(
"meta variables:");
58 world().DLOG(
" n = {}", n);
59 world().DLOG(
" S = {}", S);
60 world().DLOG(
" T = {}", T);
61 world().DLOG(
" m = {}", m);
62 world().DLOG(
" NI = {} : {}", NI, NI->type());
63 world().DLOG(
" TI = {} : {}", TI, TI->type());
64 world().DLOG(
" SI = {} : {}", SI, SI->type());
65 world().DLOG(
"arguments:");
67 world().DLOG(
" zero = {}", zero);
68 world().DLOG(
" comb = {} : {}", comb, comb->type());
69 world().DLOG(
" inputs = {} : {}", inputs, inputs->type());
85 absl::flat_hash_map<u64, const Def*> dims;
86 absl::flat_hash_map<u64, const Def*> raw_iterator;
87 absl::flat_hash_map<u64, const Def*> iterator;
95 auto n_lit = n->isa<
Lit>();
96 auto m_lit = m->isa<
Lit>();
97 if (!n_lit || !m_lit) {
98 world().DLOG(
"n or m is not a literal");
102 auto n_nat = n_lit->
get<
u64>();
103 auto m_nat = m_lit->
get<
u64>();
106 world().DLOG(
"out dims (n) = {}", n_nat);
107 for (
u64 i = 0; i < n_nat; ++i) {
108 auto dim = S->proj(n_nat, i);
109 world().DLOG(
"dim {} = {}", i, dim);
111 output_dims.push_back(dim);
115 world().DLOG(
"matrix count (m) = {}", m_nat);
117 for (
u64 i = 0; i < m_nat; ++i) {
118 auto ni = NI->proj(m_nat, i);
121 world().DLOG(
"matrix {} has non-constant dimension count", i);
124 u64 ni_nat = *ni_lit;
125 world().DLOG(
" dims({i}) = {}", i, ni_nat);
126 auto SI_i = SI->proj(m_nat, i);
128 for (
u64 j = 0; j < ni_nat; ++j) {
129 auto dim = SI_i->proj(ni_nat, j);
130 world().DLOG(
" dim {} {} = {}", i, j, dim);
132 input_dims_i.push_back(dim);
134 input_dims.push_back(input_dims_i);
135 n_input.push_back(ni_nat);
139 for (
u64 i = 0; i < m_nat; ++i) {
140 world().DLOG(
"investigate {} / {}", i, m_nat);
141 auto [indices, mat] = inputs->proj(m_nat, i)->projs<2>();
142 world().DLOG(
" indices {} = {}", i, indices);
143 world().DLOG(
" matrix {} = {}", i, mat);
144 for (
u64 j = 0; j < n_input[i]; ++j) {
146 auto idx = indices->proj(n_input[i], j);
149 world().DLOG(
" index {} {} is not a literal", i, j);
152 u64 idx_nat = *idx_lit;
153 auto dim = input_dims[i][j];
154 world().DLOG(
" index {} = {}", j, idx);
155 world().DLOG(
" dim {} = {}", idx, dim);
156 if (!dims.contains(idx_nat)) {
158 world().DLOG(
" {} ↦ {}", idx_nat, dim);
161 auto prev_dim = dims[idx_nat];
162 world().DLOG(
" prev dim {} = {}", idx_nat, prev_dim);
164 if (
auto dim_lit = dim->isa<
Lit>()) {
165 if (
auto prev_dim_lit = prev_dim->isa<
Lit>())
166 assert(dim_lit->get<
u64>() == prev_dim_lit->get<
u64>() &&
"dimensions must be equal");
174 for (
auto [idx, dim] : dims) {
175 world().ILOG(
"dim {} = {}", idx, dim);
177 out_indices.push_back(idx);
179 in_indices.push_back(idx);
182 std::sort(out_indices.begin(), out_indices.end());
183 std::sort(in_indices.begin(), in_indices.end());
188 auto fun =
world().
mut_fun(mem_type, map_reduce_ax->type())->
set(
"mapRed");
192 world().DLOG(
"ds_fun {} : {}", ds_fun, ds_fun->type());
194 world().DLOG(
"call {} : {}", call, call->type());
222 auto current_mem =
mem;
223 auto [mem2, init_mat] =
world().
app(
world().annex<matrix::init>(), {n, S, T, current_mem})->projs<2>();
227 auto cont = fun->
var(1);
228 auto current_mut = fun;
231 DefVec acc = {current_mem, init_mat};
233 for (
auto idx : out_indices) {
234 auto for_name =
world().
sym(
"forIn_"s + std::to_string(idx));
235 auto dim_nat_def = dims[idx];
238 auto [body, for_call] =
counting_for(dim, acc, cont, for_name);
239 auto [iter, new_acc, yield] = body->vars<3>();
241 raw_iterator[idx] = iter;
243 auto [new_mem, new_mat] = new_acc->projs<2>();
244 acc = {new_mem, new_mat};
245 current_mut->set(
true, for_call);
251 world().DLOG(
"acc at inner: {;}", acc);
254 auto element_acc = zero;
255 element_acc->set(
"acc");
256 current_mem = acc[0];
257 auto wb_matrix = acc[1];
259 world().DLOG(
"wb_matrix {} : {}", wb_matrix, wb_matrix->type());
263 world().DLOG(
"write_back {} : {}", write_back, write_back->type());
264 auto [wb_mem, element_final] = write_back->
vars<2>();
266 auto output_iterators =
DefVec((
size_t)n_nat, [&](
u64 i) {
267 auto idx = out_indices[i];
268 if (idx != i)
world().ELOG(
"output indices must be consecutive 0..n-1 but {} != {}", idx, i);
269 assert(idx == i &&
"output indices must be consecutive 0..n-1");
270 auto iter_idx_def = iterator[idx];
273 auto output_it_tuple =
world().
tuple(output_iterators);
274 world().DLOG(
"output tuple: {} : {}", output_it_tuple, output_it_tuple->type());
276 auto [wb_mem2, written_matrix] =
world()
277 .
app(
world().app(
world().annex<matrix::insert>(), {n, S, T}),
278 {wb_mem, wb_matrix, output_it_tuple, element_final})
281 write_back->app(
true, cont, {wb_mem2, written_matrix});
284 acc = {current_mem, element_acc};
288 for (
auto idx : in_indices) {
289 auto for_name =
world().
sym(
"forIn_"s + std::to_string(idx));
290 auto dim_nat_def = dims[idx];
293 auto [body, for_call] =
counting_for(dim, acc, cont, for_name);
294 auto [iter, new_acc, yield] = body->vars<3>();
296 raw_iterator[idx] = iter;
298 auto [new_mem, new_element] = new_acc->projs<2>();
299 acc = {new_mem, new_element};
300 current_mut->set(
true, for_call);
307 current_mem = acc[0];
308 element_acc = acc[1];
311 DefVec input_elements((
size_t)m_nat);
312 for (
u64 i = 0; i < m_nat; i++) {
314 auto input_i = inputs->proj(m_nat, i);
315 auto [input_idx_tup, input_matrix] = input_i->projs<2>();
317 world().DLOG(
"input matrix {} is {} : {}", i, input_matrix, input_matrix->type());
319 auto indices = input_idx_tup->projs(n_input[i]);
320 auto input_iterators =
DefVec(n_input[i], [&](
u64 j) {
321 auto idx = indices[j];
323 world().DLOG(
" idx {} {} = {}", i, j, idx_lit);
324 return iterator[idx_lit];
326 auto input_it_tuple =
world().
tuple(input_iterators);
328 auto read_entry =
op_read(current_mem, input_matrix, input_it_tuple);
329 world().DLOG(
"read_entry {} : {}", read_entry, read_entry->type());
330 auto [new_mem, element_i] = read_entry->projs<2>();
331 current_mem = new_mem;
332 input_elements[i] = element_i;
335 world().DLOG(
" read elements {,}", input_elements);
336 world().DLOG(
" fun {} : {}", fun, fun->type());
339 current_mut->
app(
true, comb, {
world().
tuple({current_mem, element_acc,
world().
tuple(input_elements)}), cont});