MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
The automatic differentiation Plugin

See also
mim::plug::autodif

Dependencies

plugin mem;
import compile;

For derivatives:

plugin core;

Types

axm %autodiff.Tangent: * → *, normalize_Tangent;

Operations

%autodiff.ad

This axiom operates on functions and types.

For function types the augmented type is computed: (T → U) => (T → U × (U → T))

axm %autodiff.AD: * → *, normalize_AD;

On closed terms (functions, operators, ho arguments, registered axioms, etc.) the augmented term is returned. The augmented term ‘f’returns the result together with the pullback. autodiff f = f' = λ args. (f args, f*)`

axm %autodiff.ad: {T: *} → T → %autodiff.AD T, normalize_ad;

%autodiff.zero

Represents universal zero such that (zero T) +_T t = t.

axm %autodiff.zero: [T: *] → T, normalize_zero;

%autodiff.add

A universal addition that consumes zeros and defaults to normal addition for scalar types. It lifts additions over types with structure and performs special casing for types with do not allow for addition. The sum construct performs addition over a list of terms.

TODO: how do we handle summations that need memory? (grab current memory?)

axm %autodiff.add: {T: *} → [T, T] → T, normalize_add;
axm %autodiff.sum: [n:Nat, T: *] → «n; T» → T, normalize_sum;

Passes and Phases

Passes

axm %autodiff.ad_eval_pass: %compile.Pass;
axm %autodiff.ad_zero_pass: %compile.Pass;
axm %autodiff.ad_zero_cleanup_pass: %compile.Pass;
axm %autodiff.ad_ext_cleanup_pass: %compile.Pass;

Phases

let ad_cleanup_phase =
%compile.phases_to_phase (⊤:Nat) (
(%compile.passes_to_phase 1 %autodiff.ad_zero_cleanup_pass),
(%compile.passes_to_phase 1 %autodiff.ad_ext_cleanup_pass)
);
let ad_phase =
%compile.phases_to_phase (⊤:Nat) (
optimization_phase,
(%compile.passes_to_phase 1 %autodiff.ad_eval_pass),
(%compile.passes_to_phase 1 %autodiff.ad_zero_pass),
ad_cleanup_phase
);

Registered translations

In this section, we define translations for axioms of other plugins. This best would be done using a register mechanism in a third plugin or at least in a separate file.

The general concept is that a call to an axiom is replaced with a call to the augmented axiom. The augmented axiom needs a wrapper for meta arguments (HO-function). Appropiate cps2ds wrappers are introduced to handle that the augmented axioms are in cps where the original axioms were in ds. Example:

mul' => args → result*pullback
call: r = mul (m,w) (a,b)
res : r,r* = mul' (m,w) (a,b)

The types (with Int for (Int w)) are:

mul : [m:Nat,w:Nat] → [a:Int,b:Int] → Int
r : Int
r* : cn[Int,cn[Int,Int]]

The pullback has to be in cps for compliance.

mul* := λ s. (s*b,s*a)
mul'_cps : [m:Nat,w:Nat] → cn[[Int,Int],cn[Int, cn[Int,cn[Int,Int]]]]
r,r* = (cps2ds (mul'_cps (m,w))) (a,b)

The pullback is the derivative with respect to the input (weighted with the out tangent). For arithmetic operations, s is simply multiplied to each input tangent: ∂_i f(x1,...,xn) * s You will also come to the conclusion that the applied partial pullback needs to be: sum x_i*(∂_i f(x1,...,xn) * s) = sum x_i*(•) with • as the formula from above This is a direct result from the chain composition with the partial pullback of a tuple. The tuple pullback transports the partial pullbacks of the operands and handles the sums. By its nature the pullback of a tuple needs to be a sum.

%core.icmp.xYgLE (eq)

The comparison pullback exists formally but is not used.

fun extern internal_diff_core_icmp_xYgLE {s: Nat} (ab: «2; Idx s»)@tt: [Bool, Cn[Bool, Cn«2; Idx s»]] =
return (%core.icmp.sle ab, fn [Bool]@tt: «2; Idx s» = return ‹2; 0:(Idx s)›);

%core.wrap.add

s ↦ (s, s)

fun extern internal_diff_core_wrap_add {s: Nat} (m: Nat) (ab: «2; Idx s»)@tt: [Idx s, Cn[Idx s, Cn«2; Idx s»]] =
return (%core.wrap.add m ab, fn (i: Idx s)@tt: «2; Idx s» = return ‹2; i›);

%core.wrap.mul

s ↦ (s*b, s*a)

fun extern internal_diff_core_wrap_mul {s: Nat} (m: Nat) (a b: Idx s) as ab@tt: [Idx s, Cn[Idx s, Cn«2; Idx s»]] =
return (%core.wrap.mul m ab, fn (i: Idx s)@tt: «2; Idx s» = return (%core.wrap.mul m (i, b), %core.wrap.mul m (i, a)));