Thorin 1.9.0
The Higher ORder INtermediate representation
Loading...
Searching...
No Matches
The matrix Plugin
See also
thorin::plug::matrix

Dependencies

.plugin core;
.plugin math;
// needed to access cps2ds
.plugin direct;
.plugin affine;

Types

%matrix.Mat

Thorin matrices are n-dimensional tensors with elements of type T. They can be seen as a generalization of Coq's vector type (a container with a fixed number of elements specified on type level).

matrix = Π [n: .Nat, S: «n; .Nat», T: *] -> *
matrix n S T = «S_0; «S_1; ... «S_{n-1}; T» ... »»

=> a matrix is a dependend array

.ax %matrix.Mat: Π [n: .Nat, S: «n; .Nat», T: *] -> *;

Operations

%matrix.shape

Extracts the size along the i-th dimension from the type. For a dependent matrix this is a simple projection to S(i).

Normalization

  • resolve shape calls at construction by replacing them with the size argument
.ax %matrix.shape: Π nST::[n: .Nat, S: «n; .Nat», T: *][%matrix.Mat nST, i: .Idx n] -> .Nat, normalize_shape;

%matrix.const

a constant matrix

.ax %matrix.constMat: Π nST::[n: .Nat, S: «n; .Nat», T: *][%mem.M, T] -> [%mem.M, %matrix.Mat nST];

%matrix.read

read _ (mat, idx) : body_type

Accesses an element of the matrix. (currently: arithmetic pointer access)

Normalization:

  • read(insert)
  • read(const)
