10class InferRewriter :
public Rewriter {
12 InferRewriter(World& world)
15 Ref
rewrite(Ref old_def)
override {
32 for (
auto infer = res->isa_mut<
Infer>(); infer && infer->
op(); infer = res->
isa_mut<
Infer>()) res = infer->
op();
41 assert((!res->isa<
Infer>() || res != res->
op(0)) &&
"an Infer shouldn't point to itself");
46 if (res->isa<
Tuple>() || res->isa<
Type>()) {
48 bool update = new_type != res->type();
50 auto new_ops =
DefVec(res->num_ops(), [res, &
update](
size_t i) {
51 auto r = Ref::refer(res->op(i));
52 update |= r != res->op(i);
56 if (
update)
return res->rebuild(new_type, new_ops);
63 if (std::ranges::any_of(refs, [](
auto pref) {
return should_eliminate(*pref); })) {
64 auto&
world = (*refs.front())->world();
65 InferRewriter rw(
world);
66 for (
size_t i = 0, e = refs.size(); i != e; ++i) {
68 *refs[i] = ref->has_dep(
Dep::Infer) ? rw.rewrite(ref) : ref;
78#ifdef MIM_ENABLE_CHECKS
79template<
bool infer>
bool Check::fail() {
80 if (infer &&
world().flags().break_on_alpha_unequal) fe::breakpoint();
85template<
bool infer>
bool Check::alpha_(Ref r1, Ref r2) {
89 if (!d1 && !d2)
return true;
90 if (!d1 || !d2)
return fail<infer>();
95 if (!d1->has_dep(
Dep::Var) && !d2->has_dep(
Dep::Var) && d1 == d2)
return true;
96 auto mut1 = d1->isa_mut();
97 auto mut2 = d2->isa_mut();
98 if (mut1 && mut2 && mut1 == mut2)
return true;
101 if (d1->isa<Global>() || d2->isa<Global>())
return false;
104 if (
auto [i, ins] = done_.emplace(mut1, d2); !ins)
return i->second == d2;
107 if (
auto [i, ins] = done_.emplace(mut2, d1); !ins)
return i->second == d1;
110 auto i1 = d1->isa_mut<
Infer>();
111 auto i2 = d2->isa_mut<
Infer>();
113 if ((!i1 && !d1->is_set()) || (!i2 && !d2->is_set()))
return fail<infer>();
118 if (i1->rank() < i2->rank()) std::swap(i1, i2);
120 if (i1->rank() == i2->rank()) ++i1->rank();
132 if ((d1->isa<Lit>() && !d2->isa<Lit>())
133 || (!d1->isa<UMax>() && d2->isa<UMax>())
134 || (d1->gid() > d2->gid()))
137 return alpha_internal<infer>(d1, d2);
140template<
bool infer>
bool Check::alpha_internal(Ref d1, Ref d2) {
141 if (!alpha_<infer>(d1->type(), d2->type()))
return fail<infer>();
142 if (d1->isa<
Top>() || d2->isa<
Top>())
return infer;
143 if (!infer && (d1->isa_mut<
Infer>() || d2->isa_mut<
Infer>()))
return fail<infer>();
144 if (!alpha_<infer>(d1->arity(), d2->arity()))
return fail<infer>();
147 if (
auto mut1 = d1->isa_mut())
assert_emplace(vars_, mut1, d2->isa_mut());
148 if (
auto mut2 = d2->isa_mut())
assert_emplace(vars_, mut2, d1->isa_mut());
150 if (
auto ts = d1->isa<Tuple, Sigma>()) {
151 size_t a = ts->num_ops();
152 for (
size_t i = 0; i !=
a; ++i)
153 if (!alpha_<infer>(ts->op(i), d2->proj(a, i)))
return fail<infer>();
155 }
else if (
auto pa = d1->isa<Pack, Arr>()) {
156 if (pa->node() == d2->node())
return alpha_<infer>(pa->ops().back(), d2->ops().back());
157 if (
auto a = pa->isa_lit_arity()) {
158 for (
size_t i = 0; i != *
a; ++i)
159 if (!alpha_<infer>(pa->proj(*a, i), d2->proj(*a, i)))
return fail<infer>();
162 }
else if (
auto umax = d1->isa<UMax>(); umax &&
umax->has_dep(
Dep::Infer) && !d2->isa<UMax>()) {
164 for (
auto op :
umax->ops())
165 if (
auto inf =
op->
isa_mut<
Infer>(); inf && !inf->is_set()) inf->set(d2);
169 if (d1->node() != d2->node() || d1->flags() != d2->flags() || d1->num_ops() != d2->num_ops())
return fail<infer>();
171 if (
auto var1 = d1->isa<
Var>()) {
172 auto var2 = d2->as<
Var>();
173 if (
auto i = vars_.find(var1->mut()); i != vars_.end())
return i->second == var2->mut();
174 if (
auto i = vars_.find(var2->mut()); i != vars_.end())
return fail<infer>();
176 return var1 == var2 || infer;
179 for (
size_t i = 0, e = d1->num_ops(); i != e; ++i)
180 if (!alpha_<infer>(d1->op(i), d2->op(i)))
return fail<infer>();
184bool Check::assignable_(Ref type, Ref val) {
186 if (type == val_ty)
return true;
188 if (
auto infer = val->isa_mut<
Infer>())
return alpha_<true>(type, infer->type());
190 if (
auto sigma = type->isa<Sigma>()) {
191 if (!alpha_<true>(type->arity(), val_ty->arity()))
return fail<true>();
193 size_t a = sigma->num_ops();
194 auto red = sigma->reduce(val);
195 for (
size_t i = 0; i !=
a; ++i)
196 if (!assignable_(red[i], val->proj(a, i)))
return fail<true>();
198 }
else if (
auto arr = type->isa<Arr>()) {
199 if (!alpha_<true>(type->arity(), val_ty->arity()))
return fail<true>();
201 if (
auto a =
Lit::isa(arr->arity())) {
202 for (
size_t i = 0; i != *
a; ++i)
203 if (!assignable_(arr->proj(*a, i), val->proj(*a, i)))
return fail<true>();
206 }
else if (
auto vel = val->isa<Vel>()) {
207 return assignable_(type, vel->value());
210 return alpha_<true>(type, val_ty);
214 if (defs.empty())
return nullptr;
215 auto first = defs.front();
216 for (
size_t i = 1, e = defs.size(); i != e; ++i)
228 error(
type()->
loc(),
"declared sort '{}' of array does not match inferred one '{}'",
type(), t);
233 if (
ops.size() == 0)
return w.type<1>();
234 auto kinds =
DefVec(
ops.size(), [
ops](
size_t i) { return ops[i]->unfold_type(); });
246 "incorrect type '{}' for '{}'. Correct one would be: '{}'. I'll keep this one nevertheless due to "
254 error(
filter()->
loc(),
"filter '{}' of lambda is of type '{}' but must be of type 'Bool'",
filter(),
259 .
error(
body()->
loc(),
"body of function is not assignable to declared codomain")
274 error(
type()->
loc(),
"declared sort '{}' of function type does not match inferred one '{}'",
type(), t);
279template bool Check::alpha_<true>(
Ref,
Ref);
280template bool Check::alpha_<false>(
Ref,
Ref);
static bool alpha(Ref d1, Ref d2)
Are d1 and d2 α-equivalent?
static bool assignable(Ref type, Ref value)
Can value be assigned to sth of type?
static Ref is_uniform(Defs defs)
Yields defs.front(), if all defs are Check::alpha-equivalent (infer = false) and nullptr otherwise.
const Def * op(size_t i) const
Def * set_type(const Def *)
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
const Def * unfold_type() const
Yields the type of this Def and builds a new Type (UInc n) if necessary.
void update()
Resolves Infers of this Def's type.
Def * reset(size_t i, const Def *def)
Successively reset from left to right.
Error & error(Loc loc, const char *s, Args &&... args)
Error & note(Loc loc, const char *s, Args &&... args)
This node is a hole in the IR that is inferred by its context later on.
static bool eliminate(Vector< Ref * >)
Eliminate Infers that may have been resolved in the meantime by rebuilding.
static bool should_eliminate(Ref def)
static const Def * find(const Def *)
Union-Find to unify Infer nodes.
static std::optional< T > isa(Ref def)
static Ref infer(Ref dom, Ref codom)
Helper class to retrieve Infer::arg if present.
static const Def * refer(const Def *def)
Retrieves Infer::arg from def.
static Ref infer(World &, Defs)
This is a thin wrapper for std::span<T, N> with the following additional features:
Data constructor for a Sigma.
This is a thin wrapper for absl::InlinedVector<T, N, / A> which in turn is a drop-in replacement for ...
The World represents the whole program and manages creation of MimIR nodes (Defs).
Ref op(trait o, Ref type)
Vector< const Def * > DefVec
auto assert_emplace(C &container, Args &&... args)
Invokes emplace on container, asserts that insertion actually happened, and returns the iterator.
void error(Loc loc, const char *f, Args &&... args)
DefVec rewrite(Def *mut, Ref arg)