9 world().DLOG(
"Derive lambda: {}", def);
11 auto deriv =
world().
mut_lam(deriv_ty)->
set(lam->sym().str() +
"_deriv");
19 auto [arg_ty, ret_pi] = lam->type()->doms<2>();
20 auto deriv_all_args = deriv->var();
21 Ref deriv_arg = deriv->
var(0_s)->
set(
"arg");
29 partial_pullback[deriv_arg] = arg_id_pb;
31 auto ret_var = deriv->var(1);
33 partial_pullback[ret_var] = ret_pb;
35 shadow_pullback[deriv_all_args] =
world().
tuple({arg_id_pb, ret_pb});
36 world().DLOG(
"pullback for argument {} : {} is {} : {}", deriv_arg, deriv_arg->
type(), arg_id_pb,
38 world().DLOG(
"args shadow pb is {} : {}", shadow_pullback[deriv_all_args], shadow_pullback[deriv_all_args]->type());
43 augmented[def] = deriv;
44 world().DLOG(
"Associate {} with {}", def, deriv);
45 world().DLOG(
" {} : {}", lam, lam->type());
46 world().DLOG(
" {} : {}", deriv, deriv->type());
47 augmented[lam->var()] = deriv->var();
48 world().DLOG(
"Associate vars {} with {}", lam->var(), deriv->var());
61 auto new_body =
augment(lam->body(), lam, deriv);
62 deriv->set(
true, new_body);
T * as_mut() const
Asserts that this is a mutable, casts constness away and performs a static_cast to T.
Def * set(size_t i, Ref)
Successively set from left to right.
Ref type() const noexcept
Yields the raw type of this Def, i.e. maybe nullptr.
Ref var(nat_t a, nat_t i) noexcept
Lam * set(Filter filter, Ref body)
Helper class to retrieve Infer::arg if present.
Lam * mut_lam(const Pi *pi)
Ref derive_(Ref)
Additionally to the derivation, the pullback is registered and the maps are initialized.
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
The automatic differentiation Plugin
const Pi * autodiff_type_fun_pi(const Pi *)
const Def * zero_pullback(const Def *E, const Def *A)
const Def * id_pullback(const Def *)