MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <mim/normalize.h>
2
4#include <mim/plug/mem/mem.h>
5
7
8namespace mim::plug::core {
9
10namespace {
11
12// clang-format off
13// See https://stackoverflow.com/a/64354296 for static_assert trick below.
14template<class Id, Id id, nat_t w>
15Res fold(u64 a, u64 b, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
16 using ST = w2s<w>;
17 using UT = w2u<w>;
19 auto u = mim::bitcast<UT>(a), v = mim::bitcast<UT>(b);
20
21 if constexpr (std::is_same_v<Id, wrap>) {
22 if constexpr (id == wrap::add) {
23 auto res = u + v;
24 if (nuw && res < u) return {};
25 // TODO nsw
26 return res;
27 } else if constexpr (id == wrap::sub) {
28 auto res = u - v;
29 // TODO nsw
30 return res;
31 } else if constexpr (id == wrap::mul) {
32 if constexpr (std::is_same_v<UT, bool>)
33 return UT(u & v);
34 else
35 return UT(u * v);
36 } else if constexpr (id == wrap::shl) {
37 if (b >= w) return {};
38 decltype(u) res;
39 if constexpr (std::is_same_v<UT, bool>)
40 res = bool(u64(u) << u64(v));
41 else
42 res = u << v;
43 if (nuw && res < u) return {};
44 if (nsw && get_sign(u) != get_sign(res)) return {};
45 return res;
46 } else {
47 []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
48 }
49 } else if constexpr (std::is_same_v<Id, shr>) {
50 if (b >= w) return {};
51 if constexpr (false) {}
52 else if constexpr (id == shr::a) return s >> t;
53 else if constexpr (id == shr::l) return u >> v;
54 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
55 } else if constexpr (std::is_same_v<Id, div>) {
56 if (b == 0) return {};
57 if constexpr (false) {}
58 else if constexpr (id == div::sdiv) return s / t;
59 else if constexpr (id == div::udiv) return u / v;
60 else if constexpr (id == div::srem) return s % t;
61 else if constexpr (id == div::urem) return u % v;
62 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
63 } else if constexpr (std::is_same_v<Id, icmp>) {
64 bool res = false;
65 auto pm = !(u >> UT(w - 1)) && (v >> UT(w - 1));
66 auto mp = (u >> UT(w - 1)) && !(v >> UT(w - 1));
67 res |= ((id & icmp::Xygle) != icmp::f) && pm;
68 res |= ((id & icmp::xYgle) != icmp::f) && mp;
69 res |= ((id & icmp::xyGle) != icmp::f) && u > v && !mp;
70 res |= ((id & icmp::xygLe) != icmp::f) && u < v && !pm;
71 res |= ((id & icmp::xyglE) != icmp::f) && u == v;
72 return res;
73 } else if constexpr (std::is_same_v<Id, extrema>) {
74 if constexpr (false) {}
75 else if(id == extrema::sm) return std::min(u, v);
76 else if(id == extrema::Sm) return std::min(s, t);
77 else if(id == extrema::sM) return std::max(u, v);
78 else if(id == extrema::SM) return std::max(s, t);
79 } else {
80 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
81 }
82}
83// clang-format on
84
85// Note that @p a and @p b are passed by reference as fold also commutes if possible.
86template<class Id, Id id>
87const Def* fold(World& world, const Def* type, const Def*& a, const Def*& b, const Def* mode = {}) {
88 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
89
90 if (auto la = Lit::isa(a)) {
91 if (auto lb = Lit::isa(b)) {
92 auto size = Lit::as(Idx::isa(a->type()));
93 auto width = Idx::size2bitwidth(size);
94 bool nsw = false, nuw = false;
95 if constexpr (std::is_same_v<Id, wrap>) {
96 auto m = mode ? Lit::as(mode) : 0_n;
97 nsw = m & Mode::nsw;
98 nuw = m & Mode::nuw;
99 }
100
101 Res res;
102 switch (width) {
103#define CODE(i) \
104 case i: res = fold<Id, id, i>(*la, *lb, nsw, nuw); break;
106#undef CODE
107 default:
108 // TODO this is super rough but at least better than just bailing out
109 res = fold<Id, id, 64>(*la, *lb, false, false);
110 if (res && !std::is_same_v<Id, icmp>) *res %= size;
111 }
112
113 return res ? world.lit(type, *res) : world.bot(type);
114 }
115 }
116
117 if (::mim::is_commutative(id) && commute(a, b)) std::swap(a, b);
118 return nullptr;
119}
120
121template<class Id, nat_t w> Res fold(u64 a, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
122 using ST = w2s<w>;
123 auto s = mim::bitcast<ST>(a);
124
125 if constexpr (std::is_same_v<Id, abs>)
126 return std::abs(s);
127 else
128 []<bool flag = false>() { static_assert(flag, "missing tag"); }
129 ();
130}
131
132template<class Id> const Def* fold(World& world, const Def* type, const Def*& a) {
133 if (a->isa<Bot>()) return world.bot(type);
134
135 if (auto la = Lit::isa(a)) {
136 auto size = Lit::as(Idx::isa(a->type()));
137 auto width = Idx::size2bitwidth(size);
138 bool nsw = false, nuw = false;
139 Res res;
140 switch (width) {
141#define CODE(i) \
142 case i: res = fold<Id, i>(*la, nsw, nuw); break;
144#undef CODE
145 default:
146 res = fold<Id, 64>(*la, false, false);
147 if (res && !std::is_same_v<Id, icmp>) *res %= size;
148 }
149
150 return res ? world.lit(type, *res) : world.bot(type);
151 }
152 return nullptr;
153}
154
155/// Reassociates @p a and @p b according to following rules.
156/// We use the following naming convention while literals are prefixed with an `l`:
157/// ```
158/// a op b
159/// (x op y) op (z op w)
160///
161/// (1) la op (lz op w) -> (la op lz) op w
162/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
163/// (3) a op (lz op w) -> lz op (a op w)
164/// (4) (lx op y) op b -> lx op (y op b)
165/// ```
166template<class Id>
167const Def* reassociate(Id id, World& world, [[maybe_unused]] const App* ab, const Def* a, const Def* b) {
168 if (!is_associative(id)) return nullptr;
169
170 auto xy = Axm::isa<Id>(id, a);
171 auto zw = Axm::isa<Id>(id, b);
172 auto la = a->isa<Lit>();
173 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
174 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
175 auto lx = Lit::isa(x);
176 auto lz = Lit::isa(z);
177
178 // if we reassociate, we have to forget about nsw/nuw
179 auto make_op = [&world, id](const Def* a, const Def* b) { return world.call(id, Mode::none, Defs{a, b}); };
180
181 if (la && lz) return make_op(make_op(a, z), w); // (1)
182 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
183 if (lz) return make_op(z, make_op(a, w)); // (3)
184 if (lx) return make_op(x, make_op(y, b)); // (4)
185 return nullptr;
186}
187
188template<class Id> const Def* merge_cmps(std::array<std::array<u64, 2>, 2> tab, const Def* a, const Def* b) {
189 static_assert(sizeof(sub_t) == 1, "if this ever changes, please adjust the logic below");
190 static constexpr size_t num_bits = std::bit_width(Annex::Num<Id> - 1_u64);
191
192 auto& world = a->world();
193 auto a_cmp = Axm::isa<Id>(a);
194 auto b_cmp = Axm::isa<Id>(b);
195
196 if (a_cmp && b_cmp && a_cmp->arg() == b_cmp->arg()) {
197 // push sub bits of a_cmp and b_cmp through truth table
198 sub_t res = 0;
199 sub_t a_sub = a_cmp.sub();
200 sub_t b_sub = b_cmp.sub();
201 for (size_t i = 0; i != num_bits; ++i, res >>= 1, a_sub >>= 1, b_sub >>= 1)
202 res |= tab[a_sub & 1][b_sub & 1] << 7_u8;
203 res >>= (7_u8 - u8(num_bits));
204
205 if constexpr (std::is_same_v<Id, math::cmp>)
206 return world.call(math::cmp(res), /*mode*/ a_cmp->decurry()->arg(), a_cmp->arg());
207 else
208 return world.call(icmp(Annex::Base<icmp> | res), a_cmp->arg());
209 }
210
211 return nullptr;
212}
213
214} // namespace
215
216template<nat id> const Def* normalize_nat(const Def* type, const Def* callee, const Def* arg) {
217 auto& world = type->world();
218 auto [a, b] = arg->projs<2>();
219 if (is_commutative(id) && commute(a, b)) std::swap(a, b);
220 auto la = Lit::isa(a);
221 auto lb = Lit::isa(b);
222
223 if (la) {
224 if (lb) {
225 switch (id) {
226 case nat::add: return world.lit_nat(*la + *lb);
227 case nat::sub: return *la < *lb ? world.lit_nat_0() : world.lit_nat(*la - *lb);
228 case nat::mul: return world.lit_nat(*la * *lb);
229 }
230 }
231
232 if (*la == 0) {
233 switch (id) {
234 case nat::add: return b;
235 case nat::sub: return a; // 0 - b = 0
236 case nat::mul: return a; // 0 * b = 0
237 }
238 }
239
240 if (*la == 1 && id == nat::mul) return b; // 1 * b = b
241 }
242
243 if (lb && *lb == 0 && id == nat::sub) return a; // a - 0 = a
244
245 if (a == b) {
246 switch (id) {
247 case nat::add: return world.call(nat::mul, Defs{world.lit_nat(2), a}); // a + a = 2 * a
248 case nat::sub: return world.lit_nat(0); // a - a = 0
249 case nat::mul: break;
250 }
251 }
252
253 return world.raw_app(type, callee, {a, b});
254}
255
256template<ncmp id> const Def* normalize_ncmp(const Def* type, const Def* callee, const Def* arg) {
257 auto& world = type->world();
258
259 if (id == ncmp::t) return world.lit_tt();
260 if (id == ncmp::f) return world.lit_ff();
261
262 auto [a, b] = arg->projs<2>();
263 if (is_commutative(id) && commute(a, b)) std::swap(a, b);
264
265 if (auto la = Lit::isa(a)) {
266 if (auto lb = Lit::isa(b)) {
267 // clang-format off
268 switch (id) {
269 case ncmp:: e: return world.lit_bool(*la == *lb);
270 case ncmp::ne: return world.lit_bool(*la != *lb);
271 case ncmp::l : return world.lit_bool(*la < *lb);
272 case ncmp::le: return world.lit_bool(*la <= *lb);
273 case ncmp::g : return world.lit_bool(*la > *lb);
274 case ncmp::ge: return world.lit_bool(*la >= *lb);
275 default: fe::unreachable();
276 }
277 // clang-format on
278 }
279 }
280
281 return world.raw_app(type, callee, {a, b});
282}
283
284template<icmp id> const Def* normalize_icmp(const Def* type, const Def* c, const Def* arg) {
285 auto& world = type->world();
286 auto callee = c->as<App>();
287 auto [a, b] = arg->projs<2>();
288
289 if (auto result = fold<icmp, id>(world, type, a, b)) return result;
290 if (id == icmp::f) return world.lit_ff();
291 if (id == icmp::t) return world.lit_tt();
292 if (a == b) {
293 if (id & (icmp::e & 0xff)) return world.lit_tt();
294 if (id == icmp::ne) return world.lit_ff();
295 }
296
297 return world.raw_app(type, callee, {a, b});
298}
299
300template<extrema id> const Def* normalize_extrema(const Def* type, const Def* c, const Def* arg) {
301 auto& world = type->world();
302 auto callee = c->as<App>();
303 auto [a, b] = arg->projs<2>();
304 if (auto result = fold<extrema, id>(world, type, a, b)) return result;
305 return world.raw_app(type, callee, {a, b});
306}
307
308const Def* normalize_abs(const Def* type, const Def*, const Def* arg) {
309 auto& world = type->world();
310 auto [mem, a] = arg->projs<2>();
311 auto [_, actual_type] = type->projs<2>();
312 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
313
314 if (auto result = fold<abs>(world, actual_type, a)) return make_res(result);
315 return {};
316}
317
318template<bit1 id> const Def* normalize_bit1(const Def* type, const Def* c, const Def* a) {
319 auto& world = type->world();
320 auto callee = c->as<App>();
321 auto s = callee->decurry()->arg();
322 // TODO cope with wrap around
323
324 if constexpr (id == bit1::id) return a;
325
326 if (auto ls = Lit::isa(s)) {
327 switch (id) {
328 case bit1::f: return world.lit_idx(*ls, 0);
329 case bit1::t: return world.lit_idx(*ls, *ls - 1_u64);
330 case bit1::id: fe::unreachable();
331 default: break;
332 }
333
334 assert(id == bit1::neg);
335 if (auto la = Lit::isa(a)) return world.lit_idx_mod(*ls, ~*la);
336 }
337
338 return {};
339}
340
341template<bit2 id> const Def* normalize_bit2(const Def* type, const Def* c, const Def* arg) {
342 auto& world = type->world();
343 auto callee = c->as<App>();
344 auto [a, b] = arg->projs<2>();
345 auto s = callee->decurry()->arg();
346 auto ls = Lit::isa(s);
347 // TODO cope with wrap around
348
349 if (is_commutative(id) && commute(a, b)) std::swap(a, b);
350
351 auto tab = make_truth_table(id);
352 if (auto res = merge_cmps<icmp>(tab, a, b)) return res;
353 if (auto res = merge_cmps<math::cmp>(tab, a, b)) return res;
354
355 auto la = Lit::isa(a);
356 auto lb = Lit::isa(b);
357
358 // clang-format off
359 switch (id) {
360 case bit2:: f: return world.lit(type, 0);
361 case bit2:: t: if (ls) return world.lit(type, *ls-1_u64); break;
362 case bit2:: fst: return a;
363 case bit2:: snd: return b;
364 case bit2:: nfst: return world.call(bit1::neg, s, a);
365 case bit2:: nsnd: return world.call(bit1::neg, s, b);
366 case bit2:: ciff: return world.call(bit2:: iff, s, Defs{b, a});
367 case bit2::nciff: return world.call(bit2::niff, s, Defs{b, a});
368 default: break;
369 }
370
371 if (la && lb && ls) {
372 switch (id) {
373 case bit2::and_: return world.lit_idx (*ls, *la & *lb);
374 case bit2:: or_: return world.lit_idx (*ls, *la | *lb);
375 case bit2::xor_: return world.lit_idx (*ls, *la ^ *lb);
376 case bit2::nand: return world.lit_idx_mod(*ls, ~(*la & *lb));
377 case bit2:: nor: return world.lit_idx_mod(*ls, ~(*la | *lb));
378 case bit2::nxor: return world.lit_idx_mod(*ls, ~(*la ^ *lb));
379 case bit2:: iff: return world.lit_idx_mod(*ls, ~ *la | *lb);
380 case bit2::niff: return world.lit_idx (*ls, *la & ~*lb);
381 default: fe::unreachable();
382 }
383 }
384
385 // TODO rewrite using bit2
386 auto unary = [&](bool x, bool y, const Def* a) -> const Def* {
387 if (!x && !y) return world.lit(type, 0);
388 if ( x && y) return ls ? world.lit(type, *ls-1_u64) : nullptr;
389 if (!x && y) return a;
390 if ( x && !y && id != bit2::xor_) return world.call(bit1::neg, s, a);
391 return nullptr;
392 };
393 // clang-format on
394
395 if (is_commutative(id) && a == b) {
396 if (auto res = unary(tab[0][0], tab[1][1], a)) return res;
397 }
398
399 if (la) {
400 if (*la == 0) {
401 if (auto res = unary(tab[0][0], tab[0][1], b)) return res;
402 } else if (ls && *la == *ls - 1_u64) {
403 if (auto res = unary(tab[1][0], tab[1][1], b)) return res;
404 }
405 }
406
407 if (lb) {
408 if (*lb == 0) {
409 if (auto res = unary(tab[0][0], tab[1][0], a)) return res;
410 } else if (ls && *lb == *ls - 1_u64) {
411 if (auto res = unary(tab[0][1], tab[1][1], a)) return res;
412 }
413 }
414
415 if (auto res = reassociate<bit2>(id, world, callee, a, b)) return res;
416
417 return world.raw_app(type, callee, {a, b});
418}
419
420const Def* normalize_idx(const Def* type, const Def* c, const Def* arg) {
421 auto& world = type->world();
422 auto callee = c->as<App>();
423 if (auto i = Lit::isa(arg)) {
424 if (auto s = Lit::isa(Idx::isa(type))) {
425 if (*i < *s) return world.lit_idx(*s, *i);
426 if (auto m = Lit::isa(callee->arg())) return *m ? world.bot(type) : world.lit_idx_mod(*s, *i);
427 }
428 }
429
430 return {};
431}
432
433const Def* normalize_idx_unsafe(const Def*, const Def*, const Def* arg) {
434 auto& world = arg->world();
435 if (auto i = Lit::isa(arg)) return world.lit_idx_unsafe(*i);
436 return {};
437}
438
439template<shr id> const Def* normalize_shr(const Def* type, const Def* c, const Def* arg) {
440 auto& world = type->world();
441 auto callee = c->as<App>();
442 auto [a, b] = arg->projs<2>();
443 auto s = Idx::isa(arg->type());
444 auto ls = Lit::isa(s);
445
446 if (auto result = fold<shr, id>(world, type, a, b)) return result;
447
448 if (auto la = Lit::isa(a); la && *la == 0) {
449 switch (id) {
450 case shr::a: return a;
451 case shr::l: return a;
452 }
453 }
454
455 if (auto lb = Lit::isa(b)) {
456 if (ls && *lb > *ls) return world.bot(type);
457
458 if (*lb == 0) {
459 switch (id) {
460 case shr::a: return a;
461 case shr::l: return a;
462 }
463 }
464 }
465
466 return world.raw_app(type, callee, {a, b});
467}
468
469template<wrap id> const Def* normalize_wrap(const Def* type, const Def* c, const Def* arg) {
470 auto& world = type->world();
471 auto callee = c->as<App>();
472 auto [a, b] = arg->projs<2>();
473 auto mode = callee->arg();
474 auto s = Idx::isa(a->type());
475 auto ls = Lit::isa(s);
476
477 if (auto result = fold<wrap, id>(world, type, a, b)) return result;
478
479 // clang-format off
480 if (auto la = Lit::isa(a)) {
481 if (*la == 0) {
482 switch (id) {
483 case wrap::add: return b; // 0 + b -> b
484 case wrap::sub: break;
485 case wrap::mul: return a; // 0 * b -> 0
486 case wrap::shl: return a; // 0 << b -> 0
487 }
488 } else if (*la == 1) {
489 switch (id) {
490 case wrap::add: break;
491 case wrap::sub: break;
492 case wrap::mul: return b; // 1 * b -> b
493 case wrap::shl: break;
494 }
495 }
496 }
497
498 if (auto lb = Lit::isa(b)) {
499 if (*lb == 0) {
500 switch (id) {
501 case wrap::sub: return a; // a - 0 -> a
502 case wrap::shl: return a; // a >> 0 -> a
503 default: fe::unreachable();
504 // add, mul are commutative, the literal has been normalized to the left
505 }
506 }
507
508 if (auto lm = Lit::isa(mode); lm && ls && *lm == 0 && id == wrap::sub)
509 return world.call(wrap::add, mode, Defs{a, world.lit_idx_mod(*ls, ~*lb + 1_u64)}); // a - lb -> a + (~lb + 1)
510 else if (id == wrap::shl && ls && *lb > *ls)
511 return world.bot(type);
512 }
513
514 if (a == b) {
515 switch (id) {
516 case wrap::add: return world.call(wrap::mul, mode, Defs{world.lit(type, 2), a}); // a + a -> 2 * a
517 case wrap::sub: return world.lit(type, 0); // a - a -> 0
518 case wrap::mul: break;
519 case wrap::shl: break;
520 }
521 }
522 // clang-format on
523
524 if (auto res = reassociate<wrap>(id, world, callee, a, b)) return res;
525
526 return world.raw_app(type, callee, {a, b});
527}
528
529template<div id> const Def* normalize_div(const Def* full_type, const Def*, const Def* arg) {
530 auto& world = full_type->world();
531 auto [mem, ab] = arg->projs<2>();
532 auto [a, b] = ab->projs<2>();
533 auto [_, type] = full_type->projs<2>(); // peel off actual type
534 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
535
536 if (auto result = fold<div, id>(world, type, a, b)) return make_res(result);
537
538 if (auto la = Lit::isa(a)) {
539 if (*la == 0) return make_res(a); // 0 / b -> 0 and 0 % b -> 0
540 }
541
542 if (auto lb = Lit::isa(b)) {
543 if (*lb == 0) return make_res(world.bot(type)); // a / 0 -> ⊥ and a % 0 -> ⊥
544
545 if (*lb == 1) {
546 switch (id) {
547 case div::sdiv: return make_res(a); // a / 1 -> a
548 case div::udiv: return make_res(a); // a / 1 -> a
549 case div::srem: return make_res(world.lit(type, 0)); // a % 1 -> 0
550 case div::urem: return make_res(world.lit(type, 0)); // a % 1 -> 0
551 }
552 }
553 }
554
555 if (a == b) {
556 switch (id) {
557 case div::sdiv: return make_res(world.lit(type, 1)); // a / a -> 1
558 case div::udiv: return make_res(world.lit(type, 1)); // a / a -> 1
559 case div::srem: return make_res(world.lit(type, 0)); // a % a -> 0
560 case div::urem: return make_res(world.lit(type, 0)); // a % a -> 0
561 }
562 }
563
564 return {};
565}
566
567template<conv id> const Def* normalize_conv(const Def* dst_t, const Def*, const Def* x) {
568 auto& world = dst_t->world();
569 auto s_t = x->type()->as<App>();
570 auto d_t = dst_t->as<App>();
571 auto s = s_t->arg();
572 auto d = d_t->arg();
573 auto ls = Lit::isa(s);
574 auto ld = Lit::isa(d);
575
576 if (s_t == d_t) return x;
577 if (x->isa<Bot>()) return world.bot(d_t);
578 if constexpr (id == conv::s) {
579 if (ls && ld && *ld < *ls) return world.call(conv::u, d, x); // just truncate - we don't care for signedness
580 }
581
582 if (auto l = Lit::isa(x); l && ls && ld) {
583 if constexpr (id == conv::u) {
584 if (*ld == 0) return world.lit(d_t, *l); // I64
585 return world.lit(d_t, *l % *ld);
586 }
587
588 auto sw = Idx::size2bitwidth(*ls);
589 auto dw = Idx::size2bitwidth(*ld);
590
591 // clang-format off
592 if (false) {}
593#define M(S, D) \
594 else if (S == sw && D == dw) return world.lit(d_t, w2s<D>(mim::bitcast<w2s<S>>(*l)));
595 M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
596 M( 8, 16) M( 8, 32) M( 8, 64)
597 M(16, 32) M(16, 64)
598 M(32, 64)
599 else assert(false && "TODO: conversion between different Idx sizes");
600 // clang-format on
601 }
602
603 return {};
604}
605
606const Def* normalize_bitcast(const Def* dst_t, const Def*, const Def* src) {
607 auto& world = dst_t->world();
608 auto src_t = src->type();
609
610 if (src->isa<Bot>()) return world.bot(dst_t);
611 if (src_t == dst_t) return src;
612
613 if (auto other = Axm::isa<bitcast>(src))
614 return other->arg()->type() == dst_t ? other->arg() : world.call<bitcast>(dst_t, other->arg());
615
616 if (auto l = Lit::isa(src)) {
617 if (dst_t->isa<Nat>()) return world.lit(dst_t, *l);
618 if (Idx::isa(dst_t)) return world.lit(dst_t, *l);
619 }
620
621 return {};
622}
623
624// TODO this currently hard-codes x86_64 ABI
625// TODO in contrast to C, we might want to give singleton types like 'Idx 1' or '[]' a size of 0 and simply nuke each
626// and every occurance of these types in a later phase
627// TODO Pi and others
628template<trait id> const Def* normalize_trait(const Def*, const Def*, const Def* type) {
629 auto& world = type->world();
630 if (auto ptr = Axm::isa<mem::Ptr>(type)) {
631 return world.lit_nat(8);
632 } else if (type->isa<Pi>()) {
633 return world.lit_nat(8); // Gets lowered to function ptr
634 } else if (auto size = Idx::isa(type)) {
635 if (auto w = Idx::size2bitwidth(size)) return world.lit_nat(std::max(1_n, std::bit_ceil(*w) / 8_n));
636 } else if (auto w = math::isa_f(type)) {
637 switch (*w) {
638 case 16: return world.lit_nat(2);
639 case 32: return world.lit_nat(4);
640 case 64: return world.lit_nat(8);
641 default: fe::unreachable();
642 }
643 } else if (type->isa<Sigma>() || type->isa<Meet>()) {
644 u64 offset = 0;
645 u64 align = 1;
646 for (auto t : type->ops()) {
647 auto a = Lit::isa(core::op(trait::align, t));
648 auto s = Lit::isa(core::op(trait::size, t));
649 if (!a || !s) return {};
650
651 align = std::max(align, *a);
652 offset = pad(offset, *a) + *s;
653 }
654
655 offset = pad(offset, align);
656 u64 size = std::max(1_u64, offset);
657
658 switch (id) {
659 case trait::align: return world.lit_nat(align);
660 case trait::size: return world.lit_nat(size);
661 }
662 } else if (auto arr = type->isa_imm<Arr>()) {
663 auto align = op(trait::align, arr->body());
664 if constexpr (id == trait::align) return align;
665 auto b = op(trait::size, arr->body());
666 if (b->isa<Lit>()) return world.call(nat::mul, Defs{arr->shape(), b});
667 } else if (auto join = type->isa<Join>()) {
668 if (auto sigma = convert(join)) return core::op(id, sigma);
669 }
670
671 return {};
672}
673
674template<pe id> const Def* normalize_pe(const Def* type, const Def*, const Def* arg) {
675 auto& world = type->world();
676
677 if constexpr (id == pe::is_closed) {
678 if (Axm::isa(pe::hlt, arg)) return world.lit_ff();
679 if (arg->is_closed()) return world.lit_tt();
680 }
681
682 return {};
683}
684
686
687} // namespace mim::plug::core
const Def * arg() const
Definition lam.h:230
A (possibly paramterized) Array.
Definition tuple.h:100
static auto isa(const Def *def)
Definition axm.h:104
Base class for all Defs.
Definition def.h:203
World & world() const noexcept
Definition def.cpp:380
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:350
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:247
bool is_closed() const
Has no free_vars()?
Definition def.cpp:359
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:795
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:522
static std::optional< T > isa(const Def *def)
Definition def.h:733
static T as(const Def *def)
Definition def.h:738
A dependent function type.
Definition lam.h:11
A dependent tuple type.
Definition tuple.h:15
const Lit * lit_idx_unsafe(u64 val)
Definition world.h:402
#define MIM_core_NORMALIZER_IMPL
Definition autogen.h:302
#define M(S, D)
The core Plugin
Definition core.h:8
const Def * normalize_nat(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_idx_unsafe(const Def *, const Def *, const Def *arg)
const Sigma * convert(const TBound< up > *b)
Definition core.cpp:19
const Def * normalize_div(const Def *full_type, const Def *, const Def *arg)
const Def * normalize_pe(const Def *type, const Def *, const Def *arg)
const Def * normalize_extrema(const Def *type, const Def *c, const Def *arg)
const Def * normalize_icmp(const Def *type, const Def *c, const Def *arg)
const Def * normalize_bit1(const Def *type, const Def *c, const Def *a)
const Def * normalize_conv(const Def *dst_t, const Def *, const Def *x)
const Def * normalize_bit2(const Def *type, const Def *c, const Def *arg)
const Def * normalize_wrap(const Def *type, const Def *c, const Def *arg)
const Def * normalize_trait(const Def *, const Def *, const Def *type)
const Def * op(trait o, const Def *type)
Definition core.h:33
const Def * normalize_abs(const Def *type, const Def *, const Def *arg)
const Def * normalize_idx(const Def *type, const Def *c, const Def *arg)
constexpr std::array< std::array< u64, 2 >, 2 > make_truth_table(bit2 id)
Definition core.h:50
const Def * normalize_bitcast(const Def *dst_t, const Def *, const Def *src)
const Def * normalize_ncmp(const Def *type, const Def *callee, const Def *arg)
@ nuw
No Unsigned Wrap around.
Definition core.h:16
@ none
Wrap around.
Definition core.h:14
@ nsw
No Signed Wrap around.
Definition core.h:15
const Def * normalize_shr(const Def *type, const Def *c, const Def *arg)
std::optional< nat_t > isa_f(const Def *def)
Definition math.h:76
The mem Plugin
Definition mem.h:11
View< const Def * > Defs
Definition def.h:49
u8 sub_t
Definition types.h:48
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:22
bool commute(const Def *a, const Def *b)
Swap Lit to left - or smaller Def::gid, if no lit present.
Definition normalize.h:27
TBound< true > Join
AKA union.
Definition lattice.h:161
u64 pad(u64 offset, u64 align)
Definition util.h:45
constexpr bool is_commutative(Id)
Definition axm.h:146
typename detail::w2s_< w >::type w2s
Definition types.h:73
constexpr bool is_associative(Id id)
Definition axm.h:149
typename detail::w2u_< w >::type w2u
Definition types.h:72
TExt< false > Bot
Definition lattice.h:158
uint64_t u64
Definition types.h:34
bool get_sign(T val)
Definition util.h:36
uint8_t u8
Definition types.h:34
TBound< false > Meet
AKA intersection.
Definition lattice.h:160
@ App
Definition def.h:85
@ Lit
Definition def.h:85
CODE(node, name, _)
Definition def.h:84
static constexpr size_t Num
Definition plugin.h:115
static constexpr flags_t Base
Definition plugin.h:118
#define MIM_1_8_16_32_64(m)
Definition types.h:24