MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
math.h
Go to the documentation of this file.
1#pragma once
2
3#include <mim/axiom.h>
4#include <mim/world.h>
5
7
8namespace mim::plug::math {
9
10/// @name Mode
11///@{
12// clang-format off
13/// Allowed optimizations for a specific operation.
14enum class Mode : nat_t {
15 top = 0,
16 none = top, ///< Alias for Mode::none.
17 nnan = 1 << 0, ///< No NaNs.
18 ///< Allow optimizations to assume the arguments and result are not NaN.
19 ///< Such optimizations are required to retain defined behavior over NaNs, but the value of the result is undefined.
20 ninf = 1 << 1, ///< No Infs.
21 ///< Allow optimizations to assume the arguments and result are not +/-Inf.
22 ///< Such optimizations are required to retain defined behavior over +/-Inf, but the value of the result is undefined.
23 nsz = 1 << 2, ///< No Signed Zeros.
24 ///< Allow optimizations to treat the sign of a zero argument or result as insignificant.
25 arcp = 1 << 3, ///< Allow Reciprocal.
26 ///< Allow optimizations to use the reciprocal of an argument rather than perform division.
27 contract = 1 << 4, ///< Allow floating-point contraction
28 ///< (e.g. fusing a multiply followed by an addition into a fused multiply-and-add).
29 afn = 1 << 5, ///< Approximate functions.
30 ///< Allow substitution of approximate calculations for functions (sin, log, sqrt, etc).
31 reassoc = 1 << 6, ///< Allow reassociation transformations for floating-point operations.
32 ///< This may dramatically change results in floating point.
33 finite = nnan | ninf, ///< Mode::nnan `|` Mode::ninf.
34 unsafe = nsz | arcp | reassoc, ///< Mode::nsz `|` Mode::arcp `|` Mode::reassoc
35 fast = nnan | ninf | nsz
36 | arcp | contract | afn
37 | reassoc, ///< All flags.
38 bot = fast, ///< Alias for Mode::fast.
39};
40// clang-format on
41
42/// Give Mode as mim::plug::math::Mode, mim::nat_t or Ref.
43using VMode = std::variant<Mode, nat_t, Ref>;
44
45/// mim::plug::math::VMode -> Ref.
46inline Ref mode(World& w, VMode m) {
47 if (auto def = std::get_if<Ref>(&m)) return *def;
48 if (auto nat = std::get_if<nat_t>(&m)) return w.lit_nat(*nat);
49 return w.lit_nat((nat_t)std::get<Mode>(m));
50}
51///@}
52
53/// @name %%math.F
54///@{
55inline Ref type_f(Ref pe) {
56 World& w = pe->world();
57 return w.app(w.annex<F>(), pe);
58}
59inline Ref type_f(World& w, nat_t p, nat_t e) {
60 auto lp = w.lit_nat(p);
61 auto le = w.lit_nat(e);
62 return type_f(w.tuple({lp, le}));
63}
64template<nat_t P, nat_t E> inline auto match_f(Ref def) {
65 if (auto f_ty = match<F>(def)) {
66 auto [p, e] = f_ty->arg()->projs<2>([](auto op) { return Lit::isa(op); });
67 if (p && e && *p == P && *e == E) return f_ty;
68 }
69 return Match<F, App>();
70}
71
72inline auto match_f16(Ref def) { return match_f<10, 5>(def); }
73inline auto match_f32(Ref def) { return match_f<23, 8>(def); }
74inline auto match_f64(Ref def) { return match_f<52, 11>(def); }
75
76inline std::optional<nat_t> isa_f(Ref def) {
77 if (auto f_ty = match<F>(def)) {
78 if (auto [p, e] = f_ty->arg()->projs<2>([](auto op) { return Lit::isa(op); }); p && e) {
79 if (*p == 10 && e == 5) return 16;
80 if (*p == 23 && e == 8) return 32;
81 if (*p == 52 && e == 11) return 64;
82 }
83 }
84 return {};
85}
86
87// clang-format off
88template<class R>
89const Lit* lit_f(World& w, R val) {
90 static_assert(std::is_floating_point<R>() || std::is_same<R, mim::f16>());
91 if constexpr (false) {}
92 else if constexpr (sizeof(R) == 2) return w.lit(w.annex<F16>(), mim::bitcast<u16>(val));
93 else if constexpr (sizeof(R) == 4) return w.lit(w.annex<F32>(), mim::bitcast<u32>(val));
94 else if constexpr (sizeof(R) == 8) return w.lit(w.annex<F64>(), mim::bitcast<u64>(val));
95 else fe::unreachable();
96}
97
98inline const Lit* lit_f(World& w, nat_t width, mim::f64 val) {
99 switch (width) {
100 case 16: assert(mim::f64(mim::f16(mim::f32(val))) == val && "loosing precision"); return lit_f(w, mim::f16(mim::f32(val)));
101 case 32: assert(mim::f64(mim::f32( (val))) == val && "loosing precision"); return lit_f(w, mim::f32( (val)));
102 case 64: assert(mim::f64(mim::f64( (val))) == val && "loosing precision"); return lit_f(w, mim::f64( (val)));
103 default: fe::unreachable();
104 }
105}
106// clang-format on
107///@}
108
109/// @name %%math.arith
110///@{
111inline Ref op_rminus(VMode m, Ref a) {
112 World& w = a->world();
113 auto s = isa_f(a->type());
114 return w.call(arith::sub, mode(w, m), Defs{lit_f(w, *s, -0.0), a});
115}
116///@}
117
118} // namespace mim::plug::math
119
120namespace mim {
121
122/// @name is_commutative/is_associative
123///@{
124// clang-format off
125constexpr bool is_commutative(plug::math::extrema ) { return true; }
127constexpr bool is_commutative(plug::math::cmp id) { return id == plug::math::cmp ::e || id == plug::math::cmp ::ne ; }
128constexpr bool is_associative(plug::math::arith id) { return is_commutative(id); }
129// clang-format off
130///@}
131
132} // namespace mim
133
134#ifndef DOXYGEN
135template<> struct fe::is_bit_enum<mim::plug::math::Mode> : std::true_type {};
136#endif
static std::optional< T > isa(Ref def)
Definition def.h:763
Helper class to retrieve Infer::arg if present.
Definition def.h:86
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:33
The math Plugin
Definition math.h:8
auto match_f(Ref def)
Definition math.h:64
const Lit * lit_f(World &w, R val)
Definition math.h:89
std::variant< Mode, nat_t, Ref > VMode
Give Mode as mim::plug::math::Mode, mim::nat_t or Ref.
Definition math.h:43
auto match_f16(Ref def)
Definition math.h:72
Ref op_rminus(VMode m, Ref a)
Definition math.h:111
auto match_f32(Ref def)
Definition math.h:73
Mode
Allowed optimizations for a specific operation.
Definition math.h:14
@ arcp
Allow Reciprocal.
@ none
Alias for Mode::none.
@ afn
Approximate functions.
@ unsafe
Mode::nsz | Mode::arcp | Mode::reassoc.
@ reassoc
Allow reassociation transformations for floating-point operations.
@ contract
Allow floating-point contraction (e.g.
@ nsz
No Signed Zeros.
@ finite
Mode::nnan | Mode::ninf.
@ bot
Alias for Mode::fast.
auto match_f64(Ref def)
Definition math.h:74
Ref type_f(Ref pe)
Definition math.h:55
std::optional< nat_t > isa_f(Ref def)
Definition math.h:76
Ref mode(World &w, VMode m)
mim::plug::math::VMode -> Ref.
Definition math.h:46
Definition cfg.h:11
u64 nat_t
Definition types.h:43
D bitcast(const S &src)
A bitcast from src of type S to D.
Definition util.h:26
auto match(Ref def)
Definition axiom.h:112
double f64
Definition types.h:41
constexpr bool is_commutative(Id)
Definition axiom.h:139
float f32
Definition types.h:40
half f16
Definition types.h:39
constexpr bool is_associative(Id id)
Definition axiom.h:142
constexpr decltype(auto) get(mim::Span< T, N > span)
Definition span.h:113