MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <iostream>
2
3#include "mim/axiom.h"
4#include "mim/world.h"
5
7
8// TODO: combine map_reduce calls
9
10namespace mim::plug::matrix {
11
12/// Normalizer for read opertions
13/// - read(constMat v) -> v
14/// - read(insert m v i, i) -> v (TODO: check with map_reduce)
15/// - read(insert m v i, j) -> read(m, i) if i <> j (TODO: wanted? useful?)
16/// - read(transpose m, (i,j)) -> read(m, (j,i)) (TODO: check for map_reduce)
17/// - read(product m1 m2, (i,j)) -> ... (TODO: check with map_reduce)
18/// - read (map_reduce f) idx = loop f idx (TODO: implement => use inner loop from lowering phase)
19Ref normalize_read(Ref type, Ref callee, Ref arg) {
20 auto& world = type->world();
21 auto [mem, mat, index] = arg->projs<3>();
22
23 world.DLOG("normalizing read: mat: {}\n", mat);
24
25 if (auto mex = mat->isa<Extract>()) {
26 world.DLOG(" extract: {}\n", mex);
27 auto ccall = mex->tuple();
28 world.DLOG(" ex_mat: {}\n", ccall);
29 if (auto mcm = match<constMat>(ccall)) {
30 world.DLOG(" const mat: {}\n", mcm);
31 auto [cmem, v] = mcm->arg()->projs<2>();
32 return world.tuple({mem, v});
33 }
34 }
35
36 return world.raw_app(type, callee, arg);
37}
38
39/// Normalizer for write operations
40/// TODO: implement
41Ref normalize_insert(Ref type, Ref callee, Ref arg) {
42 auto& world = type->world();
43 // auto [mat, index, val] = arg->projs<3>();
44
45 // same as read
46 // TODO:
47
48 return world.raw_app(type, callee, arg);
49}
50
51/// Normalizer for transpose operations
52/// - transpose (constMat v) -> cosntMat v (TODO: implement)
53/// - transpose (insert m v (i,j)) -> insert (transpose m) v (j,i) (TODO: implement, maybe other way around?)
54/// - transpose (tranpose m) -> m (TODO: implement)
55
56/// - shape (\@mat n (k1,k2,...,kn) i) -> (k1,k2,...,kn)\#i (TODO: implement)
57Ref normalize_shape(Ref type, Ref callee, Ref arg) {
58 auto& world = type->world();
59 auto [mat, index] = arg->projs<2>();
60 auto [dims, sizes, body_type] = match<Mat, false>(mat->type())->args<3>();
61 (void)callee;
62
63 return world.extract(sizes, index);
64}
65
66/// Matrix normalizer for product on two-dimensional matrices
67/// - product (constMat v1, constMat v2) -> constMat v1 * v2 * dim (TODO: implement)
68/// - product (constMat v, m) -> ... (TODO: implement)
69/// - product (m, constMat v) -> ... (TODO: implement)
70/// - product (id, m) -> m (TODO: check)
71/// - product (m, id) -> m
72
73/// - map(constMat v, f) -> constMat f(v) (TODO: implement)
74/// - map f (map g m) -> map (f . g) m (TODO: implement)
75/// - map f (zipWith g m1 m2) -> zipWith (f . g) m1 m2 (TODO: implement)
77 auto max_idx = init;
78
79 for (auto inp : inputs) {
80 auto [indices, mat] = inp->projs<2>();
81 auto indice_count = Lit::isa(indices->arity());
82 if (!indice_count) return -1;
83 for (auto idx : indices->projs()) {
84 auto idx_val = Lit::isa(idx);
85 if (!idx_val) return -1;
86 if (idx_val > max_idx) max_idx = idx_val.value();
87 }
88 }
89
90 return max_idx;
91}
92
93/// map_reduce normalizers
94/// - TODO: map_reduce (..., ((idx,map_reduce([out, ]...), ...))) -> unify idx, out (out is implicit), name vars apart
95/// requires: same reduction, distributive reduction
96/// we assume distributivity of the reduction function
97Ref normalize_map_reduce(Ref type, Ref callee, Ref arg) {
98 auto& world = type->world();
99
100 // // TODO: now that map_reduce returns a mem needs to check if extract from map_reduce
101 return world.raw_app(type, callee, arg);
102}
103
104Ref normalize_prod(Ref type, Ref callee, Ref arg) {
105 auto& world = type->world();
106 return world.raw_app(type, callee, arg);
107}
108
109Ref normalize_transpose(Ref type, Ref callee, Ref arg) {
110 auto& world = type->world();
111 return world.raw_app(type, callee, arg);
112}
113
115
116} // namespace mim::plug::matrix
World & world() const
Definition def.cpp:415
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Definition def.h:367
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:152
static std::optional< T > isa(Ref def)
Definition def.h:763
Helper class to retrieve Infer::arg if present.
Definition def.h:86
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
Ref raw_app(Ref type, Ref callee, Ref arg)
Definition world.cpp:217
#define MIM_matrix_NORMALIZER_IMPL
Definition autogen.h:118
The matrix Plugin
Definition matrix.h:9
Ref normalize_shape(Ref type, Ref callee, Ref arg)
Normalizer for transpose operations.
Ref normalize_insert(Ref type, Ref callee, Ref arg)
Normalizer for write operations TODO: implement.
Ref normalize_prod(Ref type, Ref callee, Ref arg)
Ref normalize_map_reduce(Ref type, Ref callee, Ref arg)
map_reduce normalizers
Ref normalize_read(Ref type, Ref callee, Ref arg)
Normalizer for read opertions.
u64 get_max_index(u64 init, Defs inputs)
Matrix normalizer for product on two-dimensional matrices.
Ref normalize_transpose(Ref type, Ref callee, Ref arg)
The mem Plugin
Definition mem.h:11
auto match(Ref def)
Definition axiom.h:112
uint64_t u64
Definition types.h:34