MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
tuple.cpp
Go to the documentation of this file.
1#include "mim/tuple.h"
2
3#include <cassert>
4
5#include "mim/world.h"
6
7// TODO this code needs to be rewritten
8
9namespace mim {
10
11// clang-format off
12const Def* Arr ::rebuild(World& w, const Def* shape, const Def* body) const { return w.arr (shape, body)->set(dbg()); }
13const Def* Pack::rebuild(World& w, const Def* shape, const Def* body) const { return w.pack(shape, body)->set(dbg()); }
14
15const Def* Arr ::prod(World& w, Defs ops) const { return w.sigma(ops)->set(dbg()); }
16const Def* Pack::prod(World& w, Defs ops) const { return w.tuple(ops)->set(dbg()); }
17// clang-format on
18
19namespace {
20bool should_flatten(const Def* def) {
21 auto type = (def->is_term() ? def->type() : def);
22 if (type->isa<Sigma>()) return true;
23 if (auto arr = type->isa<Arr>()) {
24 if (auto a = arr->isa_lit_arity(); a && *a > def->world().flags().scalarize_threshold) return false;
25 return true;
26 }
27 return false;
28}
29
30bool mut_val_or_typ(const Def* def) {
31 auto typ = def->is_term() ? def->type() : def;
32 return typ->isa_mut();
33}
34
35const Def* unflatten(Defs defs, const Def* type, size_t& j, bool flatten_muts) {
36 if (!defs.empty() && defs[0]->type() == type) return defs[j++];
37 if (auto a = type->isa_lit_arity();
38 flatten_muts == mut_val_or_typ(type) && a && *a != 1 && a <= type->world().flags().scalarize_threshold) {
39 auto& world = type->world();
40 auto ops = DefVec(*a, [&](size_t i) { return unflatten(defs, type->proj(*a, i), j, flatten_muts); });
41 return world.tuple(type, ops);
42 }
43
44 return defs[j++];
45}
46} // namespace
47
48const Def* Pack::shape() const {
49 if (auto arr = type()->isa<Arr>()) return arr->shape();
50 if (type() == world().sigma()) return world().lit_nat_0();
51 return world().lit_nat_1();
52}
53
54bool is_unit(const Def* def) { return def->type() == def->world().sigma(); }
55
56std::string tuple2str(const Def* def) {
57 if (def == nullptr) return {};
58
59 auto& w = def->world();
60 auto res = std::string();
61 if (auto n = Lit::isa(def->arity())) {
62 for (size_t i = 0; i != *n; ++i) {
63 auto elem = def->proj(*n, i);
64 if (elem->type() == w.type_i8()) {
65 if (auto l = Lit::isa<char>(elem)) {
66 res.push_back(*l);
67 continue;
68 }
69 }
70 return {};
71 }
72 }
73 return res;
74}
75
76size_t flatten(DefVec& ops, const Def* def, bool flatten_muts) {
77 if (auto a = def->isa_lit_arity(); a && *a != 1 && should_flatten(def) && flatten_muts == mut_val_or_typ(def)) {
78 auto n = 0;
79 for (size_t i = 0; i != *a; ++i) n += flatten(ops, def->proj(*a, i), flatten_muts);
80 return n;
81 } else {
82 ops.emplace_back(def);
83 return 1;
84 }
85}
86
87const Def* flatten(const Def* def) {
88 if (!should_flatten(def)) return def;
89 DefVec ops;
90 flatten(ops, def);
91 return def->is_intro() ? def->world().tuple(def->type(), ops) : def->world().sigma(ops);
92}
93
94const Def* unflatten(Defs defs, const Def* type, bool flatten_muts) {
95 size_t j = 0;
96 auto def = unflatten(defs, type, j, flatten_muts);
97 assert(j == defs.size());
98 return def;
99}
100
101const Def* unflatten(const Def* def, const Def* type) { return unflatten(def->projs(Lit::as(def->arity())), type); }
102
103DefVec merge(const Def* def, Defs defs) {
104 return DefVec(defs.size() + 1, [&](size_t i) { return i == 0 ? def : defs[i - 1]; });
105}
106
108 DefVec result(a.size() + b.size());
109 auto [_, o] = std::ranges::copy(a, result.begin());
110 std::ranges::copy(b, o);
111 return result;
112}
113
114const Def* merge_sigma(const Def* def, Defs defs) {
115 if (auto sigma = def->isa_imm<Sigma>()) return def->world().sigma(merge(sigma->ops(), defs));
116 return def->world().sigma(merge(def, defs));
117}
118
119const Def* merge_tuple(const Def* def, Defs defs) {
120 auto& w = def->world();
121 if (auto sigma = def->type()->isa_imm<Sigma>()) {
122 auto a = sigma->num_ops();
123 auto tuple = DefVec(a, [&](auto i) { return w.extract(def, a, i); });
124 return w.tuple(merge(tuple, defs));
125 }
126
127 return def->world().tuple(merge(def, defs));
128}
129
130const Def* tuple_of_types(const Def* t) {
131 auto& world = t->world();
132 if (auto sigma = t->isa<Sigma>()) return world.tuple(sigma->ops());
133 if (auto arr = t->isa<Arr>()) return world.pack(arr->shape(), arr->body());
134 return t;
135}
136
137} // namespace mim
A (possibly paramterized) Array.
Definition tuple.h:100
const Def * shape() const final
Definition tuple.h:110
friend class World
Definition tuple.h:145
Base class for all Defs.
Definition def.h:203
const Def * proj(nat_t a, nat_t i) const
Similar to World::extract while assuming an arity of a, but also works on Sigmas and Arrays.
Definition def.cpp:492
World & world() const noexcept
Definition def.cpp:377
bool is_intro() const noexcept
Definition def.h:236
constexpr auto ops() const noexcept
Definition def.h:266
bool is_term() const
Definition def.cpp:421
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:350
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:247
std::optional< nat_t > isa_lit_arity() const
Definition def.cpp:462
const Def * arity() const
Definition def.cpp:455
const T * isa_imm() const
Definition def.h:429
Dbg dbg() const
Definition def.h:453
static std::optional< T > isa(const Def *def)
Definition def.h:733
static T as(const Def *def)
Definition def.h:738
friend class World
Definition tuple.h:188
const Def * prod(World &, Defs) const final
Creates either a Tuple or Sigma.
Definition tuple.cpp:16
const Def * shape() const final
Definition tuple.cpp:48
const Def * rebuild(World &w, const Def *shape, const Def *body) const final
Definition tuple.cpp:13
const Def * body() const
Definition tuple.h:86
Def(World *, Node, const Def *type, Defs ops, flags_t flags)
Constructor for an immutable Def.
Definition def.cpp:23
A dependent tuple type.
Definition tuple.h:15
const Def * sigma(Defs ops)
Definition world.cpp:272
const Def * tuple(Defs ops)
Definition world.cpp:282
Flags & flags()
Retrieve compile Flags.
Definition world.cpp:72
const Lit * lit_nat_0()
Definition world.h:385
const Lit * lit_nat_1()
Definition world.h:386
Definition ast.h:14
View< const Def * > Defs
Definition def.h:49
const Def * flatten(const Def *def)
Flattens a sigma/array/pack/tuple.
Definition tuple.cpp:87
Vector< const Def * > DefVec
Definition def.h:50
bool is_unit(const Def *)
Definition tuple.cpp:54
uint64_t scalarize_threshold
Definition flags.h:13
const Def * merge_sigma(const Def *def, Defs defs)
Definition tuple.cpp:114
std::string tuple2str(const Def *)
Definition tuple.cpp:56
DefVec merge(Defs, Defs)
Definition tuple.cpp:107
const Def * unflatten(const Def *def, const Def *type)
Applies the reverse transformation on a Pack / Tuple, given the original type.
Definition tuple.cpp:101
const Def * tuple_of_types(const Def *t)
Definition tuple.cpp:130
const Def * merge_tuple(const Def *def, Defs defs)
Definition tuple.cpp:119