MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff_rewrite_toplevel.cpp
Go to the documentation of this file.
3
4namespace mim::plug::autodiff {
5
6/// Additionally to the derivation, the pullback is registered and the maps are initialized.
7const Def* AutoDiffEval::derive_(const Def* def) {
8 auto lam = def->as_mut<Lam>(); // TODO check if mutable
9 world().DLOG("Derive lambda: {}", def);
10 auto deriv_ty = autodiff_type_fun_pi(lam->type());
11 auto deriv = world().mut_lam(deriv_ty)->set(lam->sym().str() + "_deriv");
12
13 // We first pre-register the derivatives.
14 // This knowledge is needed for recursion.
15 // (Alternatively, we could also use projections out the variables instead of pre-partial-pullback
16 // initialization.)
17 derived[lam] = deriv;
18
19 auto [arg_ty, ret_pi] = lam->type()->doms<2>();
20 auto deriv_all_args = deriv->var();
21 const Def* deriv_arg = deriv->var(0_s)->set("arg");
22
23 // We generate the shadow pullbacks dynamically to save work and avoid code duplication.
24 // Only the toplevel pullback for arguments and return continuation is special cased.
25
26 // TODO: check identity: could use identity tangent(arg_ty) = tangent(augment(arg_ty)) with deriv_arg->type() =
27 // augment(arg_ty) We give the argument the identity pullback.
28 auto arg_id_pb = id_pullback(arg_ty);
29 partial_pullback[deriv_arg] = arg_id_pb;
30 // The return continuation has to formally exist but should never be directly accessed.
31 auto ret_var = deriv->var(1);
32 auto ret_pb = zero_pullback(lam->var(1)->type(), arg_ty);
33 partial_pullback[ret_var] = ret_pb;
34
35 shadow_pullback[deriv_all_args] = world().tuple({arg_id_pb, ret_pb});
36 world().DLOG("pullback for argument {} : {} is {} : {}", deriv_arg, deriv_arg->type(), arg_id_pb,
37 arg_id_pb->type());
38 world().DLOG("args shadow pb is {} : {}", shadow_pullback[deriv_all_args], shadow_pullback[deriv_all_args]->type());
39
40 // We pre-register the augment replacements.
41 // The function and its variables are replaced by their new derived versions.
42 // TODO: maybe leave out function call (duplication with derived)
43 augmented[def] = deriv;
44 world().DLOG("Associate {} with {}", def, deriv);
45 world().DLOG(" {} : {}", lam, lam->type());
46 world().DLOG(" {} : {}", deriv, deriv->type());
47 augmented[lam->var()] = deriv->var();
48 world().DLOG("Associate vars {} with {}", lam->var(), deriv->var());
49
50 // already contains the correct application of
51 // deriv->ret_var() by specification
52 // f : cn[R] has a partial derivative (exception to closed rule)
53 // f': cn[R, cn[R, cn[A]]]
54 // this is needed for continuations (without closure conversion)
55 // but also essentially for the return continuation
56
57 // Here a reminder of types:
58 // The expression `e: B` has the implicit function `e_fun: A -> B`
59 // The partial pullback is then `e*: B* -> A*`
60 // The derivatived version is `e': B' × (B* -> A*)` which is an application of `e'_fun: A' -> B' × (B* -> A*)`
61 auto new_body = augment(lam->body(), lam, deriv);
62 deriv->set(true, new_body);
63
64 return deriv;
65}
66
67} // namespace mim::plug::autodiff
Base class for all Defs.
Definition def.h:198
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:266
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:438
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:379
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:242
A function.
Definition lam.h:105
Lam * set(Filter filter, const Def *body)
Definition lam.h:164
const Pi * type() const
Definition lam.h:125
World & world()
Definition pass.h:296
const Def * tuple(Defs ops)
Definition world.cpp:246
const Def * var(const Def *type, Def *mut)
Definition world.cpp:167
Lam * mut_lam(const Pi *pi)
Definition world.h:280
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
const Def * derive_(const Def *)
Additionally to the derivation, the pullback is registered and the maps are initialized.
The automatic differentiation Plugin
Definition autodiff.h:6
const Pi * autodiff_type_fun_pi(const Pi *)
Definition autodiff.cpp:90
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:44
const Def * id_pullback(const Def *)
Definition autodiff.cpp:32