MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
add_mem.cpp
Go to the documentation of this file.
2
3#include <memory>
4
6
7#include "mim/plug/mem/mem.h"
8
9namespace mim::plug::mem {
10
11namespace {
12
13std::pair<const App*, Vector<Lam*>> isa_apped_mut_lam_in_tuple(const Def* def) {
14 if (auto app = def->isa<App>()) {
15 Vector<Lam*> lams;
16 std::deque<const Def*> wl;
17 wl.push_back(app->callee());
18 while (!wl.empty()) {
19 auto elem = wl.front();
20 wl.pop_front();
21 if (auto mut = elem->isa_mut<Lam>()) {
22 lams.push_back(mut);
23 } else if (auto extract = elem->isa<Extract>()) {
24 if (auto tuple = extract->tuple()->isa<Tuple>())
25 for (auto&& op : tuple->ops()) wl.push_back(op);
26 else
27 return {nullptr, {}};
28 } else {
29 return {nullptr, {}};
30 }
31 }
32 return {app, lams};
33 }
34 return {nullptr, {}};
35}
36
37// @pre isa_apped_mut_lam_in_tuple(def) valid
38template<class F, class H> const Def* rewrite_mut_lam_in_tuple(const Def* def, F&& rewrite, H&& rewrite_idx) {
39 auto& w = def->world();
40 if (auto mut = def->isa_mut<Lam>()) return std::forward<F>(rewrite)(mut);
41
42 auto extract = def->as<Extract>();
43 auto tuple = extract->tuple()->as<Tuple>();
44 auto new_ops = DefVec(tuple->ops(), [&](const Def* op) {
45 return rewrite_mut_lam_in_tuple(op, std::forward<F>(rewrite), std::forward<H>(rewrite_idx));
46 });
47 return w.extract(w.tuple(new_ops), rewrite_idx(extract->index()));
48}
49
50// @pre isa_apped_mut_lam_in_tuple(def) valid
51template<class RewriteCallee, class RewriteArg, class RewriteIdx>
52const Def* rewrite_apped_mut_lam_in_tuple(const Def* def,
53 RewriteCallee&& rewrite_callee,
54 RewriteArg&& rewrite_arg,
55 RewriteIdx&& rewrite_idx) {
56 auto app = def->as<App>();
57 auto callee = rewrite_mut_lam_in_tuple(app->callee(), std::forward<RewriteCallee>(rewrite_callee),
58 std::forward<RewriteIdx>(rewrite_idx));
59 auto arg = std::forward<RewriteArg>(rewrite_arg)(app->arg());
60 return app->rebuild(app->type(), {callee, arg});
61}
62
63} // namespace
64
65// Entry point of the phase.
66void AddMem::visit(const Scope& scope) {
67 if (auto entry = scope.entry()->isa_mut<Lam>()) {
68 scope.free_muts(); // cache this.
69 sched_ = Scheduler{scope};
70 add_mem_to_lams(entry, entry);
71 }
72}
73
74const Def* AddMem::mem_for_lam(Lam* lam) const {
75 if (auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
76 // We created a new lambda. Therefore, we want to lookup the mem for the new lambda.
77 lam = it->second->as_mut<Lam>();
78 }
79 if (auto it = val2mem_.find(lam); it != val2mem_.end()) {
80 lam->world().DLOG("found mem for {} in val2mem_ : {}", lam, it->second);
81 // We found a (overwritten) memory in the lambda.
82 return it->second;
83 }
84 // As a fallback, we lookup the memory in vars of the lambda.
85 auto mem = mem::mem_var(lam);
86 assert(mem && "mut must have mem!");
87 return mem;
88}
89
90const Def* AddMem::rewrite_type(const Def* type) {
91 if (auto pi = type->isa<Pi>()) return rewrite_pi(pi);
92
93 if (auto it = mem_rewritten_.find(type); it != mem_rewritten_.end()) return it->second;
94
95 auto new_ops = DefVec(type->num_ops(), [&](size_t i) { return rewrite_type(type->op(i)); });
96 return mem_rewritten_[type] = type->rebuild(type->type(), new_ops);
97}
98
99const Def* AddMem::rewrite_pi(const Pi* pi) {
100 if (auto it = mem_rewritten_.find(pi); it != mem_rewritten_.end()) return it->second;
101
102 auto dom = pi->dom();
103 auto new_dom = DefVec(dom->num_projs(), [&](size_t i) { return rewrite_type(dom->proj(i)); });
104 if (pi->num_doms() == 0 || !match<mem::M>(pi->dom(0_s))) {
105 new_dom
106 = DefVec(dom->num_projs() + 1, [&](size_t i) { return i == 0 ? world().annex<mem::M>() : new_dom[i - 1]; });
107 }
108
109 return mem_rewritten_[pi] = world().pi(new_dom, pi->codom());
110}
111
112const Def* AddMem::add_mem_to_lams(Lam* curr_lam, const Def* def) {
113 auto place = static_cast<Lam*>(sched_.smart(def));
114
115 // world().DLOG("rewriting {} : {} in {}", def, def->type(), place);
116
117 if (auto global = def->isa<Global>()) return global;
118 if (auto mut_lam = def->isa_mut<Lam>(); mut_lam && !mut_lam->is_set()) return def;
119 if (auto ax = def->isa<Axiom>()) return ax;
120 if (auto it = mem_rewritten_.find(def); it != mem_rewritten_.end()) {
121 auto tmp = it->second;
122 if (match<mem::M>(def->type())) {
123 world().DLOG("already known mem {} in {}", def, curr_lam);
124 auto new_mem = mem_for_lam(curr_lam);
125 world().DLOG("new mem {} in {}", new_mem, curr_lam);
126 return new_mem;
127 }
128 if (curr_lam != def) {
129 // world().DLOG("rewritten def: {} : {} in {}", tmp, tmp->type(), curr_lam);
130 return tmp;
131 }
132 }
133 if (match<mem::M>(def->type())) world().DLOG("new mem {} in {}", def, curr_lam);
134
135 auto rewrite_lam = [&](Lam* lam) -> const Def* {
136 auto pi = lam->type()->as<Pi>();
137 auto new_lam = lam;
138
139 if (auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
140 if (curr_lam == lam) // i.e. we've stubbed this, but now we rewrite it
141 new_lam = it->second->as_mut<Lam>();
142 else if (auto pi = it->second->type()->as<Pi>(); pi->num_doms() > 0 && match<mem::M>(pi->dom(0_s)))
143 return it->second;
144 }
145
146 if (!lam->is_set()) return lam;
147 world().DLOG("rewrite lam {}", lam);
148
149 bool is_bound = sched_.scope().bound(lam) || lam == curr_lam;
150
151 if (new_lam == lam) // if not stubbed yet
152 if (auto new_pi = rewrite_pi(pi); new_pi != pi) new_lam = lam->stub(new_pi);
153
154 if (!is_bound) {
155 world().DLOG("free lam {}", lam);
156 mem_rewritten_[lam] = new_lam;
157 return new_lam;
158 }
159
160 auto var_offset = new_lam->num_doms() - lam->num_doms(); // have we added a mem var?
161 if (lam->num_vars() != 0) mem_rewritten_[lam->var()] = new_lam->var();
162 for (size_t i = 0; i < lam->num_vars() && new_lam->num_vars() > 1; ++i)
163 mem_rewritten_[lam->var(i)] = new_lam->var(i + var_offset);
164
165 auto var = new_lam->var(0_n);
166 mem_rewritten_[new_lam] = new_lam;
167 mem_rewritten_[lam] = new_lam;
168 val2mem_[new_lam] = var;
169 val2mem_[lam] = var;
170 mem_rewritten_[var] = var;
171 auto filter = add_mem_to_lams(lam, lam->filter());
172 auto body = add_mem_to_lams(lam, lam->body());
173 new_lam->unset()->set({filter, body});
174
175 if (lam != new_lam && lam->is_external()) lam->transfer_external(new_lam);
176 return new_lam;
177 };
178
179 // rewrite top-level lams
180 if (auto lam = def->isa_mut<Lam>()) return rewrite_lam(lam);
181 assert(!def->isa_mut());
182
183 if (auto pi = def->isa<Pi>()) return rewrite_pi(pi);
184
185 auto rewrite_arg = [&](const Def* arg) -> const Def* {
186 size_t offset = (arg->type()->num_projs() > 0 && match<mem::M>(arg->type()->proj(0))) ? 0 : 1;
187 if (offset == 0) {
188 // depth-first, follow the mems
189 add_mem_to_lams(place, arg->proj(0));
190 }
191
192 DefVec new_args{arg->type()->num_projs() + offset};
193 for (int i = new_args.size() - 1; i >= 0; i--) {
194 new_args[i]
195 = i == 0 ? add_mem_to_lams(place, mem_for_lam(place)) : add_mem_to_lams(place, arg->proj(i - offset));
196 }
197 return arg->world().tuple(new_args);
198 };
199
200 // call-site of a mutable lambda
201 if (auto apped_mut = isa_apped_mut_lam_in_tuple(def); apped_mut.first) {
202 return mem_rewritten_[def]
203 = rewrite_apped_mut_lam_in_tuple(def, std::move(rewrite_lam), std::move(rewrite_arg),
204 [&](const Def* def) { return add_mem_to_lams(place, def); });
205 }
206
207 // call-site of a continuation
208 if (auto app = def->isa<App>(); app && (app->callee()->dep() & Dep::Var)) {
209 return mem_rewritten_[def]
210 = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()), rewrite_arg(app->arg())});
211 }
212
213 // call-site of an axiom (assuming mems are only in the final app..)
214 // assume all "negative" curry depths are fully applied axioms, so we do not want to rewrite those here..
215 if (auto app = def->isa<App>(); app && app->axiom() && app->curry() ^ 0x8000) {
216 auto arg = app->arg();
217 DefVec new_args(arg->num_projs());
218 for (int i = new_args.size() - 1; i >= 0; i--) {
219 // replace memory operand with followed mem
220 if (match<mem::M>(arg->proj(i)->type())) {
221 // depth-first, follow the mems
222 add_mem_to_lams(place, arg->proj(i));
223 new_args[i] = add_mem_to_lams(place, mem_for_lam(place));
224 } else {
225 new_args[i] = add_mem_to_lams(place, arg->proj(i));
226 }
227 }
228 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()),
229 world().tuple(new_args)->set(arg->dbg())})
230 ->set(app->dbg());
231 if (match<mem::M>(rewritten->type())) {
232 world().DLOG("memory from axiom {} : {}", rewritten, rewritten->type());
233 val2mem_[place] = rewritten;
234 }
235 if (rewritten->num_projs() > 0 && match<mem::M>(rewritten->proj(0)->type())) {
236 world().DLOG("memory from axiom 2 {} : {}", rewritten, rewritten->type());
237 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
238 val2mem_[place] = rewritten->proj(0);
239 }
240 return rewritten;
241 }
242
243 // all other apps: when rewriting the callee adds a mem to the doms, add a mem to the arg as well..
244 if (auto app = def->isa<App>()) {
245 auto new_callee = add_mem_to_lams(place, app->callee());
246 auto new_arg = add_mem_to_lams(place, app->arg());
247 if (app->callee()->type()->as<Pi>()->num_doms() + 1 == new_callee->type()->as<Pi>()->num_doms())
248 new_arg = rewrite_arg(app->arg());
249 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {new_callee, new_arg})->set(app->dbg());
250 if (match<mem::M>(rewritten->type())) {
251 world().DLOG("memory from other {} : {}", rewritten, rewritten->type());
252 val2mem_[place] = rewritten;
253 }
254 if (rewritten->num_projs() > 0 && match<mem::M>(rewritten->proj(0)->type())) {
255 world().DLOG("memory from other 2 {} : {}", rewritten, rewritten->type());
256 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
257 val2mem_[place] = rewritten->proj(0);
258 }
259 return rewritten;
260 }
261
262 auto new_ops = DefVec(def->ops(), [&](const Def* op) {
263 if (match<mem::M>(op->type())) {
264 // depth-first, follow the mems
265 add_mem_to_lams(place, op);
266 return add_mem_to_lams(place, mem_for_lam(place));
267 }
268 return add_mem_to_lams(place, op);
269 });
270
271 auto tmp = mem_rewritten_[def] = def->rebuild(rewrite_type(def->type()), new_ops)->set(def->dbg());
272 // if (match<mem::M>(tmp->type())) {
273 // world().DLOG("memory from other op 1 {} : {}", tmp, tmp->type());
274 // val2mem_[place] = tmp;
275 // }
276 // if (tmp->num_projs() > 0 && match<mem::M>(tmp->proj(0)->type())) {
277 // world().DLOG("memory from other op 2 {} : {}", tmp, tmp->type());
278 // mem_rewritten_[tmp->proj(0)] = tmp->proj(0);
279 // val2mem_[place] = tmp->proj(0);
280 // }
281 return tmp;
282}
283
284} // namespace mim::plug::mem
Base class for all Defs.
Definition def.h:223
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:309
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:455
nat_t num_vars()
Definition def.h:401
World & world() const
Definition def.cpp:415
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:447
bool is_external() const
Definition def.h:431
void transfer_external(Def *to)
Definition def.h:434
Ref var(nat_t a, nat_t i)
Definition def.h:401
A function.
Definition lam.h:103
Ref filter() const
Definition lam.h:113
Lam * stub(Ref type)
Definition lam.h:180
const Pi * type() const
Definition lam.h:115
Ref body() const
Definition lam.h:114
World & world()
Definition phase.h:25
Def * smart(const Def *)
Definition schedule.cpp:84
const Scope & scope() const
Definition schedule.h:26
const Scope & scope() const
Definition phase.h:159
A Scope represents a region of Defs that are live from the view of an entry's Var.
Definition scope.h:20
Def * entry() const
Definition scope.h:31
const MutSet & free_muts() const
All muts that occurr free in this Scope.
Definition scope.h:43
bool bound(const Def *def) const
Definition scope.h:38
const Pi * pi(Ref dom, Ref codom, bool implicit=false)
Definition world.h:255
void visit(const Scope &) override
Definition add_mem.cpp:66
@ App
Definition def.h:40
@ Tuple
Definition def.h:40
@ Extract
Definition def.h:40
@ Pi
Definition def.h:40
@ Lam
Definition def.h:40
The mem Plugin
Definition mem.h:11
Ref mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Definition mem.h:38
Vector< const Def * > DefVec
Definition def.h:62
auto match(Ref def)
Definition axiom.h:112
DefVec rewrite(Def *mut, Ref arg)
Definition rewrite.cpp:45
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >