MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_matrix_mediumlevel.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5#include <mim/lam.h>
6
7#include "mim/def.h"
8
10#include "mim/plug/core/core.h"
13#include "mim/plug/mem/mem.h"
14
15using namespace std::string_literals;
16
17namespace mim::plug::matrix {
18
20 if (auto i = rewritten.find(def); i != rewritten.end()) return i->second;
21 auto new_def = rewrite_(def);
22 rewritten[def] = new_def;
23 return rewritten[def];
24}
25
26std::pair<Lam*, Ref> counting_for(Ref bound, DefVec acc, Ref exit, const char* name = "for_body") {
27 auto& world = bound->world();
28 auto acc_ty = world.tuple(acc)->type();
29 auto body
30 = world.mut_con({/* iter */ world.type_i32(), /* acc */ acc_ty, /* return */ world.cn(acc_ty)})->set(name);
31 auto for_loop = affine::op_for(world, world.lit_i32(0), bound, world.lit_i32(1), acc, body, exit);
32 return {body, for_loop};
33}
34
35// TODO: compare with other impala version (why is one easier than the other?)
36// TODO: replace sum_ptr by using sum as accumulator
37// TODO: extract inner loop into function (for read normalizer)
39 auto& world = def->world();
40
41 if (auto map_reduce_ax = match<matrix::map_reduce>(def); map_reduce_ax) {
42 // meta arguments:
43 // * n = out-count, (nat)
44 // * S = out-dim, (n*nat)
45 // * T = out-type (*)
46 // * m = in-count (nat)
47 // * NI = in-dim-count (m*nat)
48 // * TI = types (m**)
49 // * SI = dimensions (m*NI#i)
50 // arguments:
51 // * mem
52 // * zero = accumulator init (T)
53 // * combination function (mem, acc, inputs) -> (mem, acc)
54 // * input matrixes
55 auto [mem, zero, comb, inputs] = map_reduce_ax->args<4>();
56 auto [n, S, T, m, NI, TI, SI] = map_reduce_ax->callee()->as<App>()->args<7>();
57 world.DLOG("map_reduce_ax {} : {}", map_reduce_ax, map_reduce_ax->type());
58 world.DLOG("meta variables:");
59 world.DLOG(" n = {}", n);
60 world.DLOG(" S = {}", S);
61 world.DLOG(" T = {}", T);
62 world.DLOG(" m = {}", m);
63 world.DLOG(" NI = {} : {}", NI, NI->type());
64 world.DLOG(" TI = {} : {}", TI, TI->type());
65 world.DLOG(" SI = {} : {}", SI, SI->type());
66 world.DLOG("arguments:");
67 world.DLOG(" mem = {}", mem);
68 world.DLOG(" zero = {}", zero);
69 world.DLOG(" comb = {} : {}", comb, comb->type());
70 world.DLOG(" inputs = {} : {}", inputs, inputs->type());
71
72 // Our goal is to generate a call to a function that performs:
73 // ```
74 // matrix = new matrix (n, S, T)
75 // for out_idx { // n for loops
76 // acc = zero
77 // for in_idx { // remaining loops
78 // inps = read from matrices // m-tuple
79 // acc = comb(mem, acc, inps)
80 // }
81 // write acc to output matrix
82 // }
83 // return matrix
84 // ```
85
86 absl::flat_hash_map<u64, Ref> dims; // idx ↦ nat (size bound = dimension)
87 absl::flat_hash_map<u64, Ref> raw_iterator; // idx ↦ I32
88 absl::flat_hash_map<u64, Ref> iterator; // idx ↦ %Idx (S/NI#i)
89 Vector<u64> out_indices; // output indices 0..n-1
90 Vector<u64> in_indices; // input indices ≥ n
91
92 Vector<Ref> output_dims; // i<n ↦ nat (dimension S#i)
93 Vector<DefVec> input_dims; // i<m ↦ j<NI#i ↦ nat (dimension SI#i#j)
94 Vector<u64> n_input; // i<m ↦ nat (number of dimensions of SI#i)
95
96 auto n_lit = n->isa<Lit>();
97 auto m_lit = m->isa<Lit>();
98 if (!n_lit || !m_lit) {
99 world.DLOG("n or m is not a literal");
100 return def;
101 }
102
103 auto n_nat = n_lit->get<u64>(); // number of output dimensions (in S)
104 auto m_nat = m_lit->get<u64>(); // number of input matrices
105
106 // collect output dimensions
107 world.DLOG("out dims (n) = {}", n_nat);
108 for (u64 i = 0; i < n_nat; ++i) {
109 auto dim = S->proj(n_nat, i);
110 world.DLOG("dim {} = {}", i, dim);
111 dims[i] = dim;
112 output_dims.push_back(dim);
113 }
114
115 // collect other (input) dimensions
116 world.DLOG("matrix count (m) = {}", m_nat);
117
118 for (u64 i = 0; i < m_nat; ++i) {
119 auto ni = NI->proj(m_nat, i);
120 auto ni_lit = Lit::isa(ni);
121 if (!ni_lit) {
122 world.DLOG("matrix {} has non-constant dimension count", i);
123 return def;
124 }
125 u64 ni_nat = *ni_lit;
126 world.DLOG(" dims({i}) = {}", i, ni_nat);
127 auto SI_i = SI->proj(m_nat, i);
128 DefVec input_dims_i;
129 for (u64 j = 0; j < ni_nat; ++j) {
130 auto dim = SI_i->proj(ni_nat, j);
131 world.DLOG(" dim {} {} = {}", i, j, dim);
132 // dims[i * n_nat + j] = dim;
133 input_dims_i.push_back(dim);
134 }
135 input_dims.push_back(input_dims_i);
136 n_input.push_back(ni_nat);
137 }
138
139 // extracts bounds for each index (in, out)
140 for (u64 i = 0; i < m_nat; ++i) {
141 world.DLOG("investigate {} / {}", i, m_nat);
142 auto [indices, mat] = inputs->proj(m_nat, i)->projs<2>();
143 world.DLOG(" indices {} = {}", i, indices);
144 world.DLOG(" matrix {} = {}", i, mat);
145 for (u64 j = 0; j < n_input[i]; ++j) {
146 // world.DLOG(" dimension {} / {}", j, n_input[i]);
147 auto idx = indices->proj(n_input[i], j);
148 auto idx_lit = Lit::isa(idx);
149 if (!idx_lit) {
150 world.DLOG(" index {} {} is not a literal", i, j);
151 return def;
152 }
153 u64 idx_nat = *idx_lit;
154 auto dim = input_dims[i][j];
155 world.DLOG(" index {} = {}", j, idx);
156 world.DLOG(" dim {} = {}", idx, dim);
157 if (!dims.contains(idx_nat)) {
158 dims[idx_nat] = dim;
159 world.DLOG(" {} ↦ {}", idx_nat, dim);
160 } else {
161 // assert(dims[idx_nat] == dim);
162 auto prev_dim = dims[idx_nat];
163 world.DLOG(" prev dim {} = {}", idx_nat, prev_dim);
164 // override with more precise information
165 if (auto dim_lit = dim->isa<Lit>()) {
166 if (auto prev_dim_lit = prev_dim->isa<Lit>())
167 assert(dim_lit->get<u64>() == prev_dim_lit->get<u64>() && "dimensions must be equal");
168 else
169 dims[idx_nat] = dim;
170 }
171 }
172 }
173 }
174
175 for (auto [idx, dim] : dims) {
176 world.ILOG("dim {} = {}", idx, dim);
177 if (idx < n_nat)
178 out_indices.push_back(idx);
179 else
180 in_indices.push_back(idx);
181 }
182 // sort indices to make checks easier later.
183 std::sort(out_indices.begin(), out_indices.end());
184 std::sort(in_indices.begin(), in_indices.end());
185
186 // create function `%mem.M -> [%mem.M, %matrix.Mat (n,S,T)]` to replace axiom call
187
188 auto mem_type = world.annex<mem::M>();
189 auto fun = world.mut_fun(mem_type, map_reduce_ax->type())->set("mapRed");
190
191 // assert(0);
192 auto ds_fun = direct::op_cps2ds_dep(fun);
193 world.DLOG("ds_fun {} : {}", ds_fun, ds_fun->type());
194 auto call = world.app(ds_fun, mem);
195 world.DLOG("call {} : {}", call, call->type());
196
197 // flowchart:
198 // ```
199 // -> init
200 // -> forOut1 with yieldOut1
201 // => exitOut1 = return_cont
202 // -> forOut2 with yieldOut2
203 // => exitOut2 = yieldOut1
204 // -> ...
205 // -> accumulator init
206 // -> forIn1 with yieldIn1
207 // => exitIn1 = writeCont
208 // -> forIn2 with yieldIn2
209 // => exitIn2 = yieldIn1
210 // -> ...
211 // -> read matrices
212 // -> fun
213 // => exitFun = yieldInM
214 //
215 // (return path)
216 // -> ...
217 // -> write
218 // -> yieldOutN
219 // -> ...
220 // ```
221
222 // First create the output matrix.
223 auto current_mem = mem;
224 auto [mem2, init_mat] = world.app(world.annex<matrix::init>(), {n, S, T, current_mem})->projs<2>();
225 current_mem = mem2;
226
227 // The function on where to continue -- return after all output loops.
228 auto cont = fun->var(1);
229 auto current_mut = fun;
230
231 // Each of the outer loops contains the memory and matrix as accumulator (in an inner monad).
232 DefVec acc = {current_mem, init_mat};
233
234 for (auto idx : out_indices) {
235 auto for_name = world.sym("forIn_"s + std::to_string(idx));
236 auto dim_nat_def = dims[idx];
237 auto dim = world.call<core::bitcast>(world.type_i32(), dim_nat_def);
238
239 auto [body, for_call] = counting_for(dim, acc, cont, for_name);
240 auto [iter, new_acc, yield] = body->vars<3>();
241 cont = yield;
242 raw_iterator[idx] = iter;
243 iterator[idx] = world.call<core::bitcast>(world.type_idx(dim_nat_def), iter);
244 auto [new_mem, new_mat] = new_acc->projs<2>();
245 acc = {new_mem, new_mat};
246 current_mut->set(true, for_call);
247 current_mut = body;
248 }
249
250 // Now the inner loops for the inputs:
251 // Each of the inner loops contains the element accumulator and memory as accumulator (in an inner monad).
252 world.DLOG("acc at inner: {;}", acc);
253
254 // First create the accumulator.
255 auto element_acc = zero;
256 element_acc->set("acc");
257 current_mem = acc[0];
258 auto wb_matrix = acc[1];
259 assert(wb_matrix);
260 world.DLOG("wb_matrix {} : {}", wb_matrix, wb_matrix->type());
261
262 // Write back element to matrix. Set this as return after all inner loops.
263 auto write_back = mem::mut_con(T)->set("matrixWriteBack");
264 world.DLOG("write_back {} : {}", write_back, write_back->type());
265 auto [wb_mem, element_final] = write_back->vars<2>();
266
267 auto output_iterators = DefVec((size_t)n_nat, [&](u64 i) {
268 auto idx = out_indices[i];
269 if (idx != i) world.ELOG("output indices must be consecutive 0..n-1 but {} != {}", idx, i);
270 assert(idx == i && "output indices must be consecutive 0..n-1");
271 auto iter_idx_def = iterator[idx];
272 return iter_idx_def;
273 });
274 auto output_it_tuple = world.tuple(output_iterators);
275 world.DLOG("output tuple: {} : {}", output_it_tuple, output_it_tuple->type());
276
277 auto [wb_mem2, written_matrix] = world
278 .app(world.app(world.annex<matrix::insert>(), {n, S, T}),
279 {wb_mem, wb_matrix, output_it_tuple, element_final})
280 ->projs<2>();
281
282 write_back->app(true, cont, {wb_mem2, written_matrix});
283
284 // From here on the continuations take the element and memory.
285 acc = {current_mem, element_acc};
286 cont = write_back;
287
288 // TODO this is copy&paste code from above
289 for (auto idx : in_indices) {
290 auto for_name = world.sym("forIn_"s + std::to_string(idx));
291 auto dim_nat_def = dims[idx];
292 auto dim = world.call<core::bitcast>(world.type_i32(), dim_nat_def);
293
294 auto [body, for_call] = counting_for(dim, acc, cont, for_name);
295 auto [iter, new_acc, yield] = body->vars<3>();
296 cont = yield;
297 raw_iterator[idx] = iter;
298 iterator[idx] = world.call<core::bitcast>(world.type_idx(dim_nat_def), iter);
299 auto [new_mem, new_element] = new_acc->projs<2>();
300 acc = {new_mem, new_element};
301 current_mut->set(true, for_call);
302 current_mut = body;
303 }
304
305 // For testing: id in innermost loop instead of read, fun:
306 // current_mut->app(true, cont, acc);
307
308 current_mem = acc[0];
309 element_acc = acc[1];
310
311 // Read element from input matrix.
312 DefVec input_elements((size_t)m_nat);
313 for (u64 i = 0; i < m_nat; i++) {
314 // TODO: case m_nat == 1
315 auto input_i = inputs->proj(m_nat, i);
316 auto [input_idx_tup, input_matrix] = input_i->projs<2>();
317
318 world.DLOG("input matrix {} is {} : {}", i, input_matrix, input_matrix->type());
319
320 auto indices = input_idx_tup->projs(n_input[i]);
321 auto input_iterators = DefVec(n_input[i], [&](u64 j) {
322 auto idx = indices[j];
323 auto idx_lit = idx->as<Lit>()->get<u64>();
324 world.DLOG(" idx {} {} = {}", i, j, idx_lit);
325 return iterator[idx_lit];
326 });
327 auto input_it_tuple = world.tuple(input_iterators);
328
329 auto read_entry = op_read(current_mem, input_matrix, input_it_tuple);
330 world.DLOG("read_entry {} : {}", read_entry, read_entry->type());
331 auto [new_mem, element_i] = read_entry->projs<2>();
332 current_mem = new_mem;
333 input_elements[i] = element_i;
334 }
335
336 world.DLOG(" read elements {,}", input_elements);
337 world.DLOG(" fun {} : {}", fun, fun->type());
338
339 // TODO: make non-scalar or completely scalar?
340 current_mut->app(true, comb, {world.tuple({current_mem, element_acc, world.tuple(input_elements)}), cont});
341
342 return call;
343 }
344
345 return def;
346}
347
348} // namespace mim::plug::matrix
World & world() const
Definition def.cpp:415
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Definition def.h:367
const Def * type() const
Definition def.h:248
Lam * set(Filter filter, const Def *body)
Definition lam.h:166
static std::optional< T > isa(Ref def)
Definition def.h:763
T get() const
Definition def.h:751
World & world()
Definition pass.h:296
Helper class to retrieve Infer::arg if present.
Definition def.h:86
Vars vars(const Var *var)
Definition world.h:505
Lam * mut_fun(Ref dom, Ref codom)
Definition world.h:302
Ref tuple(Defs ops)
Definition world.cpp:238
const Idx * type_idx()
Definition world.h:469
const Def * annex(Id id)
Lookup annex by Axiom::id.
Definition world.h:185
Ref app(Ref callee, Ref arg)
Definition world.cpp:186
Sym sym(std::string_view)
Definition world.cpp:77
const Def * call(Id id, Args &&... args)
Definition world.h:518
Ref type_i32()
Definition world.h:485
Ref rewrite(Ref) override
custom rewrite function memoized version of rewrite_
const Def * op_for(World &w, Ref begin, Ref end, Ref step, Defs inits, Ref body, Ref brk)
Returns a fully applied affine_for axiom.
Definition affine.h:19
const Def * op_cps2ds_dep(const Def *f)
Definition direct.h:11
The matrix Plugin
Definition matrix.h:9
std::pair< Lam *, Ref > counting_for(Ref bound, DefVec acc, Ref exit, const char *name="for_body")
const Def * op_read(Ref mem, Ref matrix, Ref idx)
Definition matrix.h:19
The mem Plugin
Definition mem.h:11
Lam * mut_con(World &w)
Yields con[mem.M].
Definition mem.h:16
Vector< const Def * > DefVec
Definition def.h:62
auto match(Ref def)
Definition axiom.h:112
uint64_t u64
Definition types.h:34