MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
The matrix Plugin

See also
mim::plug::matrix

Dependencies

plugin core;
plugin math;
// needed to access cps2ds
plugin direct;
plugin affine;

Types

%matrix.Mat

Matrices are n-dimensional tensors with elements of type T. They can be seen as a generalization of Coq's vector type (a container with a fixed number of elements specified on type level).

matrix = [n: Nat, S: «n; Nat», T: *] → *
matrix n S T = «S_0; «S_1; ... «S_{n-1}; T» ... »»

=> a matrix is a dependend array

axm %matrix.Mat: [n: Nat, S: «n; Nat», T: *] → *;

Operations

%matrix.shape

Extracts the size along the i-th dimension from the type. For a dependent matrix this is a simple projection to S(i).

Normalization

  • resolve shape calls at construction by replacing them with the size argument
axm %matrix.shape: [n: Nat, S: «n; Nat», T: *]::nST [%matrix.Mat nST, i: Idx n] → Nat, normalize_shape;

%matrix.const

a constant matrix

axm %matrix.constMat: [n: Nat, S: «n; Nat», T: *]::nST [%mem.M, T] → [%mem.M, %matrix.Mat nST];

%matrix.read

read _ (mat, idx) : body_type

Accesses an element of the matrix. (currently: arithmetic pointer access)

Normalization:

  • read(insert)
  • read(const)
