13std::pair<const App*, Vector<Lam*>> isa_apped_mut_lam_in_tuple(
const Def* def) {
14 if (
auto app = def->isa<
App>()) {
16 std::deque<const Def*> wl;
17 wl.push_back(app->callee());
19 auto elem = wl.front();
21 if (
auto mut = elem->isa_mut<
Lam>()) {
23 }
else if (
auto extract = elem->isa<
Extract>()) {
24 if (
auto tuple = extract->tuple()->isa<
Tuple>())
25 for (
auto&& op : tuple->ops())
39template<
class F,
class H>
40const Def* rewrite_mut_lam_in_tuple(
const Def* def, F&& rewrite, H&& rewrite_idx) {
41 auto&
w = def->world();
42 if (
auto mut = def->isa_mut<
Lam>())
return std::forward<F>(rewrite)(mut);
44 auto extract = def->as<
Extract>();
45 auto tuple = extract->tuple()->as<
Tuple>();
46 auto new_ops =
DefVec(tuple->ops(), [&](
const Def* op) {
47 return rewrite_mut_lam_in_tuple(op, std::forward<F>(rewrite), std::forward<H>(rewrite_idx));
49 return w.extract(
w.tuple(new_ops), rewrite_idx(extract->index()));
53template<
class RewriteCallee,
class RewriteArg,
class RewriteIdx>
54const Def* rewrite_apped_mut_lam_in_tuple(
const Def* def,
55 RewriteCallee&& rewrite_callee,
56 RewriteArg&& rewrite_arg,
57 RewriteIdx&& rewrite_idx) {
58 auto app = def->as<
App>();
59 auto callee = rewrite_mut_lam_in_tuple(app->callee(), std::forward<RewriteCallee>(rewrite_callee),
60 std::forward<RewriteIdx>(rewrite_idx));
61 auto arg = std::forward<RewriteArg>(rewrite_arg)(app->arg());
62 return app->rebuild(app->type(), {callee, arg});
73const Def* AddMem::mem_for_lam(
Lam* lam)
const {
74 if (
auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
78 if (
auto it = val2mem_.find(lam); it != val2mem_.end()) {
79 DLOG(
"found mem for {} in val2mem_ : {}", lam, it->second);
85 assert(
mem &&
"mut must have mem!");
89const Def* AddMem::rewrite_type(
const Def* type) {
90 if (
auto pi = type->isa<
Pi>())
return rewrite_pi(pi);
92 if (
auto it = mem_rewritten_.find(type); it != mem_rewritten_.end())
return it->second;
94 auto new_ops =
DefVec(type->num_ops(), [&](
size_t i) { return rewrite_type(type->op(i)); });
95 return mem_rewritten_[type] = type->rebuild(type->type(), new_ops);
98const Def* AddMem::rewrite_pi(
const Pi* pi) {
99 if (
auto it = mem_rewritten_.find(pi); it != mem_rewritten_.end())
return it->second;
101 auto dom = pi->dom();
102 auto new_dom =
DefVec(dom->num_projs(), [&](
size_t i) { return rewrite_type(dom->proj(i)); });
105 =
DefVec(dom->num_projs() + 1, [&](
size_t i) { return i == 0 ? world().call<mem::M>(0) : new_dom[i - 1]; });
108 return mem_rewritten_[pi] =
world().
pi(new_dom, pi->codom());
111const Def* AddMem::add_mem_to_lams(
Lam* curr_lam,
const Def* def) {
112 auto place =
static_cast<Lam*
>(sched_.smart(curr_lam, def)->mut());
116 if (
auto global = def->isa<
Global>())
return global;
117 if (
auto mut_lam = def->isa_mut<
Lam>(); mut_lam && !mut_lam->is_set())
return def;
118 if (
auto ax = def->isa<
Axm>())
return ax;
119 if (
auto it = mem_rewritten_.find(def); it != mem_rewritten_.end()) {
120 auto tmp = it->second;
122 DLOG(
"already known mem {} in {}", def, curr_lam);
123 auto new_mem = mem_for_lam(curr_lam);
124 DLOG(
"new mem {} in {}", new_mem, curr_lam);
127 if (curr_lam != def) {
134 auto rewrite_lam = [&](
Lam* lam) ->
const Def* {
135 auto pi = lam->
type()->as<
Pi>();
138 if (
auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
140 new_lam = it->second->as_mut<
Lam>();
141 else if (
auto pi = it->second->type()->as<
Pi>(); pi->num_doms() > 0 &&
Axm::isa<mem::M>(pi->dom(0_s)))
145 if (!lam->
is_set())
return lam;
146 DLOG(
"rewrite lam {}", lam);
148 bool is_bound = sched_.nest().contains(lam) || lam == curr_lam;
151 if (
auto new_pi = rewrite_pi(pi); new_pi != pi) new_lam = lam->
stub(new_pi);
154 DLOG(
"free lam {}", lam);
155 mem_rewritten_[lam] = new_lam;
159 auto var_offset = new_lam->num_doms() - lam->num_doms();
160 if (lam->
num_vars() != 0) mem_rewritten_[lam->
var()] = new_lam->
var();
161 for (
size_t i = 0; i < lam->
num_vars() && new_lam->num_vars() > 1; ++i)
162 mem_rewritten_[lam->
var(i)] = new_lam->
var(i + var_offset);
164 auto var = new_lam->
var(0_n);
165 mem_rewritten_[new_lam] = new_lam;
166 mem_rewritten_[lam] = new_lam;
167 val2mem_[new_lam] = var;
169 mem_rewritten_[var] = var;
170 auto filter = add_mem_to_lams(lam, lam->
filter());
171 auto body = add_mem_to_lams(lam, lam->
body());
172 new_lam->unset()->set({filter, body});
179 if (
auto lam = def->isa_mut<
Lam>())
return rewrite_lam(lam);
180 assert(!def->isa_mut());
182 if (
auto pi = def->isa<
Pi>())
return rewrite_pi(pi);
184 auto rewrite_arg = [&](
const Def* arg) ->
const Def* {
185 size_t offset = (arg->type()->num_projs() > 0 &&
Axm::isa<mem::M>(arg->type()->proj(0))) ? 0 : 1;
188 add_mem_to_lams(place, arg->proj(0));
191 DefVec new_args{arg->type()->num_projs() + offset};
192 for (
int i = new_args.size() - 1; i >= 0; i--) {
194 = i == 0 ? add_mem_to_lams(place, mem_for_lam(place)) : add_mem_to_lams(place, arg->proj(i - offset));
200 if (
auto apped_mut = isa_apped_mut_lam_in_tuple(def); apped_mut.first) {
201 return mem_rewritten_[def]
202 = rewrite_apped_mut_lam_in_tuple(def, std::move(rewrite_lam), std::move(rewrite_arg),
203 [&](
const Def* def) {
return add_mem_to_lams(place, def); });
207 if (
auto app = def->isa<
App>(); app && (app->callee()->has_dep(
Dep::Var))) {
208 return mem_rewritten_[def]
209 = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()), rewrite_arg(app->arg())});
214 if (
auto app = def->isa<
App>(); app && app->axm() && app->curry() ^ 0x8000) {
215 auto arg = app->arg();
216 DefVec new_args(arg->num_projs());
217 for (
int i = new_args.size() - 1; i >= 0; i--) {
221 add_mem_to_lams(place, arg->proj(i));
222 new_args[i] = add_mem_to_lams(place, mem_for_lam(place));
224 new_args[i] = add_mem_to_lams(place, arg->proj(i));
227 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()),
228 world().tuple(new_args)->set(arg->dbg())})
231 DLOG(
"memory from axm {} : {}", rewritten, rewritten->type());
232 val2mem_[place] = rewritten;
234 if (rewritten->num_projs() > 0 &&
Axm::isa<mem::M>(rewritten->proj(0)->type())) {
235 DLOG(
"memory from axm 2 {} : {}", rewritten, rewritten->type());
236 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
237 val2mem_[place] = rewritten->proj(0);
243 if (
auto app = def->isa<
App>()) {
244 auto new_callee = add_mem_to_lams(place, app->callee());
245 auto new_arg = add_mem_to_lams(place, app->arg());
246 if (app->callee()->type()->as<
Pi>()->num_doms() + 1 == new_callee->type()->as<
Pi>()->num_doms())
247 new_arg = rewrite_arg(app->arg());
248 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {new_callee, new_arg})->set(app->dbg());
250 DLOG(
"memory from other {} : {}", rewritten, rewritten->type());
251 val2mem_[place] = rewritten;
253 if (rewritten->num_projs() > 0 &&
Axm::isa<mem::M>(rewritten->proj(0)->type())) {
254 DLOG(
"memory from other 2 {} : {}", rewritten, rewritten->type());
255 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
256 val2mem_[place] = rewritten->proj(0);
261 auto new_ops =
DefVec(def->ops(), [&](
const Def* op) {
262 if (Axm::isa<mem::M>(op->type())) {
264 add_mem_to_lams(place, op);
265 return add_mem_to_lams(place, mem_for_lam(place));
267 return add_mem_to_lams(place, op);
270 auto tmp = mem_rewritten_[def] = def->rebuild(rewrite_type(def->type()), new_ops)->set(def->dbg());
static auto isa(const Def *def)
bool is_set() const
Yields true if empty or the last op is set.
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
const Def * var(nat_t a, nat_t i) noexcept
void transfer_external(Def *to)
nat_t num_vars() noexcept
bool is_external() const noexcept
const Def * filter() const
Lam * stub(const Def *type)
const Nest & nest() const
Builds a nesting tree of all immutables‍/binders.
A dependent function type.
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
const Def * tuple(Defs ops)
void visit(const Nest &) override
#define DLOG(...)
Vaporizes to nothingness in Debug build.
const Def * mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Vector< const Def * > DefVec
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >