21bool should_flatten(
const Def* T) {
23 if (T->isa<Sigma>())
return true;
33 if (
auto lit = T->arity()->isa<Lit>(); lit &&
lit->get<
u64>() <= 2) {
34 if (
auto arr = T->isa<Arr>(); arr && arr->body()->isa<Pi>())
return lit->get<
u64>() > 1;
40DefVec flatten_ty(
const Def* T) {
42 if (should_flatten(T)) {
43 for (
auto P : T->projs()) {
44 auto inner_types = flatten_ty(P);
45 types.insert(types.end(), inner_types.begin(), inner_types.end());
54DefVec flatten_def(
const Def* def) {
56 if (should_flatten(def->type())) {
57 for (
auto P : def->projs()) {
58 auto inner_defs = flatten_def(P);
59 defs.insert(defs.end(), inner_defs.begin(), inner_defs.end());
71const Def* Reshape::rewrite_def(
const Def* def) {
72 if (
auto i = old2new_.find(def); i != old2new_.end())
return i->second;
73 auto new_def = rewrite_def_(def);
74 old2new_[def] = new_def;
78const Def* Reshape::rewrite_def_(
const Def* def) {
80 switch (def->
node()) {
94 ss << def <<
" : " << def->
type() <<
" [" << def->
node_name() <<
"]";
95 std::string str = ss.str();
98 if (def->isa<
Var>())
world().ELOG(
"Var: {}", def);
99 assert(!def->isa<
Var>());
101 if (
auto app = def->isa<
App>()) {
102 auto callee = rewrite_def(app->callee());
103 auto arg = rewrite_def(app->arg());
105 world().DLOG(
"callee: {} : {}", callee, callee->type());
108 auto reshaped_arg = reshape(arg);
109 world().DLOG(
"reshape arg {} : {}", arg, arg->type());
110 world().DLOG(
"into arg {} : {}", reshaped_arg, reshaped_arg->type());
111 auto new_app =
world().
app(callee, reshaped_arg);
113 }
else if (
auto lam = def->
isa_mut<Lam>()) {
114 world().DLOG(
"rewrite_def lam {} : {}", def, def->
type());
115 auto new_lam = reshape_lam(lam);
116 world().DLOG(
"rewrote lam {} : {}", def, def->
type());
117 world().DLOG(
"into lam {} : {}", new_lam, new_lam->type());
119 }
else if (
auto tuple = def->isa<Tuple>()) {
120 auto elements =
DefVec(tuple->ops(), [&](
const Def* op) { return rewrite_def(op); });
123 auto new_ops =
DefVec(def->
num_ops(), [&](
auto i) { return rewrite_def(def->op(i)); });
125 auto new_type = rewrite_def(def->
type());
126 auto new_def = def->
rebuild(new_type, new_ops);
131Lam* Reshape::reshape_lam(Lam* old_lam) {
132 if (!old_lam->is_set()) {
133 world().DLOG(
"reshape_lam: {} is not a set", old_lam);
136 auto pi_ty = old_lam->type();
137 auto new_ty = reshape_type(pi_ty)->as<
Pi>();
140 if (*old_lam->sym() ==
"main") {
143 new_lam = old_lam->stub(new_ty);
144 if (!old_lam->is_external()) new_lam->debug_suffix(
"_reshape");
145 old2new_[old_lam] = new_lam;
148 world().DLOG(
"Reshape lam: {} : {}", old_lam, pi_ty);
149 world().DLOG(
" to: {} : {}", new_lam, new_ty);
153 auto new_arg = new_lam->
var();
157 auto reformed_new_arg = reshape(new_arg, old_lam->var()->type());
158 world().DLOG(
"var {} : {}", old_lam->var(), old_lam->var()->type());
159 world().DLOG(
"new var {} : {}", new_arg, new_arg->type());
160 world().DLOG(
"reshaped new_var {} : {}", reformed_new_arg, reformed_new_arg->type());
161 world().DLOG(
"{}", old_lam->var()->type());
162 world().DLOG(
"{}", reformed_new_arg->type());
163 old2new_[old_lam->var()] = reformed_new_arg;
168 auto new_body = rewrite_def(old_lam->body());
169 auto new_filter = rewrite_def(old_lam->filter());
171 new_lam->set(new_filter, new_body);
173 if (old_lam->is_external()) old_lam->transfer_external(new_lam);
175 world().DLOG(
"finished transforming: {} : {}", new_lam, new_ty);
179const Def* Reshape::reshape_type(
const Def* T) {
180 if (
auto pi = T->isa<Pi>()) {
181 auto new_dom = reshape_type(pi->dom());
182 auto new_cod = reshape_type(pi->codom());
183 return world().
pi(new_dom, new_cod);
184 }
else if (
auto sigma = T->isa<Sigma>()) {
185 auto flat_types = flatten_ty(sigma);
186 auto new_types =
DefVec(flat_types.size());
187 std::ranges::transform(flat_types, new_types.begin(), [&](
auto T) { return reshape_type(T); });
189 const Def*
mem =
nullptr;
191 for (
auto i = new_types.begin(); i != new_types.end(); i++)
192 if (is_mem_ty(*i) && !
mem)
mem = *i;
194 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
196 if (
mem) new_types.insert(new_types.begin(),
mem);
197 auto reshaped_type =
world().
sigma(new_types);
198 return reshaped_type;
200 if (new_types.size() == 0)
return world().
sigma();
201 if (new_types.size() == 1)
return new_types[0];
202 const Def*
mem =
nullptr;
203 const Def* ret =
nullptr;
205 for (
auto i = new_types.begin(); i != new_types.end(); i++)
206 if (is_mem_ty(*i) && !
mem)
mem = *i;
208 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
210 if (new_types.back()->isa<Pi>()) {
211 ret = new_types.back();
212 new_types.pop_back();
225const Def* Reshape::reshape(
DefVec& defs,
const Def* T,
const Def*
mem) {
226 auto&
world = T->world();
227 if (should_flatten(T)) {
228 auto tuples = T->projs([&](
auto P) {
return reshape(defs, P,
mem); });
233 assert(
mem !=
nullptr &&
"Reshape: mems not found");
237 assert(defs.size() > 0 &&
"Reshape: not enough arguments");
239 defs.erase(defs.begin());
240 }
while (is_mem_ty(def->type()));
243 if (!def->type()->isa<Pi>()) {
244 if (!
Check::alpha(def->type(), T))
world.ELOG(
"reconstruct T {} from def {}", T, def->type());
245 assert(
Check::alpha(def->type(), T) &&
"Reshape: argument type mismatch");
251const Def* Reshape::reshape(
const Def* def,
const Def* target) {
252 def->world().DLOG(
"reshape:\n {} =>\n {}", def->type(), target);
253 auto flat_defs = flatten_def(def);
254 const Def*
mem =
nullptr;
256 for (
auto i = flat_defs.begin(); i != flat_defs.end(); i++)
257 if (is_mem_ty((*i)->type()) && !
mem)
mem = *i;
258 def->world().DLOG(
"mem: {}",
mem);
259 return reshape(flat_defs, target,
mem);
266const Def* Reshape::reshape(
const Def* def) {
267 auto flat_defs = flatten_def(def);
268 if (flat_defs.size() == 1)
return flat_defs[0];
271 const Def*
mem =
nullptr;
273 for (
auto i = flat_defs.begin(); i != flat_defs.end(); i++)
274 if (is_mem_ty((*i)->type()) && !
mem)
mem = *i;
277 std::remove_if(flat_defs.begin(), flat_defs.end(), [](
const Def* def) { return is_mem_ty(def->type()); }),
280 if (
mem) flat_defs.insert(flat_defs.begin(),
mem);
285 const Def*
mem =
nullptr;
286 const Def* ret =
nullptr;
288 for (
auto i = flat_defs.begin(); i != flat_defs.end(); i++)
289 if (is_mem_ty((*i)->type()) && !
mem)
mem = *i;
292 std::remove_if(flat_defs.begin(), flat_defs.end(), [](
const Def* def) { return is_mem_ty(def->type()); }),
294 if (flat_defs.back()->type()->isa<Pi>()) {
295 ret = flat_defs.back();
296 flat_defs.pop_back();
static bool alpha(Ref d1, Ref d2)
Are d1 and d2 α-equivalent?
std::string_view node_name() const
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Ref rebuild(World &w, Ref type, Defs ops) const
Def::rebuilds this Def while using new_op as substitute for its i'th Def::op.
Ref var(Ref type, Def *mut)
Ref app(Ref callee, Ref arg)
const Pi * pi(Ref dom, Ref codom, bool implicit=false)
void enter() override
Fall-through to rewrite_def which falls through to rewrite_lam.
Vector< const Def * > DefVec