MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
check.cpp
Go to the documentation of this file.
1#include "mim/check.h"
2
3#include "mim/rewrite.h"
4#include "mim/world.h"
5
6namespace mim {
7
8namespace {
9
10class InferRewriter : public Rewriter {
11public:
12 InferRewriter(World& world)
13 : Rewriter(world) {}
14
15 Ref rewrite(Ref old_def) override {
16 if (Infer::should_eliminate(old_def)) return Rewriter::rewrite(old_def);
17 return old_def;
18 }
19};
20
21} // namespace
22
23/*
24 * Infer
25 */
26
27const Def* Ref::refer(const Def* def) { return def ? Infer::find(def) : nullptr; }
28
29const Def* Infer::find(const Def* def) {
30 // find root
31 auto res = def;
32 for (auto infer = res->isa_mut<Infer>(); infer && infer->op(); infer = res->isa_mut<Infer>()) res = infer->op();
33 // TODO don't re-update last infer
34
35 // path compression: set all Infers along the chain to res
36 for (auto infer = def->isa_mut<Infer>(); infer && infer->op(); infer = def->isa_mut<Infer>()) {
37 def = infer->op();
38 infer->reset(res);
39 }
40
41 assert((!res->isa<Infer>() || res != res->op(0)) && "an Infer shouldn't point to itself");
42
43 // If we have an Infer as operand, try to get rid of it now.
44 // TODO why does this not work?
45 // if (res->isa_imm() && res->has_dep(Dep::Infer)) {
46 if (res->isa<Tuple>() || res->isa<Type>()) {
47 auto new_type = Ref::refer(res->type());
48 bool update = new_type != res->type();
49
50 auto new_ops = DefVec(res->num_ops(), [res, &update](size_t i) {
51 auto r = Ref::refer(res->op(i));
52 update |= r != res->op(i);
53 return r;
54 });
55
56 if (update) return res->rebuild(new_type, new_ops);
57 }
58
59 return res;
60}
61
63 if (std::ranges::any_of(refs, [](auto pref) { return should_eliminate(*pref); })) {
64 auto& world = (*refs.front())->world();
65 InferRewriter rw(world);
66 for (size_t i = 0, e = refs.size(); i != e; ++i) {
67 auto ref = *refs[i];
68 *refs[i] = ref->has_dep(Dep::Infer) ? rw.rewrite(ref) : ref;
69 }
70 return true;
71 }
72 return false;
73}
74/*
75 * Check
76 */
77
78#ifdef MIM_ENABLE_CHECKS
79template<bool infer> bool Check::fail() {
80 if (infer && world().flags().break_on_alpha_unequal) fe::breakpoint();
81 return false;
82}
83#endif
84
85template<bool infer> bool Check::alpha_(Ref r1, Ref r2) {
86 auto d1 = *r1; // find
87 auto d2 = *r2; // find
88
89 if (!d1 && !d2) return true;
90 if (!d1 || !d2) return fail<infer>();
91
92 // It is only safe to check for pointer equality if there are no Vars involved.
93 // Otherwise, we have to look more thoroughly.
94 // Example: λx.x - λz.x
95 if (!d1->has_dep(Dep::Var) && !d2->has_dep(Dep::Var) && d1 == d2) return true;
96 auto mut1 = d1->isa_mut();
97 auto mut2 = d2->isa_mut();
98 if (mut1 && mut2 && mut1 == mut2) return true;
99 // Globals are HACKs and require additionaly HACKs:
100 // Unless they are pointer equal (above) always consider them unequal.
101 if (d1->isa<Global>() || d2->isa<Global>()) return false;
102
103 if (mut1) {
104 if (auto [i, ins] = done_.emplace(mut1, d2); !ins) return i->second == d2;
105 }
106 if (mut2) {
107 if (auto [i, ins] = done_.emplace(mut2, d1); !ins) return i->second == d1;
108 }
109
110 auto i1 = d1->isa_mut<Infer>();
111 auto i2 = d2->isa_mut<Infer>();
112
113 if ((!i1 && !d1->is_set()) || (!i2 && !d2->is_set())) return fail<infer>();
114
115 if (infer) {
116 if (i1 && i2) {
117 // union by rank
118 if (i1->rank() < i2->rank()) std::swap(i1, i2); // make sure i1 is heavier or equal
119 i2->set(i1); // make i1 new root
120 if (i1->rank() == i2->rank()) ++i1->rank();
121 return true;
122 } else if (i1) {
123 i1->set(d2);
124 return true;
125 } else if (i2) {
126 i2->set(d1);
127 return true;
128 }
129 }
130
131 // normalize:
132 if ((d1->isa<Lit>() && !d2->isa<Lit>()) // Lit to right
133 || (!d1->isa<UMax>() && d2->isa<UMax>()) // UMax to left
134 || (d1->gid() > d2->gid())) // smaller gid to left
135 std::swap(d1, d2);
136
137 return alpha_internal<infer>(d1, d2);
138}
139
140template<bool infer> bool Check::alpha_internal(Ref d1, Ref d2) {
141 if (!alpha_<infer>(d1->type(), d2->type())) return fail<infer>();
142 if (d1->isa<Top>() || d2->isa<Top>()) return infer;
143 if (!infer && (d1->isa_mut<Infer>() || d2->isa_mut<Infer>())) return fail<infer>();
144 if (!alpha_<infer>(d1->arity(), d2->arity())) return fail<infer>();
145
146 // vars are equal if they appeared under the same binder
147 if (auto mut1 = d1->isa_mut()) assert_emplace(vars_, mut1, d2->isa_mut());
148 if (auto mut2 = d2->isa_mut()) assert_emplace(vars_, mut2, d1->isa_mut());
149
150 if (auto ts = d1->isa<Tuple, Sigma>()) {
151 size_t a = ts->num_ops();
152 for (size_t i = 0; i != a; ++i)
153 if (!alpha_<infer>(ts->op(i), d2->proj(a, i))) return fail<infer>();
154 return true;
155 } else if (auto pa = d1->isa<Pack, Arr>()) {
156 if (pa->node() == d2->node()) return alpha_<infer>(pa->ops().back(), d2->ops().back());
157 if (auto a = pa->isa_lit_arity()) {
158 for (size_t i = 0; i != *a; ++i)
159 if (!alpha_<infer>(pa->proj(*a, i), d2->proj(*a, i))) return fail<infer>();
160 return true;
161 }
162 } else if (auto umax = d1->isa<UMax>(); umax && umax->has_dep(Dep::Infer) && !d2->isa<UMax>()) {
163 // .umax(a, ?) == x => .umax(a, x)
164 for (auto op : umax->ops())
165 if (auto inf = op->isa_mut<Infer>(); inf && !inf->is_set()) inf->set(d2);
166 d1 = umax->rebuild(umax->type(), umax->ops());
167 }
168
169 if (d1->node() != d2->node() || d1->flags() != d2->flags() || d1->num_ops() != d2->num_ops()) return fail<infer>();
170
171 if (auto var1 = d1->isa<Var>()) {
172 auto var2 = d2->as<Var>();
173 if (auto i = vars_.find(var1->mut()); i != vars_.end()) return i->second == var2->mut();
174 if (auto i = vars_.find(var2->mut()); i != vars_.end()) return fail<infer>(); // var2 is bound
175 // both var1 and var2 are free: OK, when they are the same or in infer mode
176 return var1 == var2 || infer;
177 }
178
179 for (size_t i = 0, e = d1->num_ops(); i != e; ++i)
180 if (!alpha_<infer>(d1->op(i), d2->op(i))) return fail<infer>();
181 return true;
182}
183
184bool Check::assignable_(Ref type, Ref val) {
185 auto val_ty = Ref::refer(val->type());
186 if (type == val_ty) return true;
187
188 if (auto infer = val->isa_mut<Infer>()) return alpha_<true>(type, infer->type());
189
190 if (auto sigma = type->isa<Sigma>()) {
191 if (!alpha_<true>(type->arity(), val_ty->arity())) return fail<true>();
192
193 size_t a = sigma->num_ops();
194 auto red = sigma->reduce(val);
195 for (size_t i = 0; i != a; ++i)
196 if (!assignable_(red[i], val->proj(a, i))) return fail<true>();
197 return true;
198 } else if (auto arr = type->isa<Arr>()) {
199 if (!alpha_<true>(type->arity(), val_ty->arity())) return fail<true>();
200
201 if (auto a = Lit::isa(arr->arity())) {
202 for (size_t i = 0; i != *a; ++i)
203 if (!assignable_(arr->proj(*a, i), val->proj(*a, i))) return fail<true>();
204 return true;
205 }
206 } else if (auto vel = val->isa<Vel>()) {
207 return assignable_(type, vel->value());
208 }
209
210 return alpha_<true>(type, val_ty);
211}
212
214 if (defs.empty()) return nullptr;
215 auto first = defs.front();
216 for (size_t i = 1, e = defs.size(); i != e; ++i)
217 if (!alpha<false>(first, defs[i])) return nullptr;
218 return first;
219}
220
221/*
222 * infer & check
223 */
224
226 auto t = body()->unfold_type();
227 if (!Check::alpha(t, type()))
228 error(type()->loc(), "declared sort '{}' of array does not match inferred one '{}'", type(), t);
229 if (t != type()) set_type(t);
230}
231
233 if (ops.size() == 0) return w.type<1>();
234 auto kinds = DefVec(ops.size(), [ops](size_t i) { return ops[i]->unfold_type(); });
235 return w.umax<Sort::Kind>(kinds);
236}
237
239 auto t = infer(world(), ops());
240 if (t != type()) {
241 // TODO HACK
242 if (Check::alpha(t, type()))
243 set_type(t);
244 else
245 world().WLOG(
246 "incorrect type '{}' for '{}'. Correct one would be: '{}'. I'll keep this one nevertheless due to "
247 "bugs in clos-conv",
248 type(), this, t);
249 }
250}
251
253 if (!Check::alpha(filter()->type(), world().type_bool())) {
254 error(filter()->loc(), "filter '{}' of lambda is of type '{}' but must be of type '.Bool'", filter(),
255 filter()->type());
256 }
257 if (!Check::assignable(codom(), body())) {
258 throw Error()
259 .error(body()->loc(), "body of function is not assignable to declared codomain")
260 .note(body()->loc(), "body: '{}'", body())
261 .note(body()->loc(), "type: '{}'", body()->type())
262 .note(codom()->loc(), "codomain: '{}'", codom());
263 }
264}
265
266Ref Pi::infer(Ref dom, Ref codom) {
267 auto& w = dom->world();
268 return w.umax<Sort::Kind>({dom->unfold_type(), codom->unfold_type()});
269}
270
271void Pi::check() {
272 auto t = infer(dom(), codom());
273 if (!Check::alpha(t, type()))
274 error(type()->loc(), "declared sort '{}' of function type does not match inferred one '{}'", type(), t);
275 if (t != type()) set_type(t);
276}
277
278#ifndef DOXYGEN
279template bool Check::alpha_<true>(Ref, Ref);
280template bool Check::alpha_<false>(Ref, Ref);
281#endif
282
283} // namespace mim
void check() override
Definition check.cpp:225
const Def * body() const
Definition tuple.h:62
static bool alpha(Ref d1, Ref d2)
Are d1 and d2 α-equivalent?
Definition check.h:59
static bool assignable(Ref type, Ref value)
Can value be assigned to sth of type?
Definition check.h:63
World & world()
Definition check.h:52
static Ref is_uniform(Defs defs)
Yields defs.front(), if all defs are Check::alpha-equivalent (infer = false) and nullptr otherwise.
Definition check.cpp:213
Base class for all Defs.
Definition def.h:220
const Def * op(size_t i) const
Definition def.h:266
Def * set_type(const Def *)
Definition def.cpp:290
World & world() const
Definition def.cpp:417
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:444
const Def * unfold_type() const
Yields the type of this Def and builds a new .Type (UInc n) if necessary.
Definition def.cpp:423
auto ops() const
Definition def.h:265
void update()
Resolves Infers of this Def's type.
Definition def.h:296
const Def * type() const
Definition def.h:245
Loc loc() const
Definition def.h:464
Def * reset(size_t i, const Def *def)
Successively reset from left to right.
Definition def.h:288
Error & error(Loc loc, const char *s, Args &&... args)
Definition dbg.h:71
Error & note(Loc loc, const char *s, Args &&... args)
Definition dbg.h:73
This node is a hole in the IR that is inferred by its context later on.
Definition check.h:12
static bool eliminate(Vector< Ref * >)
Eliminate Infers that may have been resolved in the meantime by rebuilding.
Definition check.cpp:62
static bool should_eliminate(Ref def)
Definition check.h:29
static const Def * find(const Def *)
Union-Find to unify Infer nodes.
Definition check.cpp:29
const Def * op() const
Definition check.h:20
Ref filter() const
Definition lam.h:106
void check() override
Definition check.cpp:252
Ref codom() const
Definition lam.h:123
const Pi * type() const
Definition lam.h:108
Ref body() const
Definition lam.h:107
static std::optional< T > isa(Ref def)
Definition def.h:712
Ref codom() const
Definition lam.h:40
static Ref infer(Ref dom, Ref codom)
Definition check.cpp:266
void check() override
Definition check.cpp:271
Ref dom() const
Definition lam.h:32
Helper class to retrieve Infer::arg if present.
Definition def.h:85
static const Def * refer(const Def *def)
Retrieves Infer::arg from def.
Definition check.cpp:27
virtual Ref rewrite(Ref)
Definition rewrite.cpp:9
static Ref infer(World &, Defs)
Definition check.cpp:232
void check() override
Definition check.cpp:238
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
Data constructor for a Sigma.
Definition tuple.h:40
This is a thin wrapper for absl::InlinedVector<T, N, / A> which in turn is a drop-in replacement for ...
Definition vector.h:16
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:33
Ref op(trait o, Ref type)
Definition core.h:35
Definition cfg.h:11
Vector< const Def * > DefVec
Definition def.h:61
auto assert_emplace(C &container, Args &&... args)
Invokes emplace on container, asserts that insertion actually happened, and returns the iterator.
Definition util.h:102
void error(Loc loc, const char *f, Args &&... args)
Definition dbg.h:122
TExt< true > Top
Definition lattice.h:152
DefVec rewrite(Def *mut, Ref arg)
Definition rewrite.cpp:45