MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
scalarize.cpp
Go to the documentation of this file.
2
3#include "mim/tuple.h"
4
5#include "mim/pass/eta_exp.h"
6
7namespace mim {
8
11 eta_exp_ = man->find<EtaExp>();
12}
13
14// TODO should also work for mutable non-dependent sigmas
15
16// TODO merge with make_scalar
17bool Scalarize::should_expand(Lam* lam) {
18 if (!isa_workable(lam)) return false;
19 if (auto i = tup2sca_.find(lam); i != tup2sca_.end() && i->second && i->second == lam) return false;
20
21 auto pi = lam->type();
22 if (lam->num_doms() > 1 && Pi::isa_cn(pi) && pi->isa_imm()) return true; // no ugly dependent pis
23
24 tup2sca_[lam] = lam;
25 return false;
26}
27
28Lam* Scalarize::make_scalar(const Def* def) {
29 auto tup_lam = def->isa_mut<Lam>();
30 assert(tup_lam);
31 if (auto i = tup2sca_.find(tup_lam); i != tup2sca_.end()) return i->second;
32
33 auto types = DefVec();
34 auto arg_sz = Vector<size_t>();
35 bool todo = false;
36 for (size_t i = 0, e = tup_lam->num_doms(); i != e; ++i) {
37 auto n = flatten(types, tup_lam->dom(i), false);
38 arg_sz.push_back(n);
39 todo |= n != 1 || types.back() != tup_lam->dom(i);
40 }
41
42 if (!todo) return tup2sca_[tup_lam] = tup_lam;
43
44 auto cn = world().cn(types);
45 auto sca_lam = tup_lam->stub(cn);
46 if (eta_exp_) eta_exp_->new2old(sca_lam, tup_lam);
47 size_t n = 0;
48 DLOG("type {} ~> {}", tup_lam->type(), cn);
49 auto new_vars = world().tuple(DefVec(tup_lam->num_doms(), [&](auto i) {
50 auto tuple = DefVec(arg_sz.at(i), [&](auto) { return sca_lam->var(n++); });
51 return unflatten(tuple, tup_lam->dom(i), false);
52 }));
53 sca_lam->set(tup_lam->reduce(new_vars));
54 tup2sca_[sca_lam] = sca_lam;
55 tup2sca_.emplace(tup_lam, sca_lam);
56 DLOG("lambda {} : {} ~> {} : {}", tup_lam, tup_lam->type(), sca_lam, sca_lam->type());
57 return sca_lam;
58}
59
60const Def* Scalarize::rewrite(const Def* def) {
61 auto& w = world();
62 if (auto app = def->isa<App>()) {
63 const Def* sca_callee = app->callee();
64
65 if (auto tup_lam = sca_callee->isa_mut<Lam>(); should_expand(tup_lam)) {
66 sca_callee = make_scalar(tup_lam);
67
68 } else if (auto proj = sca_callee->isa<Extract>()) {
69 auto tuple = proj->tuple()->isa<Tuple>();
70 if (tuple && std::all_of(tuple->ops().begin(), tuple->ops().end(), [&](const Def* op) {
71 return should_expand(op->isa_mut<Lam>());
72 })) {
73 auto new_tuple = w.tuple(DefVec(tuple->num_ops(), [&](auto i) { return make_scalar(tuple->op(i)); }));
74 sca_callee = w.extract(new_tuple, proj->index());
75 DLOG("Expand tuple: {, } ~> {, }", tuple->ops(), new_tuple->ops());
76 }
77 }
78
79 if (sca_callee != app->callee()) {
80 auto new_args = DefVec();
81 flatten(new_args, app->arg(), false);
82 return world().app(sca_callee, new_args);
83 }
84 }
85 return def;
86}
87
88} // namespace mim
Base class for all Defs.
Definition def.h:251
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:266
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:482
Performs η-expansion: f -> λx.f x, if f is a Lam with more than one user and does not appear in calle...
Definition eta_exp.h:13
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:206
A function.
Definition lam.h:111
const Pi * type() const
Definition lam.h:131
An optimizer that combines several optimizations in an optimal way.
Definition pass.h:172
virtual void init(PassMan *)
Definition pass.cpp:30
PassMan & man()
Definition pass.h:97
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:48
Pi * stub(const Def *type)
Definition lam.h:93
void init(PassMan *) final
Definition scalarize.cpp:9
const Def * rewrite(const Def *) override
Definition scalarize.cpp:60
World & world()
Definition pass.h:64
Data constructor for a Sigma.
Definition tuple.h:68
const Pi * cn()
Definition world.h:280
const Def * tuple(Defs ops)
Definition world.cpp:288
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
Definition ast.h:14
const Def * flatten(const Def *def)
Flattens a sigma/array/pack/tuple.
Definition tuple.cpp:109
Vector< const Def * > DefVec
Definition def.h:77
const Def * unflatten(const Def *def, const Def *type)
Applies the reverse transformation on a Pack / Tuple, given the original type.
Definition tuple.cpp:123
Lam * isa_workable(Lam *lam)
These are Lams that are neither nullptr, nor Lam::is_external, nor Lam::is_unset.
Definition lam.h:349
@ Lam
Definition def.h:114
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >