MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
cps2ds.cpp
Go to the documentation of this file.
2
3#include <ranges>
4
5#include "mim/def.h"
6#include "mim/rewrite.h"
7#include "mim/schedule.h"
8#include "mim/world.h"
9
11
12#define DEBUG_CPS2DS 0
13
14namespace mim::plug::direct {
15
17#if DEBUG_CPS2DS
18 world().debug_dump();
19#endif
20
21 scheduler_.clear();
22 nests_.clear();
23 lam2lam_.clear();
24 rewritten_.clear();
25
26 world().for_each(true, [this](Def* mut) {
27 if (auto lam = mut->isa_mut<Lam>(); lam && !lam->codom()->isa<Type>()) nests_.insert({lam, Nest(lam)});
28 });
29
30 for (auto external : world().externals().muts())
31 if (auto lam = external->isa_mut<Lam>()) {
32 current_external_ = lam;
33 rewrite_lam(lam);
34 }
35
36#if DEBUG_CPS2DS
37 world().debug_dump();
38#endif
39}
40
41const Def* CPS2DSPhase::rewrite_lam(Lam* lam) {
42 if (auto i = rewritten_.find(lam); i != rewritten_.end()) return i->second;
43 if (lam2lam_.contains(lam)) return lam;
44 if (lam->isa_imm() || !lam->is_set() || lam->codom()->isa<Type>()) {
45 world().DLOG("skipped {}", lam);
46 return lam;
47 }
48
49 lam2lam_[lam] = lam;
50
51 world().DLOG("Rewriting lam: {}", lam->unique_name());
52
53 auto filter = rewrite(lam->filter());
54
55 if (auto body = lam->body()->isa<App>(); !body) {
56 world().DLOG(" non-app body {}, skipped", lam->body());
57 auto new_body = rewrite(lam->body());
58 lam->unset()->set(filter, new_body);
59 return rewritten_[lam] = lam;
60 }
61
62 auto body = lam->body()->as<App>();
63 auto new_arg = rewrite(body->arg());
64
65 auto new_callee = rewrite(body->callee());
66 auto new_lam = result_lam(lam);
67#if DEBUG_CPS2DS
68 world().DLOG("Result of rewrite {} set for {}", lam->unique_name(), new_lam->unique_name());
69#endif
70
71 if (world().log().level() >= mim::Log::Level::Debug) body->dump(1);
72
73 new_lam->unset()->app(filter, new_callee, new_arg);
74
75 return rewritten_[lam] = lam;
76}
77
78const Def* CPS2DSPhase::rewrite(const Def* def) {
79 if (auto i = rewritten_.find(def); i != rewritten_.end()) return i->second;
80
81 if (auto lam = def->isa_mut<Lam>()) return rewrite_lam(lam);
82
83 if (auto app = def->isa<App>()) {
84 if (auto cps2ds = Axm::isa<direct::cps2ds_dep>(app->callee())) {
85 auto cps_lam = rewrite(cps2ds->arg())->as<Lam>();
86
87 auto call_arg = rewrite(app->arg());
88
89 if (world().log().level() >= mim::Log::Level::Debug) {
90 cps2ds->dump(2);
91 cps2ds->arg()->dump(2);
92 }
93
94 auto early = scheduler(app).early(app);
95 auto late = scheduler(app).late(current_external_, app);
96 auto node = Nest::lca(early, late);
97#if DEBUG_CPS2DS
98 world().DLOG("scheduling {} between {} (level {}) and {} (level {}) at {}", app,
99 early->mut() ? early->mut()->unique_name() : "root", early->level(),
100 late->mut() ? late->mut()->unique_name() : "root", late->level(),
101 node->mut() ? node->mut()->unique_name() : "root");
102#endif
103 auto lam = result_lam(node->mut()->as_mut<Lam>());
104
105#if DEBUG_CPS2DS
106 world().DLOG("current lam: {} : {}", lam->unique_name(), lam->type());
107#endif
108
109 auto cn_dom = cps_lam->ret_dom();
110 auto cont = make_continuation(cn_dom, app, cps_lam->sym());
111#if DEBUG_CPS2DS
112 world().DLOG("continuation created: {} : {}", cont, cont->type());
113 if (world().log().level() >= mim::Log::Level::Debug) cont->dump(2);
114#endif
115 {
116 auto filter = rewritten_[lam->filter()] = rewrite(lam->filter());
117 auto body = world().app(cps_lam, world().tuple({call_arg, cont}));
118 rewritten_[lam] = lam->unset()->set(filter, body);
119 lam2lam_[lam] = cont;
120 }
121
122#if DEBUG_CPS2DS
123 world().DLOG("point the lam to the cont: {} = {}", lam->unique_name(), lam->body());
124 if (world().log().level() >= mim::Log::Level::Debug) {
125 lam->dump(2);
126 cont->dump(2);
127 }
128#endif
129
130 return cont->var(); // rewritten_[def] = cont->var(); done in make_continuation
131 }
132 }
133
134 DefVec new_ops{def->ops(), [this](const Def* d) { return rewrite(d); }};
135 auto new_def = def->rebuild(def->type(), new_ops);
136 rewritten_[def] = new_def;
137 return new_def;
138}
139
140Lam* CPS2DSPhase::make_continuation(const Def* cn_type, const Def* arg, Sym prefix) {
141#if DEBUG_CPS2DS
142 world().DLOG("make_continuation {} : {} ({})", prefix, cn_type, arg);
143 if (world().log().level() >= mim::Log::Level::Debug) arg->dump(2);
144#endif
145 auto name = world().append_suffix(prefix, "_cps2ds_cont");
146 auto cont = world().mut_con(cn_type)->set(name)->set_filter(false);
147
148 rewritten_[arg] = cont->var();
149
150 return cont;
151}
152
153Lam* CPS2DSPhase::result_lam(Lam* lam) {
154 if (auto i = lam2lam_.find(lam); i != lam2lam_.end())
155 if (i->second != lam) return result_lam(i->second);
156 return lam;
157}
158
159Scheduler& CPS2DSPhase::scheduler(const Def* def) {
160 auto get_or_make = [&](const Def* lam, const Nest& nest) -> Scheduler& {
161 if (auto sched = scheduler_.find(lam); sched != scheduler_.end()) {
162#if DEBUG_CPS2DS
163 world().DLOG("found existing scheduler for {}", lam);
164#endif
165 return sched->second;
166 } else {
167#if DEBUG_CPS2DS
168 world().DLOG("creating new scheduler for {}", lam);
169#endif
170 auto [it, inserted] = scheduler_.insert({lam, Scheduler(nest)});
171 return it->second;
172 }
173 };
174 for (auto& [lam, nest] : nests_) {
175#if DEBUG_CPS2DS
176 world().DLOG("looking for scheduler in {} for {}", lam, def);
177#endif
178 if (nest.contains(def)) return get_or_make(lam, nest);
179 }
180#if DEBUG_CPS2DS
181 world().DLOG("no scheduler found for {}, using current external {}", def, current_external_);
182#endif
183 return get_or_make(current_external_, curr_external_nest());
184}
185
186const Nest& CPS2DSPhase::curr_external_nest() const {
187 auto i = nests_.find(current_external_);
188 assert(i != nests_.end());
189 return i->second;
190}
191
192} // namespace mim::plug::direct
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:297
void dump() const
Definition dump.cpp:452
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:486
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:429
std::string unique_name() const
name + "_" + Def::gid
Definition def.cpp:578
const T * isa_imm() const
Definition def.h:480
A function.
Definition lam.h:110
Lam * unset()
Definition lam.h:179
const Def * filter() const
Definition lam.h:122
Lam * set(Filter filter, const Def *body)
Definition lam.cpp:29
Lam * set_filter(Filter)
Set filter first.
Definition lam.cpp:28
const Pi * type() const
Definition lam.h:130
const Def * body() const
Definition lam.h:123
const Def * codom() const
Definition lam.h:132
Builds a nesting tree of all immutables‍/binders.
Definition nest.h:11
static const Node * lca(const Node *n, const Node *m)
Least common ancestor of n and m.
Definition nest.cpp:78
World & world()
Definition pass.h:64
std::string_view name() const
Definition pass.h:67
Log & log() const
Definition pass.h:66
const Def * insert(const Def *d, const Def *i, const Def *val)
Definition world.cpp:432
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:201
void for_each(bool elide_empty, std::function< void(Def *)>)
Definition world.cpp:679
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Definition dump.cpp:493
Sym append_suffix(Sym name, std::string suffix)
Appends a suffix or an increasing number if the suffix already exists.
Definition world.cpp:643
const Def * var(Def *mut)
Definition world.cpp:180
Lam * mut_con(const Def *dom)
Definition world.h:334
void dump(std::ostream &os)
Dump to os.
Definition dump.cpp:469
void start() final
Actual entry.
Definition cps2ds.cpp:16
The direct style Plugin
Definition direct.h:8
Vector< const Def * > DefVec
Definition def.h:77
@ Lam
Definition def.h:114
@ App
Definition def.h:114