9 world().DLOG(
"Derive lambda: {}", def);
11 auto deriv =
world().
mut_lam(deriv_ty)->
set(lam->sym().str() +
"_deriv");
19 auto [arg_ty, ret_pi] = lam->
type()->doms<2>();
20 auto deriv_all_args = deriv->
var();
21 const Def* deriv_arg = deriv->
var(0_s)->
set(
"arg");
29 partial_pullback[deriv_arg] = arg_id_pb;
31 auto ret_var = deriv->var(1);
33 partial_pullback[ret_var] = ret_pb;
35 shadow_pullback[deriv_all_args] =
world().
tuple({arg_id_pb, ret_pb});
36 world().DLOG(
"pullback for argument {} : {} is {} : {}", deriv_arg, deriv_arg->
type(), arg_id_pb,
38 world().DLOG(
"args shadow pb is {} : {}", shadow_pullback[deriv_all_args], shadow_pullback[deriv_all_args]->type());
43 augmented[def] = deriv;
44 world().DLOG(
"Associate {} with {}", def, deriv);
45 world().DLOG(
" {} : {}", lam, lam->type());
46 world().DLOG(
" {} : {}", deriv, deriv->type());
47 augmented[lam->var()] = deriv->
var();
48 world().DLOG(
"Associate vars {} with {}", lam->var(), deriv->var());
61 auto new_body =
augment(lam->body(), lam, deriv);
62 deriv->set(
true, new_body);
Def * set(size_t i, const Def *)
Successively set from left to right.
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
const Def * var(nat_t a, nat_t i) noexcept
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Lam * set(Filter filter, const Def *body)
const Def * tuple(Defs ops)
const Def * var(const Def *type, Def *mut)
Lam * mut_lam(const Pi *pi)
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
const Def * derive_(const Def *)
Additionally to the derivation, the pullback is registered and the maps are initialized.
The automatic differentiation Plugin
const Pi * autodiff_type_fun_pi(const Pi *)
const Def * zero_pullback(const Def *E, const Def *A)
const Def * id_pullback(const Def *)