MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
add_mem.cpp
Go to the documentation of this file.
2
3#include <mim/schedule.h>
4
5#include "mim/plug/mem/mem.h"
6
7// TODO make parametric in address space
8
9namespace mim::plug::mem::phase {
10
11namespace {
12
13std::pair<const App*, Vector<Lam*>> isa_apped_mut_lam_in_tuple(const Def* def) {
14 if (auto app = def->isa<App>()) {
15 Vector<Lam*> lams;
16 std::deque<const Def*> wl;
17 wl.push_back(app->callee());
18 while (!wl.empty()) {
19 auto elem = wl.front();
20 wl.pop_front();
21 if (auto mut = elem->isa_mut<Lam>()) {
22 lams.push_back(mut);
23 } else if (auto extract = elem->isa<Extract>()) {
24 if (auto tuple = extract->tuple()->isa<Tuple>())
25 for (auto&& op : tuple->ops())
26 wl.push_back(op);
27 else
28 return {nullptr, {}};
29 } else {
30 return {nullptr, {}};
31 }
32 }
33 return {app, lams};
34 }
35 return {nullptr, {}};
36}
37
38// @pre isa_apped_mut_lam_in_tuple(def) valid
39template<class F, class H>
40const Def* rewrite_mut_lam_in_tuple(const Def* def, F&& rewrite, H&& rewrite_idx) {
41 auto& w = def->world();
42 if (auto mut = def->isa_mut<Lam>()) return std::forward<F>(rewrite)(mut);
43
44 auto extract = def->as<Extract>();
45 auto tuple = extract->tuple()->as<Tuple>();
46 auto new_ops = DefVec(tuple->ops(), [&](const Def* op) {
47 return rewrite_mut_lam_in_tuple(op, std::forward<F>(rewrite), std::forward<H>(rewrite_idx));
48 });
49 return w.extract(w.tuple(new_ops), rewrite_idx(extract->index()));
50}
51
52// @pre isa_apped_mut_lam_in_tuple(def) valid
53template<class RewriteCallee, class RewriteArg, class RewriteIdx>
54const Def* rewrite_apped_mut_lam_in_tuple(const Def* def,
55 RewriteCallee&& rewrite_callee,
56 RewriteArg&& rewrite_arg,
57 RewriteIdx&& rewrite_idx) {
58 auto app = def->as<App>();
59 auto callee = rewrite_mut_lam_in_tuple(app->callee(), std::forward<RewriteCallee>(rewrite_callee),
60 std::forward<RewriteIdx>(rewrite_idx));
61 auto arg = std::forward<RewriteArg>(rewrite_arg)(app->arg());
62 return app->rebuild(app->type(), {callee, arg});
63}
64
65} // namespace
66
67// Entry point of the phase.
68void AddMem::visit(const Nest& nest) {
69 sched_ = Scheduler{nest};
70 add_mem_to_lams(root(), root());
71}
72
73const Def* AddMem::mem_for_lam(Lam* lam) const {
74 if (auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
75 // We created a new lambda. Therefore, we want to lookup the mem for the new lambda.
76 lam = it->second->as_mut<Lam>();
77 }
78 if (auto it = val2mem_.find(lam); it != val2mem_.end()) {
79 DLOG("found mem for {} in val2mem_ : {}", lam, it->second);
80 // We found a (overwritten) memory in the lambda.
81 return it->second;
82 }
83 // As a fallback, we lookup the memory in vars of the lambda.
84 auto mem = mem::mem_var(lam);
85 assert(mem && "mut must have mem!");
86 return mem;
87}
88
89const Def* AddMem::rewrite_type(const Def* type) {
90 if (auto pi = type->isa<Pi>()) return rewrite_pi(pi);
91
92 if (auto it = mem_rewritten_.find(type); it != mem_rewritten_.end()) return it->second;
93
94 auto new_ops = DefVec(type->num_ops(), [&](size_t i) { return rewrite_type(type->op(i)); });
95 return mem_rewritten_[type] = type->rebuild(type->type(), new_ops);
96}
97
98const Def* AddMem::rewrite_pi(const Pi* pi) {
99 if (auto it = mem_rewritten_.find(pi); it != mem_rewritten_.end()) return it->second;
100
101 auto dom = pi->dom();
102 auto new_dom = DefVec(dom->num_projs(), [&](size_t i) { return rewrite_type(dom->proj(i)); });
103 if (pi->num_doms() == 0 || !Axm::isa<mem::M>(pi->dom(0_s))) {
104 new_dom
105 = DefVec(dom->num_projs() + 1, [&](size_t i) { return i == 0 ? world().call<mem::M>(0) : new_dom[i - 1]; });
106 }
107
108 return mem_rewritten_[pi] = world().pi(new_dom, pi->codom());
109}
110
111const Def* AddMem::add_mem_to_lams(Lam* curr_lam, const Def* def) {
112 auto place = static_cast<Lam*>(sched_.smart(curr_lam, def)->mut());
113
114 // world().DLOG("rewriting {} : {} in {}", def, def->type(), place);
115
116 if (auto global = def->isa<Global>()) return global;
117 if (auto mut_lam = def->isa_mut<Lam>(); mut_lam && !mut_lam->is_set()) return def;
118 if (auto ax = def->isa<Axm>()) return ax;
119 if (auto it = mem_rewritten_.find(def); it != mem_rewritten_.end()) {
120 auto tmp = it->second;
121 if (Axm::isa<mem::M>(def->type())) {
122 DLOG("already known mem {} in {}", def, curr_lam);
123 auto new_mem = mem_for_lam(curr_lam);
124 DLOG("new mem {} in {}", new_mem, curr_lam);
125 return new_mem;
126 }
127 if (curr_lam != def) {
128 // DLOG("rewritten def: {} : {} in {}", tmp, tmp->type(), curr_lam);
129 return tmp;
130 }
131 }
132 if (Axm::isa<mem::M>(def->type())) DLOG("new mem {} in {}", def, curr_lam);
133
134 auto rewrite_lam = [&](Lam* lam) -> const Def* {
135 auto pi = lam->type()->as<Pi>();
136 auto new_lam = lam;
137
138 if (auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
139 if (curr_lam == lam) // i.e. we've stubbed this, but now we rewrite it
140 new_lam = it->second->as_mut<Lam>();
141 else if (auto pi = it->second->type()->as<Pi>(); pi->num_doms() > 0 && Axm::isa<mem::M>(pi->dom(0_s)))
142 return it->second;
143 }
144
145 if (!lam->is_set()) return lam;
146 DLOG("rewrite lam {}", lam);
147
148 bool is_bound = sched_.nest().contains(lam) || lam == curr_lam;
149
150 if (new_lam == lam) // if not stubbed yet
151 if (auto new_pi = rewrite_pi(pi); new_pi != pi) new_lam = lam->stub(new_pi);
152
153 if (!is_bound) {
154 DLOG("free lam {}", lam);
155 mem_rewritten_[lam] = new_lam;
156 return new_lam;
157 }
158
159 auto var_offset = new_lam->num_doms() - lam->num_doms(); // have we added a mem var?
160 if (lam->num_vars() != 0) mem_rewritten_[lam->var()] = new_lam->var();
161 for (size_t i = 0; i < lam->num_vars() && new_lam->num_vars() > 1; ++i)
162 mem_rewritten_[lam->var(i)] = new_lam->var(i + var_offset);
163
164 auto var = new_lam->var(0_n);
165 mem_rewritten_[new_lam] = new_lam;
166 mem_rewritten_[lam] = new_lam;
167 val2mem_[new_lam] = var;
168 val2mem_[lam] = var;
169 mem_rewritten_[var] = var;
170 auto filter = add_mem_to_lams(lam, lam->filter());
171 auto body = add_mem_to_lams(lam, lam->body());
172 new_lam->unset()->set({filter, body});
173
174 if (lam != new_lam && lam->is_external()) lam->transfer_external(new_lam);
175 return new_lam;
176 };
177
178 // rewrite top-level lams
179 if (auto lam = def->isa_mut<Lam>()) return rewrite_lam(lam);
180 assert(!def->isa_mut());
181
182 if (auto pi = def->isa<Pi>()) return rewrite_pi(pi);
183
184 auto rewrite_arg = [&](const Def* arg) -> const Def* {
185 size_t offset = (arg->type()->num_projs() > 0 && Axm::isa<mem::M>(arg->type()->proj(0))) ? 0 : 1;
186 if (offset == 0) {
187 // depth-first, follow the mems
188 add_mem_to_lams(place, arg->proj(0));
189 }
190
191 DefVec new_args{arg->type()->num_projs() + offset};
192 for (int i = new_args.size() - 1; i >= 0; i--) {
193 new_args[i]
194 = i == 0 ? add_mem_to_lams(place, mem_for_lam(place)) : add_mem_to_lams(place, arg->proj(i - offset));
195 }
196 return world().tuple(new_args);
197 };
198
199 // call-site of a mutable lambda
200 if (auto apped_mut = isa_apped_mut_lam_in_tuple(def); apped_mut.first) {
201 return mem_rewritten_[def]
202 = rewrite_apped_mut_lam_in_tuple(def, std::move(rewrite_lam), std::move(rewrite_arg),
203 [&](const Def* def) { return add_mem_to_lams(place, def); });
204 }
205
206 // call-site of a continuation
207 if (auto app = def->isa<App>(); app && (app->callee()->has_dep(Dep::Var))) {
208 return mem_rewritten_[def]
209 = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()), rewrite_arg(app->arg())});
210 }
211
212 // call-site of an axm (assuming mems are only in the final app..)
213 // assume all "negative" curry depths are fully applied axms, so we do not want to rewrite those here..
214 if (auto app = def->isa<App>(); app && app->axm() && app->curry() ^ 0x8000) {
215 auto arg = app->arg();
216 DefVec new_args(arg->num_projs());
217 for (int i = new_args.size() - 1; i >= 0; i--) {
218 // replace memory operand with followed mem
219 if (Axm::isa<mem::M>(arg->proj(i)->type())) {
220 // depth-first, follow the mems
221 add_mem_to_lams(place, arg->proj(i));
222 new_args[i] = add_mem_to_lams(place, mem_for_lam(place));
223 } else {
224 new_args[i] = add_mem_to_lams(place, arg->proj(i));
225 }
226 }
227 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()),
228 world().tuple(new_args)->set(arg->dbg())})
229 ->set(app->dbg());
230 if (Axm::isa<mem::M>(rewritten->type())) {
231 DLOG("memory from axm {} : {}", rewritten, rewritten->type());
232 val2mem_[place] = rewritten;
233 }
234 if (rewritten->num_projs() > 0 && Axm::isa<mem::M>(rewritten->proj(0)->type())) {
235 DLOG("memory from axm 2 {} : {}", rewritten, rewritten->type());
236 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
237 val2mem_[place] = rewritten->proj(0);
238 }
239 return rewritten;
240 }
241
242 // all other apps: when rewriting the callee adds a mem to the doms, add a mem to the arg as well..
243 if (auto app = def->isa<App>()) {
244 auto new_callee = add_mem_to_lams(place, app->callee());
245 auto new_arg = add_mem_to_lams(place, app->arg());
246 if (app->callee()->type()->as<Pi>()->num_doms() + 1 == new_callee->type()->as<Pi>()->num_doms())
247 new_arg = rewrite_arg(app->arg());
248 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {new_callee, new_arg})->set(app->dbg());
249 if (Axm::isa<mem::M>(rewritten->type())) {
250 DLOG("memory from other {} : {}", rewritten, rewritten->type());
251 val2mem_[place] = rewritten;
252 }
253 if (rewritten->num_projs() > 0 && Axm::isa<mem::M>(rewritten->proj(0)->type())) {
254 DLOG("memory from other 2 {} : {}", rewritten, rewritten->type());
255 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
256 val2mem_[place] = rewritten->proj(0);
257 }
258 return rewritten;
259 }
260
261 auto new_ops = DefVec(def->ops(), [&](const Def* op) {
262 if (Axm::isa<mem::M>(op->type())) {
263 // depth-first, follow the mems
264 add_mem_to_lams(place, op);
265 return add_mem_to_lams(place, mem_for_lam(place));
266 }
267 return add_mem_to_lams(place, op);
268 });
269
270 auto tmp = mem_rewritten_[def] = def->rebuild(rewrite_type(def->type()), new_ops)->set(def->dbg());
271 // if (Axm::isa<mem::M>(tmp->type())) {
272 // DLOG("memory from other op 1 {} : {}", tmp, tmp->type());
273 // val2mem_[place] = tmp;
274 // }
275 // if (tmp->num_projs() > 0 && test<mem::M>(tmp->proj(0)->type())) {
276 // DLOG("memory from other op 2 {} : {}", tmp, tmp->type());
277 // mem_rewritten_[tmp->proj(0)] = tmp->proj(0);
278 // val2mem_[place] = tmp->proj(0);
279 // }
280 return tmp;
281}
282
283} // namespace mim::plug::mem::phase
static auto isa(const Def *def)
Definition axm.h:107
Lam * root() const
Definition phase.h:257
Base class for all Defs.
Definition def.h:251
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:295
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:494
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:429
void transfer_external(Def *to)
Definition def.cpp:570
nat_t num_vars() noexcept
Definition def.h:429
bool is_external() const noexcept
Definition def.h:467
A function.
Definition lam.h:111
const Def * filter() const
Definition lam.h:123
const Pi * type() const
Definition lam.h:131
Lam * stub(const Def *type)
Definition lam.h:185
const Def * body() const
Definition lam.h:124
const Nest & nest() const
Definition phase.h:273
Builds a nesting tree of all immutables‍/binders.
Definition nest.h:11
A dependent function type.
Definition lam.h:15
World & world()
Definition pass.h:64
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
Definition world.h:293
const Def * tuple(Defs ops)
Definition world.cpp:291
void visit(const Nest &) override
Definition add_mem.cpp:68
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
The mem Plugin
Definition mem.h:11
const Def * mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Definition mem.h:38
Vector< const Def * > DefVec
Definition def.h:77
@ Var
Definition def.h:126
@ Pi
Definition def.h:114
@ Lam
Definition def.h:114
@ Axm
Definition def.h:114
@ Extract
Definition def.h:114
@ App
Definition def.h:114
@ Tuple
Definition def.h:114
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >