Thorin 1.9.0
The Higher ORder INtermediate representation
Loading...
Searching...
No Matches
autodiff_rewrite_toplevel.cpp
Go to the documentation of this file.
3
5
6/// Additionally to the derivation, the pullback is registered and the maps are initialized.
8 auto& world = def->world();
9 auto lam = def->as_mut<Lam>(); // TODO check if mutable
10 world.DLOG("Derive lambda: {}", def);
11 auto deriv_ty = autodiff_type_fun_pi(lam->type());
12 auto deriv = world.mut_lam(deriv_ty)->set(lam->sym().str() + "_deriv");
13
14 // We first pre-register the derivatives.
15 // This knowledge is needed for recursion.
16 // (Alternatively, we could also use projections out the variables instead of pre-partial-pullback
17 // initialization.)
18 derived[lam] = deriv;
19
20 auto [arg_ty, ret_pi] = lam->type()->doms<2>();
21 auto deriv_all_args = deriv->var();
22 Ref deriv_arg = deriv->var(0_s)->set("arg");
23
24 // We generate the shadow pullbacks dynamically to save work and avoid code duplication.
25 // Only the toplevel pullback for arguments and return continuation is special cased.
26
27 // TODO: check identity: could use identity tangent(arg_ty) = tangent(augment(arg_ty)) with deriv_arg->type() =
28 // augment(arg_ty) We give the argument the identity pullback.
29 auto arg_id_pb = id_pullback(arg_ty);
30 partial_pullback[deriv_arg] = arg_id_pb;
31 // The return continuation has to formally exist but should never be directly accessed.
32 auto ret_var = deriv->var(1);
33 auto ret_pb = zero_pullback(lam->var(1)->type(), arg_ty);
34 partial_pullback[ret_var] = ret_pb;
35
36 shadow_pullback[deriv_all_args] = world.tuple({arg_id_pb, ret_pb});
37 world.DLOG("pullback for argument {} : {} is {} : {}", deriv_arg, deriv_arg->type(), arg_id_pb, arg_id_pb->type());
38 world.DLOG("args shadow pb is {} : {}", shadow_pullback[deriv_all_args], shadow_pullback[deriv_all_args]->type());
39
40 // We pre-register the augment replacements.
41 // The function and its variables are replaced by their new derived versions.
42 // TODO: maybe leave out function call (duplication with derived)
43 augmented[def] = deriv;
44 world.DLOG("Associate {} with {}", def, deriv);
45 world.DLOG(" {} : {}", lam, lam->type());
46 world.DLOG(" {} : {}", deriv, deriv->type());
47 augmented[lam->var()] = deriv->var();
48 world.DLOG("Associate vars {} with {}", lam->var(), deriv->var());
49
50 // already contains the correct application of
51 // deriv->ret_var() by specification
52 // f : cn[R] has a partial derivative (exception to closed rule)
53 // f': cn[R, cn[R, cn[A]]]
54 // this is needed for continuations (without closure conversion)
55 // but also essentially for the return continuation
56
57 // Here a reminder of types:
58 // The expression `e: B` has the implicit function `e_fun: A -> B`
59 // The partial pullback is then `e*: B* -> A*`
60 // The derivatived version is `e': B' × (B* -> A*)` which is an application of `e'_fun: A' -> B' × (B* -> A*)`
61 auto new_body = augment(lam->body(), lam, deriv);
62 deriv->set(true, new_body);
63
64 return deriv;
65}
66
67} // namespace thorin::plug::autodiff
Ref var(nat_t a, nat_t i)
Definition def.h:403
const Def * type() const
Yields the raw type of this Def, i.e. maybe nullptr.
Definition def.h:248
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Definition def.h:457
Def * set(size_t i, const Def *def)
Successively set from left to right.
Definition def.cpp:254
World & world() const
Definition def.cpp:421
A function.
Definition lam.h:97
Lam * set(Filter filter, const Def *body)
Definition lam.h:159
World & world()
Definition pass.h:296
Helper class to retrieve Infer::arg if present.
Definition def.h:87
Ref tuple(Defs ops)
Definition world.cpp:226
Lam * mut_lam(const Pi *pi)
Definition world.h:263
Ref derive_(Ref)
Additionally to the derivation, the pullback is registered and the maps are initialized.
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
The automatic differentiation Plugin
Definition autodiff.h:7
const Pi * autodiff_type_fun_pi(const Pi *)
Definition autodiff.cpp:90
const Def * id_pullback(const Def *)
Definition autodiff.cpp:32
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:44