MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
reshape.cpp
Go to the documentation of this file.
2
3#include <functional>
4#include <sstream>
5#include <vector>
6
7#include <mim/check.h>
8#include <mim/def.h>
9#include <mim/lam.h>
10#include <mim/tuple.h>
11
12#include "mim/plug/mem/mem.h"
13
14namespace mim::plug::mem {
15
16namespace {
17
18bool is_mem_ty(const Def* T) { return match<mem::M>(T); }
19
20// TODO merge with should_flatten from tuple.*
21bool should_flatten(const Def* T) {
22 // handle [] cases
23 if (T->isa<Sigma>()) return true;
24 // also handle normalized tuple-arrays ((a:I32,b:I32) : <<2;I32>>)
25 // TODO: handle better than with magic number
26 // (do we want to flatten any array with more than 2 elements?)
27 // (2 elements are needed for conditionals)
28 // TODO: e.g. lea explicitely does not want to flatten
29
30 // TODO: annotate with test cases that need these special cases
31 // Problem with 2 Arr -> flatten
32 // lea (2, <<2;I32>>, ...) -> lea (2, I32, I32, ...)
33 if (auto lit = T->arity()->isa<Lit>(); lit && lit->get<u64>() <= 2) {
34 if (auto arr = T->isa<Arr>(); arr && arr->body()->isa<Pi>()) return lit->get<u64>() > 1;
35 }
36 return false;
37}
38
39// TODO merge with tuple.*
40DefVec flatten_ty(const Def* T) {
41 DefVec types;
42 if (should_flatten(T)) {
43 for (auto P : T->projs()) {
44 auto inner_types = flatten_ty(P);
45 types.insert(types.end(), inner_types.begin(), inner_types.end());
46 }
47 } else {
48 types.push_back(T);
49 }
50 return types;
51}
52
53// TODO try to remove code duplication with flatten_ty
54DefVec flatten_def(const Def* def) {
55 DefVec defs;
56 if (should_flatten(def->type())) {
57 for (auto P : def->projs()) {
58 auto inner_defs = flatten_def(P);
59 defs.insert(defs.end(), inner_defs.begin(), inner_defs.end());
60 }
61 } else {
62 defs.push_back(def);
63 }
64 return defs;
65}
66
67} // namespace
68
69void Reshape::enter() { rewrite_def(curr_mut()); }
70
71const Def* Reshape::rewrite_def(const Def* def) {
72 if (auto i = old2new_.find(def); i != old2new_.end()) return i->second;
73 auto new_def = rewrite_def_(def);
74 old2new_[def] = new_def;
75 return new_def;
76}
77
78const Def* Reshape::rewrite_def_(const Def* def) {
79 // We ignore types, Globals, and Axioms.
80 switch (def->node()) {
81 // TODO: check if bot: Cn[[A,B],Cn[Ret]] is handled correctly
82 // case Node::Bot:
83 // case Node::Top:
84 case Node::Type:
85 case Node::Univ:
86 case Node::Nat:
87 case Node::Axiom:
88 case Node::Global: return def;
89 default: break;
90 }
91
92 // This is dead code for debugging purposes.
93 // It allows for inspection of the current def.
94 std::stringstream ss;
95 ss << def << " : " << def->type() << " [" << def->node_name() << "]";
96 std::string str = ss.str();
97
98 // vars are handled by association.
99 if (def->isa<Var>()) world().ELOG("Var: {}", def);
100 assert(!def->isa<Var>());
101
102 if (auto app = def->isa<App>()) {
103 auto callee = rewrite_def(app->callee());
104 auto arg = rewrite_def(app->arg());
105
106 world().DLOG("callee: {} : {}", callee, callee->type());
107
108 // Reshape normally (not to callee) to ensure that callee is reshaped correctly.
109 auto reshaped_arg = reshape(arg);
110 world().DLOG("reshape arg {} : {}", arg, arg->type());
111 world().DLOG("into arg {} : {}", reshaped_arg, reshaped_arg->type());
112 auto new_app = world().app(callee, reshaped_arg);
113 return new_app;
114 } else if (auto lam = def->isa_mut<Lam>()) {
115 world().DLOG("rewrite_def lam {} : {}", def, def->type());
116 auto new_lam = reshape_lam(lam);
117 world().DLOG("rewrote lam {} : {}", def, def->type());
118 world().DLOG("into lam {} : {}", new_lam, new_lam->type());
119 return new_lam;
120 } else if (auto tuple = def->isa<Tuple>()) {
121 auto elements = DefVec(tuple->ops(), [&](const Def* op) { return rewrite_def(op); });
122 return world().tuple(elements);
123 } else {
124 auto new_ops = DefVec(def->num_ops(), [&](auto i) { return rewrite_def(def->op(i)); });
125 // Warning: if the new_type is not correct, inconcistencies will arise.
126 auto new_type = rewrite_def(def->type());
127 auto new_def = def->rebuild(new_type, new_ops);
128 return new_def;
129 }
130}
131
132Lam* Reshape::reshape_lam(Lam* old_lam) {
133 if (!old_lam->is_set()) {
134 world().DLOG("reshape_lam: {} is not a set", old_lam);
135 return old_lam;
136 }
137 auto pi_ty = old_lam->type();
138 auto new_ty = reshape_type(pi_ty)->as<Pi>();
139
140 Lam* new_lam;
141 if (*old_lam->sym() == "main") {
142 new_lam = old_lam;
143 } else {
144 new_lam = old_lam->stub(new_ty);
145 if (!old_lam->is_external()) new_lam->debug_suffix("_reshape");
146 old2new_[old_lam] = new_lam;
147 }
148
149 world().DLOG("Reshape lam: {} : {}", old_lam, pi_ty);
150 world().DLOG(" to: {} : {}", new_lam, new_ty);
151
152 // We associate the arguments (reshape the old vars).
153 // Alternatively, we could use beta reduction (reduce) to do this for us.
154 auto new_arg = new_lam->var();
155
156 // We deeply associate `old_lam->var()` with `new_arg` in a reconstructed shape.
157 // Idea: first make new_arg into "atomic" old_lam list, then recrusively imitate `old_lam->var`.
158 auto reformed_new_arg = reshape(new_arg, old_lam->var()->type()); // `old_lam->var()->type() = pi_ty`
159 world().DLOG("var {} : {}", old_lam->var(), old_lam->var()->type());
160 world().DLOG("new var {} : {}", new_arg, new_arg->type());
161 world().DLOG("reshaped new_var {} : {}", reformed_new_arg, reformed_new_arg->type());
162 world().DLOG("{}", old_lam->var()->type());
163 world().DLOG("{}", reformed_new_arg->type());
164 old2new_[old_lam->var()] = reformed_new_arg;
165 // TODO: add if necessary. This probably was an issue with unintended overriding due to bad previous naming.
166 // TODO: Remove after testing.
167 // old2new_[new_arg] = new_arg;
168
169 auto new_body = rewrite_def(old_lam->body());
170 auto new_filter = rewrite_def(old_lam->filter());
171 new_lam->unset();
172 new_lam->set(new_filter, new_body);
173
174 if (old_lam->is_external()) old_lam->transfer_external(new_lam);
175
176 world().DLOG("finished transforming: {} : {}", new_lam, new_ty);
177 return new_lam;
178}
179
180const Def* Reshape::reshape_type(const Def* T) {
181 if (auto pi = T->isa<Pi>()) {
182 auto new_dom = reshape_type(pi->dom());
183 auto new_cod = reshape_type(pi->codom());
184 return world().pi(new_dom, new_cod);
185 } else if (auto sigma = T->isa<Sigma>()) {
186 auto flat_types = flatten_ty(sigma);
187 auto new_types = DefVec(flat_types.size());
188 std::ranges::transform(flat_types, new_types.begin(), [&](auto T) { return reshape_type(T); });
189 if (mode_ == Mode::Flat) {
190 const Def* mem = nullptr;
191 // find mem
192 for (auto i = new_types.begin(); i != new_types.end(); i++)
193 if (is_mem_ty(*i) && !mem) mem = *i;
194 // filter out mems
195 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
196 // readd mem in the front
197 if (mem) new_types.insert(new_types.begin(), mem);
198 auto reshaped_type = world().sigma(new_types);
199 return reshaped_type;
200 } else {
201 if (new_types.size() == 0) return world().sigma();
202 if (new_types.size() == 1) return new_types[0];
203 const Def* mem = nullptr;
204 const Def* ret = nullptr;
205 // find mem
206 for (auto i = new_types.begin(); i != new_types.end(); i++)
207 if (is_mem_ty(*i) && !mem) mem = *i;
208 // filter out mems
209 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
210 // TODO: more fine-grained test
211 if (new_types.back()->isa<Pi>()) {
212 ret = new_types.back();
213 new_types.pop_back();
214 }
215 // Create the arg form `[[mem,args],ret]`
216 const Def* args = world().sigma(new_types);
217 if (mem) args = world().sigma({mem, args});
218 if (ret) args = world().sigma({args, ret});
219 return args;
220 }
221 } else {
222 return T;
223 }
224}
225
226const Def* Reshape::reshape(DefVec& defs, const Def* T, const Def* mem) {
227 auto& world = T->world();
228 if (should_flatten(T)) {
229 auto tuples = T->projs([&](auto P) { return reshape(defs, P, mem); });
230 return world.tuple(tuples);
231 } else {
232 const Def* def;
233 if (is_mem_ty(T)) {
234 assert(mem != nullptr && "Reshape: mems not found");
235 def = mem;
236 } else {
237 do {
238 assert(defs.size() > 0 && "Reshape: not enough arguments");
239 def = defs.front();
240 defs.erase(defs.begin());
241 } while (is_mem_ty(def->type()));
242 }
243 // For inner function types, we override the type
244 if (!def->type()->isa<Pi>()) {
245 if (!Checker::alpha<Checker::Check>(def->type(), T))
246 world.ELOG("reconstruct T {} from def {}", T, def->type());
247 assert(Checker::alpha<Checker::Check>(def->type(), T) && "Reshape: argument type mismatch");
248 }
249 return def;
250 }
251}
252
253const Def* Reshape::reshape(const Def* def, const Def* target) {
254 world().DLOG("reshape:\n {} =>\n {}", def->type(), target);
255 auto flat_defs = flatten_def(def);
256 const Def* mem = nullptr;
257 // find mem
258 for (auto i = flat_defs.begin(); i != flat_defs.end(); i++)
259 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
260 world().DLOG("mem: {}", mem);
261 return reshape(flat_defs, target, mem);
262}
263
264// called for new lambda arguments, app arguments
265// We can not (directly) replace it with the more general version above due to the mem erasure.
266// TODO: ignore mem erase, replace with more general
267// TODO: capture names
268const Def* Reshape::reshape(const Def* def) {
269 auto flat_defs = flatten_def(def);
270 if (flat_defs.size() == 1) return flat_defs[0];
271 // TODO: move mem removal to flatten_def
272 if (mode_ == Mode::Flat) {
273 const Def* mem = nullptr;
274 // find mem
275 for (auto i = flat_defs.begin(); i != flat_defs.end(); i++)
276 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
277 // filter out mems
278 flat_defs.erase(
279 std::remove_if(flat_defs.begin(), flat_defs.end(), [](const Def* def) { return is_mem_ty(def->type()); }),
280 flat_defs.end());
281 // insert mem
282 if (mem) flat_defs.insert(flat_defs.begin(), mem);
283 return world().tuple(flat_defs);
284 } else {
285 // arg style
286 // [[mem,args],ret]
287 const Def* mem = nullptr;
288 const Def* ret = nullptr;
289 // find mem
290 for (auto i = flat_defs.begin(); i != flat_defs.end(); i++)
291 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
292 // filter out mems
293 flat_defs.erase(
294 std::remove_if(flat_defs.begin(), flat_defs.end(), [](const Def* def) { return is_mem_ty(def->type()); }),
295 flat_defs.end());
296 if (flat_defs.back()->type()->isa<Pi>()) {
297 ret = flat_defs.back();
298 flat_defs.pop_back();
299 }
300 const Def* args = world().tuple(flat_defs);
301 if (mem) args = world().tuple({mem, args});
302 if (ret) args = world().tuple({args, ret});
303 return args;
304 }
305}
306
307} // namespace mim::plug::mem
static bool alpha(const Def *d1, const Def *d2)
Definition check.h:71
Base class for all Defs.
Definition def.h:198
constexpr Node node() const noexcept
Definition def.h:221
std::string_view node_name() const
Definition def.cpp:431
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:430
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:242
const Def * rebuild(World &w, const Def *type, Defs ops) const
Def::rebuilds this Def while using new_op as substitute for its i'th Def::op.
Definition def.h:492
constexpr size_t num_ops() const noexcept
Definition def.h:265
World & world()
Definition pass.h:296
Lam * curr_mut() const
Definition pass.h:232
const Def * sigma(Defs ops)
Definition world.cpp:238
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:181
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
Definition world.h:251
const Def * tuple(Defs ops)
Definition world.cpp:246
const Def * var(const Def *type, Def *mut)
Definition world.cpp:167
void enter() override
Fall-through to rewrite_def which falls through to rewrite_lam.
Definition reshape.cpp:69
The mem Plugin
Definition mem.h:11
Vector< const Def * > DefVec
Definition def.h:50
auto match(const Def *def)
Definition axiom.h:112
uint64_t u64
Definition types.h:34
@ Nat
Definition def.h:85
@ Pi
Definition def.h:85
@ Univ
Definition def.h:85
@ Axiom
Definition def.h:85
@ Lam
Definition def.h:85
@ Arr
Definition def.h:85
@ Global
Definition def.h:85
@ Sigma
Definition def.h:85
@ Type
Definition def.h:85
@ Tuple
Definition def.h:85
@ Lit
Definition def.h:85