21bool should_flatten(
const Def* T) {
23 if (T->isa<
Sigma>())
return true;
33 if (
auto lit = T->arity()->isa<
Lit>(); lit &&
lit->get<
u64>() <= 2) {
34 if (
auto arr = T->isa<
Arr>(); arr && arr->body()->isa<
Pi>())
return lit->get<
u64>() > 1;
40DefVec flatten_ty(
const Def* T) {
42 if (should_flatten(T)) {
43 for (
auto P : T->projs()) {
44 auto inner_types = flatten_ty(P);
45 types.insert(types.end(), inner_types.begin(), inner_types.end());
54DefVec flatten_def(
const Def* def) {
56 if (should_flatten(def->type())) {
57 for (
auto P : def->projs()) {
58 auto inner_defs = flatten_def(P);
59 defs.insert(defs.end(), inner_defs.begin(), inner_defs.end());
71const Def* Reshape::rewrite_def(
const Def* def) {
72 if (
auto i = old2new_.find(def); i != old2new_.end())
return i->second;
73 auto new_def = rewrite_def_(def);
74 old2new_[def] = new_def;
78const Def* Reshape::rewrite_def_(
const Def* def) {
80 switch (def->
node()) {
95 ss << def <<
" : " << def->
type() <<
" [" << def->
node_name() <<
"]";
96 std::string str = ss.str();
99 if (def->isa<
Var>())
world().ELOG(
"Var: {}", def);
100 assert(!def->isa<
Var>());
102 if (
auto app = def->isa<
App>()) {
103 auto callee = rewrite_def(app->callee());
104 auto arg = rewrite_def(app->arg());
106 world().DLOG(
"callee: {} : {}", callee, callee->type());
109 auto reshaped_arg = reshape(arg);
110 world().DLOG(
"reshape arg {} : {}", arg, arg->type());
111 world().DLOG(
"into arg {} : {}", reshaped_arg, reshaped_arg->type());
112 auto new_app =
world().
app(callee, reshaped_arg);
115 world().DLOG(
"rewrite_def lam {} : {}", def, def->
type());
116 auto new_lam = reshape_lam(lam);
117 world().DLOG(
"rewrote lam {} : {}", def, def->
type());
118 world().DLOG(
"into lam {} : {}", new_lam, new_lam->type());
120 }
else if (
auto tuple = def->isa<
Tuple>()) {
121 auto elements =
DefVec(tuple->ops(), [&](
const Def* op) { return rewrite_def(op); });
124 auto new_ops =
DefVec(def->
num_ops(), [&](
auto i) { return rewrite_def(def->op(i)); });
126 auto new_type = rewrite_def(def->
type());
127 auto new_def = def->
rebuild(new_type, new_ops);
132Lam* Reshape::reshape_lam(
Lam* old_lam) {
133 if (!old_lam->is_set()) {
134 world().DLOG(
"reshape_lam: {} is not a set", old_lam);
137 auto pi_ty = old_lam->type();
138 auto new_ty = reshape_type(pi_ty)->as<
Pi>();
141 if (*old_lam->sym() ==
"main") {
144 new_lam = old_lam->stub(new_ty);
145 if (!old_lam->is_external()) new_lam->debug_suffix(
"_reshape");
146 old2new_[old_lam] = new_lam;
149 world().DLOG(
"Reshape lam: {} : {}", old_lam, pi_ty);
150 world().DLOG(
" to: {} : {}", new_lam, new_ty);
154 auto new_arg = new_lam->
var();
158 auto reformed_new_arg = reshape(new_arg, old_lam->var()->type());
159 world().DLOG(
"var {} : {}", old_lam->var(), old_lam->var()->type());
160 world().DLOG(
"new var {} : {}", new_arg, new_arg->type());
161 world().DLOG(
"reshaped new_var {} : {}", reformed_new_arg, reformed_new_arg->type());
162 world().DLOG(
"{}", old_lam->var()->type());
163 world().DLOG(
"{}", reformed_new_arg->type());
164 old2new_[old_lam->var()] = reformed_new_arg;
169 auto new_body = rewrite_def(old_lam->body());
170 auto new_filter = rewrite_def(old_lam->filter());
172 new_lam->set(new_filter, new_body);
174 if (old_lam->is_external()) old_lam->transfer_external(new_lam);
176 world().DLOG(
"finished transforming: {} : {}", new_lam, new_ty);
180const Def* Reshape::reshape_type(
const Def* T) {
181 if (
auto pi = T->isa<
Pi>()) {
182 auto new_dom = reshape_type(pi->dom());
183 auto new_cod = reshape_type(pi->codom());
184 return world().
pi(new_dom, new_cod);
185 }
else if (
auto sigma = T->isa<
Sigma>()) {
186 auto flat_types = flatten_ty(sigma);
187 auto new_types =
DefVec(flat_types.size());
188 std::ranges::transform(flat_types, new_types.begin(), [&](
auto T) { return reshape_type(T); });
190 const Def* mem =
nullptr;
192 for (
auto i = new_types.begin(); i != new_types.end(); i++)
193 if (is_mem_ty(*i) && !mem) mem = *i;
195 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
197 if (mem) new_types.insert(new_types.begin(), mem);
198 auto reshaped_type =
world().
sigma(new_types);
199 return reshaped_type;
201 if (new_types.size() == 0)
return world().
sigma();
202 if (new_types.size() == 1)
return new_types[0];
203 const Def* mem =
nullptr;
204 const Def* ret =
nullptr;
206 for (
auto i = new_types.begin(); i != new_types.end(); i++)
207 if (is_mem_ty(*i) && !mem) mem = *i;
209 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
211 if (new_types.back()->isa<
Pi>()) {
212 ret = new_types.back();
213 new_types.pop_back();
226const Def* Reshape::reshape(
DefVec& defs,
const Def* T,
const Def* mem) {
227 auto&
world = T->world();
228 if (should_flatten(T)) {
229 auto tuples = T->projs([&](
auto P) {
return reshape(defs, P, mem); });
230 return world.tuple(tuples);
234 assert(mem !=
nullptr &&
"Reshape: mems not found");
238 assert(defs.size() > 0 &&
"Reshape: not enough arguments");
240 defs.erase(defs.begin());
241 }
while (is_mem_ty(def->type()));
244 if (!def->type()->isa<
Pi>()) {
246 world.ELOG(
"reconstruct T {} from def {}", T, def->type());
253const Def* Reshape::reshape(
const Def* def,
const Def* target) {
254 world().DLOG(
"reshape:\n {} =>\n {}", def->type(), target);
255 auto flat_defs = flatten_def(def);
256 const Def* mem =
nullptr;
258 for (
auto i = flat_defs.begin(); i != flat_defs.end(); i++)
259 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
260 world().DLOG(
"mem: {}", mem);
261 return reshape(flat_defs, target, mem);
268const Def* Reshape::reshape(
const Def* def) {
269 auto flat_defs = flatten_def(def);
270 if (flat_defs.size() == 1)
return flat_defs[0];
273 const Def* mem =
nullptr;
275 for (
auto i = flat_defs.begin(); i != flat_defs.end(); i++)
276 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
279 std::remove_if(flat_defs.begin(), flat_defs.end(), [](
const Def* def) { return is_mem_ty(def->type()); }),
282 if (mem) flat_defs.insert(flat_defs.begin(), mem);
287 const Def* mem =
nullptr;
288 const Def* ret =
nullptr;
290 for (
auto i = flat_defs.begin(); i != flat_defs.end(); i++)
291 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
294 std::remove_if(flat_defs.begin(), flat_defs.end(), [](
const Def* def) { return is_mem_ty(def->type()); }),
296 if (flat_defs.back()->type()->isa<
Pi>()) {
297 ret = flat_defs.back();
298 flat_defs.pop_back();
static bool alpha(const Def *d1, const Def *d2)
constexpr Node node() const noexcept
std::string_view node_name() const
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
const Def * rebuild(World &w, const Def *type, Defs ops) const
Def::rebuilds this Def while using new_op as substitute for its i'th Def::op.
constexpr size_t num_ops() const noexcept
const Def * sigma(Defs ops)
const Def * app(const Def *callee, const Def *arg)
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
const Def * tuple(Defs ops)
const Def * var(const Def *type, Def *mut)
void enter() override
Fall-through to rewrite_def which falls through to rewrite_lam.
Vector< const Def * > DefVec
auto match(const Def *def)