axm %matrix.read: [n: Nat, S: «n; Nat», T: *]::nST [%mem.M, %matrix.Mat nST, idx: «i: n; Idx S#i»] → [%mem.M, T], normalize_read;

%matrix.insert

insert (dims, sizes, type) (mat, idx, val) : mat

Depending on the matrix implementation, this operations needs the mem monad The implementation can be either as write or array insertion.

Normalization:

  • with other inserts
  • with initialization
axm %matrix.insert: [n: Nat, S: «n; Nat», T: *]::nST [%mem.M, %matrix.Mat nST, idx: «i: n; Idx S#i», val: T] → [%mem.M, %matrix.Mat nST], normalize_insert;

%matrix.init

A fresh matrix with uninitialized values.

axm %matrix.init: [n: Nat, S: «n; Nat», T: *, %mem.M] → [%mem.M, %matrix.Mat (n, S, T)];

High-level matrix operations

// TODO: define alias: * fst, snd, split * zip = zipWith id
axm %matrix.prod: [m: Nat, k: Nat, l: Nat, [p: Nat, e:Nat]]
[%mem.M, %matrix.Mat (2, (m, k), %math.F (p, e)), %matrix.Mat (2, (k, l), %math.F (p, e))] →
[%mem.M, %matrix.Mat (2, (m, l), %math.F (p, e))], normalize_prod;
axm %matrix.transpose: [[k:Nat, l:Nat], T: *]
[%mem.M, %matrix.Mat (2, (k, l), T)] →
[%mem.M, %matrix.Mat (2, (l, k), T)], normalize_transpose;
axm %matrix.sum: [n: Nat, S: «n; Nat», [p:Nat, e:Nat]]
[%mem.M, %matrix.Mat (n, S, %math.F (p, e))] →
[%mem.M, %math.F (p, e)];
// TODO: handle reduction case: n=0, S=[] => not empty but scalar

Our notation is inspired by einsum (with some generalizations):

The map_reduce operation can be seen as the minimal abstraction over general iteration/control flow schemes over tensors.

map_reduce applications:

  • einsum(idx, MatrixIndices) = map_reduce(0,+, product, MatrixIndices)
  • map f M = map_reduce (0,+, f,[(idx, M)]) (TODO: get rid of reduce step if not needed with dummy values)
  • reduce acc f M = map_reduce (n=0) (acc, f, id,[(idx, M)]) (TODO: see index problem above) einsum application:
  • tranpose ij→ji (einsum(,[(1,0), M]))
  • trace ii→
  • sum ij →
  • col sum ij → j
  • mat vec prod ik, k→i
  • mat mat prod ik, kj → ij
  • dot product i, i →
  • dot matrix ij, ij →
  • outer product i, j → ij TODO: introduce dummy values (zero, add, ...) in refly and use these dummy = has correct type but can not produce code (should always be eliminated)
    axm %matrix.map_reduce:
    // out shape depends on in shape but is complex
    [
    n: Nat, S: «n; Nat», T: *, // out shape
    m: Nat, // number of inputs
    NI: «m; Nat», // input dimensions
    TI: «m; *», // input types
    SI: «i:m; «NI#i; Nat»» // input shapes
    ][
    mem: %mem.M, // memory
    zero: T, // initial value
    // TODO: propagate change: no addition but instead take acc as argument (like mlir.linarith.generic)
    comb: Fn [%mem.M, T, «i: m; TI#i»] → [%mem.M, T], // inner combination
    // out_index not needed => always ij (0 ... n) for n dimensions
    input: «i: m; [«NI#i; Nat», %matrix.Mat (NI#i, SI#i, TI#i)]»
    ] → [%mem.M, %matrix.Mat (n, S, T)],
    normalize_map_reduce;

Unfolding functions

product

Follow the principle ij <- ik, kj (out[i, j] = sum_k in1[i, k] * in2[k, j]) by using mulplication as combination function and addition as reduction function.

fun extern internal_mapRed_matrix_prod(m k l: Nat, pe: «2; Nat»)
(mem: %mem.M, M: %matrix.Mat (2, (m, k), %math.F pe),
N: %matrix.Mat (2, (k, l), %math.F pe))@tt
: [ %mem.M, %matrix.Mat (2, (m, l), %math.F pe)] =
let R = %math.F pe;
let zero_real = %math.conv.f2f pe 0.0:%math.F64;
return (
%matrix.map_reduce
(2, (m, l), R,
2,
(2, 2),
(R, R),
((m, k), (k, l))
)
(
mem,
zero_real,
fn (mem: %mem.M, acc: R, ab: «2; R»): [%mem.M, R] =
return (mem, %math.arith.add 0 (acc, %math.arith.mul 0 ab)),
(
((0, 2), M),
((2, 1), N)
)
)
);

transpose

Transpose a matrix by iterating the indices in swapped order.

// TODO: check code for 1-matrix edge case
// TODO: would this automatically be handled by read(transpose) ?
fun extern internal_mapRed_matrix_transpose((k l: Nat), T: *)
(mem: %mem.M, M: %matrix.Mat (2, (k, l), T))@tt
: [ %mem.M, %matrix.Mat (2, (l, k), T)] =
let zero = ⊥:T; // TODO: use generalized zero
return (
%matrix.map_reduce
(2, (l, k), T,
1,
2,
T,
(k, l)
)
(
mem,
zero,
// We ignore the (zero) accumulator and just return the read value.
fn (mem: %mem.M, acc a: T): [%mem.M, T] = return (mem, a),
((1, 0), M)
)
);

sum

Sums up all elements of a matrix and returns a scalar.

// TODO: test 0d matrix (edge cases in code)
fun extern internal_mapRed_matrix_sum(n: Nat, S: «n; Nat», pe: «2; Nat»)
(mem: %mem.M, M: %matrix.Mat (n, S, %math.F pe))@tt
: [%mem.M, %math.F pe] =
let R = %math.F pe;
// TODO: use generalized zero
let zero_64 = 0.0:(%math.F (52,11));
let zero_real = %math.conv.f2f pe zero_64;
// should be normalized to lit tuple
let idxs = ‹i: n; %core.nat.add (1, %core.bitcast Nat i)›;
let (`mem, res) = %matrix.map_reduce
(1, (1), R,
1,
n,
R,
S
)
(
mem,
zero_real,
fn (mem: %mem.M, acc: R, a: R): [%mem.M, R]
= return (mem, %math.arith.add 0 (acc, a)),
(
(idxs, M)
)
);
return (mem, %core.bitcast R res); // TODO: test this cast

Passes and Phases

Passes

axm %matrix.lower_matrix_high_level_map_reduce: %compile.Pass;
axm %matrix.lower_matrix_medium_level: %compile.Pass;
axm %matrix.internal_map_reduce_cleanup: %compile.Pass;
axm %matrix.lower_matrix_low_level: %compile.Phase;

Phases

let matrix_lower_phase =
%compile.phases_to_phase (⊤:Nat)
(
(%compile.pass_phase
(%compile.pass_list
%matrix.lower_matrix_high_level_map_reduce
%matrix.lower_matrix_medium_level)),
// TODO: only in map_red namespace
%compile.single_pass_phase %matrix.internal_map_reduce_cleanup,
%matrix.lower_matrix_low_level
);