MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <mim/normalize.h>
2
4#include <mim/plug/mem/mem.h>
5
7
8namespace mim::plug::core {
9
10namespace {
11
12// clang-format off
13// See https://stackoverflow.com/a/64354296 for static_assert trick below.
14template<class Id, Id id, nat_t w>
15Res fold(u64 a, u64 b, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
16 using ST = w2s<w>;
17 using UT = w2u<w>;
19 auto u = mim::bitcast<UT>(a), v = mim::bitcast<UT>(b);
20
21 if constexpr (std::is_same_v<Id, wrap>) {
22 if constexpr (id == wrap::add) {
23 auto res = u + v;
24 if (nuw && res < u) return {};
25 // TODO nsw
26 return res;
27 } else if constexpr (id == wrap::sub) {
28 auto res = u - v;
29 // TODO nsw
30 return res;
31 } else if constexpr (id == wrap::mul) {
32 if constexpr (std::is_same_v<UT, bool>)
33 return UT(u & v);
34 else
35 return UT(u * v);
36 } else if constexpr (id == wrap::shl) {
37 if (b >= w) return {};
38 decltype(u) res;
39 if constexpr (std::is_same_v<UT, bool>)
40 res = bool(u64(u) << u64(v));
41 else
42 res = u << v;
43 if (nuw && res < u) return {};
44 if (nsw && get_sign(u) != get_sign(res)) return {};
45 return res;
46 } else {
47 []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
48 }
49 } else if constexpr (std::is_same_v<Id, shr>) {
50 if (b >= w) return {};
51 if constexpr (false) {}
52 else if constexpr (id == shr::a) return s >> t;
53 else if constexpr (id == shr::l) return u >> v;
54 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
55 } else if constexpr (std::is_same_v<Id, div>) {
56 if (b == 0) return {};
57 if constexpr (false) {}
58 else if constexpr (id == div::sdiv) return s / t;
59 else if constexpr (id == div::udiv) return u / v;
60 else if constexpr (id == div::srem) return s % t;
61 else if constexpr (id == div::urem) return u % v;
62 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
63 } else if constexpr (std::is_same_v<Id, icmp>) {
64 bool res = false;
65 auto pm = !(u >> UT(w - 1)) && (v >> UT(w - 1));
66 auto mp = (u >> UT(w - 1)) && !(v >> UT(w - 1));
67 res |= ((id & icmp::Xygle) != icmp::f) && pm;
68 res |= ((id & icmp::xYgle) != icmp::f) && mp;
69 res |= ((id & icmp::xyGle) != icmp::f) && u > v && !mp;
70 res |= ((id & icmp::xygLe) != icmp::f) && u < v && !pm;
71 res |= ((id & icmp::xyglE) != icmp::f) && u == v;
72 return res;
73 } else if constexpr (std::is_same_v<Id, extrema>) {
74 if constexpr (false) {}
75 else if(id == extrema::sm) return std::min(u, v);
76 else if(id == extrema::Sm) return std::min(s, t);
77 else if(id == extrema::sM) return std::max(u, v);
78 else if(id == extrema::SM) return std::max(s, t);
79 } else {
80 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
81 }
82}
83// clang-format on
84
85// Note that @p a and @p b are passed by reference as fold also commutes if possible.
86template<class Id, Id id>
87const Def* fold(World& world, const Def* type, const Def*& a, const Def*& b, const Def* mode = {}) {
88 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
89
90 if (auto la = Lit::isa(a)) {
91 if (auto lb = Lit::isa(b)) {
92 auto size = Lit::as(Idx::isa(a->type()));
93 auto width = Idx::size2bitwidth(size);
94 bool nsw = false, nuw = false;
95 if constexpr (std::is_same_v<Id, wrap>) {
96 auto m = mode ? Lit::as(mode) : 0_n;
97 nsw = m & Mode::nsw;
98 nuw = m & Mode::nuw;
99 }
100
101 Res res;
102 switch (width) {
103#define CODE(i) \
104 case i: res = fold<Id, id, i>(*la, *lb, nsw, nuw); break;
106#undef CODE
107 default:
108 // TODO this is super rough but at least better than just bailing out
109 res = fold<Id, id, 64>(*la, *lb, false, false);
110 if (res && !std::is_same_v<Id, icmp>) *res %= size;
111 }
112
113 return res ? world.lit(type, *res) : world.bot(type);
114 }
115 }
116
117 if (::mim::is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
118 return nullptr;
119}
120
121template<class Id, nat_t w>
122Res fold(u64 a, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
123 using ST = w2s<w>;
124 auto s = mim::bitcast<ST>(a);
125
126 if constexpr (std::is_same_v<Id, abs>)
127 return std::abs(s);
128 else
129 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
130}
131
132template<class Id>
133const Def* fold(World& world, const Def* type, const Def*& a) {
134 if (a->isa<Bot>()) return world.bot(type);
135
136 if (auto la = Lit::isa(a)) {
137 auto size = Lit::as(Idx::isa(a->type()));
138 auto width = Idx::size2bitwidth(size);
139 bool nsw = false, nuw = false;
140 Res res;
141 switch (width) {
142#define CODE(i) \
143 case i: res = fold<Id, i>(*la, nsw, nuw); break;
145#undef CODE
146 default:
147 res = fold<Id, 64>(*la, false, false);
148 if (res && !std::is_same_v<Id, icmp>) *res %= size;
149 }
150
151 return res ? world.lit(type, *res) : world.bot(type);
152 }
153 return nullptr;
154}
155
156/// Reassociates @p a and @p b according to following rules.
157/// We use the following naming convention while literals are prefixed with an `l`:
158/// ```
159/// a op b
160/// (x op y) op (z op w)
161///
162/// (1) la op (lz op w) -> (la op lz) op w
163/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
164/// (3) a op (lz op w) -> lz op (a op w)
165/// (4) (lx op y) op b -> lx op (y op b)
166/// ```
167template<class Id>
168const Def* reassociate(Id id, World& world, [[maybe_unused]] const App* ab, const Def* a, const Def* b) {
169 if (!is_associative(id)) return nullptr;
170
171 auto xy = Axm::isa<Id>(id, a);
172 auto zw = Axm::isa<Id>(id, b);
173 auto la = a->isa<Lit>();
174 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
175 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
176 auto lx = Lit::isa(x);
177 auto lz = Lit::isa(z);
178
179 // if we reassociate, we have to forget about nsw/nuw
180 auto make_op = [&world, id](const Def* a, const Def* b) { return world.call(id, Mode::none, Defs{a, b}); };
181
182 if (la && lz) return make_op(make_op(a, z), w); // (1)
183 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
184 if (lz) return make_op(z, make_op(a, w)); // (3)
185 if (lx) return make_op(x, make_op(y, b)); // (4)
186 return nullptr;
187}
188
189template<class Id>
190const Def* merge_cmps(std::array<std::array<u64, 2>, 2> tab, const Def* a, const Def* b) {
191 static_assert(sizeof(sub_t) == 1, "if this ever changes, please adjust the logic below");
192 static constexpr size_t num_bits = std::bit_width(Annex::num<Id>() - 1_u64);
193
194 auto& world = a->world();
195 auto a_cmp = Axm::isa<Id>(a);
196 auto b_cmp = Axm::isa<Id>(b);
197
198 if (a_cmp && b_cmp && a_cmp->arg() == b_cmp->arg()) {
199 // push sub bits of a_cmp and b_cmp through truth table
200 sub_t res = 0;
201 sub_t a_sub = a_cmp.sub();
202 sub_t b_sub = b_cmp.sub();
203 for (size_t i = 0; i != num_bits; ++i, res >>= 1, a_sub >>= 1, b_sub >>= 1)
204 res |= tab[a_sub & 1][b_sub & 1] << 7_u8;
205 res >>= (7_u8 - u8(num_bits));
206
207 if constexpr (std::is_same_v<Id, math::cmp>)
208 return world.call(math::cmp(res), /*mode*/ a_cmp->decurry()->arg(), a_cmp->arg());
209 else
210 return world.call(icmp(Annex::base<icmp>() | res), a_cmp->arg());
211 }
212
213 return nullptr;
214}
215
216} // namespace
217
218template<nat id>
219const Def* normalize_nat(const Def* type, const Def* callee, const Def* arg) {
220 auto& world = type->world();
221 auto [a, b] = arg->projs<2>();
222 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
223 auto la = Lit::isa(a);
224 auto lb = Lit::isa(b);
225
226 if (la) {
227 if (lb) {
228 switch (id) {
229 case nat::add: return world.lit_nat(*la + *lb);
230 case nat::sub: return *la < *lb ? world.lit_nat_0() : world.lit_nat(*la - *lb);
231 case nat::mul: return world.lit_nat(*la * *lb);
232 }
233 }
234
235 if (*la == 0) {
236 switch (id) {
237 case nat::add: return b;
238 case nat::sub: return a; // 0 - b = 0
239 case nat::mul: return a; // 0 * b = 0
240 }
241 }
242
243 if (*la == 1 && id == nat::mul) return b; // 1 * b = b
244 }
245
246 if (lb && *lb == 0 && id == nat::sub) return a; // a - 0 = a
247
248 if (a == b) {
249 switch (id) {
250 case nat::add: return world.call(nat::mul, Defs{world.lit_nat(2), a}); // a + a = 2 * a
251 case nat::sub: return world.lit_nat(0); // a - a = 0
252 case nat::mul: break;
253 }
254 }
255
256 return world.raw_app(type, callee, {a, b});
257}
258
259template<ncmp id>
260const Def* normalize_ncmp(const Def* type, const Def* callee, const Def* arg) {
261 auto& world = type->world();
262
263 if (id == ncmp::t) return world.lit_tt();
264 if (id == ncmp::f) return world.lit_ff();
265
266 auto [a, b] = arg->projs<2>();
267 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
268
269 if (a == b) {
270 if (id & (icmp::e & 0xff)) return world.lit_tt();
271 if (id == ncmp::ne) return world.lit_ff();
272 }
273
274 if (auto la = Lit::isa(a)) {
275 if (auto lb = Lit::isa(b)) {
276 // clang-format off
277 switch (id) {
278 case ncmp:: e: return world.lit_bool(*la == *lb);
279 case ncmp::ne: return world.lit_bool(*la != *lb);
280 case ncmp::l : return world.lit_bool(*la < *lb);
281 case ncmp::le: return world.lit_bool(*la <= *lb);
282 case ncmp::g : return world.lit_bool(*la > *lb);
283 case ncmp::ge: return world.lit_bool(*la >= *lb);
284 default: fe::unreachable();
285 }
286 // clang-format on
287 }
288 }
289
290 return world.raw_app(type, callee, {a, b});
291}
292
293template<icmp id>
294const Def* normalize_icmp(const Def* type, const Def* c, const Def* arg) {
295 auto& world = type->world();
296 auto callee = c->as<App>();
297 auto [a, b] = arg->projs<2>();
298
299 if (auto result = fold<icmp, id>(world, type, a, b)) return result;
300 if (id == icmp::f) return world.lit_ff();
301 if (id == icmp::t) return world.lit_tt();
302 if (a == b) {
303 if (id & (icmp::e & 0xff)) return world.lit_tt();
304 if (id == icmp::ne) return world.lit_ff();
305 }
306
307 return world.raw_app(type, callee, {a, b});
308}
309
310template<extrema id>
311const Def* normalize_extrema(const Def* type, const Def* c, const Def* arg) {
312 auto& world = type->world();
313 auto callee = c->as<App>();
314 auto [a, b] = arg->projs<2>();
315 if (auto result = fold<extrema, id>(world, type, a, b)) return result;
316 return world.raw_app(type, callee, {a, b});
317}
318
319const Def* normalize_abs(const Def* type, const Def*, const Def* arg) {
320 auto& world = type->world();
321 auto [mem, a] = arg->projs<2>();
322 auto [_, actual_type] = type->projs<2>();
323 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
324
325 if (auto result = fold<abs>(world, actual_type, a)) return make_res(result);
326 return {};
327}
328
329template<bit1 id>
330const Def* normalize_bit1(const Def* type, const Def* c, const Def* a) {
331 auto& world = type->world();
332 auto callee = c->as<App>();
333 auto s = callee->decurry()->arg();
334 // TODO cope with wrap around
335
336 if constexpr (id == bit1::id) return a;
337
338 if (auto ls = Lit::isa(s)) {
339 switch (id) {
340 case bit1::f: return world.lit_idx(*ls, 0);
341 case bit1::t: return world.lit_idx(*ls, *ls - 1_u64);
342 case bit1::id: fe::unreachable();
343 default: break;
344 }
345
346 assert(id == bit1::neg);
347 if (auto la = Lit::isa(a)) return world.lit_idx_mod(*ls, ~*la);
348 }
349
350 return {};
351}
352
353template<bit2 id>
354const Def* normalize_bit2(const Def* type, const Def* c, const Def* arg) {
355 auto& world = type->world();
356 auto callee = c->as<App>();
357 auto [a, b] = arg->projs<2>();
358 auto s = callee->decurry()->arg();
359 auto ls = Lit::isa(s);
360 // TODO cope with wrap around
361
362 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
363
364 auto tab = make_truth_table(id);
365 if (auto res = merge_cmps<icmp>(tab, a, b)) return res;
366 if (auto res = merge_cmps<math::cmp>(tab, a, b)) return res;
367
368 auto la = Lit::isa(a);
369 auto lb = Lit::isa(b);
370
371 // clang-format off
372 switch (id) {
373 case bit2:: f: return world.lit(type, 0);
374 case bit2:: t: if (ls) return world.lit(type, *ls-1_u64); break;
375 case bit2:: fst: return a;
376 case bit2:: snd: return b;
377 case bit2:: nfst: return world.call(bit1::neg, s, a);
378 case bit2:: nsnd: return world.call(bit1::neg, s, b);
379 case bit2:: ciff: return world.call(bit2:: iff, s, Defs{b, a});
380 case bit2::nciff: return world.call(bit2::niff, s, Defs{b, a});
381 default: break;
382 }
383
384 if (la && lb && ls) {
385 switch (id) {
386 case bit2::and_: return world.lit_idx (*ls, *la & *lb);
387 case bit2:: or_: return world.lit_idx (*ls, *la | *lb);
388 case bit2::xor_: return world.lit_idx (*ls, *la ^ *lb);
389 case bit2::nand: return world.lit_idx_mod(*ls, ~(*la & *lb));
390 case bit2:: nor: return world.lit_idx_mod(*ls, ~(*la | *lb));
391 case bit2::nxor: return world.lit_idx_mod(*ls, ~(*la ^ *lb));
392 case bit2:: iff: return world.lit_idx_mod(*ls, ~ *la | *lb);
393 case bit2::niff: return world.lit_idx (*ls, *la & ~*lb);
394 default: fe::unreachable();
395 }
396 }
397
398 // TODO rewrite using bit2
399 auto unary = [&](bool x, bool y, const Def* a) -> const Def* {
400 if (!x && !y) return world.lit(type, 0);
401 if ( x && y) return ls ? world.lit(type, *ls-1_u64) : nullptr;
402 if (!x && y) return a;
403 if ( x && !y && id != bit2::xor_) return world.call(bit1::neg, s, a);
404 return nullptr;
405 };
406 // clang-format on
407
408 if (is_commutative(id) && a == b) {
409 if (auto res = unary(tab[0][0], tab[1][1], a)) return res;
410 }
411
412 if (la) {
413 if (*la == 0) {
414 if (auto res = unary(tab[0][0], tab[0][1], b)) return res;
415 } else if (ls && *la == *ls - 1_u64) {
416 if (auto res = unary(tab[1][0], tab[1][1], b)) return res;
417 }
418 }
419
420 if (lb) {
421 if (*lb == 0) {
422 if (auto res = unary(tab[0][0], tab[1][0], a)) return res;
423 } else if (ls && *lb == *ls - 1_u64) {
424 if (auto res = unary(tab[0][1], tab[1][1], a)) return res;
425 }
426 }
427
428 if (auto res = reassociate<bit2>(id, world, callee, a, b)) return res;
429
430 return world.raw_app(type, callee, {a, b});
431}
432
433const Def* normalize_idx(const Def* type, const Def* c, const Def* arg) {
434 auto& world = type->world();
435 auto callee = c->as<App>();
436 if (auto i = Lit::isa(arg)) {
437 if (auto s = Lit::isa(Idx::isa(type))) {
438 if (*i < *s) return world.lit_idx(*s, *i);
439 if (auto m = Lit::isa(callee->arg())) return *m ? world.bot(type) : world.lit_idx_mod(*s, *i);
440 }
441 }
442
443 return {};
444}
445
446const Def* normalize_idx_unsafe(const Def*, const Def*, const Def* arg) {
447 auto& world = arg->world();
448 if (auto i = Lit::isa(arg)) return world.lit_idx_unsafe(*i);
449 return {};
450}
451
452template<shr id>
453const Def* normalize_shr(const Def* type, const Def* c, const Def* arg) {
454 auto& world = type->world();
455 auto callee = c->as<App>();
456 auto [a, b] = arg->projs<2>();
457 auto s = Idx::isa(arg->type());
458 auto ls = Lit::isa(s);
459
460 if (auto result = fold<shr, id>(world, type, a, b)) return result;
461
462 if (auto la = Lit::isa(a); la && *la == 0) {
463 switch (id) {
464 case shr::a: return a;
465 case shr::l: return a;
466 }
467 }
468
469 if (auto lb = Lit::isa(b)) {
470 if (ls && *lb > *ls) return world.bot(type);
471
472 if (*lb == 0) {
473 switch (id) {
474 case shr::a: return a;
475 case shr::l: return a;
476 }
477 }
478 }
479
480 return world.raw_app(type, callee, {a, b});
481}
482
483template<wrap id>
484const Def* normalize_wrap(const Def* type, const Def* c, const Def* arg) {
485 auto& world = type->world();
486 auto callee = c->as<App>();
487 auto [a, b] = arg->projs<2>();
488 auto mode = callee->arg();
489 auto s = Idx::isa(a->type());
490 auto ls = Lit::isa(s);
491
492 if (auto result = fold<wrap, id>(world, type, a, b)) return result;
493
494 // clang-format off
495 if (auto la = Lit::isa(a)) {
496 if (*la == 0) {
497 switch (id) {
498 case wrap::add: return b; // 0 + b -> b
499 case wrap::sub: break;
500 case wrap::mul: return a; // 0 * b -> 0
501 case wrap::shl: return a; // 0 << b -> 0
502 }
503 } else if (*la == 1) {
504 switch (id) {
505 case wrap::add: break;
506 case wrap::sub: break;
507 case wrap::mul: return b; // 1 * b -> b
508 case wrap::shl: break;
509 }
510 }
511 }
512
513 if (auto lb = Lit::isa(b)) {
514 if (*lb == 0) {
515 switch (id) {
516 case wrap::sub: return a; // a - 0 -> a
517 case wrap::shl: return a; // a >> 0 -> a
518 default: fe::unreachable();
519 // add, mul are commutative, the literal has been normalized to the left
520 }
521 }
522
523 if (auto lm = Lit::isa(mode); lm && ls && *lm == 0 && id == wrap::sub)
524 return world.call(wrap::add, mode, Defs{a, world.lit_idx_mod(*ls, ~*lb + 1_u64)}); // a - lb -> a + (~lb + 1)
525 else if (id == wrap::shl && ls && *lb > *ls)
526 return world.bot(type);
527 }
528
529 if (a == b) {
530 switch (id) {
531 case wrap::add: return world.call(wrap::mul, mode, Defs{world.lit(type, 2), a}); // a + a -> 2 * a
532 case wrap::sub: return world.lit(type, 0); // a - a -> 0
533 case wrap::mul: break;
534 case wrap::shl: break;
535 }
536 }
537 // clang-format on
538
539 if (auto res = reassociate<wrap>(id, world, callee, a, b)) return res;
540
541 return world.raw_app(type, callee, {a, b});
542}
543
544template<div id>
545const Def* normalize_div(const Def* full_type, const Def*, const Def* arg) {
546 auto& world = full_type->world();
547 auto [mem, ab] = arg->projs<2>();
548 auto [a, b] = ab->projs<2>();
549 auto [_, type] = full_type->projs<2>(); // peel off actual type
550 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
551
552 if (auto result = fold<div, id>(world, type, a, b)) return make_res(result);
553
554 if (auto la = Lit::isa(a)) {
555 if (*la == 0) return make_res(a); // 0 / b -> 0 and 0 % b -> 0
556 }
557
558 if (auto lb = Lit::isa(b)) {
559 if (*lb == 0) return make_res(world.bot(type)); // a / 0 -> ⊥ and a % 0 -> ⊥
560
561 if (*lb == 1) {
562 switch (id) {
563 case div::sdiv: return make_res(a); // a / 1 -> a
564 case div::udiv: return make_res(a); // a / 1 -> a
565 case div::srem: return make_res(world.lit(type, 0)); // a % 1 -> 0
566 case div::urem: return make_res(world.lit(type, 0)); // a % 1 -> 0
567 }
568 }
569 }
570
571 if (a == b) {
572 switch (id) {
573 case div::sdiv: return make_res(world.lit(type, 1)); // a / a -> 1
574 case div::udiv: return make_res(world.lit(type, 1)); // a / a -> 1
575 case div::srem: return make_res(world.lit(type, 0)); // a % a -> 0
576 case div::urem: return make_res(world.lit(type, 0)); // a % a -> 0
577 }
578 }
579
580 return {};
581}
582
583template<conv id>
584const Def* normalize_conv(const Def* dst_t, const Def*, const Def* x) {
585 auto& world = dst_t->world();
586 auto s_t = x->type()->as<App>();
587 auto d_t = dst_t->as<App>();
588 auto s = s_t->arg();
589 auto d = d_t->arg();
590 auto ls = Lit::isa(s);
591 auto ld = Lit::isa(d);
592
593 if (s_t == d_t) return x;
594 if (x->isa<Bot>()) return world.bot(d_t);
595 if constexpr (id == conv::s) {
596 if (ls && ld && *ld < *ls) return world.call(conv::u, d, x); // just truncate - we don't care for signedness
597 }
598
599 if (auto l = Lit::isa(x); l && ls && ld) {
600 if constexpr (id == conv::u) {
601 if (*ld == 0) return world.lit(d_t, *l); // I64
602 return world.lit(d_t, *l % *ld);
603 }
604
605 auto sw = Idx::size2bitwidth(*ls);
606 auto dw = Idx::size2bitwidth(*ld);
607
608 // clang-format off
609 if (false) {}
610#define M(S, D) \
611 else if (S == sw && D == dw) return world.lit(d_t, w2s<D>(mim::bitcast<w2s<S>>(*l)));
612 M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
613 M( 8, 16) M( 8, 32) M( 8, 64)
614 M(16, 32) M(16, 64)
615 M(32, 64)
616 else assert(false && "TODO: conversion between different Idx sizes");
617 // clang-format on
618 }
619
620 return {};
621}
622
623const Def* normalize_bitcast(const Def* dst_t, const Def*, const Def* src) {
624 auto& world = dst_t->world();
625 auto src_t = src->type();
626
627 if (src->isa<Bot>()) return world.bot(dst_t);
628 if (src_t == dst_t) return src;
629
630 if (auto other = Axm::isa<bitcast>(src))
631 return other->arg()->type() == dst_t ? other->arg() : world.call<bitcast>(dst_t, other->arg());
632
633 if (auto l = Lit::isa(src)) {
634 if (dst_t->isa<Nat>()) return world.lit(dst_t, *l);
635 if (Idx::isa(dst_t)) return world.lit(dst_t, *l);
636 }
637
638 return {};
639}
640
641// TODO this currently hard-codes x86_64 ABI
642// TODO in contrast to C, we might want to give singleton types like 'Idx 1' or '[]' a size of 0 and simply nuke each
643// and every occurance of these types in a later phase
644// TODO Pi and others
645template<trait id>
646const Def* normalize_trait(const Def*, const Def*, const Def* type) {
647 auto& world = type->world();
648 if (auto ptr = Axm::isa<mem::Ptr>(type)) {
649 return world.lit_nat(8);
650 } else if (type->isa<Pi>()) {
651 return world.lit_nat(8); // Gets lowered to function ptr
652 } else if (auto size = Idx::isa(type)) {
653 if (auto w = Idx::size2bitwidth(size)) return world.lit_nat(std::max(1_n, std::bit_ceil(*w) / 8_n));
654 } else if (auto w = math::isa_f(type)) {
655 switch (*w) {
656 case 16: return world.lit_nat(2);
657 case 32: return world.lit_nat(4);
658 case 64: return world.lit_nat(8);
659 default: fe::unreachable();
660 }
661 } else if (type->isa<Sigma>() || type->isa<Meet>()) {
662 u64 offset = 0;
663 u64 align = 1;
664 for (auto t : type->ops()) {
665 auto a = Lit::isa(core::op(trait::align, t));
666 auto s = Lit::isa(core::op(trait::size, t));
667 if (!a || !s) return {};
668
669 align = std::max(align, *a);
670 offset = pad(offset, *a) + *s;
671 }
672
673 offset = pad(offset, align);
674 u64 size = std::max(1_u64, offset);
675
676 switch (id) {
677 case trait::align: return world.lit_nat(align);
678 case trait::size: return world.lit_nat(size);
679 }
680 } else if (auto arr = type->isa_imm<Arr>()) {
681 auto align = op(trait::align, arr->body());
682 if constexpr (id == trait::align) return align;
683 auto b = op(trait::size, arr->body());
684 if (b->isa<Lit>()) return world.call(nat::mul, Defs{arr->arity(), b});
685 } else if (auto join = type->isa<Join>()) {
686 if (auto sigma = convert(join)) return core::op(id, sigma);
687 }
688
689 return {};
690}
691
692template<pe id>
693const Def* normalize_pe(const Def* type, const Def*, const Def* arg) {
694 auto& world = type->world();
695
696 if constexpr (id == pe::is_closed) {
697 if (Axm::isa(pe::hlt, arg)) return world.lit_ff();
698 if (arg->is_closed()) return world.lit_tt();
699 }
700
701 return {};
702}
703
705
706} // namespace mim::plug::core
const Def * arg() const
Definition lam.h:285
A (possibly paramterized) Array.
Definition tuple.h:117
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
World & world() const noexcept
Definition def.cpp:438
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:390
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:446
static bool greater(const Def *a, const Def *b)
Definition def.cpp:549
bool is_closed() const
Has no free_vars()?
Definition def.cpp:417
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:893
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:610
static std::optional< T > isa(const Def *def)
Definition def.h:826
static T as(const Def *def)
Definition def.h:832
A dependent function type.
Definition lam.h:14
A dependent tuple type.
Definition tuple.h:20
const Lit * lit_idx_unsafe(u64 val)
Definition world.h:468
#define MIM_core_NORMALIZER_IMPL
Definition autogen.h:302
#define M(S, D)
The core Plugin
Definition core.h:8
const Def * normalize_nat(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_idx_unsafe(const Def *, const Def *, const Def *arg)
const Sigma * convert(const TBound< up > *b)
Definition core.cpp:18
const Def * normalize_div(const Def *full_type, const Def *, const Def *arg)
const Def * normalize_pe(const Def *type, const Def *, const Def *arg)
const Def * normalize_extrema(const Def *type, const Def *c, const Def *arg)
const Def * normalize_icmp(const Def *type, const Def *c, const Def *arg)
const Def * normalize_bit1(const Def *type, const Def *c, const Def *a)
const Def * normalize_conv(const Def *dst_t, const Def *, const Def *x)
const Def * normalize_bit2(const Def *type, const Def *c, const Def *arg)
const Def * normalize_wrap(const Def *type, const Def *c, const Def *arg)
const Def * normalize_trait(const Def *, const Def *, const Def *type)
const Def * op(trait o, const Def *type)
Definition core.h:33
const Def * normalize_abs(const Def *type, const Def *, const Def *arg)
const Def * normalize_idx(const Def *type, const Def *c, const Def *arg)
constexpr std::array< std::array< u64, 2 >, 2 > make_truth_table(bit2 id)
Definition core.h:50
const Def * normalize_bitcast(const Def *dst_t, const Def *, const Def *src)
const Def * normalize_ncmp(const Def *type, const Def *callee, const Def *arg)
@ nuw
No Unsigned Wrap around.
Definition core.h:16
@ none
Wrap around.
Definition core.h:14
@ nsw
No Signed Wrap around.
Definition core.h:15
const Def * normalize_shr(const Def *type, const Def *c, const Def *arg)
std::optional< nat_t > isa_f(const Def *def)
Definition math.h:76
The mem Plugin
Definition mem.h:11
View< const Def * > Defs
Definition def.h:76
u8 sub_t
Definition types.h:49
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:23
TBound< true > Join
AKA union.
Definition lattice.h:174
u64 pad(u64 offset, u64 align)
Definition util.h:47
constexpr bool is_commutative(Id)
Definition axm.h:152
typename detail::w2s_< w >::type w2s
Definition types.h:74
constexpr bool is_associative(Id id)
Definition axm.h:158
typename detail::w2u_< w >::type w2u
Definition types.h:73
TExt< false > Bot
Definition lattice.h:171
uint64_t u64
Definition types.h:35
bool get_sign(T val)
Definition util.h:38
uint8_t u8
Definition types.h:35
TBound< false > Meet
AKA intersection.
Definition lattice.h:173
CODE(node, _)
Definition def.h:113
@ App
Definition def.h:114
@ Lit
Definition def.h:114
static consteval size_t num()
Definition plugin.h:118
static consteval flags_t base()
Definition plugin.h:119
#define MIM_1_8_16_32_64(m)
Definition types.h:25