MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
check.cpp
Go to the documentation of this file.
1#include "mim/check.h"
2
3#include <absl/container/fixed_array.h>
4#include <fe/assert.h>
5
6#include "mim/rewrite.h"
7#include "mim/world.h"
8
9namespace mim {
10
11namespace {
12
13static bool needs_zonk(const Def* def) {
14 if (def->has_dep(Dep::Hole)) {
15 for (auto mut : def->local_muts())
16 if (auto infer = mut->isa<Hole>(); infer && infer->is_set()) return true;
17 }
18
19 return false;
20}
21
22class Zonker : public Rewriter {
23public:
24 Zonker(World& world)
25 : Rewriter(world) {}
26
27 const Def* rewrite(const Def* def) override {
28 def = Hole::find(def);
29 return needs_zonk(def) ? Rewriter::rewrite(def) : def;
30 }
31};
32
33} // namespace
34
35const Def* Def::zonk() const {
36 auto def = Hole::find(this);
37 return needs_zonk(def) ? Zonker(world()).rewrite(def) : def;
38}
39
40/*
41 * Hole
42 */
43
44const Def* Hole::find(const Def* def) {
45 // find root
46 auto res = def;
47 for (auto hole = res->isa_mut<Hole>(); hole && hole->op(); hole = res->isa_mut<Hole>()) res = hole->op();
48 // TODO don't re-update last infer
49
50 // path compression: set all Holes along the chain to res
51 for (auto hole = def->isa_mut<Hole>(); hole && hole->op(); hole = def->isa_mut<Hole>()) {
52 def = hole->op();
53 hole->reset(res);
54 }
55
56 return res;
57}
58
59const Def* Hole::tuplefy() {
60 if (auto a = type()->isa_lit_arity(); a && !is_set()) {
61 auto& w = world();
62 auto n = *a;
63 auto infers = absl::FixedArray<const Def*>(n);
64 if (auto sigma = type()->isa_mut<Sigma>(); sigma && n >= 1 && sigma->has_var()) {
65 auto var = sigma->has_var();
66 auto rw = VarRewriter(var, this);
67 infers[0] = w.mut_hole(sigma->op(0));
68 for (size_t i = 1; i != n; ++i) {
69 rw.map(sigma->var(n, i - 1), infers[i - 1]);
70 infers[i] = w.mut_hole(rw.rewrite(sigma->op(i)));
71 }
72 } else {
73 for (size_t i = 0; i != n; ++i) infers[i] = w.mut_hole(type()->proj(n, i));
74 }
75
76 auto tuple = w.tuple(infers);
77 set(tuple);
78 return tuple;
79 }
80 return this;
81}
82
83/*
84 * Check
85 */
86
87#ifdef MIM_ENABLE_CHECKS
88template<Checker::Mode mode> bool Checker::fail() {
89 if (mode == Check && world().flags().break_on_alpha) fe::breakpoint();
90 return false;
91}
92
93const Def* Checker::fail() {
94 if (world().flags().break_on_alpha) fe::breakpoint();
95 return {};
96}
97#endif
98
99template<Checker::Mode mode> bool Checker::alpha_(const Def* d1, const Def* d2) {
100 d1 = Hole::find(d1);
101 d2 = Hole::find(d2);
102
103 if (!d1 && !d2) return true;
104 if (!d1 || !d2) return fail<mode>();
105
106 // It is only safe to check for pointer equality if there are no Vars involved.
107 // Otherwise, we have to look more thoroughly.
108 // Example: λx.x - λz.x
109 if (!d1->has_dep(Dep::Var) && !d2->has_dep(Dep::Var) && d1 == d2) return true;
110
111 auto h1 = d1->isa_mut<Hole>();
112 auto h2 = d2->isa_mut<Hole>();
113
114 if ((!h1 && !d1->is_set()) || (!h2 && !d2->is_set())) return fail<mode>();
115
116 if (mode == Check) {
117 if (h1 && h2) {
118 // union by rank
119 if (h1->rank() < h2->rank()) std::swap(h1, h2); // make sure h1 is heavier or equal
120 h2->set(h1); // make h1 new root
121 if (h1->rank() == h2->rank()) ++h1->rank();
122 return true;
123 } else if (h1) {
124 h1->set(d2);
125 return true;
126 } else if (h2) {
127 h2->set(d1);
128 return true;
129 }
130 }
131
132 auto mut1 = d1->isa_mut();
133 auto mut2 = d2->isa_mut();
134 if (mut1 && mut2 && mut1 == mut2) return true;
135 // Globals are HACKs and require additionaly HACKs:
136 // Unless they are pointer equal (above) always consider them unequal.
137 if (d1->isa<Global>() || d2->isa<Global>()) return false;
138
139 if (mut1) {
140 if (auto [i, ins] = binders_.emplace(mut1, d2); !ins) return i->second == d2;
141 }
142 if (mut2) {
143 if (auto [i, ins] = binders_.emplace(mut2, d1); !ins) return i->second == d1;
144 }
145
146 // normalize:
147 if ((d1->isa<Lit>() && !d2->isa<Lit>()) // Lit to right
148 || (!d1->isa<UMax>() && d2->isa<UMax>()) // UMax to left
149 || (!d1->isa<Extract>() && d2->isa<Extract>())) // Extract to left
150 std::swap(d1, d2);
151
152 return alpha_internal<mode>(d1, d2);
153}
154
155template<Checker::Mode mode> bool Checker::alpha_internal(const Def* d1, const Def* d2) {
156 if (d1->type() && d2->type() && !alpha_<mode>(d1->type(), d2->type())) return fail<mode>();
157 if (d1->isa<Top>() || d2->isa<Top>()) return mode == Check;
158 if (mode == Test && (d1->isa_mut<Hole>() || d2->isa_mut<Hole>())) return fail<mode>();
159 if (!alpha_<mode>(d1->arity(), d2->arity())) return fail<mode>();
160
161 if (auto ts = d1->isa<Tuple, Sigma>()) {
162 size_t a = ts->num_ops();
163 for (size_t i = 0; i != a; ++i)
164 if (!alpha_<mode>(ts->op(i), d2->proj(a, i))) return fail<mode>();
165 return true;
166 } else if (auto pa = d1->isa<Pack, Arr>()) {
167 if (pa->node() == d2->node()) return alpha_<mode>(pa->ops().back(), d2->ops().back());
168 if (auto a = pa->isa_lit_arity()) {
169 for (size_t i = 0; i != *a; ++i)
170 if (!alpha_<mode>(pa->proj(*a, i), d2->proj(*a, i))) return fail<mode>();
171 return true;
172 }
173 } else if (auto umax = d1->isa<UMax>(); umax && umax->has_dep(Dep::Hole) && !d2->isa<UMax>()) {
174 // .umax(a, ?) == x => .umax(a, x)
175 for (auto op : umax->ops())
176 if (auto inf = op->isa_mut<Hole>(); inf && !inf->is_set()) inf->set(d2);
177 d1 = umax->rebuild(umax->type(), umax->ops());
178 }
179
180 if (d1->node() != d2->node() || d1->flags() != d2->flags() || d1->num_ops() != d2->num_ops()) return fail<mode>();
181
182 if (auto var1 = d1->isa<Var>()) {
183 auto var2 = d2->as<Var>();
184 if (auto i = binders_.find(var1->mut()); i != binders_.end()) return i->second == var2->mut();
185 if (auto i = binders_.find(var2->mut()); i != binders_.end()) return fail<mode>(); // var2 is bound
186 // both var1 and var2 are free: OK, when they are the same or in Check mode
187 return var1 == var2 || mode == Check;
188 }
189
190 for (size_t i = 0, e = d1->num_ops(); i != e; ++i)
191 if (!alpha_<mode>(d1->op(i), d2->op(i))) return fail<mode>();
192 return true;
193}
194
195const Def* Checker::assignable_(const Def* type, const Def* val) {
196 auto val_ty = Hole::find(val->type());
197 if (type == val_ty) return val;
198
199 auto& w = world();
200 if (auto sigma = type->isa<Sigma>()) {
201 if (!alpha_<Check>(type->arity(), val_ty->arity())) return fail();
202
203 size_t a = sigma->num_ops();
204 auto red = sigma->reduce(val);
205 auto new_ops = absl::FixedArray<const Def*>(red.size());
206 for (size_t i = 0; i != a; ++i) {
207 auto new_val = assignable_(red[i], val->proj(a, i));
208 if (new_val)
209 new_ops[i] = new_val;
210 else
211 return fail();
212 }
213 return w.tuple(new_ops);
214 } else if (auto arr = type->isa<Arr>()) {
215 if (!alpha_<Check>(type->arity(), val_ty->arity())) return fail();
216
217 // TODO ack sclarize threshold
218 if (auto a = Lit::isa(arr->arity())) {
219 auto new_ops = absl::FixedArray<const Def*>(*a);
220 for (size_t i = 0; i != *a; ++i) {
221 auto new_val = assignable_(arr->proj(*a, i), val->proj(*a, i));
222 if (new_val)
223 new_ops[i] = new_val;
224 else
225 return fail();
226 }
227 return w.tuple(new_ops);
228 }
229 } else if (auto inj = val->isa<Inj>()) {
230 if (auto new_val = assignable_(type, inj->value())) return w.inj(type, new_val);
231 return fail();
232 } else if (auto uniq = val->type()->isa<Uniq>()) {
233 if (auto new_val = assignable(type, uniq->inhabitant())) return new_val;
234 return fail();
235 }
236
237 return alpha_<Check>(type, val_ty) ? val : fail();
238}
239
241 if (defs.empty()) return nullptr;
242 auto first = defs.front();
243 for (size_t i = 1, e = defs.size(); i != e; ++i)
244 if (!alpha<Test>(first, defs[i])) return nullptr;
245 return first;
246}
247
248/*
249 * infer & check
250 */
251
252const Def* Arr::check(size_t, const Def* def) { return def; } // TODO
253
254const Def* Arr::check() {
255 auto t = body()->unfold_type();
257 error(type()->loc(), "declared sort '{}' of array does not match inferred one '{}'", type(), t);
258 return t;
259}
260
262 auto elems = absl::FixedArray<const Def*>(ops.size());
263 for (size_t i = 0, e = ops.size(); i != e; ++i) elems[i] = ops[i]->type();
264 return world.sigma(elems);
265}
266
267const Def* Sigma::infer(World& w, Defs ops) {
268 auto elems = absl::FixedArray<const Def*>(ops.size());
269 for (size_t i = 0, e = ops.size(); i != e; ++i) elems[i] = ops[i]->unfold_type();
270 return w.umax<Sort::Kind>(elems);
271}
272
273const Def* Sigma::check(size_t, const Def* def) { return def; } // TODO
274
275const Def* Sigma::check() {
276 auto t = infer(world(), ops());
277 if (t != type()) {
278 // TODO HACK
280 return t;
281 else {
282 world().WLOG(
283 "incorrect type '{}' for '{}'. Correct one would be: '{}'. I'll keep this one nevertheless due to "
284 "bugs in clos-conv",
285 type(), this, t);
286 return type();
287 }
288 }
289 return t;
290}
291
292const Def* Pi::infer(const Def* dom, const Def* codom) {
293 auto& w = dom->world();
294 return w.umax<Sort::Kind>({dom->unfold_type(), codom->unfold_type()});
295}
296
297const Def* Pi::check(size_t, const Def* def) { return def; }
298
299const Def* Pi::check() {
300 auto t = infer(dom(), codom());
302 error(type()->loc(), "declared sort '{}' of function type does not match inferred one '{}'", type(), t);
303 return t;
304}
305
306const Def* Lam::check(size_t i, const Def* def) {
307 if (i == 0) {
308 if (auto filter = Checker::assignable(world().type_bool(), def)) return filter;
309 throw Error().error(filter()->loc(), "filter '{}' of lambda is of type '{}' but must be of type 'Bool'",
310 filter(), filter()->type());
311 } else if (i == 1) {
312 if (auto body = Checker::assignable(codom(), def)) return body;
313 throw Error()
314 .error(def->loc(), "body of function is not assignable to declared codomain")
315 .note(def->loc(), "body: '{}'", def)
316 .note(def->loc(), "type: '{}'", def->type())
317 .note(codom()->loc(), "codomain: '{}'", codom());
318 }
319 fe::unreachable();
320}
321
322#ifndef DOXYGEN
323template bool Checker::alpha_<Checker::Check>(const Def*, const Def*);
324template bool Checker::alpha_<Checker::Test>(const Def*, const Def*);
325#endif
326
327} // namespace mim
const Def * body() const
Definition tuple.h:85
const Def * check() override
Definition check.cpp:254
static const Def * is_uniform(Defs defs)
Yields defs.front(), if all defs are Check::alpha-equivalent (Mode::Test) and nullptr otherwise.
Definition check.cpp:240
static bool alpha(const Def *d1, const Def *d2)
Definition check.h:71
World & world()
Definition check.h:60
@ Test
In Mode::Test, no type inference is happening and Holes will not be touched.
Definition check.h:68
@ Check
In Mode::Check, type inference is happening and Holes will be resolved, if possible.
Definition check.h:65
static const Def * assignable(const Def *type, const Def *value)
Can value be assigned to sth of type?
Definition check.h:77
Base class for all Defs.
Definition def.h:197
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:278
const Def * proj(nat_t a, nat_t i) const
Similar to World::extract while assuming an arity of a, but also works on Sigmas and Arrays.
Definition def.cpp:502
const Def * zonk() const
Definition check.cpp:35
World & world() const noexcept
Definition def.cpp:387
virtual const Def * check()
Definition def.h:517
constexpr auto ops() const noexcept
Definition def.h:260
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:429
const Def * op(size_t i) const noexcept
Definition def.h:263
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:378
const Def * unfold_type() const
Yields the type of this Def and builds a new Type (UInc n) if necessary.
Definition def.cpp:394
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:241
Loc loc() const
Definition def.h:449
std::optional< nat_t > isa_lit_arity() const
Definition def.cpp:472
Def * reset(size_t i, const Def *def)
Successively reset from left to right.
Definition def.h:285
Error & error(Loc loc, const char *s, Args &&... args)
Definition dbg.h:71
Error & note(Loc loc, const char *s, Args &&... args)
Definition dbg.h:73
This node is a hole in the IR that is inferred by its context later on.
Definition check.h:10
Hole * set(const Def *op)
Definition check.h:21
const Def * tuplefy()
If unset, explode to Tuple.
Definition check.cpp:59
static const Def * find(const Def *)
Union-Find to unify Holes.
Definition check.cpp:44
const Def * filter() const
Definition lam.h:118
const Pi * type() const
Definition lam.h:126
const Def * body() const
Definition lam.h:119
const Def * codom() const
Definition lam.h:128
static std::optional< T > isa(const Def *def)
Definition def.h:712
static const Def * infer(const Def *dom, const Def *codom)
Definition check.cpp:292
const Def * dom() const
Definition lam.h:32
const Def * codom() const
Definition lam.h:33
const Def * check() override
Definition check.cpp:299
Recurseivly rebuilds part of a program into the provided World w.r.t. Rewriter::map.
Definition rewrite.h:9
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:11
friend class World
Definition tuple.h:51
static const Def * infer(World &, Defs)
Definition check.cpp:267
const Def * check() override
Definition check.cpp:275
friend class World
Definition tuple.h:68
static const Def * infer(World &, Defs)
Definition check.cpp:261
const Def * op(trait o, const Def *type)
Definition core.h:33
Definition ast.h:14
View< const Def * > Defs
Definition def.h:48
@ Var
Definition def.h:99
@ Hole
Definition def.h:100
void error(Loc loc, const char *f, Args &&... args)
Definition dbg.h:122
TExt< true > Top
Definition lattice.h:159
@ Kind
Definition def.h:93
@ Arr
Definition def.h:84
@ Pack
Definition def.h:84
@ Global
Definition def.h:84
@ Var
Definition def.h:84
@ Inj
Definition def.h:84
@ Hole
Definition def.h:84
@ Sigma
Definition def.h:84
@ Extract
Definition def.h:84
@ Uniq
Definition def.h:84
@ Tuple
Definition def.h:84
@ Lit
Definition def.h:84
@ UMax
Definition def.h:84