MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_matrix_lowlevel.cpp
Go to the documentation of this file.
2
3#include <cassert>
4
5#include <mim/axm.h>
6#include <mim/def.h>
7#include <mim/lam.h>
8
10#include "mim/plug/core/core.h"
13#include "mim/plug/mem/mem.h"
14
15namespace mim::plug::matrix {
16
17namespace {
18
19const Def* op_lea_tuple(const Def* arr, const Def* tuple) {
20 auto& world = arr->world();
21 world.DLOG("op_lea_tuple arr {} : {}", arr, arr->type());
22 auto n = tuple->num_projs();
23 auto element = arr;
24 for (size_t i = 0; i < n; ++i) element = mem::op_lea(element, tuple->proj(n, i));
25 return element;
26}
27
28const Def* op_pack_tuple(u64 n, const Def* tuple, const Def* val) {
29 auto& world = val->world();
30 // TODO: find out why num_projs is wrong
31 auto element = val;
32 for (int i = n - 1; i >= 0; i--) {
33 auto dim = tuple->proj(n, i);
34 element = world.pack(dim, element);
35 }
36 world.DLOG("op_pack_tuple: {} -> {}", val, element);
37 world.DLOG(" for tuple: {} : {}", tuple, tuple->type());
38 return element;
39}
40
41const Def* arr_ty_of_matrix_ty(const Def* S, const Def* T) {
42 auto& world = S->world();
43 auto n = S->num_projs();
44 auto arr_ty = T;
45 for (int i = n - 1; i >= 0; i--) {
46 auto dim = S->proj(n, i);
47 arr_ty = world.arr(dim, arr_ty);
48 }
49 return arr_ty;
50}
51
52} // namespace
53
55 assert(!Axm::isa<matrix::map_reduce>(def) && "map_reduce should have been lowered to for loops by now");
56 assert(!Axm::isa<matrix::shape>(def) && "high level operations should have been lowered to for loops by now");
57 assert(!Axm::isa<matrix::prod>(def) && "high level operations should have been lowered to for loops by now");
58 assert(!Axm::isa<matrix::transpose>(def) && "high level operations should have been lowered to for loops by now");
59 assert(!Axm::isa<matrix::sum>(def) && "high level operations should have been lowered to for loops by now");
60
61 // TODO: generalize arg rewrite
62 if (auto mat_ax = Axm::isa<matrix::Mat>(def)) {
63 auto [_, S, T] = mat_ax->args<3>();
64 S = rewrite(S);
65 T = rewrite(T);
66 auto arr_ty = arr_ty_of_matrix_ty(S, T);
67
68 auto addr_space = world().lit_nat_0();
69 auto ptr_ty = world().call<mem::Ptr>(Defs{arr_ty, addr_space});
70
71 return ptr_ty;
72 } else if (auto init_ax = Axm::isa<matrix::init>(def)) {
73 world().DLOG("init {} : {}", def, def->type());
74 auto [_, S, T, mem] = init_ax->args<4>();
75 world().DLOG(" S T mem {} {} {}", S, T, mem);
76 S = rewrite(S);
77 T = rewrite(T);
78 mem = rewrite(mem);
79 world().DLOG(" S T mem {} {} {}", S, T, mem);
80 auto arr_ty = arr_ty_of_matrix_ty(S, T);
81 auto [mem2, ptr_mat] = mem::op_alloc(arr_ty, mem)->projs<2>();
82 auto res = world().tuple({mem2, ptr_mat});
83 world().DLOG(" res {} : {}", res, res->type());
84 return res;
85 } else if (auto read_ax = Axm::isa<matrix::read>(def)) {
86 auto [mem, mat, idx] = read_ax->args<3>();
87 world().DLOG("read_ax: {}", read_ax);
88 world().DLOG(" mem: {} : {}", mem, mem->type());
89 world().DLOG(" mat: {} : {}", mat, mat->type());
90 world().DLOG(" idx: {} : {}", idx, idx->type());
91 mem = rewrite(mem);
92 mat = rewrite(mat);
93 idx = rewrite(idx);
94 world().DLOG("rewritten read");
95 world().DLOG(" mem: {} : {}", mem, mem->type());
96 world().DLOG(" mat: {} : {}", mat, mat->type());
97 world().DLOG(" idx: {} : {}", idx, idx->type());
98 // TODO: check if mat is already converted
99 auto ptr_mat = mat;
100 auto element_ptr = op_lea_tuple(ptr_mat, idx);
101 auto [mem2, val] = world().call<mem::load>(Defs{mem, element_ptr})->projs<2>();
102 return world().tuple({mem2, val});
103 } else if (auto insert_ax = Axm::isa<matrix::insert>(def)) {
104 auto [mem, mat, idx, val] = insert_ax->args<4>();
105 world().DLOG("insert_ax: {}", insert_ax);
106 world().DLOG(" mem: {} : {}", mem, mem->type());
107 world().DLOG(" mat: {} : {}", mat, mat->type());
108 world().DLOG(" idx: {} : {}", idx, idx->type());
109 world().DLOG(" val: {} : {}", val, val->type());
110 mem = rewrite(mem);
111 mat = rewrite(mat);
112 idx = rewrite(idx);
113 val = rewrite(val);
114 world().DLOG("rewritten insert");
115 world().DLOG(" mem: {} : {}", mem, mem->type());
116 world().DLOG(" mat: {} : {}", mat, mat->type());
117 world().DLOG(" idx: {} : {}", idx, idx->type());
118 world().DLOG(" val: {} : {}", val, val->type());
119 auto ptr_mat = mat;
120 auto element_ptr = op_lea_tuple(ptr_mat, idx);
121 auto mem2 = world().call<mem::store>(Defs{mem, element_ptr, val});
122 return world().tuple({mem2, ptr_mat});
123 } else if (auto const_ax = Axm::isa<matrix::constMat>(def)) {
124 auto [mem, val] = const_ax->args<2>();
125 mem = rewrite(mem);
126 val = rewrite(val);
127 auto [n_def, S, T] = const_ax->callee()->as<App>()->args<3>();
128 S = rewrite(S);
129 T = rewrite(T);
130 auto arr_ty = arr_ty_of_matrix_ty(S, T);
131 auto [mem2, ptr_mat] = mem::op_alloc(arr_ty, mem)->projs<2>();
132
133 // store initial value
134 auto n = n_def->as<Lit>()->get<u64>();
135 auto initial = op_pack_tuple(n, S, val);
136
137 auto mem3 = world().call<mem::store>(Defs{mem2, ptr_mat, initial});
138
139 return world().tuple({mem3, ptr_mat});
140 }
141
142 // ignore unapplied axms to avoid spurious type replacements
143 if (def->isa<Axm>()) return def;
144
145 return Rewriter::rewrite_imm(def); // continue recursive rewriting with everything else
146}
147
148} // namespace mim::plug::matrix
Definition axm.h:9
static auto isa(const Def *def)
Definition axm.h:104
Base class for all Defs.
Definition def.h:197
World & world() const noexcept
Definition def.cpp:387
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:344
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:241
World & world()
Definition phase.h:62
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:18
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:11
const Def * tuple(Defs ops)
Definition world.cpp:266
const Lit * lit_nat_0()
Definition world.h:385
const Def * call(Id id, Args &&... args)
Complete curried call of annexes obeying implicits.
Definition world.h:495
const Def * rewrite_imm(const Def *) override
The matrix Plugin
Definition matrix.h:9
The mem Plugin
Definition mem.h:11
const Def * op_lea(const Def *ptr, const Def *index)
Definition mem.h:111
const Def * op_alloc(const Def *type, const Def *mem)
Definition mem.h:136
View< const Def * > Defs
Definition def.h:48
constexpr decltype(auto) get(Span< T, N > span) noexcept
Definition span.h:107
uint64_t u64
Definition types.h:34