MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_matrix_mediumlevel.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5#include <mim/lam.h>
6
7#include "mim/def.h"
8
10#include "mim/plug/core/core.h"
13#include "mim/plug/mem/mem.h"
14
15using namespace std::string_literals;
16
17namespace mim::plug::matrix {
18
20 if (auto i = rewritten.find(def); i != rewritten.end()) return i->second;
21 auto new_def = rewrite_(def);
22 rewritten[def] = new_def;
23 return rewritten[def];
24}
25
26std::pair<Lam*, const Def*> counting_for(const Def* bound, DefVec acc, const Def* exit, Sym name) {
27 auto& world = bound->world();
28 auto acc_ty = world.tuple(acc)->type();
29 auto body
30 = world.mut_con({/* iter */ world.type_i32(), /* acc */ acc_ty, /* return */ world.cn(acc_ty)})->set(name);
31 auto for_loop
32 = world.call<affine::For>(Defs{world.lit_i32(0), bound, world.lit_i32(1), world.tuple(acc), body, exit});
33 return {body, for_loop};
34}
35
36// TODO: compare with other impala version (why is one easier than the other?)
37// TODO: replace sum_ptr by using sum as accumulator
38// TODO: extract inner loop into function (for read normalizer)
40 if (auto map_reduce_ax = Axm::isa<matrix::map_reduce>(def); map_reduce_ax) {
41 // meta arguments:
42 // * n = out-count, (nat)
43 // * S = out-dim, (n*nat)
44 // * T = out-type (*)
45 // * m = in-count (nat)
46 // * NI = in-dim-count (m*nat)
47 // * TI = types (m**)
48 // * SI = dimensions (m*NI#i)
49 // arguments:
50 // * mem
51 // * zero = accumulator init (T)
52 // * combination function (mem, acc, inputs) -> (mem, acc)
53 // * input matrixes
54 auto [mem, zero, comb, inputs] = map_reduce_ax->args<4>();
55 auto [n, S, T, m, NI, TI, SI] = map_reduce_ax->callee()->as<App>()->args<7>();
56 world().DLOG("map_reduce_ax {} : {}", map_reduce_ax, map_reduce_ax->type());
57 world().DLOG("meta variables:");
58 world().DLOG(" n = {}", n);
59 world().DLOG(" S = {}", S);
60 world().DLOG(" T = {}", T);
61 world().DLOG(" m = {}", m);
62 world().DLOG(" NI = {} : {}", NI, NI->type());
63 world().DLOG(" TI = {} : {}", TI, TI->type());
64 world().DLOG(" SI = {} : {}", SI, SI->type());
65 world().DLOG("arguments:");
66 world().DLOG(" mem = {}", mem);
67 world().DLOG(" zero = {}", zero);
68 world().DLOG(" comb = {} : {}", comb, comb->type());
69 world().DLOG(" inputs = {} : {}", inputs, inputs->type());
70
71 // Our goal is to generate a call to a function that performs:
72 // ```
73 // matrix = new matrix (n, S, T)
74 // for out_idx { // n for loops
75 // acc = zero
76 // for in_idx { // remaining loops
77 // inps = read from matrices // m-tuple
78 // acc = comb(mem, acc, inps)
79 // }
80 // write acc to output matrix
81 // }
82 // return matrix
83 // ```
84
85 absl::flat_hash_map<u64, const Def*> dims; // idx ↦ nat (size bound = dimension)
86 absl::flat_hash_map<u64, const Def*> raw_iterator; // idx ↦ I32
87 absl::flat_hash_map<u64, const Def*> iterator; // idx ↦ %Idx (S/NI#i)
88 Vector<u64> out_indices; // output indices 0..n-1
89 Vector<u64> in_indices; // input indices ≥ n
90
91 Vector<const Def*> output_dims; // i<n ↦ nat (dimension S#i)
92 Vector<DefVec> input_dims; // i<m ↦ j<NI#i ↦ nat (dimension SI#i#j)
93 Vector<u64> n_input; // i<m ↦ nat (number of dimensions of SI#i)
94
95 auto n_lit = n->isa<Lit>();
96 auto m_lit = m->isa<Lit>();
97 if (!n_lit || !m_lit) {
98 world().DLOG("n or m is not a literal");
99 return def;
100 }
101
102 auto n_nat = n_lit->get<u64>(); // number of output dimensions (in S)
103 auto m_nat = m_lit->get<u64>(); // number of input matrices
104
105 // collect output dimensions
106 world().DLOG("out dims (n) = {}", n_nat);
107 for (u64 i = 0; i < n_nat; ++i) {
108 auto dim = S->proj(n_nat, i);
109 world().DLOG("dim {} = {}", i, dim);
110 dims[i] = dim;
111 output_dims.push_back(dim);
112 }
113
114 // collect other (input) dimensions
115 world().DLOG("matrix count (m) = {}", m_nat);
116
117 for (u64 i = 0; i < m_nat; ++i) {
118 auto ni = NI->proj(m_nat, i);
119 auto ni_lit = Lit::isa(ni);
120 if (!ni_lit) {
121 world().DLOG("matrix {} has non-constant dimension count", i);
122 return def;
123 }
124 u64 ni_nat = *ni_lit;
125 world().DLOG(" dims({i}) = {}", i, ni_nat);
126 auto SI_i = SI->proj(m_nat, i);
127 DefVec input_dims_i;
128 for (u64 j = 0; j < ni_nat; ++j) {
129 auto dim = SI_i->proj(ni_nat, j);
130 world().DLOG(" dim {} {} = {}", i, j, dim);
131 // dims[i * n_nat + j] = dim;
132 input_dims_i.push_back(dim);
133 }
134 input_dims.push_back(input_dims_i);
135 n_input.push_back(ni_nat);
136 }
137
138 // extracts bounds for each index (in, out)
139 for (u64 i = 0; i < m_nat; ++i) {
140 world().DLOG("investigate {} / {}", i, m_nat);
141 auto [indices, mat] = inputs->proj(m_nat, i)->projs<2>();
142 world().DLOG(" indices {} = {}", i, indices);
143 world().DLOG(" matrix {} = {}", i, mat);
144 for (u64 j = 0; j < n_input[i]; ++j) {
145 // world().DLOG(" dimension {} / {}", j, n_input[i]);
146 auto idx = indices->proj(n_input[i], j);
147 auto idx_lit = Lit::isa(idx);
148 if (!idx_lit) {
149 world().DLOG(" index {} {} is not a literal", i, j);
150 return def;
151 }
152 u64 idx_nat = *idx_lit;
153 auto dim = input_dims[i][j];
154 world().DLOG(" index {} = {}", j, idx);
155 world().DLOG(" dim {} = {}", idx, dim);
156 if (!dims.contains(idx_nat)) {
157 dims[idx_nat] = dim;
158 world().DLOG(" {} ↦ {}", idx_nat, dim);
159 } else {
160 // assert(dims[idx_nat] == dim);
161 auto prev_dim = dims[idx_nat];
162 world().DLOG(" prev dim {} = {}", idx_nat, prev_dim);
163 // override with more precise information
164 if (auto dim_lit = dim->isa<Lit>()) {
165 if (auto prev_dim_lit = prev_dim->isa<Lit>())
166 assert(dim_lit->get<u64>() == prev_dim_lit->get<u64>() && "dimensions must be equal");
167 else
168 dims[idx_nat] = dim;
169 }
170 }
171 }
172 }
173
174 for (auto [idx, dim] : dims) {
175 world().ILOG("dim {} = {}", idx, dim);
176 if (idx < n_nat)
177 out_indices.push_back(idx);
178 else
179 in_indices.push_back(idx);
180 }
181 // sort indices to make checks easier later.
182 std::sort(out_indices.begin(), out_indices.end());
183 std::sort(in_indices.begin(), in_indices.end());
184
185 // create function `%mem.M -> [%mem.M, %matrix.Mat (n,S,T)]` to replace axm call
186
187 auto mem_type = world().annex<mem::M>();
188 auto fun = world().mut_fun(mem_type, map_reduce_ax->type())->set("mapRed");
189
190 // assert(0);
191 auto ds_fun = direct::op_cps2ds_dep(fun);
192 world().DLOG("ds_fun {} : {}", ds_fun, ds_fun->type());
193 auto call = world().app(ds_fun, mem);
194 world().DLOG("call {} : {}", call, call->type());
195
196 // flowchart:
197 // ```
198 // -> init
199 // -> forOut1 with yieldOut1
200 // => exitOut1 = return_cont
201 // -> forOut2 with yieldOut2
202 // => exitOut2 = yieldOut1
203 // -> ...
204 // -> accumulator init
205 // -> forIn1 with yieldIn1
206 // => exitIn1 = writeCont
207 // -> forIn2 with yieldIn2
208 // => exitIn2 = yieldIn1
209 // -> ...
210 // -> read matrices
211 // -> fun
212 // => exitFun = yieldInM
213 //
214 // (return path)
215 // -> ...
216 // -> write
217 // -> yieldOutN
218 // -> ...
219 // ```
220
221 // First create the output matrix.
222 auto current_mem = mem;
223 auto [mem2, init_mat] = world().app(world().annex<matrix::init>(), {n, S, T, current_mem})->projs<2>();
224 current_mem = mem2;
225
226 // The function on where to continue -- return after all output loops.
227 auto cont = fun->var(1);
228 auto current_mut = fun;
229
230 // Each of the outer loops contains the memory and matrix as accumulator (in an inner monad).
231 DefVec acc = {current_mem, init_mat};
232
233 for (auto idx : out_indices) {
234 auto for_name = world().sym("forIn_"s + std::to_string(idx));
235 auto dim_nat_def = dims[idx];
236 auto dim = world().call<core::bitcast>(world().type_i32(), dim_nat_def);
237
238 auto [body, for_call] = counting_for(dim, acc, cont, for_name);
239 auto [iter, new_acc, yield] = body->vars<3>();
240 cont = yield;
241 raw_iterator[idx] = iter;
242 iterator[idx] = world().call<core::bitcast>(world().type_idx(dim_nat_def), iter);
243 auto [new_mem, new_mat] = new_acc->projs<2>();
244 acc = {new_mem, new_mat};
245 current_mut->set(true, for_call);
246 current_mut = body;
247 }
248
249 // Now the inner loops for the inputs:
250 // Each of the inner loops contains the element accumulator and memory as accumulator (in an inner monad).
251 world().DLOG("acc at inner: {;}", acc);
252
253 // First create the accumulator.
254 auto element_acc = zero;
255 element_acc->set("acc");
256 current_mem = acc[0];
257 auto wb_matrix = acc[1];
258 assert(wb_matrix);
259 world().DLOG("wb_matrix {} : {}", wb_matrix, wb_matrix->type());
260
261 // Write back element to matrix. Set this as return after all inner loops.
262 auto write_back = mem::mut_con(T)->set("matrixWriteBack");
263 world().DLOG("write_back {} : {}", write_back, write_back->type());
264 auto [wb_mem, element_final] = write_back->vars<2>();
265
266 auto output_iterators = DefVec((size_t)n_nat, [&](u64 i) {
267 auto idx = out_indices[i];
268 if (idx != i) world().ELOG("output indices must be consecutive 0..n-1 but {} != {}", idx, i);
269 assert(idx == i && "output indices must be consecutive 0..n-1");
270 auto iter_idx_def = iterator[idx];
271 return iter_idx_def;
272 });
273 auto output_it_tuple = world().tuple(output_iterators);
274 world().DLOG("output tuple: {} : {}", output_it_tuple, output_it_tuple->type());
275
276 auto [wb_mem2, written_matrix] = world()
277 .app(world().app(world().annex<matrix::insert>(), {n, S, T}),
278 {wb_mem, wb_matrix, output_it_tuple, element_final})
279 ->projs<2>();
280
281 write_back->app(true, cont, {wb_mem2, written_matrix});
282
283 // From here on the continuations take the element and memory.
284 acc = {current_mem, element_acc};
285 cont = write_back;
286
287 // TODO this is copy&paste code from above
288 for (auto idx : in_indices) {
289 auto for_name = world().sym("forIn_"s + std::to_string(idx));
290 auto dim_nat_def = dims[idx];
291 auto dim = world().call<core::bitcast>(world().type_i32(), dim_nat_def);
292
293 auto [body, for_call] = counting_for(dim, acc, cont, for_name);
294 auto [iter, new_acc, yield] = body->vars<3>();
295 cont = yield;
296 raw_iterator[idx] = iter;
297 iterator[idx] = world().call<core::bitcast>(world().type_idx(dim_nat_def), iter);
298 auto [new_mem, new_element] = new_acc->projs<2>();
299 acc = {new_mem, new_element};
300 current_mut->set(true, for_call);
301 current_mut = body;
302 }
303
304 // For testing: id in innermost loop instead of read, fun:
305 // current_mut->app(true, cont, acc);
306
307 current_mem = acc[0];
308 element_acc = acc[1];
309
310 // Read element from input matrix.
311 DefVec input_elements((size_t)m_nat);
312 for (u64 i = 0; i < m_nat; i++) {
313 // TODO: case m_nat == 1
314 auto input_i = inputs->proj(m_nat, i);
315 auto [input_idx_tup, input_matrix] = input_i->projs<2>();
316
317 world().DLOG("input matrix {} is {} : {}", i, input_matrix, input_matrix->type());
318
319 auto indices = input_idx_tup->projs(n_input[i]);
320 auto input_iterators = DefVec(n_input[i], [&](u64 j) {
321 auto idx = indices[j];
322 auto idx_lit = idx->as<Lit>()->get<u64>();
323 world().DLOG(" idx {} {} = {}", i, j, idx_lit);
324 return iterator[idx_lit];
325 });
326 auto input_it_tuple = world().tuple(input_iterators);
327
328 auto read_entry = op_read(current_mem, input_matrix, input_it_tuple);
329 world().DLOG("read_entry {} : {}", read_entry, read_entry->type());
330 auto [new_mem, element_i] = read_entry->projs<2>();
331 current_mem = new_mem;
332 input_elements[i] = element_i;
333 }
334
335 world().DLOG(" read elements {,}", input_elements);
336 world().DLOG(" fun {} : {}", fun, fun->type());
337
338 // TODO: make non-scalar or completely scalar?
339 current_mut->app(true, comb, {world().tuple({current_mem, element_acc, world().tuple(input_elements)}), cont});
340
341 return call;
342 }
343
344 return def;
345}
346
347} // namespace mim::plug::matrix
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
World & world() const noexcept
Definition def.cpp:436
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:429
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:295
Lam * set(Filter filter, const Def *body)
Definition lam.cpp:29
static std::optional< T > isa(const Def *def)
Definition def.h:810
T get() const
Definition def.h:797
World & world()
Definition pass.h:326
This is a thin wrapper for absl::InlinedVector<T, N, A> which is a drop-in replacement for std::vecto...
Definition vector.h:18
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:196
const Idx * type_idx()
Definition world.h:498
const Def * annex(Id id)
Lookup annex by Axm::id.
Definition world.h:181
Lam * mut_fun(const Def *dom, const Def *codom)
Definition world.h:311
const Def * tuple(Defs ops)
Definition world.cpp:288
Sym sym(std::string_view)
Definition world.cpp:75
const Def * call(Id id, Args &&... args)
Complete curried call of annexes obeying implicits.
Definition world.h:542
auto & vars()
Definition world.h:550
const Def * type_i32()
Definition world.h:514
const Def * rewrite(const Def *) override
custom rewrite function memoized version of rewrite_
const Def * op_cps2ds_dep(const Def *k)
Definition direct.h:15
The matrix Plugin
Definition matrix.h:7
const Def * op_read(const Def *mem, const Def *matrix, const Def *idx)
Definition matrix.h:17
std::pair< Lam *, const Def * > counting_for(const Def *bound, DefVec acc, const Def *exit, Sym name)
The mem Plugin
Definition mem.h:11
Lam * mut_con(World &w)
Yields con[mem.M].
Definition mem.h:16
View< const Def * > Defs
Definition def.h:76
Vector< const Def * > DefVec
Definition def.h:77
constexpr decltype(auto) get(Span< T, N > span) noexcept
Definition span.h:115
uint64_t u64
Definition types.h:34