MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <mim/normalize.h>
2
4
5namespace mim::plug::math {
6
7namespace {
8
9// clang-format off
10template<class Id, Id id, nat_t w>
11Res fold(u64 a) {
12 using T = w2f<w>;
13 auto x = bitcast<T>(a);
14 if constexpr (std::is_same_v<Id, tri>) {
15 if constexpr (false) {}
16 else if constexpr (id == tri:: sin ) return sin (x);
17 else if constexpr (id == tri:: cos ) return cos (x);
18 else if constexpr (id == tri:: tan ) return tan (x);
19 else if constexpr (id == tri:: sinh) return sinh(x);
20 else if constexpr (id == tri:: cosh) return cosh(x);
21 else if constexpr (id == tri:: tanh) return tanh(x);
22 else if constexpr (id == tri::asin ) return asin (x);
23 else if constexpr (id == tri::acos ) return acos (x);
24 else if constexpr (id == tri::atan ) return atan (x);
25 else if constexpr (id == tri::asinh) return asinh(x);
26 else if constexpr (id == tri::acosh) return acosh(x);
27 else if constexpr (id == tri::atanh) return atanh(x);
28 else fe::unreachable();
29 } else if constexpr (std::is_same_v<Id, rt>) {
30 if constexpr (false) {}
31 else if constexpr (id == rt::sq) return std::sqrt(x);
32 else if constexpr (id == rt::cb) return std::cbrt(x);
33 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
34 } else if constexpr (std::is_same_v<Id, exp>) {
35 if constexpr (false) {}
36 else if constexpr (id == exp::exp ) return std::exp (x);
37 else if constexpr (id == exp::exp2 ) return std::exp2 (x);
38 else if constexpr (id == exp::exp10) return std::pow(T(10), x);
39 else if constexpr (id == exp::log ) return std::log (x);
40 else if constexpr (id == exp::log2 ) return std::log2 (x);
41 else if constexpr (id == exp::log10) return std::log10(x);
42 else fe::unreachable();
43 } else if constexpr (std::is_same_v<Id, er>) {
44 if constexpr (false) {}
45 else if constexpr (id == er::f ) return std::erf (x);
46 else if constexpr (id == er::fc) return std::erfc(x);
47 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
48 } else if constexpr (std::is_same_v<Id, gamma>) {
49 if constexpr (false) {}
50 else if constexpr (id == gamma::t) return std::tgamma(x);
51 else if constexpr (id == gamma::l) return std::lgamma(x);
52 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
53 } else if constexpr (std::is_same_v<Id, round>) {
54 if constexpr (false) {}
55 else if constexpr (id == round::f) return std::floor (x);
56 else if constexpr (id == round::c) return std::ceil (x);
57 else if constexpr (id == round::r) return std::round (x);
58 else if constexpr (id == round::t) return std::trunc (x);
59 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
60 } else {
61 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
62 }
63}
64
65template<class Id, nat_t w>
66Res fold(u64 a) {
67 using T = w2f<w>;
68 auto x = bitcast<T>(a);
69 if constexpr (std::is_same_v<Id, abs>) {
70 return std::abs(x);
71 } else {
72 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
73 }
74}
75
76template<class Id>
77const Def* fold(World& world, const Def* type, const Def* a) {
78 if (a->isa<Bot>()) return world.bot(type);
79 auto la = a->isa<Lit>();
80
81 if (la) {
82 nat_t width = *isa_f(a->type());
83 Res res;
84 switch (width) {
85#define CODE(i) \
86 case i: res = fold<Id, i>(la->get()); break;
88#undef CODE
89 default: fe::unreachable();
90 }
91
92 return world.lit(type, *res);
93 }
94
95 return nullptr;
96}
97
98template<class Id, Id id, nat_t w>
99Res fold(u64 a, u64 b) {
100 using T = w2f<w>;
101 auto x = bitcast<T>(a), y = bitcast<T>(b);
102 if constexpr (std::is_same_v<Id, arith>) {
103 if constexpr (false) {}
104 else if constexpr (id == arith::add) return x + y;
105 else if constexpr (id == arith::sub) return x - y;
106 else if constexpr (id == arith::mul) return x * y;
107 else if constexpr (id == arith::div) return x / y;
108 else if constexpr (id == arith::rem) return rem(x, y);
109 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
110 } else if constexpr (std::is_same_v<Id, math::extrema>) {
111 if (x == T(-0.0) && y == T(+0.0)) return (id == extrema::fmin || id == extrema::ieee754min) ? x : y;
112 if (x == T(+0.0) && y == T(-0.0)) return (id == extrema::fmin || id == extrema::ieee754min) ? y : x;
113
114 if constexpr (id == extrema::fmin || id == extrema::fmax) {
115 return id == extrema::fmin ? std::fmin(x, y) : std::fmax(x, y);
116 } else if constexpr (id == extrema::ieee754min || id == extrema::ieee754max) {
117 if (std::isnan(x)) return x;
118 if (std::isnan(y)) return y;
119 return id == extrema::ieee754min ? std::fmin(x, y) : std::fmax(x, y);
120 } else {
121 []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
122 }
123 } else if constexpr (std::is_same_v<Id, pow>) {
124 return std::pow(a, b);
125 } else if constexpr (std::is_same_v<Id, cmp>) {
126 using std::isunordered;
127 bool res = false;
128 res |= ((id & cmp::u) != cmp::f) && isunordered(x, y);
129 res |= ((id & cmp::g) != cmp::f) && x > y;
130 res |= ((id & cmp::l) != cmp::f) && x < y;
131 res |= ((id & cmp::e) != cmp::f) && x == y;
132 return res;
133 } else {
134 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
135 }
136}
137// clang-format on
138
139template<class Id, Id id> const Def* fold(World& world, const Def* type, const Def* a) {
140 if (a->isa<Bot>()) return world.bot(type);
141
142 if (auto la = Lit::isa(a)) {
143 nat_t width = *isa_f(a->type());
144 Res res;
145 switch (width) {
146#define CODE(i) \
147 case i: res = fold<Id, id, i>(*la); break;
149#undef CODE
150 default: fe::unreachable();
151 }
152
153 return world.lit(type, *res);
154 }
155
156 return nullptr;
157}
158
159// Note that @p a and @p b are passed by reference as fold also commutes if possible.
160template<class Id, Id id> const Def* fold(World& world, const Def* type, const Def*& a, const Def*& b) {
161 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
162
163 if (auto la = Lit::isa(a)) {
164 if (auto lb = Lit::isa(b)) {
165 nat_t width = *isa_f(a->type());
166 Res res;
167 switch (width) {
168#define CODE(i) \
169 case i: res = fold<Id, id, i>(*la, *lb); break;
171#undef CODE
172 default: fe::unreachable();
173 }
174
175 return world.lit(type, *res);
176 }
177 }
178
179 if (is_commutative(id)) commute(a, b);
180 return nullptr;
181}
182
183/// Reassociates @p a und @p b according to following rules.
184/// We use the following naming convention while literals are prefixed with an 'l':
185/// ```
186/// a op b
187/// (x op y) op (z op w)
188///
189/// (1) la op (lz op w) -> (la op lz) op w
190/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
191/// (3) a op (lz op w) -> lz op (a op w)
192/// (4) (lx op y) op b -> lx op (y op b)
193/// ```
194template<class Id>
195const Def* reassociate(Id id, World& world, [[maybe_unused]] const App* ab, const Def* a, const Def* b) {
196 if (!is_associative(id)) return nullptr;
197
198 auto xy = match<Id>(id, a);
199 auto zw = match<Id>(id, b);
200 auto la = a->isa<Lit>();
201 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
202 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
203 auto lx = Lit::isa(x);
204 auto lz = Lit::isa(z);
205
206 // build mode for all new ops by using the least upper bound of all involved apps
207 auto mode = (nat_t)Mode::bot;
208 auto check_mode = [&](const App* app) {
209 auto app_m = Lit::isa(app->arg(0));
210 if (!app_m || !(*app_m & Mode::reassoc)) return false;
211 mode &= *app_m; // least upper bound
212 return true;
213 };
214
215 if (!check_mode(ab)) return nullptr;
216 if (lx && !check_mode(xy->decurry())) return nullptr;
217 if (lz && !check_mode(zw->decurry())) return nullptr;
218
219 auto make_op = [&](const Def* a, const Def* b) { return world.call(id, mode, Defs{a, b}); };
220
221 if (la && lz) return make_op(make_op(a, z), w); // (1)
222 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
223 if (lz) return make_op(z, make_op(a, w)); // (3)
224 if (lx) return make_op(x, make_op(y, b)); // (4)
225 return nullptr;
226}
227
228template<class Id, Id id, nat_t sw, nat_t dw> Res fold(u64 a) {
229 using S = std::conditional_t<id == conv::s2f, w2s<sw>, std::conditional_t<id == conv::u2f, w2u<sw>, w2f<sw>>>;
230 using D = std::conditional_t<id == conv::f2s, w2s<dw>, std::conditional_t<id == conv::f2u, w2u<dw>, w2f<dw>>>;
231 return D(bitcast<S>(a));
232}
233
234} // namespace
235
236template<arith id> const Def* normalize_arith(const Def* type, const Def* c, const Def* arg) {
237 auto& world = type->world();
238 auto callee = c->as<App>();
239 auto [a, b] = arg->projs<2>();
240 auto mode = callee->arg();
241 auto lm = Lit::isa(mode);
242 auto w = isa_f(a->type());
243
244 if (auto result = fold<arith, id>(world, type, a, b)) return result;
245
246 // clang-format off
247 // TODO check mode properly
248 if (lm && *lm == Mode::fast) {
249 if (auto la = a->isa<Lit>()) {
250 if (la == lit_f(world, *w, 0.0)) {
251 switch (id) {
252 case arith::add: return b; // 0 + b -> b
253 case arith::sub: break;
254 case arith::mul: return la; // 0 * b -> 0
255 case arith::div: return la; // 0 / b -> 0
256 case arith::rem: return la; // 0 % b -> 0
257 }
258 }
259
260 if (la == lit_f(world, *w, 1.0)) {
261 switch (id) {
262 case arith::add: break;
263 case arith::sub: break;
264 case arith::mul: return b; // 1 * b -> b
265 case arith::div: break;
266 case arith::rem: break;
267 }
268 }
269 }
270
271 if (auto lb = b->isa<Lit>()) {
272 if (lb == lit_f(world, *w, 0.0)) {
273 switch (id) {
274 case arith::sub: return a; // a - 0 -> a
275 case arith::div: break;
276 case arith::rem: break;
277 default: fe::unreachable();
278 // add, mul are commutative, the literal has been normalized to the left
279 }
280 }
281 }
282
283 if (a == b) {
284 switch (id) {
285 case arith::add: return world.call(arith::mul, mode, Defs{lit_f(world, *w, 2.0), a}); // a + a -> 2 * a
286 case arith::sub: return lit_f(world, *w, 0.0); // a - a -> 0
287 case arith::mul: break;
288 case arith::div: return lit_f(world, *w, 1.0); // a / a -> 1
289 case arith::rem: break;
290 }
291 }
292 }
293 // clang-format on
294
295 if (auto res = reassociate<arith>(id, world, callee, a, b)) return res;
296
297 return world.raw_app(type, callee, {a, b});
298}
299
300template<extrema id> const Def* normalize_extrema(const Def* type, const Def* c, const Def* arg) {
301 auto& world = type->world();
302 auto callee = c->as<App>();
303 auto [a, b] = arg->projs<2>();
304 auto m = callee->arg();
305 auto lm = Lit::isa(m);
306 // TODO commute
307
308 if (auto lit = fold<extrema, id>(world, type, a, b)) return lit;
309
310 if (lm && *lm & (Mode::nnan | Mode::nsz)) { // if ignore NaNs and signed zero, then *imum -> *num
311 switch (id) {
312 case extrema::ieee754min: return world.call(extrema::fmin, m, Defs{a, b});
313 case extrema::ieee754max: return world.call(extrema::fmax, m, Defs{a, b});
314 default: break;
315 }
316 }
317
318 return world.raw_app(type, c, {a, b});
319}
320
321template<tri id> const Def* normalize_tri(const Def* type, const Def*, const Def* arg) {
322 auto& world = type->world();
323 if (auto lit = fold<tri, id>(world, type, arg)) return lit;
324 return {};
325}
326
327const Def* normalize_pow(const Def* type, const Def*, const Def* arg) {
328 auto& world = type->world();
329 auto [a, b] = arg->projs<2>();
330 if (auto lit = fold<pow, /*dummy*/ pow(0)>(world, type, a, b)) return lit;
331 return {};
332}
333
334template<rt id> const Def* normalize_rt(const Def* type, const Def*, const Def* arg) {
335 auto& world = type->world();
336 if (auto lit = fold<rt, id>(world, type, arg)) return lit;
337 return {};
338}
339
340template<exp id> const Def* normalize_exp(const Def* type, const Def*, const Def* arg) {
341 auto& world = type->world();
342 if (auto lit = fold<exp, id>(world, type, arg)) return lit;
343 return {};
344}
345
346template<er id> const Def* normalize_er(const Def* type, const Def*, const Def* arg) {
347 auto& world = type->world();
348 if (auto lit = fold<er, id>(world, type, arg)) return lit;
349 return {};
350}
351
352template<gamma id> const Def* normalize_gamma(const Def* type, const Def*, const Def* arg) {
353 auto& world = type->world();
354 if (auto lit = fold<gamma, id>(world, type, arg)) return lit;
355 return {};
356}
357
358template<cmp id> const Def* normalize_cmp(const Def* type, const Def* c, const Def* arg) {
359 auto& world = type->world();
360 auto callee = c->as<App>();
361 auto [a, b] = arg->projs<2>();
362
363 if (auto result = fold<cmp, id>(world, type, a, b)) return result;
364 if (id == cmp::f) return world.lit_ff();
365 if (id == cmp::t) return world.lit_tt();
366
367 return world.raw_app(type, callee, {a, b});
368}
369
370template<conv id> const Def* normalize_conv(const Def* dst_t, const Def*, const Def* x) {
371 auto& world = dst_t->world();
372 auto s_t = x->type()->as<App>();
373 auto d_t = dst_t->as<App>();
374 auto s = s_t->arg();
375 auto d = d_t->arg();
376 auto ls = Lit::isa(s);
377 auto ld = Lit::isa(d);
378
379 if (s_t == d_t) return x;
380 if (x->isa<Bot>()) return world.bot(d_t);
381
382 if (auto l = Lit::isa(x); l && ls && ld) {
383 constexpr bool sf = id == conv::f2f || id == conv::f2s || id == conv::f2u;
384 constexpr bool df = id == conv::f2f || id == conv::s2f || id == conv::u2f;
385 constexpr nat_t min_s = sf ? 16 : 1;
386 constexpr nat_t min_d = df ? 16 : 1;
387 auto sw = sf ? isa_f(s_t) : Idx::size2bitwidth(*ls);
388 auto dw = df ? isa_f(d_t) : Idx::size2bitwidth(*ld);
389
390 if (sw && dw) {
391 Res res;
392 // clang-format off
393 if (false) {}
394#define M(S, D) \
395 else if (S == *sw && D == *dw) { \
396 if constexpr (S >= min_s && D >= min_d) \
397 res = fold<conv, id, S, D>(*l); \
398 else \
399 return {}; \
400 }
401 M( 1, 1) M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
402 M( 8, 1) M( 8, 8) M( 8, 16) M( 8, 32) M( 8, 64)
403 M(16, 1) M(16, 8) M(16, 16) M(16, 32) M(16, 64)
404 M(32, 1) M(32, 8) M(32, 16) M(32, 32) M(32, 64)
405 M(64, 1) M(64, 8) M(64, 16) M(64, 32) M(64, 64)
406
407 else fe::unreachable();
408 // clang-format on
409 return world.lit(d_t, *res);
410 }
411 }
412
413 return {};
414}
415
416const Def* normalize_abs(const Def* type, const Def*, const Def* arg) {
417 auto& world = type->world();
418 if (auto lit = fold<abs>(world, type, arg)) return lit;
419 return {};
420}
421
422template<round id> const Def* normalize_round(const Def* type, const Def*, const Def* arg) {
423 auto& world = type->world();
424 if (auto lit = fold<round, id>(world, type, arg)) return lit;
425 return {};
426}
427
429
430} // namespace mim::plug::math
const Def * arg() const
Definition lam.h:225
Base class for all Defs.
Definition def.h:198
World & world() const noexcept
Definition def.cpp:413
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:345
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:242
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:792
static std::optional< T > isa(const Def *def)
Definition def.h:730
Utility class when folding constants in normalizers.
Definition normalize.h:8
#define M(S, D)
#define MIM_math_NORMALIZER_IMPL
Definition autogen.h:369
The math Plugin
Definition math.h:8
const Def * normalize_extrema(const Def *type, const Def *c, const Def *arg)
const Lit * lit_f(World &w, R val)
Definition math.h:89
const Def * normalize_er(const Def *type, const Def *, const Def *arg)
const Def * normalize_cmp(const Def *type, const Def *c, const Def *arg)
const Def * normalize_abs(const Def *type, const Def *, const Def *arg)
const Def * normalize_gamma(const Def *type, const Def *, const Def *arg)
@ fast
All flags.
Definition math.h:35
@ reassoc
Allow reassociation transformations for floating-point operations.
Definition math.h:31
@ nsz
No Signed Zeros.
Definition math.h:23
@ nnan
No NaNs.
Definition math.h:17
@ bot
Alias for Mode::fast.
Definition math.h:38
const Def * mode(World &w, VMode m)
mim::plug::math::VMode -> const Def*.
Definition math.h:46
const Def * normalize_arith(const Def *type, const Def *c, const Def *arg)
const Def * normalize_round(const Def *type, const Def *, const Def *arg)
std::optional< nat_t > isa_f(const Def *def)
Definition math.h:76
const Def * normalize_tri(const Def *type, const Def *, const Def *arg)
const Def * normalize_exp(const Def *type, const Def *, const Def *arg)
const Def * normalize_rt(const Def *type, const Def *, const Def *arg)
const Def * normalize_pow(const Def *type, const Def *, const Def *arg)
const Def * normalize_conv(const Def *dst_t, const Def *, const Def *x)
View< const Def * > Defs
Definition def.h:49
u64 nat_t
Definition types.h:43
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:23
typename detail::w2f_< w >::type w2f
Definition types.h:74
void commute(const Def *&a, const Def *&b)
Swap Lit to left - or smaller Def::gid, if no lit present.
Definition normalize.h:25
constexpr bool is_commutative(Id)
Definition axiom.h:139
auto match(const Def *def)
Definition axiom.h:112
constexpr bool is_associative(Id id)
Definition axiom.h:142
TExt< false > Bot
Definition lattice.h:177
uint64_t u64
Definition types.h:34
@ App
Definition def.h:85
@ Lit
Definition def.h:85
CODE(node, name, _)
Definition def.h:84
#define MIM_16_32_64(m)
Definition types.h:26