13bool Scalarize::should_expand(Lam* lam) {
15 if (
auto i = tup2sca_.find(lam); i != tup2sca_.end() && i->second && i->second == lam)
return false;
17 auto pi = lam->type();
24Lam* Scalarize::make_scalar(Ref def) {
25 auto tup_lam = def->isa_mut<
Lam>();
27 if (
auto i = tup2sca_.find(tup_lam); i != tup2sca_.end())
return i->second;
32 for (
size_t i = 0, e = tup_lam->num_doms(); i != e; ++i) {
33 auto n =
flatten(types, tup_lam->dom(i),
false);
35 todo |= n != 1 || types.back() != tup_lam->dom(i);
38 if (!todo)
return tup2sca_[tup_lam] = tup_lam;
41 auto sca_lam = tup_lam->
stub(cn);
42 if (eta_exp_) eta_exp_->
new2old(sca_lam, tup_lam);
44 world().DLOG(
"type {} ~> {}", tup_lam->type(), cn);
46 auto tuple = DefVec(arg_sz.at(i), [&](auto) { return sca_lam->var(n++); });
47 return unflatten(tuple, tup_lam->dom(i),
false);
49 sca_lam->
set(tup_lam->reduce(new_vars));
50 tup2sca_[sca_lam] = sca_lam;
51 tup2sca_.emplace(tup_lam, sca_lam);
52 world().DLOG(
"lambda {} : {} ~> {} : {}", tup_lam, tup_lam->type(), sca_lam, sca_lam->type());
58 if (
auto app = def->isa<
App>()) {
59 Ref sca_callee = app->callee();
61 if (
auto tup_lam = sca_callee->
isa_mut<
Lam>(); should_expand(tup_lam)) {
62 sca_callee = make_scalar(tup_lam);
64 }
else if (
auto proj = sca_callee->isa<
Extract>()) {
65 auto tuple = proj->tuple()->isa<
Tuple>();
66 if (tuple && std::all_of(tuple->ops().begin(), tuple->ops().end(), [&](
Ref op) {
67 return should_expand(op->isa_mut<Lam>());
69 auto new_tuple = w.tuple(
DefVec(tuple->num_ops(), [&](
auto i) { return make_scalar(tuple->op(i)); }));
70 sca_callee = w.extract(new_tuple, proj->index());
71 w.DLOG(
"Expand tuple: {, } ~> {, }", tuple->ops(), new_tuple->ops());
75 if (sca_callee != app->callee()) {
77 flatten(new_args, app->arg(),
false);
78 return world().app(sca_callee, new_args);
Def * set(size_t i, const Def *def)
Successively set from left to right.
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
const T * isa_imm() const
void new2old(Lam *new_lam, Lam *old_lam)
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Helper class to retrieve Infer::arg if present.
Data constructor for a Sigma.
const Def * flatten(const Def *def)
Flattens a sigma/array/pack/tuple.
Vector< const Def * > DefVec
const Def * unflatten(const Def *def, const Def *type)
Applies the reverse transformation on a Pack / Tuple, given the original type.
Lam * isa_workable(Lam *lam)
These are Lams that are neither nullptr, nor Lam::is_external, nor Lam::is_unset.
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >