MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
cps2ds.cpp
Go to the documentation of this file.
2
3#include <ranges>
4
5#include "mim/def.h"
6#include "mim/rewrite.h"
7#include "mim/schedule.h"
8#include "mim/world.h"
9
11
12#define DEBUG_CPS2DS 0
13
14namespace mim::plug::direct {
15
17#if DEBUG_CPS2DS
18 world().debug_dump();
19#endif
20
21 scheduler_.clear();
22 nests_.clear();
23 lam2lam_.clear();
24 rewritten_.clear();
25
26 world().for_each(true, [this](Def* mut) {
27 if (auto lam = mut->isa_mut<Lam>(); lam && !lam->codom()->isa<Type>())
28 nests_.try_emplace(lam, std::make_unique<Nest>(lam));
29 });
30
31 for (auto external : world().externals().muts())
32 if (auto lam = external->isa_mut<Lam>()) {
33 current_external_ = lam;
34 rewrite_lam(lam);
35 }
36
37#if DEBUG_CPS2DS
38 world().debug_dump();
39#endif
40}
41
42const Def* CPS2DSPhase::rewrite_lam(Lam* lam) {
43 if (auto i = rewritten_.find(lam); i != rewritten_.end()) return i->second;
44 if (lam2lam_.contains(lam)) return lam;
45 if (lam->isa_imm() || !lam->is_set() || lam->codom()->isa<Type>()) {
46 world().DLOG("skipped {}", lam);
47 return lam;
48 }
49
50 lam2lam_[lam] = lam;
51
52 world().DLOG("Rewriting lam: {}", lam->unique_name());
53
54 auto filter = rewrite(lam->filter());
55
56 if (auto body = lam->body()->isa<App>(); !body) {
57 world().DLOG(" non-app body {}, skipped", lam->body());
58 auto new_body = rewrite(lam->body());
59 lam->unset()->set(filter, new_body);
60 return rewritten_[lam] = lam;
61 }
62
63 auto body = lam->body()->as<App>();
64 auto new_arg = rewrite(body->arg());
65
66 auto new_callee = rewrite(body->callee());
67 auto new_lam = result_lam(lam);
68#if DEBUG_CPS2DS
69 world().DLOG("Result of rewrite {} set for {}", lam->unique_name(), new_lam->unique_name());
70#endif
71
72 if (world().log().level() >= mim::Log::Level::Debug) body->dump(1);
73
74 new_lam->unset()->app(filter, new_callee, new_arg);
75
76 return rewritten_[lam] = lam;
77}
78
79const Def* CPS2DSPhase::rewrite(const Def* def) {
80 if (auto i = rewritten_.find(def); i != rewritten_.end()) return i->second;
81
82 if (auto lam = def->isa_mut<Lam>()) return rewrite_lam(lam);
83
84 if (auto app = def->isa<App>()) {
85 if (auto cps2ds = Axm::isa<direct::cps2ds_dep>(app->callee())) {
86 auto cps_lam = rewrite(cps2ds->arg())->as<Lam>();
87
88 auto call_arg = rewrite(app->arg());
89
90 if (world().log().level() >= mim::Log::Level::Debug) {
91 cps2ds->dump(2);
92 cps2ds->arg()->dump(2);
93 }
94
95 auto early = scheduler(app).early(app);
96 auto late = scheduler(app).late(current_external_, app);
97 auto node = Nest::lca(early, late);
98#if DEBUG_CPS2DS
99 world().DLOG("scheduling {} between {} (level {}) and {} (level {}) at {}", app,
100 early->mut() ? early->mut()->unique_name() : "root", early->level(),
101 late->mut() ? late->mut()->unique_name() : "root", late->level(),
102 node->mut() ? node->mut()->unique_name() : "root");
103#endif
104 auto lam = result_lam(node->mut()->as_mut<Lam>());
105
106#if DEBUG_CPS2DS
107 world().DLOG("current lam: {} : {}", lam->unique_name(), lam->type());
108#endif
109
110 auto cn_dom = cps_lam->ret_dom();
111 auto cont = make_continuation(cn_dom, app, cps_lam->sym());
112#if DEBUG_CPS2DS
113 world().DLOG("continuation created: {} : {}", cont, cont->type());
114 if (world().log().level() >= mim::Log::Level::Debug) cont->dump(2);
115#endif
116 {
117 auto filter = rewritten_[lam->filter()] = rewrite(lam->filter());
118 auto body = world().app(cps_lam, world().tuple({call_arg, cont}));
119 rewritten_[lam] = lam->unset()->set(filter, body);
120 lam2lam_[lam] = cont;
121 }
122
123#if DEBUG_CPS2DS
124 world().DLOG("point the lam to the cont: {} = {}", lam->unique_name(), lam->body());
125 if (world().log().level() >= mim::Log::Level::Debug) {
126 lam->dump(2);
127 cont->dump(2);
128 }
129#endif
130
131 return cont->var(); // rewritten_[def] = cont->var(); done in make_continuation
132 }
133 }
134
135 DefVec new_ops{def->ops(), [this](const Def* d) { return rewrite(d); }};
136 auto new_def = def->rebuild(def->type(), new_ops);
137 rewritten_[def] = new_def;
138 return new_def;
139}
140
141Lam* CPS2DSPhase::make_continuation(const Def* cn_type, const Def* arg, Sym prefix) {
142#if DEBUG_CPS2DS
143 world().DLOG("make_continuation {} : {} ({})", prefix, cn_type, arg);
144 if (world().log().level() >= mim::Log::Level::Debug) arg->dump(2);
145#endif
146 auto name = world().append_suffix(prefix, "_cps2ds_cont");
147 auto cont = world().mut_con(cn_type)->set(name)->set_filter(false);
148
149 rewritten_[arg] = cont->var();
150
151 return cont;
152}
153
154Lam* CPS2DSPhase::result_lam(Lam* lam) {
155 if (auto i = lam2lam_.find(lam); i != lam2lam_.end())
156 if (i->second != lam) return result_lam(i->second);
157 return lam;
158}
159
160Scheduler& CPS2DSPhase::scheduler(const Def* def) {
161 auto get_or_make = [&](const Def* lam, const Nest& nest) -> Scheduler& {
162 if (auto sched = scheduler_.find(lam); sched != scheduler_.end()) {
163#if DEBUG_CPS2DS
164 world().DLOG("found existing scheduler for {}", lam);
165#endif
166 return sched->second;
167 } else {
168#if DEBUG_CPS2DS
169 world().DLOG("creating new scheduler for {}", lam);
170#endif
171 auto [it, inserted] = scheduler_.insert({lam, Scheduler(nest)});
172 return it->second;
173 }
174 };
175 for (const auto& [lam, nest] : nests_) {
176#if DEBUG_CPS2DS
177 world().DLOG("looking for scheduler in {} for {}", lam, def);
178#endif
179 if (nest->contains(def)) return get_or_make(lam, *nest);
180 }
181#if DEBUG_CPS2DS
182 world().DLOG("no scheduler found for {}, using current external {}", def, current_external_);
183#endif
184 return get_or_make(current_external_, curr_external_nest());
185}
186
187const Nest& CPS2DSPhase::curr_external_nest() const {
188 auto i = nests_.find(current_external_);
189 assert(i != nests_.end());
190 return *i->second;
191}
192
193} // namespace mim::plug::direct
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:298
void dump() const
Definition dump.cpp:452
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:486
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:429
std::string unique_name() const
name + "_" + Def::gid
Definition def.cpp:579
const T * isa_imm() const
Definition def.h:480
A function.
Definition lam.h:110
Lam * unset()
Definition lam.h:179
const Def * filter() const
Definition lam.h:122
Lam * set(Filter filter, const Def *body)
Definition lam.cpp:29
Lam * set_filter(Filter)
Set filter first.
Definition lam.cpp:28
const Pi * type() const
Definition lam.h:130
const Def * body() const
Definition lam.h:123
const Def * codom() const
Definition lam.h:132
static const Node * lca(const Node *n, const Node *m)
Least common ancestor of n and m.
Definition nest.cpp:105
World & world()
Definition pass.h:64
std::string_view name() const
Definition pass.h:67
Log & log() const
Definition pass.h:66
const Def * insert(const Def *d, const Def *i, const Def *val)
Definition world.cpp:432
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:201
void for_each(bool elide_empty, std::function< void(Def *)>)
Definition world.cpp:679
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Definition dump.cpp:493
Sym append_suffix(Sym name, std::string suffix)
Appends a suffix or an increasing number if the suffix already exists.
Definition world.cpp:643
const Def * var(Def *mut)
Definition world.cpp:180
Lam * mut_con(const Def *dom)
Definition world.h:338
void dump(std::ostream &os)
Dump to os.
Definition dump.cpp:469
void start() final
Actual entry.
Definition cps2ds.cpp:16
The direct style Plugin
Definition direct.h:8
Vector< const Def * > DefVec
Definition def.h:77
@ Lam
Definition def.h:114
@ App
Definition def.h:114