MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
scalarize.cpp
Go to the documentation of this file.
2
3#include "mim/rewrite.h"
4#include "mim/tuple.h"
5
6#include "mim/pass/eta_exp.h"
7
8namespace mim {
9
10// TODO should also work for mutable non-dependent sigmas
11
12// TODO merge with make_scalar
13bool Scalarize::should_expand(Lam* lam) {
14 if (!isa_workable(lam)) return false;
15 if (auto i = tup2sca_.find(lam); i != tup2sca_.end() && i->second && i->second == lam) return false;
16
17 auto pi = lam->type();
18 if (lam->num_doms() > 1 && Pi::isa_cn(pi) && pi->isa_imm()) return true; // no ugly dependent pis
19
20 tup2sca_[lam] = lam;
21 return false;
22}
23
24Lam* Scalarize::make_scalar(Ref def) {
25 auto tup_lam = def->isa_mut<Lam>();
26 assert(tup_lam);
27 if (auto i = tup2sca_.find(tup_lam); i != tup2sca_.end()) return i->second;
28
29 auto types = DefVec();
30 auto arg_sz = Vector<size_t>();
31 bool todo = false;
32 for (size_t i = 0, e = tup_lam->num_doms(); i != e; ++i) {
33 auto n = flatten(types, tup_lam->dom(i), false);
34 arg_sz.push_back(n);
35 todo |= n != 1 || types.back() != tup_lam->dom(i);
36 }
37
38 if (!todo) return tup2sca_[tup_lam] = tup_lam;
39
40 auto cn = world().cn(types);
41 auto sca_lam = tup_lam->stub(cn);
42 if (eta_exp_) eta_exp_->new2old(sca_lam, tup_lam);
43 size_t n = 0;
44 world().DLOG("type {} ~> {}", tup_lam->type(), cn);
45 auto new_vars = world().tuple(DefVec(tup_lam->num_doms(), [&](auto i) {
46 auto tuple = DefVec(arg_sz.at(i), [&](auto) { return sca_lam->var(n++); });
47 return unflatten(tuple, tup_lam->dom(i), false);
48 }));
49 sca_lam->set(tup_lam->reduce(new_vars));
50 tup2sca_[sca_lam] = sca_lam;
51 tup2sca_.emplace(tup_lam, sca_lam);
52 world().DLOG("lambda {} : {} ~> {} : {}", tup_lam, tup_lam->type(), sca_lam, sca_lam->type());
53 return sca_lam;
54}
55
56Ref Scalarize::rewrite(Ref def) {
57 auto& w = world();
58 if (auto app = def->isa<App>()) {
59 Ref sca_callee = app->callee();
60
61 if (auto tup_lam = sca_callee->isa_mut<Lam>(); should_expand(tup_lam)) {
62 sca_callee = make_scalar(tup_lam);
63
64 } else if (auto proj = sca_callee->isa<Extract>()) {
65 auto tuple = proj->tuple()->isa<Tuple>();
66 if (tuple && std::all_of(tuple->ops().begin(), tuple->ops().end(), [&](Ref op) {
67 return should_expand(op->isa_mut<Lam>());
68 })) {
69 auto new_tuple = w.tuple(DefVec(tuple->num_ops(), [&](auto i) { return make_scalar(tuple->op(i)); }));
70 sca_callee = w.extract(new_tuple, proj->index());
71 w.DLOG("Expand tuple: {, } ~> {, }", tuple->ops(), new_tuple->ops());
72 }
73 }
74
75 if (sca_callee != app->callee()) {
76 auto new_args = DefVec();
77 flatten(new_args, app->arg(), false);
78 return world().app(sca_callee, new_args);
79 }
80 }
81 return def;
82}
83
84} // namespace mim
Def * set(size_t i, const Def *def)
Successively set from left to right.
Definition def.cpp:246
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:447
const T * isa_imm() const
Definition def.h:441
void new2old(Lam *new_lam, Lam *old_lam)
Definition eta_exp.h:22
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:152
A function.
Definition lam.h:103
World & world()
Definition pass.h:296
Pi * stub(Ref type)
Definition lam.h:88
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:50
Helper class to retrieve Infer::arg if present.
Definition def.h:86
Data constructor for a Sigma.
Definition tuple.h:49
const Pi * cn()
Definition world.h:267
Ref tuple(Defs ops)
Definition world.cpp:238
@ Lam
Definition def.h:40
Definition cfg.h:11
const Def * flatten(const Def *def)
Flattens a sigma/array/pack/tuple.
Definition tuple.cpp:66
Vector< const Def * > DefVec
Definition def.h:62
const Def * unflatten(const Def *def, const Def *type)
Applies the reverse transformation on a Pack / Tuple, given the original type.
Definition tuple.cpp:80
Lam * isa_workable(Lam *lam)
These are Lams that are neither nullptr, nor Lam::is_external, nor Lam::is_unset.
Definition lam.h:251
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >