MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <mim/normalize.h>
2
4
5namespace mim::plug::math {
6
7namespace {
8
9// clang-format off
10template<class Id, Id id, nat_t w>
11Res fold(u64 a) {
12 using T = w2f<w>;
13 auto x = bitcast<T>(a);
14 if constexpr (std::is_same_v<Id, tri>) {
15 if constexpr (false) {}
16 else if constexpr (id == tri:: sin ) return sin (x);
17 else if constexpr (id == tri:: cos ) return cos (x);
18 else if constexpr (id == tri:: tan ) return tan (x);
19 else if constexpr (id == tri:: sinh) return sinh(x);
20 else if constexpr (id == tri:: cosh) return cosh(x);
21 else if constexpr (id == tri:: tanh) return tanh(x);
22 else if constexpr (id == tri::asin ) return asin (x);
23 else if constexpr (id == tri::acos ) return acos (x);
24 else if constexpr (id == tri::atan ) return atan (x);
25 else if constexpr (id == tri::asinh) return asinh(x);
26 else if constexpr (id == tri::acosh) return acosh(x);
27 else if constexpr (id == tri::atanh) return atanh(x);
28 else fe::unreachable();
29 } else if constexpr (std::is_same_v<Id, rt>) {
30 if constexpr (false) {}
31 else if constexpr (id == rt::sq) return std::sqrt(x);
32 else if constexpr (id == rt::cb) return std::cbrt(x);
33 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
34 } else if constexpr (std::is_same_v<Id, exp>) {
35 if constexpr (false) {}
36 else if constexpr (id == exp::exp ) return std::exp (x);
37 else if constexpr (id == exp::exp2 ) return std::exp2 (x);
38 else if constexpr (id == exp::exp10) return std::pow(T(10), x);
39 else if constexpr (id == exp::log ) return std::log (x);
40 else if constexpr (id == exp::log2 ) return std::log2 (x);
41 else if constexpr (id == exp::log10) return std::log10(x);
42 else fe::unreachable();
43 } else if constexpr (std::is_same_v<Id, er>) {
44 if constexpr (false) {}
45 else if constexpr (id == er::f ) return std::erf (x);
46 else if constexpr (id == er::fc) return std::erfc(x);
47 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
48 } else if constexpr (std::is_same_v<Id, gamma>) {
49 if constexpr (false) {}
50 else if constexpr (id == gamma::t) return std::tgamma(x);
51 else if constexpr (id == gamma::l) return std::lgamma(x);
52 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
53 } else if constexpr (std::is_same_v<Id, round>) {
54 if constexpr (false) {}
55 else if constexpr (id == round::f) return std::floor (x);
56 else if constexpr (id == round::c) return std::ceil (x);
57 else if constexpr (id == round::r) return std::round (x);
58 else if constexpr (id == round::t) return std::trunc (x);
59 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
60 } else {
61 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
62 }
63}
64
65template<class Id, nat_t w>
66Res fold(u64 a) {
67 using T = w2f<w>;
68 auto x = bitcast<T>(a);
69 if constexpr (std::is_same_v<Id, abs>) {
70 return std::abs(x);
71 } else {
72 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
73 }
74}
75
76template<class Id>
77const Def* fold(World& world, const Def* type, const Def* a) {
78 if (a->isa<Bot>()) return world.bot(type);
79 auto la = a->isa<Lit>();
80
81 if (la) {
82 nat_t width = *isa_f(a->type());
83 Res res;
84 switch (width) {
85#define CODE(i) \
86 case i: res = fold<Id, i>(la->get()); break;
88#undef CODE
89 default: fe::unreachable();
90 }
91
92 return world.lit(type, *res);
93 }
94
95 return nullptr;
96}
97
98template<class Id, Id id, nat_t w>
99Res fold(u64 a, u64 b) {
100 using T = w2f<w>;
101 auto x = bitcast<T>(a), y = bitcast<T>(b);
102 if constexpr (std::is_same_v<Id, arith>) {
103 if constexpr (false) {}
104 else if constexpr (id == arith::add) return x + y;
105 else if constexpr (id == arith::sub) return x - y;
106 else if constexpr (id == arith::mul) return x * y;
107 else if constexpr (id == arith::div) return x / y;
108 else if constexpr (id == arith::rem) return rem(x, y);
109 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
110 } else if constexpr (std::is_same_v<Id, math::extrema>) {
111 if (x == T(-0.0) && y == T(+0.0)) return (id == extrema::fmin || id == extrema::ieee754min) ? x : y;
112 if (x == T(+0.0) && y == T(-0.0)) return (id == extrema::fmin || id == extrema::ieee754min) ? y : x;
113
114 if constexpr (id == extrema::fmin || id == extrema::fmax) {
115 return id == extrema::fmin ? std::fmin(x, y) : std::fmax(x, y);
116 } else if constexpr (id == extrema::ieee754min || id == extrema::ieee754max) {
117 if (std::isnan(x)) return x;
118 if (std::isnan(y)) return y;
119 return id == extrema::ieee754min ? std::fmin(x, y) : std::fmax(x, y);
120 } else {
121 []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
122 }
123 } else if constexpr (std::is_same_v<Id, pow>) {
124 return std::pow(a, b);
125 } else if constexpr (std::is_same_v<Id, cmp>) {
126 using std::isunordered;
127 bool res = false;
128 res |= ((id & cmp::u) != cmp::f) && isunordered(x, y);
129 res |= ((id & cmp::g) != cmp::f) && x > y;
130 res |= ((id & cmp::l) != cmp::f) && x < y;
131 res |= ((id & cmp::e) != cmp::f) && x == y;
132 return res;
133 } else {
134 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
135 }
136}
137// clang-format on
138
139template<class Id, Id id>
140const Def* fold(World& world, const Def* type, const Def* a) {
141 if (a->isa<Bot>()) return world.bot(type);
142
143 if (auto la = Lit::isa(a)) {
144 nat_t width = *isa_f(a->type());
145 Res res;
146 switch (width) {
147#define CODE(i) \
148 case i: res = fold<Id, id, i>(*la); break;
150#undef CODE
151 default: fe::unreachable();
152 }
153
154 return world.lit(type, *res);
155 }
156
157 return nullptr;
158}
159
160// Note that @p a and @p b are passed by reference as fold also commutes if possible.
161template<class Id, Id id>
162const Def* fold(World& world, const Def* type, const Def*& a, const Def*& b) {
163 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
164
165 if (auto la = Lit::isa(a)) {
166 if (auto lb = Lit::isa(b)) {
167 nat_t width = *isa_f(a->type());
168 Res res;
169 switch (width) {
170#define CODE(i) \
171 case i: res = fold<Id, id, i>(*la, *lb); break;
173#undef CODE
174 default: fe::unreachable();
175 }
176
177 return world.lit(type, *res);
178 }
179 }
180
181 if (is_commutative(id) && Def::greater(a, b)) std::swap(a, b);
182 return nullptr;
183}
184
185/// Reassociates @p a und @p b according to following rules.
186/// We use the following naming convention while literals are prefixed with an 'l':
187/// ```
188/// a op b
189/// (x op y) op (z op w)
190///
191/// (1) la op (lz op w) -> (la op lz) op w
192/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
193/// (3) a op (lz op w) -> lz op (a op w)
194/// (4) (lx op y) op b -> lx op (y op b)
195/// ```
196template<class Id>
197const Def* reassociate(Id id, World& world, [[maybe_unused]] const App* ab, const Def* a, const Def* b) {
198 if (!is_associative(id)) return nullptr;
199
200 auto xy = Axm::isa<Id>(id, a);
201 auto zw = Axm::isa<Id>(id, b);
202 auto la = a->isa<Lit>();
203 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
204 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
205 auto lx = Lit::isa(x);
206 auto lz = Lit::isa(z);
207
208 // build mode for all new ops by using the least upper bound of all involved apps
209 auto mode = (nat_t)Mode::bot;
210 auto check_mode = [&](const App* app) {
211 auto app_m = Lit::isa(app->arg(0));
212 if (!app_m || !(*app_m & Mode::reassoc)) return false;
213 mode &= *app_m; // least upper bound
214 return true;
215 };
216
217 if (!check_mode(ab)) return nullptr;
218 if (lx && !check_mode(xy->decurry())) return nullptr;
219 if (lz && !check_mode(zw->decurry())) return nullptr;
220
221 auto make_op = [&](const Def* a, const Def* b) { return world.call(id, mode, Defs{a, b}); };
222
223 if (la && lz) return make_op(make_op(a, z), w); // (1)
224 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
225 if (lz) return make_op(z, make_op(a, w)); // (3)
226 if (lx) return make_op(x, make_op(y, b)); // (4)
227 return nullptr;
228}
229
230template<class Id, Id id, nat_t sw, nat_t dw>
231Res fold(u64 a) {
232 using S = std::conditional_t<id == conv::s2f, w2s<sw>, std::conditional_t<id == conv::u2f, w2u<sw>, w2f<sw>>>;
233 using D = std::conditional_t<id == conv::f2s, w2s<dw>, std::conditional_t<id == conv::f2u, w2u<dw>, w2f<dw>>>;
234 return D(bitcast<S>(a));
235}
236
237} // namespace
238
239template<arith id>
240const Def* normalize_arith(const Def* type, const Def* c, const Def* arg) {
241 auto& world = type->world();
242 auto callee = c->as<App>();
243 auto [a, b] = arg->projs<2>();
244 auto mode = callee->arg();
245 auto lm = Lit::isa(mode);
246 auto w = isa_f(a->type());
247
248 if (auto result = fold<arith, id>(world, type, a, b)) return result;
249
250 // clang-format off
251 // TODO check mode properly
252 if (lm && *lm == Mode::fast) {
253 if (auto la = a->isa<Lit>()) {
254 if (la == lit_f(world, *w, 0.0)) {
255 switch (id) {
256 case arith::add: return b; // 0 + b -> b
257 case arith::sub: break;
258 case arith::mul: return la; // 0 * b -> 0
259 case arith::div: return la; // 0 / b -> 0
260 case arith::rem: return la; // 0 % b -> 0
261 }
262 }
263
264 if (la == lit_f(world, *w, 1.0)) {
265 switch (id) {
266 case arith::add: break;
267 case arith::sub: break;
268 case arith::mul: return b; // 1 * b -> b
269 case arith::div: break;
270 case arith::rem: break;
271 }
272 }
273 }
274
275 if (auto lb = b->isa<Lit>()) {
276 if (lb == lit_f(world, *w, 0.0)) {
277 switch (id) {
278 case arith::sub: return a; // a - 0 -> a
279 case arith::div: break;
280 case arith::rem: break;
281 default: fe::unreachable();
282 // add, mul are commutative, the literal has been normalized to the left
283 }
284 }
285 }
286
287 if (a == b) {
288 switch (id) {
289 case arith::add: return world.call(arith::mul, mode, Defs{lit_f(world, *w, 2.0), a}); // a + a -> 2 * a
290 case arith::sub: return lit_f(world, *w, 0.0); // a - a -> 0
291 case arith::mul: break;
292 case arith::div: return lit_f(world, *w, 1.0); // a / a -> 1
293 case arith::rem: break;
294 }
295 }
296 }
297 // clang-format on
298
299 if (auto res = reassociate<arith>(id, world, callee, a, b)) return res;
300
301 return world.raw_app(type, callee, {a, b});
302}
303
304template<extrema id>
305const Def* normalize_extrema(const Def* type, const Def* c, const Def* arg) {
306 auto& world = type->world();
307 auto callee = c->as<App>();
308 auto [a, b] = arg->projs<2>();
309 auto m = callee->arg();
310 auto lm = Lit::isa(m);
311 // TODO commute
312
313 if (auto lit = fold<extrema, id>(world, type, a, b)) return lit;
314
315 if (lm && *lm & (Mode::nnan | Mode::nsz)) { // if ignore NaNs and signed zero, then *imum -> *num
316 switch (id) {
317 case extrema::ieee754min: return world.call(extrema::fmin, m, Defs{a, b});
318 case extrema::ieee754max: return world.call(extrema::fmax, m, Defs{a, b});
319 default: break;
320 }
321 }
322
323 return world.raw_app(type, c, {a, b});
324}
325
326template<tri id>
327const Def* normalize_tri(const Def* type, const Def*, const Def* arg) {
328 auto& world = type->world();
329 if (auto lit = fold<tri, id>(world, type, arg)) return lit;
330 return {};
331}
332
333const Def* normalize_pow(const Def* type, const Def*, const Def* arg) {
334 auto& world = type->world();
335 auto [a, b] = arg->projs<2>();
336 if (auto lit = fold<pow, /*dummy*/ pow(0)>(world, type, a, b)) return lit;
337 return {};
338}
339
340template<rt id>
341const Def* normalize_rt(const Def* type, const Def*, const Def* arg) {
342 auto& world = type->world();
343 if (auto lit = fold<rt, id>(world, type, arg)) return lit;
344 return {};
345}
346
347template<exp id>
348const Def* normalize_exp(const Def* type, const Def*, const Def* arg) {
349 auto& world = type->world();
350 if (auto lit = fold<exp, id>(world, type, arg)) return lit;
351 return {};
352}
353
354template<er id>
355const Def* normalize_er(const Def* type, const Def*, const Def* arg) {
356 auto& world = type->world();
357 if (auto lit = fold<er, id>(world, type, arg)) return lit;
358 return {};
359}
360
361template<gamma id>
362const Def* normalize_gamma(const Def* type, const Def*, const Def* arg) {
363 auto& world = type->world();
364 if (auto lit = fold<gamma, id>(world, type, arg)) return lit;
365 return {};
366}
367
368template<cmp id>
369const Def* normalize_cmp(const Def* type, const Def* c, const Def* arg) {
370 auto& world = type->world();
371 auto callee = c->as<App>();
372 auto [a, b] = arg->projs<2>();
373
374 if (auto result = fold<cmp, id>(world, type, a, b)) return result;
375 if (id == cmp::f) return world.lit_ff();
376 if (id == cmp::t) return world.lit_tt();
377
378 return world.raw_app(type, callee, {a, b});
379}
380
381template<conv id>
382const Def* normalize_conv(const Def* dst_t, const Def*, const Def* x) {
383 auto& world = dst_t->world();
384 auto s_t = x->type()->as<App>();
385 auto d_t = dst_t->as<App>();
386 auto s = s_t->arg();
387 auto d = d_t->arg();
388 auto ls = Lit::isa(s);
389 auto ld = Lit::isa(d);
390
391 if (s_t == d_t) return x;
392 if (x->isa<Bot>()) return world.bot(d_t);
393
394 if (auto l = Lit::isa(x); l && ls && ld) {
395 constexpr bool sf = id == conv::f2f || id == conv::f2s || id == conv::f2u;
396 constexpr bool df = id == conv::f2f || id == conv::s2f || id == conv::u2f;
397 constexpr nat_t min_s = sf ? 16 : 1;
398 constexpr nat_t min_d = df ? 16 : 1;
399 auto sw = sf ? isa_f(s_t) : Idx::size2bitwidth(*ls);
400 auto dw = df ? isa_f(d_t) : Idx::size2bitwidth(*ld);
401
402 if (sw && dw) {
403 Res res;
404 // clang-format off
405 if (false) {}
406#define M(S, D) \
407 else if (S == *sw && D == *dw) { \
408 if constexpr (S >= min_s && D >= min_d) \
409 res = fold<conv, id, S, D>(*l); \
410 else \
411 return {}; \
412 }
413 M( 1, 1) M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
414 M( 8, 1) M( 8, 8) M( 8, 16) M( 8, 32) M( 8, 64)
415 M(16, 1) M(16, 8) M(16, 16) M(16, 32) M(16, 64)
416 M(32, 1) M(32, 8) M(32, 16) M(32, 32) M(32, 64)
417 M(64, 1) M(64, 8) M(64, 16) M(64, 32) M(64, 64)
418
419 else fe::unreachable();
420 // clang-format on
421 return world.lit(d_t, *res);
422 }
423 }
424
425 return {};
426}
427
428const Def* normalize_abs(const Def* type, const Def*, const Def* arg) {
429 auto& world = type->world();
430 if (auto lit = fold<abs>(world, type, arg)) return lit;
431 return {};
432}
433
434template<round id>
435const Def* normalize_round(const Def* type, const Def*, const Def* arg) {
436 auto& world = type->world();
437 if (auto lit = fold<round, id>(world, type, arg)) return lit;
438 return {};
439}
440
442
443} // namespace mim::plug::math
const Def * arg() const
Definition lam.h:282
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
World & world() const noexcept
Definition def.cpp:436
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:390
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:295
static bool greater(const Def *a, const Def *b)
Definition def.cpp:543
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:877
static std::optional< T > isa(const Def *def)
Definition def.h:810
Utility class when folding constants in normalizers.
Definition normalize.h:8
#define M(S, D)
#define MIM_math_NORMALIZER_IMPL
Definition autogen.h:369
The math Plugin
Definition math.h:8
const Def * normalize_extrema(const Def *type, const Def *c, const Def *arg)
const Lit * lit_f(World &w, R val)
Definition math.h:89
const Def * normalize_er(const Def *type, const Def *, const Def *arg)
const Def * normalize_cmp(const Def *type, const Def *c, const Def *arg)
const Def * normalize_abs(const Def *type, const Def *, const Def *arg)
const Def * normalize_gamma(const Def *type, const Def *, const Def *arg)
@ fast
All flags.
Definition math.h:35
@ reassoc
Allow reassociation transformations for floating-point operations.
Definition math.h:31
@ nsz
No Signed Zeros.
Definition math.h:23
@ nnan
No NaNs.
Definition math.h:17
@ bot
Alias for Mode::fast.
Definition math.h:38
const Def * mode(World &w, VMode m)
mim::plug::math::VMode -> const Def*.
Definition math.h:46
const Def * normalize_arith(const Def *type, const Def *c, const Def *arg)
const Def * normalize_round(const Def *type, const Def *, const Def *arg)
std::optional< nat_t > isa_f(const Def *def)
Definition math.h:76
const Def * normalize_tri(const Def *type, const Def *, const Def *arg)
const Def * normalize_exp(const Def *type, const Def *, const Def *arg)
const Def * normalize_rt(const Def *type, const Def *, const Def *arg)
const Def * normalize_pow(const Def *type, const Def *, const Def *arg)
const Def * normalize_conv(const Def *dst_t, const Def *, const Def *x)
View< const Def * > Defs
Definition def.h:76
u64 nat_t
Definition types.h:43
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:23
typename detail::w2f_< w >::type w2f
Definition types.h:74
constexpr bool is_commutative(Id)
Definition axm.h:152
constexpr bool is_associative(Id id)
Definition axm.h:158
TExt< false > Bot
Definition lattice.h:171
uint64_t u64
Definition types.h:34
CODE(node, _)
Definition def.h:113
@ App
Definition def.h:114
@ Lit
Definition def.h:114
#define MIM_16_32_64(m)
Definition types.h:26