.ax %matrix.read: Π nST::[n: .Nat, S: «n; .Nat», T: *][%mem.M, %matrix.Mat nST, idx: «i: n; .Idx S#i»] -> [%mem.M, T], normalize_read;

%matrix.insert

insert (dims, sizes, type) (mat, idx, val) : mat

Depending on the matrix implementation, this operations needs the mem monad The implementation can be either as write or array insertion.

Normalization:

  • with other inserts
  • with initialization
.ax %matrix.insert: Π nST::[n: .Nat, S: «n; .Nat», T: *][%mem.M, %matrix.Mat nST, idx: «i: n; .Idx S#i», val: T] -> [%mem.M, %matrix.Mat nST], normalize_insert;

%matrix.init

A fresh matrix with uninitialized values.

.ax %matrix.init: Π [n: .Nat, S: «n; .Nat», T: *, %mem.M] -> [%mem.M, %matrix.Mat (n, S, T)];

High-level matrix operations

// TODO: define alias: * fst, snd, split * zip = zipWith id
.ax %matrix.prod: Π [m: .Nat, k: .Nat, l: .Nat, [p: .Nat, e:.Nat]]
[%mem.M, %matrix.Mat (2, (m, k), %math.F (p, e)), %matrix.Mat (2, (k, l), %math.F (p, e))]
-> [%mem.M, %matrix.Mat (2, (m, l), %math.F (p, e))], normalize_prod;
.ax %matrix.transpose: Π [[k:.Nat, l:.Nat], T: *]
[%mem.M, %matrix.Mat (2, (k, l), T)]
-> [%mem.M, %matrix.Mat (2, (l, k), T)], normalize_transpose;
.ax %matrix.sum: Π [n: .Nat, S: «n; .Nat», [p:.Nat, e:.Nat]]
[%mem.M, %matrix.Mat (n, S, %math.F (p, e))]
-> [%mem.M, %math.F (p, e)];
// TODO: handle reduction case: n=0, S=[] => not empty but scalar

Our notation is inspired by einsum (with some generalizations):

The map_reduce operation can be seen as the minimal abstraction over general iteration/control flow schemes over tensors.

map_reduce applications:

  • einsum(idx, MatrixIndices) = map_reduce(0,+, product, MatrixIndices)
  • map f M = map_reduce (0,+, f,[(idx, M)]) (TODO: get rid of reduce step if not needed with dummy values)
  • reduce acc f M = map_reduce (n=0) (acc, f, id,[(idx, M)]) (TODO: see index problem above) einsum application:
  • tranpose ij->ji (einsum(,[(1,0), M]))
  • trace ii->
  • sum ij ->
  • col sum ij -> j
  • mat vec prod ik, k->i
  • mat mat prod ik, kj -> ij
  • dot product i, i ->
  • dot matrix ij, ij ->
  • outer product i, j -> ij TODO: introduce dummy values (zero, add, ...) in refly and use these dummy = has correct type but can not produce code (should always be eliminated)
    .ax %matrix.map_reduce:
    // out shape depends on in shape but is complex
    Π [
    n: .Nat, S: «n; .Nat», T: *, // out shape
    m: .Nat, // number of inputs
    NI: «m; .Nat», // input dimensions
    TI: «m; *», // input types
    SI: «i:m; «NI#i; .Nat»» // input shapes
    ][
    mem: %mem.M, // memory
    zero: T, // initial value
    // TODO: propagate change: no addition but instead take acc as argument (like mlir.linarith.generic)
    comb: .Fn [%mem.M, T, «i: m; TI#i»] -> [%mem.M, T], // inner combination
    // out_index not needed => always ij (0 ... n) for n dimensions
    input: «i: m; [«NI#i; .Nat», %matrix.Mat (NI#i, SI#i, TI#i)]»
    ] -> [%mem.M, %matrix.Mat (n, S, T)],
    normalize_map_reduce;

Unfolding functions

product

Follow the principle ij <- ik, kj (out[i, j] = sum_k in1[i, k] * in2[k, j]) by using mulplication as combination function and addition as reduction function.

.fun .extern internal_mapRed_matrix_prod!(m k l: .Nat, pe: «2; .Nat»)
!(mem: %mem.M, M: %matrix.Mat (2, (m, k), %math.F pe),
N: %matrix.Mat (2, (k, l), %math.F pe))
:[ %mem.M, %matrix.Mat (2, (m, l), %math.F pe)] =
.let R = %math.F pe;
.let zero_real = %math.conv.f2f pe 0.0:%math.F64;
return (
%matrix.map_reduce
(2, (m, l), R,
2,
(2, 2),
(R, R),
((m, k), (k, l))
)
(
mem,
zero_real,
.fn (mem: %mem.M, acc: R, ab: «2; R»): [%mem.M, R] =
return (mem, %math.arith.add 0 (acc, %math.arith.mul 0 ab)),
(
((0, 2), M),
((2, 1), N)
)
)
);

transpose

Transpose a matrix by iterating the indices in swapped order.

// TODO: check code for 1-matrix edge case
// TODO: would this automatically be handled by read(transpose) ?
.fun .extern internal_mapRed_matrix_transpose!((k l: .Nat), T: *)
!(mem: %mem.M, M: %matrix.Mat (2, (k, l), T)) :
[ %mem.M, %matrix.Mat (2, (l, k), T)] =
.let zero = ⊥:T; // TODO: use generalized zero
return (
%matrix.map_reduce
(2, (l, k), T,
1,
2,
T,
(k, l)
)
(
mem,
zero,
// We ignore the (zero) accumulator and just return the read value.
.fn (mem: %mem.M, acc a: T): [%mem.M, T] = return (mem, a),
((1, 0), M)
)
);

sum

Sums up all elements of a matrix and returns a scalar.

// TODO: test 0d matrix (edge cases in code)
.fun .extern internal_mapRed_matrix_sum!(n: .Nat, S: «n; .Nat», pe: «2; .Nat»)
!(mem: %mem.M, M: %matrix.Mat (n, S, %math.F pe))
:[%mem.M, %math.F pe] =
.let R = %math.F pe;
// TODO: use generalized zero
.let zero_64 = 0.0:(%math.F (52,11));
.let zero_real = %math.conv.f2f pe zero_64;
// should be normalized to lit tuple
.let idxs = ‹i: n; %core.nat.add (1, %core.bitcast .Nat i)›;
.let (`mem, res) = %matrix.map_reduce
(1, (1), R,
1,
n,
R,
S
)
(
mem,
zero_real,
.fn (mem: %mem.M, acc: R, a: R): [%mem.M, R]
= return (mem, %math.arith.add 0 (acc, a)),
(
(idxs, M)
)
);
return (mem, %core.bitcast R res); // TODO: test this cast

Passes and Phases

Passes

.ax %matrix.lower_matrix_high_level_map_reduce: %compile.Pass;
.ax %matrix.lower_matrix_medium_level: %compile.Pass;
.ax %matrix.internal_map_reduce_cleanup: %compile.Pass;
.ax %matrix.lower_matrix_low_level: %compile.Phase;

Phases

.let matrix_lower_phase = {
%compile.phases_to_phase (⊤:.Nat)
(
(%compile.pass_phase (%compile.pass_list
%matrix.lower_matrix_high_level_map_reduce
%matrix.lower_matrix_medium_level
)),
// TODO: only in map_red namespace
%compile.single_pass_phase %matrix.internal_map_reduce_cleanup,
%matrix.lower_matrix_low_level
)
};