17bool should_flatten(
const Def* T) {
19 if (T->isa<
Sigma>())
return true;
29 if (
auto lit = T->arity()->isa<
Lit>(); lit &&
lit->get<
u64>() <= 2) {
30 if (
auto arr = T->isa<
Arr>(); arr && arr->body()->isa<
Pi>())
return lit->get<
u64>() > 1;
36DefVec flatten_ty(
const Def* T) {
38 if (should_flatten(T)) {
39 for (
auto P : T->projs()) {
40 auto inner_types = flatten_ty(P);
41 types.insert(types.end(), inner_types.begin(), inner_types.end());
50DefVec flatten_def(
const Def* def) {
52 if (should_flatten(def->type())) {
53 for (
auto P : def->projs()) {
54 auto inner_defs = flatten_def(P);
55 defs.insert(defs.end(), inner_defs.begin(), inner_defs.end());
71 auto axm = app->
arg()->as<
Axm>();
77const Def* Reshape::rewrite_def(
const Def* def) {
78 if (
auto i = old2new_.find(def); i != old2new_.end())
return i->second;
79 auto new_def = rewrite_def_(def);
80 old2new_[def] = new_def;
84const Def* Reshape::rewrite_def_(
const Def* def) {
86 switch (def->
node()) {
100 std::stringstream ss;
101 ss << def <<
" : " << def->
type() <<
" [" << def->
node_name() <<
"]";
102 std::string str = ss.str();
105 if (def->isa<
Var>())
ELOG(
"Var: {}", def);
106 assert(!def->isa<
Var>());
108 if (
auto app = def->isa<
App>()) {
109 auto callee = rewrite_def(app->callee());
110 auto arg = rewrite_def(app->arg());
112 DLOG(
"callee: {} : {}", callee, callee->type());
115 auto reshaped_arg = reshape(arg);
116 DLOG(
"reshape arg {} : {}", arg, arg->type());
117 DLOG(
"into arg {} : {}", reshaped_arg, reshaped_arg->type());
118 auto new_app =
world().
app(callee, reshaped_arg);
121 DLOG(
"rewrite_def lam {} : {}", def, def->
type());
122 auto new_lam = reshape_lam(lam);
123 DLOG(
"rewrote lam {} : {}", def, def->
type());
124 DLOG(
"into lam {} : {}", new_lam, new_lam->type());
126 }
else if (
auto tuple = def->isa<
Tuple>()) {
127 auto elements =
DefVec(tuple->ops(), [&](
const Def* op) { return rewrite_def(op); });
130 auto new_ops =
DefVec(def->
num_ops(), [&](
auto i) { return rewrite_def(def->op(i)); });
132 auto new_type = rewrite_def(def->
type());
133 auto new_def = def->
rebuild(new_type, new_ops);
138Lam* Reshape::reshape_lam(
Lam* old_lam) {
139 if (!old_lam->is_set()) {
140 DLOG(
"reshape_lam: {} is not a set", old_lam);
143 auto pi_ty = old_lam->type();
144 auto new_ty = reshape_type(pi_ty)->as<
Pi>();
147 if (*old_lam->sym() ==
"main") {
150 new_lam = old_lam->stub(new_ty);
151 if (!old_lam->is_external()) new_lam->debug_suffix(
"_reshape");
152 old2new_[old_lam] = new_lam;
155 DLOG(
"Reshape lam: {} : {}", old_lam, pi_ty);
156 DLOG(
" to: {} : {}", new_lam, new_ty);
160 auto new_arg = new_lam->var();
164 auto reformed_new_arg = reshape(new_arg, old_lam->var()->type());
165 DLOG(
"var {} : {}", old_lam->var(), old_lam->var()->type());
166 DLOG(
"new var {} : {}", new_arg, new_arg->type());
167 DLOG(
"reshaped new_var {} : {}", reformed_new_arg, reformed_new_arg->type());
168 DLOG(
"{}", old_lam->var()->type());
169 DLOG(
"{}", reformed_new_arg->type());
170 old2new_[old_lam->var()] = reformed_new_arg;
175 auto new_body = rewrite_def(old_lam->body());
176 auto new_filter = rewrite_def(old_lam->filter());
178 new_lam->set(new_filter, new_body);
180 if (old_lam->is_external()) old_lam->transfer_external(new_lam);
182 DLOG(
"finished transforming: {} : {}", new_lam, new_ty);
186const Def* Reshape::reshape_type(
const Def* T) {
187 if (
auto pi = T->isa<
Pi>()) {
188 auto new_dom = reshape_type(pi->dom());
189 auto new_cod = reshape_type(pi->codom());
190 return world().
pi(new_dom, new_cod);
191 }
else if (
auto sigma = T->isa<
Sigma>()) {
192 auto flat_types = flatten_ty(sigma);
193 auto new_types =
DefVec(flat_types.size());
194 std::ranges::transform(flat_types, new_types.begin(), [&](
auto T) { return reshape_type(T); });
196 const Def* mem =
nullptr;
198 for (
auto i = new_types.begin(); i != new_types.end(); i++)
199 if (is_mem_ty(*i) && !mem) mem = *i;
201 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
203 if (mem) new_types.insert(new_types.begin(), mem);
204 auto reshaped_type =
world().
sigma(new_types);
205 return reshaped_type;
207 if (new_types.size() == 0)
return world().
sigma();
208 if (new_types.size() == 1)
return new_types[0];
209 const Def* mem =
nullptr;
210 const Def* ret =
nullptr;
212 for (
auto i = new_types.begin(); i != new_types.end(); i++)
213 if (is_mem_ty(*i) && !mem) mem = *i;
215 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
217 if (new_types.back()->isa<
Pi>()) {
218 ret = new_types.back();
219 new_types.pop_back();
232const Def* Reshape::reshape(
DefVec& defs,
const Def* T,
const Def* mem) {
233 auto&
world = T->world();
234 if (should_flatten(T)) {
235 auto tuples = T->projs([&](
auto P) {
return reshape(defs, P, mem); });
236 return world.tuple(tuples);
240 assert(mem !=
nullptr &&
"Reshape: mems not found");
244 assert(defs.size() > 0 &&
"Reshape: not enough arguments");
246 defs.erase(defs.begin());
247 }
while (is_mem_ty(def->type()));
250 if (!def->type()->isa<
Pi>()) {
252 world.ELOG(
"reconstruct T {} from def {}", T, def->type());
259const Def* Reshape::reshape(
const Def* def,
const Def* target) {
260 DLOG(
"reshape:\n {} =>\n {}", def->type(), target);
261 auto flat_defs = flatten_def(def);
262 const Def* mem =
nullptr;
264 for (
auto i = flat_defs.begin(); i != flat_defs.end(); i++)
265 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
266 DLOG(
"mem: {}", mem);
267 return reshape(flat_defs, target, mem);
274const Def* Reshape::reshape(
const Def* def) {
275 auto flat_defs = flatten_def(def);
276 if (flat_defs.size() == 1)
return flat_defs[0];
279 const Def* mem =
nullptr;
281 for (
auto i = flat_defs.begin(); i != flat_defs.end(); i++)
282 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
285 std::remove_if(flat_defs.begin(), flat_defs.end(), [](
const Def* def) { return is_mem_ty(def->type()); }),
288 if (mem) flat_defs.insert(flat_defs.begin(), mem);
293 const Def* mem =
nullptr;
294 const Def* ret =
nullptr;
296 for (
auto i = flat_defs.begin(); i != flat_defs.end(); i++)
297 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
300 std::remove_if(flat_defs.begin(), flat_defs.end(), [](
const Def* def) { return is_mem_ty(def->type()); }),
302 if (flat_defs.back()->type()->isa<
Pi>()) {
303 ret = flat_defs.back();
304 flat_defs.pop_back();
static auto isa(const Def *def)
static bool alpha(const Def *d1, const Def *d2)
constexpr Node node() const noexcept
std::string_view node_name() const
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
const Def * rebuild(World &w, const Def *type, Defs ops) const
Def::rebuilds this Def while using new_op as substitute for its i'th Def::op.
constexpr size_t num_ops() const noexcept
const Def * sigma(Defs ops)
const Def * app(const Def *callee, const Def *arg)
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
const Def * tuple(Defs ops)
void enter() override
Fall-through to rewrite_def which falls through to rewrite_lam.
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Vector< const Def * > DefVec
static consteval flags_t base()