MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
rewrite.cpp
Go to the documentation of this file.
1#include "mim/rewrite.h"
2
3#include <absl/container/fixed_array.h>
4
5#include "mim/check.h"
6#include "mim/world.h"
7
8#include "fe/assert.h"
9
10// Don't use fancy C++-lambdas; it's way too annoying stepping through them in a debugger.
11
12namespace mim {
13
14const Def* Rewriter::rewrite(const Def* old_def) {
15 if (auto new_def = lookup(old_def)) return new_def;
16
17 auto new_def = old_def->isa_mut() ? rewrite_mut((Def*)old_def) : rewrite_imm(old_def);
18 return new_def->set(old_def->dbg());
19}
20
21// clang-format off
22#define CODE_MUT(N) case Node::N: new_def = rewrite_mut_##N(old_mut->as<N>()); break;
23#define CODE_IMM(N) case Node::N: new_def = rewrite_imm_##N(old_def->as<N>()); break;
24// clang-format on
25
26const Def* Rewriter::rewrite_imm(const Def* old_def) {
27 const Def* new_def;
28 switch (old_def->node()) {
30 default: fe::unreachable();
31 }
32 return map(old_def, new_def);
33}
34
35const Def* Rewriter::rewrite_mut(Def* old_mut) {
36 const Def* new_def;
37 switch (old_mut->node()) {
39 default: fe::unreachable();
40 }
41 return new_def;
42}
43
44#undef CODE_MUT
45#undef CODE_IMM
46
48 auto new_ops = DefVec(ops.size());
49 for (size_t i = 0, e = ops.size(); i != e; ++i)
50 new_ops[i] = rewrite(ops[i]);
51 return new_ops;
52}
53
54#ifndef DOXYGEN
55// clang-format off
56const Def* Rewriter::rewrite_imm_Idx (const Idx* ) { return world().type_idx(); }
57const Def* Rewriter::rewrite_imm_Nat (const Nat* ) { return world().type_nat(); }
58const Def* Rewriter::rewrite_imm_Univ (const Univ* ) { return world().univ(); }
59const Def* Rewriter::rewrite_imm_App (const App* d) { return world().app (rewrite(d->callee()), rewrite(d->arg())); }
60const Def* Rewriter::rewrite_imm_Inj (const Inj* d) { return world().inj (rewrite(d->type()), rewrite(d->value())); }
61const Def* Rewriter::rewrite_imm_Insert(const Insert* d) { return world().insert(rewrite(d->tuple()), rewrite(d->index()), rewrite(d->value())); }
62const Def* Rewriter::rewrite_imm_Lam (const Lam* d) { return world().lam (rewrite(d->type())->as<Pi>(), rewrite(d->filter()), rewrite(d->body())); }
63const Def* Rewriter::rewrite_imm_Lit (const Lit* d) { return world().lit (rewrite(d->type()), d->get()); }
64const Def* Rewriter::rewrite_imm_Match (const Match* d) { return world().match (rewrite(d->ops())); }
65const Def* Rewriter::rewrite_imm_Merge (const Merge* d) { return world().merge (rewrite(d->type()), rewrite(d->ops())); }
66const Def* Rewriter::rewrite_imm_Pi (const Pi* d) { return world().pi (rewrite(d->dom()), rewrite(d->codom()), d->is_implicit()); }
67const Def* Rewriter::rewrite_imm_Proxy (const Proxy* d) { return world().proxy (rewrite(d->type()), rewrite(d->ops()), d->pass(), d->tag()); }
68const Def* Rewriter::rewrite_imm_Rule (const Rule* d) { return world().rule (rewrite(d->type()), rewrite(d->lhs()), rewrite(d->rhs()), rewrite(d->guard())); }
69const Def* Rewriter::rewrite_imm_Sigma (const Sigma* d) { return world().sigma (rewrite(d->ops())); }
70const Def* Rewriter::rewrite_imm_Split (const Split* d) { return world().split (rewrite(d->type()), rewrite(d->value())); }
71const Def* Rewriter::rewrite_imm_Tuple (const Tuple* d) { return world().tuple (rewrite(d->type()), rewrite(d->ops())); }
72const Def* Rewriter::rewrite_imm_Type (const Type* d) { return world().type (rewrite(d->level())); }
73const Def* Rewriter::rewrite_imm_UInc (const UInc* d) { return world().uinc (rewrite(d->op()), d->offset()); }
74const Def* Rewriter::rewrite_imm_UMax (const UMax* d) { return world().umax (rewrite(d->ops())); }
75const Def* Rewriter::rewrite_imm_Uniq (const Uniq* d) { return world().uniq (rewrite(d->op())); }
76const Def* Rewriter::rewrite_imm_Var (const Var* d) { return world().var (rewrite(d->type()), rewrite(d->mut())->as_mut()); }
77const Def* Rewriter::rewrite_imm_Top (const Top* d) { return world().top (rewrite(d->type())); }
78const Def* Rewriter::rewrite_imm_Bot (const Bot* d) { return world().bot (rewrite(d->type())); }
79const Def* Rewriter::rewrite_imm_Meet (const Meet* d) { return world().meet (rewrite(d->ops())); }
80const Def* Rewriter::rewrite_imm_Join (const Join* d) { return world().join (rewrite(d->ops())); }
81
82const Def* Rewriter::rewrite_imm_Arr (const Arr* d) { return rewrite_imm_Seq(d); }
83const Def* Rewriter::rewrite_imm_Pack(const Pack* d) { return rewrite_imm_Seq(d); }
84const Def* Rewriter::rewrite_mut_Arr ( Arr* d) { return rewrite_mut_Seq(d); }
85const Def* Rewriter::rewrite_mut_Pack( Pack* d) { return rewrite_mut_Seq(d); }
86
87const Def* Rewriter::rewrite_mut_Pi (Pi* d) { return rewrite_stub(d, world().mut_pi (rewrite(d->type()), d->is_implicit())); }
88const Def* Rewriter::rewrite_mut_Lam (Lam* d) { return rewrite_stub(d, world().mut_lam (rewrite(d->type())->as<Pi>())); }
89const Def* Rewriter::rewrite_mut_Rule (Rule* d) { return rewrite_stub(d, world().mut_rule (rewrite(d->type())->as<Reform>())); }
90const Def* Rewriter::rewrite_mut_Sigma (Sigma* d) { return rewrite_stub(d, world().mut_sigma(rewrite(d->type()), d->num_ops())); }
91const Def* Rewriter::rewrite_mut_Global(Global* d) { return rewrite_stub(d, world().global (rewrite(d->type()), d->is_mutable())); }
92// clang-format on
93
94const Def* Rewriter::rewrite_imm_Axm(const Axm* a) {
95 if (&a->world() != &world()) {
96 auto type = rewrite(a->type());
97 return world().axm(a->normalizer(), a->curry(), a->trip(), type, a->plugin(), a->tag(), a->sub());
98 }
99 return a;
100}
101
102const Def* Rewriter::rewrite_imm_Extract(const Extract* ex) {
103 auto new_index = rewrite(ex->index());
104 if (auto index = Lit::isa(new_index)) {
105 if (auto tuple = ex->tuple()->isa<Tuple>()) return map(ex, rewrite(tuple->op(*index)));
106 if (auto pack = ex->tuple()->isa_imm<Pack>(); pack && pack->arity()->is_closed())
107 return map(ex, rewrite(pack->body()));
108 }
109
110 auto new_tuple = rewrite(ex->tuple());
111 return world().extract(new_tuple, new_index);
112}
113
114const Def* Rewriter::rewrite_mut_Hole(Hole* hole) {
115 auto [last, op] = hole->find();
116 return op ? rewrite(op) : rewrite_stub(last, world().mut_hole(rewrite(last->type())));
117}
118
119#endif
120
122 auto new_arity = rewrite(seq->arity());
123 if (auto l = Lit::isa(new_arity); l && *l == 0) return world().prod(seq->is_intro());
124 return world().seq(seq->is_intro(), new_arity, rewrite(seq->body()));
125}
126
128 if (!seq->is_set()) {
129 auto new_seq = seq->as_mut<Seq>()->stub(world(), rewrite(seq->type()));
130 return map(seq, new_seq);
131 }
132
133 auto new_arity = rewrite(seq->arity());
134 auto l = Lit::isa(new_arity);
135 if (l && *l == 0) return world().prod(seq->is_intro());
136
137 if (auto var = seq->has_var(); var && l && *l <= world().flags().scalarize_threshold) {
138 auto new_ops = absl::FixedArray<const Def*>(*l);
139 for (size_t i = 0, e = *l; i != e; ++i) {
140 push();
141 map(var, world().lit_idx(e, i));
142 new_ops[i] = rewrite(seq->body());
143 pop();
144 }
145 return map(seq, world().prod(seq->is_intro(), new_ops));
146 }
147
148 if (!seq->has_var()) return map(seq, world().seq(seq->is_intro(), new_arity, rewrite(seq->body())));
149 return rewrite_stub(seq->as_mut(), world().mut_seq(seq->is_term(), rewrite(seq->type())));
150}
151
152const Def* Rewriter::rewrite_stub(Def* old_mut, Def* new_mut) {
153 map(old_mut, new_mut);
154
155 if (old_mut->is_set()) {
156 for (size_t i = 0, e = old_mut->num_ops(); i != e; ++i)
157 new_mut->set(i, rewrite(old_mut->op(i)));
158 if (auto new_imm = new_mut->immutabilize()) return map(old_mut, new_imm);
159 }
160
161 return new_mut;
162}
163
164} // namespace mim
A (possibly paramterized) Array.
Definition tuple.h:117
Definition axm.h:9
Base class for all Defs.
Definition def.h:251
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:298
constexpr Node node() const noexcept
Definition def.h:274
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:266
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:491
bool is_intro() const noexcept
Definition def.h:284
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:482
bool is_term() const
Definition def.cpp:480
const Def * op(size_t i) const noexcept
Definition def.h:308
virtual const Def * immutabilize()
Tries to make an immutable from a mutable.
Definition def.h:553
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:295
virtual const Def * arity() const
Definition def.cpp:546
Dbg dbg() const
Definition def.h:502
const Var * has_var()
Only returns not nullptr, if Var of this mutable has ever been created.
Definition def.h:433
constexpr size_t num_ops() const noexcept
Definition def.h:309
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:206
This node is a hole in the IR that is inferred by its context later on.
Definition check.h:14
A built-in constant of type Nat -> *.
Definition def.h:845
Constructs a Join value.
Definition lattice.h:70
Creates a new Tuple / Pack by inserting Insert::value at position Insert::index into Insert::tuple.
Definition tuple.h:233
A function.
Definition lam.h:109
static std::optional< T > isa(const Def *def)
Definition def.h:810
Scrutinize Match::scrutinee() and dispatch to Match::arms.
Definition lattice.h:116
Constructs a Meet value.
Definition lattice.h:53
A (possibly paramterized) Tuple.
Definition tuple.h:166
A dependent function type.
Definition lam.h:13
const Def * lookup(const Def *old_def)
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.h:49
virtual const Def * rewrite_imm_Seq(const Seq *seq)
Definition rewrite.cpp:121
virtual const Def * rewrite_mut_Seq(Seq *seq)
Definition rewrite.cpp:127
World & world()
Definition rewrite.h:36
virtual const Def * rewrite_mut(Def *)
Definition rewrite.cpp:35
virtual void push()
Definition rewrite.h:40
virtual const Def * rewrite_stub(Def *, Def *)
Definition rewrite.cpp:152
const Def * map(const Def *old_def, const Def *new_def)
Map old_def to new_def and returns new_def.
Definition rewrite.h:45
virtual void pop()
Definition rewrite.h:41
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:26
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:14
A rewrite rule.
Definition rule.h:38
Base class for Arr and Pack.
Definition tuple.h:84
const Def * body() const
Definition tuple.h:91
A dependent tuple type.
Definition tuple.h:20
Picks the aspect of a Meet [value](Pick::value) by its [type](Def::type).
Definition lattice.h:93
Data constructor for a Sigma.
Definition tuple.h:68
A singleton wraps a type into a higher order type.
Definition lattice.h:180
const Def * insert(const Def *d, const Def *i, const Def *val)
Definition world.cpp:425
const Def * meet(Defs ops)
Definition world.h:480
const Def * uinc(const Def *op, level_t offset=1)
Definition world.cpp:115
const Lit * lit(const Def *type, u64 val)
Definition world.cpp:508
const Type * type(const Def *level)
Definition world.cpp:106
const Def * sigma(Defs ops)
Definition world.cpp:278
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:196
const Def * match(Defs)
Definition world.cpp:595
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
Definition world.h:264
const Def * seq(bool term, const Def *arity, const Def *body)
Definition world.cpp:474
const Univ * univ()
Definition world.h:213
const Def * bot(const Def *type)
Definition world.h:472
const Idx * type_idx()
Definition world.h:498
const Nat * type_nat()
Definition world.h:497
const Lam * lam(const Pi *pi, Lam::Filter f, const Def *body)
Definition world.h:292
const Def * tuple(Defs ops)
Definition world.cpp:288
const Def * inj(const Def *type, const Def *value)
Definition world.cpp:580
const Axm * axm(NormalizeFn n, u8 curry, u8 trip, const Def *type, plugin_t p, tag_t t, sub_t s)
Definition world.h:246
const Def * var(const Def *type, Def *mut)
Definition world.cpp:177
const Def * extract(const Def *d, const Def *i)
Definition world.cpp:346
const Def * join(Defs ops)
Definition world.h:479
const Proxy * proxy(const Def *type, Defs ops, u32 index, u32 tag)
Definition world.h:229
const Def * uniq(const Def *inhabitant)
Definition world.cpp:631
const Def * prod(bool term, Defs ops)
Definition world.h:385
const Def * umax(Defs)
Definition world.cpp:135
const Def * merge(const Def *type, Defs ops)
Definition world.cpp:562
const Def * top(const Def *type)
Definition world.h:473
const Def * split(const Def *type, const Def *value)
Definition world.cpp:588
const Rule * rule(const Reform *type, const Def *lhs, const Def *rhs, const Def *guard)
Definition world.h:322
#define MIM_MUT_NODE(m)
Definition def.h:52
#define MIM_IMM_NODE(m)
Definition def.h:36
const Def * op(trait o, const Def *type)
Definition core.h:33
Definition ast.h:14
View< const Def * > Defs
Definition def.h:76
Vector< const Def * > DefVec
Definition def.h:77
TBound< true > Join
AKA union.
Definition lattice.h:174
TExt< true > Top
Definition lattice.h:172
TExt< false > Bot
Definition lattice.h:171
TBound< false > Meet
AKA intersection.
Definition lattice.h:173
@ Nat
Definition def.h:114
@ Pi
Definition def.h:114
@ Pack
Definition def.h:114
@ Reform
Definition def.h:114
@ Tuple
Definition def.h:114
#define CODE_MUT(N)
Definition rewrite.h:66
#define CODE_IMM(N)
Definition rewrite.h:65