MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
nest.cpp
Go to the documentation of this file.
1#include "mim/nest.h"
2
3#include "mim/world.h"
4
5namespace mim {
6
8 : world_(r->world())
9 , root_(make_node(r)) {
10 populate();
11}
12
14 : world_(muts.front()->world())
15 , root_(make_node(nullptr)) {
16 for (auto mut : muts)
17 make_node(mut, root_);
18 populate();
19}
20
22 : world_(world)
23 , root_(make_node(nullptr)) {
24 world.for_each(false, [this](Def* mut) { make_node(mut, root_); });
25 populate();
26}
27
28void Nest::populate() {
29 std::queue<Node*> queue;
30
31 if (root()->mut())
32 queue.push(root_);
33 else
34 for (auto child : root_->children().nodes())
35 queue.push(child);
36
37 while (!queue.empty()) {
38 auto curr_node = pop(queue);
39 for (auto op : curr_node->mut()->deps()) {
40 for (auto local_mut : op->local_muts()) {
41 if ((*this)[local_mut] || !contains(local_mut)) continue;
42
43 if (curr_node->level() < local_mut->free_vars().size()) {
44 for (auto node = curr_node;; node = node->inest_) {
45 if (auto var = node->mut()->has_var()) {
46 if (local_mut->free_vars().contains(var)) {
47 queue.push(make_node(local_mut, node));
48 break;
49 }
50 }
51 }
52 } else {
53 uint32_t max = 0;
54 auto inest = root_;
55 for (auto var : local_mut->free_vars()) {
56 if (auto node = (*this)[var->mut()]; node && node->level() > max) {
57 max = node->level();
58 inest = node;
59 }
60 }
61 queue.push(make_node(local_mut, inest));
62 }
63 }
64 }
65 }
66}
67
68Nest::Node* Nest::make_node(Def* mut, Node* inest) {
69 auto node = std::unique_ptr<Node>(new Node(*this, mut, inest)); // can't use make_unique - c'tor is private
70 auto res = node.get();
71 mut2node_.emplace(mut, std::move(node));
72 if (mut) {
73 if (auto var = mut->has_var()) vars_ = world().vars().insert(vars_, var);
74 }
75 return res;
76}
77
78void Nest::assign_postorder_numbers() const {
79 if (root()->postorder_number_.has_value()) return;
80
81 auto number = 0;
82
83 std::function<void(const Nest::Node*)> visit = [&](const Nest::Node* node) {
84 if (node->postorder_number_.has_value()) return; // already visited
85 node->postorder_number_ = 0; // mark in progress
86 for (auto op : node->mut()->deps()) {
87 for (auto mut : op->local_muts())
88 if (auto succ = node->nest()[mut]) visit(succ);
89 }
90 node->postorder_number_ = ++number;
91 };
92
93 if (root()->mut()) {
94 // not virtual root, visit
95 visit(root());
96 } else {
97 // virtual root, visit children
98 for (auto [_, node] : root()->children())
99 visit(node);
100 root()->postorder_number_ = ++number;
101 }
102}
103
104template<bool bootstrapping>
105const Nest::Node* Nest::lca(const Node* n, const Node* m) {
106 while (n != m) {
107 // Nest::lca is also used within with_dominance and should not call it recursively
108 if constexpr (!bootstrapping) {
109 n->calc_dominance();
110 m->calc_dominance();
111 }
112 if (n->postorder_number_ < m->postorder_number_)
113 n = n->idom_ ? n->idom_ : n->inest();
114 else if (m->postorder_number_ < n->postorder_number_)
115 m = m->idom_ ? m->idom_ : m->inest();
116 }
117
118 return n;
119}
120
121void Nest::calc_sibl_deps(Node* curr) const {
122 if (curr->mut()) {
123 for (auto op : curr->mut()->deps()) {
124 for (auto local_mut : op->local_muts()) {
125 if (auto local_node = const_cast<Nest&>(*this)[local_mut]) {
126 if (local_node == curr)
127 local_node->link(local_node);
128 else if (auto inest = local_node->inest()) {
129 if (auto curr_child = inest->curr_child) {
130 assert(inest->children().contains(curr_child->mut()));
131 curr_child->link(local_node);
132 }
133 }
134 }
135 }
136 }
137 }
138
139 for (auto child : curr->children().nodes()) {
140 curr->curr_child = child;
141 calc_sibl_deps(child);
142 curr->curr_child = nullptr;
143 }
144}
145
146void Nest::calc_SCCs(Node* curr) const {
147 curr->calc_SCCs();
148 for (auto [_, child] : curr->children()) {
149 child->loop_depth_ = child->is_recursive() ? curr->loop_depth() + 1 : curr->loop_depth();
150 calc_SCCs(child);
151 }
152}
153
154void Nest::Node::calc_SCCs() {
155 Stack stack;
156 for (int i = 0; auto& [_, node] : children())
157 if (node->idx_ == Unvisited) i = node->tarjan(i, this, stack);
158}
159
160uint32_t Nest::Node::tarjan(uint32_t i, Node* inest, Stack& stack) {
161 this->idx_ = this->low_ = i++;
162 this->on_stack_ = true;
163 stack.emplace(this);
164
165 for (auto dep : this->sibl_deps_.nodes_) {
166 if (dep->idx_ == Unvisited) i = dep->tarjan(i, inest, stack);
167 if (dep->on_stack_) this->low_ = std::min(this->low_, dep->low_);
168 }
169
170 if (this->idx_ == this->low_) {
171 inest_->topo_.emplace_front(std::make_unique<SCC>());
172 SCC* scc = inest_->topo_.front().get();
173 Node* node;
174 int num = 0;
175 do {
176 node = pop(stack);
177 node->on_stack_ = false;
178 node->recursive_ = true;
179 node->low_ = this->idx_;
180 ++num;
181
182 scc->emplace(node);
183 auto [_, ins] = inest_->SCCs_.emplace(node, scc);
184 assert_unused(ins);
185 } while (node != this);
186
187 if (num == 1 && !this->sibl_deps().contains(this)) this->recursive_ = false;
188 }
189
190 return i;
191}
192
193/// Calculates dominance using Cooper-Harvey-Kennedy algorithm
194/// from Cooper et al, "A Simple, Fast Dominance Algorithm".
195/// https://www.clear.rice.edu/comp512/Lectures/Papers/TR06-33870-Dom.pdf
196const Nest::Node* Nest::Node::calc_dominance() const {
197 if (idom_ || is_root() || !inest()->mut()) return this;
198 nest().assign_postorder_numbers();
199
200 if (!inest()->mut()) idom_ = inest();
201
202 // Holds all siblings in reverse post-order coming from the parent
203 absl::flat_hash_set<const Node*> visited;
205
206 // Initialize entry nodes directly referenced by the parent
207 for (auto op : inest()->mut()->deps()) {
208 for (auto local_mut : op->local_muts())
209 if (auto node = nest()[local_mut]; node && node->inest() == inest()) node->idom_ = inest();
210 }
211
212 std::function<void(const Node*)> visit = [&](const Node* node) {
213 if (visited.contains(node)) return; // already visited
214 visited.insert(node);
215 for (auto child : node->sibl_deps())
216 visit(child);
217 nodes.push_back(node);
218 };
219
220 // Traverse siblings in postorder
221 for (auto op : inest()->mut()->deps()) {
222 for (auto mut : op->local_muts())
223 if (auto node = nest()[mut]; node && node->inest() == inest()) visit(node);
224 }
225
226 // Actual dominance algorithm
227 for (bool todo = true; todo;) {
228 todo = false;
229 for (auto node : nodes | std::ranges::views::reverse) {
230 // skip entry nodes
231 if (node->idom_ == inest()) continue;
232
233 const Node* new_idom = nullptr;
234 for (auto user : node->sibl_deps<false>())
235 if (user->idom_) new_idom = new_idom ? Nest::lca<true>(new_idom, user) : user;
236 if (node->idom_ != new_idom) {
237 node->idom_ = new_idom;
238 todo = true;
239 }
240 }
241 }
242
243 return this;
244}
245
246template const Nest::Node* Nest::lca<true>(const Node*, const Node*);
247template const Nest::Node* Nest::lca<false>(const Node*, const Node*);
248
249} // namespace mim
Base class for all Defs.
Definition def.h:251
Muts local_muts() const
Mutables reachable by following immutable deps(); mut->local_muts() is by definition the set { mut }...
Definition def.cpp:331
const Children & children() const
Definition nest.h:73
const Node * inest() const
Immediate nester/parent of this Node.
Definition nest.h:20
Builds a nesting tree of all mutables‍/binders.
Definition nest.h:12
auto nodes() const
Definition nest.h:213
World & world() const
Definition nest.h:201
static const Node * lca(const Node *n, const Node *m)
Least common ancestor of n and m.
Definition nest.cpp:105
auto muts() const
Definition nest.h:212
const Node * root() const
Definition nest.h:202
Nest(Def *root)
Definition nest.cpp:7
bool contains(const Def *def) const
Definition nest.h:204
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:31
auto & vars()
Definition world.h:583
const Def * op(trait o, const Def *type)
Definition core.h:33
Definition ast.h:14
auto pop(S &s) -> decltype(s.top(), typename S::value_type())
Definition util.h:83
Span< const T, N > View
Definition span.h:98
Node
Definition def.h:112
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >
auto nodes() const
Definition nest.h:43