MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
reshape.cpp
Go to the documentation of this file.
2
3#include <functional>
4#include <sstream>
5#include <vector>
6
7#include <mim/check.h>
8#include <mim/def.h>
9#include <mim/lam.h>
10#include <mim/tuple.h>
11
12#include "mim/plug/mem/mem.h"
13
14namespace mim::plug::mem {
15
16namespace {
17
18bool is_mem_ty(const Def* T) { return match<mem::M>(T); }
19
20// TODO merge with should_flatten from tuple.*
21bool should_flatten(const Def* T) {
22 // handle [] cases
23 if (T->isa<Sigma>()) return true;
24 // also handle normalized tuple-arrays ((a:I32,b:I32) : <<2;I32>>)
25 // TODO: handle better than with magic number
26 // (do we want to flatten any array with more than 2 elements?)
27 // (2 elements are needed for conditionals)
28 // TODO: e.g. lea explicitely does not want to flatten
29
30 // TODO: annotate with test cases that need these special cases
31 // Problem with 2 Arr -> flatten
32 // lea (2, <<2;I32>>, ...) -> lea (2, I32, I32, ...)
33 if (auto lit = T->arity()->isa<Lit>(); lit && lit->get<u64>() <= 2) {
34 if (auto arr = T->isa<Arr>(); arr && arr->body()->isa<Pi>()) return lit->get<u64>() > 1;
35 }
36 return false;
37}
38
39// TODO merge with tuple.*
40DefVec flatten_ty(const Def* T) {
41 DefVec types;
42 if (should_flatten(T)) {
43 for (auto P : T->projs()) {
44 auto inner_types = flatten_ty(P);
45 types.insert(types.end(), inner_types.begin(), inner_types.end());
46 }
47 } else {
48 types.push_back(T);
49 }
50 return types;
51}
52
53// TODO try to remove code duplication with flatten_ty
54DefVec flatten_def(const Def* def) {
55 DefVec defs;
56 if (should_flatten(def->type())) {
57 for (auto P : def->projs()) {
58 auto inner_defs = flatten_def(P);
59 defs.insert(defs.end(), inner_defs.begin(), inner_defs.end());
60 }
61 } else {
62 defs.push_back(def);
63 }
64 return defs;
65}
66
67} // namespace
68
69void Reshape::enter() { rewrite_def(curr_mut()); }
70
71const Def* Reshape::rewrite_def(const Def* def) {
72 if (auto i = old2new_.find(def); i != old2new_.end()) return i->second;
73 auto new_def = rewrite_def_(def);
74 old2new_[def] = new_def;
75 return new_def;
76}
77
78const Def* Reshape::rewrite_def_(const Def* def) {
79 // We ignore types, Globals, and Axioms.
80 switch (def->node()) {
81 // TODO: check if bot: Cn[[A,B],Cn[Ret]] is handled correctly
82 // case Node::Bot:
83 // case Node::Top:
84 case Node::Type:
85 case Node::Univ:
86 case Node::Nat:
87 case Node::Axiom:
88 case Node::Global: return def;
89 }
90
91 // This is dead code for debugging purposes.
92 // It allows for inspection of the current def.
93 std::stringstream ss;
94 ss << def << " : " << def->type() << " [" << def->node_name() << "]";
95 std::string str = ss.str();
96
97 // vars are handled by association.
98 if (def->isa<Var>()) world().ELOG("Var: {}", def);
99 assert(!def->isa<Var>());
100
101 if (auto app = def->isa<App>()) {
102 auto callee = rewrite_def(app->callee());
103 auto arg = rewrite_def(app->arg());
104
105 world().DLOG("callee: {} : {}", callee, callee->type());
106
107 // Reshape normally (not to callee) to ensure that callee is reshaped correctly.
108 auto reshaped_arg = reshape(arg);
109 world().DLOG("reshape arg {} : {}", arg, arg->type());
110 world().DLOG("into arg {} : {}", reshaped_arg, reshaped_arg->type());
111 auto new_app = world().app(callee, reshaped_arg);
112 return new_app;
113 } else if (auto lam = def->isa_mut<Lam>()) {
114 world().DLOG("rewrite_def lam {} : {}", def, def->type());
115 auto new_lam = reshape_lam(lam);
116 world().DLOG("rewrote lam {} : {}", def, def->type());
117 world().DLOG("into lam {} : {}", new_lam, new_lam->type());
118 return new_lam;
119 } else if (auto tuple = def->isa<Tuple>()) {
120 auto elements = DefVec(tuple->ops(), [&](const Def* op) { return rewrite_def(op); });
121 return world().tuple(elements);
122 } else {
123 auto new_ops = DefVec(def->num_ops(), [&](auto i) { return rewrite_def(def->op(i)); });
124 // Warning: if the new_type is not correct, inconcistencies will arise.
125 auto new_type = rewrite_def(def->type());
126 auto new_def = def->rebuild(new_type, new_ops);
127 return new_def;
128 }
129}
130
131Lam* Reshape::reshape_lam(Lam* old_lam) {
132 if (!old_lam->is_set()) {
133 world().DLOG("reshape_lam: {} is not a set", old_lam);
134 return old_lam;
135 }
136 auto pi_ty = old_lam->type();
137 auto new_ty = reshape_type(pi_ty)->as<Pi>();
138
139 Lam* new_lam;
140 if (*old_lam->sym() == "main") {
141 new_lam = old_lam;
142 } else {
143 new_lam = old_lam->stub(new_ty);
144 if (!old_lam->is_external()) new_lam->debug_suffix("_reshape");
145 old2new_[old_lam] = new_lam;
146 }
147
148 world().DLOG("Reshape lam: {} : {}", old_lam, pi_ty);
149 world().DLOG(" to: {} : {}", new_lam, new_ty);
150
151 // We associate the arguments (reshape the old vars).
152 // Alternatively, we could use beta reduction (reduce) to do this for us.
153 auto new_arg = new_lam->var();
154
155 // We deeply associate `old_lam->var()` with `new_arg` in a reconstructed shape.
156 // Idea: first make new_arg into "atomic" old_lam list, then recrusively imitate `old_lam->var`.
157 auto reformed_new_arg = reshape(new_arg, old_lam->var()->type()); // `old_lam->var()->type() = pi_ty`
158 world().DLOG("var {} : {}", old_lam->var(), old_lam->var()->type());
159 world().DLOG("new var {} : {}", new_arg, new_arg->type());
160 world().DLOG("reshaped new_var {} : {}", reformed_new_arg, reformed_new_arg->type());
161 world().DLOG("{}", old_lam->var()->type());
162 world().DLOG("{}", reformed_new_arg->type());
163 old2new_[old_lam->var()] = reformed_new_arg;
164 // TODO: add if necessary. This probably was an issue with unintended overriding due to bad previous naming.
165 // TODO: Remove after testing.
166 // old2new_[new_arg] = new_arg;
167
168 auto new_body = rewrite_def(old_lam->body());
169 auto new_filter = rewrite_def(old_lam->filter());
170 new_lam->unset();
171 new_lam->set(new_filter, new_body);
172
173 if (old_lam->is_external()) old_lam->transfer_external(new_lam);
174
175 world().DLOG("finished transforming: {} : {}", new_lam, new_ty);
176 return new_lam;
177}
178
179const Def* Reshape::reshape_type(const Def* T) {
180 if (auto pi = T->isa<Pi>()) {
181 auto new_dom = reshape_type(pi->dom());
182 auto new_cod = reshape_type(pi->codom());
183 return world().pi(new_dom, new_cod);
184 } else if (auto sigma = T->isa<Sigma>()) {
185 auto flat_types = flatten_ty(sigma);
186 auto new_types = DefVec(flat_types.size());
187 std::ranges::transform(flat_types, new_types.begin(), [&](auto T) { return reshape_type(T); });
188 if (mode_ == Mode::Flat) {
189 const Def* mem = nullptr;
190 // find mem
191 for (auto i = new_types.begin(); i != new_types.end(); i++)
192 if (is_mem_ty(*i) && !mem) mem = *i;
193 // filter out mems
194 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
195 // readd mem in the front
196 if (mem) new_types.insert(new_types.begin(), mem);
197 auto reshaped_type = world().sigma(new_types);
198 return reshaped_type;
199 } else {
200 if (new_types.size() == 0) return world().sigma();
201 if (new_types.size() == 1) return new_types[0];
202 const Def* mem = nullptr;
203 const Def* ret = nullptr;
204 // find mem
205 for (auto i = new_types.begin(); i != new_types.end(); i++)
206 if (is_mem_ty(*i) && !mem) mem = *i;
207 // filter out mems
208 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
209 // TODO: more fine-grained test
210 if (new_types.back()->isa<Pi>()) {
211 ret = new_types.back();
212 new_types.pop_back();
213 }
214 // Create the arg form `[[mem,args],ret]`
215 const Def* args = world().sigma(new_types);
216 if (mem) args = world().sigma({mem, args});
217 if (ret) args = world().sigma({args, ret});
218 return args;
219 }
220 } else {
221 return T;
222 }
223}
224
225const Def* Reshape::reshape(DefVec& defs, const Def* T, const Def* mem) {
226 auto& world = T->world();
227 if (should_flatten(T)) {
228 auto tuples = T->projs([&](auto P) { return reshape(defs, P, mem); });
229 return world.tuple(tuples);
230 } else {
231 const Def* def;
232 if (is_mem_ty(T)) {
233 assert(mem != nullptr && "Reshape: mems not found");
234 def = mem;
235 } else {
236 do {
237 assert(defs.size() > 0 && "Reshape: not enough arguments");
238 def = defs.front();
239 defs.erase(defs.begin());
240 } while (is_mem_ty(def->type()));
241 }
242 // For inner function types, we override the type
243 if (!def->type()->isa<Pi>()) {
244 if (!Check::alpha(def->type(), T)) world.ELOG("reconstruct T {} from def {}", T, def->type());
245 assert(Check::alpha(def->type(), T) && "Reshape: argument type mismatch");
246 }
247 return def;
248 }
249}
250
251const Def* Reshape::reshape(const Def* def, const Def* target) {
252 def->world().DLOG("reshape:\n {} =>\n {}", def->type(), target);
253 auto flat_defs = flatten_def(def);
254 const Def* mem = nullptr;
255 // find mem
256 for (auto i = flat_defs.begin(); i != flat_defs.end(); i++)
257 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
258 def->world().DLOG("mem: {}", mem);
259 return reshape(flat_defs, target, mem);
260}
261
262// called for new lambda arguments, app arguments
263// We can not (directly) replace it with the more general version above due to the mem erasure.
264// TODO: ignore mem erase, replace with more general
265// TODO: capture names
266const Def* Reshape::reshape(const Def* def) {
267 auto flat_defs = flatten_def(def);
268 if (flat_defs.size() == 1) return flat_defs[0];
269 // TODO: move mem removal to flatten_def
270 if (mode_ == Mode::Flat) {
271 const Def* mem = nullptr;
272 // find mem
273 for (auto i = flat_defs.begin(); i != flat_defs.end(); i++)
274 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
275 // filter out mems
276 flat_defs.erase(
277 std::remove_if(flat_defs.begin(), flat_defs.end(), [](const Def* def) { return is_mem_ty(def->type()); }),
278 flat_defs.end());
279 // insert mem
280 if (mem) flat_defs.insert(flat_defs.begin(), mem);
281 return world().tuple(flat_defs);
282 } else {
283 // arg style
284 // [[mem,args],ret]
285 const Def* mem = nullptr;
286 const Def* ret = nullptr;
287 // find mem
288 for (auto i = flat_defs.begin(); i != flat_defs.end(); i++)
289 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
290 // filter out mems
291 flat_defs.erase(
292 std::remove_if(flat_defs.begin(), flat_defs.end(), [](const Def* def) { return is_mem_ty(def->type()); }),
293 flat_defs.end());
294 if (flat_defs.back()->type()->isa<Pi>()) {
295 ret = flat_defs.back();
296 flat_defs.pop_back();
297 }
298 const Def* args = world().tuple(flat_defs);
299 if (mem) args = world().tuple({mem, args});
300 if (ret) args = world().tuple({args, ret});
301 return args;
302 }
303}
304
305} // namespace mim::plug::mem
static bool alpha(Ref d1, Ref d2)
Are d1 and d2 α-equivalent?
Definition check.h:65
Base class for all Defs.
Definition def.h:223
size_t num_ops() const
Definition def.h:270
std::string_view node_name() const
Definition def.cpp:431
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:447
const Def * type() const
Definition def.h:248
node_t node() const
Definition def.h:241
Ref rebuild(World &w, Ref type, Defs ops) const
Def::rebuilds this Def while using new_op as substitute for its i'th Def::op.
Definition def.h:509
World & world()
Definition pass.h:296
Lam * curr_mut() const
Definition pass.h:232
Ref tuple(Defs ops)
Definition world.cpp:238
Ref var(Ref type, Def *mut)
Definition world.cpp:156
Ref app(Ref callee, Ref arg)
Definition world.cpp:186
const Pi * pi(Ref dom, Ref codom, bool implicit=false)
Definition world.h:255
Ref sigma(Defs ops)
Definition world.cpp:230
void enter() override
Fall-through to rewrite_def which falls through to rewrite_lam.
Definition reshape.cpp:69
@ Type
Definition def.h:40
@ Univ
Definition def.h:40
@ Nat
Definition def.h:40
@ Axiom
Definition def.h:40
@ Global
Definition def.h:40
@ Pi
Definition def.h:40
@ Lam
Definition def.h:40
The mem Plugin
Definition mem.h:11
Vector< const Def * > DefVec
Definition def.h:62
auto match(Ref def)
Definition axiom.h:112
uint64_t u64
Definition types.h:34