MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
math.h
Go to the documentation of this file.
1#pragma once
2
3#include <mim/axiom.h>
4#include <mim/world.h>
5
7
8namespace mim::plug::math {
9
10/// @name Mode
11///@{
12// clang-format off
13/// Allowed optimizations for a specific operation.
14enum class Mode : nat_t {
15 top = 0,
16 none = top, ///< Alias for Mode::none.
17 nnan = 1 << 0, ///< No NaNs.
18 ///< Allow optimizations to assume the arguments and result are not NaN.
19 ///< Such optimizations are required to retain defined behavior over NaNs, but the value of the result is undefined.
20 ninf = 1 << 1, ///< No Infs.
21 ///< Allow optimizations to assume the arguments and result are not +/-Inf.
22 ///< Such optimizations are required to retain defined behavior over +/-Inf, but the value of the result is undefined.
23 nsz = 1 << 2, ///< No Signed Zeros.
24 ///< Allow optimizations to treat the sign of a zero argument or result as insignificant.
25 arcp = 1 << 3, ///< Allow Reciprocal.
26 ///< Allow optimizations to use the reciprocal of an argument rather than perform division.
27 contract = 1 << 4, ///< Allow floating-point contraction
28 ///< (e.g. fusing a multiply followed by an addition into a fused multiply-and-add).
29 afn = 1 << 5, ///< Approximate functions.
30 ///< Allow substitution of approximate calculations for functions (sin, log, sqrt, etc).
31 reassoc = 1 << 6, ///< Allow reassociation transformations for floating-point operations.
32 ///< This may dramatically change results in floating point.
33 finite = nnan | ninf, ///< Mode::nnan `|` Mode::ninf.
34 unsafe = nsz | arcp | reassoc, ///< Mode::nsz `|` Mode::arcp `|` Mode::reassoc
36 | arcp | contract | afn
37 | reassoc, ///< All flags.
38 bot = fast, ///< Alias for Mode::fast.
39};
40// clang-format on
41
42/// Give Mode as mim::plug::math::Mode, mim::nat_t or const Def*.
43using VMode = std::variant<Mode, nat_t, const Def*>;
44
45/// mim::plug::math::VMode -> const Def*.
46inline const Def* mode(World& w, VMode m) {
47 if (auto def = std::get_if<const Def*>(&m)) return *def;
48 if (auto nat = std::get_if<nat_t>(&m)) return w.lit_nat(*nat);
49 return w.lit_nat((nat_t)std::get<Mode>(m));
50}
51///@}
52
53/// @name %%math.F
54///@{
55inline const Def* type_f(const Def* pe) {
56 World& w = pe->world();
57 return w.app(w.annex<F>(), pe);
58}
59inline const Def* type_f(World& w, nat_t p, nat_t e) {
60 auto lp = w.lit_nat(p);
61 auto le = w.lit_nat(e);
62 return type_f(w.tuple({lp, le}));
63}
64template<nat_t P, nat_t E> inline auto match_f(const Def* def) {
65 if (auto f_ty = match<F>(def)) {
66 auto [p, e] = f_ty->arg()->projs<2>([](auto op) { return Lit::isa(op); });
67 if (p && e && *p == P && *e == E) return f_ty;
68 }
69 return Match<F, App>();
70}
71
72inline auto match_f16(const Def* def) { return match_f<10, 5>(def); }
73inline auto match_f32(const Def* def) { return match_f<23, 8>(def); }
74inline auto match_f64(const Def* def) { return match_f<52, 11>(def); }
75
76inline std::optional<nat_t> isa_f(const Def* def) {
77 if (auto f_ty = match<F>(def)) {
78 if (auto [p, e] = f_ty->arg()->projs<2>([](auto op) { return Lit::isa(op); }); p && e) {
79 if (*p == 10 && e == 5) return 16;
80 if (*p == 23 && e == 8) return 32;
81 if (*p == 52 && e == 11) return 64;
82 }
83 }
84 return {};
85}
86
87// clang-format off
88template<class R>
89const Lit* lit_f(World& w, R val) {
90 static_assert(std::is_floating_point<R>() || std::is_same<R, mim::f16>());
91 if constexpr (false) {}
92 else if constexpr (sizeof(R) == 2) return w.lit(w.annex<F16>(), mim::bitcast<u16>(val));
93 else if constexpr (sizeof(R) == 4) return w.lit(w.annex<F32>(), mim::bitcast<u32>(val));
94 else if constexpr (sizeof(R) == 8) return w.lit(w.annex<F64>(), mim::bitcast<u64>(val));
95 else fe::unreachable();
96}
97
98inline const Lit* lit_f(World& w, nat_t width, mim::f64 val) {
99 switch (width) {
100 case 16: assert(mim::f64(mim::f16(mim::f32(val))) == val && "loosing precision"); return lit_f(w, mim::f16(mim::f32(val)));
101 case 32: assert(mim::f64(mim::f32( (val))) == val && "loosing precision"); return lit_f(w, mim::f32( (val)));
102 case 64: assert(mim::f64(mim::f64( (val))) == val && "loosing precision"); return lit_f(w, mim::f64( (val)));
103 default: fe::unreachable();
104 }
105}
106// clang-format on
107///@}
108
109/// @name %%math.arith
110///@{
111inline const Def* op_rminus(VMode m, const Def* a) {
112 World& w = a->world();
113 auto s = isa_f(a->type());
114 return w.call(arith::sub, mode(w, m), Defs{lit_f(w, *s, -0.0), a});
115}
116///@}
117
118} // namespace mim::plug::math
119
120namespace mim {
121
122/// @name is_commutative/is_associative
123///@{
124// clang-format off
125constexpr bool is_commutative(plug::math::extrema ) { return true; }
127constexpr bool is_commutative(plug::math::cmp id) { return id == plug::math::cmp ::e || id == plug::math::cmp ::ne ; }
128constexpr bool is_associative(plug::math::arith id) { return is_commutative(id); }
129// clang-format off
130///@}
131
132} // namespace mim
133
134#ifndef DOXYGEN
135template<> struct fe::is_bit_enum<mim::plug::math::Mode> : std::true_type {};
136#endif
Base class for all Defs.
Definition def.h:198
static std::optional< T > isa(const Def *def)
Definition def.h:730
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:33
The math Plugin
Definition math.h:8
const Lit * lit_f(World &w, R val)
Definition math.h:89
const Def * type_f(const Def *pe)
Definition math.h:55
auto match_f(const Def *def)
Definition math.h:64
Mode
Allowed optimizations for a specific operation.
Definition math.h:14
@ arcp
Allow Reciprocal.
Definition math.h:25
@ fast
All flags.
Definition math.h:35
@ none
Alias for Mode::none.
Definition math.h:16
@ afn
Approximate functions.
Definition math.h:29
@ ninf
No Infs.
Definition math.h:20
@ unsafe
Mode::nsz | Mode::arcp | Mode::reassoc.
Definition math.h:34
@ reassoc
Allow reassociation transformations for floating-point operations.
Definition math.h:31
@ contract
Allow floating-point contraction (e.g.
Definition math.h:27
@ nsz
No Signed Zeros.
Definition math.h:23
@ nnan
No NaNs.
Definition math.h:17
@ finite
Mode::nnan | Mode::ninf.
Definition math.h:33
@ bot
Alias for Mode::fast.
Definition math.h:38
const Def * mode(World &w, VMode m)
mim::plug::math::VMode -> const Def*.
Definition math.h:46
std::variant< Mode, nat_t, const Def * > VMode
Give Mode as mim::plug::math::Mode, mim::nat_t or const Def*.
Definition math.h:43
auto match_f16(const Def *def)
Definition math.h:72
std::optional< nat_t > isa_f(const Def *def)
Definition math.h:76
const Def * op_rminus(VMode m, const Def *a)
Definition math.h:111
auto match_f64(const Def *def)
Definition math.h:74
auto match_f32(const Def *def)
Definition math.h:73
Definition ast.h:14
View< const Def * > Defs
Definition def.h:49
u64 nat_t
Definition types.h:43
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:23
double f64
Definition types.h:41
constexpr bool is_commutative(Id)
Definition axiom.h:139
float f32
Definition types.h:40
auto match(const Def *def)
Definition axiom.h:112
half f16
Definition types.h:39
constexpr bool is_associative(Id id)
Definition axiom.h:142