MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
add_mem.cpp
Go to the documentation of this file.
2
3#include <mim/schedule.h>
4
5#include "mim/plug/mem/mem.h"
6
7namespace mim::plug::mem::phase {
8
9namespace {
10
11std::pair<const App*, Vector<Lam*>> isa_apped_mut_lam_in_tuple(const Def* def) {
12 if (auto app = def->isa<App>()) {
13 Vector<Lam*> lams;
14 std::deque<const Def*> wl;
15 wl.push_back(app->callee());
16 while (!wl.empty()) {
17 auto elem = wl.front();
18 wl.pop_front();
19 if (auto mut = elem->isa_mut<Lam>()) {
20 lams.push_back(mut);
21 } else if (auto extract = elem->isa<Extract>()) {
22 if (auto tuple = extract->tuple()->isa<Tuple>())
23 for (auto&& op : tuple->ops())
24 wl.push_back(op);
25 else
26 return {nullptr, {}};
27 } else {
28 return {nullptr, {}};
29 }
30 }
31 return {app, lams};
32 }
33 return {nullptr, {}};
34}
35
36// @pre isa_apped_mut_lam_in_tuple(def) valid
37template<class F, class H>
38const Def* rewrite_mut_lam_in_tuple(const Def* def, F&& rewrite, H&& rewrite_idx) {
39 auto& w = def->world();
40 if (auto mut = def->isa_mut<Lam>()) return std::forward<F>(rewrite)(mut);
41
42 auto extract = def->as<Extract>();
43 auto tuple = extract->tuple()->as<Tuple>();
44 auto new_ops = DefVec(tuple->ops(), [&](const Def* op) {
45 return rewrite_mut_lam_in_tuple(op, std::forward<F>(rewrite), std::forward<H>(rewrite_idx));
46 });
47 return w.extract(w.tuple(new_ops), rewrite_idx(extract->index()));
48}
49
50// @pre isa_apped_mut_lam_in_tuple(def) valid
51template<class RewriteCallee, class RewriteArg, class RewriteIdx>
52const Def* rewrite_apped_mut_lam_in_tuple(const Def* def,
53 RewriteCallee&& rewrite_callee,
54 RewriteArg&& rewrite_arg,
55 RewriteIdx&& rewrite_idx) {
56 auto app = def->as<App>();
57 auto callee = rewrite_mut_lam_in_tuple(app->callee(), std::forward<RewriteCallee>(rewrite_callee),
58 std::forward<RewriteIdx>(rewrite_idx));
59 auto arg = std::forward<RewriteArg>(rewrite_arg)(app->arg());
60 return app->rebuild(app->type(), {callee, arg});
61}
62
63} // namespace
64
65// Entry point of the phase.
66void AddMem::visit(const Nest& nest) {
67 sched_ = Scheduler{nest};
68 add_mem_to_lams(root(), root());
69}
70
71const Def* AddMem::mem_for_lam(Lam* lam) const {
72 if (auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
73 // We created a new lambda. Therefore, we want to lookup the mem for the new lambda.
74 lam = it->second->as_mut<Lam>();
75 }
76 if (auto it = val2mem_.find(lam); it != val2mem_.end()) {
77 DLOG("found mem for {} in val2mem_ : {}", lam, it->second);
78 // We found a (overwritten) memory in the lambda.
79 return it->second;
80 }
81 // As a fallback, we lookup the memory in vars of the lambda.
82 auto mem = mem::mem_var(lam);
83 assert(mem && "mut must have mem!");
84 return mem;
85}
86
87const Def* AddMem::rewrite_type(const Def* type) {
88 if (auto pi = type->isa<Pi>()) return rewrite_pi(pi);
89
90 if (auto it = mem_rewritten_.find(type); it != mem_rewritten_.end()) return it->second;
91
92 auto new_ops = DefVec(type->num_ops(), [&](size_t i) { return rewrite_type(type->op(i)); });
93 return mem_rewritten_[type] = type->rebuild(type->type(), new_ops);
94}
95
96const Def* AddMem::rewrite_pi(const Pi* pi) {
97 if (auto it = mem_rewritten_.find(pi); it != mem_rewritten_.end()) return it->second;
98
99 auto dom = pi->dom();
100 auto new_dom = DefVec(dom->num_projs(), [&](size_t i) { return rewrite_type(dom->proj(i)); });
101 if (pi->num_doms() == 0 || !Axm::isa<mem::M>(pi->dom(0_s))) {
102 new_dom
103 = DefVec(dom->num_projs() + 1, [&](size_t i) { return i == 0 ? world().annex<mem::M>() : new_dom[i - 1]; });
104 }
105
106 return mem_rewritten_[pi] = world().pi(new_dom, pi->codom());
107}
108
109const Def* AddMem::add_mem_to_lams(Lam* curr_lam, const Def* def) {
110 auto place = static_cast<Lam*>(sched_.smart(curr_lam, def)->mut());
111
112 // world().DLOG("rewriting {} : {} in {}", def, def->type(), place);
113
114 if (auto global = def->isa<Global>()) return global;
115 if (auto mut_lam = def->isa_mut<Lam>(); mut_lam && !mut_lam->is_set()) return def;
116 if (auto ax = def->isa<Axm>()) return ax;
117 if (auto it = mem_rewritten_.find(def); it != mem_rewritten_.end()) {
118 auto tmp = it->second;
119 if (Axm::isa<mem::M>(def->type())) {
120 DLOG("already known mem {} in {}", def, curr_lam);
121 auto new_mem = mem_for_lam(curr_lam);
122 DLOG("new mem {} in {}", new_mem, curr_lam);
123 return new_mem;
124 }
125 if (curr_lam != def) {
126 // DLOG("rewritten def: {} : {} in {}", tmp, tmp->type(), curr_lam);
127 return tmp;
128 }
129 }
130 if (Axm::isa<mem::M>(def->type())) DLOG("new mem {} in {}", def, curr_lam);
131
132 auto rewrite_lam = [&](Lam* lam) -> const Def* {
133 auto pi = lam->type()->as<Pi>();
134 auto new_lam = lam;
135
136 if (auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
137 if (curr_lam == lam) // i.e. we've stubbed this, but now we rewrite it
138 new_lam = it->second->as_mut<Lam>();
139 else if (auto pi = it->second->type()->as<Pi>(); pi->num_doms() > 0 && Axm::isa<mem::M>(pi->dom(0_s)))
140 return it->second;
141 }
142
143 if (!lam->is_set()) return lam;
144 DLOG("rewrite lam {}", lam);
145
146 bool is_bound = sched_.nest().contains(lam) || lam == curr_lam;
147
148 if (new_lam == lam) // if not stubbed yet
149 if (auto new_pi = rewrite_pi(pi); new_pi != pi) new_lam = lam->stub(new_pi);
150
151 if (!is_bound) {
152 DLOG("free lam {}", lam);
153 mem_rewritten_[lam] = new_lam;
154 return new_lam;
155 }
156
157 auto var_offset = new_lam->num_doms() - lam->num_doms(); // have we added a mem var?
158 if (lam->num_vars() != 0) mem_rewritten_[lam->var()] = new_lam->var();
159 for (size_t i = 0; i < lam->num_vars() && new_lam->num_vars() > 1; ++i)
160 mem_rewritten_[lam->var(i)] = new_lam->var(i + var_offset);
161
162 auto var = new_lam->var(0_n);
163 mem_rewritten_[new_lam] = new_lam;
164 mem_rewritten_[lam] = new_lam;
165 val2mem_[new_lam] = var;
166 val2mem_[lam] = var;
167 mem_rewritten_[var] = var;
168 auto filter = add_mem_to_lams(lam, lam->filter());
169 auto body = add_mem_to_lams(lam, lam->body());
170 new_lam->unset()->set({filter, body});
171
172 if (lam != new_lam && lam->is_external()) lam->transfer_external(new_lam);
173 return new_lam;
174 };
175
176 // rewrite top-level lams
177 if (auto lam = def->isa_mut<Lam>()) return rewrite_lam(lam);
178 assert(!def->isa_mut());
179
180 if (auto pi = def->isa<Pi>()) return rewrite_pi(pi);
181
182 auto rewrite_arg = [&](const Def* arg) -> const Def* {
183 size_t offset = (arg->type()->num_projs() > 0 && Axm::isa<mem::M>(arg->type()->proj(0))) ? 0 : 1;
184 if (offset == 0) {
185 // depth-first, follow the mems
186 add_mem_to_lams(place, arg->proj(0));
187 }
188
189 DefVec new_args{arg->type()->num_projs() + offset};
190 for (int i = new_args.size() - 1; i >= 0; i--) {
191 new_args[i]
192 = i == 0 ? add_mem_to_lams(place, mem_for_lam(place)) : add_mem_to_lams(place, arg->proj(i - offset));
193 }
194 return world().tuple(new_args);
195 };
196
197 // call-site of a mutable lambda
198 if (auto apped_mut = isa_apped_mut_lam_in_tuple(def); apped_mut.first) {
199 return mem_rewritten_[def]
200 = rewrite_apped_mut_lam_in_tuple(def, std::move(rewrite_lam), std::move(rewrite_arg),
201 [&](const Def* def) { return add_mem_to_lams(place, def); });
202 }
203
204 // call-site of a continuation
205 if (auto app = def->isa<App>(); app && (app->callee()->has_dep(Dep::Var))) {
206 return mem_rewritten_[def]
207 = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()), rewrite_arg(app->arg())});
208 }
209
210 // call-site of an axm (assuming mems are only in the final app..)
211 // assume all "negative" curry depths are fully applied axms, so we do not want to rewrite those here..
212 if (auto app = def->isa<App>(); app && app->axm() && app->curry() ^ 0x8000) {
213 auto arg = app->arg();
214 DefVec new_args(arg->num_projs());
215 for (int i = new_args.size() - 1; i >= 0; i--) {
216 // replace memory operand with followed mem
217 if (Axm::isa<mem::M>(arg->proj(i)->type())) {
218 // depth-first, follow the mems
219 add_mem_to_lams(place, arg->proj(i));
220 new_args[i] = add_mem_to_lams(place, mem_for_lam(place));
221 } else {
222 new_args[i] = add_mem_to_lams(place, arg->proj(i));
223 }
224 }
225 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()),
226 world().tuple(new_args)->set(arg->dbg())})
227 ->set(app->dbg());
228 if (Axm::isa<mem::M>(rewritten->type())) {
229 DLOG("memory from axm {} : {}", rewritten, rewritten->type());
230 val2mem_[place] = rewritten;
231 }
232 if (rewritten->num_projs() > 0 && Axm::isa<mem::M>(rewritten->proj(0)->type())) {
233 DLOG("memory from axm 2 {} : {}", rewritten, rewritten->type());
234 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
235 val2mem_[place] = rewritten->proj(0);
236 }
237 return rewritten;
238 }
239
240 // all other apps: when rewriting the callee adds a mem to the doms, add a mem to the arg as well..
241 if (auto app = def->isa<App>()) {
242 auto new_callee = add_mem_to_lams(place, app->callee());
243 auto new_arg = add_mem_to_lams(place, app->arg());
244 if (app->callee()->type()->as<Pi>()->num_doms() + 1 == new_callee->type()->as<Pi>()->num_doms())
245 new_arg = rewrite_arg(app->arg());
246 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {new_callee, new_arg})->set(app->dbg());
247 if (Axm::isa<mem::M>(rewritten->type())) {
248 DLOG("memory from other {} : {}", rewritten, rewritten->type());
249 val2mem_[place] = rewritten;
250 }
251 if (rewritten->num_projs() > 0 && Axm::isa<mem::M>(rewritten->proj(0)->type())) {
252 DLOG("memory from other 2 {} : {}", rewritten, rewritten->type());
253 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
254 val2mem_[place] = rewritten->proj(0);
255 }
256 return rewritten;
257 }
258
259 auto new_ops = DefVec(def->ops(), [&](const Def* op) {
260 if (Axm::isa<mem::M>(op->type())) {
261 // depth-first, follow the mems
262 add_mem_to_lams(place, op);
263 return add_mem_to_lams(place, mem_for_lam(place));
264 }
265 return add_mem_to_lams(place, op);
266 });
267
268 auto tmp = mem_rewritten_[def] = def->rebuild(rewrite_type(def->type()), new_ops)->set(def->dbg());
269 // if (Axm::isa<mem::M>(tmp->type())) {
270 // DLOG("memory from other op 1 {} : {}", tmp, tmp->type());
271 // val2mem_[place] = tmp;
272 // }
273 // if (tmp->num_projs() > 0 && test<mem::M>(tmp->proj(0)->type())) {
274 // DLOG("memory from other op 2 {} : {}", tmp, tmp->type());
275 // mem_rewritten_[tmp->proj(0)] = tmp->proj(0);
276 // val2mem_[place] = tmp->proj(0);
277 // }
278 return tmp;
279}
280
281} // namespace mim::plug::mem::phase
static auto isa(const Def *def)
Definition axm.h:107
Lam * root() const
Definition phase.h:257
Base class for all Defs.
Definition def.h:251
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:295
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:494
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:429
void transfer_external(Def *to)
Definition def.cpp:573
nat_t num_vars() noexcept
Definition def.h:429
bool is_external() const noexcept
Definition def.h:467
A function.
Definition lam.h:111
const Def * filter() const
Definition lam.h:123
const Pi * type() const
Definition lam.h:131
Lam * stub(const Def *type)
Definition lam.h:185
const Def * body() const
Definition lam.h:124
const Nest & nest() const
Definition phase.h:273
Builds a nesting tree of all immutables‍/binders.
Definition nest.h:11
A dependent function type.
Definition lam.h:15
World & world()
Definition pass.h:64
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
Definition world.h:265
const Def * tuple(Defs ops)
Definition world.cpp:287
void visit(const Nest &) override
Definition add_mem.cpp:66
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
The mem Plugin
Definition mem.h:11
const Def * mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Definition mem.h:38
Vector< const Def * > DefVec
Definition def.h:77
@ Var
Definition def.h:126
@ Pi
Definition def.h:114
@ Lam
Definition def.h:114
@ Axm
Definition def.h:114
@ Extract
Definition def.h:114
@ App
Definition def.h:114
@ Tuple
Definition def.h:114
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >