MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
cps2ds.cpp
Go to the documentation of this file.
2
3#include <iostream>
4#include <type_traits>
5
6#include <mim/lam.h>
7
10
11namespace mim::plug::direct {
12
13void CPS2DS::enter() { rewrite_lam(curr_mut()); }
14
15void CPS2DS::rewrite_lam(Lam* lam) {
16 if (auto [_, ins] = rewritten_lams.emplace(lam); !ins) return;
17
18 if (lam->isa_imm() || !lam->is_set() || lam->codom()->isa<Type>()) {
19 world().DLOG("skipped {}", lam);
20 return;
21 }
22
23 world().DLOG("Rewrite lam: {}", lam->sym());
24
25 lam_stack.push_back(curr_lam_);
26 curr_lam_ = lam;
27
28 auto new_f = rewrite_body(curr_lam_->filter());
29 auto new_b = rewrite_body(curr_lam_->body());
30 // curr_lam_ might be different at this point (newly introduced continuation).
31 world().DLOG("Result of rewrite {} in {}", lam, curr_lam_);
32 // TODO This is odd: Why is this *sometimes* not set?
33 curr_lam_->unset()->set({new_f, new_b});
34 curr_lam_ = lam_stack.back();
35 lam_stack.pop_back();
36}
37
38const Def* CPS2DS::rewrite_body(const Def* def) {
39 if (!def) return nullptr;
40 if (auto i = rewritten_.find(def); i != rewritten_.end()) return i->second;
41 auto new_def = rewrite_body_(def);
42 rewritten_[def] = new_def;
43 return rewritten_[def];
44}
45
46const Def* CPS2DS::rewrite_body_(const Def* def) {
47 if (auto app = def->isa<App>()) {
48 auto callee = app->callee();
49 auto arg = app->arg();
50 auto new_callee = rewrite_body(callee);
51 auto new_arg = rewrite_body(arg);
52
53 if (auto cps2ds = match<direct::cps2ds_dep>(new_callee)) {
54 world().DLOG("rewrite callee {} : {}", callee, callee->type());
55 world().DLOG("rewrite arg {} : {}", arg, arg->type());
56 // TODO: rewrite function?
57 auto cps_fun = cps2ds->arg();
58 cps_fun = rewrite_body(cps_fun);
59 world().DLOG("function: {} : {}", cps_fun, cps_fun->type());
60
61 // ```
62 // h:
63 // b = f a
64 // C[b]
65 // ```
66 // =>
67 // ```
68 // h:
69 // f'(a,h_cont)
70 //
71 // h_cont(b):
72 // C[b]
73 //
74 // f : A -> B
75 // f': .Cn [A, ret: .Cn[B]]
76 // ```
77
78 // TODO: rewrite map vs mim::rewrite
79 // TODO: unify replacements
80
81 // We instantiate the function type with the applied argument.
82 auto ty = callee->type();
83 auto ret_ty = ty->as<Pi>()->codom();
84 world().DLOG("callee {} : {}", callee, ty);
85 world().DLOG("new arguments {} : {}", new_arg, new_arg->type());
86 world().DLOG("ret_ty {}", ret_ty);
87
88 auto inst_ret_ty = ret_ty;
89 if (auto pi = ty->isa_mut<Pi>()) inst_ret_ty = pi->reduce(new_arg).back();
90
91 // The continuation that receives the result of the cps function call.
92 auto new_name = world().append_suffix(curr_lam_->sym(), "_cps_cont");
93 auto fun_cont = world().mut_con(inst_ret_ty)->set(new_name);
94 rewritten_lams.insert(fun_cont);
95
96 // Generate the cps function call `f a` -> `f_cps(a,cont)`
97 auto cps_call = world().app(cps_fun, {new_arg, fun_cont})->set("cps_call");
98 world().DLOG(" curr_lam {}", curr_lam_->sym());
99 if (curr_lam_->is_set()) {
100 auto filter = curr_lam_->filter();
101 curr_lam_->reset({filter, cps_call});
102 } else {
103 curr_lam_->set(world().lit_ff(), cps_call);
104 }
105
106 // Fixme: would be great to PE the newly added overhead away..
107 // The current PE just does not terminate on loops.. :/
108 // TODO: Set filter (inline call wrapper)
109 // curr_lam_->set_filter(true);
110
111 // fun_cont->set_filter(curr_lam_->filter());
112
113 // We write the body context in the newly created continuation that has access to the result
114 // (as its argument).
115 curr_lam_ = fun_cont;
116 // `res` is the result of the cps function.
117 auto res = fun_cont->var();
118
119 world().DLOG(" result {} : {} instead of {} : {}", res, res->type(), def, def->type());
120 return res;
121 }
122
123 return world().app(new_callee, new_arg);
124 }
125
126 if (auto lam = def->isa_mut<Lam>()) {
127 rewrite_lam(lam);
128 return lam;
129 }
130
131 if (def->isa<Var>()) return def;
132 if (def->isa<Global>()) return def;
133
134 if (auto tuple = def->isa<Tuple>()) {
135 auto elements = DefVec(tuple->ops(), [&](const Def* op) { return rewrite_body(op); });
136 return world().tuple(def->type(), elements)->set(tuple->dbg());
137 }
138
139 // TODO there are more probls like this:
140 // 1. we have to also rewrite the type (regardless of mut/imm)
141 // 2. muts may be recursive, so it's important to first build the stub and put into rewritten_ before recursing
142 if (auto old_mut = def->isa_mut()) {
143 auto new_type = rewrite_body(old_mut->type());
144 auto new_mut = old_mut->stub(new_type);
145 rewritten_[old_mut] = new_mut;
146 if (auto var = old_mut->has_var()) rewritten_[var] = new_mut->var();
147 auto new_ops = DefVec(def->ops(), [&](const Def* op) { return rewrite_body(op); });
148 new_mut->set(new_ops);
149
150 if (auto imm = new_mut->immutabilize()) return rewritten_[old_mut] = imm;
151 return new_mut;
152 }
153
154 auto new_ops = DefVec(def->ops(), [&](const Def* op) { return rewrite_body(op); });
155 world().DLOG("def {} : {} [{}]", def, def->type(), def->node_name());
156
157 if (def->isa<Infer>()) {
158 world().WLOG("infer node {} : {} [{}]", def, def->type(), def->node_name());
159 return def;
160 }
161
162 return def->rebuild(def->type(), new_ops);
163}
164
165} // namespace mim::plug::direct
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:311
Def * set(size_t i, const Def *def)
Successively set from left to right.
Definition def.cpp:248
Ref var(nat_t a, nat_t i)
Definition def.h:398
Sym sym() const
Definition def.h:465
const T * isa_imm() const
Definition def.h:438
Def * reset(size_t i, const Def *def)
Successively reset from left to right.
Definition def.h:288
A function.
Definition lam.h:96
Lam * unset()
Definition lam.h:169
Ref filter() const
Definition lam.h:106
Lam * set(Filter filter, const Def *body)
Definition lam.h:158
Ref codom() const
Definition lam.h:123
Ref body() const
Definition lam.h:107
World & world()
Definition pass.h:296
Lam * curr_mut() const
Definition pass.h:232
Ref tuple(Defs ops)
Definition world.cpp:239
Lam * mut_con(Ref dom)
Definition world.h:296
Ref app(Ref callee, Ref arg)
Definition world.cpp:187
Sym append_suffix(Sym name, std::string suffix)
Appends a suffix or an increasing number if the suffix already exists.
Definition world.cpp:534
const Type * type(Ref level)
Definition world.cpp:94
void enter() override
Invoked just before Pass::rewriteing PassMan::curr_mut's body.
Definition cps2ds.cpp:13
@ Pi
Definition def.h:39
The direct style Plugin
Definition direct.h:7
Vector< const Def * > DefVec
Definition def.h:61
auto match(Ref def)
Definition axiom.h:105