11std::pair<const App*, Vector<Lam*>> isa_apped_mut_lam_in_tuple(
const Def* def) {
12 if (
auto app = def->isa<
App>()) {
14 std::deque<const Def*> wl;
15 wl.push_back(app->callee());
17 auto elem = wl.front();
19 if (
auto mut = elem->isa_mut<
Lam>()) {
21 }
else if (
auto extract = elem->isa<
Extract>()) {
22 if (
auto tuple = extract->tuple()->isa<
Tuple>())
23 for (
auto&& op : tuple->ops())
37template<
class F,
class H>
38const Def* rewrite_mut_lam_in_tuple(
const Def* def, F&& rewrite, H&& rewrite_idx) {
39 auto&
w = def->world();
40 if (
auto mut = def->isa_mut<
Lam>())
return std::forward<F>(rewrite)(mut);
42 auto extract = def->as<
Extract>();
43 auto tuple = extract->tuple()->as<
Tuple>();
44 auto new_ops =
DefVec(tuple->ops(), [&](
const Def* op) {
45 return rewrite_mut_lam_in_tuple(op, std::forward<F>(rewrite), std::forward<H>(rewrite_idx));
47 return w.extract(
w.tuple(new_ops), rewrite_idx(extract->index()));
51template<
class RewriteCallee,
class RewriteArg,
class RewriteIdx>
52const Def* rewrite_apped_mut_lam_in_tuple(
const Def* def,
53 RewriteCallee&& rewrite_callee,
54 RewriteArg&& rewrite_arg,
55 RewriteIdx&& rewrite_idx) {
56 auto app = def->as<
App>();
57 auto callee = rewrite_mut_lam_in_tuple(app->callee(), std::forward<RewriteCallee>(rewrite_callee),
58 std::forward<RewriteIdx>(rewrite_idx));
59 auto arg = std::forward<RewriteArg>(rewrite_arg)(app->arg());
60 return app->rebuild(app->type(), {callee, arg});
71const Def* AddMem::mem_for_lam(
Lam* lam)
const {
72 if (
auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
76 if (
auto it = val2mem_.find(lam); it != val2mem_.end()) {
77 DLOG(
"found mem for {} in val2mem_ : {}", lam, it->second);
83 assert(
mem &&
"mut must have mem!");
87const Def* AddMem::rewrite_type(
const Def* type) {
88 if (
auto pi = type->isa<
Pi>())
return rewrite_pi(pi);
90 if (
auto it = mem_rewritten_.find(type); it != mem_rewritten_.end())
return it->second;
92 auto new_ops =
DefVec(type->num_ops(), [&](
size_t i) { return rewrite_type(type->op(i)); });
93 return mem_rewritten_[type] = type->rebuild(type->type(), new_ops);
96const Def* AddMem::rewrite_pi(
const Pi* pi) {
97 if (
auto it = mem_rewritten_.find(pi); it != mem_rewritten_.end())
return it->second;
100 auto new_dom =
DefVec(dom->num_projs(), [&](
size_t i) { return rewrite_type(dom->proj(i)); });
103 =
DefVec(dom->num_projs() + 1, [&](
size_t i) { return i == 0 ? world().annex<mem::M>() : new_dom[i - 1]; });
106 return mem_rewritten_[pi] =
world().
pi(new_dom, pi->codom());
109const Def* AddMem::add_mem_to_lams(
Lam* curr_lam,
const Def* def) {
110 auto place =
static_cast<Lam*
>(sched_.smart(curr_lam, def)->mut());
114 if (
auto global = def->isa<
Global>())
return global;
115 if (
auto mut_lam = def->isa_mut<
Lam>(); mut_lam && !mut_lam->is_set())
return def;
116 if (
auto ax = def->isa<
Axm>())
return ax;
117 if (
auto it = mem_rewritten_.find(def); it != mem_rewritten_.end()) {
118 auto tmp = it->second;
120 DLOG(
"already known mem {} in {}", def, curr_lam);
121 auto new_mem = mem_for_lam(curr_lam);
122 DLOG(
"new mem {} in {}", new_mem, curr_lam);
125 if (curr_lam != def) {
132 auto rewrite_lam = [&](
Lam* lam) ->
const Def* {
133 auto pi = lam->
type()->as<
Pi>();
136 if (
auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
138 new_lam = it->second->as_mut<
Lam>();
139 else if (
auto pi = it->second->type()->as<
Pi>(); pi->num_doms() > 0 &&
Axm::isa<mem::M>(pi->dom(0_s)))
143 if (!lam->
is_set())
return lam;
144 DLOG(
"rewrite lam {}", lam);
146 bool is_bound = sched_.nest().contains(lam) || lam == curr_lam;
149 if (
auto new_pi = rewrite_pi(pi); new_pi != pi) new_lam = lam->
stub(new_pi);
152 DLOG(
"free lam {}", lam);
153 mem_rewritten_[lam] = new_lam;
157 auto var_offset = new_lam->num_doms() - lam->num_doms();
158 if (lam->
num_vars() != 0) mem_rewritten_[lam->
var()] = new_lam->
var();
159 for (
size_t i = 0; i < lam->
num_vars() && new_lam->num_vars() > 1; ++i)
160 mem_rewritten_[lam->
var(i)] = new_lam->
var(i + var_offset);
162 auto var = new_lam->
var(0_n);
163 mem_rewritten_[new_lam] = new_lam;
164 mem_rewritten_[lam] = new_lam;
165 val2mem_[new_lam] = var;
167 mem_rewritten_[var] = var;
168 auto filter = add_mem_to_lams(lam, lam->
filter());
169 auto body = add_mem_to_lams(lam, lam->
body());
170 new_lam->unset()->set({filter, body});
177 if (
auto lam = def->isa_mut<
Lam>())
return rewrite_lam(lam);
178 assert(!def->isa_mut());
180 if (
auto pi = def->isa<
Pi>())
return rewrite_pi(pi);
182 auto rewrite_arg = [&](
const Def* arg) ->
const Def* {
183 size_t offset = (arg->type()->num_projs() > 0 &&
Axm::isa<mem::M>(arg->type()->proj(0))) ? 0 : 1;
186 add_mem_to_lams(place, arg->proj(0));
189 DefVec new_args{arg->type()->num_projs() + offset};
190 for (
int i = new_args.size() - 1; i >= 0; i--) {
192 = i == 0 ? add_mem_to_lams(place, mem_for_lam(place)) : add_mem_to_lams(place, arg->proj(i - offset));
198 if (
auto apped_mut = isa_apped_mut_lam_in_tuple(def); apped_mut.first) {
199 return mem_rewritten_[def]
200 = rewrite_apped_mut_lam_in_tuple(def, std::move(rewrite_lam), std::move(rewrite_arg),
201 [&](
const Def* def) {
return add_mem_to_lams(place, def); });
205 if (
auto app = def->isa<
App>(); app && (app->callee()->has_dep(
Dep::Var))) {
206 return mem_rewritten_[def]
207 = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()), rewrite_arg(app->arg())});
212 if (
auto app = def->isa<
App>(); app && app->axm() && app->curry() ^ 0x8000) {
213 auto arg = app->arg();
214 DefVec new_args(arg->num_projs());
215 for (
int i = new_args.size() - 1; i >= 0; i--) {
219 add_mem_to_lams(place, arg->proj(i));
220 new_args[i] = add_mem_to_lams(place, mem_for_lam(place));
222 new_args[i] = add_mem_to_lams(place, arg->proj(i));
225 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()),
226 world().tuple(new_args)->set(arg->dbg())})
229 DLOG(
"memory from axm {} : {}", rewritten, rewritten->type());
230 val2mem_[place] = rewritten;
232 if (rewritten->num_projs() > 0 &&
Axm::isa<mem::M>(rewritten->proj(0)->type())) {
233 DLOG(
"memory from axm 2 {} : {}", rewritten, rewritten->type());
234 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
235 val2mem_[place] = rewritten->proj(0);
241 if (
auto app = def->isa<
App>()) {
242 auto new_callee = add_mem_to_lams(place, app->callee());
243 auto new_arg = add_mem_to_lams(place, app->arg());
244 if (app->callee()->type()->as<
Pi>()->num_doms() + 1 == new_callee->type()->as<
Pi>()->num_doms())
245 new_arg = rewrite_arg(app->arg());
246 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {new_callee, new_arg})->set(app->dbg());
248 DLOG(
"memory from other {} : {}", rewritten, rewritten->type());
249 val2mem_[place] = rewritten;
251 if (rewritten->num_projs() > 0 &&
Axm::isa<mem::M>(rewritten->proj(0)->type())) {
252 DLOG(
"memory from other 2 {} : {}", rewritten, rewritten->type());
253 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
254 val2mem_[place] = rewritten->proj(0);
259 auto new_ops =
DefVec(def->ops(), [&](
const Def* op) {
260 if (Axm::isa<mem::M>(op->type())) {
262 add_mem_to_lams(place, op);
263 return add_mem_to_lams(place, mem_for_lam(place));
265 return add_mem_to_lams(place, op);
268 auto tmp = mem_rewritten_[def] = def->rebuild(rewrite_type(def->type()), new_ops)->set(def->dbg());
static auto isa(const Def *def)
bool is_set() const
Yields true if empty or the last op is set.
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
const Def * var(nat_t a, nat_t i) noexcept
void transfer_external(Def *to)
nat_t num_vars() noexcept
bool is_external() const noexcept
const Def * filter() const
Lam * stub(const Def *type)
const Nest & nest() const
Builds a nesting tree of all immutables‍/binders.
A dependent function type.
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
const Def * tuple(Defs ops)
void visit(const Nest &) override
#define DLOG(...)
Vaporizes to nothingness in Debug build.
const Def * mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Vector< const Def * > DefVec
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >