MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <mim/normalize.h>
2
4
5namespace mim::plug::math {
6
7namespace {
8
9// clang-format off
10template<class Id, Id id, nat_t w>
11Res fold(u64 a) {
12 using T = w2f<w>;
13 auto x = bitcast<T>(a);
14 if constexpr (std::is_same_v<Id, tri>) {
15 if constexpr (false) {}
16 else if constexpr (id == tri:: sin ) return sin (x);
17 else if constexpr (id == tri:: cos ) return cos (x);
18 else if constexpr (id == tri:: tan ) return tan (x);
19 else if constexpr (id == tri:: sinh) return sinh(x);
20 else if constexpr (id == tri:: cosh) return cosh(x);
21 else if constexpr (id == tri:: tanh) return tanh(x);
22 else if constexpr (id == tri::asin ) return asin (x);
23 else if constexpr (id == tri::acos ) return acos (x);
24 else if constexpr (id == tri::atan ) return atan (x);
25 else if constexpr (id == tri::asinh) return asinh(x);
26 else if constexpr (id == tri::acosh) return acosh(x);
27 else if constexpr (id == tri::atanh) return atanh(x);
28 else fe::unreachable();
29 } else if constexpr (std::is_same_v<Id, rt>) {
30 if constexpr (false) {}
31 else if constexpr (id == rt::sq) return std::sqrt(x);
32 else if constexpr (id == rt::cb) return std::cbrt(x);
33 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
34 } else if constexpr (std::is_same_v<Id, exp>) {
35 if constexpr (false) {}
36 else if constexpr (id == exp::exp ) return std::exp (x);
37 else if constexpr (id == exp::exp2 ) return std::exp2 (x);
38 else if constexpr (id == exp::exp10) return std::pow(T(10), x);
39 else if constexpr (id == exp::log ) return std::log (x);
40 else if constexpr (id == exp::log2 ) return std::log2 (x);
41 else if constexpr (id == exp::log10) return std::log10(x);
42 else fe::unreachable();
43 } else if constexpr (std::is_same_v<Id, er>) {
44 if constexpr (false) {}
45 else if constexpr (id == er::f ) return std::erf (x);
46 else if constexpr (id == er::fc) return std::erfc(x);
47 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
48 } else if constexpr (std::is_same_v<Id, gamma>) {
49 if constexpr (false) {}
50 else if constexpr (id == gamma::t) return std::tgamma(x);
51 else if constexpr (id == gamma::l) return std::lgamma(x);
52 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
53 } else if constexpr (std::is_same_v<Id, round>) {
54 if constexpr (false) {}
55 else if constexpr (id == round::f) return std::floor (x);
56 else if constexpr (id == round::c) return std::ceil (x);
57 else if constexpr (id == round::r) return std::round (x);
58 else if constexpr (id == round::t) return std::trunc (x);
59 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
60 } else {
61 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
62 }
63}
64
65template<class Id, nat_t w>
66Res fold(u64 a) {
67 using T = w2f<w>;
68 auto x = bitcast<T>(a);
69 if constexpr (std::is_same_v<Id, abs>) {
70 return std::abs(x);
71 } else {
72 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
73 }
74}
75
76template<class Id>
77Ref fold(World& world, Ref type, const Def* a) {
78 if (a->isa<Bot>()) return world.bot(type);
79 auto la = a->isa<Lit>();
80
81 if (la) {
82 nat_t width = *isa_f(a->type());
83 Res res;
84 switch (width) {
85#define CODE(i) \
86 case i: res = fold<Id, i>(la->get()); break;
88#undef CODE
89 default: fe::unreachable();
90 }
91
92 return world.lit(type, *res);
93 }
94
95 return nullptr;
96}
97
98template<class Id, Id id, nat_t w>
99Res fold(u64 a, u64 b) {
100 using T = w2f<w>;
101 auto x = bitcast<T>(a), y = bitcast<T>(b);
102 if constexpr (std::is_same_v<Id, arith>) {
103 if constexpr (false) {}
104 else if constexpr (id == arith::add) return x + y;
105 else if constexpr (id == arith::sub) return x - y;
106 else if constexpr (id == arith::mul) return x * y;
107 else if constexpr (id == arith::div) return x / y;
108 else if constexpr (id == arith::rem) return rem(x, y);
109 else []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
110 } else if constexpr (std::is_same_v<Id, math::extrema>) {
111 if (x == T(-0.0) && y == T(+0.0)) return (id == extrema::fmin || id == extrema::ieee754min) ? x : y;
112 if (x == T(+0.0) && y == T(-0.0)) return (id == extrema::fmin || id == extrema::ieee754min) ? y : x;
113
114 if constexpr (id == extrema::fmin || id == extrema::fmax) {
115 return id == extrema::fmin ? std::fmin(x, y) : std::fmax(x, y);
116 } else if constexpr (id == extrema::ieee754min || id == extrema::ieee754max) {
117 if (std::isnan(x)) return x;
118 if (std::isnan(y)) return y;
119 return id == extrema::ieee754min ? std::fmin(x, y) : std::fmax(x, y);
120 } else {
121 []<bool flag = false>() { static_assert(flag, "missing sub tag"); }();
122 }
123 } else if constexpr (std::is_same_v<Id, pow>) {
124 return std::pow(a, b);
125 } else if constexpr (std::is_same_v<Id, cmp>) {
126 using std::isunordered;
127 bool res = false;
128 res |= ((id & cmp::u) != cmp::f) && isunordered(x, y);
129 res |= ((id & cmp::g) != cmp::f) && x > y;
130 res |= ((id & cmp::l) != cmp::f) && x < y;
131 res |= ((id & cmp::e) != cmp::f) && x == y;
132 return res;
133 } else {
134 []<bool flag = false>() { static_assert(flag, "missing tag"); }();
135 }
136}
137// clang-format on
138
139template<class Id, Id id> Ref fold(World& world, Ref type, const Def* a) {
140 if (a->isa<Bot>()) return world.bot(type);
141
142 if (auto la = Lit::isa(a)) {
143 nat_t width = *isa_f(a->type());
144 Res res;
145 switch (width) {
146#define CODE(i) \
147 case i: res = fold<Id, id, i>(*la); break;
149#undef CODE
150 default: fe::unreachable();
151 }
152
153 return world.lit(type, *res);
154 }
155
156 return nullptr;
157}
158
159// Note that @p a and @p b are passed by reference as fold also commutes if possible.
160template<class Id, Id id> Ref fold(World& world, Ref type, const Def*& a, const Def*& b) {
161 if (a->isa<Bot>() || b->isa<Bot>()) return world.bot(type);
162
163 if (auto la = Lit::isa(a)) {
164 if (auto lb = Lit::isa(b)) {
165 nat_t width = *isa_f(a->type());
166 Res res;
167 switch (width) {
168#define CODE(i) \
169 case i: res = fold<Id, id, i>(*la, *lb); break;
171#undef CODE
172 default: fe::unreachable();
173 }
174
175 return world.lit(type, *res);
176 }
177 }
178
179 if (is_commutative(id)) commute(a, b);
180 return nullptr;
181}
182
183/// Reassociates @p a und @p b according to following rules.
184/// We use the following naming convention while literals are prefixed with an 'l':
185/// ```
186/// a op b
187/// (x op y) op (z op w)
188///
189/// (1) la op (lz op w) -> (la op lz) op w
190/// (2) (lx op y) op (lz op w) -> (lx op lz) op (y op w)
191/// (3) a op (lz op w) -> lz op (a op w)
192/// (4) (lx op y) op b -> lx op (y op b)
193/// ```
194template<class Id> Ref reassociate(Id id, World& world, [[maybe_unused]] const App* ab, Ref a, Ref b) {
195 if (!is_associative(id)) return nullptr;
196
197 auto xy = match<Id>(id, a);
198 auto zw = match<Id>(id, b);
199 auto la = a->isa<Lit>();
200 auto [x, y] = xy ? xy->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
201 auto [z, w] = zw ? zw->template args<2>() : std::array<const Def*, 2>{nullptr, nullptr};
202 auto lx = Lit::isa(x);
203 auto lz = Lit::isa(z);
204
205 // build mode for all new ops by using the least upper bound of all involved apps
206 auto mode = (nat_t)Mode::bot;
207 auto check_mode = [&](const App* app) {
208 auto app_m = Lit::isa(app->arg(0));
209 if (!app_m || !(*app_m & Mode::reassoc)) return false;
210 mode &= *app_m; // least upper bound
211 return true;
212 };
213
214 if (!check_mode(ab)) return nullptr;
215 if (lx && !check_mode(xy->decurry())) return nullptr;
216 if (lz && !check_mode(zw->decurry())) return nullptr;
217
218 auto make_op = [&](Ref a, Ref b) { return world.call(id, mode, Defs{a, b}); };
219
220 if (la && lz) return make_op(make_op(a, z), w); // (1)
221 if (lx && lz) return make_op(make_op(x, z), make_op(y, w)); // (2)
222 if (lz) return make_op(z, make_op(a, w)); // (3)
223 if (lx) return make_op(x, make_op(y, b)); // (4)
224 return nullptr;
225}
226
227template<class Id, Id id, nat_t sw, nat_t dw> Res fold(u64 a) {
228 using S = std::conditional_t<id == conv::s2f, w2s<sw>, std::conditional_t<id == conv::u2f, w2u<sw>, w2f<sw>>>;
229 using D = std::conditional_t<id == conv::f2s, w2s<dw>, std::conditional_t<id == conv::f2u, w2u<dw>, w2f<dw>>>;
230 return D(bitcast<S>(a));
231}
232
233} // namespace
234
235template<arith id> Ref normalize_arith(Ref type, Ref c, Ref arg) {
236 auto& world = type->world();
237 auto callee = c->as<App>();
238 auto [a, b] = arg->projs<2>();
239 auto mode = callee->arg();
240 auto lm = Lit::isa(mode);
241 auto w = isa_f(a->type());
242
243 if (auto result = fold<arith, id>(world, type, a, b)) return result;
244
245 // clang-format off
246 // TODO check mode properly
247 if (lm && *lm == Mode::fast) {
248 if (auto la = a->isa<Lit>()) {
249 if (la == lit_f(world, *w, 0.0)) {
250 switch (id) {
251 case arith::add: return b; // 0 + b -> b
252 case arith::sub: break;
253 case arith::mul: return la; // 0 * b -> 0
254 case arith::div: return la; // 0 / b -> 0
255 case arith::rem: return la; // 0 % b -> 0
256 }
257 }
258
259 if (la == lit_f(world, *w, 1.0)) {
260 switch (id) {
261 case arith::add: break;
262 case arith::sub: break;
263 case arith::mul: return b; // 1 * b -> b
264 case arith::div: break;
265 case arith::rem: break;
266 }
267 }
268 }
269
270 if (auto lb = b->isa<Lit>()) {
271 if (lb == lit_f(world, *w, 0.0)) {
272 switch (id) {
273 case arith::sub: return a; // a - 0 -> a
274 case arith::div: break;
275 case arith::rem: break;
276 default: fe::unreachable();
277 // add, mul are commutative, the literal has been normalized to the left
278 }
279 }
280 }
281
282 if (a == b) {
283 switch (id) {
284 case arith::add: return world.call(arith::mul, mode, Defs{lit_f(world, *w, 2.0), a}); // a + a -> 2 * a
285 case arith::sub: return lit_f(world, *w, 0.0); // a - a -> 0
286 case arith::mul: break;
287 case arith::div: return lit_f(world, *w, 1.0); // a / a -> 1
288 case arith::rem: break;
289 }
290 }
291 }
292 // clang-format on
293
294 if (auto res = reassociate<arith>(id, world, callee, a, b)) return res;
295
296 return world.raw_app(type, callee, {a, b});
297}
298
299template<extrema id> Ref normalize_extrema(Ref type, Ref c, Ref arg) {
300 auto& world = type->world();
301 auto callee = c->as<App>();
302 auto [a, b] = arg->projs<2>();
303 auto m = callee->arg();
304 auto lm = Lit::isa(m);
305
306 if (auto lit = fold<extrema, id>(world, type, a, b)) return lit;
307
308 if (lm && *lm & (Mode::nnan | Mode::nsz)) { // if ignore NaNs and signed zero, then *imum -> *num
309 switch (id) {
310 case extrema::ieee754min: return world.call(extrema::fmin, m, Defs{a, b});
311 case extrema::ieee754max: return world.call(extrema::fmax, m, Defs{a, b});
312 default: break;
313 }
314 }
315
316 return world.raw_app(type, c, arg);
317}
318
319template<tri id> Ref normalize_tri(Ref type, Ref c, Ref arg) {
320 auto& world = type->world();
321 if (auto lit = fold<tri, id>(world, type, arg)) return lit;
322 return world.raw_app(type, c, arg);
323}
324
326 auto& world = type->world();
327 auto [a, b] = arg->projs<2>();
328 if (auto lit = fold<pow, /*dummy*/ pow(0)>(world, type, a, b)) return lit;
329 return world.raw_app(type, c, arg);
330}
331
332template<rt id> Ref normalize_rt(Ref type, Ref c, Ref arg) {
333 auto& world = type->world();
334 if (auto lit = fold<rt, id>(world, type, arg)) return lit;
335 return world.raw_app(type, c, arg);
336}
337
338template<exp id> Ref normalize_exp(Ref type, Ref c, Ref arg) {
339 auto& world = type->world();
340 if (auto lit = fold<exp, id>(world, type, arg)) return lit;
341 return world.raw_app(type, c, arg);
342}
343
344template<er id> Ref normalize_er(Ref type, Ref c, Ref arg) {
345 auto& world = type->world();
346 if (auto lit = fold<er, id>(world, type, arg)) return lit;
347 return world.raw_app(type, c, arg);
348}
349
350template<gamma id> Ref normalize_gamma(Ref type, Ref c, Ref arg) {
351 auto& world = type->world();
352 if (auto lit = fold<gamma, id>(world, type, arg)) return lit;
353 return world.raw_app(type, c, arg);
354}
355
356template<cmp id> Ref normalize_cmp(Ref type, Ref c, Ref arg) {
357 auto& world = type->world();
358 auto callee = c->as<App>();
359 auto [a, b] = arg->projs<2>();
360
361 if (auto result = fold<cmp, id>(world, type, a, b)) return result;
362 if (id == cmp::f) return world.lit_ff();
363 if (id == cmp::t) return world.lit_tt();
364
365 return world.raw_app(type, callee, {a, b});
366}
367
368template<conv id> Ref normalize_conv(Ref dst_t, Ref c, Ref x) {
369 auto& world = dst_t->world();
370 auto callee = c->as<App>();
371 auto s_t = x->type()->as<App>();
372 auto d_t = dst_t->as<App>();
373 auto s = s_t->arg();
374 auto d = d_t->arg();
375 auto ls = Lit::isa(s);
376 auto ld = Lit::isa(d);
377
378 if (s_t == d_t) return x;
379 if (x->isa<Bot>()) return world.bot(d_t);
380
381 if (auto l = Lit::isa(x); l && ls && ld) {
382 constexpr bool sf = id == conv::f2f || id == conv::f2s || id == conv::f2u;
383 constexpr bool df = id == conv::f2f || id == conv::s2f || id == conv::u2f;
384 constexpr nat_t min_s = sf ? 16 : 1;
385 constexpr nat_t min_d = df ? 16 : 1;
386 auto sw = sf ? isa_f(s_t) : Idx::size2bitwidth(*ls);
387 auto dw = df ? isa_f(d_t) : Idx::size2bitwidth(*ld);
388
389 if (sw && dw) {
390 Res res;
391 // clang-format off
392 if (false) {}
393#define M(S, D) \
394 else if (S == *sw && D == *dw) { \
395 if constexpr (S >= min_s && D >= min_d) \
396 res = fold<conv, id, S, D>(*l); \
397 else \
398 goto out; \
399 }
400 M( 1, 1) M( 1, 8) M( 1, 16) M( 1, 32) M( 1, 64)
401 M( 8, 1) M( 8, 8) M( 8, 16) M( 8, 32) M( 8, 64)
402 M(16, 1) M(16, 8) M(16, 16) M(16, 32) M(16, 64)
403 M(32, 1) M(32, 8) M(32, 16) M(32, 32) M(32, 64)
404 M(64, 1) M(64, 8) M(64, 16) M(64, 32) M(64, 64)
405
406 else fe::unreachable();
407 // clang-format on
408 return world.lit(d_t, *res);
409 }
410 }
411out:
412 return world.raw_app(dst_t, callee, x);
413}
414
416 auto& world = type->world();
417 if (auto lit = fold<abs>(world, type, arg)) return lit;
418 return world.raw_app(type, c, arg);
419}
420
421template<round id> Ref normalize_round(Ref type, Ref c, Ref arg) {
422 auto& world = type->world();
423 if (auto lit = fold<round, id>(world, type, arg)) return lit;
424 return world.raw_app(type, c, arg);
425}
426
428
429} // namespace mim::plug::math
const Def * arg() const
Definition lam.h:227
World & world() const
Definition def.cpp:415
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Definition def.h:367
const Def * type() const
Definition def.h:248
static constexpr nat_t size2bitwidth(nat_t n)
Definition def.h:808
static std::optional< T > isa(Ref def)
Definition def.h:763
Helper class to retrieve Infer::arg if present.
Definition def.h:86
Utility class when folding constants in normalizers.
Definition normalize.h:8
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
Extremum. Either Top (Up) or Bottom.
Definition lattice.h:156
#define M(S, D)
#define MIM_math_NORMALIZER_IMPL
Definition autogen.h:369
@ App
Definition def.h:40
@ Lit
Definition def.h:40
The math Plugin
Definition math.h:8
Ref normalize_rt(Ref type, Ref c, Ref arg)
const Lit * lit_f(World &w, R val)
Definition math.h:89
Ref normalize_abs(Ref type, Ref c, Ref arg)
Ref normalize_exp(Ref type, Ref c, Ref arg)
Ref normalize_tri(Ref type, Ref c, Ref arg)
Ref normalize_conv(Ref dst_t, Ref c, Ref x)
Ref normalize_extrema(Ref type, Ref c, Ref arg)
Ref normalize_er(Ref type, Ref c, Ref arg)
@ reassoc
Allow reassociation transformations for floating-point operations.
@ nsz
No Signed Zeros.
@ bot
Alias for Mode::fast.
Ref normalize_round(Ref type, Ref c, Ref arg)
Ref normalize_arith(Ref type, Ref c, Ref arg)
std::optional< nat_t > isa_f(Ref def)
Definition math.h:76
Ref normalize_gamma(Ref type, Ref c, Ref arg)
Ref mode(World &w, VMode m)
mim::plug::math::VMode -> Ref.
Definition math.h:46
Ref normalize_cmp(Ref type, Ref c, Ref arg)
Ref normalize_pow(Ref type, Ref c, Ref arg)
View< const Def * > Defs
Definition def.h:61
u64 nat_t
Definition types.h:43
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:26
typename detail::w2f_< w >::type w2f
Definition types.h:74
auto match(Ref def)
Definition axiom.h:112
void commute(const Def *&a, const Def *&b)
Swap Lit to left - or smaller Def::gid, if no lit present.
Definition normalize.h:25
constexpr bool is_commutative(Id)
Definition axiom.h:139
constexpr bool is_associative(Id id)
Definition axiom.h:142
TExt< false > Bot
Definition lattice.h:177
uint64_t u64
Definition types.h:34
Definition span.h:104
#define CODE(t, str)
Definition tok.h:54
#define MIM_16_32_64(m)
Definition types.h:26