MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_for.cpp
Go to the documentation of this file.
2
3#include <mim/lam.h>
4#include <mim/tuple.h>
5
6#include <mim/plug/mem/mem.h>
7
9
11
12namespace {
13
14const Def* merge_s(const Def* elem, const Def* sigma, const Def* mem) {
15 auto& w = elem->world();
16 if (mem) {
17 auto elems = sigma->projs();
18 return merge_sigma(elem, elems);
19 }
20 return w.sigma({elem, sigma});
21}
22
23const Def* merge_t(const Def* elem, const Def* tuple, const Def* mem) {
24 auto& w = elem->world();
25 if (mem) {
26 auto elems = tuple->projs();
27 return merge_tuple(elem, elems);
28 }
29 return w.tuple({elem, tuple});
30}
31
32} // namespace
33
34const Def* LowerFor::rewrite_imm_App(const App* app) {
35 if (auto for_ax = Axm::isa<affine::For>(app)) {
36 DLOG("rewriting for axm: `{}`", for_ax);
37 auto [old_body, old_exit, args] = for_ax->uncurry_args<3>();
38 auto [new_begin, new_end, new_step, new_init] = args->projs<4>([this](const Def* def) { return rewrite(def); });
39
40 auto old_body_lam = old_body->isa_mut<Lam>();
41 auto old_exit_lam = old_exit->isa_mut<Lam>();
42 if (!old_body_lam) old_body_lam = Lam::eta_expand(old_body);
43 if (!old_exit_lam) old_exit_lam = Lam::eta_expand(old_exit);
44
45 auto new_mem = mem::mem_def(new_init);
46 auto new_head_lam = new_world().mut_con(merge_s(new_begin->type(), new_init->type(), new_mem))->set("head");
47 auto new_phis = new_head_lam->vars();
48 auto new_iter = new_phis.front();
49 auto new_acc = new_world().tuple(new_phis.view().subspan(1));
50 new_mem = mem::mem_var(new_head_lam);
51 auto new_bb_dom = new_mem ? new_mem->type() : new_world().sigma();
52
53 auto new_body = new_world().mut_con(new_bb_dom)->set("new_body");
54 auto new_exit = new_world().mut_con(new_bb_dom)->set("new_exit");
55 auto new_yield = new_world().mut_con(new_init->type())->set("new_yield");
56 auto new_cmp = new_world().call(core::icmp::ul, Defs{new_iter, new_end});
57 auto new_inc = new_world().call(core::wrap::add, core::Mode::nusw, Defs{new_iter, new_step});
58
59 new_head_lam->branch(false, new_cmp, new_body, new_exit, new_mem);
60 new_yield->app(false, new_head_lam, merge_t(new_inc, new_yield->var(), new_mem));
61
62 push();
63 map(old_body_lam->var(), {new_iter, new_acc, new_yield});
64 new_body->set({rewrite(old_body_lam->filter()), rewrite(old_body_lam->body())});
65 pop();
66
67 push();
68 map(old_exit_lam->var(), new_acc);
69 new_exit->set({rewrite(old_exit_lam->filter()), rewrite(old_exit_lam->body())});
70 pop();
71
72 return new_world().app(new_head_lam, merge_t(new_begin, new_init, new_mem));
73 }
74
75 return Rewriter::rewrite_imm_App(app);
76}
77
78} // namespace mim::plug::affine::phase
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:485
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:429
auto vars(F f) noexcept
Definition def.h:429
A function.
Definition lam.h:111
const Def * filter() const
Definition lam.h:123
Lam * set(Filter filter, const Def *body)
Definition lam.cpp:29
static Lam * eta_expand(Filter, const Def *f)
Definition lam.cpp:51
const Def * body() const
Definition lam.h:124
World & new_world()
Create new Defs into this.
Definition phase.h:100
virtual void push()
Definition rewrite.h:40
const Def * map(const Def *old_def, const Def *new_def)
Map old_def to new_def and returns new_def.
Definition rewrite.h:46
virtual void pop()
Definition rewrite.h:41
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:14
const Def * sigma(Defs ops)
Definition world.cpp:276
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:194
const Def * tuple(Defs ops)
Definition world.cpp:286
const Def * call(Id id, Args &&... args)
Complete curried call of annexes obeying implicits.
Definition world.h:547
Lam * mut_con(const Def *dom)
Definition world.h:310
const Def * rewrite_imm_App(const App *) final
Definition lower_for.cpp:34
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
const Def * mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Definition mem.h:38
const Def * mem_def(const Def *def)
Returns the (first) element of type mem::M from the given tuple.
Definition mem.h:25
View< const Def * > Defs
Definition def.h:76
const Def * merge_sigma(const Def *def, Defs defs)
Definition tuple.cpp:136
const Def * merge_tuple(const Def *def, Defs defs)
Definition tuple.cpp:141