MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
rewrite.h
Go to the documentation of this file.
1#pragma once
2
3#include <memory>
4
5#include "mim/check.h"
6#include "mim/def.h"
7#include "mim/lam.h"
8#include "mim/lattice.h"
9#include "mim/rule.h"
10#include "mim/tuple.h"
11
12namespace mim {
13
14class World;
15
16/// Recurseivly rebuilds part of a program **into** the provided World w.r.t.\ Rewriter::map.
17/// This World may be different than the World we started with.
18class Rewriter {
19public:
20 /// @name Construction & Destruction
21 ///@{
22 Rewriter(std::unique_ptr<World>&& ptr)
23 : ptr_(std::move(ptr))
24 , world_(ptr_.get()) {
25 push(); // create root map
26 }
28 : world_(&world) {
29 push(); // create root map
30 }
31 virtual ~Rewriter() = default;
32
33 void reset(std::unique_ptr<World>&& ptr) {
34 ptr_ = std::move(ptr);
35 world_ = ptr_.get();
36 reset();
37 }
38 void reset() {
39 pop();
40 assert(old2news_.empty());
41 push();
42 }
43 ///@}
44
45 /// @name Getters
46 ///@{
47 World& world() { return *world_; }
48 ///@}
49
50 /// @name Push / Pop
51 ///@{
52 virtual void push() { old2news_.emplace_back(Def2Def{}); }
53 virtual void pop() { old2news_.pop_back(); }
54 ///@}
55
56 /// @name Map / Lookup
57 /// Map @p old_def to @p new_def and returns @p new_def.
58 ///@{
59 virtual const Def* map(const Def* old_def, const Def* new_def) { return old2news_.back()[old_def] = new_def; }
60
61 // clang-format off
62 const Def* map(const Def* old_def , Defs new_defs);
63 const Def* map(Defs old_defs, const Def* new_def );
64 const Def* map(Defs old_defs, Defs new_defs);
65 // clang-format on
66
67 /// Lookup `old_def` by searching in reverse through the stack of maps.
68 /// @returns `nullptr` if nothing was found.
69 virtual const Def* lookup(const Def* old_def) {
70 for (const auto& old2new : old2news_ | std::views::reverse)
71 if (auto i = old2new.find(old_def); i != old2new.end()) return i->second;
72 return nullptr;
73 }
74 ///@}
75
76 /// @name rewrite
77 /// Recursively rewrite old Def%s.
78 ///@{
79 virtual const Def* rewrite(const Def*);
80 virtual const Def* rewrite_imm(const Def*);
81 virtual const Def* rewrite_mut(Def*);
82 virtual const Def* rewrite_stub(Def*, Def*);
83 virtual DefVec rewrite(Defs);
84
85#define CODE_IMM(N) virtual const Def* rewrite_imm_##N(const N*);
86#define CODE_MUT(N) virtual const Def* rewrite_mut_##N(N*);
89#undef CODE_IMM
90#undef CODE_MUT
91
92 virtual const Def* rewrite_imm_Seq(const Seq* seq);
93 virtual const Def* rewrite_mut_Seq(Seq* seq);
94 ///@}
95
96 friend void swap(Rewriter& rw1, Rewriter& rw2) noexcept {
97 using std::swap;
98 swap(rw1.old2news_, rw2.old2news_);
99 // do NOT back pointers ptr_ and world_
100 }
101
102private:
103 std::unique_ptr<World> ptr_;
104 World* world_;
105
106protected:
107 std::deque<Def2Def> old2news_;
108};
109
110class VarRewriter : public Rewriter {
111public:
112 /// @name Construction
113 ///@{
116 VarRewriter(const Var* var, const Def* arg)
117 : Rewriter(arg->world()) {
118 add(var, arg);
119 }
120
121 // Add initial mapping from @pvar -> @p arg.
122 VarRewriter& add(const Var* var, const Def* arg) {
123 map(var, arg);
124 vars_.emplace_back(var);
125 return *this;
126 }
127 ///@}
128
129 /// @name push / pop
130 ///@{
131 void push() final { Rewriter::push(), vars_.emplace_back(Vars()); }
132 void pop() final { vars_.pop_back(), Rewriter::pop(); }
133 ///@}
134
135 /// @name rewrite
136 ///@{
137 const Def* rewrite(const Def*) final;
138 const Def* rewrite_mut(Def*) final;
139 ///@}
140
141 friend void swap(VarRewriter& vrw1, VarRewriter& vrw2) noexcept {
142 using std::swap;
143 swap(static_cast<Rewriter&>(vrw1), static_cast<Rewriter&>(vrw2));
144 swap(vrw1.vars_, vrw2.vars_);
145 }
146
147private:
148 bool has_intersection(const Def* old_def) {
149 for (const auto& vars : vars_ | std::views::reverse)
150 if (vars.has_intersection(old_def->free_vars())) return true;
151 return false;
152 }
153
154 Vector<Vars> vars_;
155};
156
157class Zonker : public Rewriter {
158public:
159 /// @name C'tor
160 ///@{
163 ///@}
164
165 /// @name Stack of Maps
166 ///@{
167 const Def* map(const Def* old_def, const Def* new_def) final;
168 const Def* lookup(const Def* old_def) final;
169 ///@}
170
171 /// @name rewrite
172 ///@{
173 const Def* rewrite(const Def*) final;
174 const Def* rewrite_mut(Def* mut) final { return map(mut, mut); }
175 const Def* rewire_mut(Def*);
176 ///@}
177
178 friend void swap(Zonker& z1, Zonker& z2) noexcept {
179 using std::swap;
180 swap(static_cast<Rewriter&>(z1), static_cast<Rewriter&>(z2));
181 }
182
183private:
184 const Def* get(const Def* old_def) {
185 auto& old2new = old2news_.back();
186 if (auto i = old2new.find(old_def); i != old2new.end()) return i->second;
187 return nullptr;
188 }
189};
190
191} // namespace mim
Base class for all Defs.
Definition def.h:251
Vars free_vars() const
Compute a global solution by transitively following mutables as well.
Definition def.cpp:337
virtual ~Rewriter()=default
friend void swap(Rewriter &rw1, Rewriter &rw2) noexcept
Definition rewrite.h:96
virtual const Def * rewrite_imm_Seq(const Seq *seq)
Definition rewrite.cpp:134
virtual const Def * rewrite_mut_Seq(Seq *seq)
Definition rewrite.cpp:140
World & world()
Definition rewrite.h:47
virtual const Def * rewrite_mut(Def *)
Definition rewrite.cpp:48
virtual void push()
Definition rewrite.h:52
virtual const Def * rewrite_stub(Def *, Def *)
Definition rewrite.cpp:165
virtual const Def * map(const Def *old_def, const Def *new_def)
Definition rewrite.h:59
virtual void pop()
Definition rewrite.h:53
void reset(std::unique_ptr< World > &&ptr)
Definition rewrite.h:33
Rewriter(World &world)
Definition rewrite.h:27
void reset()
Definition rewrite.h:38
std::deque< Def2Def > old2news_
Definition rewrite.h:107
virtual const Def * rewrite_imm(const Def *)
Definition rewrite.cpp:39
Rewriter(std::unique_ptr< World > &&ptr)
Definition rewrite.h:22
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:27
virtual const Def * lookup(const Def *old_def)
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.h:69
VarRewriter & add(const Var *var, const Def *arg)
Definition rewrite.h:122
void pop() final
Definition rewrite.h:132
const Def * rewrite_mut(Def *) final
Definition rewrite.cpp:192
void push() final
Definition rewrite.h:131
VarRewriter(World &world)
Definition rewrite.h:114
friend void swap(VarRewriter &vrw1, VarRewriter &vrw2) noexcept
Definition rewrite.h:141
const Def * rewrite(const Def *) final
Definition rewrite.cpp:181
VarRewriter(const Var *var, const Def *arg)
Definition rewrite.h:116
A variable introduced by a binder (mutable).
Definition def.h:700
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:32
const Def * rewire_mut(Def *)
Definition rewrite.cpp:247
friend void swap(Zonker &z1, Zonker &z2) noexcept
Definition rewrite.h:178
const Def * lookup(const Def *old_def) final
Lookup old_def by searching in reverse through the stack of maps.
Definition rewrite.cpp:211
const Def * rewrite(const Def *) final
Definition rewrite.cpp:238
const Def * map(const Def *old_def, const Def *new_def) final
Definition rewrite.cpp:205
Zonker(World &world)
Definition rewrite.h:161
const Def * rewrite_mut(Def *mut) final
Definition rewrite.h:174
#define MIM_MUT_NODE(m)
Definition def.h:52
#define MIM_IMM_NODE(m)
Definition def.h:36
Definition ast.h:14
View< const Def * > Defs
Definition def.h:76
DefMap< const Def * > Def2Def
Definition def.h:75
Vector< const Def * > DefVec
Definition def.h:77
constexpr decltype(auto) get(Span< T, N > span) noexcept
Definition span.h:115
Sets< const Var >::Set Vars
Definition def.h:97
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >
Definition span.h:122
#define CODE_MUT(N)
Definition rewrite.h:86
#define CODE_IMM(N)
Definition rewrite.h:85