MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <mim/normalize.h>
2
4#include <mim/plug/mem/mem.h>
5
7
8namespace mim::plug::core {
9
10namespace {
11
12// clang-format off
13// See https://stackoverflow.com/a/64354296 for static_assert trick below.
14template<class Id, Id id, nat_t w>
15Res fold(u64 a, u64 b, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
16 using ST = w2s<w>;
17 using UT = w2u<w>;
19 auto u = mim::bitcast<UT>(a), v = mim::bitcast<UT>(b);
20
21 if constexpr (std::is_same_v<Id, wrap>) {
22 if constexpr (id == wrap::add) {
23 auto res = u + v;
24 if (nuw && res < u) return {};
25 // TODO nsw
26 return res;
27 } else if constexpr (id == wrap::sub) {
28 auto res = u - v;
29 // TODO nsw
30 return res;
31 } else if constexpr (id == wrap::mul) {
32 if constexpr (std::is_same_v<UT, bool>)
33 return UT(u & v);
34 else
35 return UT(u * v);
36 } else if constexpr (id == wrap::shl) {
37 if (b >= w) return {};
38 decltype(u) res;
39 if constexpr (std::is_same_v<UT, bool>)
40 res = bool(u64(u) << u64(v));
41 else
42 res = u << v;
43 if (nuw && res < u) return {};
44 if (nsw && get_sign(u) != get_sign(res)) return {};
45 return res;
46 } else {
47 []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
48 }
49 } else if constexpr (std::is_same_v<Id, shr>) {
50 if (b >= w) return {};
51 if constexpr (false) {}
52 else if constexpr (id == shr::a) return s >> t;
53 else if constexpr (id == shr::l) return u >> v;
54 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
55 } else if constexpr (std::is_same_v<Id, div>) {
56 if (b == 0) return {};
57 if constexpr (false) {}
58 else if constexpr (id == div::sdiv) return s / t;
59 else if constexpr (id == div::udiv) return u / v;
60 else if constexpr (id == div::srem) return s % t;
61 else if constexpr (id == div::urem) return u % v;
62 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
63 } else if constexpr (std::is_same_v<Id, icmp>) {
64 bool res = false;
65 auto pm = !(u >> UT(w - 1)) && (v >> UT(w - 1));
66 auto mp = (u >> UT(w - 1)) && !(v >> UT(w - 1));
67 res |= ((id & icmp::Xygle) != icmp::f) && pm;
68 res |= ((id & icmp::xYgle) != icmp::f) && mp;
69 res |= ((id & icmp::xyGle) != icmp::f) && u > v && !mp;
70 res |= ((id & icmp::xygLe) != icmp::f) && u < v && !pm;
71 res |= ((id & icmp::xyglE) != icmp::f) && u == v;
72 return res;
73 } else if constexpr (std::is_same_v<Id, extrema>) {
74 if constexpr (false) {}
75 else if(id == extrema::sm) return std::min(u, v);
76 else if(id == extrema::Sm) return std::min(s, t);
77 else if(id == extrema::sM) return std::max(u, v);
78 else if(id == extrema::SM) return std::max(s, t);
79 } else {
80 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
81 }
82}
83// clang-format on
84
85// Note that @p a and @p b are passed by reference as fold also commutes if possible.
86template<class Id, Id id>
87const Def* fold(World& world, const Def* type, const Def*& a, const Def*& b, const Def* mode = {}) {
88 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
89
90 if (auto la = Lit::isa(a)) {
91 if (auto lb = Lit::isa(b)) {
92 auto size = Lit::as(Idx::isa(a->type()));
93 auto width = Idx::size2bitwidth(size);
94 bool nsw = false, nuw = false;
95 if constexpr (std::is_same_v<Id, wrap>) {
96 auto m = mode ? Lit::as(mode) : 0_n;
97 nsw = m & Mode::nsw;
98 nuw = m & Mode::nuw;
99 }
100
101 Res res;
102 switch (width) {
103#define CODE(i) \
104 case i: res = fold<Id, id, i>(*la, *lb, nsw, nuw); break;
106#undef CODE
107 default:
108 // TODO this is super rough but at least better than just bailing out
109 res = fold<Id, id, 64>(*la, *lb, false, false);
110 if (res && !std::is_same_v<Id, icmp>) *res %= size;
111 }
112
113 return res ? world.lit(type, *res) : world.bot(type);
114 }
115 }
116
117 if (::mim::is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
118 return nullptr;
119}
120
121template<class Id, nat_t w>
122Res fold(u64 a, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
123 using ST = w2s<w>;
124 auto s = mim::bitcast<ST>(a);
125
126 if constexpr (std::is_same_v<Id, abs>)
127 return std::abs(s);
128 else
129 []<bool flag = false>() { static_assert(flag, "missing tag"); }
130 ();
131}
132
133template<class Id>
134const Def* fold(World& world, const Def* type, const Def*& a) {
135 if (a->isa<Bot>()) return world.bot(type);
136
137 if (auto la = Lit::isa(a)) {
138 auto size = Lit::as(Idx::isa(a->type()));
139 auto width = Idx::size2bitwidth(size);
140 bool nsw = false, nuw = false;
141 Res res;
142 switch (width) {
143#define CODE(i) \
144 case i: res = fold<Id, i>(*la, nsw, nuw); break;
146#undef CODE
147 default:
148 res = fold<Id, 64>(*la, false, false);
149 if (res && !std::is_same_v<Id, icmp>) *res %= size;
150 }
151
152 return res ? world.lit(type, *res) : world.bot(type);
153 }
154 return nullptr;
155}
156
157/// Reassociates @p a and @p b according to following rules.
158/// We use the following naming convention while literals are prefixed with an `l`:
159/// ```
160/// a op b
161/// (x op y) op (z op w)
162///
163/// (1) la op (lz op w) -> (la op lz) op w
164/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
165/// (3) a op (lz op w) -> lz op (a op w)
166/// (4) (lx op y) op b -> lx op (y op b)
167/// ```
168template<class Id>
169const Def* reassociate(Id id, World& world, [[maybe_unused]] const App* ab, const Def* a, const Def* b) {
170 if (!is_associative(id)) return nullptr;
171
172 auto xy = Axm::isa<Id>(id, a);
173 auto zw = Axm::isa<Id>(id, b);
174 auto la = a->isa<Lit>();
175 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
176 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
177 auto lx = Lit::isa(x);
178 auto lz = Lit::isa(z);
179
180 // if we reassociate, we have to forget about nsw/nuw
181 auto make_op = [&world, id](const Def* a, const Def* b) { return world.call(id, Mode::none, Defs{a, b}); };
182
183 if (la && lz) return make_op(make_op(a, z), w); // (1)
184 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
185 if (lz) return make_op(z, make_op(a, w)); // (3)
186 if (lx) return make_op(x, make_op(y, b)); // (4)
187 return nullptr;
188}
189
190template<class Id>
191const Def* merge_cmps(std::array<std::array<u64, 2>, 2> tab, const Def* a, const Def* b) {
192 static_assert(sizeof(sub_t) == 1, "if this ever changes, please adjust the logic below");
193 static constexpr size_t num_bits = std::bit_width(Annex::Num<Id> - 1_u64);
194
195 auto& world = a->world();
196 auto a_cmp = Axm::isa<Id>(a);
197 auto b_cmp = Axm::isa<Id>(b);
198
199 if (a_cmp && b_cmp && a_cmp->arg() == b_cmp->arg()) {
200 // push sub bits of a_cmp and b_cmp through truth table
201 sub_t res = 0;
202 sub_t a_sub = a_cmp.sub();
203 sub_t b_sub = b_cmp.sub();
204 for (size_t i = 0; i != num_bits; ++i, res >>= 1, a_sub >>= 1, b_sub >>= 1)
205 res |= tab[a_sub & 1][b_sub & 1] << 7_u8;
206 res >>= (7_u8 - u8(num_bits));
207
208 if constexpr (std::is_same_v<Id, math::cmp>)
209 return world.call(math::cmp(res), /*mode*/ a_cmp->decurry()->arg(), a_cmp->arg());
210 else
211 return world.call(icmp(Annex::Base<icmp> | res), a_cmp->arg());
212 }
213
214 return nullptr;
215}
216
217} // namespace
218
219template<nat id>
220const Def* normalize_nat(const Def* type, const Def* callee, const Def* arg) {
221 auto& world = type->world();
222 auto [a, b] = arg->projs<2>();
223 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
224 auto la = Lit::isa(a);
225 auto lb = Lit::isa(b);
226
227 if (la) {
228 if (lb) {
229 switch (id) {
230 case nat::add: return world.lit_nat(*la + *lb);
231 case nat::sub: return *la < *lb ? world.lit_nat_0() : world.lit_nat(*la - *lb);
232 case nat::mul: return world.lit_nat(*la * *lb);
233 }
234 }
235
236 if (*la == 0) {
237 switch (id) {
238 case nat::add: return b;
239 case nat::sub: return a; // 0 - b = 0
240 case nat::mul: return a; // 0 * b = 0
241 }
242 }
243
244 if (*la == 1 && id == nat::mul) return b; // 1 * b = b
245 }
246
247 if (lb && *lb == 0 && id == nat::sub) return a; // a - 0 = a
248
249 if (a == b) {
250 switch (id) {
251 case nat::add: return world.call(nat::mul, Defs{world.lit_nat(2), a}); // a + a = 2 * a
252 case nat::sub: return world.lit_nat(0); // a - a = 0
253 case nat::mul: break;
254 }
255 }
256
257 return world.raw_app(type, callee, {a, b});
258}
259
260template<ncmp id>
261const Def* normalize_ncmp(const Def* type, const Def* callee, const Def* arg) {
262 auto& world = type->world();
263
264 if (id == ncmp::t) return world.lit_tt();
265 if (id == ncmp::f) return world.lit_ff();
266
267 auto [a, b] = arg->projs<2>();
268 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
269
270 if (auto la = Lit::isa(a)) {
271 if (auto lb = Lit::isa(b)) {
272 // clang-format off
273 switch (id) {
274 case ncmp:: e: return world.lit_bool(*la == *lb);
275 case ncmp::ne: return world.lit_bool(*la != *lb);
276 case ncmp::l : return world.lit_bool(*la < *lb);
277 case ncmp::le: return world.lit_bool(*la <= *lb);
278 case ncmp::g : return world.lit_bool(*la > *lb);
279 case ncmp::ge: return world.lit_bool(*la >= *lb);
280 default: fe::unreachable();
281 }
282 // clang-format on
283 }
284 }
285
286 return world.raw_app(type, callee, {a, b});
287}
288
289template<icmp id>
290const Def* normalize_icmp(const Def* type, const Def* c, const Def* arg) {
291 auto& world = type->world();
292 auto callee = c->as<App>();
293 auto [a, b] = arg->projs<2>();
294
295 if (auto result = fold<icmp, id>(world, type, a, b)) return result;
296 if (id == icmp::f) return world.lit_ff();
297 if (id == icmp::t) return world.lit_tt();
298 if (a == b) {
299 if (id & (icmp::e & 0xff)) return world.lit_tt();
300 if (id == icmp::ne) return world.lit_ff();
301 }
302
303 return world.raw_app(type, callee, {a, b});
304}
305
306template<extrema id>
307const Def* normalize_extrema(const Def* type, const Def* c, const Def* arg) {
308 auto& world = type->world();
309 auto callee = c->as<App>();
310 auto [a, b] = arg->projs<2>();
311 if (auto result = fold<extrema, id>(world, type, a, b)) return result;
312 return world.raw_app(type, callee, {a, b});
313}
314
315const Def* normalize_abs(const Def* type, const Def*, const Def* arg) {
316 auto& world = type->world();
317 auto [mem, a] = arg->projs<2>();
318 auto [_, actual_type] = type->projs<2>();
319 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
320
321 if (auto result = fold<abs>(world, actual_type, a)) return make_res(result);
322 return {};
323}
324
325template<bit1 id>
326const Def* normalize_bit1(const Def* type, const Def* c, const Def* a) {
327 auto& world = type->world();
328 auto callee = c->as<App>();
329 auto s = callee->decurry()->arg();
330 // TODO cope with wrap around
331
332 if constexpr (id == bit1::id) return a;
333
334 if (auto ls = Lit::isa(s)) {
335 switch (id) {
336 case bit1::f: return world.lit_idx(*ls, 0);
337 case bit1::t: return world.lit_idx(*ls, *ls - 1_u64);
338 case bit1::id: fe::unreachable();
339 default: break;
340 }
341
342 assert(id == bit1::neg);
343 if (auto la = Lit::isa(a)) return world.lit_idx_mod(*ls, ~*la);
344 }
345
346 return {};
347}
348
349template<bit2 id>
350const Def* normalize_bit2(const Def* type, const Def* c, const Def* arg) {
351 auto& world = type->world();
352 auto callee = c->as<App>();
353 auto [a, b] = arg->projs<2>();
354 auto s = callee->decurry()->arg();
355 auto ls = Lit::isa(s);
356 // TODO cope with wrap around
357
358 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
359
360 auto tab = make_truth_table(id);
361 if (auto res = merge_cmps<icmp>(tab, a, b)) return res;
362 if (auto res = merge_cmps<math::cmp>(tab, a, b)) return res;
363
364 auto la = Lit::isa(a);
365 auto lb = Lit::isa(b);
366
367 // clang-format off
368 switch (id) {
369 case bit2:: f: return world.lit(type, 0);
370 case bit2:: t: if (ls) return world.lit(type, *ls-1_u64); break;
371 case bit2:: fst: return a;
372 case bit2:: snd: return b;
373 case bit2:: nfst: return world.call(bit1::neg, s, a);
374 case bit2:: nsnd: return world.call(bit1::neg, s, b);
375 case bit2:: ciff: return world.call(bit2:: iff, s, Defs{b, a});
376 case bit2::nciff: return world.call(bit2::niff, s, Defs{b, a});
377 default: break;
378 }
379
380 if (la && lb && ls) {
381 switch (id) {
382 case bit2::and_: return world.lit_idx (*ls, *la & *lb);
383 case bit2:: or_: return world.lit_idx (*ls, *la | *lb);
384 case bit2::xor_: return world.lit_idx (*ls, *la ^ *lb);
385 case bit2::nand: return world.lit_idx_mod(*ls, ~(*la & *lb));
386 case bit2:: nor: return world.lit_idx_mod(*ls, ~(*la | *lb));
387 case bit2::nxor: return world.lit_idx_mod(*ls, ~(*la ^ *lb));
388 case bit2:: iff: return world.lit_idx_mod(*ls, ~ *la | *lb);
389 case bit2::niff: return world.lit_idx (*ls, *la & ~*lb);
390 default: fe::unreachable();
391 }
392 }
393
394 // TODO rewrite using bit2
395 auto unary = [&](bool x, bool y, const Def* a) -> const Def* {
396 if (!x && !y) return world.lit(type, 0);
397 if ( x && y) return ls ? world.lit(type, *ls-1_u64) : nullptr;
398 if (!x && y) return a;
399 if ( x && !y && id != bit2::xor_) return world.call(bit1::neg, s, a);
400 return nullptr;
401 };
402 // clang-format on
403
404 if (is_commutative(id) && a == b) {
405 if (auto res = unary(tab[0][0], tab[1][1], a)) return res;
406 }
407
408 if (la) {
409 if (*la == 0) {
410 if (auto res = unary(tab[0][0], tab[0][1], b)) return res;
411 } else if (ls && *la == *ls - 1_u64) {
412 if (auto res = unary(tab[1][0], tab[1][1], b)) return res;
413 }
414 }
415
416 if (lb) {
417 if (*lb == 0) {
418 if (auto res = unary(tab[0][0], tab[1][0], a)) return res;
419 } else if (ls && *lb == *ls - 1_u64) {
420 if (auto res = unary(tab[0][1], tab[1][1], a)) return res;
421 }
422 }
423
424 if (auto res = reassociate<bit2>(id, world, callee, a, b)) return res;
425
426 return world.raw_app(type, callee, {a, b});
427}
428
429const Def* normalize_idx(const Def* type, const Def* c, const Def* arg) {
430 auto& world = type->world();
431 auto callee = c->as<App>();
432 if (auto i = Lit::isa(arg)) {
433 if (auto s = Lit::isa(Idx::isa(type))) {
434 if (*i < *s) return world.lit_idx(*s, *i);
435 if (auto m = Lit::isa(callee->arg())) return *m ? world.bot(type) : world.lit_idx_mod(*s, *i);
436 }
437 }
438
439 return {};
440}
441
442const Def* normalize_idx_unsafe(const Def*, const Def*, const Def* arg) {
443 auto& world = arg->world();
444 if (auto i = Lit::isa(arg)) return world.lit_idx_unsafe(*i);
445 return {};
446}
447
448template<shr id>
449const Def* normalize_shr(const Def* type, const Def* c, const Def* arg) {
450 auto& world = type->world();
451 auto callee = c->as<App>();
452 auto [a, b] = arg->projs<2>();
453 auto s = Idx::isa(arg->type());
454 auto ls = Lit::isa(s);
455
456 if (auto result = fold<shr, id>(world, type, a, b)) return result;
457
458 if (auto la = Lit::isa(a); la && *la == 0) {
459 switch (id) {
460 case shr::a: return a;
461 case shr::l: return a;
462 }
463 }
464
465 if (auto lb = Lit::isa(b)) {
466 if (ls && *lb > *ls) return world.bot(type);
467
468 if (*lb == 0) {
469 switch (id) {
470 case shr::a: return a;
471 case shr::l: return a;
472 }
473 }
474 }
475
476 return world.raw_app(type, callee, {a, b});
477}
478
479template<wrap id>
480const Def* normalize_wrap(const Def* type, const Def* c, const Def* arg) {
481 auto& world = type->world();
482 auto callee = c->as<App>();
483 auto [a, b] = arg->projs<2>();
484 auto mode = callee->arg();
485 auto s = Idx::isa(a->type());
486 auto ls = Lit::isa(s);
487
488 if (auto result = fold<wrap, id>(world, type, a, b)) return result;
489
490 // clang-format off
491 if (auto la = Lit::isa(a)) {
492 if (*la == 0) {
493 switch (id) {
494 case wrap::add: return b; // 0 + b -> b
495 case wrap::sub: break;
496 case wrap::mul: return a; // 0 * b -> 0
497 case wrap::shl: return a; // 0 << b -> 0
498 }
499 } else if (*la == 1) {
500 switch (id) {
501 case wrap::add: break;
502 case wrap::sub: break;
503 case wrap::mul: return b; // 1 * b -> b
504 case wrap::shl: break;
505 }
506 }
507 }
508
509 if (auto lb = Lit::isa(b)) {
510 if (*lb == 0) {
511 switch (id) {
512 case wrap::sub: return a; // a - 0 -> a
513 case wrap::shl: return a; // a >> 0 -> a
514 default: fe::unreachable();
515 // add, mul are commutative, the literal has been normalized to the left
516 }
517 }
518
519 if (auto lm = Lit::isa(mode); lm && ls && *lm == 0 && id == wrap::sub)
520 return world.call(wrap::add, mode, Defs{a, world.lit_idx_mod(*ls, ~*lb + 1_u64)}); // a - lb -> a + (~lb + 1)
521 else if (id == wrap::shl && ls && *lb > *ls)
522 return world.bot(type);
523 }
524
525 if (a == b) {
526 switch (id) {
527 case wrap::add: return world.call(wrap::mul, mode, Defs{world.lit(type, 2), a}); // a + a -> 2 * a
528 case wrap::sub: return world.lit(type, 0); // a - a -> 0
529 case wrap::mul: break;
530 case wrap::shl: break;
531 }
532 }
533 // clang-format on
534
535 if (auto res = reassociate<wrap>(id, world, callee, a, b)) return res;
536
537 return world.raw_app(type, callee, {a, b});
538}
539
540template<div id>
541const Def* normalize_div(const Def* full_type, const Def*, const Def* arg) {
542 auto& world = full_type->world();
543 auto [mem, ab] = arg->projs<2>();
544 auto [a, b] = ab->projs<2>();
545 auto [_, type] = full_type->projs<2>(); // peel off actual type
546 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
547
548 if (auto result = fold<div, id>(world, type, a, b)) return make_res(result);
549
550 if (auto la = Lit::isa(a)) {
551 if (*la == 0) return make_res(a); // 0 / b -> 0 and 0 % b -> 0
552 }
553
554 if (auto lb = Lit::isa(b)) {
555 if (*lb == 0) return make_res(world.bot(type)); // a / 0 -> ⊥ and a % 0 -> ⊥
556
557 if (*lb == 1) {
558 switch (id) {
559 case div::sdiv: return make_res(a); // a / 1 -> a
560 case div::udiv: return make_res(a); // a / 1 -> a
561 case div::srem: return make_res(world.lit(type, 0)); // a % 1 -> 0
562 case div::urem: return make_res(world.lit(type, 0)); // a % 1 -> 0
563 }
564 }
565 }
566
567 if (a == b) {
568 switch (id) {
569 case div::sdiv: return make_res(world.lit(type, 1)); // a / a -> 1
570 case div::udiv: return make_res(world.lit(type, 1)); // a / a -> 1
571 case div::srem: return make_res(world.lit(type, 0)); // a % a -> 0
572 case div::urem: return make_res(world.lit(type, 0)); // a % a -> 0
573 }
574 }
575
576 return {};
577}
578
579template<conv id>
580const Def* normalize_conv(const Def* dst_t, const Def*, const Def* x) {
581 auto& world = dst_t->world();
582 auto s_t = x->type()->as<App>();
583 auto d_t = dst_t->as<App>();
584 auto s = s_t->arg();
585 auto d = d_t->arg();
586 auto ls = Lit::isa(s);
587 auto ld = Lit::isa(d);
588
589 if (s_t == d_t) return x;
590 if (x->isa<Bot>()) return world.bot(d_t);
591 if constexpr (id == conv::s) {
592 if (ls && ld && *ld < *ls) return world.call(conv::u, d, x); // just truncate - we don't care for signedness
593 }
594
595 if (auto l = Lit::isa(x); l && ls && ld) {
596 if constexpr (id == conv::u) {
597 if (*ld == 0) return world.lit(d_t, *l); // I64
598 return world.lit(d_t, *l % *ld);
599 }
600
601 auto sw = Idx::size2bitwidth(*ls);
602 auto dw = Idx::size2bitwidth(*ld);
603
604 // clang-format off
605 if (false) {}
606#define M(S, D) \
607 else if (S == sw && D == dw) return world.lit(d_t, w2s<D>(mim::bitcast<w2s<S>>(*l)));
608 M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
609 M( 8, 16) M( 8, 32) M( 8, 64)
610 M(16, 32) M(16, 64)
611 M(32, 64)
612 else assert(false && "TODO: conversion between different Idx sizes");
613 // clang-format on
614 }
615
616 return {};
617}
618
619const Def* normalize_bitcast(const Def* dst_t, const Def*, const Def* src) {
620 auto& world = dst_t->world();
621 auto src_t = src->type();
622
623 if (src->isa<Bot>()) return world.bot(dst_t);
624 if (src_t == dst_t) return src;
625
626 if (auto other = Axm::isa<bitcast>(src))
627 return other->arg()->type() == dst_t ? other->arg() : world.call<bitcast>(dst_t, other->arg());
628
629 if (auto l = Lit::isa(src)) {
630 if (dst_t->isa<Nat>()) return world.lit(dst_t, *l);
631 if (Idx::isa(dst_t)) return world.lit(dst_t, *l);
632 }
633
634 return {};
635}
636
637// TODO this currently hard-codes x86_64 ABI
638// TODO in contrast to C, we might want to give singleton types like 'Idx 1' or '[]' a size of 0 and simply nuke each
639// and every occurance of these types in a later phase
640// TODO Pi and others
641template<trait id>
642const Def* normalize_trait(const Def*, const Def*, const Def* type) {
643 auto& world = type->world();
644 if (auto ptr = Axm::isa<mem::Ptr>(type)) {
645 return world.lit_nat(8);
646 } else if (type->isa<Pi>()) {
647 return world.lit_nat(8); // Gets lowered to function ptr
648 } else if (auto size = Idx::isa(type)) {
649 if (auto w = Idx::size2bitwidth(size)) return world.lit_nat(std::max(1_n, std::bit_ceil(*w) / 8_n));
650 } else if (auto w = math::isa_f(type)) {
651 switch (*w) {
652 case 16: return world.lit_nat(2);
653 case 32: return world.lit_nat(4);
654 case 64: return world.lit_nat(8);
655 default: fe::unreachable();
656 }
657 } else if (type->isa<Sigma>() || type->isa<Meet>()) {
658 u64 offset = 0;
659 u64 align = 1;
660 for (auto t : type->ops()) {
661 auto a = Lit::isa(core::op(trait::align, t));
662 auto s = Lit::isa(core::op(trait::size, t));
663 if (!a || !s) return {};
664
665 align = std::max(align, *a);
666 offset = pad(offset, *a) + *s;
667 }
668
669 offset = pad(offset, align);
670 u64 size = std::max(1_u64, offset);
671
672 switch (id) {
673 case trait::align: return world.lit_nat(align);
674 case trait::size: return world.lit_nat(size);
675 }
676 } else if (auto arr = type->isa_imm<Arr>()) {
677 auto align = op(trait::align, arr->body());
678 if constexpr (id == trait::align) return align;
679 auto b = op(trait::size, arr->body());
680 if (b->isa<Lit>()) return world.call(nat::mul, Defs{arr->arity(), b});
681 } else if (auto join = type->isa<Join>()) {
682 if (auto sigma = convert(join)) return core::op(id, sigma);
683 }
684
685 return {};
686}
687
688template<pe id>
689const Def* normalize_pe(const Def* type, const Def*, const Def* arg) {
690 auto& world = type->world();
691
692 if constexpr (id == pe::is_closed) {
693 if (Axm::isa(pe::hlt, arg)) return world.lit_ff();
694 if (arg->is_closed()) return world.lit_tt();
695 }
696
697 return {};
698}
699
701
702} // namespace mim::plug::core
const Def * arg() const
Definition lam.h:233
A (possibly paramterized) Array.
Definition tuple.h:120
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:216
World & world() const noexcept
Definition def.cpp:413
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:355
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:260
static bool greater(const Def *a, const Def *b)
Definition def.cpp:520
bool is_closed() const
Has no free_vars()?
Definition def.cpp:392
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:840
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:576
static std::optional< T > isa(const Def *def)
Definition def.h:773
static T as(const Def *def)
Definition def.h:779
A dependent function type.
Definition lam.h:11
A dependent tuple type.
Definition tuple.h:20
const Lit * lit_idx_unsafe(u64 val)
Definition world.h:422
#define MIM_core_NORMALIZER_IMPL
Definition autogen.h:302
#define M(S, D)
The core Plugin
Definition core.h:8
const Def * normalize_nat(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_idx_unsafe(const Def *, const Def *, const Def *arg)
const Sigma * convert(const TBound< up > *b)
Definition core.cpp:19
const Def * normalize_div(const Def *full_type, const Def *, const Def *arg)
const Def * normalize_pe(const Def *type, const Def *, const Def *arg)
const Def * normalize_extrema(const Def *type, const Def *c, const Def *arg)
const Def * normalize_icmp(const Def *type, const Def *c, const Def *arg)
const Def * normalize_bit1(const Def *type, const Def *c, const Def *a)
const Def * normalize_conv(const Def *dst_t, const Def *, const Def *x)
const Def * normalize_bit2(const Def *type, const Def *c, const Def *arg)
const Def * normalize_wrap(const Def *type, const Def *c, const Def *arg)
const Def * normalize_trait(const Def *, const Def *, const Def *type)
const Def * op(trait o, const Def *type)
Definition core.h:33
const Def * normalize_abs(const Def *type, const Def *, const Def *arg)
const Def * normalize_idx(const Def *type, const Def *c, const Def *arg)
constexpr std::array< std::array< u64, 2 >, 2 > make_truth_table(bit2 id)
Definition core.h:50
const Def * normalize_bitcast(const Def *dst_t, const Def *, const Def *src)
const Def * normalize_ncmp(const Def *type, const Def *callee, const Def *arg)
@ nuw
No Unsigned Wrap around.
Definition core.h:16
@ none
Wrap around.
Definition core.h:14
@ nsw
No Signed Wrap around.
Definition core.h:15
const Def * normalize_shr(const Def *type, const Def *c, const Def *arg)
std::optional< nat_t > isa_f(const Def *def)
Definition math.h:76
The mem Plugin
Definition mem.h:11
View< const Def * > Defs
Definition def.h:51
u8 sub_t
Definition types.h:48
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:23
TBound< true > Join
AKA union.
Definition lattice.h:174
u64 pad(u64 offset, u64 align)
Definition util.h:47
constexpr bool is_commutative(Id)
Definition axm.h:152
typename detail::w2s_< w >::type w2s
Definition types.h:73
constexpr bool is_associative(Id id)
Definition axm.h:158
typename detail::w2u_< w >::type w2u
Definition types.h:72
TExt< false > Bot
Definition lattice.h:171
uint64_t u64
Definition types.h:34
bool get_sign(T val)
Definition util.h:38
uint8_t u8
Definition types.h:34
TBound< false > Meet
AKA intersection.
Definition lattice.h:173
CODE(node, _)
Definition def.h:88
@ App
Definition def.h:89
@ Lit
Definition def.h:89
static constexpr size_t Num
Definition plugin.h:115
static constexpr flags_t Base
Definition plugin.h:118
#define MIM_1_8_16_32_64(m)
Definition types.h:24