MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
clos_conv.cpp
Go to the documentation of this file.
2
3#include "mim/check.h"
4
6
8#include "mim/plug/mem/mem.h"
9
10using namespace std::literals;
11
12namespace mim::plug::clos {
13
14namespace {
15
16bool is_toplevel(const Def* fd) {
17 return fd->dep_const() || fd->isa_mut<Global>() || fd->isa<Axiom>() || !fd->is_term();
18}
19
20bool is_memop_res(const Def* fd) {
21 auto proj = fd->isa<Extract>();
22 if (!proj) return false;
23 auto types = proj->tuple()->type()->ops();
24 return std::any_of(types.begin(), types.end(), [](auto d) { return match<mem::M>(d); });
25}
26
27} // namespace
28
29/*
30 * Free variable analysis
31 */
32
33void FreeDefAna::split_fd(Node* node, const Def* fd, bool& init_node, NodeQueue& worklist) {
34 assert(!match<mem::M>(fd) && "mem tokens must not be free");
35 if (is_toplevel(fd)) return;
36 if (auto [var, lam] = ca_isa_var<Lam>(fd); var && lam) {
37 if (var != lam->ret_var()) node->add_fvs(fd);
38 } else if (auto q = match(attr::freeBB, fd)) {
39 node->add_fvs(q);
40 } else if (auto pred = fd->isa_mut()) {
41 if (pred != node->mut) {
42 auto [pnode, inserted] = build_node(pred, worklist);
43 node->preds.push_back(pnode);
44 pnode->succs.push_back(node);
45 init_node |= inserted;
46 }
47 } else if (fd->dep() == Dep::Var && !fd->isa<Tuple>()) {
48 // Note: Var's can still have Def::Top, if their type is a mut!
49 // So the first case is *not* redundant
50 node->add_fvs(fd);
51 } else if (is_memop_res(fd)) {
52 // Results of memops must not be floated down
53 node->add_fvs(fd);
54 } else {
55 for (auto op : fd->ops()) split_fd(node, op, init_node, worklist);
56 }
57}
58
59std::pair<FreeDefAna::Node*, bool> FreeDefAna::build_node(Def* mut, NodeQueue& worklist) {
60 auto& w = world();
61 auto [p, inserted] = lam2nodes_.emplace(mut, nullptr);
62 if (!inserted) return {p->second.get(), false};
63 w.DLOG("FVA: create node: {}", mut);
64 p->second = std::make_unique<Node>(Node{mut, {}, {}, {}, 0});
65 auto node = p->second.get();
66 auto scope = Scope(mut);
67 bool init_node = false;
68 for (auto v : scope.free_defs()) split_fd(node, v, init_node, worklist);
69 if (!init_node) {
70 worklist.push(node);
71 w.DLOG("FVA: init {}", mut);
72 }
73 return {node, true};
74}
75
76void FreeDefAna::run(NodeQueue& worklist) {
77 while (!worklist.empty()) {
78 auto node = worklist.front();
79 worklist.pop();
80 if (is_done(node)) continue;
81 auto changed = is_bot(node);
82 mark(node);
83 for (auto p : node->preds) {
84 auto& pfvs = p->fvs;
85 for (auto&& pfv : pfvs) changed |= node->add_fvs(pfv).second;
86 }
87 if (changed)
88 for (auto s : node->succs) worklist.push(s);
89 }
90}
91
93 auto worklist = NodeQueue();
94 auto [node, _] = build_node(lam, worklist);
95 if (!is_done(node)) {
96 cur_pass_id++;
97 run(worklist);
98 }
99
100 return node->fvs;
101}
102
103/*
104 * Closure Conversion
105 */
106
108 auto externals = Vector(world().externals().begin(), world().externals().end());
109 auto subst = Def2Def();
110 for (auto [_, ext_def] : externals) rewrite(ext_def, subst);
111 while (!worklist_.empty()) {
112 auto def = worklist_.front();
113 subst = Def2Def();
114 worklist_.pop();
115 if (auto i = closures_.find(def); i != closures_.end()) {
116 rewrite_body(i->second.fn, subst);
117 } else {
118 world().DLOG("RUN: rewrite def {}", def);
119 rewrite(def, subst);
120 }
121 }
122}
123
124void ClosConv::rewrite_body(Lam* new_lam, Def2Def& subst) {
125 auto& w = world();
126 auto it = closures_.find(new_lam);
127 assert(it != closures_.end() && "closure should have a stub if rewrite_body is called!");
128 auto [old_fn, num_fvs, env, new_fn] = it->second;
129
130 if (!old_fn->is_set()) return;
131
132 w.DLOG("rw body: {} [old={}, env={}]\nt", new_fn, old_fn, env);
133 auto env_param = new_fn->var(Clos_Env_Param)->set("closure_env");
134 if (num_fvs == 1) {
135 subst.emplace(env, env_param);
136 } else {
137 for (size_t i = 0; i < num_fvs; i++) {
138 auto fv = env->op(i);
139 auto sym = w.sym("fv_"s + (fv->sym() ? fv->sym().str() : std::to_string(i)));
140 subst.emplace(fv, env_param->proj(i)->set(sym));
141 }
142 }
143
144 auto params = w.tuple(DefVec(old_fn->num_doms(), [&](auto i) { return new_lam->var(skip_env(i)); }));
145 subst.emplace(old_fn->var(), params);
146
147 auto filter = rewrite(new_fn->filter(), subst);
148 auto body = rewrite(new_fn->body(), subst);
149 new_fn->reset({filter, body});
150}
151
152const Def* ClosConv::rewrite(const Def* def, Def2Def& subst) {
153 switch (def->node()) {
154 case Node::Type:
155 case Node::Univ:
156 case Node::Nat:
157 case Node::Bot: // TODO This is used by the AD stuff????
158 case Node::Top: return def;
159 default: break;
160 }
161
162 auto& w = world();
163 auto map = [&](const Def* new_def) {
164 subst[def] = new_def;
165 assert(subst[def] == new_def);
166 return new_def;
167 };
168
169 if (auto i = subst.find(def); i != subst.end()) {
170 return i->second;
171 } else if (auto pi = Pi::isa_cn(def)) {
172 return map(type_clos(pi, subst));
173 } else if (auto lam = def->isa_mut<Lam>(); lam && Lam::isa_cn(lam)) {
174 auto [_, __, fv_env, new_lam] = make_stub(lam, subst);
175 auto clos_ty = rewrite(lam->type(), subst);
176 auto env = rewrite(fv_env, subst);
177 auto closure = clos_pack(env, new_lam, clos_ty);
178 world().DLOG("RW: pack {} ~> {} : {}", lam, closure, clos_ty);
179 return map(closure);
180 } else if (auto a = match<attr>(def)) {
181 switch (a.id()) {
182 case attr::returning:
183 if (auto ret_lam = a->arg()->isa_mut<Lam>()) {
184 // assert(ret_lam && ret_lam->is_basicblock());
185 // Note: This should be cont_lam's only occurance after η-expansion, so its okay to
186 // put into the local subst only
187 auto new_doms
188 = DefVec(ret_lam->num_doms(), [&](auto i) { return rewrite(ret_lam->dom(i), subst); });
189 auto new_lam = ret_lam->stub(w.cn(new_doms));
190 subst[ret_lam] = new_lam;
191 if (ret_lam->is_set()) {
192 new_lam->set_filter(rewrite(ret_lam->filter(), subst));
193 new_lam->set_body(rewrite(ret_lam->body(), subst));
194 }
195 return new_lam;
196 }
197 break;
198 case attr::fstclassBB:
199 case attr::freeBB: {
200 // Note: Same thing about η-conversion applies here
201 auto bb_lam = a->arg()->isa_mut<Lam>();
202 assert(bb_lam && Lam::isa_basicblock(bb_lam));
203 auto [_, __, ___, new_lam] = make_stub({}, bb_lam, subst);
204 subst[bb_lam] = clos_pack(w.tuple(), new_lam, rewrite(bb_lam->type(), subst));
205 rewrite_body(new_lam, subst);
206 return map(subst[bb_lam]);
207 }
208 default: break;
209 }
210 } else if (auto [var, lam] = ca_isa_var<Lam>(def); var && lam && lam->ret_var() == var) {
211 // HACK to rewrite a retvar that is defined in an enclosing lambda
212 // If we put external bb's into the env, this should never happen
213 auto new_lam = make_stub(lam, subst).fn;
214 auto new_idx = skip_env(Lit::as(var->index()));
215 return map(new_lam->var(new_idx));
216 }
217
218 auto new_type = rewrite(def->type(), subst);
219
220 if (auto mut = def->isa_mut()) {
221 if (auto global = def->isa_mut<Global>()) {
222 if (auto i = glob_muts_.find(global); i != glob_muts_.end()) return i->second;
223 auto subst = Def2Def();
224 return glob_muts_[mut] = rewrite_mut(global, new_type, subst);
225 }
226 assert(!isa_clos_type(mut));
227 w.DLOG("RW: mut {}", mut);
228 auto new_mut = rewrite_mut(mut, new_type, subst);
229 // Try to reduce the amount of muts that are created
230 if (!mut->isa_mut<Global>() && Check::alpha(mut, new_mut)) return map(mut);
231 if (auto imm = new_mut->immutabilize()) return map(imm);
232 return map(new_mut);
233 } else {
234 auto new_ops = DefVec(def->num_ops(), [&](auto i) { return rewrite(def->op(i), subst); });
235 if (auto app = def->isa<App>(); app && new_ops[0]->type()->isa<Sigma>())
236 return map(clos_apply(new_ops[0], new_ops[1]));
237 else if (def->isa<Axiom>())
238 return def;
239 else
240 return map(def->rebuild(new_type, new_ops));
241 }
242
243 fe::unreachable();
244}
245
246Def* ClosConv::rewrite_mut(Def* mut, const Def* new_type, Def2Def& subst) {
247 auto new_mut = mut->stub(new_type);
248 subst.emplace(mut, new_mut);
249 for (size_t i = 0; i < mut->num_ops(); i++)
250 if (mut->op(i)) new_mut->set(i, rewrite(mut->op(i), subst));
251 return new_mut;
252}
253
254const Pi* ClosConv::rewrite_type_cn(const Pi* pi, Def2Def& subst) {
255 assert(Pi::isa_basicblock(pi));
256 auto new_ops = DefVec(pi->num_doms(), [&](auto i) { return rewrite(pi->dom(i), subst); });
257 return world().cn(new_ops);
258}
259
260const Def* ClosConv::type_clos(const Pi* pi, Def2Def& subst, const Def* env_type) {
261 if (auto i = glob_muts_.find(pi); i != glob_muts_.end() && !env_type) return i->second;
262 auto& w = world();
263 auto new_doms = DefVec(pi->num_doms(), [&](auto i) {
264 return (i == pi->num_doms() - 1 && Pi::isa_returning(pi)) ? rewrite_type_cn(pi->ret_pi(), subst)
265 : rewrite(pi->dom(i), subst);
266 });
267 auto ct = ctype(w, new_doms, env_type);
268 if (!env_type) {
269 glob_muts_.emplace(pi, ct);
270 w.DLOG("C-TYPE: pct {} ~~> {}", pi, ct);
271 } else {
272 w.DLOG("C-TYPE: ct {}, env = {} ~~> {}", pi, env_type, ct);
273 }
274 return ct;
275}
276
277ClosConv::Stub ClosConv::make_stub(const DefSet& fvs, Lam* old_lam, Def2Def& subst) {
278 auto& w = world();
279 auto env = w.tuple(DefVec(fvs.begin(), fvs.end()));
280 auto num_fvs = fvs.size();
281 auto env_type = rewrite(env->type(), subst);
282 auto new_fn_type = type_clos(old_lam->type(), subst, env_type)->as<Pi>();
283 auto new_lam = old_lam->stub(new_fn_type);
284 // TODO
285 // new_lam->set_debug_name((old_lam->is_external() || !old_lam->is_set()) ? "cc_" + old_lam->name() :
286 // old_lam->name());
287 if (!isa_workable(old_lam)) {
288 auto new_ext_type = w.cn(clos_remove_env(new_fn_type->dom()));
289 auto new_ext_lam = old_lam->stub(new_ext_type);
290 w.DLOG("wrap ext lam: {} -> stub: {}, ext: {}", old_lam, new_lam, new_ext_lam);
291 if (old_lam->is_set()) {
292 old_lam->transfer_external(new_ext_lam);
293 new_ext_lam->app(false, new_lam, clos_insert_env(env, new_ext_lam->var()));
294 new_lam->set(old_lam->filter(), old_lam->body());
295 } else {
296 new_ext_lam->unset();
297 new_lam->app(false, new_ext_lam, clos_remove_env(new_lam->var()));
298 }
299 } else {
300 new_lam->set(old_lam->filter(), old_lam->body());
301 }
302 w.DLOG("STUB {} ~~> ({}, {})", old_lam, env, new_lam);
303 auto closure = Stub{old_lam, num_fvs, env, new_lam};
304 closures_.emplace(old_lam, closure);
305 closures_.emplace(closure.fn, closure);
306 return closure;
307}
308
309ClosConv::Stub ClosConv::make_stub(Lam* old_lam, Def2Def& subst) {
310 if (auto i = closures_.find(old_lam); i != closures_.end()) return i->second;
311 auto fvs = fva_.run(old_lam);
312 auto closure = make_stub(fvs, old_lam, subst);
313 worklist_.emplace(closure.fn);
314 return closure;
315}
316
317} // namespace mim::plug::clos
static bool alpha(Ref d1, Ref d2)
Are d1 and d2 α-equivalent?
Definition check.h:65
Ref var(nat_t a, nat_t i)
Definition def.h:401
A function.
Definition lam.h:103
Lam * stub(Ref type)
Definition lam.h:180
Lam * set_filter(Filter)
Set filter first.
Definition lam.cpp:28
Lam * set_body(const Def *body)
Set body second.
Definition lam.h:168
static const Lam * isa_cn(Ref d)
Definition lam.h:138
static const Lam * isa_basicblock(Ref d)
Definition lam.h:139
static T as(Ref def)
Definition def.h:768
World & world()
Definition phase.h:25
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:50
static const Pi * isa_basicblock(Ref d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Definition lam.h:54
const Pi * cn()
Definition world.h:267
void start() override
Actual entry.
DefSet & run(Lam *lam)
FreeDefAna::run will compute free defs (FD) that appear in lams body.
Definition clos_conv.cpp:92
@ Type
Definition def.h:40
@ Univ
Definition def.h:40
@ Nat
Definition def.h:40
@ Bot
Definition def.h:40
@ Top
Definition def.h:40
@ Extract
Definition def.h:40
@ Global
Definition def.h:40
@ Pi
Definition def.h:40
@ Lam
Definition def.h:40
The clos Plugin
Definition clos.h:7
Ref clos_apply(Ref closure, Ref args)
Apply a closure to arguments.
Definition clos.cpp:100
Ref clos_remove_env(size_t i, std::function< Ref(size_t)> f)
Definition clos.cpp:143
static constexpr size_t Clos_Env_Param
Describes where the environment is placed in the argument list.
Definition clos.h:107
const Sigma * isa_clos_type(Ref def)
Definition clos.cpp:111
std::tuple< const Extract *, N * > ca_isa_var(Ref def)
Checks is def is the variable of a mut of type N.
Definition clos.h:79
Ref ctype(World &w, Defs doms, Ref env_type=nullptr)
Definition clos.cpp:145
size_t skip_env(size_t i)
Definition clos.h:113
Ref clos_pack(Ref env, Ref fn, Ref ct=nullptr)
Pack a typed closure. This assumes that fn expects the environment as its Clos_Env_Paramth argument.
Definition clos.cpp:78
Ref clos_insert_env(size_t i, Ref env, std::function< Ref(size_t)> f)
Definition clos.cpp:139
DefMap< const Def * > Def2Def
Definition def.h:60
Vector< const Def * > DefVec
Definition def.h:62
auto match(Ref def)
Definition axiom.h:112
Lam * isa_workable(Lam *lam)
These are Lams that are neither nullptr, nor Lam::is_external, nor Lam::is_unset.
Definition lam.h:251
GIDSet< const Def * > DefSet
Definition def.h:59
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >