MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <mim/normalize.h>
2
4#include <mim/plug/mem/mem.h>
5
7
8namespace mim::plug::core {
9
10namespace {
11
12// clang-format off
13// See https://stackoverflow.com/a/64354296 for static_assert trick below.
14template<class Id, Id id, nat_t w>
15Res fold(u64 a, u64 b, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
16 using ST = w2s<w>;
17 using UT = w2u<w>;
19 auto u = mim::bitcast<UT>(a), v = mim::bitcast<UT>(b);
20
21 if constexpr (std::is_same_v<Id, wrap>) {
22 if constexpr (id == wrap::add) {
23 auto res = u + v;
24 if (nuw && res < u) return {};
25 // TODO nsw
26 return res;
27 } else if constexpr (id == wrap::sub) {
28 auto res = u - v;
29 // TODO nsw
30 return res;
31 } else if constexpr (id == wrap::mul) {
32 if constexpr (std::is_same_v<UT, bool>)
33 return UT(u & v);
34 else
35 return UT(u * v);
36 } else if constexpr (id == wrap::shl) {
37 if (b >= w) return {};
38 decltype(u) res;
39 if constexpr (std::is_same_v<UT, bool>)
40 res = bool(u64(u) << u64(v));
41 else
42 res = u << v;
43 if (nuw && res < u) return {};
44 if (nsw && get_sign(u) != get_sign(res)) return {};
45 return res;
46 } else {
47 []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
48 }
49 } else if constexpr (std::is_same_v<Id, shr>) {
50 if (b >= w) return {};
51 if constexpr (false) {}
52 else if constexpr (id == shr::a) return s >> t;
53 else if constexpr (id == shr::l) return u >> v;
54 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
55 } else if constexpr (std::is_same_v<Id, div>) {
56 if (b == 0) return {};
57 if constexpr (false) {}
58 else if constexpr (id == div::sdiv) return s / t;
59 else if constexpr (id == div::udiv) return u / v;
60 else if constexpr (id == div::srem) return s % t;
61 else if constexpr (id == div::urem) return u % v;
62 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
63 } else if constexpr (std::is_same_v<Id, icmp>) {
64 bool res = false;
65 auto pm = !(u >> UT(w - 1)) && (v >> UT(w - 1));
66 auto mp = (u >> UT(w - 1)) && !(v >> UT(w - 1));
67 res |= ((id & icmp::Xygle) != icmp::f) && pm;
68 res |= ((id & icmp::xYgle) != icmp::f) && mp;
69 res |= ((id & icmp::xyGle) != icmp::f) && u > v && !mp;
70 res |= ((id & icmp::xygLe) != icmp::f) && u < v && !pm;
71 res |= ((id & icmp::xyglE) != icmp::f) && u == v;
72 return res;
73 } else if constexpr (std::is_same_v<Id, extrema>) {
74 if constexpr (false) {}
75 else if(id == extrema::sm) return std::min(u, v);
76 else if(id == extrema::Sm) return std::min(s, t);
77 else if(id == extrema::sM) return std::max(u, v);
78 else if(id == extrema::SM) return std::max(s, t);
79 } else {
80 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
81 }
82}
83// clang-format on
84
85// Note that @p a and @p b are passed by reference as fold also commutes if possible.
86template<class Id, Id id>
87const Def* fold(World& world, const Def* type, const Def*& a, const Def*& b, const Def* mode = {}) {
88 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
89
90 if (auto la = Lit::isa(a)) {
91 if (auto lb = Lit::isa(b)) {
92 auto size = Lit::as(Idx::isa(a->type()));
93 auto width = Idx::size2bitwidth(size);
94 bool nsw = false, nuw = false;
95 if constexpr (std::is_same_v<Id, wrap>) {
96 auto m = mode ? Lit::as(mode) : 0_n;
97 nsw = m & Mode::nsw;
98 nuw = m & Mode::nuw;
99 }
100
101 Res res;
102 switch (width) {
103#define CODE(i) \
104 case i: res = fold<Id, id, i>(*la, *lb, nsw, nuw); break;
106#undef CODE
107 default:
108 // TODO this is super rough but at least better than just bailing out
109 res = fold<Id, id, 64>(*la, *lb, false, false);
110 if (res && !std::is_same_v<Id, icmp>) *res %= size;
111 }
112
113 return res ? world.lit(type, *res) : world.bot(type);
114 }
115 }
116
117 if (::mim::is_commutative(id)) commute(a, b);
118 return nullptr;
119}
120
121template<class Id, nat_t w> Res fold(u64 a, [[maybe_unused]] bool nsw, [[maybe_unused]] bool nuw) {
122 using ST = w2s<w>;
123 auto s = mim::bitcast<ST>(a);
124
125 if constexpr (std::is_same_v<Id, abs>)
126 return std::abs(s);
127 else
128 []<bool flag = false>() { static_assert(flag, "missing tag"); }
129 ();
130}
131
132template<class Id> const Def* fold(World& world, const Def* type, const Def*& a) {
133 if (a->isa<Bot>()) return world.bot(type);
134
135 if (auto la = Lit::isa(a)) {
136 auto size = Lit::as(Idx::isa(a->type()));
137 auto width = Idx::size2bitwidth(size);
138 bool nsw = false, nuw = false;
139 Res res;
140 switch (width) {
141#define CODE(i) \
142 case i: res = fold<Id, i>(*la, nsw, nuw); break;
144#undef CODE
145 default:
146 res = fold<Id, 64>(*la, false, false);
147 if (res && !std::is_same_v<Id, icmp>) *res %= size;
148 }
149
150 return res ? world.lit(type, *res) : world.bot(type);
151 }
152 return nullptr;
153}
154
155/// Reassociates @p a and @p b according to following rules.
156/// We use the following naming convention while literals are prefixed with an `l`:
157/// ```
158/// a op b
159/// (x op y) op (z op w)
160///
161/// (1) la op (lz op w) -> (la op lz) op w
162/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
163/// (3) a op (lz op w) -> lz op (a op w)
164/// (4) (lx op y) op b -> lx op (y op b)
165/// ```
166template<class Id>
167const Def* reassociate(Id id, World& world, [[maybe_unused]] const App* ab, const Def* a, const Def* b) {
168 if (!is_associative(id)) return nullptr;
169
170 auto xy = match<Id>(id, a);
171 auto zw = match<Id>(id, b);
172 auto la = a->isa<Lit>();
173 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
174 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
175 auto lx = Lit::isa(x);
176 auto lz = Lit::isa(z);
177
178 // if we reassociate, we have to forget about nsw/nuw
179 auto make_op = [&world, id](const Def* a, const Def* b) { return world.call(id, Mode::none, Defs{a, b}); };
180
181 if (la && lz) return make_op(make_op(a, z), w); // (1)
182 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
183 if (lz) return make_op(z, make_op(a, w)); // (3)
184 if (lx) return make_op(x, make_op(y, b)); // (4)
185 return nullptr;
186}
187
188template<class Id> const Def* merge_cmps(std::array<std::array<u64, 2>, 2> tab, const Def* a, const Def* b) {
189 static_assert(sizeof(sub_t) == 1, "if this ever changes, please adjust the logic below");
190 static constexpr size_t num_bits = std::bit_width(Annex::Num<Id> - 1_u64);
191
192 auto& world = a->world();
193 auto a_cmp = match<Id>(a);
194 auto b_cmp = match<Id>(b);
195
196 if (a_cmp && b_cmp && a_cmp->arg() == b_cmp->arg()) {
197 // push sub bits of a_cmp and b_cmp through truth table
198 sub_t res = 0;
199 sub_t a_sub = a_cmp.sub();
200 sub_t b_sub = b_cmp.sub();
201 for (size_t i = 0; i != num_bits; ++i, res >>= 1, a_sub >>= 1, b_sub >>= 1)
202 res |= tab[a_sub & 1][b_sub & 1] << 7_u8;
203 res >>= (7_u8 - u8(num_bits));
204
205 if constexpr (std::is_same_v<Id, math::cmp>)
206 return world.call(math::cmp(res), /*mode*/ a_cmp->decurry()->arg(), a_cmp->arg());
207 else
208 return world.call(icmp(Annex::Base<icmp> | res), a_cmp->arg());
209 }
210
211 return nullptr;
212}
213
214} // namespace
215
216template<nat id> const Def* normalize_nat(const Def* type, const Def* callee, const Def* arg) {
217 auto& world = type->world();
218 auto [a, b] = arg->projs<2>();
219 if (is_commutative(id)) commute(a, b);
220 auto la = Lit::isa(a);
221 auto lb = Lit::isa(b);
222
223 if (la) {
224 if (lb) {
225 switch (id) {
226 case nat::add: return world.lit_nat(*la + *lb);
227 case nat::sub: return *la < *lb ? world.lit_nat_0() : world.lit_nat(*la - *lb);
228 case nat::mul: return world.lit_nat(*la * *lb);
229 }
230 }
231
232 if (*la == 0) {
233 switch (id) {
234 case nat::add: return b;
235 case nat::sub: return a; // 0 - b = 0
236 case nat::mul: return a; // 0 * b = 0
237 }
238 }
239
240 if (*la == 1 && id == nat::mul) return b; // 1 * b = b
241 }
242
243 if (lb && *lb == 0 && id == nat::sub) return a; // a - 0 = a
244
245 return world.raw_app(type, callee, {a, b});
246}
247
248template<ncmp id> const Def* normalize_ncmp(const Def* type, const Def* callee, const Def* arg) {
249 auto& world = type->world();
250
251 if (id == ncmp::t) return world.lit_tt();
252 if (id == ncmp::f) return world.lit_ff();
253
254 auto [a, b] = arg->projs<2>();
255 if (is_commutative(id)) commute(a, b);
256
257 if (auto la = Lit::isa(a)) {
258 if (auto lb = Lit::isa(b)) {
259 // clang-format off
260 switch (id) {
261 case ncmp:: e: return world.lit_bool(*la == *lb);
262 case ncmp::ne: return world.lit_bool(*la != *lb);
263 case ncmp::l : return world.lit_bool(*la < *lb);
264 case ncmp::le: return world.lit_bool(*la <= *lb);
265 case ncmp::g : return world.lit_bool(*la > *lb);
266 case ncmp::ge: return world.lit_bool(*la >= *lb);
267 default: fe::unreachable();
268 }
269 // clang-format on
270 }
271 }
272
273 return world.raw_app(type, callee, {a, b});
274}
275
276template<icmp id> const Def* normalize_icmp(const Def* type, const Def* c, const Def* arg) {
277 auto& world = type->world();
278 auto callee = c->as<App>();
279 auto [a, b] = arg->projs<2>();
280
281 if (auto result = fold<icmp, id>(world, type, a, b)) return result;
282 if (id == icmp::f) return world.lit_ff();
283 if (id == icmp::t) return world.lit_tt();
284 if (a == b) {
285 if (id & (icmp::e & 0xff)) return world.lit_tt();
286 if (id == icmp::ne) return world.lit_ff();
287 }
288
289 return world.raw_app(type, callee, {a, b});
290}
291
292template<extrema id> const Def* normalize_extrema(const Def* type, const Def* c, const Def* arg) {
293 auto& world = type->world();
294 auto callee = c->as<App>();
295 auto [a, b] = arg->projs<2>();
296 if (auto result = fold<extrema, id>(world, type, a, b)) return result;
297 return world.raw_app(type, callee, {a, b});
298}
299
300const Def* normalize_abs(const Def* type, const Def*, const Def* arg) {
301 auto& world = type->world();
302 auto [mem, a] = arg->projs<2>();
303 auto [_, actual_type] = type->projs<2>();
304 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
305
306 if (auto result = fold<abs>(world, actual_type, a)) return make_res(result);
307 return {};
308}
309
310template<bit1 id> const Def* normalize_bit1(const Def* type, const Def* c, const Def* a) {
311 auto& world = type->world();
312 auto callee = c->as<App>();
313 auto s = callee->decurry()->arg();
314 // TODO cope with wrap around
315
316 if constexpr (id == bit1::id) return a;
317
318 if (auto ls = Lit::isa(s)) {
319 switch (id) {
320 case bit1::f: return world.lit_idx(*ls, 0);
321 case bit1::t: return world.lit_idx(*ls, *ls - 1_u64);
322 case bit1::id: fe::unreachable();
323 default: break;
324 }
325
326 assert(id == bit1::neg);
327 if (auto la = Lit::isa(a)) return world.lit_idx_mod(*ls, ~*la);
328 }
329
330 return {};
331}
332
333template<bit2 id> const Def* normalize_bit2(const Def* type, const Def* c, const Def* arg) {
334 auto& world = type->world();
335 auto callee = c->as<App>();
336 auto [a, b] = arg->projs<2>();
337 auto s = callee->decurry()->arg();
338 auto ls = Lit::isa(s);
339 // TODO cope with wrap around
340
341 if (is_commutative(id)) commute(a, b);
342
343 auto tab = make_truth_table(id);
344 if (auto res = merge_cmps<icmp>(tab, a, b)) return res;
345 if (auto res = merge_cmps<math::cmp>(tab, a, b)) return res;
346
347 auto la = Lit::isa(a);
348 auto lb = Lit::isa(b);
349
350 // clang-format off
351 switch (id) {
352 case bit2:: f: return world.lit(type, 0);
353 case bit2:: t: if (ls) return world.lit(type, *ls-1_u64); break;
354 case bit2:: fst: return a;
355 case bit2:: snd: return b;
356 case bit2:: nfst: return world.call(bit1::neg, s, a);
357 case bit2:: nsnd: return world.call(bit1::neg, s, b);
358 case bit2:: ciff: return world.call(bit2:: iff, s, Defs{b, a});
359 case bit2::nciff: return world.call(bit2::niff, s, Defs{b, a});
360 default: break;
361 }
362
363 if (la && lb && ls) {
364 switch (id) {
365 case bit2::and_: return world.lit_idx (*ls, *la & *lb);
366 case bit2:: or_: return world.lit_idx (*ls, *la | *lb);
367 case bit2::xor_: return world.lit_idx (*ls, *la ^ *lb);
368 case bit2::nand: return world.lit_idx_mod(*ls, ~(*la & *lb));
369 case bit2:: nor: return world.lit_idx_mod(*ls, ~(*la | *lb));
370 case bit2::nxor: return world.lit_idx_mod(*ls, ~(*la ^ *lb));
371 case bit2:: iff: return world.lit_idx_mod(*ls, ~ *la | *lb);
372 case bit2::niff: return world.lit_idx (*ls, *la & ~*lb);
373 default: fe::unreachable();
374 }
375 }
376
377 // TODO rewrite using bit2
378 auto unary = [&](bool x, bool y, const Def* a) -> const Def* {
379 if (!x && !y) return world.lit(type, 0);
380 if ( x && y) return ls ? world.lit(type, *ls-1_u64) : nullptr;
381 if (!x && y) return a;
382 if ( x && !y && id != bit2::xor_) return world.call(bit1::neg, s, a);
383 return nullptr;
384 };
385 // clang-format on
386
387 if (is_commutative(id) && a == b) {
388 if (auto res = unary(tab[0][0], tab[1][1], a)) return res;
389 }
390
391 if (la) {
392 if (*la == 0) {
393 if (auto res = unary(tab[0][0], tab[0][1], b)) return res;
394 } else if (ls && *la == *ls - 1_u64) {
395 if (auto res = unary(tab[1][0], tab[1][1], b)) return res;
396 }
397 }
398
399 if (lb) {
400 if (*lb == 0) {
401 if (auto res = unary(tab[0][0], tab[1][0], a)) return res;
402 } else if (ls && *lb == *ls - 1_u64) {
403 if (auto res = unary(tab[0][1], tab[1][1], a)) return res;
404 }
405 }
406
407 if (auto res = reassociate<bit2>(id, world, callee, a, b)) return res;
408
409 return world.raw_app(type, callee, {a, b});
410}
411
412const Def* normalize_idx(const Def* type, const Def* c, const Def* arg) {
413 auto& world = type->world();
414 auto callee = c->as<App>();
415 if (auto i = Lit::isa(arg)) {
416 if (auto s = Lit::isa(Idx::isa(type))) {
417 if (*i < *s) return world.lit_idx(*s, *i);
418 if (auto m = Lit::isa(callee->arg())) return *m ? world.bot(type) : world.lit_idx_mod(*s, *i);
419 }
420 }
421
422 return {};
423}
424
425template<shr id> const Def* normalize_shr(const Def* type, const Def* c, const Def* arg) {
426 auto& world = type->world();
427 auto callee = c->as<App>();
428 auto [a, b] = arg->projs<2>();
429 auto s = Idx::isa(arg->type());
430 auto ls = Lit::isa(s);
431
432 if (auto result = fold<shr, id>(world, type, a, b)) return result;
433
434 if (auto la = Lit::isa(a); la && *la == 0) {
435 switch (id) {
436 case shr::a: return a;
437 case shr::l: return a;
438 }
439 }
440
441 if (auto lb = Lit::isa(b)) {
442 if (ls && *lb > *ls) return world.bot(type);
443
444 if (*lb == 0) {
445 switch (id) {
446 case shr::a: return a;
447 case shr::l: return a;
448 }
449 }
450 }
451
452 return world.raw_app(type, callee, {a, b});
453}
454
455template<wrap id> const Def* normalize_wrap(const Def* type, const Def* c, const Def* arg) {
456 auto& world = type->world();
457 auto callee = c->as<App>();
458 auto [a, b] = arg->projs<2>();
459 auto mode = callee->arg();
460 auto s = Idx::isa(a->type());
461 auto ls = Lit::isa(s);
462
463 if (auto result = fold<wrap, id>(world, type, a, b)) return result;
464
465 // clang-format off
466 if (auto la = Lit::isa(a)) {
467 if (*la == 0) {
468 switch (id) {
469 case wrap::add: return b; // 0 + b -> b
470 case wrap::sub: break;
471 case wrap::mul: return a; // 0 * b -> 0
472 case wrap::shl: return a; // 0 << b -> 0
473 }
474 } else if (*la == 1) {
475 switch (id) {
476 case wrap::add: break;
477 case wrap::sub: break;
478 case wrap::mul: return b; // 1 * b -> b
479 case wrap::shl: break;
480 }
481 }
482 }
483
484 if (auto lb = Lit::isa(b)) {
485 if (*lb == 0) {
486 switch (id) {
487 case wrap::sub: return a; // a - 0 -> a
488 case wrap::shl: return a; // a >> 0 -> a
489 default: fe::unreachable();
490 // add, mul are commutative, the literal has been normalized to the left
491 }
492 }
493
494 if (auto lm = Lit::isa(mode); lm && *lm == 0 && id == wrap::sub)
495 return world.call(wrap::add, mode, Defs{a, world.lit_idx_mod(*ls, ~*lb + 1_u64)}); // a - lb -> a + (~lb + 1)
496 else if (id == wrap::shl && ls && *lb > *ls)
497 return world.bot(type);
498 }
499
500 if (a == b) {
501 switch (id) {
502 case wrap::add: return world.call(wrap::mul, mode, Defs{world.lit(type, 2), a}); // a + a -> 2 * a
503 case wrap::sub: return world.lit(type, 0); // a - a -> 0
504 case wrap::mul: break;
505 case wrap::shl: break;
506 }
507 }
508 // clang-format on
509
510 if (auto res = reassociate<wrap>(id, world, callee, a, b)) return res;
511
512 return world.raw_app(type, callee, {a, b});
513}
514
515template<div id> const Def* normalize_div(const Def* full_type, const Def*, const Def* arg) {
516 auto& world = full_type->world();
517 auto [mem, ab] = arg->projs<2>();
518 auto [a, b] = ab->projs<2>();
519 auto [_, type] = full_type->projs<2>(); // peel off actual type
520 auto make_res = [&, mem = mem](const Def* res) { return world.tuple({mem, res}); };
521
522 if (auto result = fold<div, id>(world, type, a, b)) return make_res(result);
523
524 if (auto la = Lit::isa(a)) {
525 if (*la == 0) return make_res(a); // 0 / b -> 0 and 0 % b -> 0
526 }
527
528 if (auto lb = Lit::isa(b)) {
529 if (*lb == 0) return make_res(world.bot(type)); // a / 0 -> ⊥ and a % 0 -> ⊥
530
531 if (*lb == 1) {
532 switch (id) {
533 case div::sdiv: return make_res(a); // a / 1 -> a
534 case div::udiv: return make_res(a); // a / 1 -> a
535 case div::srem: return make_res(world.lit(type, 0)); // a % 1 -> 0
536 case div::urem: return make_res(world.lit(type, 0)); // a % 1 -> 0
537 }
538 }
539 }
540
541 if (a == b) {
542 switch (id) {
543 case div::sdiv: return make_res(world.lit(type, 1)); // a / a -> 1
544 case div::udiv: return make_res(world.lit(type, 1)); // a / a -> 1
545 case div::srem: return make_res(world.lit(type, 0)); // a % a -> 0
546 case div::urem: return make_res(world.lit(type, 0)); // a % a -> 0
547 }
548 }
549
550 return {};
551}
552
553template<conv id> const Def* normalize_conv(const Def* dst_t, const Def*, const Def* x) {
554 auto& world = dst_t->world();
555 auto s_t = x->type()->as<App>();
556 auto d_t = dst_t->as<App>();
557 auto s = s_t->arg();
558 auto d = d_t->arg();
559 auto ls = Lit::isa(s);
560 auto ld = Lit::isa(d);
561
562 if (s_t == d_t) return x;
563 if (x->isa<Bot>()) return world.bot(d_t);
564 if constexpr (id == conv::s) {
565 if (ls && ld && *ld < *ls) return world.call(conv::u, d, x); // just truncate - we don't care for signedness
566 }
567
568 if (auto l = Lit::isa(x); l && ls && ld) {
569 if constexpr (id == conv::u) {
570 if (*ld == 0) return world.lit(d_t, *l); // I64
571 return world.lit(d_t, *l % *ld);
572 }
573
574 auto sw = Idx::size2bitwidth(*ls);
575 auto dw = Idx::size2bitwidth(*ld);
576
577 // clang-format off
578 if (false) {}
579#define M(S, D) \
580 else if (S == sw && D == dw) return world.lit(d_t, w2s<D>(mim::bitcast<w2s<S>>(*l)));
581 M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
582 M( 8, 16) M( 8, 32) M( 8, 64)
583 M(16, 32) M(16, 64)
584 M(32, 64)
585 else assert(false && "TODO: conversion between different Idx sizes");
586 // clang-format on
587 }
588
589 return {};
590}
591
592const Def* normalize_bitcast(const Def* dst_t, const Def*, const Def* src) {
593 auto& world = dst_t->world();
594 auto src_t = src->type();
595
596 if (src->isa<Bot>()) return world.bot(dst_t);
597 if (src_t == dst_t) return src;
598
599 if (auto other = match<bitcast>(src))
600 return other->arg()->type() == dst_t ? other->arg() : world.call<bitcast>(dst_t, other->arg());
601
602 if (auto l = Lit::isa(src)) {
603 if (dst_t->isa<Nat>()) return world.lit(dst_t, *l);
604 if (Idx::isa(dst_t)) return world.lit(dst_t, *l);
605 }
606
607 return {};
608}
609
610// TODO this currently hard-codes x86_64 ABI
611// TODO in contrast to C, we might want to give singleton types like 'Idx 1' or '[]' a size of 0 and simply nuke each
612// and every occurance of these types in a later phase
613// TODO Pi and others
614template<trait id> const Def* normalize_trait(const Def*, const Def*, const Def* type) {
615 auto& world = type->world();
616 if (auto ptr = match<mem::Ptr>(type)) {
617 return world.lit_nat(8);
618 } else if (type->isa<Pi>()) {
619 return world.lit_nat(8); // Gets lowered to function ptr
620 } else if (auto size = Idx::isa(type)) {
621 if (auto w = Idx::size2bitwidth(size)) return world.lit_nat(std::max(1_n, std::bit_ceil(*w) / 8_n));
622 } else if (auto w = math::isa_f(type)) {
623 switch (*w) {
624 case 16: return world.lit_nat(2);
625 case 32: return world.lit_nat(4);
626 case 64: return world.lit_nat(8);
627 default: fe::unreachable();
628 }
629 } else if (type->isa<Sigma>() || type->isa<Meet>()) {
630 u64 offset = 0;
631 u64 align = 1;
632 for (auto t : type->ops()) {
633 auto a = Lit::isa(core::op(trait::align, t));
634 auto s = Lit::isa(core::op(trait::size, t));
635 if (!a || !s) return {};
636
637 align = std::max(align, *a);
638 offset = pad(offset, *a) + *s;
639 }
640
641 offset = pad(offset, align);
642 u64 size = std::max(1_u64, offset);
643
644 switch (id) {
645 case trait::align: return world.lit_nat(align);
646 case trait::size: return world.lit_nat(size);
647 }
648 } else if (auto arr = type->isa_imm<Arr>()) {
649 auto align = op(trait::align, arr->body());
650 if constexpr (id == trait::align) return align;
651 auto b = op(trait::size, arr->body());
652 if (b->isa<Lit>()) return world.call(nat::mul, Defs{arr->shape(), b});
653 } else if (auto join = type->isa<Join>()) {
654 if (auto sigma = convert(join)) return core::op(id, sigma);
655 }
656
657 return {};
658}
659
660const Def* normalize_zip(const Def* type, const Def* c, const Def* arg) {
661 auto& w = type->world();
662 auto callee = c->as<App>();
663 auto is_os = callee->arg();
664 auto [n_i, Is, n_o, Os, f] = is_os->projs<5>();
665 auto [r, s] = callee->decurry()->args<2>();
666 auto lr = Lit::isa(r);
667 auto ls = Lit::isa(s);
668
669 // TODO commute
670 // TODO reassociate
671 // TODO more than one Os
672 // TODO select which Is/Os to zip
673
674 if (lr && ls && *lr == 1 && *ls == 1) return w.app(f, arg);
675
676 if (auto l_in = Lit::isa(n_i)) {
677 auto args = arg->projs(*l_in);
678
679 if (lr && std::ranges::all_of(args, [](const Def* arg) { return arg->isa<Tuple, Pack>(); })) {
680 auto shapes = s->projs(*lr);
681 auto s_n = Lit::isa(shapes.front());
682
683 if (s_n) {
684 auto elems = DefVec(*s_n, [&, f = f](size_t s_i) {
685 auto inner_args = DefVec(args.size(), [&](size_t i) { return args[i]->proj(*s_n, s_i); });
686 if (*lr == 1) {
687 return w.app(f, inner_args);
688 } else {
689 auto app_zip = w.app(w.annex<zip>(), {w.lit_nat(*lr - 1), w.tuple(shapes.view().subspan(1))});
690 return w.app(w.app(app_zip, is_os), inner_args);
691 }
692 });
693 return w.tuple(elems);
694 }
695 }
696 }
697
698 return {};
699}
700
701template<pe id> const Def* normalize_pe(const Def* type, const Def*, const Def* arg) {
702 auto& world = type->world();
703
704 if constexpr (id == pe::is_closed) {
705 if (match(pe::hlt, arg)) return world.lit_ff();
706 if (arg->is_closed()) return world.lit_tt();
707 }
708
709 return {};
710}
711
713
714} // namespace mim::plug::core
const Def * arg() const
Definition lam.h:225
A (possibly paramterized) Array.
Definition tuple.h:68
Base class for all Defs.
Definition def.h:198
World & world() const noexcept
Definition def.cpp:413
constexpr auto ops() const noexcept
Definition def.h:261
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:345
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:242
const T * isa_imm() const
Definition def.h:424
bool is_closed() const
Has no free_vars()?
Definition def.cpp:392
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:792
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:555
static std::optional< T > isa(const Def *def)
Definition def.h:730
static T as(const Def *def)
Definition def.h:735
A (possibly paramterized) Tuple.
Definition tuple.h:115
A dependent function type.
Definition lam.h:11
A dependent tuple type.
Definition tuple.h:9
Data constructor for a Sigma.
Definition tuple.h:50
const Lit * lit_tt()
Definition world.h:423
#define MIM_core_NORMALIZER_IMPL
Definition autogen.h:295
#define M(S, D)
The core Plugin
Definition core.h:8
const Def * normalize_nat(const Def *type, const Def *callee, const Def *arg)
const Sigma * convert(const TBound< up > *b)
Definition core.cpp:19
const Def * normalize_div(const Def *full_type, const Def *, const Def *arg)
const Def * normalize_pe(const Def *type, const Def *, const Def *arg)
const Def * normalize_extrema(const Def *type, const Def *c, const Def *arg)
const Def * normalize_icmp(const Def *type, const Def *c, const Def *arg)
const Def * normalize_bit1(const Def *type, const Def *c, const Def *a)
const Def * normalize_conv(const Def *dst_t, const Def *, const Def *x)
const Def * normalize_bit2(const Def *type, const Def *c, const Def *arg)
const Def * normalize_wrap(const Def *type, const Def *c, const Def *arg)
const Def * normalize_trait(const Def *, const Def *, const Def *type)
const Def * op(trait o, const Def *type)
Definition core.h:33
const Def * normalize_abs(const Def *type, const Def *, const Def *arg)
const Def * normalize_idx(const Def *type, const Def *c, const Def *arg)
constexpr std::array< std::array< u64, 2 >, 2 > make_truth_table(bit2 id)
Definition core.h:50
const Def * normalize_bitcast(const Def *dst_t, const Def *, const Def *src)
const Def * normalize_ncmp(const Def *type, const Def *callee, const Def *arg)
@ nuw
No Unsigned Wrap around.
Definition core.h:16
@ none
Wrap around.
Definition core.h:14
@ nsw
No Signed Wrap around.
Definition core.h:15
const Def * normalize_zip(const Def *type, const Def *c, const Def *arg)
const Def * normalize_shr(const Def *type, const Def *c, const Def *arg)
std::optional< nat_t > isa_f(const Def *def)
Definition math.h:76
The mem Plugin
Definition mem.h:11
View< const Def * > Defs
Definition def.h:49
Vector< const Def * > DefVec
Definition def.h:50
u8 sub_t
Definition types.h:48
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:23
TBound< true > Join
Definition lattice.h:180
void commute(const Def *&a, const Def *&b)
Swap Lit to left - or smaller Def::gid, if no lit present.
Definition normalize.h:25
u64 pad(u64 offset, u64 align)
Definition util.h:46
constexpr bool is_commutative(Id)
Definition axiom.h:139
typename detail::w2s_< w >::type w2s
Definition types.h:73
auto match(const Def *def)
Definition axiom.h:112
constexpr bool is_associative(Id id)
Definition axiom.h:142
typename detail::w2u_< w >::type w2u
Definition types.h:72
TExt< false > Bot
Definition lattice.h:177
uint64_t u64
Definition types.h:34
bool get_sign(T val)
Definition util.h:37
uint8_t u8
Definition types.h:34
TBound< false > Meet
Definition lattice.h:179
@ App
Definition def.h:85
@ Lit
Definition def.h:85
CODE(node, name, _)
Definition def.h:84
static constexpr size_t Num
Definition plugin.h:115
static constexpr flags_t Base
Definition plugin.h:118
#define MIM_1_8_16_32_64(m)
Definition types.h:24