MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff_eval.h
Go to the documentation of this file.
1#pragma once
2
3#include <mim/def.h>
4#include <mim/pass/pass.h>
5
6namespace mim::plug::autodiff {
7
8/// This pass is the heart of AD.
9/// We replace an `autodiff fun` call with the differentiated function.
10class AutoDiffEval : public RWPass<AutoDiffEval, Lam> {
11public:
13 : RWPass(man, "autodiff_eval") {}
14
15 /// Detect autodiff calls.
16 Ref rewrite(Ref) override;
17
18 /// Acts on toplevel autodiff on closed terms:
19 /// * Replaces lambdas, operators with the appropriate derivatives.
20 /// * Creates new lambda, calls associate variables, init maps, calls augment.
21 Ref derive(Ref);
23
24 /// Applies to (open) expressions in a functional context.
25 /// Returns the rewritten expressions and augments the partial and modular pullbacks.
26 /// The rewrite is identity on the term up to renaming of variables.
27 /// Otherwise, only pullbacks are added.
28 /// To do so, some calls (e.g. axioms) are replaced by their derivatives.
29 /// This transformation can be seen as an augmentation with a dual computation that generates the derivatives.
30 Ref augment(Ref, Lam*, Lam*);
31 Ref augment_(Ref, Lam*, Lam*);
32 /// helper functions for augment
33 Ref augment_var(const Var*, Lam*, Lam*);
35 Ref augment_extract(const Extract*, Lam*, Lam*);
36 Ref augment_app(const App*, Lam*, Lam*);
37 Ref augment_lit(const Lit*, Lam*, Lam*);
38 Ref augment_tuple(const Tuple*, Lam*, Lam*);
39 Ref augment_pack(const Pack* pack, Lam* f, Lam* f_diff);
40
41private:
42 /// Transforms closed terms (lambda, operator) to derived expressions.
43 /// `f => f' = λ x. (f x, f*_x)`
44 /// src Def -> dst Def
45 Def2Def derived;
46 /// Associates expressions (not necessarily closed) in a functional context to their derivatived counterpart.
47 /// src Def -> dst Def
48 Def2Def augmented;
49
50 /// dst Def -> dst Def
51 Def2Def partial_pullback;
52
53 /// Shadows the structure of containers for additional auxiliary pullbacks.
54 /// A very advanced optimization might be able to recover shadow pullbacks from the partial pullbacks.
55 /// Example: The shadow pullback of a tuple is a tuple of pullbacks.
56 /// Shadow pullbacks are not modular (composable with other pullbacks).
57 /// The structure pullback only preserves structure shallowly:
58 /// A n-times nested tuple has a tuple of "normal" pullbacks as shadow pullback.
59 /// Each inner nested tuples should have their own structure pullback by construction.
60 /// ```
61 /// e : [B0, [B11, B12], B2]
62 /// e* : [B0, [B11, B12], B2] -> A
63 /// e*S: [B0 -> A, [B11, B12] -> A, B2 -> A]
64 /// ```
65 /// short theory of shadow pb:
66 /// ```
67 /// t: [B0, ..., Bn]
68 /// t*: [B0, ..., Bn] -> A
69 /// t*_S: [B0 -> A, ..., Bn -> A]
70 /// b = t#i : Bi
71 /// b* : Bi -> A
72 /// b* = t*_S #i (if exists)
73 /// ```
74 /// This is equivalent to:
75 /// `\lambda (s:Bi). t*_S (insert s at i in (zero [B0, ..., Bn]))`
76 /// dst Def -> dst Def
77 Def2Def shadow_pullback;
78};
79
80} // namespace mim::plug::autodiff
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:119
A function.
Definition lam.h:96
A (possibly paramterized) Tuple.
Definition tuple.h:88
An optimizer that combines several optimizations in an optimal way.
Definition pass.h:107
PassMan & man()
Definition pass.h:30
Inherit from this class using CRTP, if your Pass does not need state and a fixed-point iteration.
Definition pass.h:220
Helper class to retrieve Infer::arg if present.
Definition def.h:85
Data constructor for a Sigma.
Definition tuple.h:40
This pass is the heart of AD.
Ref augment_var(const Var *, Lam *, Lam *)
helper functions for augment
Ref derive_(Ref)
Additionally to the derivation, the pullback is registered and the maps are initialized.
Ref augment_lit(const Lit *, Lam *, Lam *)
Ref augment_tuple(const Tuple *, Lam *, Lam *)
Ref augment_(Ref, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Ref augment_app(const App *, Lam *, Lam *)
Ref rewrite(Ref) override
Detect autodiff calls.
Ref derive(Ref)
Acts on toplevel autodiff on closed terms:
Ref augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
Ref augment_extract(const Extract *, Lam *, Lam *)
The automatic differentiation Plugin
Definition autodiff.h:6
DefMap< const Def * > Def2Def
Definition def.h:59