MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_matrix_lowlevel.cpp
Go to the documentation of this file.
2
3#include <cassert>
4
5#include <mim/axm.h>
6#include <mim/def.h>
7#include <mim/lam.h>
8#include "mim/rewrite.h"
9
11#include "mim/plug/core/core.h"
14#include "mim/plug/mem/mem.h"
15
16namespace mim::plug::matrix {
17
18namespace {
19
20const Def* op_lea_tuple(const Def* arr, const Def* tuple) {
21 auto& world = arr->world();
22 world.DLOG("op_lea_tuple arr {} : {}", arr, arr->type());
23 auto n = tuple->num_projs();
24 auto element = arr;
25 for (size_t i = 0; i < n; ++i)
26 element = mem::op_lea(element, tuple->proj(n, i));
27 return element;
28}
29
30const Def* op_pack_tuple(u64 n, const Def* tuple, const Def* val) {
31 auto& world = val->world();
32 // TODO: find out why num_projs is wrong
33 auto element = val;
34 for (int i = n - 1; i >= 0; i--) {
35 auto dim = tuple->proj(n, i);
36 element = world.pack(dim, element);
37 }
38 world.DLOG("op_pack_tuple: {} -> {}", val, element);
39 world.DLOG(" for tuple: {} : {}", tuple, tuple->type());
40 return element;
41}
42
43const Def* arr_ty_of_matrix_ty(const Def* S, const Def* T) {
44 auto& world = S->world();
45 auto n = S->num_projs();
46 auto arr_ty = T;
47 for (int i = n - 1; i >= 0; i--) {
48 auto dim = S->proj(n, i);
49 arr_ty = world.arr(dim, arr_ty);
50 }
51 return arr_ty;
52}
53
54} // namespace
55
57 if (is_bootstrapping()) return Rewriter::rewrite_imm_App(app);
58
59 assert(!Axm::isa<matrix::map_reduce>(app) && "map_reduce should have been lowered to for loops by now");
60 assert(!Axm::isa<matrix::shape>(app) && "high level operations should have been lowered to for loops by now");
61 assert(!Axm::isa<matrix::prod>(app) && "high level operations should have been lowered to for loops by now");
62 assert(!Axm::isa<matrix::transpose>(app) && "high level operations should have been lowered to for loops by now");
63 assert(!Axm::isa<matrix::sum>(app) && "high level operations should have been lowered to for loops by now");
64 auto& w = new_world();
65
66 // TODO: generalize arg rewrite
67 if (auto mat_ax = Axm::isa<matrix::Mat>(app)) {
68 auto [_, S, T] = mat_ax->args<3>();
69 S = rewrite(S);
70 T = rewrite(T);
71 auto arr_ty = arr_ty_of_matrix_ty(S, T);
72
73 auto addr_space = w.lit_nat_0();
74 auto ptr_ty = w.call<mem::Ptr>(Defs{arr_ty, addr_space});
75
76 return ptr_ty;
77 } else if (auto init_ax = Axm::isa<matrix::init>(app)) {
78 w.DLOG("init {} : {}", app, app->type());
79 auto [_, S, T, mem] = init_ax->args<4>();
80 w.DLOG(" S T mem {} {} {}", S, T, mem);
81 S = rewrite(S);
82 T = rewrite(T);
83 mem = rewrite(mem);
84 w.DLOG(" S T mem {} {} {}", S, T, mem);
85 auto arr_ty = arr_ty_of_matrix_ty(S, T);
86 auto [mem2, ptr_mat] = mem::op_alloc(arr_ty, mem)->projs<2>();
87 auto res = w.tuple({mem2, ptr_mat});
88 w.DLOG(" res {} : {}", res, res->type());
89 return res;
90 } else if (auto read_ax = Axm::isa<matrix::read>(app)) {
91 auto [mem, mat, idx] = read_ax->args<3>();
92 w.DLOG("read_ax: {}", read_ax);
93 w.DLOG(" mem: {} : {}", mem, mem->type());
94 w.DLOG(" mat: {} : {}", mat, mat->type());
95 w.DLOG(" idx: {} : {}", idx, idx->type());
96 mem = rewrite(mem);
97 mat = rewrite(mat);
98 idx = rewrite(idx);
99 w.DLOG("rewritten read");
100 w.DLOG(" mem: {} : {}", mem, mem->type());
101 w.DLOG(" mat: {} : {}", mat, mat->type());
102 w.DLOG(" idx: {} : {}", idx, idx->type());
103 // TODO: check if mat is already converted
104 auto ptr_mat = mat;
105 auto element_ptr = op_lea_tuple(ptr_mat, idx);
106 auto [mem2, val] = w.call<mem::load>(Defs{mem, element_ptr})->projs<2>();
107 return w.tuple({mem2, val});
108 } else if (auto insert_ax = Axm::isa<matrix::insert>(app)) {
109 auto [mem, mat, idx, val] = insert_ax->args<4>();
110 w.DLOG("insert_ax: {}", insert_ax);
111 w.DLOG(" mem: {} : {}", mem, mem->type());
112 w.DLOG(" mat: {} : {}", mat, mat->type());
113 w.DLOG(" idx: {} : {}", idx, idx->type());
114 w.DLOG(" val: {} : {}", val, val->type());
115 mem = rewrite(mem);
116 mat = rewrite(mat);
117 idx = rewrite(idx);
118 val = rewrite(val);
119 w.DLOG("rewritten insert");
120 w.DLOG(" mem: {} : {}", mem, mem->type());
121 w.DLOG(" mat: {} : {}", mat, mat->type());
122 w.DLOG(" idx: {} : {}", idx, idx->type());
123 w.DLOG(" val: {} : {}", val, val->type());
124 auto ptr_mat = mat;
125 auto element_ptr = op_lea_tuple(ptr_mat, idx);
126 auto mem2 = w.call<mem::store>(Defs{mem, element_ptr, val});
127 return w.tuple({mem2, ptr_mat});
128 } else if (auto const_ax = Axm::isa<matrix::constMat>(app)) {
129 auto [mem, val] = const_ax->args<2>();
130 mem = rewrite(mem);
131 val = rewrite(val);
132 auto [n_def, S, T] = const_ax->callee()->as<App>()->args<3>();
133 S = rewrite(S);
134 T = rewrite(T);
135 auto arr_ty = arr_ty_of_matrix_ty(S, T);
136 auto [mem2, ptr_mat] = mem::op_alloc(arr_ty, mem)->projs<2>();
137
138 // store initial value
139 auto n = n_def->as<Lit>()->get<u64>();
140 auto initial = op_pack_tuple(n, S, val);
141
142 auto mem3 = w.call<mem::store>(Defs{mem2, ptr_mat, initial});
143
144 return w.tuple({mem3, ptr_mat});
145 }
146
147 return Rewriter::rewrite_imm_App(app); // continue recursive rewriting with everything else
148}
149
150} // namespace mim::plug::matrix
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
World & world() const noexcept
Definition def.cpp:436
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:390
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:295
World & new_world()
Create new Defs into this.
Definition phase.h:99
bool is_bootstrapping() const
Definition phase.h:102
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:14
const Def * rewrite_imm_App(const App *) override
The matrix Plugin
Definition matrix.h:7
The mem Plugin
Definition mem.h:11
const Def * op_lea(const Def *ptr, const Def *index)
Definition mem.h:111
const Def * op_alloc(const Def *type, const Def *mem)
Definition mem.h:136
View< const Def * > Defs
Definition def.h:76
constexpr decltype(auto) get(Span< T, N > span) noexcept
Definition span.h:115
uint64_t u64
Definition types.h:34