MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
eval.h
Go to the documentation of this file.
1#pragma once
2
3#include <mim/def.h>
4#include <mim/pass.h>
5
6namespace mim::plug::autodiff {
7
8/// This pass is the heart of AD.
9/// We replace an `autodiff fun` call with the differentiated function.
10class Eval : public RWPass<Eval, Lam> {
11public:
14
15 /// Detect autodiff calls.
16 const Def* rewrite(const Def*) override;
17
18 /// Acts on toplevel autodiff on closed terms:
19 /// * Replaces lambdas, operators with the appropriate derivatives.
20 /// * Creates new lambda, calls associate variables, init maps, calls augment.
21 const Def* derive(const Def*);
22 const Def* derive_(const Def*);
23
24 /// Applies to (open) expressions in a functional context.
25 /// Returns the rewritten expressions and augments the partial and modular pullbacks.
26 /// The rewrite is identity on the term up to renaming of variables.
27 /// Otherwise, only pullbacks are added.
28 /// To do so, some calls (e.g. axms) are replaced by their derivatives.
29 /// This transformation can be seen as an augmentation with a dual computation that generates the derivatives.
30 const Def* augment(const Def*, Lam*, Lam*);
31 const Def* augment_(const Def*, Lam*, Lam*);
32 /// helper functions for augment
33 const Def* augment_var(const Var*, Lam*, Lam*);
34 const Def* augment_lam(Lam*, Lam*, Lam*);
35 const Def* augment_extract(const Extract*, Lam*, Lam*);
36 const Def* augment_app(const App*, Lam*, Lam*);
37 const Def* augment_lit(const Lit*, Lam*, Lam*);
38 const Def* augment_tuple(const Tuple*, Lam*, Lam*);
39 const Def* augment_pack(const Pack* pack, Lam* f, Lam* f_diff);
40
41private:
42 /// Transforms closed terms (lambda, operator) to derived expressions.
43 /// `f => f' = λ x. (f x, f*_x)`
44 /// src Def -> dst Def
45 Def2Def derived;
46 /// Associates expressions (not necessarily closed) in a functional context to their derivatived counterpart.
47 /// src Def -> dst Def
48 Def2Def augmented;
49
50 /// dst Def -> dst Def
51 Def2Def partial_pullback;
52
53 /// Shadows the structure of containers for additional auxiliary pullbacks.
54 /// A very advanced optimization might be able to recover shadow pullbacks from the partial pullbacks.
55 /// Example: The shadow pullback of a tuple is a tuple of pullbacks.
56 /// Shadow pullbacks are not modular (composable with other pullbacks).
57 /// The structure pullback only preserves structure shallowly:
58 /// A n-times nested tuple has a tuple of "normal" pullbacks as shadow pullback.
59 /// Each inner nested tuples should have their own structure pullback by construction.
60 /// ```
61 /// e : [B0, [B11, B12], B2]
62 /// e* : [B0, [B11, B12], B2] -> A
63 /// e*S: [B0 -> A, [B11, B12] -> A, B2 -> A]
64 /// ```
65 /// short theory of shadow pb:
66 /// ```
67 /// t: [B0, ..., Bn]
68 /// t*: [B0, ..., Bn] -> A
69 /// t*_S: [B0 -> A, ..., Bn -> A]
70 /// b = t#i : Bi
71 /// b* : Bi -> A
72 /// b* = t*_S #i (if exists)
73 /// ```
74 /// This is equivalent to:
75 /// `\lambda (s:Bi). t*_S (insert s at i in (zero [B0, ..., Bn]))`
76 /// dst Def -> dst Def
77 Def2Def shadow_pullback;
78};
79
80} // namespace mim::plug::autodiff
Base class for all Defs.
Definition def.h:251
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:206
A function.
Definition lam.h:111
A (possibly paramterized) Tuple.
Definition tuple.h:166
RWPass(World &world, std::string name)
Definition pass.h:295
World & world()
Definition pass.h:64
flags_t annex() const
Definition pass.h:68
Data constructor for a Sigma.
Definition tuple.h:68
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:36
const Def * augment_lit(const Lit *, Lam *, Lam *)
const Def * derive_(const Def *)
Additionally to the derivation, the pullback is registered and the maps are initialized.
const Def * augment_tuple(const Tuple *, Lam *, Lam *)
const Def * augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
const Def * augment_app(const App *, Lam *, Lam *)
const Def * augment_lam(Lam *, Lam *, Lam *)
const Def * augment_(const Def *, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Definition eval.cpp:10
const Def * augment_extract(const Extract *, Lam *, Lam *)
const Def * rewrite(const Def *) override
Detect autodiff calls.
Definition eval.cpp:22
Eval(World &world, flags_t annex)
Definition eval.h:12
const Def * augment_var(const Var *, Lam *, Lam *)
helper functions for augment
const Def * derive(const Def *)
Acts on toplevel autodiff on closed terms:
Definition eval.cpp:16
The automatic differentiation Plugin
Definition autodiff.h:6
DefMap< const Def * > Def2Def
Definition def.h:75
u64 flags_t
Definition types.h:45