MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
ds2cps.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5#include <mim/lam.h>
6
8
9namespace mim::plug::direct {
10
12 if (auto app = def->isa<App>()) {
13 if (auto lam = app->callee()->isa_mut<Lam>()) {
14 world().DLOG("encountered lam app");
15 auto new_lam = rewrite_lam(lam);
16 world().DLOG("new lam: {} : {}", new_lam, new_lam->type());
17 world().DLOG("arg: {} : {}", app->arg(), app->arg()->type());
18 auto new_app = world().app(new_lam, app->arg());
19 world().DLOG("new app: {} : {}", new_app, new_app->type());
20 return new_app;
21 }
22 }
23 return def;
24}
25
26/// This function generates the cps function `f_cps : cn [a:A, cn B]` for a ds function `f: [a : A] -> B`.
27/// The translation is associated in the `rewritten_` map.
28Ref DS2CPS::rewrite_lam(Lam* lam) {
29 if (auto i = rewritten_.find(lam); i != rewritten_.end()) return i->second;
30
31 // only look at lambdas (ds not cps)
32 if (Lam::isa_cn(lam)) return lam;
33 // ignore ds on type level
34 if (lam->type()->codom()->isa<Type>()) return lam;
35 // ignore higher order function
36 if (lam->type()->codom()->isa<Pi>()) {
37 // We can not set the filter here as this causes segfaults.
38 return lam;
39 }
40
41 world().DLOG("rewrite DS function {} : {}", lam, lam->type());
42
43 auto ty = lam->type();
44 auto var = ty->has_var();
45 auto dom = ty->dom();
46 auto codom = ty->codom();
47 auto sigma = world().mut_sigma(2);
48 // replace ds dom var with cps sigma var (cps dom)
49 auto rw_codom = var ? VarRewriter(var, sigma->var(2, 0)).rewrite(codom) : codom;
50 sigma->set(0, dom);
51 sigma->set(1, world().cn(rw_codom));
52
53 world().DLOG("original codom: {}", codom);
54 world().DLOG("rewritten codom: {}", rw_codom);
55
56 auto cps_lam = world().mut_con(sigma)->set(lam->sym().str() + "_cps");
57
58 // rewrite vars of new function
59 // calls handled separately
60 world().DLOG("body: {} : {}", lam->body(), lam->body()->type());
61
62 auto new_ops = lam->reduce(cps_lam->var(0_n));
63 auto filter = new_ops[0];
64 auto cps_body = new_ops[1];
65
66 world().DLOG("cps body: {} : {}", cps_body, cps_body->type());
67
68 cps_lam->app(filter, cps_lam->vars().back(), cps_body);
69
70 rewritten_[lam] = op_cps2ds_dep(cps_lam);
71 world().DLOG("replace {} : {}", lam, lam->type());
72 world().DLOG("with {} : {}", rewritten_[lam], rewritten_[lam]->type());
73
74 return rewritten_[lam];
75}
76
77} // namespace mim::plug::direct
Ref type() const
Definition def.h:251
Def * set(size_t i, Ref)
Successively set from left to right.
Definition def.cpp:256
DefVec reduce(Ref arg) const
Rewrites Def::ops by substituting this mutable's Var with arg.
Definition def.cpp:204
Sym sym() const
Definition def.h:466
const Var * has_var()
Only returns not nullptr, if Var of this mutable has ever been created.
Definition def.h:399
A function.
Definition lam.h:104
Lam * set(Filter filter, Ref body)
Definition lam.h:161
const Pi * type() const
Definition lam.h:116
Ref body() const
Definition lam.h:115
static const Lam * isa_cn(Ref d)
Definition lam.h:133
World & world()
Definition pass.h:296
A dependent function type.
Definition lam.h:11
Ref codom() const
Definition lam.h:33
Helper class to retrieve Infer::arg if present.
Definition def.h:86
virtual Ref rewrite(Ref)
Definition rewrite.cpp:9
Lam * mut_con(Ref dom)
Definition world.h:296
Sigma * mut_sigma(Ref type, size_t size)
Definition world.h:321
Ref app(Ref callee, Ref arg)
Definition world.cpp:170
Ref rewrite(Ref) override
Definition ds2cps.cpp:11
The direct style Plugin
Definition direct.h:7
Ref op_cps2ds_dep(Ref k)
Definition direct.h:15