MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <iostream>
2
3#include "mim/axiom.h"
4#include "mim/world.h"
5
7
8// TODO: combine map_reduce calls
9
10namespace mim::plug::matrix {
11
12/// Normalizer for read opertions
13/// - read(constMat v) -> v
14/// - read(insert m v i, i) -> v (TODO: check with map_reduce)
15/// - read(insert m v i, j) -> read(m, i) if i <> j (TODO: wanted? useful?)
16/// - read(transpose m, (i,j)) -> read(m, (j,i)) (TODO: check for map_reduce)
17/// - read(product m1 m2, (i,j)) -> ... (TODO: check with map_reduce)
18/// - read (map_reduce f) idx = loop f idx (TODO: implement => use inner loop from lowering phase)
19const Def* normalize_read(const Def* type, const Def*, const Def* arg) {
20 auto& world = type->world();
21 auto [mem, mat, index] = arg->projs<3>();
22
23 world.DLOG("normalizing read: mat: {}\n", mat);
24
25 if (auto mex = mat->isa<Extract>()) {
26 world.DLOG(" extract: {}\n", mex);
27 auto ccall = mex->tuple();
28 world.DLOG(" ex_mat: {}\n", ccall);
29 if (auto mcm = match<constMat>(ccall)) {
30 world.DLOG(" const mat: {}\n", mcm);
31 auto [cmem, v] = mcm->arg()->projs<2>();
32 return world.tuple({mem, v});
33 }
34 }
35
36 return {};
37}
38
39/// Normalizer for write operations
40/// TODO: implement
41const Def* normalize_insert(const Def*, const Def*, const Def*) { return {}; }
42
43/// Normalizer for transpose operations
44/// - transpose (constMat v) -> cosntMat v (TODO: implement)
45/// - transpose (insert m v (i,j)) -> insert (transpose m) v (j,i) (TODO: implement, maybe other way around?)
46/// - transpose (tranpose m) -> m (TODO: implement)
47
48/// - shape (\@mat n (k1,k2,...,kn) i) -> (k1,k2,...,kn)\#i (TODO: implement)
49const Def* normalize_shape(const Def* type, const Def* callee, const Def* arg) {
50 auto& world = type->world();
51 auto [mat, index] = arg->projs<2>();
52 auto [dims, sizes, body_type] = match<Mat, false>(mat->type())->args<3>();
53 (void)callee;
54
55 return world.extract(sizes, index);
56}
57
58/// Matrix normalizer for product on two-dimensional matrices
59/// - product (constMat v1, constMat v2) -> constMat v1 * v2 * dim (TODO: implement)
60/// - product (constMat v, m) -> ... (TODO: implement)
61/// - product (m, constMat v) -> ... (TODO: implement)
62/// - product (id, m) -> m (TODO: check)
63/// - product (m, id) -> m
64
65/// - map(constMat v, f) -> constMat f(v) (TODO: implement)
66/// - map f (map g m) -> map (f . g) m (TODO: implement)
67/// - map f (zipWith g m1 m2) -> zipWith (f . g) m1 m2 (TODO: implement)
69 auto max_idx = init;
70
71 for (auto inp : inputs) {
72 auto [indices, mat] = inp->projs<2>();
73 auto indice_count = Lit::isa(indices->arity());
74 if (!indice_count) return -1;
75 for (auto idx : indices->projs()) {
76 auto idx_val = Lit::isa(idx);
77 if (!idx_val) return -1;
78 if (idx_val > max_idx) max_idx = idx_val.value();
79 }
80 }
81
82 return max_idx;
83}
84
85/// map_reduce normalizers
86/// - TODO: map_reduce (..., ((idx,map_reduce([out, ]...), ...))) -> unify idx, out (out is implicit), name vars apart
87/// requires: same reduction, distributive reduction
88/// we assume distributivity of the reduction function
89const Def* normalize_map_reduce(const Def*, const Def*, const Def*) { return {}; }
90const Def* normalize_prod(const Def*, const Def*, const Def*) { return {}; }
91const Def* normalize_transpose(const Def*, const Def*, const Def*) { return {}; }
92
94
95} // namespace mim::plug::matrix
Base class for all Defs.
Definition def.h:198
World & world() const noexcept
Definition def.cpp:413
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:345
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:155
static std::optional< T > isa(const Def *def)
Definition def.h:730
#define MIM_matrix_NORMALIZER_IMPL
Definition autogen.h:118
The matrix Plugin
Definition matrix.h:9
const Def * normalize_insert(const Def *, const Def *, const Def *)
Normalizer for write operations TODO: implement.
const Def * normalize_map_reduce(const Def *, const Def *, const Def *)
map_reduce normalizers
const Def * normalize_shape(const Def *type, const Def *callee, const Def *arg)
Normalizer for transpose operations.
const Def * normalize_prod(const Def *, const Def *, const Def *)
u64 get_max_index(u64 init, Defs inputs)
Matrix normalizer for product on two-dimensional matrices.
const Def * normalize_read(const Def *type, const Def *, const Def *arg)
Normalizer for read opertions.
const Def * normalize_transpose(const Def *, const Def *, const Def *)
The mem Plugin
Definition mem.h:11
View< const Def * > Defs
Definition def.h:49
auto match(const Def *def)
Definition axiom.h:112
uint64_t u64
Definition types.h:34