MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
rewrite.cpp
Go to the documentation of this file.
1#include "mim/rewrite.h"
2
3#include <absl/container/fixed_array.h>
4
5#include "mim/world.h"
6
7#include "fe/assert.h"
8
9// Don't use fancy C++-lambdas; it's way too annoying stepping through them in a debugger.
10
11namespace mim {
12
13/*
14 * Rewriter
15 */
16
17const Def* Rewriter::map(const Def* old_def, Defs new_defs) {
18 return old2news_.back()[old_def] = world().tuple(new_defs);
19}
20const Def* Rewriter::map(Defs old_defs, const Def* new_def) {
21 return old2news_.back()[world().tuple(old_defs)] = new_def;
22}
23const Def* Rewriter::map(Defs old_defs, Defs new_defs) {
24 return old2news_.back()[world().tuple(old_defs)] = world().tuple(new_defs);
25}
26
27const Def* Rewriter::rewrite(const Def* old_def) {
28 if (auto new_def = lookup(old_def)) return new_def;
29
30 auto new_def = old_def->isa_mut() ? rewrite_mut((Def*)old_def) : rewrite_imm(old_def);
31 return new_def->set(old_def->dbg());
32}
33
34// clang-format off
35#define CODE_MUT(N) case Node::N: new_def = rewrite_mut_##N(old_mut->as<N>()); break;
36#define CODE_IMM(N) case Node::N: new_def = rewrite_imm_##N(old_def->as<N>()); break;
37// clang-format on
38
39const Def* Rewriter::rewrite_imm(const Def* old_def) {
40 const Def* new_def;
41 switch (old_def->node()) {
43 default: fe::unreachable();
44 }
45 return map(old_def, new_def);
46}
47
48const Def* Rewriter::rewrite_mut(Def* old_mut) {
49 const Def* new_def;
50 switch (old_mut->node()) {
52 default: fe::unreachable();
53 }
54 return new_def;
55}
56
57#undef CODE_MUT
58#undef CODE_IMM
59
61 auto new_ops = DefVec(ops.size());
62 for (size_t i = 0, e = ops.size(); i != e; ++i)
63 new_ops[i] = rewrite(ops[i]);
64 return new_ops;
65}
66
67#ifndef DOXYGEN
68// clang-format off
69const Def* Rewriter::rewrite_imm_Idx (const Idx* ) { return world().type_idx(); }
70const Def* Rewriter::rewrite_imm_Nat (const Nat* ) { return world().type_nat(); }
71const Def* Rewriter::rewrite_imm_Univ (const Univ* ) { return world().univ(); }
72const Def* Rewriter::rewrite_imm_App (const App* d) { return world().app (rewrite(d->callee()), rewrite(d->arg())); }
73const Def* Rewriter::rewrite_imm_Inj (const Inj* d) { return world().inj (rewrite(d->type()), rewrite(d->value())); }
74const Def* Rewriter::rewrite_imm_Insert(const Insert* d) { return world().insert(rewrite(d->tuple()), rewrite(d->index()), rewrite(d->value())); }
75const Def* Rewriter::rewrite_imm_Lam (const Lam* d) { return world().lam (rewrite(d->type())->as<Pi>(), rewrite(d->filter()), rewrite(d->body())); }
76const Def* Rewriter::rewrite_imm_Lit (const Lit* d) { return world().lit (rewrite(d->type()), d->get()); }
77const Def* Rewriter::rewrite_imm_Match (const Match* d) { return world().match (rewrite(d->ops())); }
78const Def* Rewriter::rewrite_imm_Merge (const Merge* d) { return world().merge (rewrite(d->type()), rewrite(d->ops())); }
79const Def* Rewriter::rewrite_imm_Pi (const Pi* d) { return world().pi (rewrite(d->dom()), rewrite(d->codom()), d->is_implicit()); }
80const Def* Rewriter::rewrite_imm_Proxy (const Proxy* d) { return world().proxy (rewrite(d->type()), rewrite(d->ops()), d->pass(), d->tag()); }
81const Def* Rewriter::rewrite_imm_Reform(const Reform* d) { return world().reform(rewrite(d->meta_type())); }
82const Def* Rewriter::rewrite_imm_Rule (const Rule* d) { return world().rule (rewrite(d->type()), rewrite(d->lhs()), rewrite(d->rhs()), rewrite(d->guard())); }
83const Def* Rewriter::rewrite_imm_Sigma (const Sigma* d) { return world().sigma (rewrite(d->ops())); }
84const Def* Rewriter::rewrite_imm_Split (const Split* d) { return world().split (rewrite(d->type()), rewrite(d->value())); }
85const Def* Rewriter::rewrite_imm_Tuple (const Tuple* d) { return world().tuple (rewrite(d->type()), rewrite(d->ops())); }
86const Def* Rewriter::rewrite_imm_Type (const Type* d) { return world().type (rewrite(d->level())); }
87const Def* Rewriter::rewrite_imm_UInc (const UInc* d) { return world().uinc (rewrite(d->op()), d->offset()); }
88const Def* Rewriter::rewrite_imm_UMax (const UMax* d) { return world().umax (rewrite(d->ops())); }
89const Def* Rewriter::rewrite_imm_Uniq (const Uniq* d) { return world().uniq (rewrite(d->op())); }
90const Def* Rewriter::rewrite_imm_Var (const Var* d) { return world().var (rewrite(d->mut())->as_mut()); }
91const Def* Rewriter::rewrite_imm_Top (const Top* d) { return world().top (rewrite(d->type())); }
92const Def* Rewriter::rewrite_imm_Bot (const Bot* d) { return world().bot (rewrite(d->type())); }
93const Def* Rewriter::rewrite_imm_Meet (const Meet* d) { return world().meet (rewrite(d->ops())); }
94const Def* Rewriter::rewrite_imm_Join (const Join* d) { return world().join (rewrite(d->ops())); }
95
96const Def* Rewriter::rewrite_imm_Arr (const Arr* d) { return rewrite_imm_Seq(d); }
97const Def* Rewriter::rewrite_imm_Pack(const Pack* d) { return rewrite_imm_Seq(d); }
98const Def* Rewriter::rewrite_mut_Arr ( Arr* d) { return rewrite_mut_Seq(d); }
99const Def* Rewriter::rewrite_mut_Pack( Pack* d) { return rewrite_mut_Seq(d); }
100
101const Def* Rewriter::rewrite_mut_Pi (Pi* d) { return rewrite_stub(d, world().mut_pi (rewrite(d->type()), d->is_implicit())); }
102const Def* Rewriter::rewrite_mut_Lam (Lam* d) { return rewrite_stub(d, world().mut_lam (rewrite(d->type())->as<Pi>())); }
103const Def* Rewriter::rewrite_mut_Rule (Rule* d) { return rewrite_stub(d, world().mut_rule (rewrite(d->type())->as<Reform>())); }
104const Def* Rewriter::rewrite_mut_Sigma (Sigma* d) { return rewrite_stub(d, world().mut_sigma(rewrite(d->type()), d->num_ops())); }
105const Def* Rewriter::rewrite_mut_Global(Global* d) { return rewrite_stub(d, world().global (rewrite(d->type()), d->is_mutable())); }
106// clang-format on
107
108const Def* Rewriter::rewrite_imm_Axm(const Axm* a) {
109 if (&a->world() != &world()) {
110 auto type = rewrite(a->type());
111 return world().axm(a->normalizer(), a->curry(), a->trip(), type, a->plugin(), a->tag(), a->sub());
112 }
113 return a;
114}
115
116const Def* Rewriter::rewrite_imm_Extract(const Extract* ex) {
117 auto new_index = rewrite(ex->index());
118 if (auto index = Lit::isa(new_index)) {
119 if (auto tuple = ex->tuple()->isa<Tuple>()) return map(ex, rewrite(tuple->op(*index)));
120 if (auto pack = ex->tuple()->isa_imm<Pack>(); pack && pack->arity()->is_closed())
121 return map(ex, rewrite(pack->body()));
122 }
123
124 auto new_tuple = rewrite(ex->tuple());
125 return world().extract(new_tuple, new_index);
126}
127
128const Def* Rewriter::rewrite_mut_Hole(Hole* hole) {
129 auto [last, op] = hole->find();
130 return op ? rewrite(op) : rewrite_stub(last, world().mut_hole(rewrite(last->type())));
131}
132
133#endif
134
136 auto new_arity = rewrite(seq->arity());
137 if (auto l = Lit::isa(new_arity); l && *l == 0) return world().prod(seq->is_intro());
138 return world().seq(seq->is_intro(), new_arity, rewrite(seq->body()));
139}
140
142 if (!seq->is_set()) {
143 auto new_seq = seq->as_mut<Seq>()->stub(world(), rewrite(seq->type()));
144 return map(seq, new_seq);
145 }
146
147 auto new_arity = rewrite(seq->arity());
148 auto l = Lit::isa(new_arity);
149 if (l && *l == 0) return world().prod(seq->is_intro());
150
151 if (auto var = seq->has_var(); var && l && *l <= world().flags().scalarize_threshold) {
152 auto new_ops = absl::FixedArray<const Def*>(*l);
153 for (size_t i = 0, e = *l; i != e; ++i) {
154 push();
155 map(var, world().lit_idx(e, i));
156 new_ops[i] = rewrite(seq->body());
157 pop();
158 }
159 return map(seq, world().prod(seq->is_intro(), new_ops));
160 }
161
162 if (!seq->has_var()) return map(seq, world().seq(seq->is_intro(), new_arity, rewrite(seq->body())));
163 return rewrite_stub(seq->as_mut(), world().mut_seq(seq->is_intro(), rewrite(seq->type())));
164}
165
166const Def* Rewriter::rewrite_stub(Def* old_mut, Def* new_mut) {
167 map(old_mut, new_mut);
168
169 if (old_mut->is_set()) {
170 for (size_t i = 0, e = old_mut->num_ops(); i != e; ++i)
171 new_mut->set(i, rewrite(old_mut->op(i)));
172 if (auto new_imm = new_mut->immutabilize()) return map(old_mut, new_imm);
173 }
174
175 return new_mut;
176}
177
178/*
179 * VarRewriter
180 */
181
182const Def* VarRewriter::rewrite(const Def* old_def) {
183 if (auto new_def = lookup(old_def)) return new_def;
184
185 if (auto old_mut = old_def->isa_mut())
186 return has_intersection(old_mut) ? rewrite_mut(old_mut)->set(old_mut->dbg()) : old_mut;
187
188 if (old_def->local_vars().empty() && old_def->local_muts().empty()) return old_def; // safe to skip
189
190 return has_intersection(old_def) ? rewrite_imm(old_def)->set(old_def->dbg()) : old_def;
191}
192
194 if (auto var = mut->has_var()) {
195 auto& vars = vars_.back();
196 vars = world().vars().insert(vars, var);
197 }
198
199 return Rewriter::rewrite_mut(mut);
200}
201
202/*
203 * Zonker
204 */
205
206const Def* Zonker::map(const Def* old_def, const Def* new_def) {
207 auto repr = lookup(new_def); // always normalize new_def to its representative
208 if (!repr) repr = new_def;
209 return old2news_.back()[old_def] = repr;
210}
211
212const Def* Zonker::lookup(const Def* old_def) {
213 for (auto& old2new : old2news_ | std::views::reverse) {
214 const Def* repr;
215 auto path = DefVec();
216 while (true) {
217 repr = get(old_def);
218
219 if (repr == nullptr) break;
220
221 path.emplace_back(repr);
222 if (repr == old_def) break; // explicit self-map
223
224 old_def = repr;
225 }
226
227 if (path.empty()) continue;
228
229 // path compression: flatten all visited nodes
230 for (auto def : path)
231 old2new[def] = repr;
232
233 return repr;
234 }
235
236 return nullptr;
237}
238
239const Def* Zonker::rewrite(const Def* def) {
240 if (auto hole = def->isa_mut<Hole>()) {
241 auto [last, op] = hole->find();
242 def = op ? op : last;
243 }
244
245 return def->needs_zonk() ? Rewriter::rewrite(def) : def;
246}
247
249 map(mut, mut);
250
251 auto old_type = mut->type();
252 auto old_ops = absl::FixedArray<const Def*>(mut->ops().begin(), mut->ops().end());
253
254 mut->unset()->set_type(rewrite(old_type));
255
256 for (size_t i = 0, e = mut->num_ops(); i != e; ++i)
257 mut->set(i, rewrite(old_ops[i]));
258
259 if (auto new_imm = mut->immutabilize()) return map(mut, new_imm);
260
261 return mut;
262}
263
264} // namespace mim
A (possibly paramterized) Array.
Definition tuple.h:117
Definition axm.h:9
Base class for all Defs.
Definition def.h:252
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:298
constexpr Node node() const noexcept
Definition def.h:275
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:266
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:507
Def * set_type(const Def *)
Update type.
Definition def.cpp:283
bool is_intro() const noexcept
Definition def.h:285
constexpr auto ops() const noexcept
Definition def.h:306
Vars local_vars() const
Vars reachable by following immutable deps().
Definition def.cpp:348
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:498
const Def * op(size_t i) const noexcept
Definition def.h:309
virtual const Def * immutabilize()
Tries to make an immutable from a mutable.
Definition def.h:569
Muts local_muts() const
Mutables reachable by following immutable deps(); mut->local_muts() is by definition the set { mut }...
Definition def.cpp:332
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:452
virtual const Def * arity() const
Definition def.cpp:558
Def * unset()
Unsets all Def::ops; works even, if not set at all or only partially set.
Definition def.cpp:289
bool needs_zonk() const
Yields true, if Def::local_muts() contain a Hole that is set.
Definition check.cpp:12
Dbg dbg() const
Definition def.h:518
const Var * has_var()
Only returns not nullptr, if Var of this mutable has ever been created.
Definition def.h:434
constexpr size_t num_ops() const noexcept
Definition def.h:310
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:206
This node is a hole in the IR that is inferred by its context later on.
Definition check.h:14
A built-in constant of type Nat -> *.
Definition def.h:878
Constructs a Join value.
Definition lattice.h:70
Creates a new Tuple / Pack by inserting Insert::value at position Insert::index into Insert::tuple.
Definition tuple.h:233
A function.
Definition lam.h:110
static std::optional< T > isa(const Def *def)
Definition def.h:843
Scrutinize Match::scrutinee() and dispatch to Match::arms.
Definition lattice.h:116
Constructs a Meet value.
Definition lattice.h:53
A (possibly paramterized) Tuple.
Definition tuple.h:166
A dependent function type.
Definition lam.h:14
Type formation of a rewrite Rule.
Definition rule.h:9
virtual const Def * rewrite_imm_Seq(const Seq *seq)
Definition rewrite.cpp:135
virtual const Def * rewrite_mut_Seq(Seq *seq)
Definition rewrite.cpp:141
World & world()
Definition rewrite.h:48
virtual const Def * rewrite_mut(Def *)
Definition rewrite.cpp:48
virtual void push()
Definition rewrite.h:53
virtual const Def * rewrite_stub(Def *, Def *)
Definition rewrite.cpp:166
virtual const Def * map(const Def *old_def, const Def *new_def)
Definition rewrite.h:60
virtual void pop()
Definition rewrite.h:54
std::deque< Def2Def > old2news_
Definition rewrite.h:108
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:39
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:27
virtual const Def * lookup(const Def *old_def)
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.h:70
A rewrite rule.
Definition rule.h:38
Base class for Arr and Pack.
Definition tuple.h:84
const Def * body() const
Definition tuple.h:91
constexpr bool empty() const noexcept
Is empty?
Definition sets.h:244
A dependent tuple type.
Definition tuple.h:20
Picks the aspect of a Meet [value](Pick::value) by its [type](Def::type).
Definition lattice.h:93
Data constructor for a Sigma.
Definition tuple.h:68
A singleton wraps a type into a higher order type.
Definition lattice.h:180
const Def * rewrite_mut(Def *) final
Definition rewrite.cpp:193
const Def * rewrite(const Def *) final
Definition rewrite.cpp:182
A variable introduced by a binder (mutable).
Definition def.h:719
const Def * insert(const Def *d, const Def *i, const Def *val)
Definition world.cpp:443
const Def * meet(Defs ops)
Definition world.h:522
const Def * uinc(const Def *op, level_t offset=1)
Definition world.cpp:118
const Lit * lit(const Def *type, u64 val)
Definition world.cpp:526
const Type * type(const Def *level)
Definition world.cpp:108
const Def * sigma(Defs ops)
Definition world.cpp:283
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:201
const Def * match(Defs)
Definition world.cpp:614
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
Definition world.h:306
const Def * seq(bool term, const Def *arity, const Def *body)
Definition world.cpp:492
const Univ * univ()
Definition world.h:255
const Def * bot(const Def *type)
Definition world.h:514
const Idx * type_idx()
Definition world.h:540
const Reform * reform(const Def *meta_type)
Definition world.h:362
const Nat * type_nat()
Definition world.h:539
const Lam * lam(const Pi *pi, Lam::Filter f, const Def *body)
Definition world.h:334
const Def * tuple(Defs ops)
Definition world.cpp:293
const Def * inj(const Def *type, const Def *value)
Definition world.cpp:599
const Axm * axm(NormalizeFn n, u8 curry, u8 trip, const Def *type, plugin_t p, tag_t t, sub_t s)
Definition world.h:288
const Def * extract(const Def *d, const Def *i)
Definition world.cpp:358
const Def * join(Defs ops)
Definition world.h:521
const Proxy * proxy(const Def *type, Defs ops, u32 index, u32 tag)
Definition world.h:271
const Def * var(Def *mut)
Definition world.cpp:180
const Def * uniq(const Def *inhabitant)
Definition world.cpp:650
const Def * prod(bool term, Defs ops)
Definition world.h:427
const Def * umax(Defs)
Definition world.cpp:138
const Def * merge(const Def *type, Defs ops)
Definition world.cpp:581
const Def * top(const Def *type)
Definition world.h:515
auto & vars()
Definition world.h:592
const Def * split(const Def *type, const Def *value)
Definition world.cpp:607
const Rule * rule(const Reform *type, const Def *lhs, const Def *rhs, const Def *guard)
Definition world.h:364
const Def * rewire_mut(Def *)
Definition rewrite.cpp:248
const Def * lookup(const Def *old_def) final
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.cpp:212
const Def * rewrite(const Def *) final
Definition rewrite.cpp:239
const Def * map(const Def *old_def, const Def *new_def) final
Definition rewrite.cpp:206
#define MIM_MUT_NODE(X)
Definition def.h:53
#define MIM_IMM_NODE(X)
Definition def.h:37
const Def * op(trait o, const Def *type)
Definition core.h:33
Definition ast.h:14
View< const Def * > Defs
Definition def.h:77
Vector< const Def * > DefVec
Definition def.h:78
TBound< true > Join
AKA union.
Definition lattice.h:174
TExt< true > Top
Definition lattice.h:172
TExt< false > Bot
Definition lattice.h:171
TBound< false > Meet
AKA intersection.
Definition lattice.h:173
@ Nat
Definition def.h:115
@ Pi
Definition def.h:115
@ Pack
Definition def.h:115
@ Reform
Definition def.h:115
@ Tuple
Definition def.h:115
#define CODE_MUT(N)
Definition rewrite.h:87
#define CODE_IMM(N)
Definition rewrite.h:86