MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_matrix_lowlevel.cpp
Go to the documentation of this file.
2
3#include <cassert>
4
5#include <mim/axm.h>
6#include <mim/def.h>
7#include <mim/lam.h>
8
9#include "mim/rewrite.h"
10
12#include "mim/plug/core/core.h"
15#include "mim/plug/mem/mem.h"
16
17namespace mim::plug::matrix {
18
19namespace {
20
21const Def* op_lea_tuple(const Def* arr, const Def* tuple) {
22 auto& world = arr->world();
23 world.DLOG("op_lea_tuple arr {} : {}", arr, arr->type());
24 auto n = tuple->num_projs();
25 auto element = arr;
26 for (size_t i = 0; i < n; ++i)
27 element = mem::op_lea(element, tuple->proj(n, i));
28 return element;
29}
30
31const Def* op_pack_tuple(u64 n, const Def* tuple, const Def* val) {
32 auto& world = val->world();
33 // TODO: find out why num_projs is wrong
34 auto element = val;
35 for (int i = n - 1; i >= 0; i--) {
36 auto dim = tuple->proj(n, i);
37 element = world.pack(dim, element);
38 }
39 world.DLOG("op_pack_tuple: {} -> {}", val, element);
40 world.DLOG(" for tuple: {} : {}", tuple, tuple->type());
41 return element;
42}
43
44const Def* arr_ty_of_matrix_ty(const Def* S, const Def* T) {
45 auto& world = S->world();
46 auto n = S->num_projs();
47 auto arr_ty = T;
48 for (int i = n - 1; i >= 0; i--) {
49 auto dim = S->proj(n, i);
50 arr_ty = world.arr(dim, arr_ty);
51 }
52 return arr_ty;
53}
54
55} // namespace
56
58 if (is_bootstrapping()) return Rewriter::rewrite_imm_App(app);
59
60 assert(!Axm::isa<matrix::map_reduce>(app) && "map_reduce should have been lowered to for loops by now");
61 assert(!Axm::isa<matrix::shape>(app) && "high level operations should have been lowered to for loops by now");
62 assert(!Axm::isa<matrix::prod>(app) && "high level operations should have been lowered to for loops by now");
63 assert(!Axm::isa<matrix::transpose>(app) && "high level operations should have been lowered to for loops by now");
64 assert(!Axm::isa<matrix::sum>(app) && "high level operations should have been lowered to for loops by now");
65 auto& w = new_world();
66
67 // TODO: generalize arg rewrite
68 if (auto mat_ax = Axm::isa<matrix::Mat>(app)) {
69 auto [_, S, T] = mat_ax->args<3>();
70 S = rewrite(S);
71 T = rewrite(T);
72 auto arr_ty = arr_ty_of_matrix_ty(S, T);
73
74 auto addr_space = w.lit_nat_0();
75 auto ptr_ty = w.call<mem::Ptr>(Defs{arr_ty, addr_space});
76
77 return ptr_ty;
78 } else if (auto init_ax = Axm::isa<matrix::init>(app)) {
79 DLOG("init {} : {}", app, app->type());
80 auto [_, S, T, mem] = init_ax->args<4>();
81 DLOG(" S T mem {} {} {}", S, T, mem);
82 S = rewrite(S);
83 T = rewrite(T);
84 mem = rewrite(mem);
85 DLOG(" S T mem {} {} {}", S, T, mem);
86 auto arr_ty = arr_ty_of_matrix_ty(S, T);
87 auto [mem2, ptr_mat] = mem::op_alloc(arr_ty, mem)->projs<2>();
88 auto res = w.tuple({mem2, ptr_mat});
89 DLOG(" res {} : {}", res, res->type());
90 return res;
91 } else if (auto read_ax = Axm::isa<matrix::read>(app)) {
92 auto [mem, mat, idx] = read_ax->args<3>();
93 DLOG("read_ax: {}", read_ax);
94 DLOG(" mem: {} : {}", mem, mem->type());
95 DLOG(" mat: {} : {}", mat, mat->type());
96 DLOG(" idx: {} : {}", idx, idx->type());
97 mem = rewrite(mem);
98 mat = rewrite(mat);
99 idx = rewrite(idx);
100 DLOG("rewritten read");
101 DLOG(" mem: {} : {}", mem, mem->type());
102 DLOG(" mat: {} : {}", mat, mat->type());
103 DLOG(" idx: {} : {}", idx, idx->type());
104 // TODO: check if mat is already converted
105 auto ptr_mat = mat;
106 auto element_ptr = op_lea_tuple(ptr_mat, idx);
107 auto [mem2, val] = w.call<mem::load>(Defs{mem, element_ptr})->projs<2>();
108 return w.tuple({mem2, val});
109 } else if (auto insert_ax = Axm::isa<matrix::insert>(app)) {
110 auto [mem, mat, idx, val] = insert_ax->args<4>();
111 DLOG("insert_ax: {}", insert_ax);
112 DLOG(" mem: {} : {}", mem, mem->type());
113 DLOG(" mat: {} : {}", mat, mat->type());
114 DLOG(" idx: {} : {}", idx, idx->type());
115 DLOG(" val: {} : {}", val, val->type());
116 mem = rewrite(mem);
117 mat = rewrite(mat);
118 idx = rewrite(idx);
119 val = rewrite(val);
120 DLOG("rewritten insert");
121 DLOG(" mem: {} : {}", mem, mem->type());
122 DLOG(" mat: {} : {}", mat, mat->type());
123 DLOG(" idx: {} : {}", idx, idx->type());
124 DLOG(" val: {} : {}", val, val->type());
125 auto ptr_mat = mat;
126 auto element_ptr = op_lea_tuple(ptr_mat, idx);
127 auto mem2 = w.call<mem::store>(Defs{mem, element_ptr, val});
128 return w.tuple({mem2, ptr_mat});
129 } else if (auto const_ax = Axm::isa<matrix::constMat>(app)) {
130 auto [mem, val] = const_ax->args<2>();
131 mem = rewrite(mem);
132 val = rewrite(val);
133 auto [n_def, S, T] = const_ax->callee()->as<App>()->args<3>();
134 S = rewrite(S);
135 T = rewrite(T);
136 auto arr_ty = arr_ty_of_matrix_ty(S, T);
137 auto [mem2, ptr_mat] = mem::op_alloc(arr_ty, mem)->projs<2>();
138
139 // store initial value
140 auto n = n_def->as<Lit>()->get<u64>();
141 auto initial = op_pack_tuple(n, S, val);
142
143 auto mem3 = w.call<mem::store>(Defs{mem2, ptr_mat, initial});
144
145 return w.tuple({mem3, ptr_mat});
146 }
147
148 return Rewriter::rewrite_imm_App(app); // continue recursive rewriting with everything else
149}
150
151} // namespace mim::plug::matrix
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
World & world() const noexcept
Definition def.cpp:439
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:390
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:447
World & new_world()
Create new Defs into this.
Definition phase.h:100
bool is_bootstrapping() const
Definition phase.h:79
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:14
const Def * rewrite_imm_App(const App *) override
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
The matrix Plugin
Definition matrix.h:7
The mem Plugin
Definition mem.h:11
const Def * op_lea(const Def *ptr, const Def *index)
Definition mem.h:111
const Def * op_alloc(const Def *type, const Def *mem)
Definition mem.h:136
View< const Def * > Defs
Definition def.h:76
constexpr decltype(auto) get(Span< T, N > span) noexcept
Definition span.h:115
uint64_t u64
Definition types.h:34