MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <mim/normalize.h>
2
4#include <mim/plug/mem/mem.h>
5
7
8namespace mim::plug::core {
9
10namespace {
11
12// clang-format off
13// See https://stackoverflow.com/a/64354296 for static_assert trick below.
14template<class Id, Id id, nat_t w>
15Res fold(u64 a, u64 b, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
16 using ST = w2s<w>;
17 using UT = w2u<w>;
19 auto u = mim::bitcast<UT>(a), v = mim::bitcast<UT>(b);
20
21 if constexpr (std::is_same_v<Id, wrap>) {
22 if constexpr (id == wrap::add) {
23 auto res = u + v;
24 if (nuw && res < u) return {};
25 // TODO nsw
26 return res;
27 } else if constexpr (id == wrap::sub) {
28 auto res = u - v;
29 // TODO nsw
30 return res;
31 } else if constexpr (id == wrap::mul) {
32 if constexpr (std::is_same_v<UT, bool>)
33 return UT(u & v);
34 else
35 return UT(u * v);
36 } else if constexpr (id == wrap::shl) {
37 if (b >= w) return {};
38 decltype(u) res;
39 if constexpr (std::is_same_v<UT, bool>)
40 res = bool(u64(u) << u64(v));
41 else
42 res = u << v;
43 if (nuw && res < u) return {};
44 if (nsw && get_sign(u) != get_sign(res)) return {};
45 return res;
46 } else {
47 []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
48 }
49 } else if constexpr (std::is_same_v<Id, shr>) {
50 if (b >= w) return {};
51 if constexpr (false) {}
52 else if constexpr (id == shr::a) return s >> t;
53 else if constexpr (id == shr::l) return u >> v;
54 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
55 } else if constexpr (std::is_same_v<Id, div>) {
56 if (b == 0) return {};
57 if constexpr (false) {}
58 else if constexpr (id == div::sdiv) return s / t;
59 else if constexpr (id == div::udiv) return u / v;
60 else if constexpr (id == div::srem) return s % t;
61 else if constexpr (id == div::urem) return u % v;
62 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
63 } else if constexpr (std::is_same_v<Id, icmp>) {
64 bool res = false;
65 auto pm = !(u >> UT(w - 1)) && (v >> UT(w - 1));
66 auto mp = (u >> UT(w - 1)) && !(v >> UT(w - 1));
67 res |= ((id & icmp::Xygle) != icmp::f) && pm;
68 res |= ((id & icmp::xYgle) != icmp::f) && mp;
69 res |= ((id & icmp::xyGle) != icmp::f) && u > v && !mp;
70 res |= ((id & icmp::xygLe) != icmp::f) && u < v && !pm;
71 res |= ((id & icmp::xyglE) != icmp::f) && u == v;
72 return res;
73 } else if constexpr (std::is_same_v<Id, extrema>) {
74 if constexpr (false) {}
75 else if(id == extrema::sm) return std::min(u, v);
76 else if(id == extrema::Sm) return std::min(s, t);
77 else if(id == extrema::sM) return std::max(u, v);
78 else if(id == extrema::SM) return std::max(s, t);
79 } else {
80 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
81 }
82}
83// clang-format on
84
85// Note that @p a and @p b are passed by reference as fold also commutes if possible.
86template<class Id, Id id> Ref fold(World& world, Ref type, const Def*& a, const Def*& b, Ref mode = {}) {
87 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
88
89 if (auto la = Lit::isa(a)) {
90 if (auto lb = Lit::isa(b)) {
91 auto size = Lit::as(Idx::isa(a->type()));
92 auto width = Idx::size2bitwidth(size);
93 bool nsw = false, nuw = false;
94 if constexpr (std::is_same_v<Id, wrap>) {
95 auto m = mode ? Lit::as(mode) : 0_n;
96 nsw = m & Mode::nsw;
97 nuw = m & Mode::nuw;
98 }
99
100 Res res;
101 switch (width) {
102#define CODE(i) \
103 case i: res = fold<Id, id, i>(*la, *lb, nsw, nuw); break;
105#undef CODE
106 default:
107 // TODO this is super rough but at least better than just bailing out
108 res = fold<Id, id, 64>(*la, *lb, false, false);
109 if (res && !std::is_same_v<Id, icmp>) *res %= size;
110 }
111
112 return res ? world.lit(type, *res) : world.bot(type);
113 }
114 }
115
116 if (::mim::is_commutative(id)) commute(a, b);
117 return nullptr;
118}
119
120template<class Id, nat_t w> Res fold(u64 a, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
121 using ST = w2s<w>;
122 auto s = mim::bitcast<ST>(a);
123
124 if constexpr (std::is_same_v<Id, abs>)
125 return std::abs(s);
126 else
127 []<bool flag = false>() { static_assert(flag, "missing tag"); }
128 ();
129}
130
131template<class Id> Ref fold(World& world, Ref type, const Def*& a) {
132 if (a->isa<Bot>()) return world.bot(type);
133
134 if (auto la = Lit::isa(a)) {
135 auto size = Lit::as(Idx::isa(a->type()));
136 auto width = Idx::size2bitwidth(size);
137 bool nsw = false, nuw = false;
138 Res res;
139 switch (width) {
140#define CODE(i) \
141 case i: res = fold<Id, i>(*la, nsw, nuw); break;
143#undef CODE
144 default:
145 res = fold<Id, 64>(*la, false, false);
146 if (res && !std::is_same_v<Id, icmp>) *res %= size;
147 }
148
149 return res ? world.lit(type, *res) : world.bot(type);
150 }
151 return nullptr;
152}
153
154/// Reassociates @p a and @p b according to following rules.
155/// We use the following naming convention while literals are prefixed with an `l`:
156/// ```
157/// a op b
158/// (x op y) op (z op w)
159///
160/// (1) la op (lz op w) -> (la op lz) op w
161/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
162/// (3) a op (lz op w) -> lz op (a op w)
163/// (4) (lx op y) op b -> lx op (y op b)
164/// ```
165template<class Id> Ref reassociate(Id id, World& world, [[maybe_unused]] const App* ab, Ref a, Ref b) {
166 if (!is_associative(id)) return nullptr;
167
168 auto xy = match<Id>(id, a);
169 auto zw = match<Id>(id, b);
170 auto la = a->isa<Lit>();
171 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
172 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
173 auto lx = Lit::isa(x);
174 auto lz = Lit::isa(z);
175
176 // if we reassociate, we have to forget about nsw/nuw
177 auto make_op = [&world, id](Ref a, Ref b) { return world.call(id, Mode::none, Defs{a, b}); };
178
179 if (la && lz) return make_op(make_op(a, z), w); // (1)
180 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
181 if (lz) return make_op(z, make_op(a, w)); // (3)
182 if (lx) return make_op(x, make_op(y, b)); // (4)
183 return nullptr;
184}
185
186template<class Id> Ref merge_cmps(std::array<std::array<u64, 2>, 2> tab, Ref a, Ref b) {
187 static_assert(sizeof(sub_t) == 1, "if this ever changes, please adjust the logic below");
188 static constexpr size_t num_bits = std::bit_width(Annex::Num<Id> - 1_u64);
189
190 auto& world = a->world();
191 auto a_cmp = match<Id>(a);
192 auto b_cmp = match<Id>(b);
193
194 if (a_cmp && b_cmp && a_cmp->arg() == b_cmp->arg()) {
195 // push sub bits of a_cmp and b_cmp through truth table
196 sub_t res = 0;
197 sub_t a_sub = a_cmp.sub();
198 sub_t b_sub = b_cmp.sub();
199 for (size_t i = 0; i != num_bits; ++i, res >>= 1, a_sub >>= 1, b_sub >>= 1)
200 res |= tab[a_sub & 1][b_sub & 1] << 7_u8;
201 res >>= (7_u8 - u8(num_bits));
202
203 if constexpr (std::is_same_v<Id, math::cmp>)
204 return world.call(math::cmp(res), /*mode*/ a_cmp->decurry()->arg(), a_cmp->arg());
205 else
206 return world.call(icmp(Annex::Base<icmp> | res), a_cmp->arg());
207 }
208
209 return nullptr;
210}
211
212} // namespace
213
214template<nat id> Ref normalize_nat(Ref type, Ref callee, Ref arg) {
215 auto& world = type->world();
216 auto [a, b] = arg->projs<2>();
217 if (is_commutative(id)) commute(a, b);
218 auto la = Lit::isa(a);
219 auto lb = Lit::isa(b);
220
221 if (la) {
222 if (lb) {
223 switch (id) {
224 case nat::add: return world.lit_nat(*la + *lb);
225 case nat::sub: return *la < *lb ? world.lit_nat_0() : world.lit_nat(*la - *lb);
226 case nat::mul: return world.lit_nat(*la * *lb);
227 }
228 }
229
230 if (*la == 0) {
231 switch (id) {
232 case nat::add: return b;
233 case nat::sub: return a; // 0 - b = 0
234 case nat::mul: return a; // 0 * b = 0
235 }
236 }
237
238 if (*la == 1 && id == nat::mul) return b; // 1 * b = b
239 }
240
241 if (lb && *lb == 0 && id == nat::sub) return a; // a - 0 = a
242
243 return world.raw_app(type, callee, {a, b});
244}
245
246template<ncmp id> Ref normalize_ncmp(Ref type, Ref callee, Ref arg) {
247 auto& world = type->world();
248
249 if (id == ncmp::t) return world.lit_tt();
250 if (id == ncmp::f) return world.lit_ff();
251
252 auto [a, b] = arg->projs<2>();
253 if (is_commutative(id)) commute(a, b);
254
255 if (auto la = Lit::isa(a)) {
256 if (auto lb = Lit::isa(b)) {
257 // clang-format off
258 switch (id) {
259 case ncmp:: e: return world.lit_bool(*la == *lb);
260 case ncmp::ne: return world.lit_bool(*la != *lb);
261 case ncmp::l : return world.lit_bool(*la < *lb);
262 case ncmp::le: return world.lit_bool(*la <= *lb);
263 case ncmp::g : return world.lit_bool(*la > *lb);
264 case ncmp::ge: return world.lit_bool(*la >= *lb);
265 default: fe::unreachable();
266 }
267 // clang-format on
268 }
269 }
270
271 return world.raw_app(type, callee, {a, b});
272}
273
274template<icmp id> Ref normalize_icmp(Ref type, Ref c, Ref arg) {
275 auto& world = type->world();
276 auto callee = c->as<App>();
277 auto [a, b] = arg->projs<2>();
278
279 if (auto result = fold<icmp, id>(world, type, a, b)) return result;
280 if (id == icmp::f) return world.lit_ff();
281 if (id == icmp::t) return world.lit_tt();
282 if (a == b) {
283 if (id & (icmp::e & 0xff)) return world.lit_tt();
284 if (id == icmp::ne) return world.lit_ff();
285 }
286
287 return world.raw_app(type, callee, {a, b});
288}
289
290template<extrema id> Ref normalize_extrema(Ref type, Ref c, Ref arg) {
291 auto& world = type->world();
292 auto callee = c->as<App>();
293 auto [a, b] = arg->projs<2>();
294 if (auto result = fold<extrema, id>(world, type, a, b)) return result;
295 return world.raw_app(type, callee, {a, b});
296}
297
299 auto& world = type->world();
300 auto [mem, a] = arg->projs<2>();
301 auto [_, actual_type] = type->projs<2>();
302 auto make_res = [&, mem = mem](Ref res) { return world.tuple({mem, res}); };
303
304 if (auto result = fold<abs>(world, actual_type, a)) return make_res(result);
305 return {};
306}
307
308template<bit1 id> Ref normalize_bit1(Ref type, Ref c, Ref a) {
309 auto& world = type->world();
310 auto callee = c->as<App>();
311 auto s = callee->decurry()->arg();
312 // TODO cope with wrap around
313
314 if constexpr (id == bit1::id) return a;
315
316 if (auto ls = Lit::isa(s)) {
317 switch (id) {
318 case bit1::f: return world.lit_idx(*ls, 0);
319 case bit1::t: return world.lit_idx(*ls, *ls - 1_u64);
320 case bit1::id: fe::unreachable();
321 default: break;
322 }
323
324 assert(id == bit1::neg);
325 if (auto la = Lit::isa(a)) return world.lit_idx_mod(*ls, ~*la);
326 }
327
328 return {};
329}
330
331template<bit2 id> Ref normalize_bit2(Ref type, Ref c, Ref arg) {
332 auto& world = type->world();
333 auto callee = c->as<App>();
334 auto [a, b] = arg->projs<2>();
335 auto s = callee->decurry()->arg();
336 auto ls = Lit::isa(s);
337 // TODO cope with wrap around
338
339 if (is_commutative(id)) commute(a, b);
340
341 auto tab = make_truth_table(id);
342 if (auto res = merge_cmps<icmp>(tab, a, b)) return res;
343 if (auto res = merge_cmps<math::cmp>(tab, a, b)) return res;
344
345 auto la = Lit::isa(a);
346 auto lb = Lit::isa(b);
347
348 // clang-format off
349 switch (id) {
350 case bit2:: f: return world.lit(type, 0);
351 case bit2:: t: if (ls) return world.lit(type, *ls-1_u64); break;
352 case bit2:: fst: return a;
353 case bit2:: snd: return b;
354 case bit2:: nfst: return world.call(bit1::neg, s, a);
355 case bit2:: nsnd: return world.call(bit1::neg, s, b);
356 case bit2:: ciff: return world.call(bit2:: iff, s, Defs{b, a});
357 case bit2::nciff: return world.call(bit2::niff, s, Defs{b, a});
358 default: break;
359 }
360
361 if (la && lb && ls) {
362 switch (id) {
363 case bit2::and_: return world.lit_idx (*ls, *la & *lb);
364 case bit2:: or_: return world.lit_idx (*ls, *la | *lb);
365 case bit2::xor_: return world.lit_idx (*ls, *la ^ *lb);
366 case bit2::nand: return world.lit_idx_mod(*ls, ~(*la & *lb));
367 case bit2:: nor: return world.lit_idx_mod(*ls, ~(*la | *lb));
368 case bit2::nxor: return world.lit_idx_mod(*ls, ~(*la ^ *lb));
369 case bit2:: iff: return world.lit_idx_mod(*ls, ~ *la | *lb);
370 case bit2::niff: return world.lit_idx (*ls, *la & ~*lb);
371 default: fe::unreachable();
372 }
373 }
374
375 // TODO rewrite using bit2
376 auto unary = [&](bool x, bool y, Ref a) -> Ref {
377 if (!x && !y) return world.lit(type, 0);
378 if ( x && y) return ls ? world.lit(type, *ls-1_u64) : nullptr;
379 if (!x && y) return a;
380 if ( x && !y && id != bit2::xor_) return world.call(bit1::neg, s, a);
381 return nullptr;
382 };
383 // clang-format on
384
385 if (is_commutative(id) && a == b) {
386 if (auto res = unary(tab[0][0], tab[1][1], a)) return res;
387 }
388
389 if (la) {
390 if (*la == 0) {
391 if (auto res = unary(tab[0][0], tab[0][1], b)) return res;
392 } else if (ls && *la == *ls - 1_u64) {
393 if (auto res = unary(tab[1][0], tab[1][1], b)) return res;
394 }
395 }
396
397 if (lb) {
398 if (*lb == 0) {
399 if (auto res = unary(tab[0][0], tab[1][0], a)) return res;
400 } else if (ls && *lb == *ls - 1_u64) {
401 if (auto res = unary(tab[0][1], tab[1][1], a)) return res;
402 }
403 }
404
405 if (auto res = reassociate<bit2>(id, world, callee, a, b)) return res;
406
407 return world.raw_app(type, callee, {a, b});
408}
409
410Ref normalize_idx(Ref type, Ref c, Ref arg) {
411 auto& world = type->world();
412 auto callee = c->as<App>();
413 if (auto i = Lit::isa(arg)) {
414 if (auto s = Lit::isa(Idx::isa(type))) {
415 if (*i < *s) return world.lit_idx(*s, *i);
416 if (auto m = Lit::isa(callee->arg())) return *m ? world.bot(type) : world.lit_idx_mod(*s, *i);
417 }
418 }
419
420 return {};
421}
422
423template<shr id> Ref normalize_shr(Ref type, Ref c, Ref arg) {
424 auto& world = type->world();
425 auto callee = c->as<App>();
426 auto [a, b] = arg->projs<2>();
427 auto s = Idx::isa(arg->type());
428 auto ls = Lit::isa(s);
429
430 if (auto result = fold<shr, id>(world, type, a, b)) return result;
431
432 if (auto la = Lit::isa(a); la && *la == 0) {
433 switch (id) {
434 case shr::a: return a;
435 case shr::l: return a;
436 }
437 }
438
439 if (auto lb = Lit::isa(b)) {
440 if (ls && *lb > *ls) return world.bot(type);
441
442 if (*lb == 0) {
443 switch (id) {
444 case shr::a: return a;
445 case shr::l: return a;
446 }
447 }
448 }
449
450 return world.raw_app(type, callee, {a, b});
451}
452
453template<wrap id> Ref normalize_wrap(Ref type, Ref c, Ref arg) {
454 auto& world = type->world();
455 auto callee = c->as<App>();
456 auto [a, b] = arg->projs<2>();
457 auto mode = callee->arg();
458 auto s = Idx::isa(a->type());
459 auto ls = Lit::isa(s);
460
461 if (auto result = fold<wrap, id>(world, type, a, b)) return result;
462
463 // clang-format off
464 if (auto la = Lit::isa(a)) {
465 if (*la == 0) {
466 switch (id) {
467 case wrap::add: return b; // 0 + b -> b
468 case wrap::sub: break;
469 case wrap::mul: return a; // 0 * b -> 0
470 case wrap::shl: return a; // 0 << b -> 0
471 }
472 } else if (*la == 1) {
473 switch (id) {
474 case wrap::add: break;
475 case wrap::sub: break;
476 case wrap::mul: return b; // 1 * b -> b
477 case wrap::shl: break;
478 }
479 }
480 }
481
482 if (auto lb = Lit::isa(b)) {
483 if (*lb == 0) {
484 switch (id) {
485 case wrap::sub: return a; // a - 0 -> a
486 case wrap::shl: return a; // a >> 0 -> a
487 default: fe::unreachable();
488 // add, mul are commutative, the literal has been normalized to the left
489 }
490 }
491
492 if (auto lm = Lit::isa(mode); lm && *lm == 0 && id == wrap::sub)
493 return world.call(wrap::add, mode, Defs{a, world.lit_idx_mod(*ls, ~*lb + 1_u64)}); // a - lb -> a + (~lb + 1)
494 else if (id == wrap::shl && ls && *lb > *ls)
495 return world.bot(type);
496 }
497
498 if (a == b) {
499 switch (id) {
500 case wrap::add: return world.call(wrap::mul, mode, Defs{world.lit(type, 2), a}); // a + a -> 2 * a
501 case wrap::sub: return world.lit(type, 0); // a - a -> 0
502 case wrap::mul: break;
503 case wrap::shl: break;
504 }
505 }
506 // clang-format on
507
508 if (auto res = reassociate<wrap>(id, world, callee, a, b)) return res;
509
510 return world.raw_app(type, callee, {a, b});
511}
512
513template<div id> Ref normalize_div(Ref full_type, Ref, Ref arg) {
514 auto& world = full_type->world();
515 auto [mem, ab] = arg->projs<2>();
516 auto [a, b] = ab->projs<2>();
517 auto [_, type] = full_type->projs<2>(); // peel off actual type
518 auto make_res = [&, mem = mem](Ref res) { return world.tuple({mem, res}); };
519
520 if (auto result = fold<div, id>(world, type, a, b)) return make_res(result);
521
522 if (auto la = Lit::isa(a)) {
523 if (*la == 0) return make_res(a); // 0 / b -> 0 and 0 % b -> 0
524 }
525
526 if (auto lb = Lit::isa(b)) {
527 if (*lb == 0) return make_res(world.bot(type)); // a / 0 -> ⊥ and a % 0 -> ⊥
528
529 if (*lb == 1) {
530 switch (id) {
531 case div::sdiv: return make_res(a); // a / 1 -> a
532 case div::udiv: return make_res(a); // a / 1 -> a
533 case div::srem: return make_res(world.lit(type, 0)); // a % 1 -> 0
534 case div::urem: return make_res(world.lit(type, 0)); // a % 1 -> 0
535 }
536 }
537 }
538
539 if (a == b) {
540 switch (id) {
541 case div::sdiv: return make_res(world.lit(type, 1)); // a / a -> 1
542 case div::udiv: return make_res(world.lit(type, 1)); // a / a -> 1
543 case div::srem: return make_res(world.lit(type, 0)); // a % a -> 0
544 case div::urem: return make_res(world.lit(type, 0)); // a % a -> 0
545 }
546 }
547
548 return {};
549}
550
551template<conv id> Ref normalize_conv(Ref dst_t, Ref, Ref x) {
552 auto& world = dst_t->world();
553 auto s_t = x->type()->as<App>();
554 auto d_t = dst_t->as<App>();
555 auto s = s_t->arg();
556 auto d = d_t->arg();
557 auto ls = Lit::isa(s);
558 auto ld = Lit::isa(d);
559
560 if (s_t == d_t) return x;
561 if (x->isa<Bot>()) return world.bot(d_t);
562 if constexpr (id == conv::s) {
563 if (ls && ld && *ld < *ls) return world.call(conv::u, d, x); // just truncate - we don't care for signedness
564 }
565
566 if (auto l = Lit::isa(x); l && ls && ld) {
567 if constexpr (id == conv::u) {
568 if (*ld == 0) return world.lit(d_t, *l); // I64
569 return world.lit(d_t, *l % *ld);
570 }
571
572 auto sw = Idx::size2bitwidth(*ls);
573 auto dw = Idx::size2bitwidth(*ld);
574
575 // clang-format off
576 if (false) {}
577#define M(S, D) \
578 else if (S == sw && D == dw) return world.lit(d_t, w2s<D>(mim::bitcast<w2s<S>>(*l)));
579 M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
580 M( 8, 16) M( 8, 32) M( 8, 64)
581 M(16, 32) M(16, 64)
582 M(32, 64)
583 else assert(false && "TODO: conversion between different Idx sizes");
584 // clang-format on
585 }
586
587 return {};
588}
589
591 auto& world = dst_t->world();
592 auto src_t = src->type();
593
594 if (src->isa<Bot>()) return world.bot(dst_t);
595 if (src_t == dst_t) return src;
596
597 if (auto other = match<bitcast>(src))
598 return other->arg()->type() == dst_t ? *other->arg() : world.call<bitcast>(dst_t, other->arg());
599
600 if (auto l = Lit::isa(src)) {
601 if (dst_t->isa<Nat>()) return world.lit(dst_t, *l);
602 if (Idx::isa(dst_t)) return world.lit(dst_t, *l);
603 }
604
605 return {};
606}
607
608// TODO this currently hard-codes x86_64 ABI
609// TODO in contrast to C, we might want to give singleton types like 'Idx 1' or '[]' a size of 0 and simply nuke each
610// and every occurance of these types in a later phase
611// TODO Pi and others
612template<trait id> Ref normalize_trait(Ref, Ref, Ref type) {
613 auto& world = type->world();
614 if (auto ptr = match<mem::Ptr>(type)) {
615 return world.lit_nat(8);
616 } else if (type->isa<Pi>()) {
617 return world.lit_nat(8); // Gets lowered to function ptr
618 } else if (auto size = Idx::isa(type)) {
619 if (auto w = Idx::size2bitwidth(size)) return world.lit_nat(std::max(1_n, std::bit_ceil(*w) / 8_n));
620 } else if (auto w = math::isa_f(type)) {
621 switch (*w) {
622 case 16: return world.lit_nat(2);
623 case 32: return world.lit_nat(4);
624 case 64: return world.lit_nat(8);
625 default: fe::unreachable();
626 }
627 } else if (type->isa<Sigma>() || type->isa<Meet>()) {
628 u64 offset = 0;
629 u64 align = 1;
630 for (auto t : type->ops()) {
631 auto a = Lit::isa(core::op(trait::align, t));
632 auto s = Lit::isa(core::op(trait::size, t));
633 if (!a || !s) return {};
634
635 align = std::max(align, *a);
636 offset = pad(offset, *a) + *s;
637 }
638
639 offset = pad(offset, align);
640 u64 size = std::max(1_u64, offset);
641
642 switch (id) {
643 case trait::align: return world.lit_nat(align);
644 case trait::size: return world.lit_nat(size);
645 }
646 } else if (auto arr = type->isa_imm<Arr>()) {
647 auto align = op(trait::align, arr->body());
648 if constexpr (id == trait::align) return align;
649 auto b = op(trait::size, arr->body());
650 if (b->isa<Lit>()) return world.call(nat::mul, Defs{arr->shape(), b});
651 } else if (auto join = type->isa<Join>()) {
652 if (auto sigma = convert(join)) return core::op(id, sigma);
653 }
654
655 return {};
656}
657
658Ref normalize_zip(Ref type, Ref c, Ref arg) {
659 auto& w = type->world();
660 auto callee = c->as<App>();
661 auto is_os = callee->arg();
662 auto [n_i, Is, n_o, Os, f] = is_os->projs<5>();
663 auto [r, s] = callee->decurry()->args<2>();
664 auto lr = Lit::isa(r);
665 auto ls = Lit::isa(s);
666
667 // TODO commute
668 // TODO reassociate
669 // TODO more than one Os
670 // TODO select which Is/Os to zip
671
672 if (lr && ls && *lr == 1 && *ls == 1) return w.app(f, arg);
673
674 if (auto l_in = Lit::isa(n_i)) {
675 auto args = arg->projs(*l_in);
676
677 if (lr && std::ranges::all_of(args, [](Ref arg) { return arg->isa<Tuple, Pack>(); })) {
678 auto shapes = s->projs(*lr);
679 auto s_n = Lit::isa(shapes.front());
680
681 if (s_n) {
682 auto elems = DefVec(*s_n, [&, f = f](size_t s_i) {
683 auto inner_args = DefVec(args.size(), [&](size_t i) { return args[i]->proj(*s_n, s_i); });
684 if (*lr == 1) {
685 return w.app(f, inner_args);
686 } else {
687 auto app_zip = w.app(w.annex<zip>(), {w.lit_nat(*lr - 1), w.tuple(shapes.view().subspan(1))});
688 return w.app(w.app(app_zip, is_os), inner_args);
689 }
690 });
691 return w.tuple(elems);
692 }
693 }
694 }
695
696 return {};
697}
698
699template<pe id> Ref normalize_pe(Ref type, Ref, Ref arg) {
700 auto& world = type->world();
701
702 if constexpr (id == pe::known) {
703 if (match(pe::hlt, arg)) return world.lit_ff();
704 if (arg->dep_const()) return world.lit_tt();
705 }
706
707 return {};
708}
709
711
712} // namespace mim::plug::core
Ref arg() const
Definition lam.h:222
const App * decurry() const
Returns App::callee again as App.
Definition lam.h:214
A (possibly paramterized) Array.
Definition tuple.h:67
Ref type() const
Definition def.h:251
World & world() const
Definition def.cpp:411
bool dep_const() const
Definition def.h:332
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:361
auto ops() const
Definition def.h:271
const T * isa_imm() const
Definition def.h:439
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:823
static Ref isa(Ref def)
Definition def.cpp:552
static std::optional< T > isa(Ref def)
Definition def.h:762
static T as(Ref def)
Definition def.h:767
A (possibly paramterized) Tuple.
Definition tuple.h:114
A dependent function type.
Definition lam.h:11
Helper class to retrieve Infer::arg if present.
Definition def.h:86
A dependent tuple type.
Definition tuple.h:9
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
Specific Bound depending on Up.
Definition lattice.h:31
Extremum. Either Top (Up) or Bottom.
Definition lattice.h:156
Data constructor for a Sigma.
Definition tuple.h:50
const Lit * lit_tt()
Definition world.h:426
#define MIM_core_NORMALIZER_IMPL
Definition autogen.h:295
#define M(S, D)
@ Lit
Definition def.h:40
The core Plugin
Definition core.h:8
Ref normalize_extrema(Ref type, Ref c, Ref arg)
Ref normalize_bit1(Ref type, Ref c, Ref a)
Ref normalize_div(Ref full_type, Ref, Ref arg)
const Sigma * convert(const TBound< up > *b)
Definition core.cpp:19
Ref normalize_trait(Ref, Ref, Ref type)
Ref normalize_ncmp(Ref type, Ref callee, Ref arg)
Ref normalize_nat(Ref type, Ref callee, Ref arg)
Ref normalize_icmp(Ref type, Ref c, Ref arg)
Ref normalize_abs(Ref type, Ref, Ref arg)
Ref normalize_bit2(Ref type, Ref c, Ref arg)
Ref normalize_wrap(Ref type, Ref c, Ref arg)
Ref op(trait o, Ref type)
Definition core.h:33
constexpr std::array< std::array< u64, 2 >, 2 > make_truth_table(bit2 id)
Definition core.h:50
Ref normalize_zip(Ref type, Ref c, Ref arg)
Ref normalize_shr(Ref type, Ref c, Ref arg)
Ref normalize_conv(Ref dst_t, Ref, Ref x)
@ nuw
No Unsigned Wrap around.
@ none
Wrap around.
@ nsw
No Signed Wrap around.
Ref normalize_bitcast(Ref dst_t, Ref, Ref src)
Ref normalize_pe(Ref type, Ref, Ref arg)
Ref normalize_idx(Ref type, Ref c, Ref arg)
std::optional< nat_t > isa_f(Ref def)
Definition math.h:76
The mem Plugin
Definition mem.h:11
View< const Def * > Defs
Definition def.h:61
Vector< const Def * > DefVec
Definition def.h:62
u8 sub_t
Definition types.h:48
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:26
auto match(Ref def)
Definition axiom.h:112
void commute(const Def *&a, const Def *&b)
Swap Lit to left - or smaller Def::gid, if no lit present.
Definition normalize.h:25
u64 pad(u64 offset, u64 align)
Definition util.h:49
constexpr bool is_commutative(Id)
Definition axiom.h:139
typename detail::w2s_< w >::type w2s
Definition types.h:73
constexpr bool is_associative(Id id)
Definition axiom.h:142
typename detail::w2u_< w >::type w2u
Definition types.h:72
TExt< false > Bot
Definition lattice.h:177
uint64_t u64
Definition types.h:34
bool get_sign(T val)
Definition util.h:40
uint8_t u8
Definition types.h:34
Definition span.h:104
static constexpr size_t Num
Definition plugin.h:115
static constexpr flags_t Base
Definition plugin.h:118
#define CODE(t, str)
Definition tok.h:54
#define MIM_1_8_16_32_64(m)
Definition types.h:24