MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff_eval.h
Go to the documentation of this file.
1#pragma once
2
3#include <mim/def.h>
4
5#include <mim/pass/pass.h>
6
7namespace mim::plug::autodiff {
8
9/// This pass is the heart of AD.
10/// We replace an `autodiff fun` call with the differentiated function.
11class AutoDiffEval : public RWPass<AutoDiffEval, Lam> {
12public:
14 : RWPass(man, "autodiff_eval") {}
15
16 /// Detect autodiff calls.
17 const Def* rewrite(const Def*) override;
18
19 /// Acts on toplevel autodiff on closed terms:
20 /// * Replaces lambdas, operators with the appropriate derivatives.
21 /// * Creates new lambda, calls associate variables, init maps, calls augment.
22 const Def* derive(const Def*);
23 const Def* derive_(const Def*);
24
25 /// Applies to (open) expressions in a functional context.
26 /// Returns the rewritten expressions and augments the partial and modular pullbacks.
27 /// The rewrite is identity on the term up to renaming of variables.
28 /// Otherwise, only pullbacks are added.
29 /// To do so, some calls (e.g. axioms) are replaced by their derivatives.
30 /// This transformation can be seen as an augmentation with a dual computation that generates the derivatives.
31 const Def* augment(const Def*, Lam*, Lam*);
32 const Def* augment_(const Def*, Lam*, Lam*);
33 /// helper functions for augment
34 const Def* augment_var(const Var*, Lam*, Lam*);
35 const Def* augment_lam(Lam*, Lam*, Lam*);
36 const Def* augment_extract(const Extract*, Lam*, Lam*);
37 const Def* augment_app(const App*, Lam*, Lam*);
38 const Def* augment_lit(const Lit*, Lam*, Lam*);
39 const Def* augment_tuple(const Tuple*, Lam*, Lam*);
40 const Def* augment_pack(const Pack* pack, Lam* f, Lam* f_diff);
41
42private:
43 /// Transforms closed terms (lambda, operator) to derived expressions.
44 /// `f => f' = λ x. (f x, f*_x)`
45 /// src Def -> dst Def
46 Def2Def derived;
47 /// Associates expressions (not necessarily closed) in a functional context to their derivatived counterpart.
48 /// src Def -> dst Def
49 Def2Def augmented;
50
51 /// dst Def -> dst Def
52 Def2Def partial_pullback;
53
54 /// Shadows the structure of containers for additional auxiliary pullbacks.
55 /// A very advanced optimization might be able to recover shadow pullbacks from the partial pullbacks.
56 /// Example: The shadow pullback of a tuple is a tuple of pullbacks.
57 /// Shadow pullbacks are not modular (composable with other pullbacks).
58 /// The structure pullback only preserves structure shallowly:
59 /// A n-times nested tuple has a tuple of "normal" pullbacks as shadow pullback.
60 /// Each inner nested tuples should have their own structure pullback by construction.
61 /// ```
62 /// e : [B0, [B11, B12], B2]
63 /// e* : [B0, [B11, B12], B2] -> A
64 /// e*S: [B0 -> A, [B11, B12] -> A, B2 -> A]
65 /// ```
66 /// short theory of shadow pb:
67 /// ```
68 /// t: [B0, ..., Bn]
69 /// t*: [B0, ..., Bn] -> A
70 /// t*_S: [B0 -> A, ..., Bn -> A]
71 /// b = t#i : Bi
72 /// b* : Bi -> A
73 /// b* = t*_S #i (if exists)
74 /// ```
75 /// This is equivalent to:
76 /// `\lambda (s:Bi). t*_S (insert s at i in (zero [B0, ..., Bn]))`
77 /// dst Def -> dst Def
78 Def2Def shadow_pullback;
79};
80
81} // namespace mim::plug::autodiff
Base class for all Defs.
Definition def.h:198
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:155
A function.
Definition lam.h:105
A (possibly paramterized) Tuple.
Definition tuple.h:115
PassMan & man()
Definition pass.h:30
friend class PassMan
Definition pass.h:101
RWPass(PassMan &man, std::string_view name)
Definition pass.h:222
Data constructor for a Sigma.
Definition tuple.h:50
const Def * augment_tuple(const Tuple *, Lam *, Lam *)
const Def * augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
const Def * augment_var(const Var *, Lam *, Lam *)
helper functions for augment
const Def * derive(const Def *)
Acts on toplevel autodiff on closed terms:
const Def * augment_app(const App *, Lam *, Lam *)
const Def * augment_(const Def *, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
const Def * augment_lam(Lam *, Lam *, Lam *)
const Def * derive_(const Def *)
Additionally to the derivation, the pullback is registered and the maps are initialized.
const Def * augment_extract(const Extract *, Lam *, Lam *)
const Def * augment_lit(const Lit *, Lam *, Lam *)
const Def * rewrite(const Def *) override
Detect autodiff calls.
The automatic differentiation Plugin
Definition autodiff.h:6
DefMap< const Def * > Def2Def
Definition def.h:48