- See also
- mim::plug::matrix
Dependencies
plugin core;
plugin math;
// needed to access cps2ds
plugin direct;
plugin affine;
Types
%matrix.Mat
Matrices are n-dimensional tensors with elements of type T. They can be seen as a generalization of Coq's vector type (a container with a fixed number of elements specified on type level).
matrix = [n: Nat, S: «n; Nat», T: *] → *
matrix n S T = «S_0; «S_1; ... «S_{n-1}; T» ... »»
=> a matrix is a dependend array
axm %matrix.Mat: [n: Nat, S: «n; Nat», T: *] → *;
Operations
%matrix.shape
Extracts the size along the i-th dimension from the type. For a dependent matrix this is a simple projection to S(i).
Normalization
- resolve shape calls at construction by replacing them with the size argument
axm %matrix.shape: [n: Nat, S: «n; Nat», T: *]::nST [%matrix.Mat nST, i: Idx n] → Nat, normalize_shape;
%matrix.const
a constant matrix
axm %matrix.constMat: [n: Nat, S: «n; Nat», T: *]::nST [%mem.M, T] → [%mem.M, %matrix.Mat nST];
%matrix.read
read _ (mat, idx) : body_type
Accesses an element of the matrix. (currently: arithmetic pointer access)
Normalization:
axm %matrix.read: [n: Nat, S: «n; Nat», T: *]::nST [%mem.M, %matrix.Mat nST, idx: «i: n; Idx S#i»] → [%mem.M, T], normalize_read;
%matrix.insert
insert (dims, sizes, type) (mat, idx, val) : mat
Depending on the matrix implementation, this operations needs the mem monad The implementation can be either as write or array insertion.
Normalization:
- with other inserts
- with initialization
axm %matrix.insert: [n: Nat, S: «n; Nat», T: *]::nST [%mem.M, %matrix.Mat nST, idx: «i: n; Idx S#i», val: T] → [%mem.M, %matrix.Mat nST], normalize_insert;
%matrix.init
A fresh matrix with uninitialized values.
axm %matrix.init: [n: Nat, S: «n; Nat», T: *, %mem.M] → [%mem.M, %matrix.Mat (n, S, T)];
High-level matrix operations
// TODO: define alias: * fst, snd, split * zip = zipWith id
axm %matrix.prod: [m: Nat, k: Nat, l: Nat, [p: Nat, e:Nat]]
[%mem.M, %matrix.Mat (2, (m, k), %math.F (p, e)), %matrix.Mat (2, (k, l), %math.F (p, e))] →
[%mem.M, %matrix.Mat (2, (m, l), %math.F (p, e))], normalize_prod;
axm %matrix.transpose: [[k:Nat, l:Nat], T: *]
[%mem.M, %matrix.Mat (2, (k, l), T)] →
[%mem.M, %matrix.Mat (2, (l, k), T)], normalize_transpose;
axm %matrix.sum: [n: Nat, S: «n; Nat», [p:Nat, e:Nat]]
[%mem.M, %matrix.Mat (n, S, %math.F (p, e))] →
[%mem.M, %math.F (p, e)];
// TODO: handle reduction case: n=0, S=[] => not empty but scalar
Our notation is inspired by einsum (with some generalizations):
The map_reduce
operation can be seen as the minimal abstraction over general iteration/control flow schemes over tensors.
map_reduce applications:
einsum(idx, MatrixIndices) = map_reduce(0,+, product, MatrixIndices)
map f M = map_reduce (0,+, f,[(idx, M)])
(TODO: get rid of reduce step if not needed with dummy values)
reduce acc f M = map_reduce (n=0) (acc, f, id,[(idx, M)])
(TODO: see index problem above) einsum application:
tranpose ij→ji (einsum(,[(1,0), M]))
trace ii→
sum ij →
col sum ij → j
mat vec prod ik, k→i
mat mat prod ik, kj → ij
dot product i, i →
dot matrix ij, ij →
outer product i, j → ij
TODO: introduce dummy values (zero, add, ...) in refly and use these dummy = has correct type but can not produce code (should always be eliminated) axm %matrix.map_reduce:
// out shape depends on in shape but is complex
[
n: Nat, S: «n; Nat», T: *, // out shape
m: Nat, // number of inputs
NI: «m; Nat», // input dimensions
TI: «m; *», // input types
SI: «i:m; «NI#i; Nat»» // input shapes
][
mem: %mem.M, // memory
zero: T, // initial value
// TODO: propagate change: no addition but instead take acc as argument (like mlir.linarith.generic)
comb: Fn [%mem.M, T, «i: m; TI#i»] → [%mem.M, T], // inner combination
// out_index not needed => always ij (0 ... n) for n dimensions
input: «i: m; [«NI#i; Nat», %matrix.Mat (NI#i, SI#i, TI#i)]»
] → [%mem.M, %matrix.Mat (n, S, T)],
normalize_map_reduce;
Unfolding functions
product
Follow the principle ij <- ik, kj
(out[i, j] = sum_k in1[i, k] * in2[k, j]
) by using mulplication as combination function and addition as reduction function.
fun extern internal_mapRed_matrix_prod(m k l: Nat, pe: «2; Nat»)
(mem: %mem.M, M: %matrix.Mat (2, (m, k), %math.F pe),
N: %matrix.Mat (2, (k, l), %math.F pe))@tt
: [ %mem.M, %matrix.Mat (2, (m, l), %math.F pe)] =
let R = %math.F pe;
let zero_real = %math.conv.f2f pe 0.0:%math.F64;
return (
%matrix.map_reduce
(2, (m, l), R,
2,
(2, 2),
(R, R),
((m, k), (k, l))
)
(
mem,
zero_real,
fn (mem: %mem.M, acc: R, ab: «2; R»): [%mem.M, R] =
return (mem, %math.arith.add 0 (acc, %math.arith.mul 0 ab)),
(
((0, 2), M),
((2, 1), N)
)
)
);
transpose
Transpose a matrix by iterating the indices in swapped order.
// TODO: check code for 1-matrix edge case
// TODO: would this automatically be handled by read(transpose) ?
fun extern internal_mapRed_matrix_transpose((k l: Nat), T: *)
(mem: %mem.M, M: %matrix.Mat (2, (k, l), T))@tt
: [ %mem.M, %matrix.Mat (2, (l, k), T)] =
let zero = ⊥:T; // TODO: use generalized zero
return (
%matrix.map_reduce
(2, (l, k), T,
1,
2,
T,
(k, l)
)
(
mem,
zero,
// We ignore the (zero) accumulator and just return the read value.
fn (mem: %mem.M, acc a: T): [%mem.M, T] = return (mem, a),
((1, 0), M)
)
);
sum
Sums up all elements of a matrix and returns a scalar.
// TODO: test 0d matrix (edge cases in code)
fun extern internal_mapRed_matrix_sum(n: Nat, S: «n; Nat», pe: «2; Nat»)
(mem: %mem.M, M: %matrix.Mat (n, S, %math.F pe))@tt
: [%mem.M, %math.F pe] =
let R = %math.F pe;
// TODO: use generalized zero
let zero_64 = 0.0:(%math.F (52,11));
let zero_real = %math.conv.f2f pe zero_64;
// should be normalized to lit tuple
let idxs = ‹i: n; %core.nat.add (1, %core.bitcast Nat i)›;
let (`mem, res) = %matrix.map_reduce
(1, (1), R,
1,
n,
R,
S
)
(
mem,
zero_real,
fn (mem: %mem.M, acc: R, a: R): [%mem.M, R]
= return (mem, %math.arith.add 0 (acc, a)),
(
(idxs, M)
)
);
return (mem, %core.bitcast R res); // TODO: test this cast
Passes and Phases
Passes
axm %matrix.lower_matrix_high_level_map_reduce: %compile.Pass;
axm %matrix.lower_matrix_medium_level: %compile.Pass;
axm %matrix.internal_map_reduce_cleanup: %compile.Pass;
axm %matrix.lower_matrix_low_level: %compile.Phase;
Phases
let matrix_lower_phase =
%compile.phases_to_phase (⊤:Nat)
(
(%compile.pass_phase
(%compile.pass_list
%matrix.lower_matrix_high_level_map_reduce
%matrix.lower_matrix_medium_level)),
// TODO: only in map_red namespace
%compile.single_pass_phase %matrix.internal_map_reduce_cleanup,
%matrix.lower_matrix_low_level
);