- See also
- mim::plug::tensor
Dependencies
plugin tuple;
plugin refly;
plugin vec;
A tensor plugin
Types
Represents an algebraic Ring.
let Ring = [
T: *,
_0: T,
_1: T,
add: [T, T] → T,
mul: [T, T] → T,
];
Operations
%tensor.prod_2d
let nat_ring = (Nat, 0, 1, core.nat.add, core.nat.mul);
%tensor.element_wise
axm %tensor.map: {T: *}
→ {r: Nat, s: «r; Nat»}
→ {ni: Nat, Is: «ni; *»}
→ [app: «i: ni; Is#i» → T]
→ [is: «i: ni; «s; Is#i» »]
→ «s; T»;
%tensor.transpose
axm %tensor.transpose: {T: *}
→ {r: Nat, s: «r; Nat»}
→ [input: «s;T», permutation: «r;Idx r»]
→ « <j: r; s#(permutation#j)>; T»;
%tensor.slice
axm %tensor.slice: {T: *}
→ {r: Nat, s: «r; Nat»}
→ [input: «s;T», output_shape: «r; Nat», start_indices: «i :r; Idx (s#i)», steps: «i: r; Nat»]
→ «output_shape; T»;
%tensor.reshape
axm %tensor.reshape: {T: *}
→ {r: Nat, s: «r; Nat»}
→ [input: «s;T», output_shape: «r; Nat»]
→ «output_shape; T»;
%tensor.dot_general
Generalized dot product:
out_r = nb + (r1 - nc - nb) + (r2 - nc - nb)
axm %tensor.dot_general: [R: Ring]
→ {r1 r2: Nat}
→ {nc nb: Nat}
→ {s1: «r1; Nat», s2: «r2; Nat»}
→ [c1: «nc; Idx r1», c2: «nc; Idx r2», b1: «nb; Idx r1», b2: «nb; Idx r2»]
→ [«s1; R#T», «s2; R#T»]
→ let bs = ‹i: nb; %refly.check (%core.ncmp.e (s1#(b1#i), s2#(b2#i)), s1#(b1#i), "batching dims don't match")›;
let cs = ‹i: nc; %refly.check (%core.ncmp.e (s1#(c1#i), s2#(c2#i)), s1#(c1#i), "contracting dims don't match")›;
let bc_1 = %tuple.cat (b1, c1);
let bc_2 = %tuple.cat (b2, c2);
let s12_res = %tuple.cat (%vec.diff (s1, bc_1), %vec.diff (s2, bc_2));
let s_out = %tuple.cat (bs, s12_res);
«s_out; R#T», normalize_dot;
%tensor.prod_2d
//lam %tensor.prod_2d_lam {R: Ring} {m: Nat, k: Nat, l: Nat} [t1: «m, k; R#T», t2: «k, l; R#T»]
// : «m, l; R#T»
// = %tensor.dot_general R (0_2, 1_2, (), ()) (t1, t2);
axm %tensor.dot_2d_00: {R: Ring} → {m: Nat, k: Nat, l: Nat} → [t1: «m, k; R#T», t2: «k, l; R#T»] → «m, l; R#T»;
axm %tensor.dot_2d_01: {R: Ring} → {m: Nat, k: Nat, l: Nat} → [t1: «m, k; R#T», t2: «k, l; R#T»] → «m, l; R#T»;
axm %tensor.dot_2d_10: {R: Ring} → {m: Nat, k: Nat, l: Nat} → [t1: «m, k; R#T», t2: «k, l; R#T»] → «m, l; R#T»;
axm %tensor.dot_2d_11: {R: Ring} → {m: Nat, k: Nat, l: Nat} → [t1: «m, k; R#T», t2: «k, l; R#T»] → «m, l; R#T»;
%tensor.broadcast_in_dim
axm %tensor.broadcast_in_dim: {T: *}
→ {r_in r_out: Nat, s_in: «r_in; Nat», s_out: «r_out; Nat»}
→ [«s_in; T», «r_in; Idx r_in»]
→ «s_out; T»;
%tensor.reduce
ni
tensors each of shape s
and type Is#i
in arg is
- i: ni is a notation for i : (Idx ni)
axm %tensor.reduce: {r: Nat, s: «r; Nat»}
→ {ni: Nat, Is: «ni; *»}
→ [f: «2; «i: ni; Is#i» » → «i: ni; Is#i»]
→ [is: «i: ni; «s; Is#i» », init: «i: ni; Is#i», dims: «r; Bool»]
→ «i: ni; « <j: r; (s#j, 1)#(dims#j)>; Is#i» »;
%tensor.map_reduce
nis
: number of inputs
T/R/S
is/o
: respectively the type/rank/shape of the inputs/output
f
: function to reduce over (takes an element of type To
and one of each type in Tis
, and returns a To
)
init
: accumulator to start f
with
subs
: for each input, for each dimension, an index to compute the output in Einstein notation
is
: the inputs
Returns a tensor obtained by folding f
following the indexes in subs
axm %tensor.map_reduce: {nis: Nat}
→ {To: *, Ro: Nat}
→ [So: «Ro; Nat»]
→ {Tis: «nis; *», Ris: «i: nis; Nat», Sis: «i:nis; «Ris#i; Nat»»}
→ [f: [To, «i: nis; Tis#i»] → To, init: To]
→ [subs: «i: nis; « Ris#i; Nat»»]
→ [is: «i: nis; «Sis#i; Tis#i» »]
→ «So; To»;