MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
clos_conv.cpp
Go to the documentation of this file.
2
3#include "mim/check.h"
4
6
7using namespace std::literals;
8
9namespace mim::plug::clos {
10
11namespace {
12
13bool is_memop_res(const Def* fd) {
14 auto proj = fd->isa<Extract>();
15 if (!proj) return false;
16 auto types = proj->tuple()->type()->ops();
17 return std::any_of(types.begin(), types.end(), [](auto d) { return Axm::isa<mem::M>(d); });
18}
19
20} // namespace
21
22DefSet free_defs(const Nest& nest) {
23 DefSet bound;
24 DefSet free;
25 std::queue<const Def*> queue;
26 queue.emplace(nest.root()->mut());
27
28 while (!queue.empty()) {
29 auto def = pop(queue);
30 for (auto op : def->deps()) {
31 if (op->is_closed()) {
32 // do nothing
33 } else if (nest.contains(op)) {
34 if (auto [_, ins] = bound.emplace(op); ins) queue.emplace(op);
35 } else {
36 free.emplace(op);
37 }
38 }
39 }
40
41 return free;
42}
43
44/*
45 * Free variable analysis
46 */
47
48void FreeDefAna::split_fd(Node* node, const Def* fd, bool& init_node, NodeQueue& worklist) {
49 assert(!Axm::isa<mem::M>(fd) && "mem tokens must not be free");
50 if (fd->is_closed()) return;
51 if (auto [var, lam] = ca_isa_var<Lam>(fd); var && lam) {
52 if (var != lam->ret_var()) node->add_fvs(fd);
53 } else if (auto q = Axm::isa(attr::freeBB, fd)) {
54 node->add_fvs(q);
55 } else if (auto pred = fd->isa_mut()) {
56 if (pred != node->mut) {
57 auto [pnode, inserted] = build_node(pred, worklist);
58 node->preds.push_back(pnode);
59 pnode->succs.push_back(node);
60 init_node |= inserted;
61 }
62 } else if (fd->has_dep(Dep::Var) && !fd->isa<Tuple>()) {
63 // Note: Var's can still have Def::Top, if their type is a mut!
64 // So the first case is *not* redundant
65 node->add_fvs(fd);
66 } else if (is_memop_res(fd)) {
67 // Results of memops must not be floated down
68 node->add_fvs(fd);
69 } else {
70 for (auto op : fd->ops())
71 split_fd(node, op, init_node, worklist);
72 }
73}
74
75std::pair<FreeDefAna::Node*, bool> FreeDefAna::build_node(Def* mut, NodeQueue& worklist) {
76 auto& w = world();
77 auto [p, inserted] = lam2nodes_.emplace(mut, nullptr);
78 if (!inserted) return {p->second.get(), false};
79 w.DLOG("FVA: create node: {}", mut);
80 p->second = std::make_unique<Node>(Node{mut, {}, {}, {}, 0});
81 auto node = p->second.get();
82 auto nest = Nest(mut);
83 bool init_node = false;
84 for (auto v : free_defs(nest))
85 split_fd(node, v, init_node, worklist);
86 if (!init_node) {
87 worklist.push(node);
88 w.DLOG("FVA: init {}", mut);
89 }
90 return {node, true};
91}
92
93void FreeDefAna::run(NodeQueue& worklist) {
94 while (!worklist.empty()) {
95 auto node = worklist.front();
96 worklist.pop();
97 if (is_done(node)) continue;
98 auto changed = is_bot(node);
99 mark(node);
100 for (auto p : node->preds) {
101 auto& pfvs = p->fvs;
102 for (auto&& pfv : pfvs)
103 changed |= node->add_fvs(pfv).second;
104 }
105 if (changed)
106 for (auto s : node->succs)
107 worklist.push(s);
108 }
109}
110
112 auto worklist = NodeQueue();
113 auto [node, _] = build_node(lam, worklist);
114 if (!is_done(node)) {
115 cur_pass_id++;
116 run(worklist);
117 }
118
119 return node->fvs;
120}
121
122/*
123 * Closure Conversion
124 */
125
127 auto subst = Def2Def();
128 for (auto mut : world().copy_externals())
129 rewrite(mut, subst);
130 while (!worklist_.empty()) {
131 auto def = worklist_.front();
132 subst = Def2Def();
133 worklist_.pop();
134 if (auto i = closures_.find(def); i != closures_.end()) {
135 rewrite_body(i->second.fn, subst);
136 } else {
137 DLOG("RUN: rewrite def {}", def);
138 rewrite(def, subst);
139 }
140 }
141}
142
143void ClosConv::rewrite_body(Lam* new_lam, Def2Def& subst) {
144 auto& w = world();
145 auto it = closures_.find(new_lam);
146 assert(it != closures_.end() && "closure should have a stub if rewrite_body is called!");
147 auto [old_fn, num_fvs, env, new_fn] = it->second;
148
149 if (!old_fn->is_set()) return;
150
151 DLOG("rw body: {} [old={}, env={}]\nt", new_fn, old_fn, env);
152 auto env_param = new_fn->var(Clos_Env_Param)->set("closure_env");
153 if (num_fvs == 1) {
154 subst.emplace(env, env_param);
155 } else {
156 for (size_t i = 0; i < num_fvs; i++) {
157 auto fv = env->op(i);
158 auto sym = w.sym("fv_"s + (fv->sym() ? fv->sym().str() : std::to_string(i)));
159 subst.emplace(fv, env_param->proj(i)->set(sym));
160 }
161 }
162
163 auto params = w.tuple(DefVec(old_fn->num_doms(), [&](auto i) { return new_lam->var(skip_env(i)); }));
164 subst.emplace(old_fn->var(), params);
165
166 auto filter = rewrite(new_fn->filter(), subst);
167 auto body = rewrite(new_fn->body(), subst);
168 new_fn->unset()->set({filter, body});
169}
170
171const Def* ClosConv::rewrite(const Def* def, Def2Def& subst) {
172 switch (def->node()) {
173 case Node::Type:
174 case Node::Univ:
175 case Node::Nat:
176 case Node::Bot: // TODO This is used by the AD stuff????
177 case Node::Top: return def;
178 default: break;
179 }
180
181 auto& w = world();
182 auto map = [&](const Def* new_def) {
183 subst[def] = new_def;
184 assert(subst[def] == new_def);
185 return new_def;
186 };
187
188 if (auto i = subst.find(def); i != subst.end()) {
189 return i->second;
190 } else if (auto pi = Pi::isa_cn(def)) {
191 return map(type_clos(pi, subst));
192 } else if (auto lam = def->isa_mut<Lam>(); lam && Lam::isa_cn(lam)) {
193 auto [_, __, fv_env, new_lam] = make_stub(lam, subst);
194 auto clos_ty = rewrite(lam->type(), subst);
195 auto env = rewrite(fv_env, subst);
196 auto closure = clos_pack(env, new_lam, clos_ty);
197 DLOG("RW: pack {} ~> {} : {}", lam, closure, clos_ty);
198 return map(closure);
199 } else if (auto a = Axm::isa<attr>(def)) {
200 switch (a.id()) {
201 case attr::returning:
202 if (auto ret_lam = a->arg()->isa_mut<Lam>()) {
203 // assert(ret_lam && ret_lam->is_basicblock());
204 // Note: This should be cont_lam's only occurance after η-expansion, so its okay to
205 // put into the local subst only
206 auto new_doms
207 = DefVec(ret_lam->num_doms(), [&](auto i) { return rewrite(ret_lam->dom(i), subst); });
208 auto new_lam = ret_lam->stub(w.cn(new_doms));
209 subst[ret_lam] = new_lam;
210 if (ret_lam->is_set()) {
211 new_lam->set_filter(rewrite(ret_lam->filter(), subst));
212 new_lam->set_body(rewrite(ret_lam->body(), subst));
213 }
214 return new_lam;
215 }
216 break;
217 case attr::fstclassBB:
218 case attr::freeBB: {
219 // Note: Same thing about η-conversion applies here
220 auto bb_lam = a->arg()->isa_mut<Lam>();
221 assert(bb_lam && Lam::isa_basicblock(bb_lam));
222 auto [_, __, ___, new_lam] = make_stub({}, bb_lam, subst);
223 subst[bb_lam] = clos_pack(w.tuple(), new_lam, rewrite(bb_lam->type(), subst));
224 rewrite_body(new_lam, subst);
225 return map(subst[bb_lam]);
226 }
227 default: break;
228 }
229 } else if (auto [var, lam] = ca_isa_var<Lam>(def); var && lam && lam->ret_var() == var) {
230 // HACK to rewrite a retvar that is defined in an enclosing lambda
231 // If we put external bb's into the env, this should never happen
232 auto new_lam = make_stub(lam, subst).fn;
233 auto new_idx = skip_env(Lit::as(var->index()));
234 return map(new_lam->var(new_idx));
235 }
236
237 auto new_type = rewrite(def->type(), subst);
238
239 if (auto mut = def->isa_mut()) {
240 if (auto global = def->isa_mut<Global>()) {
241 if (auto i = glob_muts_.find(global); i != glob_muts_.end()) return i->second;
242 auto subst = Def2Def();
243 return glob_muts_[mut] = rewrite_mut(global, new_type, subst);
244 }
245 assert(!isa_clos_type(mut));
246 DLOG("RW: mut {}", mut);
247 auto new_mut = rewrite_mut(mut, new_type, subst);
248 // Try to reduce the amount of muts that are created
249 if (!mut->isa_mut<Global>() && Checker::alpha<Checker::Check>(mut, new_mut)) return map(mut);
250 if (auto imm = new_mut->immutabilize()) return map(imm);
251 return map(new_mut);
252 } else {
253 auto new_ops = DefVec(def->num_ops(), [&](auto i) { return rewrite(def->op(i), subst); });
254 if (auto app = def->isa<App>(); app && new_ops[0]->type()->isa<Sigma>())
255 return map(clos_apply(new_ops[0], new_ops[1]));
256 else if (def->isa<Axm>())
257 return def;
258 else
259 return map(def->rebuild(new_type, new_ops));
260 }
261
262 fe::unreachable();
263}
264
265Def* ClosConv::rewrite_mut(Def* mut, const Def* new_type, Def2Def& subst) {
266 auto new_mut = mut->stub(new_type);
267 subst.emplace(mut, new_mut);
268 for (size_t i = 0; i < mut->num_ops(); i++)
269 if (mut->op(i)) new_mut->set(i, rewrite(mut->op(i), subst));
270 return new_mut;
271}
272
273const Pi* ClosConv::rewrite_type_cn(const Pi* pi, Def2Def& subst) {
274 assert(Pi::isa_basicblock(pi));
275 auto new_ops = DefVec(pi->num_doms(), [&](auto i) { return rewrite(pi->dom(i), subst); });
276 return world().cn(new_ops);
277}
278
279const Def* ClosConv::type_clos(const Pi* pi, Def2Def& subst, const Def* env_type) {
280 if (auto i = glob_muts_.find(pi); i != glob_muts_.end() && !env_type) return i->second;
281 auto& w = world();
282 auto new_doms = DefVec(pi->num_doms(), [&](auto i) {
283 return (i == pi->num_doms() - 1 && Pi::isa_returning(pi)) ? rewrite_type_cn(pi->ret_pi(), subst)
284 : rewrite(pi->dom(i), subst);
285 });
286 auto ct = ctype(w, new_doms, env_type);
287 if (!env_type) {
288 glob_muts_.emplace(pi, ct);
289 DLOG("C-TYPE: pct {} ~~> {}", pi, ct);
290 } else {
291 DLOG("C-TYPE: ct {}, env = {} ~~> {}", pi, env_type, ct);
292 }
293 return ct;
294}
295
296ClosConv::Stub ClosConv::make_stub(const DefSet& fvs, Lam* old_lam, Def2Def& subst) {
297 auto& w = world();
298 auto env = w.tuple(DefVec(fvs.begin(), fvs.end()));
299 auto num_fvs = fvs.size();
300 auto env_type = rewrite(env->type(), subst);
301 auto new_fn_type = type_clos(old_lam->type(), subst, env_type)->as<Pi>();
302 auto new_lam = old_lam->stub(new_fn_type);
303 // TODO
304 // new_lam->set_debug_name((old_lam->is_external() || !old_lam->is_set()) ? "cc_" + old_lam->name() :
305 // old_lam->name());
306 if (!isa_workable(old_lam)) {
307 auto new_ext_type = w.cn(clos_remove_env(new_fn_type->dom()));
308 auto new_ext_lam = old_lam->stub(new_ext_type);
309 DLOG("wrap ext lam: {} -> stub: {}, ext: {}", old_lam, new_lam, new_ext_lam);
310 if (old_lam->is_set()) {
311 old_lam->transfer_external(new_ext_lam);
312 new_ext_lam->app(false, new_lam, clos_insert_env(env, new_ext_lam->var()));
313 new_lam->set(old_lam->filter(), old_lam->body());
314 } else {
315 new_ext_lam->unset();
316 new_lam->app(false, new_ext_lam, clos_remove_env(new_lam->var()));
317 }
318 } else {
319 new_lam->set(old_lam->filter(), old_lam->body());
320 }
321 DLOG("STUB {} ~~> ({}, {})", old_lam, env, new_lam);
322 auto closure = Stub{old_lam, num_fvs, env, new_lam};
323 closures_.emplace(old_lam, closure);
324 closures_.emplace(closure.fn, closure);
325 return closure;
326}
327
328ClosConv::Stub ClosConv::make_stub(Lam* old_lam, Def2Def& subst) {
329 if (auto i = closures_.find(old_lam); i != closures_.end()) return i->second;
330 auto fvs = fva_.run(old_lam);
331 auto closure = make_stub(fvs, old_lam, subst);
332 worklist_.emplace(closure.fn);
333 return closure;
334}
335
336} // namespace mim::plug::clos
static auto isa(const Def *def)
Definition axm.h:107
static bool alpha(const Def *d1, const Def *d2)
Definition check.h:96
Base class for all Defs.
Definition def.h:251
Defs deps() const noexcept
Definition def.cpp:464
constexpr auto ops() const noexcept
Definition def.h:305
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:482
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:429
bool has_dep() const
Definition def.h:357
bool is_closed() const
Has no free_vars()?
Definition def.cpp:415
A function.
Definition lam.h:111
static const Lam * isa_cn(const Def *d)
Definition lam.h:142
Lam * set_filter(Filter)
Set filter first.
Definition lam.cpp:28
static const Lam * isa_basicblock(const Def *d)
Definition lam.h:143
Lam * set_body(const Def *body)
Set body second.
Definition lam.h:172
Lam * stub(const Def *type)
Definition lam.h:185
static T as(const Def *def)
Definition def.h:816
Def * mut() const
The mutable capsulated in this Node or nullptr, if it's a virtual root comprising several Nodes.
Definition nest.h:22
Builds a nesting tree of all immutables‍/binders.
Definition nest.h:11
const Node * root() const
Definition nest.h:129
bool contains(const Def *def) const
Definition nest.h:131
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:48
static const Pi * isa_basicblock(const Def *d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Definition lam.h:52
World & world()
Definition pass.h:64
const Pi * cn()
Definition world.h:280
void start() override
Actual entry.
DefSet & run(Lam *lam)
FreeDefAna::run will compute free defs (FD) that appear in lams body.
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
The clos Plugin
Definition clos.h:7
std::tuple< const Extract *, N * > ca_isa_var(const Def *def)
Checks is def is the variable of a mut of type N.
Definition clos.h:79
DefSet free_defs(const Nest &nest)
Definition clos_conv.cpp:22
const Def * clos_remove_env(size_t i, std::function< const Def *(size_t)> f)
Definition clos.cpp:138
const Def * clos_insert_env(size_t i, const Def *env, std::function< const Def *(size_t)> f)
Definition clos.cpp:134
const Def * ctype(World &w, Defs doms, const Def *env_type=nullptr)
Definition clos.cpp:140
static constexpr size_t Clos_Env_Param
Describes where the environment is placed in the argument list.
Definition clos.h:107
size_t skip_env(size_t i)
Definition clos.h:113
const Def * clos_pack(const Def *env, const Def *fn, const Def *ct=nullptr)
Pack a typed closure. This assumes that fn expects the environment as its Clos_Env_Paramth argument.
Definition clos.cpp:73
const Def * clos_apply(const Def *closure, const Def *args)
Apply a closure to arguments.
Definition clos.cpp:95
const Sigma * isa_clos_type(const Def *def)
Definition clos.cpp:106
DefMap< const Def * > Def2Def
Definition def.h:75
auto pop(S &s) -> decltype(s.top(), typename S::value_type())
Definition util.h:83
Vector< const Def * > DefVec
Definition def.h:77
@ Var
Definition def.h:126
Lam * isa_workable(Lam *lam)
These are Lams that are neither nullptr, nor Lam::is_external, nor Lam::is_unset.
Definition lam.h:349
GIDSet< const Def * > DefSet
Definition def.h:74
Node
Definition def.h:112
@ Nat
Definition def.h:114
@ Pi
Definition def.h:114
@ Univ
Definition def.h:114
@ Bot
Definition def.h:114
@ Lam
Definition def.h:114
@ Global
Definition def.h:114
@ Axm
Definition def.h:114
@ Sigma
Definition def.h:114
@ Extract
Definition def.h:114
@ Type
Definition def.h:114
@ Top
Definition def.h:114
@ App
Definition def.h:114
@ Tuple
Definition def.h:114