MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
rewrite.cpp
Go to the documentation of this file.
1#include "mim/rewrite.h"
2
3#include <absl/container/fixed_array.h>
4
5#include "mim/world.h"
6
7#include "fe/assert.h"
8
9// Don't use fancy C++-lambdas; it's way too annoying stepping through them in a debugger.
10
11namespace mim {
12
13/*
14 * Rewriter
15 */
16
17const Def* Rewriter::map(const Def* old_def, Defs new_defs) {
18 return old2news_.back()[old_def] = world().tuple(new_defs);
19}
20const Def* Rewriter::map(Defs old_defs, const Def* new_def) {
21 return old2news_.back()[world().tuple(old_defs)] = new_def;
22}
23const Def* Rewriter::map(Defs old_defs, Defs new_defs) {
24 return old2news_.back()[world().tuple(old_defs)] = world().tuple(new_defs);
25}
26
27const Def* Rewriter::rewrite(const Def* old_def) {
28 if (auto new_def = lookup(old_def)) return new_def;
29
30 auto new_def = old_def->isa_mut() ? rewrite_mut((Def*)old_def) : rewrite_imm(old_def);
31 return new_def->set(old_def->dbg());
32}
33
34// clang-format off
35#define CODE_MUT(N) case Node::N: new_def = rewrite_mut_##N(old_mut->as<N>()); break;
36#define CODE_IMM(N) case Node::N: new_def = rewrite_imm_##N(old_def->as<N>()); break;
37// clang-format on
38
39const Def* Rewriter::rewrite_imm(const Def* old_def) {
40 const Def* new_def;
41 switch (old_def->node()) {
43 default: fe::unreachable();
44 }
45 return map(old_def, new_def);
46}
47
48const Def* Rewriter::rewrite_mut(Def* old_mut) {
49 const Def* new_def;
50 switch (old_mut->node()) {
52 default: fe::unreachable();
53 }
54 return new_def;
55}
56
57#undef CODE_MUT
58#undef CODE_IMM
59
61 auto new_ops = DefVec(ops.size());
62 for (size_t i = 0, e = ops.size(); i != e; ++i)
63 new_ops[i] = rewrite(ops[i]);
64 return new_ops;
65}
66
67#ifndef DOXYGEN
68// clang-format off
69const Def* Rewriter::rewrite_imm_Idx (const Idx* ) { return world().type_idx(); }
70const Def* Rewriter::rewrite_imm_Nat (const Nat* ) { return world().type_nat(); }
71const Def* Rewriter::rewrite_imm_Univ (const Univ* ) { return world().univ(); }
72const Def* Rewriter::rewrite_imm_App (const App* d) { return world().app (rewrite(d->callee()), rewrite(d->arg())); }
73const Def* Rewriter::rewrite_imm_Inj (const Inj* d) { return world().inj (rewrite(d->type()), rewrite(d->value())); }
74const Def* Rewriter::rewrite_imm_Insert(const Insert* d) { return world().insert(rewrite(d->tuple()), rewrite(d->index()), rewrite(d->value())); }
75const Def* Rewriter::rewrite_imm_Lam (const Lam* d) { return world().lam (rewrite(d->type())->as<Pi>(), rewrite(d->filter()), rewrite(d->body())); }
76const Def* Rewriter::rewrite_imm_Lit (const Lit* d) { return world().lit (rewrite(d->type()), d->get()); }
77const Def* Rewriter::rewrite_imm_Match (const Match* d) { return world().match (rewrite(d->ops())); }
78const Def* Rewriter::rewrite_imm_Merge (const Merge* d) { return world().merge (rewrite(d->type()), rewrite(d->ops())); }
79const Def* Rewriter::rewrite_imm_Pi (const Pi* d) { return world().pi (rewrite(d->dom()), rewrite(d->codom()), d->is_implicit()); }
80const Def* Rewriter::rewrite_imm_Proxy (const Proxy* d) { return world().proxy (rewrite(d->type()), rewrite(d->ops()), d->pass(), d->tag()); }
81const Def* Rewriter::rewrite_imm_Rule (const Rule* d) { return world().rule (rewrite(d->type()), rewrite(d->lhs()), rewrite(d->rhs()), rewrite(d->guard())); }
82const Def* Rewriter::rewrite_imm_Sigma (const Sigma* d) { return world().sigma (rewrite(d->ops())); }
83const Def* Rewriter::rewrite_imm_Split (const Split* d) { return world().split (rewrite(d->type()), rewrite(d->value())); }
84const Def* Rewriter::rewrite_imm_Tuple (const Tuple* d) { return world().tuple (rewrite(d->type()), rewrite(d->ops())); }
85const Def* Rewriter::rewrite_imm_Type (const Type* d) { return world().type (rewrite(d->level())); }
86const Def* Rewriter::rewrite_imm_UInc (const UInc* d) { return world().uinc (rewrite(d->op()), d->offset()); }
87const Def* Rewriter::rewrite_imm_UMax (const UMax* d) { return world().umax (rewrite(d->ops())); }
88const Def* Rewriter::rewrite_imm_Uniq (const Uniq* d) { return world().uniq (rewrite(d->op())); }
89const Def* Rewriter::rewrite_imm_Var (const Var* d) { return world().var (rewrite(d->mut())->as_mut()); }
90const Def* Rewriter::rewrite_imm_Top (const Top* d) { return world().top (rewrite(d->type())); }
91const Def* Rewriter::rewrite_imm_Bot (const Bot* d) { return world().bot (rewrite(d->type())); }
92const Def* Rewriter::rewrite_imm_Meet (const Meet* d) { return world().meet (rewrite(d->ops())); }
93const Def* Rewriter::rewrite_imm_Join (const Join* d) { return world().join (rewrite(d->ops())); }
94
95const Def* Rewriter::rewrite_imm_Arr (const Arr* d) { return rewrite_imm_Seq(d); }
96const Def* Rewriter::rewrite_imm_Pack(const Pack* d) { return rewrite_imm_Seq(d); }
97const Def* Rewriter::rewrite_mut_Arr ( Arr* d) { return rewrite_mut_Seq(d); }
98const Def* Rewriter::rewrite_mut_Pack( Pack* d) { return rewrite_mut_Seq(d); }
99
100const Def* Rewriter::rewrite_mut_Pi (Pi* d) { return rewrite_stub(d, world().mut_pi (rewrite(d->type()), d->is_implicit())); }
101const Def* Rewriter::rewrite_mut_Lam (Lam* d) { return rewrite_stub(d, world().mut_lam (rewrite(d->type())->as<Pi>())); }
102const Def* Rewriter::rewrite_mut_Rule (Rule* d) { return rewrite_stub(d, world().mut_rule (rewrite(d->type())->as<Reform>())); }
103const Def* Rewriter::rewrite_mut_Sigma (Sigma* d) { return rewrite_stub(d, world().mut_sigma(rewrite(d->type()), d->num_ops())); }
104const Def* Rewriter::rewrite_mut_Global(Global* d) { return rewrite_stub(d, world().global (rewrite(d->type()), d->is_mutable())); }
105// clang-format on
106
107const Def* Rewriter::rewrite_imm_Axm(const Axm* a) {
108 if (&a->world() != &world()) {
109 auto type = rewrite(a->type());
110 return world().axm(a->normalizer(), a->curry(), a->trip(), type, a->plugin(), a->tag(), a->sub());
111 }
112 return a;
113}
114
115const Def* Rewriter::rewrite_imm_Extract(const Extract* ex) {
116 auto new_index = rewrite(ex->index());
117 if (auto index = Lit::isa(new_index)) {
118 if (auto tuple = ex->tuple()->isa<Tuple>()) return map(ex, rewrite(tuple->op(*index)));
119 if (auto pack = ex->tuple()->isa_imm<Pack>(); pack && pack->arity()->is_closed())
120 return map(ex, rewrite(pack->body()));
121 }
122
123 auto new_tuple = rewrite(ex->tuple());
124 return world().extract(new_tuple, new_index);
125}
126
127const Def* Rewriter::rewrite_mut_Hole(Hole* hole) {
128 auto [last, op] = hole->find();
129 return op ? rewrite(op) : rewrite_stub(last, world().mut_hole(rewrite(last->type())));
130}
131
132#endif
133
135 auto new_arity = rewrite(seq->arity());
136 if (auto l = Lit::isa(new_arity); l && *l == 0) return world().prod(seq->is_intro());
137 return world().seq(seq->is_intro(), new_arity, rewrite(seq->body()));
138}
139
141 if (!seq->is_set()) {
142 auto new_seq = seq->as_mut<Seq>()->stub(world(), rewrite(seq->type()));
143 return map(seq, new_seq);
144 }
145
146 auto new_arity = rewrite(seq->arity());
147 auto l = Lit::isa(new_arity);
148 if (l && *l == 0) return world().prod(seq->is_intro());
149
150 if (auto var = seq->has_var(); var && l && *l <= world().flags().scalarize_threshold) {
151 auto new_ops = absl::FixedArray<const Def*>(*l);
152 for (size_t i = 0, e = *l; i != e; ++i) {
153 push();
154 map(var, world().lit_idx(e, i));
155 new_ops[i] = rewrite(seq->body());
156 pop();
157 }
158 return map(seq, world().prod(seq->is_intro(), new_ops));
159 }
160
161 if (!seq->has_var()) return map(seq, world().seq(seq->is_intro(), new_arity, rewrite(seq->body())));
162 return rewrite_stub(seq->as_mut(), world().mut_seq(seq->is_term(), rewrite(seq->type())));
163}
164
165const Def* Rewriter::rewrite_stub(Def* old_mut, Def* new_mut) {
166 map(old_mut, new_mut);
167
168 if (old_mut->is_set()) {
169 for (size_t i = 0, e = old_mut->num_ops(); i != e; ++i)
170 new_mut->set(i, rewrite(old_mut->op(i)));
171 if (auto new_imm = new_mut->immutabilize()) return map(old_mut, new_imm);
172 }
173
174 return new_mut;
175}
176
177/*
178 * VarRewriter
179 */
180
181const Def* VarRewriter::rewrite(const Def* old_def) {
182 if (auto new_def = lookup(old_def)) return new_def;
183
184 if (auto old_mut = old_def->isa_mut())
185 return has_intersection(old_mut) ? rewrite_mut(old_mut)->set(old_mut->dbg()) : old_mut;
186
187 if (old_def->local_vars().empty() && old_def->local_muts().empty()) return old_def; // safe to skip
188
189 return has_intersection(old_def) ? rewrite_imm(old_def)->set(old_def->dbg()) : old_def;
190}
191
193 if (auto var = mut->has_var()) {
194 auto& vars = vars_.back();
195 vars = world().vars().insert(vars, var);
196 }
197
198 return Rewriter::rewrite_mut(mut);
199}
200
201/*
202 * Zonker
203 */
204
205const Def* Zonker::map(const Def* old_def, const Def* new_def) {
206 auto repr = lookup(new_def); // always normalize new_def to its representative
207 if (!repr) repr = new_def;
208 return old2news_.back()[old_def] = repr;
209}
210
211const Def* Zonker::lookup(const Def* old_def) {
212 for (auto& old2new : old2news_ | std::views::reverse) {
213 const Def* repr;
214 auto path = DefVec();
215 while (true) {
216 repr = get(old_def);
217
218 if (repr == nullptr) break;
219
220 path.emplace_back(repr);
221 if (repr == old_def) break; // explicit self-map
222
223 old_def = repr;
224 }
225
226 if (path.empty()) continue;
227
228 // path compression: flatten all visited nodes
229 for (auto def : path)
230 old2new[def] = repr;
231
232 return repr;
233 }
234
235 return nullptr;
236}
237
238const Def* Zonker::rewrite(const Def* def) {
239 if (auto hole = def->isa_mut<Hole>()) {
240 auto [last, op] = hole->find();
241 def = op ? op : last;
242 }
243
244 return def->needs_zonk() ? Rewriter::rewrite(def) : def;
245}
246
248 map(mut, mut);
249
250 auto old_type = mut->type();
251 auto old_ops = absl::FixedArray<const Def*>(mut->ops().begin(), mut->ops().end());
252
253 mut->unset()->set_type(rewrite(old_type));
254
255 for (size_t i = 0, e = mut->num_ops(); i != e; ++i)
256 mut->set(i, rewrite(old_ops[i]));
257
258 if (auto new_imm = mut->immutabilize()) return map(mut, new_imm);
259
260 return mut;
261}
262
263} // namespace mim
A (possibly paramterized) Array.
Definition tuple.h:117
Definition axm.h:9
Base class for all Defs.
Definition def.h:251
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:295
constexpr Node node() const noexcept
Definition def.h:274
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:263
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:494
Def * set_type(const Def *)
Update type.
Definition def.cpp:280
bool is_intro() const noexcept
Definition def.h:284
constexpr auto ops() const noexcept
Definition def.h:305
Vars local_vars() const
Vars reachable by following immutable deps().
Definition def.cpp:348
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:485
bool is_term() const
Definition def.cpp:487
const Def * op(size_t i) const noexcept
Definition def.h:308
virtual const Def * immutabilize()
Tries to make an immutable from a mutable.
Definition def.h:556
Muts local_muts() const
Mutables reachable by following immutable deps(); mut->local_muts() is by definition the set { mut }...
Definition def.cpp:332
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:447
virtual const Def * arity() const
Definition def.cpp:553
Def * unset()
Unsets all Def::ops; works even, if not set at all or only partially set.
Definition def.cpp:286
bool needs_zonk() const
Yields true, if Def::local_muts() contain a Hole that is set.
Definition check.cpp:12
Dbg dbg() const
Definition def.h:505
const Var * has_var()
Only returns not nullptr, if Var of this mutable has ever been created.
Definition def.h:433
constexpr size_t num_ops() const noexcept
Definition def.h:309
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:206
This node is a hole in the IR that is inferred by its context later on.
Definition check.h:14
A built-in constant of type Nat -> *.
Definition def.h:859
Constructs a Join value.
Definition lattice.h:70
Creates a new Tuple / Pack by inserting Insert::value at position Insert::index into Insert::tuple.
Definition tuple.h:233
A function.
Definition lam.h:111
static std::optional< T > isa(const Def *def)
Definition def.h:824
Scrutinize Match::scrutinee() and dispatch to Match::arms.
Definition lattice.h:116
Constructs a Meet value.
Definition lattice.h:53
A (possibly paramterized) Tuple.
Definition tuple.h:166
A dependent function type.
Definition lam.h:15
virtual const Def * rewrite_imm_Seq(const Seq *seq)
Definition rewrite.cpp:134
virtual const Def * rewrite_mut_Seq(Seq *seq)
Definition rewrite.cpp:140
World & world()
Definition rewrite.h:47
virtual const Def * rewrite_mut(Def *)
Definition rewrite.cpp:48
virtual void push()
Definition rewrite.h:52
virtual const Def * rewrite_stub(Def *, Def *)
Definition rewrite.cpp:165
virtual const Def * map(const Def *old_def, const Def *new_def)
Definition rewrite.h:59
virtual void pop()
Definition rewrite.h:53
std::deque< Def2Def > old2news_
Definition rewrite.h:107
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:39
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:27
virtual const Def * lookup(const Def *old_def)
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.h:69
A rewrite rule.
Definition rule.h:38
Base class for Arr and Pack.
Definition tuple.h:84
const Def * body() const
Definition tuple.h:91
constexpr bool empty() const noexcept
Is empty?
Definition sets.h:240
A dependent tuple type.
Definition tuple.h:20
Picks the aspect of a Meet [value](Pick::value) by its [type](Def::type).
Definition lattice.h:93
Data constructor for a Sigma.
Definition tuple.h:68
A singleton wraps a type into a higher order type.
Definition lattice.h:180
const Def * rewrite_mut(Def *) final
Definition rewrite.cpp:192
const Def * rewrite(const Def *) final
Definition rewrite.cpp:181
A variable introduced by a binder (mutable).
Definition def.h:700
const Def * insert(const Def *d, const Def *i, const Def *val)
Definition world.cpp:424
const Def * meet(Defs ops)
Definition world.h:481
const Def * uinc(const Def *op, level_t offset=1)
Definition world.cpp:116
const Lit * lit(const Def *type, u64 val)
Definition world.cpp:507
const Type * type(const Def *level)
Definition world.cpp:107
const Def * sigma(Defs ops)
Definition world.cpp:277
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:195
const Def * match(Defs)
Definition world.cpp:594
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
Definition world.h:265
const Def * seq(bool term, const Def *arity, const Def *body)
Definition world.cpp:473
const Univ * univ()
Definition world.h:214
const Def * bot(const Def *type)
Definition world.h:473
const Idx * type_idx()
Definition world.h:499
const Nat * type_nat()
Definition world.h:498
const Lam * lam(const Pi *pi, Lam::Filter f, const Def *body)
Definition world.h:293
const Def * tuple(Defs ops)
Definition world.cpp:287
const Def * inj(const Def *type, const Def *value)
Definition world.cpp:579
const Axm * axm(NormalizeFn n, u8 curry, u8 trip, const Def *type, plugin_t p, tag_t t, sub_t s)
Definition world.h:247
const Def * extract(const Def *d, const Def *i)
Definition world.cpp:345
const Def * join(Defs ops)
Definition world.h:480
const Proxy * proxy(const Def *type, Defs ops, u32 index, u32 tag)
Definition world.h:230
const Def * var(Def *mut)
Definition world.cpp:178
const Def * uniq(const Def *inhabitant)
Definition world.cpp:630
const Def * prod(bool term, Defs ops)
Definition world.h:386
const Def * umax(Defs)
Definition world.cpp:136
const Def * merge(const Def *type, Defs ops)
Definition world.cpp:561
const Def * top(const Def *type)
Definition world.h:474
auto & vars()
Definition world.h:551
const Def * split(const Def *type, const Def *value)
Definition world.cpp:587
const Rule * rule(const Reform *type, const Def *lhs, const Def *rhs, const Def *guard)
Definition world.h:323
const Def * rewire_mut(Def *)
Definition rewrite.cpp:247
const Def * lookup(const Def *old_def) final
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.cpp:211
const Def * rewrite(const Def *) final
Definition rewrite.cpp:238
const Def * map(const Def *old_def, const Def *new_def) final
Definition rewrite.cpp:205
#define MIM_MUT_NODE(m)
Definition def.h:52
#define MIM_IMM_NODE(m)
Definition def.h:36
const Def * op(trait o, const Def *type)
Definition core.h:33
Definition ast.h:14
View< const Def * > Defs
Definition def.h:76
Vector< const Def * > DefVec
Definition def.h:77
TBound< true > Join
AKA union.
Definition lattice.h:174
TExt< true > Top
Definition lattice.h:172
TExt< false > Bot
Definition lattice.h:171
TBound< false > Meet
AKA intersection.
Definition lattice.h:173
@ Nat
Definition def.h:114
@ Pi
Definition def.h:114
@ Pack
Definition def.h:114
@ Reform
Definition def.h:114
@ Tuple
Definition def.h:114
#define CODE_MUT(N)
Definition rewrite.h:86
#define CODE_IMM(N)
Definition rewrite.h:85