MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
scalarize.cpp
Go to the documentation of this file.
2
3#include "mim/tuple.h"
4
5#include "mim/pass/eta_exp.h"
6
7namespace mim {
8
9// TODO should also work for mutable non-dependent sigmas
10
11// TODO merge with make_scalar
12bool Scalarize::should_expand(Lam* lam) {
13 if (!isa_workable(lam)) return false;
14 if (auto i = tup2sca_.find(lam); i != tup2sca_.end() && i->second && i->second == lam) return false;
15
16 auto pi = lam->type();
17 if (lam->num_doms() > 1 && Pi::isa_cn(pi) && pi->isa_imm()) return true; // no ugly dependent pis
18
19 tup2sca_[lam] = lam;
20 return false;
21}
22
23Lam* Scalarize::make_scalar(const Def* def) {
24 auto tup_lam = def->isa_mut<Lam>();
25 assert(tup_lam);
26 if (auto i = tup2sca_.find(tup_lam); i != tup2sca_.end()) return i->second;
27
28 auto types = DefVec();
29 auto arg_sz = Vector<size_t>();
30 bool todo = false;
31 for (size_t i = 0, e = tup_lam->num_doms(); i != e; ++i) {
32 auto n = flatten(types, tup_lam->dom(i), false);
33 arg_sz.push_back(n);
34 todo |= n != 1 || types.back() != tup_lam->dom(i);
35 }
36
37 if (!todo) return tup2sca_[tup_lam] = tup_lam;
38
39 auto cn = world().cn(types);
40 auto sca_lam = tup_lam->stub(cn);
41 if (eta_exp_) eta_exp_->new2old(sca_lam, tup_lam);
42 size_t n = 0;
43 world().DLOG("type {} ~> {}", tup_lam->type(), cn);
44 auto new_vars = world().tuple(DefVec(tup_lam->num_doms(), [&](auto i) {
45 auto tuple = DefVec(arg_sz.at(i), [&](auto) { return sca_lam->var(n++); });
46 return unflatten(tuple, tup_lam->dom(i), false);
47 }));
48 sca_lam->set(tup_lam->reduce(new_vars));
49 tup2sca_[sca_lam] = sca_lam;
50 tup2sca_.emplace(tup_lam, sca_lam);
51 world().DLOG("lambda {} : {} ~> {} : {}", tup_lam, tup_lam->type(), sca_lam, sca_lam->type());
52 return sca_lam;
53}
54
55const Def* Scalarize::rewrite(const Def* def) {
56 auto& w = world();
57 if (auto app = def->isa<App>()) {
58 const Def* sca_callee = app->callee();
59
60 if (auto tup_lam = sca_callee->isa_mut<Lam>(); should_expand(tup_lam)) {
61 sca_callee = make_scalar(tup_lam);
62
63 } else if (auto proj = sca_callee->isa<Extract>()) {
64 auto tuple = proj->tuple()->isa<Tuple>();
65 if (tuple && std::all_of(tuple->ops().begin(), tuple->ops().end(), [&](const Def* op) {
66 return should_expand(op->isa_mut<Lam>());
67 })) {
68 auto new_tuple = w.tuple(DefVec(tuple->num_ops(), [&](auto i) { return make_scalar(tuple->op(i)); }));
69 sca_callee = w.extract(new_tuple, proj->index());
70 w.DLOG("Expand tuple: {, } ~> {, }", tuple->ops(), new_tuple->ops());
71 }
72 }
73
74 if (sca_callee != app->callee()) {
75 auto new_args = DefVec();
76 flatten(new_args, app->arg(), false);
77 return world().app(sca_callee, new_args);
78 }
79 }
80 return def;
81}
82
83} // namespace mim
Base class for all Defs.
Definition def.h:197
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:240
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:429
const T * isa_imm() const
Definition def.h:423
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:163
A function.
Definition lam.h:106
World & world()
Definition pass.h:296
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:44
Pi * stub(const Def *type)
Definition lam.h:89
const Def * rewrite(const Def *) override
Definition scalarize.cpp:55
Data constructor for a Sigma.
Definition tuple.h:56
const Pi * cn()
Definition world.h:263
const Def * tuple(Defs ops)
Definition world.cpp:266
Definition ast.h:14
const Def * flatten(const Def *def)
Flattens a sigma/array/pack/tuple.
Definition tuple.cpp:66
Vector< const Def * > DefVec
Definition def.h:49
const Def * unflatten(const Def *def, const Def *type)
Applies the reverse transformation on a Pack / Tuple, given the original type.
Definition tuple.cpp:80
Lam * isa_workable(Lam *lam)
These are Lams that are neither nullptr, nor Lam::is_external, nor Lam::is_unset.
Definition lam.h:254
@ Lam
Definition def.h:84
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >