MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <algorithm>
2#include <iterator>
3#include <numeric>
4#include <ranges>
5#include <vector>
6
8#include <fe/assert.h>
9
10#include "mim/axm.h"
11#include "mim/def.h"
12#include "mim/tuple.h"
13#include "mim/world.h"
14
15#include "mim/util/log.h"
16
18
21
22namespace mim::plug::regex {
23
24template<quant id>
25const Def* normalize_quant(const Def* type, const Def* callee, const Def* arg) {
26 auto& world = type->world();
27
28 // quantifiers are idempotent
29 if (Axm::isa(id, arg)) return arg;
30
31 if constexpr (id == quant::plus) {
32 // (\d?)+ and (\d*)+ == \d*
33 if (auto optional_app = Axm::isa(quant::optional, arg))
34 return world.call(quant::star, optional_app->arg());
35 else if (auto star_app = Axm::isa(quant::star, arg))
36 return arg;
37 } else if constexpr (id == quant::star) {
38 // (\d?)* and (\d+)* == \d*
39 if (auto quant_app = Axm::isa<quant>(arg)) return world.app(callee, quant_app->arg());
40 } else if constexpr (id == quant::optional) {
41 // (\d*)? and (\d+)? == \d*
42 if (auto star_app = Axm::isa(quant::star, arg))
43 return arg;
44 else if (auto plus_app = Axm::isa(quant::plus, arg))
45 return world.call(quant::star, plus_app->arg());
46 }
47
48 return {};
49}
50
51template<class ConjOrDisj>
52void flatten_in_arg(const Def* arg, DefVec& new_args) {
53 for (const auto* proj : arg->projs()) {
54 // flatten conjs in conjs / disj in disjs
55 if (auto seq_app = Axm::isa<ConjOrDisj>(proj))
56 flatten_in_arg<ConjOrDisj>(seq_app->arg(), new_args);
57 else
58 new_args.push_back(proj);
59 }
60}
61
62template<class ConjOrDisj>
64 DefVec new_args;
65 flatten_in_arg<ConjOrDisj>(arg, new_args);
66 return new_args;
67}
68
69template<class ConjOrDisj>
71 assert(!args.empty());
72 auto& world = args.front()->world();
73 return std::accumulate(args.begin() + 1, args.end(), args.front(), [&world](const Def* lhs, const Def* rhs) {
74 return world.call<ConjOrDisj, false>(Defs{lhs, rhs});
75 });
76}
77
78const Def* normalize_conj(const Def* type, const Def* callee, const Def* arg) {
79 auto& world = type->world();
80 world.DLOG("conj {}:{} ({})", type, callee, arg);
81
82 if (auto a = Lit::isa(arg->arity())) {
83 switch (*a) {
84 case 0: return world.lit_tt();
85 case 1: return arg;
86 default: return make_binary_tree<conj>(flatten_in_arg<conj>(arg));
87 }
88 }
89
90 return {};
91}
92
93bool compare_re(const Def* lhs, const Def* rhs) {
94 auto lhs_range = Axm::isa<range>(lhs);
95 auto rhs_range = Axm::isa<range>(rhs);
96 // sort ranges by increasing lower bound
97 if (lhs_range && rhs_range) return Lit::as(lhs_range->arg()->proj(0)) < Lit::as(rhs_range->arg()->proj(0));
98 // ranges to the end
99 if (lhs_range) return false;
100 if (rhs_range) return true;
101
102 return lhs->gid() < rhs->gid(); // make irreflexive
103}
104
106 std::stable_sort(args.begin(), args.end(), &compare_re);
107 {
108 auto new_end = std::unique(args.begin(), args.end());
109 args.erase(new_end, args.end());
110 }
111}
112
113bool is_in_range(Range range, nat_t needle) { return needle >= range.first && needle <= range.second; }
114
115auto get_range(const Def* rng) -> Range {
116 auto rng_match = Axm::isa<range, false>(rng);
117 return {Lit::as<std::uint8_t>(rng_match->arg(0)), Lit::as<std::uint8_t>(rng_match->arg(1))};
118}
119
120struct app_range {
122 const Def* operator()(Range rng) { return w.call<range>(Defs{w.lit_i8(rng.first), w.lit_i8(rng.second)}); }
123};
124
125void merge_ranges(DefVec& args) {
126 auto ranges_begin = args.begin();
127 while (ranges_begin != args.end() && !Axm::isa<range>(*ranges_begin))
128 ranges_begin++;
129 if (ranges_begin == args.end()) return;
130
131 std::set<const Def*> to_remove;
132 Ranges old_ranges;
133 auto& world = (*ranges_begin)->world();
134
135 std::transform(ranges_begin, args.end(), std::back_inserter(old_ranges), get_range);
136
137 auto new_ranges = automaton::merge_ranges(
138 old_ranges, [&world](auto&&... args) { world.DLOG(std::forward<decltype(args)>(args)...); });
139
140 // invalidates ranges_begin
141 args.erase(ranges_begin, args.end());
142 std::transform(new_ranges.begin(), new_ranges.end(), std::back_inserter(args), app_range{world});
143
144 make_vector_unique(args);
145}
146
147template<cls A, cls B>
148bool equals_any(const Def* cls0, const Def* cls1) {
149 return (Axm::isa(A, cls0) && Axm::isa(B, cls1)) || (Axm::isa(A, cls1) && Axm::isa(B, cls0));
150}
151
152bool equals_any(const Def* lhs, const Def* rhs) {
153 auto check_arg_equiv = [](const Def* lhs, const Def* rhs) {
154 if (auto rng_lhs = Axm::isa<range>(lhs))
155 if (auto not_rhs = Axm::isa<not_>(rhs)) {
156 if (auto rng_rhs = Axm::isa<range>(not_rhs->arg())) return rng_lhs == rng_rhs;
157 }
158 return false;
159 };
160
161 return check_arg_equiv(lhs, rhs) || check_arg_equiv(rhs, lhs);
162}
163
164bool equals_any(Defs lhs, Defs rhs) {
165 Ranges lhs_ranges, rhs_ranges;
166 auto only_ranges = std::ranges::views::filter([](auto d) { return Axm::isa<range>(d); });
167 std::ranges::transform(lhs | only_ranges, std::back_inserter(lhs_ranges), get_range);
168 std::ranges::transform(rhs | only_ranges, std::back_inserter(rhs_ranges), get_range);
169 return std::ranges::includes(lhs_ranges, rhs_ranges) || std::ranges::includes(rhs_ranges, lhs_ranges);
170}
171
172const Def* normalize_disj(const Def* type, const Def*, const Def* arg) {
173 auto& world = type->world();
174 if (auto a = Lit::isa(arg->arity())) {
175 switch (*a) {
176 case 0: return world.lit_ff();
177 case 1: return arg;
178 default:
179 auto contains_any = [](auto args) {
180 return std::ranges::find_if(args, [](const Def* ax) -> bool { return Axm::isa<any>(ax); })
181 != args.end();
182 };
183
184 auto new_args = flatten_in_arg<disj>(arg);
185 if (contains_any(new_args)) return world.annex<any>();
186 make_vector_unique(new_args);
187 merge_ranges(new_args);
188
189 const Def* to_remove = nullptr;
190 for (const auto* cls0 : new_args) {
191 for (const auto* cls1 : new_args)
192 if (equals_any(cls0, cls1)) return world.annex<any>();
193
194 if (auto not_rhs = Axm::isa<not_>(cls0)) {
195 if (auto disj_rhs = Axm::isa<disj>(not_rhs->arg())) {
196 auto rngs = flatten_in_arg<disj>(disj_rhs->arg());
197 make_vector_unique(rngs);
198 if (equals_any(new_args, rngs)) return world.annex<any>();
199 }
200 }
201 }
202
203 erase(new_args, to_remove);
204 world.DLOG("final ranges {, }", new_args);
205
206 if (new_args.size() > 2) return make_binary_tree<disj>(new_args);
207 if (new_args.size() > 1) return world.call<disj, false>(new_args);
208 return new_args.back();
209 }
210 }
211 return {};
212}
213
214const Def* normalize_range(const Def* type, const Def* callee, const Def* arg) {
215 auto& world = type->world();
216 auto [lhs, rhs] = arg->projs<2>();
217
218 if (!lhs->isa<Var>() && !rhs->isa<Var>()) // before first PE.
219 if (lhs->as<Lit>()->get() > rhs->as<Lit>()->get()) return world.raw_app(type, callee, {rhs, lhs});
220
221 return {};
222}
223
224const Def* normalize_not(const Def*, const Def*, const Def*) { return {}; }
225
227
228} // namespace mim::plug::regex
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:216
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:355
constexpr u32 gid() const noexcept
Global id - unique number for this Def.
Definition def.h:235
virtual const Def * arity() const
Definition def.cpp:539
static std::optional< T > isa(const Def *def)
Definition def.h:773
T get() const
Definition def.h:760
static T as(const Def *def)
Definition def.h:779
This is a thin wrapper for absl::InlinedVector<T, N, A> which is a drop-in replacement for std::vecto...
Definition vector.h:18
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:34
automaton::Range Range
Vector< Range > Ranges
std::pair< std::uint64_t, std::uint64_t > Range
std::optional< Range > merge_ranges(Range a, Range b) noexcept
The regex Plugin
Definition lower_regex.h:5
void merge_ranges(DefVec &args)
void make_vector_unique(DefVec &args)
auto get_range(const Def *rng) -> Range
void flatten_in_arg(const Def *arg, DefVec &new_args)
bool is_in_range(Range range, nat_t needle)
const Def * normalize_conj(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_range(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_disj(const Def *type, const Def *, const Def *arg)
const Def * normalize_quant(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_not(const Def *, const Def *, const Def *)
bool compare_re(const Def *lhs, const Def *rhs)
const Def * make_binary_tree(Defs args)
bool equals_any(const Def *cls0, const Def *cls1)
View< const Def * > Defs
Definition def.h:51
u64 nat_t
Definition types.h:43
Vector< const Def * > DefVec
Definition def.h:52
Vector< T, N, A >::size_type erase(Vector< T, N, A > &c, const U &value) noexcept
Definition vector.h:70
#define MIM_regex_NORMALIZER_IMPL
Definition autogen.h:93
const Def * operator()(Range rng)