MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_for.cpp
Go to the documentation of this file.
2
3#include <mim/lam.h>
4#include <mim/tuple.h>
5
6#include <mim/plug/mem/mem.h>
7
9
10namespace mim::plug::affine {
11
12namespace {
13
14const Def* merge_s(World& w, const Def* elem, const Def* sigma, const Def* mem) {
15 if (mem) {
16 auto elems = sigma->projs();
17 return merge_sigma(elem, elems);
18 }
19 return w.sigma({elem, sigma});
20}
21
22const Def* merge_t(World& w, const Def* elem, const Def* tuple, const Def* mem) {
23 if (mem) {
24 auto elems = tuple->projs();
25 return merge_tuple(elem, elems);
26 }
27 return w.tuple({elem, tuple});
28}
29
30const Def* eta_expand(World& w, const Def* f) {
31 auto eta = w.mut_con(Pi::isa_cn(f->type())->dom());
32 eta->app(false, f, eta->var());
33 return eta;
34}
35
36} // namespace
37
38const Def* LowerFor::rewrite(const Def* def) {
39 if (auto i = rewritten_.find(def); i != rewritten_.end()) return i->second;
40
41 if (auto for_ax = Axm::isa<affine::For>(def)) {
42 world().DLOG("rewriting for axm: {} within {}", for_ax, curr_mut());
43 auto [begin, end, step, init, body, exit] = for_ax->args<6>();
44
45 auto body_lam = body->isa_mut<Lam>();
46 auto exit_lam = exit->isa_mut<Lam>();
47 if (!body_lam) body = eta_expand(world(), body);
48 if (!exit_lam) exit = eta_expand(world(), exit);
49
50 auto mem = mem::mem_def(init);
51 auto head_lam = world().mut_con(merge_s(world(), begin->type(), init->type(), mem))->set("head");
52 auto phis = head_lam->vars();
53 auto iter = phis.front();
54 auto acc = world().tuple(phis.view().subspan(1));
55 mem = mem::mem_var(head_lam);
56 auto bb_dom = mem ? mem->type() : world().sigma();
57 auto new_body = world().mut_con(bb_dom)->set("new_body");
58 auto new_exit = world().mut_con(bb_dom)->set("new_exit");
59 auto new_yield = world().mut_con(init->type())->set("new_yield");
60 auto cmp = world().call(core::icmp::ul, Defs{iter, end});
61 auto new_iter = world().call(core::wrap::add, core::Mode::nusw, Defs{iter, step});
62
63 head_lam->branch(false, cmp, new_body, new_exit, mem);
64 new_yield->app(false, head_lam, merge_t(world(), new_iter, new_yield->var(), mem));
65 new_body->set(body->reduce(world().tuple({iter, acc, new_yield})));
66 new_exit->set(exit->reduce(acc));
67
68 return rewritten_[def] = world().app(head_lam, merge_t(world(), begin, init, mem));
69 }
70
71 return def;
72}
73
74} // namespace mim::plug::affine
static auto isa(const Def *def)
Definition axm.h:104
Base class for all Defs.
Definition def.h:203
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:240
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:434
auto vars(F f) noexcept
Definition def.h:384
A function.
Definition lam.h:106
Lam * set(Filter filter, const Def *body)
Definition lam.h:165
World & world()
Definition pass.h:296
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:44
const Def * dom() const
Definition lam.h:32
Lam * curr_mut() const
Definition pass.h:232
const Type * type(const Def *level)
Definition world.cpp:106
const Def * sigma(Defs ops)
Definition world.cpp:272
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:190
const Def * tuple(Defs ops)
Definition world.cpp:282
const Def * var(const Def *type, Def *mut)
Definition world.cpp:174
const Def * call(Id id, Args &&... args)
Complete curried call of annexes obeying implicits.
Definition world.h:500
Lam * mut_con(const Def *dom)
Definition world.h:296
const Def * rewrite(const Def *) override
Definition lower_for.cpp:38
The affine Plugin
Definition lower_for.h:7
The mem Plugin
Definition mem.h:11
const Def * mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Definition mem.h:38
const Def * mem_def(const Def *def)
Returns the (first) element of type mem::M from the given tuple.
Definition mem.h:25
The tuple Plugin
View< const Def * > Defs
Definition def.h:49
const Def * merge_sigma(const Def *def, Defs defs)
Definition tuple.cpp:114
const Def * merge_tuple(const Def *def, Defs defs)
Definition tuple.cpp:119