MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_matrix_mediumlevel.h
Go to the documentation of this file.
1#pragma once
2
3#include <mim/def.h>
4
5#include <mim/pass/pass.h>
6
7namespace mim::plug::matrix {
8
9/// In this step, we lower `map_reduce` operations into affine for loops making the iteration scheme explicit.
10/// Pseudo-code:
11/// ```
12/// out_matrix = init
13/// for output_indices:
14/// acc = zero
15/// for input_indices:
16/// element_[0..m] = read(matrix[0..m], indices)
17/// acc = f (acc, elements)
18/// insert (out_matrix, output_indices, acc)
19/// return out_matrix
20/// ```
21///
22/// Detailed pseudo-code:
23/// * out indices = (0,1,2, ..., n)
24/// * bounds in S
25/// * we assume that certain paramters are constant and statically known
26/// to avoid inline-metaprogramming like multiiter
27/// e.g. the number of matrizes, the dimensions, the indices
28/// ```
29/// // iterate over out indices
30/// output = init_matrix (n,S,T)
31/// for i_0 in [0, S#0)
32/// ...
33/// for i_{n-1} in [0, S#(n-1))
34/// s = zero
35/// // iterate over non-out indices
36/// for j in [0, SI#(...)]:
37/// // indices depend on the specified access
38/// // input#k#0
39/// e_0 = read (input#0#1, (i_1, i_0))
40/// ...
41/// e_(m-1) = read (input#(m-1)#1, (i_2, j))
42///
43/// s = add(s, mul (e_0, ..., e_(m-1)) )
44/// write (output, (i_0, ..., i_{n-1}), s)
45/// ```
46class LowerMatrixMediumLevel : public RWPass<LowerMatrixMediumLevel, Lam> {
47public:
49 : RWPass(man, "lower_matrix_mediumlevel") {}
50
51 /// custom rewrite function
52 /// memoized version of rewrite_
53 const Def* rewrite(const Def*) override;
54 const Def* rewrite_(const Def*);
55
56private:
57 Def2Def rewritten;
58};
59
60} // namespace mim::plug::matrix
Base class for all Defs.
Definition def.h:198
PassMan & man()
Definition pass.h:30
friend class PassMan
Definition pass.h:101
RWPass(PassMan &man, std::string_view name)
Definition pass.h:222
const Def * rewrite(const Def *) override
custom rewrite function memoized version of rewrite_
The matrix Plugin
Definition matrix.h:9
DefMap< const Def * > Def2Def
Definition def.h:48