13std::pair<const App*, Vector<Lam*>> isa_apped_mut_lam_in_tuple(
const Def* def) {
14 if (
auto app = def->isa<App>()) {
16 std::deque<const Def*> wl;
17 wl.push_back(app->callee());
19 auto elem = wl.front();
21 if (
auto mut = elem->isa_mut<Lam>()) {
23 }
else if (
auto extract = elem->isa<Extract>()) {
24 if (
auto tuple = extract->tuple()->isa<Tuple>())
25 for (
auto&& op : tuple->ops()) wl.push_back(op);
38template<
class F,
class H>
const Def* rewrite_mut_lam_in_tuple(
const Def* def, F&&
rewrite, H&& rewrite_idx) {
39 auto&
w = def->world();
40 if (
auto mut = def->isa_mut<Lam>())
return std::forward<F>(
rewrite)(mut);
42 auto extract = def->as<
Extract>();
43 auto tuple = extract->tuple()->as<
Tuple>();
44 auto new_ops =
DefVec(tuple->ops(), [&](
const Def* op) {
45 return rewrite_mut_lam_in_tuple(op, std::forward<F>(rewrite), std::forward<H>(rewrite_idx));
47 return w.extract(
w.tuple(new_ops), rewrite_idx(extract->index()));
51template<
class RewriteCallee,
class RewriteArg,
class RewriteIdx>
52const Def* rewrite_apped_mut_lam_in_tuple(
const Def* def,
53 RewriteCallee&& rewrite_callee,
54 RewriteArg&& rewrite_arg,
55 RewriteIdx&& rewrite_idx) {
56 auto app = def->as<
App>();
57 auto callee = rewrite_mut_lam_in_tuple(app->callee(), std::forward<RewriteCallee>(rewrite_callee),
58 std::forward<RewriteIdx>(rewrite_idx));
59 auto arg = std::forward<RewriteArg>(rewrite_arg)(app->arg());
60 return app->rebuild(app->type(), {callee, arg});
70 add_mem_to_lams(entry, entry);
74const Def* AddMem::mem_for_lam(
Lam* lam)
const {
75 if (
auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
79 if (
auto it = val2mem_.find(lam); it != val2mem_.end()) {
80 lam->
world().DLOG(
"found mem for {} in val2mem_ : {}", lam, it->second);
86 assert(
mem &&
"mut must have mem!");
90const Def* AddMem::rewrite_type(
const Def* type) {
91 if (
auto pi = type->isa<Pi>())
return rewrite_pi(pi);
93 if (
auto it = mem_rewritten_.find(type); it != mem_rewritten_.end())
return it->second;
95 auto new_ops =
DefVec(type->num_ops(), [&](
size_t i) { return rewrite_type(type->op(i)); });
96 return mem_rewritten_[type] = type->rebuild(type->type(), new_ops);
99const Def* AddMem::rewrite_pi(
const Pi* pi) {
100 if (
auto it = mem_rewritten_.find(pi); it != mem_rewritten_.end())
return it->second;
102 auto dom = pi->dom();
103 auto new_dom =
DefVec(dom->num_projs(), [&](
size_t i) { return rewrite_type(dom->proj(i)); });
106 =
DefVec(dom->num_projs() + 1, [&](
size_t i) { return i == 0 ? world().annex<mem::M>() : new_dom[i - 1]; });
109 return mem_rewritten_[pi] =
world().
pi(new_dom, pi->codom());
112const Def* AddMem::add_mem_to_lams(Lam* curr_lam,
const Def* def) {
113 auto place =
static_cast<Lam*
>(sched_.
smart(def));
117 if (
auto global = def->isa<
Global>())
return global;
118 if (
auto mut_lam = def->isa_mut<Lam>(); mut_lam && !mut_lam->is_set())
return def;
119 if (
auto ax = def->isa<
Axiom>())
return ax;
120 if (
auto it = mem_rewritten_.find(def); it != mem_rewritten_.end()) {
121 auto tmp = it->second;
123 world().DLOG(
"already known mem {} in {}", def, curr_lam);
124 auto new_mem = mem_for_lam(curr_lam);
125 world().DLOG(
"new mem {} in {}", new_mem, curr_lam);
128 if (curr_lam != def) {
135 auto rewrite_lam = [&](
Lam* lam) ->
const Def* {
136 auto pi = lam->
type()->as<
Pi>();
139 if (
auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
141 new_lam = it->second->as_mut<
Lam>();
142 else if (
auto pi = it->second->type()->as<Pi>(); pi->num_doms() > 0 &&
match<mem::M>(pi->dom(0_s)))
146 if (!lam->
is_set())
return lam;
147 world().DLOG(
"rewrite lam {}", lam);
149 bool is_bound = sched_.
scope().
bound(lam) || lam == curr_lam;
152 if (
auto new_pi = rewrite_pi(pi); new_pi != pi) new_lam = lam->
stub(new_pi);
155 world().DLOG(
"free lam {}", lam);
156 mem_rewritten_[lam] = new_lam;
160 auto var_offset = new_lam->num_doms() - lam->num_doms();
161 if (lam->
num_vars() != 0) mem_rewritten_[lam->
var()] = new_lam->var();
162 for (
size_t i = 0; i < lam->
num_vars() && new_lam->num_vars() > 1; ++i)
163 mem_rewritten_[lam->
var(i)] = new_lam->var(i + var_offset);
165 auto var = new_lam->var(0_n);
166 mem_rewritten_[new_lam] = new_lam;
167 mem_rewritten_[lam] = new_lam;
168 val2mem_[new_lam] = var;
170 mem_rewritten_[var] = var;
171 auto filter = add_mem_to_lams(lam, lam->
filter());
172 auto body = add_mem_to_lams(lam, lam->
body());
173 new_lam->unset()->set({filter, body});
180 if (
auto lam = def->isa_mut<Lam>())
return rewrite_lam(lam);
181 assert(!def->isa_mut());
183 if (
auto pi = def->isa<Pi>())
return rewrite_pi(pi);
185 auto rewrite_arg = [&](
const Def* arg) ->
const Def* {
186 size_t offset = (arg->type()->num_projs() > 0 &&
match<mem::M>(arg->type()->proj(0))) ? 0 : 1;
189 add_mem_to_lams(place, arg->proj(0));
192 DefVec new_args{arg->type()->num_projs() + offset};
193 for (
int i = new_args.size() - 1; i >= 0; i--) {
195 = i == 0 ? add_mem_to_lams(place, mem_for_lam(place)) : add_mem_to_lams(place, arg->proj(i - offset));
197 return arg->world().tuple(new_args);
201 if (
auto apped_mut = isa_apped_mut_lam_in_tuple(def); apped_mut.first) {
202 return mem_rewritten_[def]
203 = rewrite_apped_mut_lam_in_tuple(def, std::move(rewrite_lam), std::move(rewrite_arg),
204 [&](
const Def* def) {
return add_mem_to_lams(place, def); });
208 if (
auto app = def->isa<App>(); app && (app->callee()->dep() &
Dep::Var)) {
209 return mem_rewritten_[def]
210 = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()), rewrite_arg(app->arg())});
215 if (
auto app = def->isa<App>(); app && app->axiom() && app->curry() ^ 0x8000) {
216 auto arg = app->arg();
217 DefVec new_args(arg->num_projs());
218 for (
int i = new_args.size() - 1; i >= 0; i--) {
222 add_mem_to_lams(place, arg->proj(i));
223 new_args[i] = add_mem_to_lams(place, mem_for_lam(place));
225 new_args[i] = add_mem_to_lams(place, arg->proj(i));
228 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()),
229 world().tuple(new_args)->set(arg->dbg())})
232 world().DLOG(
"memory from axiom {} : {}", rewritten, rewritten->type());
233 val2mem_[place] = rewritten;
235 if (rewritten->num_projs() > 0 &&
match<mem::M>(rewritten->proj(0)->type())) {
236 world().DLOG(
"memory from axiom 2 {} : {}", rewritten, rewritten->type());
237 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
238 val2mem_[place] = rewritten->proj(0);
244 if (
auto app = def->isa<App>()) {
245 auto new_callee = add_mem_to_lams(place, app->callee());
246 auto new_arg = add_mem_to_lams(place, app->arg());
247 if (app->callee()->type()->as<Pi>()->num_doms() + 1 == new_callee->type()->as<Pi>()->num_doms())
248 new_arg = rewrite_arg(app->arg());
249 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {new_callee, new_arg})->set(app->dbg());
251 world().DLOG(
"memory from other {} : {}", rewritten, rewritten->type());
252 val2mem_[place] = rewritten;
254 if (rewritten->num_projs() > 0 &&
match<mem::M>(rewritten->proj(0)->type())) {
255 world().DLOG(
"memory from other 2 {} : {}", rewritten, rewritten->type());
256 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
257 val2mem_[place] = rewritten->proj(0);
262 auto new_ops =
DefVec(def->ops(), [&](
const Def* op) {
263 if (match<mem::M>(op->type())) {
265 add_mem_to_lams(place, op);
266 return add_mem_to_lams(place, mem_for_lam(place));
268 return add_mem_to_lams(place, op);
271 auto tmp = mem_rewritten_[def] = def->rebuild(rewrite_type(def->type()), new_ops)->set(def->dbg());
bool is_set() const
Yields true if empty or the last op is set.
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
void transfer_external(Def *to)
Ref var(nat_t a, nat_t i)
const Scope & scope() const
const Scope & scope() const
A Scope represents a region of Defs that are live from the view of an entry's Var.
const MutSet & free_muts() const
All muts that occurr free in this Scope.
bool bound(const Def *def) const
const Pi * pi(Ref dom, Ref codom, bool implicit=false)
void visit(const Scope &) override
Ref mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Vector< const Def * > DefVec
DefVec rewrite(Def *mut, Ref arg)
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >