MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
rewrite.h
Go to the documentation of this file.
1#pragma once
2
3#include <memory>
4
5#include "mim/world.h"
6
7namespace mim {
8
9/// Recurseivly rebuilds part of a program **into** the provided World w.r.t.\ Rewriter::map.
10/// This World may be different than the World we started with.
11class Rewriter {
12public:
13 Rewriter(std::unique_ptr<World>&& ptr)
14 : ptr_(std::move(ptr))
15 , world_(ptr_.get()) {
16 push(); // create root map
17 }
18
20 : world_(&world) {
21 push(); // create root map
22 }
23
24 void reset(std::unique_ptr<World>&& ptr) {
25 ptr_ = std::move(ptr);
26 world_ = ptr_.get();
27 reset();
28 }
29
30 void reset() {
31 pop();
32 assert(old2news_.empty());
33 push();
34 }
35
36 World& world() { return *world_; }
37
38 /// @name Stack of Maps
39 ///@{
40 virtual void push() { old2news_.emplace_back(Def2Def{}); }
41 virtual void pop() { old2news_.pop_back(); }
42
43 /// Map @p old_def to @p new_def and returns @p new_def.
44 /// @returns `new_def`
45 // clang-format off
46 const Def* map(const Def* old_def , const Def* new_def ) { return old2news_.back()[ old_def ] = new_def ; }
47 const Def* map(const Def* old_def , Defs new_defs) { return old2news_.back()[ old_def ] = world().tuple(new_defs); }
48 const Def* map(Defs old_defs, const Def* new_def ) { return old2news_.back()[world().tuple(old_defs)] = new_def ; }
49 const Def* map(Defs old_defs, Defs new_defs) { return old2news_.back()[world().tuple(old_defs)] = world().tuple(new_defs); }
50 // clang-format on
51
52 /// Lookup `old_def` by searching in reverse through the stack of maps.
53 /// @returns `nullptr` if nothing was found.
54 const Def* lookup(const Def* old_def) {
55 for (const auto& old2new : old2news_ | std::views::reverse)
56 if (auto i = old2new.find(old_def); i != old2new.end()) return i->second;
57 return nullptr;
58 }
59 ///@}
60
61 /// @name rewrite
62 /// Recursively rewrite old Def%s.
63 ///@{
64 virtual const Def* rewrite(const Def*);
65 virtual const Def* rewrite_imm(const Def*);
66 virtual const Def* rewrite_mut(Def*);
67 virtual const Def* rewrite_stub(Def*, Def*);
68 virtual DefVec rewrite(Defs);
69
70#define CODE_IMM(N) virtual const Def* rewrite_imm_##N(const N*);
71#define CODE_MUT(N) virtual const Def* rewrite_mut_##N(N*);
74#undef CODE_IMM
75#undef CODE_MUT
76
77 virtual const Def* rewrite_imm_Seq(const Seq* seq);
78 virtual const Def* rewrite_mut_Seq(Seq* seq);
79 ///@}
80
81private:
82 std::unique_ptr<World> ptr_;
83 World* world_;
84 std::deque<Def2Def> old2news_;
85};
86
87class VarRewriter : public Rewriter {
88public:
91
92 VarRewriter(const Var* var, const Def* arg)
93 : Rewriter(arg->world()) {
94 add(var, arg);
95 }
96
97 void add(const Var* var, const Def* arg) {
98 map(var, arg);
99 vars_.emplace_back(var);
100 }
101
102 void push() final {
104 vars_.emplace_back(Vars());
105 }
106
107 void pop() final {
108 vars_.pop_back();
110 }
111
112 const Def* rewrite(const Def* old_def) final {
113 if (auto new_def = lookup(old_def)) return new_def;
114
115 if (auto old_mut = old_def->isa_mut())
116 return has_intersection(old_mut) ? rewrite_mut(old_mut)->set(old_mut->dbg()) : old_mut;
117
118 if (old_def->local_vars().empty() && old_def->local_muts().empty()) return old_def; // safe to skip
119
120 return has_intersection(old_def) ? rewrite_imm(old_def)->set(old_def->dbg()) : old_def;
121 }
122
123 const Def* rewrite_mut(Def* mut) final {
124 if (auto var = mut->has_var()) {
125 auto& vars = vars_.back();
126 vars = world().vars().insert(vars, var);
127 }
128
129 return Rewriter::rewrite_mut(mut);
130 }
131
132private:
133 bool has_intersection(const Def* old_def) {
134 for (const auto& vars : vars_ | std::views::reverse)
135 if (vars.has_intersection(old_def->free_vars())) return true;
136 return false;
137 }
138
139 Vector<Vars> vars_;
140};
141
142} // namespace mim
Base class for all Defs.
Definition def.h:251
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:266
Vars free_vars() const
Compute a global solution by transitively following mutables as well.
Definition def.cpp:334
const Def * map(const Def *old_def, Defs new_defs)
Definition rewrite.h:47
const Def * lookup(const Def *old_def)
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.h:54
virtual const Def * rewrite_imm_Seq(const Seq *seq)
Definition rewrite.cpp:121
virtual const Def * rewrite_mut_Seq(Seq *seq)
Definition rewrite.cpp:127
World & world()
Definition rewrite.h:36
virtual const Def * rewrite_mut(Def *)
Definition rewrite.cpp:35
virtual void push()
Definition rewrite.h:40
virtual const Def * rewrite_stub(Def *, Def *)
Definition rewrite.cpp:152
const Def * map(const Def *old_def, const Def *new_def)
Map old_def to new_def and returns new_def.
Definition rewrite.h:46
virtual void pop()
Definition rewrite.h:41
void reset(std::unique_ptr< World > &&ptr)
Definition rewrite.h:24
Rewriter(World &world)
Definition rewrite.h:19
void reset()
Definition rewrite.h:30
const Def * map(Defs old_defs, const Def *new_def)
Definition rewrite.h:48
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:26
const Def * map(Defs old_defs, Defs new_defs)
Definition rewrite.h:49
Rewriter(std::unique_ptr< World > &&ptr)
Definition rewrite.h:13
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:14
void add(const Var *var, const Def *arg)
Definition rewrite.h:97
void pop() final
Definition rewrite.h:107
const Def * rewrite(const Def *old_def) final
Definition rewrite.h:112
const Def * rewrite_mut(Def *mut) final
Definition rewrite.h:123
void push() final
Definition rewrite.h:102
VarRewriter(World &world)
Definition rewrite.h:89
VarRewriter(const Var *var, const Def *arg)
Definition rewrite.h:92
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:36
const Def * tuple(Defs ops)
Definition world.cpp:288
auto & vars()
Definition world.h:554
#define MIM_MUT_NODE(m)
Definition def.h:52
#define MIM_IMM_NODE(m)
Definition def.h:36
Definition ast.h:14
View< const Def * > Defs
Definition def.h:76
DefMap< const Def * > Def2Def
Definition def.h:75
Vector< const Def * > DefVec
Definition def.h:77
constexpr decltype(auto) get(Span< T, N > span) noexcept
Definition span.h:115
Sets< const Var >::Set Vars
Definition def.h:97
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >
Definition span.h:122
#define CODE_MUT(N)
Definition rewrite.h:71
#define CODE_IMM(N)
Definition rewrite.h:70