MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <mim/normalize.h>
2
4#include <mim/plug/mem/mem.h>
5
7
8namespace mim::plug::core {
9
10namespace {
11
12// clang-format off
13// See https://stackoverflow.com/a/64354296 for static_assert trick below.
14template<class Id, Id id, nat_t w>
15Res fold(u64 a, u64 b, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
16 using ST = w2s<w>;
17 using UT = w2u<w>;
19 auto u = mim::bitcast<UT>(a), v = mim::bitcast<UT>(b);
20
21 if constexpr (std::is_same_v<Id, wrap>) {
22 if constexpr (id == wrap::add) {
23 auto res = u + v;
24 if (nuw && res < u) return {};
25 // TODO nsw
26 return res;
27 } else if constexpr (id == wrap::sub) {
28 auto res = u - v;
29 // TODO nsw
30 return res;
31 } else if constexpr (id == wrap::mul) {
32 if constexpr (std::is_same_v<UT, bool>)
33 return UT(u & v);
34 else
35 return UT(u * v);
36 } else if constexpr (id == wrap::shl) {
37 if (b >= w) return {};
38 decltype(u) res;
39 if constexpr (std::is_same_v<UT, bool>)
40 res = bool(u64(u) << u64(v));
41 else
42 res = u << v;
43 if (nuw && res < u) return {};
44 if (nsw && get_sign(u) != get_sign(res)) return {};
45 return res;
46 } else {
47 []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
48 }
49 } else if constexpr (std::is_same_v<Id, shr>) {
50 if (b >= w) return {};
51 if constexpr (false) {}
52 else if constexpr (id == shr::a) return s >> t;
53 else if constexpr (id == shr::l) return u >> v;
54 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
55 } else if constexpr (std::is_same_v<Id, div>) {
56 if (b == 0) return {};
57 if constexpr (false) {}
58 else if constexpr (id == div::sdiv) return s / t;
59 else if constexpr (id == div::udiv) return u / v;
60 else if constexpr (id == div::srem) return s % t;
61 else if constexpr (id == div::urem) return u % v;
62 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
63 } else if constexpr (std::is_same_v<Id, icmp>) {
64 bool res = false;
65 auto pm = !(u >> UT(w - 1)) && (v >> UT(w - 1));
66 auto mp = (u >> UT(w - 1)) && !(v >> UT(w - 1));
67 res |= ((id & icmp::Xygle) != icmp::f) && pm;
68 res |= ((id & icmp::xYgle) != icmp::f) && mp;
69 res |= ((id & icmp::xyGle) != icmp::f) && u > v && !mp;
70 res |= ((id & icmp::xygLe) != icmp::f) && u < v && !pm;
71 res |= ((id & icmp::xyglE) != icmp::f) && u == v;
72 return res;
73 } else if constexpr (std::is_same_v<Id, extrema>) {
74 if constexpr (false) {}
75 else if(id == extrema::sm) return std::min(u, v);
76 else if(id == extrema::Sm) return std::min(s, t);
77 else if(id == extrema::sM) return std::max(u, v);
78 else if(id == extrema::SM) return std::max(s, t);
79 } else {
80 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
81 }
82}
83// clang-format on
84
85// Note that @p a and @p b are passed by reference as fold also commutes if possible.
86template<class Id, Id id> Ref fold(World& world, Ref type, const Def*& a, const Def*& b, Ref mode = {}) {
87 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
88
89 if (auto la = Lit::isa(a)) {
90 if (auto lb = Lit::isa(b)) {
91 auto size = Lit::as(Idx::size(a->type()));
92 auto width = Idx::size2bitwidth(size);
93 bool nsw = false, nuw = false;
94 if constexpr (std::is_same_v<Id, wrap>) {
95 auto m = mode ? Lit::as(mode) : 0_n;
96 nsw = m & Mode::nsw;
97 nuw = m & Mode::nuw;
98 }
99
100 Res res;
101 switch (width) {
102#define CODE(i) \
103 case i: res = fold<Id, id, i>(*la, *lb, nsw, nuw); break;
105#undef CODE
106 default:
107 // TODO this is super rough but at least better than just bailing out
108 res = fold<Id, id, 64>(*la, *lb, false, false);
109 if (res && !std::is_same_v<Id, icmp>) *res %= size;
110 }
111
112 return res ? world.lit(type, *res) : world.bot(type);
113 }
114 }
115
116 if (::mim::is_commutative(id)) commute(a, b);
117 return nullptr;
118}
119
120template<class Id, nat_t w> Res fold(u64 a, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
121 using ST = w2s<w>;
122 auto s = mim::bitcast<ST>(a);
123
124 if constexpr (std::is_same_v<Id, abs>)
125 return std::abs(s);
126 else
127 []<bool flag = false>() { static_assert(flag, "missing tag"); }
128 ();
129}
130
131template<class Id> Ref fold(World& world, Ref type, const Def*& a) {
132 if (a->isa<Bot>()) return world.bot(type);
133
134 if (auto la = Lit::isa(a)) {
135 auto size = Lit::as(Idx::size(a->type()));
136 auto width = Idx::size2bitwidth(size);
137 bool nsw = false, nuw = false;
138 Res res;
139 switch (width) {
140#define CODE(i) \
141 case i: res = fold<Id, i>(*la, nsw, nuw); break;
143#undef CODE
144 default:
145 res = fold<Id, 64>(*la, false, false);
146 if (res && !std::is_same_v<Id, icmp>) *res %= size;
147 }
148
149 return res ? world.lit(type, *res) : world.bot(type);
150 }
151 return nullptr;
152}
153
154/// Reassociates @p a and @p b according to following rules.
155/// We use the following naming convention while literals are prefixed with an `l`:
156/// ```
157/// a op b
158/// (x op y) op (z op w)
159///
160/// (1) la op (lz op w) -> (la op lz) op w
161/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
162/// (3) a op (lz op w) -> lz op (a op w)
163/// (4) (lx op y) op b -> lx op (y op b)
164/// ```
165template<class Id> Ref reassociate(Id id, World& world, [[maybe_unused]] const App* ab, Ref a, Ref b) {
166 if (!is_associative(id)) return nullptr;
167
168 auto xy = match<Id>(id, a);
169 auto zw = match<Id>(id, b);
170 auto la = a->isa<Lit>();
171 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
172 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
173 auto lx = Lit::isa(x);
174 auto lz = Lit::isa(z);
175
176 // if we reassociate, we have to forget about nsw/nuw
177 auto make_op = [&world, id](Ref a, Ref b) { return world.call(id, Mode::none, Defs{a, b}); };
178
179 if (la && lz) return make_op(make_op(a, z), w); // (1)
180 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
181 if (lz) return make_op(z, make_op(a, w)); // (3)
182 if (lx) return make_op(x, make_op(y, b)); // (4)
183 return nullptr;
184}
185
186template<class Id> Ref merge_cmps(std::array<std::array<u64, 2>, 2> tab, Ref a, Ref b) {
187 static_assert(sizeof(sub_t) == 1, "if this ever changes, please adjust the logic below");
188 static constexpr size_t num_bits = std::bit_width(Annex::Num<Id> - 1_u64);
189
190 auto& world = a->world();
191 auto a_cmp = match<Id>(a);
192 auto b_cmp = match<Id>(b);
193
194 if (a_cmp && b_cmp && a_cmp->arg() == b_cmp->arg()) {
195 // push sub bits of a_cmp and b_cmp through truth table
196 sub_t res = 0;
197 sub_t a_sub = a_cmp.sub();
198 sub_t b_sub = b_cmp.sub();
199 for (size_t i = 0; i != num_bits; ++i, res >>= 1, a_sub >>= 1, b_sub >>= 1)
200 res |= tab[a_sub & 1][b_sub & 1] << 7_u8;
201 res >>= (7_u8 - u8(num_bits));
202
203 if constexpr (std::is_same_v<Id, math::cmp>)
204 return world.call(math::cmp(res), /*mode*/ a_cmp->decurry()->arg(), a_cmp->arg());
205 else
206 return world.call(icmp(Annex::Base<icmp> | res), a_cmp->arg());
207 }
208
209 return nullptr;
210}
211
212} // namespace
213
214template<nat id> Ref normalize_nat(Ref type, Ref callee, Ref arg) {
215 auto& world = type->world();
216 auto [a, b] = arg->projs<2>();
217 if (is_commutative(id)) commute(a, b);
218 auto la = Lit::isa(a);
219 auto lb = Lit::isa(b);
220
221 if (la) {
222 if (lb) {
223 switch (id) {
224 case nat::add: return world.lit_nat(*la + *lb);
225 case nat::sub: return *la < *lb ? world.lit_nat_0() : world.lit_nat(*la - *lb);
226 case nat::mul: return world.lit_nat(*la * *lb);
227 }
228 }
229
230 if (*la == 0) {
231 switch (id) {
232 case nat::add: return b;
233 case nat::sub: return a; // 0 - b = 0
234 case nat::mul: return a; // 0 * b = 0
235 }
236 }
237
238 if (*la == 1 && id == nat::mul) return b; // 1 * b = b
239 }
240
241 if (lb && *lb == 0 && id == nat::sub) return a; // a - 0 = a
242
243 return world.raw_app(type, callee, {a, b});
244}
245
246template<ncmp id> Ref normalize_ncmp(Ref type, Ref callee, Ref arg) {
247 auto& world = type->world();
248
249 if (id == ncmp::t) return world.lit_tt();
250 if (id == ncmp::f) return world.lit_ff();
251
252 auto [a, b] = arg->projs<2>();
253 if (is_commutative(id)) commute(a, b);
254
255 if (auto la = Lit::isa(a)) {
256 if (auto lb = Lit::isa(b)) {
257 // clang-format off
258 switch (id) {
259 case ncmp:: e: return world.lit_bool(*la == *lb);
260 case ncmp::ne: return world.lit_bool(*la != *lb);
261 case ncmp::l : return world.lit_bool(*la < *lb);
262 case ncmp::le: return world.lit_bool(*la <= *lb);
263 case ncmp::g : return world.lit_bool(*la > *lb);
264 case ncmp::ge: return world.lit_bool(*la >= *lb);
265 default: fe::unreachable();
266 }
267 // clang-format on
268 }
269 }
270
271 return world.raw_app(type, callee, {a, b});
272}
273
274template<icmp id> Ref normalize_icmp(Ref type, Ref c, Ref arg) {
275 auto& world = type->world();
276 auto callee = c->as<App>();
277 auto [a, b] = arg->projs<2>();
278
279 if (auto result = fold<icmp, id>(world, type, a, b)) return result;
280 if (id == icmp::f) return world.lit_ff();
281 if (id == icmp::t) return world.lit_tt();
282 if (a == b) {
283 if (id & (icmp::e & 0xff)) return world.lit_tt();
284 if (id == icmp::ne) return world.lit_ff();
285 }
286
287 return world.raw_app(type, callee, {a, b});
288}
289
290template<extrema id> Ref normalize_extrema(Ref type, Ref c, Ref arg) {
291 auto& world = type->world();
292 auto callee = c->as<App>();
293 auto [a, b] = arg->projs<2>();
294 if (auto result = fold<extrema, id>(world, type, a, b)) return result;
295 return world.raw_app(type, callee, {a, b});
296}
297
298Ref normalize_abs(Ref type, Ref c, Ref arg) {
299 auto& world = type->world();
300 auto callee = c->as<App>();
301 auto [mem, a] = arg->projs<2>();
302 auto [_, actual_type] = type->projs<2>();
303 auto make_res = [&, mem = mem](Ref res) { return world.tuple({mem, res}); };
304
305 if (auto result = fold<abs>(world, actual_type, a)) return make_res(result);
306 return world.raw_app(type, callee, arg);
307}
308
309template<bit1 id> Ref normalize_bit1(Ref type, Ref c, Ref a) {
310 auto& world = type->world();
311 auto callee = c->as<App>();
312 auto s = callee->decurry()->arg();
313 // TODO cope with wrap around
314
315 if constexpr (id == bit1::id) return a;
316
317 if (auto ls = Lit::isa(s)) {
318 switch (id) {
319 case bit1::f: return world.lit_idx(*ls, 0);
320 case bit1::t: return world.lit_idx(*ls, *ls - 1_u64);
321 case bit1::id: fe::unreachable();
322 default: break;
323 }
324
325 assert(id == bit1::neg);
326 if (auto la = Lit::isa(a)) return world.lit_idx_mod(*ls, ~*la);
327 }
328
329 return world.raw_app(type, callee, a);
330}
331
332template<bit2 id> Ref normalize_bit2(Ref type, Ref c, Ref arg) {
333 auto& world = type->world();
334 auto callee = c->as<App>();
335 auto [a, b] = arg->projs<2>();
336 auto s = callee->decurry()->arg();
337 auto ls = Lit::isa(s);
338 // TODO cope with wrap around
339
340 if (is_commutative(id)) commute(a, b);
341
342 auto tab = make_truth_table(id);
343 if (auto res = merge_cmps<icmp>(tab, a, b)) return res;
344 if (auto res = merge_cmps<math::cmp>(tab, a, b)) return res;
345
346 auto la = Lit::isa(a);
347 auto lb = Lit::isa(b);
348
349 // clang-format off
350 switch (id) {
351 case bit2:: f: return world.lit(type, 0);
352 case bit2:: t: if (ls) return world.lit(type, *ls-1_u64); break;
353 case bit2:: fst: return a;
354 case bit2:: snd: return b;
355 case bit2:: nfst: return world.call(bit1::neg, s, a);
356 case bit2:: nsnd: return world.call(bit1::neg, s, b);
357 case bit2:: ciff: return world.call(bit2:: iff, s, Defs{b, a});
358 case bit2::nciff: return world.call(bit2::niff, s, Defs{b, a});
359 default: break;
360 }
361
362 if (la && lb && ls) {
363 switch (id) {
364 case bit2::and_: return world.lit_idx (*ls, *la & *lb);
365 case bit2:: or_: return world.lit_idx (*ls, *la | *lb);
366 case bit2::xor_: return world.lit_idx (*ls, *la ^ *lb);
367 case bit2::nand: return world.lit_idx_mod(*ls, ~(*la & *lb));
368 case bit2:: nor: return world.lit_idx_mod(*ls, ~(*la | *lb));
369 case bit2::nxor: return world.lit_idx_mod(*ls, ~(*la ^ *lb));
370 case bit2:: iff: return world.lit_idx_mod(*ls, ~ *la | *lb);
371 case bit2::niff: return world.lit_idx (*ls, *la & ~*lb);
372 default: fe::unreachable();
373 }
374 }
375
376 // TODO rewrite using bit2
377 auto unary = [&](bool x, bool y, Ref a) -> Ref {
378 if (!x && !y) return world.lit(type, 0);
379 if ( x && y) return ls ? world.lit(type, *ls-1_u64) : nullptr;
380 if (!x && y) return a;
381 if ( x && !y && id != bit2::xor_) return world.call(bit1::neg, s, a);
382 return nullptr;
383 };
384 // clang-format on
385
386 if (is_commutative(id) && a == b) {
387 if (auto res = unary(tab[0][0], tab[1][1], a)) return res;
388 }
389
390 if (la) {
391 if (*la == 0) {
392 if (auto res = unary(tab[0][0], tab[0][1], b)) return res;
393 } else if (ls && *la == *ls - 1_u64) {
394 if (auto res = unary(tab[1][0], tab[1][1], b)) return res;
395 }
396 }
397
398 if (lb) {
399 if (*lb == 0) {
400 if (auto res = unary(tab[0][0], tab[1][0], a)) return res;
401 } else if (ls && *lb == *ls - 1_u64) {
402 if (auto res = unary(tab[0][1], tab[1][1], a)) return res;
403 }
404 }
405
406 if (auto res = reassociate<bit2>(id, world, callee, a, b)) return res;
407
408 return world.raw_app(type, callee, {a, b});
409}
410
411Ref normalize_idx(Ref type, Ref c, Ref arg) {
412 auto& world = type->world();
413 auto callee = c->as<App>();
414 if (auto i = Lit::isa(arg)) {
415 if (auto s = Lit::isa(Idx::size(type))) {
416 if (*i < *s) return world.lit_idx(*s, *i);
417 if (auto m = Lit::isa(callee->arg())) return *m ? world.bot(type) : world.lit_idx_mod(*s, *i);
418 }
419 }
420
421 return world.raw_app(type, c, arg);
422}
423
424template<shr id> Ref normalize_shr(Ref type, Ref c, Ref arg) {
425 auto& world = type->world();
426 auto callee = c->as<App>();
427 auto [a, b] = arg->projs<2>();
428 auto s = Idx::size(arg->type());
429 auto ls = Lit::isa(s);
430
431 if (auto result = fold<shr, id>(world, type, a, b)) return result;
432
433 if (auto la = Lit::isa(a); la && *la == 0) {
434 switch (id) {
435 case shr::a: return a;
436 case shr::l: return a;
437 }
438 }
439
440 if (auto lb = Lit::isa(b)) {
441 if (ls && *lb > *ls) return world.bot(type);
442
443 if (*lb == 0) {
444 switch (id) {
445 case shr::a: return a;
446 case shr::l: return a;
447 }
448 }
449 }
450
451 return world.raw_app(type, callee, {a, b});
452}
453
454template<wrap id> Ref normalize_wrap(Ref type, Ref c, Ref arg) {
455 auto& world = type->world();
456 auto callee = c->as<App>();
457 auto [a, b] = arg->projs<2>();
458 auto mode = callee->arg();
459 auto s = Idx::size(a->type());
460 auto ls = Lit::isa(s);
461
462 if (auto result = fold<wrap, id>(world, type, a, b)) return result;
463
464 // clang-format off
465 if (auto la = Lit::isa(a)) {
466 if (*la == 0) {
467 switch (id) {
468 case wrap::add: return b; // 0 + b -> b
469 case wrap::sub: break;
470 case wrap::mul: return a; // 0 * b -> 0
471 case wrap::shl: return a; // 0 << b -> 0
472 }
473 } else if (*la == 1) {
474 switch (id) {
475 case wrap::add: break;
476 case wrap::sub: break;
477 case wrap::mul: return b; // 1 * b -> b
478 case wrap::shl: break;
479 }
480 }
481 }
482
483 if (auto lb = Lit::isa(b)) {
484 if (*lb == 0) {
485 switch (id) {
486 case wrap::sub: return a; // a - 0 -> a
487 case wrap::shl: return a; // a >> 0 -> a
488 default: fe::unreachable();
489 // add, mul are commutative, the literal has been normalized to the left
490 }
491 }
492
493 if (auto lm = Lit::isa(mode); lm && *lm == 0 && id == wrap::sub)
494 return world.call(wrap::add, mode, Defs{a, world.lit_idx_mod(*ls, ~*lb + 1_u64)}); // a - lb -> a + (~lb + 1)
495 else if (id == wrap::shl && ls && *lb > *ls)
496 return world.bot(type);
497 }
498
499 if (a == b) {
500 switch (id) {
501 case wrap::add: return world.call(wrap::mul, mode, Defs{world.lit(type, 2), a}); // a + a -> 2 * a
502 case wrap::sub: return world.lit(type, 0); // a - a -> 0
503 case wrap::mul: break;
504 case wrap::shl: break;
505 }
506 }
507 // clang-format on
508
509 if (auto res = reassociate<wrap>(id, world, callee, a, b)) return res;
510
511 return world.raw_app(type, callee, {a, b});
512}
513
514template<div id> Ref normalize_div(Ref full_type, Ref c, Ref arg) {
515 auto& world = full_type->world();
516 auto callee = c->as<App>();
517 auto [mem, ab] = arg->projs<2>();
518 auto [a, b] = ab->projs<2>();
519 auto [_, type] = full_type->projs<2>(); // peel off actual type
520 auto make_res = [&, mem = mem](Ref res) { return world.tuple({mem, res}); };
521
522 if (auto result = fold<div, id>(world, type, a, b)) return make_res(result);
523
524 if (auto la = Lit::isa(a)) {
525 if (*la == 0) return make_res(a); // 0 / b -> 0 and 0 % b -> 0
526 }
527
528 if (auto lb = Lit::isa(b)) {
529 if (*lb == 0) return make_res(world.bot(type)); // a / 0 -> ⊥ and a % 0 -> ⊥
530
531 if (*lb == 1) {
532 switch (id) {
533 case div::sdiv: return make_res(a); // a / 1 -> a
534 case div::udiv: return make_res(a); // a / 1 -> a
535 case div::srem: return make_res(world.lit(type, 0)); // a % 1 -> 0
536 case div::urem: return make_res(world.lit(type, 0)); // a % 1 -> 0
537 }
538 }
539 }
540
541 if (a == b) {
542 switch (id) {
543 case div::sdiv: return make_res(world.lit(type, 1)); // a / a -> 1
544 case div::udiv: return make_res(world.lit(type, 1)); // a / a -> 1
545 case div::srem: return make_res(world.lit(type, 0)); // a % a -> 0
546 case div::urem: return make_res(world.lit(type, 0)); // a % a -> 0
547 }
548 }
549
550 return world.raw_app(full_type, callee, arg);
551}
552
553template<conv id> Ref normalize_conv(Ref dst_t, Ref c, Ref x) {
554 auto& world = dst_t->world();
555 auto callee = c->as<App>();
556 auto s_t = x->type()->as<App>();
557 auto d_t = dst_t->as<App>();
558 auto s = s_t->arg();
559 auto d = d_t->arg();
560 auto ls = Lit::isa(s);
561 auto ld = Lit::isa(d);
562
563 if (s_t == d_t) return x;
564 if (x->isa<Bot>()) return world.bot(d_t);
565 if constexpr (id == conv::s) {
566 if (ls && ld && *ld < *ls) return world.call(conv::u, d, x); // just truncate - we don't care for signedness
567 }
568
569 if (auto l = Lit::isa(x); l && ls && ld) {
570 if constexpr (id == conv::u) {
571 if (*ld == 0) return world.lit(d_t, *l); // I64
572 return world.lit(d_t, *l % *ld);
573 }
574
575 auto sw = Idx::size2bitwidth(*ls);
576 auto dw = Idx::size2bitwidth(*ld);
577
578 // clang-format off
579 if (false) {}
580#define M(S, D) \
581 else if (S == sw && D == dw) return world.lit(d_t, w2s<D>(mim::bitcast<w2s<S>>(*l)));
582 M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
583 M( 8, 16) M( 8, 32) M( 8, 64)
584 M(16, 32) M(16, 64)
585 M(32, 64)
586 else assert(false && "TODO: conversion between different Idx sizes");
587 // clang-format on
588 }
589
590 return world.raw_app(dst_t, callee, x);
591}
592
593Ref normalize_bitcast(Ref dst_t, Ref callee, Ref src) {
594 auto& world = dst_t->world();
595 auto src_t = src->type();
596
597 if (src->isa<Bot>()) return world.bot(dst_t);
598 if (src_t == dst_t) return src;
599
600 if (auto other = match<bitcast>(src))
601 return other->arg()->type() == dst_t ? other->arg() : world.call<bitcast>(dst_t, other->arg());
602
603 if (auto l = Lit::isa(src)) {
604 if (dst_t->isa<Nat>()) return world.lit(dst_t, *l);
605 if (Idx::size(dst_t)) return world.lit(dst_t, *l);
606 }
607
608 return world.raw_app(dst_t, callee, src);
609}
610
611// TODO this currently hard-codes x86_64 ABI
612// TODO in contrast to C, we might want to give singleton types like '.Idx 1' or '[]' a size of 0 and simply nuke each
613// and every occurance of these types in a later phase
614// TODO Pi and others
615template<trait id> Ref normalize_trait(Ref nat, Ref callee, Ref type) {
616 auto& world = type->world();
617 if (auto ptr = match<mem::Ptr>(type)) {
618 return world.lit_nat(8);
619 } else if (type->isa<Pi>()) {
620 return world.lit_nat(8); // Gets lowered to function ptr
621 } else if (auto size = Idx::size(type)) {
622 if (auto w = Idx::size2bitwidth(size)) return world.lit_nat(std::max(1_n, std::bit_ceil(*w) / 8_n));
623 } else if (auto w = math::isa_f(type)) {
624 switch (*w) {
625 case 16: return world.lit_nat(2);
626 case 32: return world.lit_nat(4);
627 case 64: return world.lit_nat(8);
628 default: fe::unreachable();
629 }
630 } else if (type->isa<Sigma>() || type->isa<Meet>()) {
631 u64 offset = 0;
632 u64 align = 1;
633 for (auto t : type->ops()) {
634 auto a = Lit::isa(core::op(trait::align, t));
635 auto s = Lit::isa(core::op(trait::size, t));
636 if (!a || !s) goto out;
637
638 align = std::max(align, *a);
639 offset = pad(offset, *a) + *s;
640 }
641
642 offset = pad(offset, align);
643 u64 size = std::max(1_u64, offset);
644
645 switch (id) {
646 case trait::align: return world.lit_nat(align);
647 case trait::size: return world.lit_nat(size);
648 }
649 } else if (auto arr = type->isa_imm<Arr>()) {
650 auto align = op(trait::align, arr->body());
651 if constexpr (id == trait::align) return align;
652 auto b = op(trait::size, arr->body());
653 if (b->isa<Lit>()) return world.call(nat::mul, Defs{arr->shape(), b});
654 } else if (auto join = type->isa<Join>()) {
655 if (auto sigma = convert(join)) return core::op(id, sigma);
656 }
657
658out:
659 return world.raw_app(nat, callee, type);
660}
661
662Ref normalize_zip(Ref type, Ref c, Ref arg) {
663 auto& w = type->world();
664 auto callee = c->as<App>();
665 auto is_os = callee->arg();
666 auto [n_i, Is, n_o, Os, f] = is_os->projs<5>();
667 auto [r, s] = callee->decurry()->args<2>();
668 auto lr = Lit::isa(r);
669 auto ls = Lit::isa(s);
670
671 // TODO commute
672 // TODO reassociate
673 // TODO more than one Os
674 // TODO select which Is/Os to zip
675
676 if (lr && ls && *lr == 1 && *ls == 1) return w.app(f, arg);
677
678 if (auto l_in = Lit::isa(n_i)) {
679 auto args = arg->projs(*l_in);
680
681 if (lr && std::ranges::all_of(args, [](Ref arg) { return arg->isa<Tuple, Pack>(); })) {
682 auto shapes = s->projs(*lr);
683 auto s_n = Lit::isa(shapes.front());
684
685 if (s_n) {
686 auto elems = DefVec(*s_n, [&, f = f](size_t s_i) {
687 auto inner_args = DefVec(args.size(), [&](size_t i) { return args[i]->proj(*s_n, s_i); });
688 if (*lr == 1) {
689 return w.app(f, inner_args);
690 } else {
691 auto app_zip = w.app(w.annex<zip>(), {w.lit_nat(*lr - 1), w.tuple(shapes.view().subspan(1))});
692 return w.app(w.app(app_zip, is_os), inner_args);
693 }
694 });
695 return w.tuple(elems);
696 }
697 }
698 }
699
700 return w.raw_app(type, callee, arg);
701}
702
703template<pe id> Ref normalize_pe(Ref type, Ref callee, Ref arg) {
704 auto& world = type->world();
705
706 if constexpr (id == pe::known) {
707 if (match(pe::hlt, arg)) return world.lit_ff();
708 if (arg->dep_const()) return world.lit_tt();
709 }
710
711 return world.raw_app(type, callee, arg);
712}
713
715
716} // namespace mim::plug::core
const App * decurry() const
Returns App::callee again as App.
Definition lam.h:206
const Def * arg() const
Definition lam.h:214
A (possibly paramterized) Array.
Definition tuple.h:51
World & world() const
Definition def.cpp:417
bool dep_const() const
Definition def.h:335
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Definition def.h:364
auto ops() const
Definition def.h:265
const Def * type() const
Definition def.h:245
const T * isa_imm() const
Definition def.h:438
static Ref size(Ref def)
Checks if def is a .Idx s and returns s or nullptr otherwise.
Definition def.cpp:558
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:744
static std::optional< T > isa(Ref def)
Definition def.h:712
static T as(Ref def)
Definition def.h:717
A (possibly paramterized) Tuple.
Definition tuple.h:88
A dependent function type.
Definition lam.h:11
Helper class to retrieve Infer::arg if present.
Definition def.h:85
A dependent tuple type.
Definition tuple.h:9
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
Specific Bound depending on Up.
Definition lattice.h:31
Extremum. Either Top (Up) or Bottom.
Definition lattice.h:130
Data constructor for a Sigma.
Definition tuple.h:40
const Lit * lit_tt()
Definition world.h:427
Ref raw_app(Ref type, Ref callee, Ref arg)
Definition world.cpp:218
#define MIM_core_NORMALIZER_IMPL
Definition autogen.h:380
#define M(S, D)
@ Lit
Definition def.h:39
The core Plugin
Definition core.h:8
Ref normalize_extrema(Ref type, Ref c, Ref arg)
Ref normalize_bit1(Ref type, Ref c, Ref a)
const Sigma * convert(const TBound< up > *b)
Definition core.cpp:19
Ref normalize_bitcast(Ref dst_t, Ref callee, Ref src)
Ref normalize_ncmp(Ref type, Ref callee, Ref arg)
Ref normalize_div(Ref full_type, Ref c, Ref arg)
Ref normalize_nat(Ref type, Ref callee, Ref arg)
Ref normalize_conv(Ref dst_t, Ref c, Ref x)
Ref normalize_trait(Ref nat, Ref callee, Ref type)
Ref normalize_icmp(Ref type, Ref c, Ref arg)
Ref normalize_bit2(Ref type, Ref c, Ref arg)
Ref normalize_wrap(Ref type, Ref c, Ref arg)
Ref normalize_pe(Ref type, Ref callee, Ref arg)
Ref op(trait o, Ref type)
Definition core.h:35
Ref normalize_abs(Ref type, Ref c, Ref arg)
constexpr std::array< std::array< u64, 2 >, 2 > make_truth_table(bit2 id)
Definition core.h:52
Ref normalize_zip(Ref type, Ref c, Ref arg)
Ref normalize_shr(Ref type, Ref c, Ref arg)
@ nuw
No Unsigned Wrap around.
@ none
Wrap around.
@ nsw
No Signed Wrap around.
Ref normalize_idx(Ref type, Ref c, Ref arg)
std::optional< nat_t > isa_f(Ref def)
Definition math.h:78
The mem Plugin
Definition mem.h:10
View< const Def * > Defs
Definition def.h:60
Vector< const Def * > DefVec
Definition def.h:61
u8 sub_t
Definition types.h:48
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:26
auto match(Ref def)
Definition axiom.h:105
void commute(const Def *&a, const Def *&b)
Swap Lit to left - or smaller Def::gid, if no lit present.
Definition normalize.h:25
u64 pad(u64 offset, u64 align)
Definition util.h:49
constexpr bool is_commutative(Id)
Definition axiom.h:132
typename detail::w2s_< w >::type w2s
Definition types.h:73
constexpr bool is_associative(Id id)
Definition axiom.h:135
typename detail::w2u_< w >::type w2u
Definition types.h:72
TExt< false > Bot
Definition lattice.h:151
uint64_t u64
Definition types.h:34
bool get_sign(T val)
Definition util.h:40
uint8_t u8
Definition types.h:34
Definition span.h:104
static constexpr size_t Num
Definition plugin.h:115
static constexpr flags_t Base
Definition plugin.h:118
#define CODE(t, str)
Definition tok.h:53
#define MIM_1_8_16_32_64(m)
Definition types.h:24