MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
rewrite.cpp
Go to the documentation of this file.
1#include "mim/rewrite.h"
2
3#include <absl/container/fixed_array.h>
4
5#include "mim/check.h"
6#include "mim/world.h"
7
8// Don't use fancy C++-lambdas; it's way too annoying stepping through them in a debugger.
9
10namespace mim {
11
12const Def* Rewriter::rewrite(const Def* old_def) {
13 if (old_def->isa<Univ>()) return world().univ();
14 if (auto new_def = lookup(old_def)) return new_def;
15
16 // clang-format off
17 if (auto arr = old_def->isa<Arr >()) return rewrite_arr (arr );
18 if (auto pack = old_def->isa<Pack >()) return rewrite_pack (pack );
19 if (auto extract = old_def->isa<Extract >()) return rewrite_extract(extract);
20 if (auto hole = old_def->isa_mut<Hole>()) return rewrite_hole (hole );
21 // clang-format on
22
23 if (auto old_mut = old_def->isa_mut()) return rewrite_mut(old_mut);
24 return map(old_def, rewrite_imm(old_def));
25}
26
27const Def* Rewriter::rewrite_imm(const Def* old_def) {
28 auto new_type = old_def->isa<Type>() ? nullptr : rewrite(old_def->type());
29 auto size = old_def->num_ops();
30 auto new_ops = absl::FixedArray<const Def*>(size);
31 for (size_t i = 0; i != size; ++i) new_ops[i] = rewrite(old_def->op(i));
32 return old_def->rebuild(world(), new_type, new_ops);
33}
34
35const Def* Rewriter::rewrite_mut(Def* old_mut) {
36 auto new_type = rewrite(old_mut->type());
37 auto new_mut = old_mut->stub(world(), new_type);
38 map(old_mut, new_mut);
39
40 if (old_mut->is_set()) {
41 for (size_t i = 0, e = old_mut->num_ops(); i != e; ++i) new_mut->set(i, rewrite(old_mut->op(i)));
42 if (auto new_imm = new_mut->immutabilize()) return map(old_mut, new_imm);
43 }
44
45 return new_mut;
46}
47
48const Def* Rewriter::rewrite_seq(const Seq* seq) {
49 if (!seq->is_set()) {
50 auto new_seq = seq->as_mut<Seq>()->stub(world(), rewrite(seq->type()));
51 return map(seq, new_seq);
52 }
53
54 auto new_shape = rewrite(seq->shape());
55
56 if (auto l = Lit::isa(new_shape); l && *l <= world().flags().scalarize_threshold) {
57 auto new_ops = absl::FixedArray<const Def*>(*l);
58 for (size_t i = 0, e = *l; i != e; ++i) {
59 if (auto var = seq->has_var()) {
60 push();
61 map(var, world().lit_idx(e, i));
62 new_ops[i] = rewrite(seq->body());
63 pop();
64 } else {
65 new_ops[i] = rewrite(seq->body());
66 }
67 }
68 return map(seq, seq->prod(world(), new_ops));
69 }
70
71 if (!seq->has_var()) return map(seq, seq->rebuild(world(), new_shape, rewrite(seq->body())));
72 return rewrite_mut(seq->as_mut());
73}
74
76 auto new_index = rewrite(ex->index());
77 if (auto index = Lit::isa(new_index)) {
78 if (auto tuple = ex->tuple()->isa<Tuple>()) return map(ex, rewrite(tuple->op(*index)));
79 if (auto pack = ex->tuple()->isa_imm<Pack>(); pack && pack->shape()->is_closed())
80 return map(ex, rewrite(pack->body()));
81 }
82
83 auto new_tuple = rewrite(ex->tuple());
84 return map(ex, world().extract(new_tuple, new_index)->set(ex->dbg()));
85}
86
88 auto [last, op] = hole->find();
89 if (op) return rewrite(op);
90 return rewrite_mut(last);
91}
92
93} // namespace mim
A (possibly paramterized) Array.
Definition tuple.h:100
Base class for all Defs.
Definition def.h:203
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:268
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:442
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:434
const Def * op(size_t i) const noexcept
Definition def.h:269
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:247
const Def * rebuild(World &w, const Def *type, Defs ops) const
Def::rebuilds this Def while using new_op as substitute for its i'th Def::op.
Definition def.h:496
const T * isa_imm() const
Definition def.h:429
bool is_closed() const
Has no free_vars()?
Definition def.cpp:356
Dbg dbg() const
Definition def.h:453
Def * stub(World &w, const Def *type)
Definition def.h:492
const Var * has_var()
Only returns not nullptr, if Var of this mutable has ever been created.
Definition def.h:388
constexpr size_t num_ops() const noexcept
Definition def.h:270
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:192
const Def * tuple() const
Definition tuple.h:202
const Def * index() const
Definition tuple.h:203
This node is a hole in the IR that is inferred by its context later on.
Definition check.h:10
std::pair< Hole *, const Def * > find()
Transitively walks up Holes until the last one while path-compressing everything.
Definition check.cpp:70
static std::optional< T > isa(const Def *def)
Definition def.h:733
A (possibly paramterized) Tuple.
Definition tuple.h:150
const Def * shape() const final
Definition tuple.cpp:48
virtual const Def * rewrite_hole(Hole *)
Definition rewrite.cpp:87
void pop()
Definition rewrite.h:23
const Def * lookup(const Def *old_def)
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.h:31
World & world()
Definition rewrite.h:18
virtual const Def * rewrite_mut(Def *)
Definition rewrite.cpp:35
virtual const Def * rewrite_pack(const Pack *pack)
Definition rewrite.h:45
const Def * map(const Def *old_def, const Def *new_def)
Map old_def to new_def and returns new_def.
Definition rewrite.h:27
void push()
Definition rewrite.h:22
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:27
virtual const Def * rewrite_seq(const Seq *)
Definition rewrite.cpp:48
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:12
virtual const Def * rewrite_extract(const Extract *)
Definition rewrite.cpp:75
virtual const Def * rewrite_arr(const Arr *arr)
Definition rewrite.h:44
Base class for Arr and Pack.
Definition tuple.h:78
virtual const Def * prod(World &w, Defs) const =0
Creates either a Tuple or Sigma.
virtual const Def * shape() const =0
const Def * body() const
Definition tuple.h:86
virtual const Def * rebuild(World &, const Def *shape, const Def *body) const =0
Data constructor for a Sigma.
Definition tuple.h:62
const Univ * univ()
Definition world.h:200
Flags & flags()
Retrieve compile Flags.
Definition world.cpp:72
Definition ast.h:14
uint64_t scalarize_threshold
Definition flags.h:13