MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
reshape.cpp
Go to the documentation of this file.
2
3#include <mim/check.h>
4#include <mim/def.h>
5#include <mim/lam.h>
6#include <mim/tuple.h>
7
8#include "mim/plug/mem/mem.h"
9
10namespace mim::plug::mem::pass {
11
12namespace {
13
14bool is_mem_ty(const Def* T) { return Axm::isa<mem::M>(T); }
15
16// TODO merge with should_flatten from tuple.*
17bool should_flatten(const Def* T) {
18 // handle [] cases
19 if (T->isa<Sigma>()) return true;
20 // also handle normalized tuple-arrays ((a:I32,b:I32) : <<2;I32>>)
21 // TODO: handle better than with magic number
22 // (do we want to flatten any array with more than 2 elements?)
23 // (2 elements are needed for conditionals)
24 // TODO: e.g. lea explicitely does not want to flatten
25
26 // TODO: annotate with test cases that need these special cases
27 // Problem with 2 Arr -> flatten
28 // lea (2, <<2;I32>>, ...) -> lea (2, I32, I32, ...)
29 if (auto lit = T->arity()->isa<Lit>(); lit && lit->get<u64>() <= 2) {
30 if (auto arr = T->isa<Arr>(); arr && arr->body()->isa<Pi>()) return lit->get<u64>() > 1;
31 }
32 return false;
33}
34
35// TODO merge with tuple.*
36DefVec flatten_ty(const Def* T) {
37 DefVec types;
38 if (should_flatten(T)) {
39 for (auto P : T->projs()) {
40 auto inner_types = flatten_ty(P);
41 types.insert(types.end(), inner_types.begin(), inner_types.end());
42 }
43 } else {
44 types.push_back(T);
45 }
46 return types;
47}
48
49// TODO try to remove code duplication with flatten_ty
50DefVec flatten_def(const Def* def) {
51 DefVec defs;
52 if (should_flatten(def->type())) {
53 for (auto P : def->projs()) {
54 auto inner_defs = flatten_def(P);
55 defs.insert(defs.end(), inner_defs.begin(), inner_defs.end());
56 }
57 } else {
58 defs.push_back(def);
59 }
60 return defs;
61}
62
63} // namespace
64
66 mode_ = mode;
67 name_ += mode == Flat ? " Flat" : " Arg";
68}
69
70void Reshape::apply(const App* app) {
71 auto axm = app->arg()->as<Axm>();
72 apply(axm->flags() == Annex::base<mem::reshape_arg>() ? Arg : Flat);
73}
74
75void Reshape::enter() { rewrite_def(curr_mut()); }
76
77const Def* Reshape::rewrite_def(const Def* def) {
78 if (auto i = old2new_.find(def); i != old2new_.end()) return i->second;
79 auto new_def = rewrite_def_(def);
80 old2new_[def] = new_def;
81 return new_def;
82}
83
84const Def* Reshape::rewrite_def_(const Def* def) {
85 // We ignore types, Globals, and Axms.
86 switch (def->node()) {
87 // TODO: check if bot: Cn[[A,B],Cn[Ret]] is handled correctly
88 // case Node::Bot:
89 // case Node::Top:
90 case Node::Type:
91 case Node::Univ:
92 case Node::Nat:
93 case Node::Axm:
94 case Node::Global: return def;
95 default: break;
96 }
97
98 // This is dead code for debugging purposes.
99 // It allows for inspection of the current def.
100 std::stringstream ss;
101 ss << def << " : " << def->type() << " [" << def->node_name() << "]";
102 std::string str = ss.str();
103
104 // vars are handled by association.
105 if (def->isa<Var>()) ELOG("Var: {}", def);
106 assert(!def->isa<Var>());
107
108 if (auto app = def->isa<App>()) {
109 auto callee = rewrite_def(app->callee());
110 auto arg = rewrite_def(app->arg());
111
112 DLOG("callee: {} : {}", callee, callee->type());
113
114 // Reshape normally (not to callee) to ensure that callee is reshaped correctly.
115 auto reshaped_arg = reshape(arg);
116 DLOG("reshape arg {} : {}", arg, arg->type());
117 DLOG("into arg {} : {}", reshaped_arg, reshaped_arg->type());
118 auto new_app = world().app(callee, reshaped_arg);
119 return new_app;
120 } else if (auto lam = def->isa_mut<Lam>()) {
121 DLOG("rewrite_def lam {} : {}", def, def->type());
122 auto new_lam = reshape_lam(lam);
123 DLOG("rewrote lam {} : {}", def, def->type());
124 DLOG("into lam {} : {}", new_lam, new_lam->type());
125 return new_lam;
126 } else if (auto tuple = def->isa<Tuple>()) {
127 auto elements = DefVec(tuple->ops(), [&](const Def* op) { return rewrite_def(op); });
128 return world().tuple(elements);
129 } else {
130 auto new_ops = DefVec(def->num_ops(), [&](auto i) { return rewrite_def(def->op(i)); });
131 // Warning: if the new_type is not correct, inconcistencies will arise.
132 auto new_type = rewrite_def(def->type());
133 auto new_def = def->rebuild(new_type, new_ops);
134 return new_def;
135 }
136}
137
138Lam* Reshape::reshape_lam(Lam* old_lam) {
139 if (!old_lam->is_set()) {
140 DLOG("reshape_lam: {} is not a set", old_lam);
141 return old_lam;
142 }
143 auto pi_ty = old_lam->type();
144 auto new_ty = reshape_type(pi_ty)->as<Pi>();
145
146 Lam* new_lam;
147 if (*old_lam->sym() == "main") {
148 new_lam = old_lam;
149 } else {
150 new_lam = old_lam->stub(new_ty);
151 if (!old_lam->is_external()) new_lam->debug_suffix("_reshape");
152 old2new_[old_lam] = new_lam;
153 }
154
155 DLOG("Reshape lam: {} : {}", old_lam, pi_ty);
156 DLOG(" to: {} : {}", new_lam, new_ty);
157
158 // We associate the arguments (reshape the old vars).
159 // Alternatively, we could use beta reduction (reduce) to do this for us.
160 auto new_arg = new_lam->var();
161
162 // We deeply associate `old_lam->var()` with `new_arg` in a reconstructed shape.
163 // Idea: first make new_arg into "atomic" old_lam list, then recrusively imitate `old_lam->var`.
164 auto reformed_new_arg = reshape(new_arg, old_lam->var()->type()); // `old_lam->var()->type() = pi_ty`
165 DLOG("var {} : {}", old_lam->var(), old_lam->var()->type());
166 DLOG("new var {} : {}", new_arg, new_arg->type());
167 DLOG("reshaped new_var {} : {}", reformed_new_arg, reformed_new_arg->type());
168 DLOG("{}", old_lam->var()->type());
169 DLOG("{}", reformed_new_arg->type());
170 old2new_[old_lam->var()] = reformed_new_arg;
171 // TODO: add if necessary. This probably was an issue with unintended overriding due to bad previous naming.
172 // TODO: Remove after testing.
173 // old2new_[new_arg] = new_arg;
174
175 auto new_body = rewrite_def(old_lam->body());
176 auto new_filter = rewrite_def(old_lam->filter());
177 new_lam->unset();
178 new_lam->set(new_filter, new_body);
179
180 if (old_lam->is_external()) old_lam->transfer_external(new_lam);
181
182 DLOG("finished transforming: {} : {}", new_lam, new_ty);
183 return new_lam;
184}
185
186const Def* Reshape::reshape_type(const Def* T) {
187 if (auto pi = T->isa<Pi>()) {
188 auto new_dom = reshape_type(pi->dom());
189 auto new_cod = reshape_type(pi->codom());
190 return world().pi(new_dom, new_cod);
191 } else if (auto sigma = T->isa<Sigma>()) {
192 auto flat_types = flatten_ty(sigma);
193 auto new_types = DefVec(flat_types.size());
194 std::ranges::transform(flat_types, new_types.begin(), [&](auto T) { return reshape_type(T); });
195 if (mode_ == Mode::Flat) {
196 const Def* mem = nullptr;
197 // find mem
198 for (auto i = new_types.begin(); i != new_types.end(); i++)
199 if (is_mem_ty(*i) && !mem) mem = *i;
200 // filter out mems
201 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
202 // readd mem in the front
203 if (mem) new_types.insert(new_types.begin(), mem);
204 auto reshaped_type = world().sigma(new_types);
205 return reshaped_type;
206 } else {
207 if (new_types.size() == 0) return world().sigma();
208 if (new_types.size() == 1) return new_types[0];
209 const Def* mem = nullptr;
210 const Def* ret = nullptr;
211 // find mem
212 for (auto i = new_types.begin(); i != new_types.end(); i++)
213 if (is_mem_ty(*i) && !mem) mem = *i;
214 // filter out mems
215 new_types.erase(std::remove_if(new_types.begin(), new_types.end(), is_mem_ty), new_types.end());
216 // TODO: more fine-grained test
217 if (new_types.back()->isa<Pi>()) {
218 ret = new_types.back();
219 new_types.pop_back();
220 }
221 // Create the arg form `[[mem,args],ret]`
222 const Def* args = world().sigma(new_types);
223 if (mem) args = world().sigma({mem, args});
224 if (ret) args = world().sigma({args, ret});
225 return args;
226 }
227 } else {
228 return T;
229 }
230}
231
232const Def* Reshape::reshape(DefVec& defs, const Def* T, const Def* mem) {
233 auto& world = T->world();
234 if (should_flatten(T)) {
235 auto tuples = T->projs([&](auto P) { return reshape(defs, P, mem); });
236 return world.tuple(tuples);
237 } else {
238 const Def* def;
239 if (is_mem_ty(T)) {
240 assert(mem != nullptr && "Reshape: mems not found");
241 def = mem;
242 } else {
243 do {
244 assert(defs.size() > 0 && "Reshape: not enough arguments");
245 def = defs.front();
246 defs.erase(defs.begin());
247 } while (is_mem_ty(def->type()));
248 }
249 // For inner function types, we override the type
250 if (!def->type()->isa<Pi>()) {
251 if (!Checker::alpha<Checker::Check>(def->type(), T))
252 world.ELOG("reconstruct T {} from def {}", T, def->type());
253 assert(Checker::alpha<Checker::Check>(def->type(), T) && "Reshape: argument type mismatch");
254 }
255 return def;
256 }
257}
258
259const Def* Reshape::reshape(const Def* def, const Def* target) {
260 DLOG("reshape:\n {} =>\n {}", def->type(), target);
261 auto flat_defs = flatten_def(def);
262 const Def* mem = nullptr;
263 // find mem
264 for (auto i = flat_defs.begin(); i != flat_defs.end(); i++)
265 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
266 DLOG("mem: {}", mem);
267 return reshape(flat_defs, target, mem);
268}
269
270// called for new lambda arguments, app arguments
271// We can not (directly) replace it with the more general version above due to the mem erasure.
272// TODO: ignore mem erase, replace with more general
273// TODO: capture names
274const Def* Reshape::reshape(const Def* def) {
275 auto flat_defs = flatten_def(def);
276 if (flat_defs.size() == 1) return flat_defs[0];
277 // TODO: move mem removal to flatten_def
278 if (mode_ == Mode::Flat) {
279 const Def* mem = nullptr;
280 // find mem
281 for (auto i = flat_defs.begin(); i != flat_defs.end(); i++)
282 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
283 // filter out mems
284 flat_defs.erase(
285 std::remove_if(flat_defs.begin(), flat_defs.end(), [](const Def* def) { return is_mem_ty(def->type()); }),
286 flat_defs.end());
287 // insert mem
288 if (mem) flat_defs.insert(flat_defs.begin(), mem);
289 return world().tuple(flat_defs);
290 } else {
291 // arg style
292 // [[mem,args],ret]
293 const Def* mem = nullptr;
294 const Def* ret = nullptr;
295 // find mem
296 for (auto i = flat_defs.begin(); i != flat_defs.end(); i++)
297 if (is_mem_ty((*i)->type()) && !mem) mem = *i;
298 // filter out mems
299 flat_defs.erase(
300 std::remove_if(flat_defs.begin(), flat_defs.end(), [](const Def* def) { return is_mem_ty(def->type()); }),
301 flat_defs.end());
302 if (flat_defs.back()->type()->isa<Pi>()) {
303 ret = flat_defs.back();
304 flat_defs.pop_back();
305 }
306 const Def* args = world().tuple(flat_defs);
307 if (mem) args = world().tuple({mem, args});
308 if (ret) args = world().tuple({args, ret});
309 return args;
310 }
311}
312
313} // namespace mim::plug::mem::pass
const Def * arg() const
Definition lam.h:286
Definition axm.h:9
static auto isa(const Def *def)
Definition axm.h:107
static bool alpha(const Def *d1, const Def *d2)
Definition check.h:96
Base class for all Defs.
Definition def.h:251
constexpr Node node() const noexcept
Definition def.h:274
std::string_view node_name() const
Definition def.cpp:454
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:482
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:295
const Def * rebuild(World &w, const Def *type, Defs ops) const
Def::rebuilds this Def while using new_op as substitute for its i'th Def::op.
Definition def.h:545
constexpr size_t num_ops() const noexcept
Definition def.h:309
A function.
Definition lam.h:111
Lam * curr_mut() const
Definition pass.h:307
World & world()
Definition pass.h:64
std::string name_
Definition pass.h:76
const Def * sigma(Defs ops)
Definition world.cpp:278
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:196
const Pi * pi(const Def *dom, const Def *codom, bool implicit=false)
Definition world.h:268
const Def * tuple(Defs ops)
Definition world.cpp:288
void enter() override
Fall-through to rewrite_def which falls through to rewrite_lam.
Definition reshape.cpp:75
#define ELOG(...)
Definition log.h:89
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
Vector< const Def * > DefVec
Definition def.h:77
uint64_t u64
Definition types.h:34
@ Nat
Definition def.h:114
@ Pi
Definition def.h:114
@ Univ
Definition def.h:114
@ Lam
Definition def.h:114
@ Arr
Definition def.h:114
@ Global
Definition def.h:114
@ Axm
Definition def.h:114
@ Sigma
Definition def.h:114
@ Type
Definition def.h:114
@ Tuple
Definition def.h:114
@ Lit
Definition def.h:114
static consteval flags_t base()
Definition plugin.h:119