MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
nest.cpp
Go to the documentation of this file.
1#include "mim/nest.h"
2
3#include "mim/world.h"
4
5#include "mim/phase/phase.h"
6
7namespace mim {
8
10 : world_(r->world())
11 , root_(make_node(r)) {
12 populate();
13}
14
16 : world_(muts.front()->world())
17 , root_(make_node(nullptr)) {
18 for (auto mut : muts) make_node(mut, root_);
19 populate();
20}
21
23 : world_(world)
24 , root_(make_node(nullptr)) {
25 for (auto mut : ClosedCollector<>::collect(world)) make_node(mut, root_);
26 populate();
27}
28
29void Nest::populate() {
30 std::queue<Node*> queue;
31
32 if (root()->mut())
33 queue.push(root_);
34 else
35 for (auto [_, child] : root_->child_mut2node_) queue.push(child);
36
37 while (!queue.empty()) {
38 auto curr_node = pop(queue);
39 for (auto op : curr_node->mut()->deps()) {
40 for (auto local_mut : op->local_muts()) {
41 if (mut2node(local_mut) || !contains(local_mut)) continue;
42
43 if (curr_node->level() < local_mut->free_vars().size()) {
44 for (auto node = curr_node;; node = node->parent_) {
45 if (auto var = node->mut()->has_var()) {
46 if (local_mut->free_vars().contains(var)) {
47 queue.push(make_node(local_mut, node));
48 break;
49 }
50 }
51 }
52 } else {
53 uint32_t max = 0;
54 auto parent = root_;
55 for (auto var : local_mut->free_vars()) {
56 if (auto node = mut2node_nonconst(var->mut()); node && node->level() > max) {
57 max = node->level();
58 parent = node;
59 }
60 }
61 queue.push(make_node(local_mut, parent));
62 }
63 }
64 }
65 }
66}
67
68Nest::Node* Nest::make_node(Def* mut, Node* parent) {
69 auto node = std::unique_ptr<Node>(new Node(*this, mut, parent)); // can't use make_unique - c'tor is private
70 auto res = node.get();
71 mut2node_.emplace(mut, std::move(node));
72 if (mut) {
73 if (auto var = mut->has_var()) vars_ = world().vars().insert(vars_, var);
74 }
75 return res;
76}
77
78const Nest::Node* Nest::lca(const Node* n, const Node* m) {
79 while (n->level() < m->level()) m = m->parent();
80 while (m->level() < n->level()) n = n->parent();
81 while (n != m) {
82 // TODO support longer dep chains and and the possibility to opt out from this
83 if (n->deps().depends_.contains(m)) return n;
84 if (m->deps().depends_.contains(n)) return m;
85 n = n->parent();
86 m = m->parent();
87 }
88 return n;
89}
90
91void Nest::deps(Node* curr) const {
92 if (curr->mut()) {
93 for (auto op : curr->mut()->deps()) {
94 for (auto local_mut : op->local_muts()) {
95 if (auto local_node = mut2node_nonconst(local_mut)) {
96 if (local_node == curr)
97 local_node->link(local_node);
98 else if (auto parent = local_node->parent()) {
99 if (auto curr_child = parent->curr_child) {
100 assert(parent->child_mut2node_.contains(curr_child->mut()));
101 curr_child->link(local_node);
102 }
103 }
104 }
105 }
106 }
107 }
108
109 for (auto [_, child] : curr->child_mut2node_) {
110 curr->curr_child = child;
111 deps(child);
112 curr->curr_child = nullptr;
113 }
114}
115
116void Nest::find_SCCs(Node* curr) const {
117 curr->find_SCCs();
118 for (auto [_, child] : curr->child_mut2node_) {
119 child->loop_depth_ = child->is_recursive() ? curr->loop_depth() + 1 : curr->loop_depth();
120 find_SCCs(child);
121 }
122}
123
124void Nest::Node::find_SCCs() {
125 Stack stack;
126 for (int i = 0; auto& [_, node] : child_mut2node_)
127 if (node->idx_ == Unvisited) i = node->tarjan(i, this, stack);
128}
129
130uint32_t Nest::Node::tarjan(uint32_t i, Node* parent, Stack& stack) {
131 this->idx_ = this->low_ = i++;
132 this->on_stack_ = true;
133 stack.emplace(this);
134
135 for (auto dep : this->depends_) {
136 if (dep->idx_ == Unvisited) i = dep->tarjan(i, parent, stack);
137 if (dep->on_stack_) this->low_ = std::min(this->low_, dep->low_);
138 }
139
140 if (this->idx_ == this->low_) {
141 parent_->topo_.emplace_front(std::make_unique<SCC>());
142 SCC* scc = parent_->topo_.front().get();
143 Node* node;
144 int num = 0;
145 do {
146 node = pop(stack);
147 node->on_stack_ = false;
148 node->recursive_ = true;
149 node->low_ = this->idx_;
150 ++num;
151
152 scc->emplace(node);
153 auto [_, ins] = parent_->SCCs_.emplace(node, scc);
154 assert_unused(ins);
155 } while (node != this);
156
157 if (num == 1 && !this->depends_.contains(this)) this->recursive_ = false;
158 }
159
160 return i;
161}
162
163} // namespace mim
static Vector< M * > collect(World &world)
Wrapper to directly receive all closed mutables as Vector.
Definition phase.h:189
Base class for all Defs.
Definition def.h:203
Muts local_muts() const
Mutables reachable by following immutable deps(); mut->local_muts() is by definition the set { mut }...
Definition def.cpp:279
uint32_t level() const
Definition nest.h:26
const Node * parent() const
Definition nest.h:19
World & world() const
Definition nest.h:128
static const Node * lca(const Node *n, const Node *m)
Least common ancestor of n and m.
Definition nest.cpp:78
auto muts() const
Definition nest.h:137
const Node * root() const
Definition nest.h:129
Nest(Def *root)
Definition nest.cpp:9
const auto & mut2node() const
Definition nest.h:141
bool contains(const Def *def) const
Definition nest.h:131
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:33
auto & vars()
Definition world.h:504
Definition ast.h:14
auto pop(S &s) -> decltype(s.top(), typename S::value_type())
Definition util.h:79
Span< const T, N > View
Definition span.h:93
Node
Definition def.h:83