- See also
- thorin::plug::matrix
Dependencies
.plugin core;
.plugin math;
// needed to access cps2ds
.plugin direct;
.plugin affine;
Types
%matrix.Mat
Thorin matrices are n-dimensional tensors with elements of type T. They can be seen as a generalization of Coq's vector type (a container with a fixed number of elements specified on type level).
matrix = Π [n: .Nat, S: «n; .Nat», T: *] -> *
matrix n S T = «S_0; «S_1; ... «S_{n-1}; T» ... »»
=> a matrix is a dependend array
.ax %matrix.Mat: Π [n: .Nat, S: «n; .Nat», T: *] -> *;
Operations
%matrix.shape
Extracts the size along the i-th dimension from the type. For a dependent matrix this is a simple projection to S(i).
Normalization
- resolve shape calls at construction by replacing them with the size argument
.ax %matrix.shape: Π nST::[n: .Nat, S: «n; .Nat», T: *][%matrix.Mat nST, i: .Idx n] -> .Nat, normalize_shape;
%matrix.const
a constant matrix
.ax %matrix.constMat: Π nST::[n: .Nat, S: «n; .Nat», T: *][%mem.M, T] -> [%mem.M, %matrix.Mat nST];
%matrix.read
read _ (mat, idx) : body_type
Accesses an element of the matrix. (currently: arithmetic pointer access)
Normalization:
.ax %matrix.read: Π nST::[n: .Nat, S: «n; .Nat», T: *][%mem.M, %matrix.Mat nST, idx: «i: n; .Idx S#i»] -> [%mem.M, T], normalize_read;
%matrix.insert
insert (dims, sizes, type) (mat, idx, val) : mat
Depending on the matrix implementation, this operations needs the mem monad The implementation can be either as write or array insertion.
Normalization:
- with other inserts
- with initialization
.ax %matrix.insert: Π nST::[n: .Nat, S: «n; .Nat», T: *][%mem.M, %matrix.Mat nST, idx: «i: n; .Idx S#i», val: T] -> [%mem.M, %matrix.Mat nST], normalize_insert;
%matrix.init
A fresh matrix with uninitialized values.
.ax %matrix.init: Π [n: .Nat, S: «n; .Nat», T: *, %mem.M] -> [%mem.M, %matrix.Mat (n, S, T)];
High-level matrix operations
// TODO: define alias: * fst, snd, split * zip = zipWith id
.ax %matrix.prod: Π [m: .Nat, k: .Nat, l: .Nat, [p: .Nat, e:.Nat]]
[%mem.M, %matrix.Mat (2, (m, k), %math.F (p, e)), %matrix.Mat (2, (k, l), %math.F (p, e))]
-> [%mem.M, %matrix.Mat (2, (m, l), %math.F (p, e))], normalize_prod;
.ax %matrix.transpose: Π [[k:.Nat, l:.Nat], T: *]
[%mem.M, %matrix.Mat (2, (k, l), T)]
-> [%mem.M, %matrix.Mat (2, (l, k), T)], normalize_transpose;
.ax %matrix.sum: Π [n: .Nat, S: «n; .Nat», [p:.Nat, e:.Nat]]
[%mem.M, %matrix.Mat (n, S, %math.F (p, e))]
-> [%mem.M, %math.F (p, e)];
// TODO: handle reduction case: n=0, S=[] => not empty but scalar
Our notation is inspired by einsum (with some generalizations):
The map_reduce
operation can be seen as the minimal abstraction over general iteration/control flow schemes over tensors.
map_reduce applications:
einsum(idx, MatrixIndices) = map_reduce(0,+, product, MatrixIndices)
map f M = map_reduce (0,+, f,[(idx, M)])
(TODO: get rid of reduce step if not needed with dummy values)
reduce acc f M = map_reduce (n=0) (acc, f, id,[(idx, M)])
(TODO: see index problem above) einsum application:
tranpose ij->ji (einsum(,[(1,0), M]))
trace ii->
sum ij ->
col sum ij -> j
mat vec prod ik, k->i
mat mat prod ik, kj -> ij
dot product i, i ->
dot matrix ij, ij ->
outer product i, j -> ij
TODO: introduce dummy values (zero, add, ...) in refly and use these dummy = has correct type but can not produce code (should always be eliminated) .ax %matrix.map_reduce:
// out shape depends on in shape but is complex
Π [
n: .Nat, S: «n; .Nat», T: *, // out shape
m: .Nat, // number of inputs
NI: «m; .Nat», // input dimensions
TI: «m; *», // input types
SI: «i:m; «NI#i; .Nat»» // input shapes
][
mem: %mem.M, // memory
zero: T, // initial value
// TODO: propagate change: no addition but instead take acc as argument (like mlir.linarith.generic)
comb: .Fn [%mem.M, T, «i: m; TI#i»] -> [%mem.M, T], // inner combination
// out_index not needed => always ij (0 ... n) for n dimensions
input: «i: m; [«NI#i; .Nat», %matrix.Mat (NI#i, SI#i, TI#i)]»
] -> [%mem.M, %matrix.Mat (n, S, T)],
normalize_map_reduce;
Unfolding functions
product
Follow the principle ij <- ik, kj
(out[i, j] = sum_k in1[i, k] * in2[k, j]
) by using mulplication as combination function and addition as reduction function.
.fun .extern internal_mapRed_matrix_prod!(m k l: .Nat, pe: «2; .Nat»)
!(mem: %mem.M, M: %matrix.Mat (2, (m, k), %math.F pe),
N: %matrix.Mat (2, (k, l), %math.F pe))
:[ %mem.M, %matrix.Mat (2, (m, l), %math.F pe)] =
.let R = %math.F pe;
.let zero_real = %math.conv.f2f pe 0.0:%math.F64;
return (
%matrix.map_reduce
(2, (m, l), R,
2,
(2, 2),
(R, R),
((m, k), (k, l))
)
(
mem,
zero_real,
.fn (mem: %mem.M, acc: R, ab: «2; R»): [%mem.M, R] =
return (mem, %math.arith.add 0 (acc, %math.arith.mul 0 ab)),
(
((0, 2), M),
((2, 1), N)
)
)
);
transpose
Transpose a matrix by iterating the indices in swapped order.
// TODO: check code for 1-matrix edge case
// TODO: would this automatically be handled by read(transpose) ?
.fun .extern internal_mapRed_matrix_transpose!((k l: .Nat), T: *)
!(mem: %mem.M, M: %matrix.Mat (2, (k, l), T)) :
[ %mem.M, %matrix.Mat (2, (l, k), T)] =
.let zero = ⊥:T; // TODO: use generalized zero
return (
%matrix.map_reduce
(2, (l, k), T,
1,
2,
T,
(k, l)
)
(
mem,
zero,
// We ignore the (zero) accumulator and just return the read value.
.fn (mem: %mem.M, acc a: T): [%mem.M, T] = return (mem, a),
((1, 0), M)
)
);
sum
Sums up all elements of a matrix and returns a scalar.
// TODO: test 0d matrix (edge cases in code)
.fun .extern internal_mapRed_matrix_sum!(n: .Nat, S: «n; .Nat», pe: «2; .Nat»)
!(mem: %mem.M, M: %matrix.Mat (n, S, %math.F pe))
:[%mem.M, %math.F pe] =
.let R = %math.F pe;
// TODO: use generalized zero
.let zero_64 = 0.0:(%math.F (52,11));
.let zero_real = %math.conv.f2f pe zero_64;
// should be normalized to lit tuple
.let idxs = ‹i: n; %core.nat.add (1, %core.bitcast .Nat i)›;
.let (`mem, res) = %matrix.map_reduce
(1, (1), R,
1,
n,
R,
S
)
(
mem,
zero_real,
.fn (mem: %mem.M, acc: R, a: R): [%mem.M, R]
= return (mem, %math.arith.add 0 (acc, a)),
(
(idxs, M)
)
);
return (mem, %core.bitcast R res); // TODO: test this cast
Passes and Phases
Passes
.ax %matrix.lower_matrix_high_level_map_reduce: %compile.Pass;
.ax %matrix.lower_matrix_medium_level: %compile.Pass;
.ax %matrix.internal_map_reduce_cleanup: %compile.Pass;
.ax %matrix.lower_matrix_low_level: %compile.Phase;
Phases
.let matrix_lower_phase = {
%compile.phases_to_phase (⊤:.Nat)
(
(%compile.pass_phase (%compile.pass_list
%matrix.lower_matrix_high_level_map_reduce
%matrix.lower_matrix_medium_level
)),
// TODO: only in map_red namespace
%compile.single_pass_phase %matrix.internal_map_reduce_cleanup,
%matrix.lower_matrix_low_level
)
};