11std::pair<const App*, Vector<Lam*>> isa_apped_mut_lam_in_tuple(
const Def* def) {
12 if (
auto app = def->isa<App>()) {
14 std::deque<const Def*> wl;
15 wl.push_back(app->callee());
17 auto elem = wl.front();
19 if (
auto mut = elem->isa_mut<Lam>()) {
21 }
else if (
auto extract = elem->isa<Extract>()) {
22 if (
auto tuple = extract->tuple()->isa<Tuple>())
23 for (
auto&& op : tuple->ops()) wl.push_back(op);
36template<
class F,
class H>
const Def* rewrite_mut_lam_in_tuple(
const Def* def, F&& rewrite, H&& rewrite_idx) {
37 auto&
w = def->world();
38 if (
auto mut = def->isa_mut<Lam>())
return std::forward<F>(rewrite)(mut);
40 auto extract = def->as<
Extract>();
41 auto tuple = extract->tuple()->as<
Tuple>();
42 auto new_ops =
DefVec(tuple->ops(), [&](
const Def* op) {
43 return rewrite_mut_lam_in_tuple(op, std::forward<F>(rewrite), std::forward<H>(rewrite_idx));
45 return w.extract(
w.tuple(new_ops), rewrite_idx(extract->index()));
49template<
class RewriteCallee,
class RewriteArg,
class RewriteIdx>
50const Def* rewrite_apped_mut_lam_in_tuple(
const Def* def,
51 RewriteCallee&& rewrite_callee,
52 RewriteArg&& rewrite_arg,
53 RewriteIdx&& rewrite_idx) {
54 auto app = def->as<
App>();
55 auto callee = rewrite_mut_lam_in_tuple(app->callee(), std::forward<RewriteCallee>(rewrite_callee),
56 std::forward<RewriteIdx>(rewrite_idx));
57 auto arg = std::forward<RewriteArg>(rewrite_arg)(app->arg());
58 return app->rebuild(app->type(), {callee, arg});
69const Def* AddMem::mem_for_lam(
Lam* lam)
const {
70 if (
auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
74 if (
auto it = val2mem_.find(lam); it != val2mem_.end()) {
75 lam->
world().DLOG(
"found mem for {} in val2mem_ : {}", lam, it->second);
81 assert(
mem &&
"mut must have mem!");
85const Def* AddMem::rewrite_type(
const Def* type) {
86 if (
auto pi = type->isa<Pi>())
return rewrite_pi(pi);
88 if (
auto it = mem_rewritten_.find(type); it != mem_rewritten_.end())
return it->second;
90 auto new_ops =
DefVec(type->num_ops(), [&](
size_t i) { return rewrite_type(type->op(i)); });
91 return mem_rewritten_[type] = type->rebuild(type->type(), new_ops);
94const Def* AddMem::rewrite_pi(
const Pi* pi) {
95 if (
auto it = mem_rewritten_.find(pi); it != mem_rewritten_.end())
return it->second;
98 auto new_dom =
DefVec(dom->num_projs(), [&](
size_t i) { return rewrite_type(dom->proj(i)); });
101 =
DefVec(dom->num_projs() + 1, [&](
size_t i) { return i == 0 ? world().annex<mem::M>() : new_dom[i - 1]; });
104 return mem_rewritten_[pi] =
world().
pi(new_dom, pi->codom());
107const Def* AddMem::add_mem_to_lams(Lam* curr_lam,
const Def* def) {
108 auto place =
static_cast<Lam*
>(sched_.
smart(curr_lam, def)->mut());
112 if (
auto global = def->isa<
Global>())
return global;
113 if (
auto mut_lam = def->isa_mut<Lam>(); mut_lam && !mut_lam->is_set())
return def;
114 if (
auto ax = def->isa<
Axiom>())
return ax;
115 if (
auto it = mem_rewritten_.find(def); it != mem_rewritten_.end()) {
116 auto tmp = it->second;
118 world().DLOG(
"already known mem {} in {}", def, curr_lam);
119 auto new_mem = mem_for_lam(curr_lam);
120 world().DLOG(
"new mem {} in {}", new_mem, curr_lam);
123 if (curr_lam != def) {
130 auto rewrite_lam = [&](
Lam* lam) ->
const Def* {
131 auto pi = lam->
type()->as<
Pi>();
134 if (
auto it = mem_rewritten_.find(lam); it != mem_rewritten_.end()) {
136 new_lam = it->second->as_mut<
Lam>();
137 else if (
auto pi = it->second->type()->as<Pi>(); pi->num_doms() > 0 &&
match<mem::M>(pi->dom(0_s)))
141 if (!lam->
is_set())
return lam;
142 world().DLOG(
"rewrite lam {}", lam);
144 bool is_bound = sched_.
nest().
contains(lam) || lam == curr_lam;
147 if (
auto new_pi = rewrite_pi(pi); new_pi != pi) new_lam = lam->
stub(new_pi);
150 world().DLOG(
"free lam {}", lam);
151 mem_rewritten_[lam] = new_lam;
155 auto var_offset = new_lam->num_doms() - lam->num_doms();
156 if (lam->
num_vars() != 0) mem_rewritten_[lam->
var()] = new_lam->var();
157 for (
size_t i = 0; i < lam->
num_vars() && new_lam->num_vars() > 1; ++i)
158 mem_rewritten_[lam->
var(i)] = new_lam->var(i + var_offset);
160 auto var = new_lam->var(0_n);
161 mem_rewritten_[new_lam] = new_lam;
162 mem_rewritten_[lam] = new_lam;
163 val2mem_[new_lam] = var;
165 mem_rewritten_[var] = var;
166 auto filter = add_mem_to_lams(lam, lam->
filter());
167 auto body = add_mem_to_lams(lam, lam->
body());
168 new_lam->unset()->set({filter, body});
175 if (
auto lam = def->isa_mut<Lam>())
return rewrite_lam(lam);
176 assert(!def->isa_mut());
178 if (
auto pi = def->isa<Pi>())
return rewrite_pi(pi);
180 auto rewrite_arg = [&](
const Def* arg) ->
const Def* {
181 size_t offset = (arg->type()->num_projs() > 0 &&
match<mem::M>(arg->type()->proj(0))) ? 0 : 1;
184 add_mem_to_lams(place, arg->proj(0));
187 DefVec new_args{arg->type()->num_projs() + offset};
188 for (
int i = new_args.size() - 1; i >= 0; i--) {
190 = i == 0 ? add_mem_to_lams(place, mem_for_lam(place)) : add_mem_to_lams(place, arg->proj(i - offset));
192 return arg->world().tuple(new_args);
196 if (
auto apped_mut = isa_apped_mut_lam_in_tuple(def); apped_mut.first) {
197 return mem_rewritten_[def]
198 = rewrite_apped_mut_lam_in_tuple(def, std::move(rewrite_lam), std::move(rewrite_arg),
199 [&](
const Def* def) {
return add_mem_to_lams(place, def); });
203 if (
auto app = def->isa<App>(); app && (app->callee()->has_dep(
Dep::Var))) {
204 return mem_rewritten_[def]
205 = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()), rewrite_arg(app->arg())});
210 if (
auto app = def->isa<App>(); app && app->axiom() && app->curry() ^ 0x8000) {
211 auto arg = app->arg();
212 DefVec new_args(arg->num_projs());
213 for (
int i = new_args.size() - 1; i >= 0; i--) {
217 add_mem_to_lams(place, arg->proj(i));
218 new_args[i] = add_mem_to_lams(place, mem_for_lam(place));
220 new_args[i] = add_mem_to_lams(place, arg->proj(i));
223 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {add_mem_to_lams(place, app->callee()),
224 world().tuple(new_args)->set(arg->dbg())})
227 world().DLOG(
"memory from axiom {} : {}", rewritten, rewritten->type());
228 val2mem_[place] = rewritten;
230 if (rewritten->num_projs() > 0 &&
match<mem::M>(rewritten->proj(0)->type())) {
231 world().DLOG(
"memory from axiom 2 {} : {}", rewritten, rewritten->type());
232 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
233 val2mem_[place] = rewritten->proj(0);
239 if (
auto app = def->isa<App>()) {
240 auto new_callee = add_mem_to_lams(place, app->callee());
241 auto new_arg = add_mem_to_lams(place, app->arg());
242 if (app->callee()->type()->as<Pi>()->num_doms() + 1 == new_callee->type()->as<Pi>()->num_doms())
243 new_arg = rewrite_arg(app->arg());
244 auto rewritten = mem_rewritten_[def] = app->rebuild(app->type(), {new_callee, new_arg})->set(app->dbg());
246 world().DLOG(
"memory from other {} : {}", rewritten, rewritten->type());
247 val2mem_[place] = rewritten;
249 if (rewritten->num_projs() > 0 &&
match<mem::M>(rewritten->proj(0)->type())) {
250 world().DLOG(
"memory from other 2 {} : {}", rewritten, rewritten->type());
251 mem_rewritten_[rewritten->proj(0)] = rewritten->proj(0);
252 val2mem_[place] = rewritten->proj(0);
257 auto new_ops =
DefVec(def->ops(), [&](
const Def* op) {
258 if (match<mem::M>(op->type())) {
260 add_mem_to_lams(place, op);
261 return add_mem_to_lams(place, mem_for_lam(place));
263 return add_mem_to_lams(place, op);
266 auto tmp = mem_rewritten_[def] = def->rebuild(rewrite_type(def->type()), new_ops)->set(def->dbg());
bool is_set() const
Yields true if empty or the last op is set.
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
World & world() const noexcept
void transfer_external(Def *to)
nat_t num_vars() noexcept
Ref var(nat_t a, nat_t i) noexcept
const Nest & nest() const
Builds a nesting tree of all immutables‍/binders.
bool contains(const Def *def) const
const Nest::Node * smart(Def *curr, const Def *)
const Nest & nest() const
const Pi * pi(Ref dom, Ref codom, bool implicit=false)
void visit(const Nest &) override
Ref mem_var(Lam *lam)
Returns the memory argument of a function if it has one.
Vector< const Def * > DefVec
Vector(I, I, A=A()) -> Vector< typename std::iterator_traits< I >::value_type, Default_Inlined_Size< typename std::iterator_traits< I >::value_type >, A >