MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_matrix_mediumlevel.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5#include <mim/lam.h>
6
7#include "mim/def.h"
8
10#include "mim/plug/core/core.h"
13#include "mim/plug/mem/mem.h"
14
15using namespace std::string_literals;
16
17namespace mim::plug::matrix {
18
20 if (auto i = rewritten.find(def); i != rewritten.end()) return i->second;
21 auto new_def = rewrite_(def);
22 rewritten[def] = new_def;
23 return rewritten[def];
24}
25
26std::pair<Lam*, Ref> counting_for(Ref bound, DefVec acc, Ref exit, const char* name = "for_body") {
27 auto& world = bound->world();
28 auto acc_ty = world.tuple(acc)->type();
29 auto body
30 = world.mut_con({/* iter */ world.type_i32(), /* acc */ acc_ty, /* return */ world.cn(acc_ty)})->set(name);
31 auto for_loop = affine::op_for(world, world.lit_i32(0), bound, world.lit_i32(1), acc, body, exit);
32 return {body, for_loop};
33}
34
35// TODO: compare with other impala version (why is one easier than the other?)
36// TODO: replace sum_ptr by using sum as accumulator
37// TODO: extract inner loop into function (for read normalizer)
39 if (auto map_reduce_ax = match<matrix::map_reduce>(def); map_reduce_ax) {
40 // meta arguments:
41 // * n = out-count, (nat)
42 // * S = out-dim, (n*nat)
43 // * T = out-type (*)
44 // * m = in-count (nat)
45 // * NI = in-dim-count (m*nat)
46 // * TI = types (m**)
47 // * SI = dimensions (m*NI#i)
48 // arguments:
49 // * mem
50 // * zero = accumulator init (T)
51 // * combination function (mem, acc, inputs) -> (mem, acc)
52 // * input matrixes
53 auto [mem, zero, comb, inputs] = map_reduce_ax->args<4>();
54 auto [n, S, T, m, NI, TI, SI] = map_reduce_ax->callee()->as<App>()->args<7>();
55 world().DLOG("map_reduce_ax {} : {}", map_reduce_ax, map_reduce_ax->type());
56 world().DLOG("meta variables:");
57 world().DLOG(" n = {}", n);
58 world().DLOG(" S = {}", S);
59 world().DLOG(" T = {}", T);
60 world().DLOG(" m = {}", m);
61 world().DLOG(" NI = {} : {}", NI, NI->type());
62 world().DLOG(" TI = {} : {}", TI, TI->type());
63 world().DLOG(" SI = {} : {}", SI, SI->type());
64 world().DLOG("arguments:");
65 world().DLOG(" mem = {}", mem);
66 world().DLOG(" zero = {}", zero);
67 world().DLOG(" comb = {} : {}", comb, comb->type());
68 world().DLOG(" inputs = {} : {}", inputs, inputs->type());
69
70 // Our goal is to generate a call to a function that performs:
71 // ```
72 // matrix = new matrix (n, S, T)
73 // for out_idx { // n for loops
74 // acc = zero
75 // for in_idx { // remaining loops
76 // inps = read from matrices // m-tuple
77 // acc = comb(mem, acc, inps)
78 // }
79 // write acc to output matrix
80 // }
81 // return matrix
82 // ```
83
84 absl::flat_hash_map<u64, Ref> dims; // idx ↦ nat (size bound = dimension)
85 absl::flat_hash_map<u64, Ref> raw_iterator; // idx ↦ I32
86 absl::flat_hash_map<u64, Ref> iterator; // idx ↦ %Idx (S/NI#i)
87 Vector<u64> out_indices; // output indices 0..n-1
88 Vector<u64> in_indices; // input indices ≥ n
89
90 Vector<Ref> output_dims; // i<n ↦ nat (dimension S#i)
91 Vector<DefVec> input_dims; // i<m ↦ j<NI#i ↦ nat (dimension SI#i#j)
92 Vector<u64> n_input; // i<m ↦ nat (number of dimensions of SI#i)
93
94 auto n_lit = n->isa<Lit>();
95 auto m_lit = m->isa<Lit>();
96 if (!n_lit || !m_lit) {
97 world().DLOG("n or m is not a literal");
98 return def;
99 }
100
101 auto n_nat = n_lit->get<u64>(); // number of output dimensions (in S)
102 auto m_nat = m_lit->get<u64>(); // number of input matrices
103
104 // collect output dimensions
105 world().DLOG("out dims (n) = {}", n_nat);
106 for (u64 i = 0; i < n_nat; ++i) {
107 auto dim = S->proj(n_nat, i);
108 world().DLOG("dim {} = {}", i, dim);
109 dims[i] = dim;
110 output_dims.push_back(dim);
111 }
112
113 // collect other (input) dimensions
114 world().DLOG("matrix count (m) = {}", m_nat);
115
116 for (u64 i = 0; i < m_nat; ++i) {
117 auto ni = NI->proj(m_nat, i);
118 auto ni_lit = Lit::isa(ni);
119 if (!ni_lit) {
120 world().DLOG("matrix {} has non-constant dimension count", i);
121 return def;
122 }
123 u64 ni_nat = *ni_lit;
124 world().DLOG(" dims({i}) = {}", i, ni_nat);
125 auto SI_i = SI->proj(m_nat, i);
126 DefVec input_dims_i;
127 for (u64 j = 0; j < ni_nat; ++j) {
128 auto dim = SI_i->proj(ni_nat, j);
129 world().DLOG(" dim {} {} = {}", i, j, dim);
130 // dims[i * n_nat + j] = dim;
131 input_dims_i.push_back(dim);
132 }
133 input_dims.push_back(input_dims_i);
134 n_input.push_back(ni_nat);
135 }
136
137 // extracts bounds for each index (in, out)
138 for (u64 i = 0; i < m_nat; ++i) {
139 world().DLOG("investigate {} / {}", i, m_nat);
140 auto [indices, mat] = inputs->proj(m_nat, i)->projs<2>();
141 world().DLOG(" indices {} = {}", i, indices);
142 world().DLOG(" matrix {} = {}", i, mat);
143 for (u64 j = 0; j < n_input[i]; ++j) {
144 // world().DLOG(" dimension {} / {}", j, n_input[i]);
145 auto idx = indices->proj(n_input[i], j);
146 auto idx_lit = Lit::isa(idx);
147 if (!idx_lit) {
148 world().DLOG(" index {} {} is not a literal", i, j);
149 return def;
150 }
151 u64 idx_nat = *idx_lit;
152 auto dim = input_dims[i][j];
153 world().DLOG(" index {} = {}", j, idx);
154 world().DLOG(" dim {} = {}", idx, dim);
155 if (!dims.contains(idx_nat)) {
156 dims[idx_nat] = dim;
157 world().DLOG(" {} ↦ {}", idx_nat, dim);
158 } else {
159 // assert(dims[idx_nat] == dim);
160 auto prev_dim = dims[idx_nat];
161 world().DLOG(" prev dim {} = {}", idx_nat, prev_dim);
162 // override with more precise information
163 if (auto dim_lit = dim->isa<Lit>()) {
164 if (auto prev_dim_lit = prev_dim->isa<Lit>())
165 assert(dim_lit->get<u64>() == prev_dim_lit->get<u64>() && "dimensions must be equal");
166 else
167 dims[idx_nat] = dim;
168 }
169 }
170 }
171 }
172
173 for (auto [idx, dim] : dims) {
174 world().ILOG("dim {} = {}", idx, dim);
175 if (idx < n_nat)
176 out_indices.push_back(idx);
177 else
178 in_indices.push_back(idx);
179 }
180 // sort indices to make checks easier later.
181 std::sort(out_indices.begin(), out_indices.end());
182 std::sort(in_indices.begin(), in_indices.end());
183
184 // create function `%mem.M -> [%mem.M, %matrix.Mat (n,S,T)]` to replace axiom call
185
186 auto mem_type = world().annex<mem::M>();
187 auto fun = world().mut_fun(mem_type, map_reduce_ax->type())->set("mapRed");
188
189 // assert(0);
190 auto ds_fun = direct::op_cps2ds_dep(fun);
191 world().DLOG("ds_fun {} : {}", ds_fun, ds_fun->type());
192 auto call = world().app(ds_fun, mem);
193 world().DLOG("call {} : {}", call, call->type());
194
195 // flowchart:
196 // ```
197 // -> init
198 // -> forOut1 with yieldOut1
199 // => exitOut1 = return_cont
200 // -> forOut2 with yieldOut2
201 // => exitOut2 = yieldOut1
202 // -> ...
203 // -> accumulator init
204 // -> forIn1 with yieldIn1
205 // => exitIn1 = writeCont
206 // -> forIn2 with yieldIn2
207 // => exitIn2 = yieldIn1
208 // -> ...
209 // -> read matrices
210 // -> fun
211 // => exitFun = yieldInM
212 //
213 // (return path)
214 // -> ...
215 // -> write
216 // -> yieldOutN
217 // -> ...
218 // ```
219
220 // First create the output matrix.
221 auto current_mem = mem;
222 auto [mem2, init_mat] = world().app(world().annex<matrix::init>(), {n, S, T, current_mem})->projs<2>();
223 current_mem = mem2;
224
225 // The function on where to continue -- return after all output loops.
226 auto cont = fun->var(1);
227 auto current_mut = fun;
228
229 // Each of the outer loops contains the memory and matrix as accumulator (in an inner monad).
230 DefVec acc = {current_mem, init_mat};
231
232 for (auto idx : out_indices) {
233 auto for_name = world().sym("forIn_"s + std::to_string(idx));
234 auto dim_nat_def = dims[idx];
235 auto dim = world().call<core::bitcast>(world().type_i32(), dim_nat_def);
236
237 auto [body, for_call] = counting_for(dim, acc, cont, for_name);
238 auto [iter, new_acc, yield] = body->vars<3>();
239 cont = yield;
240 raw_iterator[idx] = iter;
241 iterator[idx] = world().call<core::bitcast>(world().type_idx(dim_nat_def), iter);
242 auto [new_mem, new_mat] = new_acc->projs<2>();
243 acc = {new_mem, new_mat};
244 current_mut->set(true, for_call);
245 current_mut = body;
246 }
247
248 // Now the inner loops for the inputs:
249 // Each of the inner loops contains the element accumulator and memory as accumulator (in an inner monad).
250 world().DLOG("acc at inner: {;}", acc);
251
252 // First create the accumulator.
253 auto element_acc = zero;
254 element_acc->set("acc");
255 current_mem = acc[0];
256 auto wb_matrix = acc[1];
257 assert(wb_matrix);
258 world().DLOG("wb_matrix {} : {}", wb_matrix, wb_matrix->type());
259
260 // Write back element to matrix. Set this as return after all inner loops.
261 auto write_back = mem::mut_con(T)->set("matrixWriteBack");
262 world().DLOG("write_back {} : {}", write_back, write_back->type());
263 auto [wb_mem, element_final] = write_back->vars<2>();
264
265 auto output_iterators = DefVec((size_t)n_nat, [&](u64 i) {
266 auto idx = out_indices[i];
267 if (idx != i) world().ELOG("output indices must be consecutive 0..n-1 but {} != {}", idx, i);
268 assert(idx == i && "output indices must be consecutive 0..n-1");
269 auto iter_idx_def = iterator[idx];
270 return iter_idx_def;
271 });
272 auto output_it_tuple = world().tuple(output_iterators);
273 world().DLOG("output tuple: {} : {}", output_it_tuple, output_it_tuple->type());
274
275 auto [wb_mem2, written_matrix] = world()
276 .app(world().app(world().annex<matrix::insert>(), {n, S, T}),
277 {wb_mem, wb_matrix, output_it_tuple, element_final})
278 ->projs<2>();
279
280 write_back->app(true, cont, {wb_mem2, written_matrix});
281
282 // From here on the continuations take the element and memory.
283 acc = {current_mem, element_acc};
284 cont = write_back;
285
286 // TODO this is copy&paste code from above
287 for (auto idx : in_indices) {
288 auto for_name = world().sym("forIn_"s + std::to_string(idx));
289 auto dim_nat_def = dims[idx];
290 auto dim = world().call<core::bitcast>(world().type_i32(), dim_nat_def);
291
292 auto [body, for_call] = counting_for(dim, acc, cont, for_name);
293 auto [iter, new_acc, yield] = body->vars<3>();
294 cont = yield;
295 raw_iterator[idx] = iter;
296 iterator[idx] = world().call<core::bitcast>(world().type_idx(dim_nat_def), iter);
297 auto [new_mem, new_element] = new_acc->projs<2>();
298 acc = {new_mem, new_element};
299 current_mut->set(true, for_call);
300 current_mut = body;
301 }
302
303 // For testing: id in innermost loop instead of read, fun:
304 // current_mut->app(true, cont, acc);
305
306 current_mem = acc[0];
307 element_acc = acc[1];
308
309 // Read element from input matrix.
310 DefVec input_elements((size_t)m_nat);
311 for (u64 i = 0; i < m_nat; i++) {
312 // TODO: case m_nat == 1
313 auto input_i = inputs->proj(m_nat, i);
314 auto [input_idx_tup, input_matrix] = input_i->projs<2>();
315
316 world().DLOG("input matrix {} is {} : {}", i, input_matrix, input_matrix->type());
317
318 auto indices = input_idx_tup->projs(n_input[i]);
319 auto input_iterators = DefVec(n_input[i], [&](u64 j) {
320 auto idx = indices[j];
321 auto idx_lit = idx->as<Lit>()->get<u64>();
322 world().DLOG(" idx {} {} = {}", i, j, idx_lit);
323 return iterator[idx_lit];
324 });
325 auto input_it_tuple = world().tuple(input_iterators);
326
327 auto read_entry = op_read(current_mem, input_matrix, input_it_tuple);
328 world().DLOG("read_entry {} : {}", read_entry, read_entry->type());
329 auto [new_mem, element_i] = read_entry->projs<2>();
330 current_mem = new_mem;
331 input_elements[i] = element_i;
332 }
333
334 world().DLOG(" read elements {,}", input_elements);
335 world().DLOG(" fun {} : {}", fun, fun->type());
336
337 // TODO: make non-scalar or completely scalar?
338 current_mut->app(true, comb, {world().tuple({current_mem, element_acc, world().tuple(input_elements)}), cont});
339
340 return call;
341 }
342
343 return def;
344}
345
346} // namespace mim::plug::matrix
Ref type() const
Definition def.h:251
World & world() const
Definition def.cpp:411
Ref var(nat_t a, nat_t i)
Definition def.h:395
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:361
Lam * set(Filter filter, Ref body)
Definition lam.h:173
static std::optional< T > isa(Ref def)
Definition def.h:758
T get() const
Definition def.h:746
World & world()
Definition pass.h:296
Helper class to retrieve Infer::arg if present.
Definition def.h:86
Lam * mut_fun(Ref dom, Ref codom)
Definition world.h:302
Ref tuple(Defs ops)
Definition world.cpp:233
const Idx * type_idx()
Definition world.h:468
const Def * annex(Id id)
Lookup annex by Axiom::id.
Definition world.h:185
Sym sym(std::string_view)
Definition world.cpp:77
Ref type_i32()
Definition world.h:484
const Def * call(Id id, Args &&... args)
Complete curried call of annexes obeying implicits.
Definition world.h:507
Ref app(Ref callee, Ref arg)
Definition world.cpp:170
auto & vars()
Definition world.h:515
Ref rewrite(Ref) override
custom rewrite function memoized version of rewrite_
const Def * op_for(World &w, Ref begin, Ref end, Ref step, Defs inits, Ref body, Ref brk)
Returns a fully applied affine_for axiom.
Definition affine.h:19
Ref op_cps2ds_dep(Ref k)
Definition direct.h:15
The matrix Plugin
Definition matrix.h:9
std::pair< Lam *, Ref > counting_for(Ref bound, DefVec acc, Ref exit, const char *name="for_body")
const Def * op_read(Ref mem, Ref matrix, Ref idx)
Definition matrix.h:19
The mem Plugin
Definition mem.h:11
Lam * mut_con(World &w)
Yields con[mem.M].
Definition mem.h:16
Vector< const Def * > DefVec
Definition def.h:62
auto match(Ref def)
Definition axiom.h:112
uint64_t u64
Definition types.h:34