MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
add_mem.cpp
Go to the documentation of this file.
2
3#include <mim/schedule.h>
4
5#include "mim/plug/mem/mem.h"
6
7namespace mim::plug::mem {
8
9namespace {
10
11std::pair<const App*, Vector<Lam*>> isa_apped_mut_lam_in_tuple(const Def* def) {
12 if (auto app = def->isa<App>()) {
13 Vector<Lam*> lams;
14 std::deque<const Def*> wl;
15 wl.push_back(app->callee());
16 while (!wl.empty()) {
17 auto elem = wl.front();
18 wl.pop_front();
19 if (auto mut = elem->isa_mut<Lam>()) {
20 lams.push_back(mut);
21 } else if (auto extract = elem->isa<Extract>()) {
22 if (auto tuple = extract->tuple()->isa<Tuple>())
23 for (auto&& op : tuple->ops()) wl.push_back(op);
24 else
25 return {nullptr, {}};
26 } else {
27 return {nullptr, {}};
28 }
29 }
30 return {app, lams};
31 }
32 return {nullptr, {}};
33}
34
35// @pre isa_apped_mut_lam_in_tuple(def) valid
36template<class F, class H> const Def* rewrite_mut_lam_in_tuple(const Def* def, F&& rewrite, H&& rewrite_idx) {
37 auto& w = def->world();
38 if (auto mut = def->isa_mut<Lam>()) return std::forward<F>(rewrite)(mut);
39
40 auto extract = def->as<Extract>();
41 auto tuple = extract->tuple()->as<Tuple>();
42 auto new_ops = DefVec(tuple->ops(), [&](const Def* op) {
43 return rewrite_mut_lam_in_tuple(op, std::forward<F>(rewrite), std::forward<H>(rewrite_idx));
44 });
45 return w.extract(w.tuple(new_ops), rewrite_idx(extract->index()));
46}
47
48// @pre isa_apped_mut_lam_in_tuple(def) valid
49template<class RewriteCallee, class RewriteArg, class RewriteIdx>
50const Def* rewrite_apped_mut_lam_in_tuple(const Def* def,
51 RewriteCallee&& rewrite_callee,
52 RewriteArg&& rewrite_arg,
53 RewriteIdx&& rewrite_idx) {
54 auto app = def->as<App>();
55 auto callee = rewrite_mut_lam_in_tuple(app->callee(), std::forward<RewriteCallee>(rewrite_callee),
56 std::forward<RewriteIdx>(rewrite_idx));
57 auto arg = std::forward<RewriteArg>(rewrite_arg)(app->arg());
58 return app->rebuild(app->type(), {callee, arg});
59}
60
61} // namespace
62
63// Entry point of the phase.
64void AddMem::visit(const Nest& nest) {
65 sched_ = Scheduler{nest};
66 add_mem_to_lams(root(), root());
67}
68
69const Def* AddMem::mem_for_lam(Lam* lam) const {
70 if (auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
71 // We created a new lambda. Therefore, we want to lookup the mem for the new lambda.
72 lam = it->second->as_mut<Lam>();
73 }
74 if (auto it = val2mem_.find(lam); it != val2mem_.end()) {
75 lam->world().DLOG("found mem for {} in val2mem_ : {}", lam, it->second);
76 // We found a (overwritten) memory in the lambda.
77 return it->second;
78 }
79 // As a fallback, we lookup the memory in vars of the lambda.
80 auto mem = mem::mem_var(lam);
81 assert(mem && "mut must have mem!");
82 return mem;
83}
84
85const Def* AddMem::rewrite_type(const Def* type) {
86 if (auto pi = type->isa<Pi>()) return rewrite_pi(pi);
87
88 if (auto it = mem_rewritten_.find(type); it != mem_rewritten_.end()) return it->second;
89
90 auto new_ops = DefVec(type->num_ops(), [&](size_t i) { return rewrite_type(type->op(i)); });
91 return mem_rewritten_[type] = type->rebuild(type->type(), new_ops);
92}
93
94const Def* AddMem::rewrite_pi(const Pi* pi) {
95 if (auto it = mem_rewritten_.find(pi); it != mem_rewritten_.end()) return it->second;
96
97 auto dom = pi->dom();
98 auto new_dom = DefVec(dom->num_projs(), [&](size_t i) { return rewrite_type(dom->proj(i)); });
99 if (pi->num_doms() == 0 || !match<mem::M>(pi->dom(0_s))) {
100 new_dom
101 = DefVec(dom->num_projs() + 1, [&](size_t i) { return i == 0 ? world().annex<mem::M>() : new_dom[i - 1]; });
102 }
103
104 return mem_rewritten_[pi] = world().pi(new_dom, pi->codom());
105}
106
107const Def* AddMem::add_mem_to_lams(Lam* curr_lam, const Def* def) {
108 auto place = static_cast<Lam*>(sched_.smart(curr_lam, def)->mut());
109
110 // world().DLOG("rewriting {} : {} in {}", def, def->type(), place);
111
112 if (auto global = def->isa<Global>()) return global;
113 if (auto mut_lam = def->isa_mut<Lam>(); mut_lam && !mut_lam->is_set()) return def;
114 if (auto ax = def->isa<Axiom>()) return ax;
115 if (auto it = mem_rewritten_.find(def); it != mem_rewritten_.end()) {
116 auto tmp = it->second;
117 if (match<mem::M>(def->type())) {
118 world().DLOG("already known mem {} in {}", def, curr_lam);
119 auto new_mem = mem_for_lam(curr_lam);
120 world().DLOG("new mem {} in {}", new_mem, curr_lam);
121 return new_mem;
122 }
123 if (curr_lam != def) {
124 // world().DLOG("rewritten def: {} : {} in {}", tmp, tmp->type(), curr_lam);
125 return tmp;
126 }
127 }
128 if (match<mem::M>(def->type())) world().DLOG("new mem {} in {}", def, curr_lam);
129
130 auto rewrite_lam = [&](Lam* lam) -> const Def* {
131 auto pi = lam->type()->as<Pi>();
132 auto new_lam = lam;
133
134 if (auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
135 if (curr_lam == lam) // i.e. we've stubbed this, but now we rewrite it
136 new_lam = it->second->as_mut<Lam>();
137 else if (auto pi = it->second->type()->as<Pi>(); pi->num_doms() > 0 && match<mem::M>(pi->dom(0_s)))
138 return it->second;
139 }
140
141 if (!lam->is_set()) return lam;
142 world().DLOG("rewrite lam {}", lam);
143
144 bool is_bound = sched_.nest().contains(lam) || lam == curr_lam;
145
146 if (new_lam == lam) // if not stubbed yet
147 if (auto new_pi = rewrite_pi(pi); new_pi != pi) new_lam = lam->stub(new_pi);
148
149 if (!is_bound) {
150 world().DLOG("free lam {}", lam);
151 mem_rewritten_[lam] = new_lam;
152 return new_lam;
153 }
154
155 auto var_offset = new_lam->num_doms() - lam->num_doms(); // have we added a mem var?
156 if (lam->num_vars() != 0) mem_rewritten_[lam->var()] = new_lam->var();
157 for (size_t i = 0; i < lam->num_vars() && new_lam->num_vars() > 1; ++i)
158 mem_rewritten_[lam->var(i)] = new_lam->var(i + var_offset);
159
160 auto var = new_lam->var(0_n);
161 mem_rewritten_[new_lam] = new_lam;
162 mem_rewritten_[lam] = new_lam;
163 val2mem_[new_lam] = var;
164 val2mem_[lam] = var;
165 mem_rewritten_[var] = var;
166 auto filter = add_mem_to_lams(lam, lam->filter());
167 auto body = add_mem_to_lams(lam, lam->body());
168 new_lam->unset()->set({filter, body});
169
170 if (lam != new_lam && lam->is_external()) lam->transfer_external(new_lam);
171 return new_lam;
172 };
173
174 // rewrite top-level lams
175 if (auto lam = def->isa_mut<Lam>()) return rewrite_lam(lam);
176 assert(!def->isa_mut());
177
178 if (auto pi = def->isa<Pi>()) return rewrite_pi(pi);
179
180 auto rewrite_arg = [&](const Def* arg) -> const Def* {
181 size_t offset = (arg->type()->num_projs() > 0 && match<mem::M>(arg->type()->proj(0))) ? 0 : 1;
182 if (offset == 0) {
183 // depth-first, follow the mems
184 add_mem_to_lams(place, arg->proj(0));
185 }
186
187 DefVec new_args{arg->type()->num_projs() + offset};
188 for (int i = new_args.size() - 1; i >= 0; i--) {
189 new_args[i]
190 = i == 0 ? add_mem_to_lams(place, mem_for_lam(place)) : add_mem_to_lams(place, arg->proj(i - offset));
191 }
192 return arg->world().tuple(new_args);
193 };
194
195 // call-site of a mutable lambda
196 if (auto apped_mut = isa_apped_mut_lam_in_tuple(def); apped_mut.first) {
197 return mem_rewritten_[def]
198 = rewrite_apped_mut_lam_in_tuple(def, std::move(rewrite_lam), std::move(rewrite_arg),
199 [&](const Def* def) { return add_mem_to_lams(place, def); });
200 }
201
202 // call-site of a continuation
203 if (auto app = def->isa<App>(); app && (app->callee()->has_dep(Dep::Var))) {
204 return mem_rewritten_[def]
205 = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()), rewrite_arg(app->arg())});
206 }
207
208 // call-site of an axiom (assuming mems are only in the final app..)
209 // assume all "negative" curry depths are fully applied axioms, so we do not want to rewrite those here..
210 if (auto app = def->isa<App>(); app && app->axiom() && app->curry() ^ 0x8000) {
211 auto arg = app->arg();
212 DefVec new_args(arg->num_projs());
213 for (int i = new_args.size() - 1; i >= 0; i--) {
214 // replace memory operand with followed mem
215 if (match<mem::M>(arg->proj(i)->type())) {
216 // depth-first, follow the mems
217 add_mem_to_lams(place, arg->proj(i));
218 new_args[i] = add_mem_to_lams(place, mem_for_lam(place));
219 } else {
220 new_args[i] = add_mem_to_lams(place, arg->proj(i));
221 }
222 }
223 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()),
224 world().tuple(new_args)->set(arg->dbg())})
225 ->set(app->dbg());
226 if (match<mem::M>(rewritten->type())) {
227 world().DLOG("memory from axiom {} : {}", rewritten, rewritten->type());
228 val2mem_[place] = rewritten;
229 }
230 if (rewritten->num_projs() > 0 && match<mem::M>(rewritten->proj(0)->type())) {
231 world().DLOG("memory from axiom 2 {} : {}", rewritten, rewritten->type());
232 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
233 val2mem_[place] = rewritten->proj(0);
234 }
235 return rewritten;
236 }
237
238 // all other apps: when rewriting the callee adds a mem to the doms, add a mem to the arg as well..
239 if (auto app = def->isa<App>()) {
240 auto new_callee = add_mem_to_lams(place, app->callee());
241 auto new_arg = add_mem_to_lams(place, app->arg());
242 if (app->callee()->type()->as<Pi>()->num_doms() + 1 == new_callee->type()->as<Pi>()->num_doms())
243 new_arg = rewrite_arg(app->arg());
244 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {new_callee, new_arg})->set(app->dbg());
245 if (match<mem::M>(rewritten->type())) {
246 world().DLOG("memory from other {} : {}", rewritten, rewritten->type());
247 val2mem_[place] = rewritten;
248 }
249 if (rewritten->num_projs() > 0 && match<mem::M>(rewritten->proj(0)->type())) {
250 world().DLOG("memory from other 2 {} : {}", rewritten, rewritten->type());
251 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
252 val2mem_[place] = rewritten->proj(0);
253 }
254 return rewritten;
255 }
256
257 auto new_ops = DefVec(def->ops(), [&](const Def* op) {
258 if (match<mem::M>(op->type())) {
259 // depth-first, follow the mems
260 add_mem_to_lams(place, op);
261 return add_mem_to_lams(place, mem_for_lam(place));
262 }
263 return add_mem_to_lams(place, op);
264 });
265
266 auto tmp = mem_rewritten_[def] = def->rebuild(rewrite_type(def->type()), new_ops)->set(def->dbg());
267 // if (match<mem::M>(tmp->type())) {
268 // world().DLOG("memory from other op 1 {} : {}", tmp, tmp->type());
269 // val2mem_[place] = tmp;
270 // }
271 // if (tmp->num_projs() > 0 && match<mem::M>(tmp->proj(0)->type())) {
272 // world().DLOG("memory from other op 2 {} : {}", tmp, tmp->type());
273 // mem_rewritten_[tmp->proj(0)] = tmp->proj(0);
274 // val2mem_[place] = tmp->proj(0);
275 // }
276 return tmp;
277}
278
279} // namespace mim::plug::mem
Lam * root() const
Definition phase.h:172
Base class for all Defs.
Definition def.h:212
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:281
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:440
World & world() const noexcept
Definition def.cpp:384
bool is_external() const
Definition def.h:416
void transfer_external(Def *to)
Definition def.h:419
nat_t num_vars() noexcept
Definition def.h:381
Ref var(nat_t a, nat_t i) noexcept
Definition def.h:381
A function.
Definition lam.h:105
Ref filter() const
Definition lam.h:117
Lam * stub(Ref type)
Definition lam.h:177
const Pi * type() const
Definition lam.h:125
Ref body() const
Definition lam.h:118
const Nest & nest() const
Definition phase.h:203
Builds a nesting tree of all immutables‍/binders.
Definition nest.h:11
bool contains(const Def *def) const
Definition nest.h:131
World & world()
Definition phase.h:27
const Nest::Node * smart(Def *curr, const Def *)
Definition schedule.cpp:79
const Nest & nest() const
Definition schedule.h:61
const Pi * pi(Ref dom, Ref codom, bool implicit=false)
Definition world.h:262
void visit(const Nest &) override
Definition add_mem.cpp:64
@ App
Definition def.h:40
@ Tuple
Definition def.h:40
@ Extract
Definition def.h:40
@ Pi
Definition def.h:40
@ Lam
Definition def.h:40
The mem Plugin
Definition mem.h:11
Ref mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Definition mem.h:38
Vector< const Def * > DefVec
Definition def.h:62
auto match(Ref def)
Definition axiom.h:112
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >