MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
check.cpp
Go to the documentation of this file.
1#include "mim/check.h"
2
3#include <absl/container/fixed_array.h>
4#include <fe/assert.h>
5
6#include "mim/rewrite.h"
7#include "mim/world.h"
8
9namespace mim {
10
11namespace {
12
13static bool needs_zonk(const Def* def) {
14 if (def->has_dep(Dep::Hole)) {
15 for (auto mut : def->local_muts())
16 if (auto hole = mut->isa<Hole>(); hole && hole->is_set()) return true;
17 }
18
19 return false;
20}
21
22class Zonker : public Rewriter {
23public:
24 Zonker(World& world)
25 : Rewriter(world) {}
26
27 const Def* rewrite(const Def* def) override {
28 if (auto hole = def->isa_mut<Hole>()) {
29 auto [last, op] = hole->find();
30 return op ? rewrite(op) : last;
31 }
32
33 return needs_zonk(def) ? Rewriter::rewrite(def) : def;
34 }
35};
36
37} // namespace
38
39const Def* Def::zonk() const { return needs_zonk(this) ? Zonker(world()).rewrite(this) : this; }
40
41const Def* Def::zonk_mut() {
42 bool zonk = false;
43 for (auto def : deps())
44 if (needs_zonk(def)) {
45 zonk = true;
46 break;
47 }
48
49 if (zonk) {
50 auto zonker = Zonker(world());
51 auto old_type = type();
52 auto old_ops = absl::FixedArray<const Def*>(ops().begin(), ops().end());
53 unset();
54 set_type(zonker.rewrite(old_type));
55 for (size_t i = 0, e = num_ops(); i != e; ++i) set(i, zonker.rewrite(old_ops[i]));
56 }
57
58 if (auto imm = immutabilize()) return imm;
59 return nullptr;
60}
61
63 return DefVec(defs.size(), [defs](size_t i) { return defs[i]->zonk(); });
64}
65
66/*
67 * Hole
68 */
69
70std::pair<Hole*, const Def*> Hole::find() {
71 auto def = Def::op(0);
72 auto last = this;
73
74 for (; def;) {
75 if (auto h = def->isa_mut<Hole>()) {
76 def = h->op();
77 last = h;
78 } else {
79 break;
80 }
81 }
82
83 auto root = def ? def : last;
84
85 // path compression
86 for (auto h = this; h != last;) {
87 auto next = h->op()->as_mut<Hole>();
88 h->unset()->set(root);
89 h = next;
90 }
91
92 return {last, def};
93}
94
95const Def* Hole::tuplefy(nat_t n) {
96 if (is_set()) return this;
97
98 auto& w = world();
99 auto holes = absl::FixedArray<const Def*>(n);
100 if (auto sigma = type()->isa_mut<Sigma>(); sigma && n >= 1 && sigma->has_var()) {
101 auto var = sigma->has_var();
102 auto rw = VarRewriter(var, this);
103 holes[0] = w.mut_hole(sigma->op(0));
104 for (size_t i = 1; i != n; ++i) {
105 rw.map(sigma->var(n, i - 1), holes[i - 1]);
106 holes[i] = w.mut_hole(rw.rewrite(sigma->op(i)));
107 }
108 } else {
109 for (size_t i = 0; i != n; ++i) holes[i] = w.mut_hole(type()->proj(n, i));
110 }
111
112 auto tuple = w.tuple(holes);
113 set(tuple);
114 return tuple;
115}
116
117/*
118 * Check
119 */
120
121#ifdef MIM_ENABLE_CHECKS
122template<Checker::Mode mode> bool Checker::fail() {
123 if (mode == Check && world().flags().break_on_alpha) fe::breakpoint();
124 return false;
125}
126
127const Def* Checker::fail() {
128 if (world().flags().break_on_alpha) fe::breakpoint();
129 return {};
130}
131#endif
132
133template<Checker::Mode mode> bool Checker::alpha_(const Def* d1_, const Def* d2_) {
134 auto ds = std::array<const Def*, 2>{d1_->zonk(), d2_->zonk()};
135 auto& [d1, d2] = ds;
136
137 if (!d1 && !d2) return true;
138 if (!d1 || !d2) return fail<mode>();
139
140 // It is only safe to check for pointer equality if there are no Vars involved.
141 // Otherwise, we have to look more thoroughly.
142 // Example: λx.x - λz.x
143 if (!d1->has_dep(Dep::Var) && !d2->has_dep(Dep::Var) && d1 == d2) return true;
144
145 auto h1 = d1->isa_mut<Hole>();
146 auto h2 = d2->isa_mut<Hole>();
147
148 if ((!h1 && !d1->is_set()) || (!h2 && !d2->is_set())) return fail<mode>();
149
150 if (mode == Check) {
151 if (h1)
152 return h1->set(d2), true;
153 else if (h2)
154 return h2->set(d1), true;
155 }
156
157 auto muts = std::array<Def*, 2>{d1->isa_mut(), d2->isa_mut()};
158 auto& [mut1, mut2] = muts;
159
160 if (mut1 && mut2 && mut1 == mut2) return true;
161
162 // Globals are HACKs and require additionaly HACKs:
163 // Unless they are pointer equal (above) always consider them unequal.
164 if (d1->isa<Global>() || d2->isa<Global>()) return false;
165
166 bool redo = false;
167 for (size_t i = 0; i != 2; ++i) {
168 auto& mut = muts[i];
169 if (!mut || !mut->is_set()) continue;
170 size_t other = (i + 1) % 2;
171
172 if (auto imm = mut->zonk_mut())
173 mut = nullptr, ds[i] = imm, redo = true;
174 else if (auto [i, ins] = binders_.emplace(mut, ds[other]); !ins)
175 return i->second == ds[other];
176 }
177
178 return redo ? alpha<mode>(d1, d2) : alpha_internal<mode>(d1, d2);
179}
180
181template<Checker::Mode mode> bool Checker::alpha_internal(const Def* d1, const Def* d2) {
182 if (d1->type() && d2->type() && !alpha_<mode>(d1->type(), d2->type())) return fail<mode>();
183 if (d1->isa<Top>() || d2->isa<Top>()) return mode == Check;
184 if (mode == Test && (d1->isa_mut<Hole>() || d2->isa_mut<Hole>())) return fail<mode>();
185 if (!alpha_<mode>(d1->arity(), d2->arity())) return fail<mode>();
186
187 auto check1 = [this](const Arr* arr, const Def* d) {
188 auto body = arr->reduce(world().lit_idx(1, 0))->zonk();
189 if (!alpha_<mode>(body, d)) return fail<mode>();
190 if (auto mut_arr = arr->isa_mut<Arr>()) mut_arr->unset()->set(world().lit_nat_1(), body->zonk());
191 return true;
192 };
193
194 if (mode == Mode::Check) {
195 if (auto arr = d1->isa<Arr>();
196 arr && arr->is_set() && arr->shape()->zonk() == world().lit_nat_1() && !d2->isa<Arr>())
197 return check1(arr, d2);
198
199 if (auto arr = d2->isa<Arr>();
200 arr && arr->is_set() && arr->shape()->zonk() == world().lit_nat_1() && !d1->isa<Arr>())
201 return check1(arr, d1);
202 }
203
204 if (auto prod = d1->isa<Prod>()) {
205 size_t a = prod->num_ops();
206 for (size_t i = 0; i != a; ++i)
207 if (!alpha_<mode>(prod->op(i), d2->proj(a, i))) return fail<mode>();
208 return true;
209 } else if (auto seq = d1->isa<Seq>()) {
210 if (seq->node() != d2->node()) return fail<mode>();
211
212 if (auto a = seq->isa_lit_arity()) {
213 for (size_t i = 0; i != *a; ++i)
214 if (!alpha_<mode>(seq->proj(*a, i), d2->proj(*a, i))) return fail<mode>();
215 return true;
216 }
217
218 auto check_arr = [this](Arr* mut_arr, const Arr* imm_arr) {
219 if (!alpha_<mode>(mut_arr->shape(), imm_arr->shape())) return fail<mode>();
220
221 auto mut_shape = mut_arr->shape()->zonk();
222 auto mut_body = mut_arr->reduce(world().top(world().type_idx(mut_arr->shape())));
223 if (!alpha_<mode>(mut_body, imm_arr->body())) return fail<mode>();
224
225 mut_arr->unset()->set(mut_shape, mut_body->zonk());
226 return true;
227 };
228
229 if (mode == Mode::Check) {
230 if (auto mut_arr = d1->isa_mut<Arr>(); mut_arr && mut_arr->is_set()) {
231 if (auto imm_arr = d2->isa_imm<Arr>()) return check_arr(mut_arr, imm_arr);
232 }
233 if (auto mut_arr = d2->isa_mut<Arr>(); mut_arr && mut_arr->is_set()) {
234 if (auto imm_arr = d1->isa_imm<Arr>()) return check_arr(mut_arr, imm_arr);
235 }
236 }
237 } else if (auto umax = d1->isa<UMax>(); umax && umax->has_dep(Dep::Hole) && !d2->isa<UMax>()) {
238 // .umax(a, ?) == x => .umax(a, x)
239 for (auto op : umax->ops())
240 if (auto inf = op->isa_mut<Hole>(); inf && !inf->is_set()) inf->set(d2);
241 d1 = umax->rebuild(umax->type(), umax->ops());
242 }
243
244 if (d1->node() != d2->node() || d1->flags() != d2->flags() || d1->num_ops() != d2->num_ops()) return fail<mode>();
245
246 if (auto var1 = d1->isa<Var>()) {
247 auto var2 = d2->as<Var>();
248 if (auto i = binders_.find(var1->mut()); i != binders_.end()) return i->second == var2->mut();
249 if (auto i = binders_.find(var2->mut()); i != binders_.end()) return fail<mode>(); // var2 is bound
250 // both var1 and var2 are free: OK, when they are the same or in Check mode
251 return var1 == var2 || mode == Check;
252 }
253
254 for (size_t i = 0, e = d1->num_ops(); i != e; ++i)
255 if (!alpha_<mode>(d1->op(i), d2->op(i))) return fail<mode>();
256 return true;
257}
258
259const Def* Checker::assignable_(const Def* type, const Def* val) {
260 auto val_ty = val->type()->zonk();
261 if (type == val_ty) return val;
262
263 auto& w = world();
264 if (auto sigma = type->isa<Sigma>()) {
265 if (!alpha_<Check>(type->arity(), val_ty->arity())) return fail();
266
267 size_t a = sigma->num_ops();
268 auto red = sigma->reduce(val);
269 auto new_ops = absl::FixedArray<const Def*>(red.size());
270 for (size_t i = 0; i != a; ++i) {
271 auto new_val = assignable_(red[i], val->proj(a, i));
272 if (new_val)
273 new_ops[i] = new_val;
274 else
275 return fail();
276 }
277 return w.tuple(new_ops);
278 } else if (auto arr = type->isa<Arr>()) {
279 if (!alpha_<Check>(type->arity(), val_ty->arity())) return fail();
280 type = type->zonk(); // TODO hack
281
282 // TODO ack sclarize threshold
283 if (auto a = Lit::isa(arr->arity())) {
284 auto new_ops = absl::FixedArray<const Def*>(*a);
285 for (size_t i = 0; i != *a; ++i) {
286 auto new_val = assignable_(arr->proj(*a, i), val->proj(*a, i));
287 if (new_val)
288 new_ops[i] = new_val;
289 else
290 return fail();
291 }
292 return w.tuple(new_ops);
293 }
294 } else if (auto inj = val->isa<Inj>()) {
295 if (auto new_val = assignable_(type, inj->value())) return w.inj(type, new_val);
296 return fail();
297 } else if (auto uniq = val->type()->isa<Uniq>()) {
298 if (auto new_val = assignable(type, uniq->inhabitant())) return new_val;
299 return fail();
300 }
301
302 return alpha_<Check>(type, val_ty) ? val : fail();
303}
304
306 if (defs.empty()) return nullptr;
307 auto first = defs.front();
308 for (size_t i = 1, e = defs.size(); i != e; ++i)
309 if (!alpha<Test>(first, defs[i])) return nullptr;
310 return first;
311}
312
313/*
314 * infer & check
315 */
316
317const Def* Arr::check(size_t, const Def* def) { return def; } // TODO
318
319const Def* Arr::check() {
320 auto t = body()->unfold_type();
322 error(type()->loc(), "declared sort '{}' of array does not match inferred one '{}'", type(), t);
323 return t;
324}
325
327 auto elems = absl::FixedArray<const Def*>(ops.size());
328 for (size_t i = 0, e = ops.size(); i != e; ++i) elems[i] = ops[i]->unfold_type();
329 return world.sigma(elems);
330}
331
333 auto elems = absl::FixedArray<const Def*>(ops.size());
334 for (size_t i = 0, e = ops.size(); i != e; ++i) elems[i] = ops[i]->unfold_type();
335 return w.umax<Sort::Kind>(elems);
336}
337
338const Def* Sigma::check(size_t, const Def* def) { return def; } // TODO
339
341 auto t = infer(world(), ops());
342 if (t != type()) {
343 // TODO HACK
345 return t;
346 else {
347 world().WLOG(
348 "incorrect type '{}' for '{}'. Correct one would be: '{}'. I'll keep this one nevertheless due to "
349 "bugs in clos-conv",
350 type(), this, t);
351 return type();
352 }
353 }
354 return t;
355}
356
357const Def* Pi::infer(const Def* dom, const Def* codom) {
358 auto& w = dom->world();
359 return w.umax<Sort::Kind>({dom->unfold_type(), codom->unfold_type()});
360}
361
362const Def* Pi::check(size_t, const Def* def) { return def; }
363
364const Def* Pi::check() {
365 auto t = infer(dom(), codom());
367 error(type()->loc(), "declared sort '{}' of function type does not match inferred one '{}'", type(), t);
368 return t;
369}
370
371const Def* Lam::check(size_t i, const Def* def) {
372 if (i == 0) {
373 if (auto filter = Checker::assignable(world().type_bool(), def)) return filter;
374 throw Error().error(filter()->loc(), "filter '{}' of lambda is of type '{}' but must be of type 'Bool'",
375 filter(), filter()->type());
376 } else if (i == 1) {
377 if (auto body = Checker::assignable(codom(), def)) return body;
378 throw Error()
379 .error(def->loc(), "body of function is not assignable to declared codomain")
380 .note(def->loc(), "body: '{}'", def)
381 .note(def->loc(), "type: '{}'", def->type())
382 .note(codom()->loc(), "codomain: '{}'", codom());
383 }
384 fe::unreachable();
385}
386
387#ifndef DOXYGEN
388template bool Checker::alpha_<Checker::Check>(const Def*, const Def*);
389template bool Checker::alpha_<Checker::Test>(const Def*, const Def*);
390#endif
391
392} // namespace mim
const Def * check() final
After all Def::ops have ben Def::set, this method will be invoked to check the type of this mutable.
Definition check.cpp:319
static const Def * is_uniform(Defs defs)
Yields defs.front(), if all defs are Check::alpha-equivalent (Mode::Test) and nullptr otherwise.
Definition check.cpp:305
static bool alpha(const Def *d1, const Def *d2)
Definition check.h:66
World & world()
Definition check.h:55
@ Test
In Mode::Test, no type inference is happening and Holes will not be touched.
Definition check.h:63
@ Check
In Mode::Check, type inference is happening and Holes will be resolved, if possible.
Definition check.h:60
static const Def * assignable(const Def *type, const Def *value)
Can value be assigned to sth of type?
Definition check.h:72
Base class for all Defs.
Definition def.h:203
bool is_set() const
Yields true if empty or the last op is set.
Definition def.cpp:268
const Def * proj(nat_t a, nat_t i) const
Similar to World::extract while assuming an arity of a, but also works on Sigmas and Arrays.
Definition def.cpp:492
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:241
Defs deps() const noexcept
Definition def.cpp:405
const Def * zonk() const
If Holes have been filled, reconstruct the program without them.
Definition check.cpp:39
World & world() const noexcept
Definition def.cpp:377
virtual const Def * check()
After all Def::ops have ben Def::set, this method will be invoked to check the type of this mutable.
Definition def.h:530
Def * set_type(const Def *)
Update type.
Definition def.cpp:253
constexpr auto ops() const noexcept
Definition def.h:266
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:434
const Def * op(size_t i) const noexcept
Definition def.h:269
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:384
const Def * unfold_type() const
Yields the type of this Def and builds a new Type (UInc n) if necessary.
Definition def.cpp:384
virtual const Def * immutabilize()
Tries to make an immutable from a mutable.
Definition def.h:504
const Def * zonk_mut()
zonks all ops of this mutable and tries to immutabilize it; if it succeeds return it.
Definition check.cpp:41
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:247
Loc loc() const
Definition def.h:454
Def * unset()
Unsets all Def::ops; works even, if not set at all or partially.
Definition def.cpp:259
constexpr size_t num_ops() const noexcept
Definition def.h:270
Error & error(Loc loc, const char *s, Args &&... args)
Definition dbg.h:71
Error & note(Loc loc, const char *s, Args &&... args)
Definition dbg.h:73
This node is a hole in the IR that is inferred by its context later on.
Definition check.h:10
std::pair< Hole *, const Def * > find()
Transitively walks up Holes until the last one while path-compressing everything.
Definition check.cpp:70
Hole * set(const Def *op)
Definition check.h:25
const Def * tuplefy(nat_t)
If unset, explode to Tuple.
Definition check.cpp:95
const Def * filter() const
Definition lam.h:118
const Pi * type() const
Definition lam.h:126
const Def * body() const
Definition lam.h:119
const Def * codom() const
Definition lam.h:128
static std::optional< T > isa(const Def *def)
Definition def.h:733
const Def * check() final
After all Def::ops have ben Def::set, this method will be invoked to check the type of this mutable.
Definition check.cpp:364
static const Def * infer(const Def *dom, const Def *codom)
Definition check.cpp:357
const Def * dom() const
Definition lam.h:32
const Def * codom() const
Definition lam.h:33
Def(World *, Node, const Def *type, Defs ops, flags_t flags)
Constructor for an immutable Def.
Definition def.cpp:23
Recurseivly rebuilds part of a program into the provided World w.r.t. Rewriter::map.
Definition rewrite.h:11
virtual const Def * rewrite(const Def *)
Definition rewrite.cpp:12
const Def * body() const
Definition tuple.h:86
Def(World *, Node, const Def *type, Defs ops, flags_t flags)
Constructor for an immutable Def.
Definition def.cpp:23
friend class World
Definition tuple.h:57
const Def * check() final
After all Def::ops have ben Def::set, this method will be invoked to check the type of this mutable.
Definition check.cpp:340
static const Def * infer(World &, Defs)
Definition check.cpp:332
friend class World
Definition tuple.h:74
static const Def * infer(World &, Defs)
Definition check.cpp:326
const Def * op(trait o, const Def *type)
Definition core.h:33
Definition ast.h:14
View< const Def * > Defs
Definition def.h:49
u64 nat_t
Definition types.h:43
Vector< const Def * > DefVec
Definition def.h:50
@ Var
Definition def.h:100
@ Hole
Definition def.h:101
void error(Loc loc, const char *f, Args &&... args)
Definition dbg.h:122
TExt< true > Top
Definition lattice.h:159
@ Kind
Definition def.h:94
@ Arr
Definition def.h:85
@ Global
Definition def.h:85
@ Var
Definition def.h:85
@ Inj
Definition def.h:85
@ Hole
Definition def.h:85
@ Sigma
Definition def.h:85
@ Uniq
Definition def.h:85
@ UMax
Definition def.h:85