MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
The tensor Plugin

See also
mim::plug::tensor

Dependencies

plugin tuple;
plugin refly;
plugin vec;

A tensor plugin

Types

Represents an algebraic Ring.

let Ring = [
T: *,
_0: T,
_1: T,
add: [T, T] → T,
mul: [T, T] → T,
];

Operations

%tensor.prod_2d

let nat_ring = (Nat, 0, 1, core.nat.add, core.nat.mul);

%tensor.element_wise

axm %tensor.map: {T: *}
→ {r: Nat, s: «r; Nat»}
→ {ni: Nat, Is: «ni; *»}
→ [app: «i: ni; Is#i» → T]
→ [is: «i: ni; «s; Is#i» »]
→ «s; T»;

%tensor.transpose

axm %tensor.transpose: {T: *}
→ {r: Nat, s: «r; Nat»}
→ [input: «s;T», permutation: «r;Idx r»]
→ « <j: r; s#(permutation#j)>; T»;

%tensor.slice

axm %tensor.slice: {T: *}
→ {r: Nat, s: «r; Nat»}
→ [input: «s;T», output_shape: «r; Nat», start_indices: «i :r; Idx (s#i)», steps: «i: r; Nat»]
→ «output_shape; T»;

%tensor.reshape

axm %tensor.reshape: {T: *}
→ {r: Nat, s: «r; Nat»}
→ [input: «s;T», output_shape: «r; Nat»]
→ «output_shape; T»;

%tensor.dot_general

Generalized dot product:

  • out_r = nb + (r1 - nc - nb) + (r2 - nc - nb)
axm %tensor.dot_general: [R: Ring]
→ {r1 r2: Nat}
→ {nc nb: Nat}
→ {s1: «r1; Nat», s2: «r2; Nat»}
→ [c1: «nc; Idx r1», c2: «nc; Idx r2», b1: «nb; Idx r1», b2: «nb; Idx r2»]
→ [«s1; R#T», «s2; R#T»]
→ let bs = ‹i: nb; %refly.check (%core.ncmp.e (s1#(b1#i), s2#(b2#i)), s1#(b1#i), "batching dims don't match")›;
let cs = ‹i: nc; %refly.check (%core.ncmp.e (s1#(c1#i), s2#(c2#i)), s1#(c1#i), "contracting dims don't match")›;
let bc_1 = %tuple.cat (b1, c1);
let bc_2 = %tuple.cat (b2, c2);
let s12_res = %tuple.cat (%vec.diff (s1, bc_1), %vec.diff (s2, bc_2));
let s_out = %tuple.cat (bs, s12_res);
«s_out; R#T», normalize_dot;

%tensor.prod_2d

//lam %tensor.prod_2d_lam {R: Ring} {m: Nat, k: Nat, l: Nat} [t1: «m, k; R#T», t2: «k, l; R#T»]
// : «m, l; R#T»
// = %tensor.dot_general R (0_2, 1_2, (), ()) (t1, t2);
axm %tensor.dot_2d_00: {R: Ring} → {m: Nat, k: Nat, l: Nat} → [t1: «m, k; R#T», t2: «k, l; R#T»] → «m, l; R#T»;
axm %tensor.dot_2d_01: {R: Ring} → {m: Nat, k: Nat, l: Nat} → [t1: «m, k; R#T», t2: «k, l; R#T»] → «m, l; R#T»;
axm %tensor.dot_2d_10: {R: Ring} → {m: Nat, k: Nat, l: Nat} → [t1: «m, k; R#T», t2: «k, l; R#T»] → «m, l; R#T»;
axm %tensor.dot_2d_11: {R: Ring} → {m: Nat, k: Nat, l: Nat} → [t1: «m, k; R#T», t2: «k, l; R#T»] → «m, l; R#T»;

%tensor.broadcast_in_dim

axm %tensor.broadcast_in_dim: {T: *}
→ {r_in r_out: Nat, s_in: «r_in; Nat», s_out: «r_out; Nat»}
→ [«s_in; T», «r_in; Idx r_in»]
→ «s_out; T»;

%tensor.reduce

ni tensors each of shape s and type Is#i in arg is

  • i: ni is a notation for i : (Idx ni)
axm %tensor.reduce: {r: Nat, s: «r; Nat»}
→ {ni: Nat, Is: «ni; *»}
→ [f: «2; «i: ni; Is#i» » → «i: ni; Is#i»]
→ [is: «i: ni; «s; Is#i» », init: «i: ni; Is#i», dims: «r; Bool»]
→ «i: ni; « <j: r; (s#j, 1)#(dims#j)>; Is#i» »;

%tensor.map_reduce

  • nis: number of inputs
  • T/R/Sis/o : respectively the type/rank/shape of the inputs/output
  • f : function to reduce over (takes an element of type To and one of each type in Tis, and returns a To)
  • init : accumulator to start f with
  • subs : for each input, for each dimension, an index to compute the output in Einstein notation
  • is : the inputs

Returns a tensor obtained by folding f following the indexes in subs

axm %tensor.map_reduce: {nis: Nat}
→ {To: *, Ro: Nat}
→ [So: «Ro; Nat»]
→ {Tis: «nis; *», Ris: «i: nis; Nat», Sis: «i:nis; «Ris#i; Nat»»}
→ [f: [To, «i: nis; Tis#i»] → To, init: To]
→ [subs: «i: nis; « Ris#i; Nat»»]
→ [is: «i: nis; «Sis#i; Tis#i» »]
→ «So; To»;