15using namespace std::literals;
33 auto& world = A->
world();
35 auto id_pb = world.mut_lam(arg_pb_ty)->set(
"id_pb");
36 auto id_pb_scalar = id_pb->var(0_s)->set(
"s");
45 auto& world = A->
world();
48 auto pb = world.mut_lam(pb_ty)->set(
"zero_pb");
49 world.DLOG(
"zero_pullback for {} resp. {} (-> {})", E, A, A_tangent);
50 pb->app(
true, pb->var(1), world.call<
zero>(A_tangent));
63 auto& world = E->
world();
66 auto pb_ty = world.cn({tang_ret, world.cn(tang_arg)});
73 auto& world = arg->
world();
74 world.DLOG(
"autodiff type for {} => {}", arg, ret);
77 world.DLOG(
"augmented types: {} => {}", aug_arg, aug_ret);
78 if (!aug_arg || !aug_ret)
return nullptr;
81 world.DLOG(
"pb type: {}", pb_ty);
84 auto deriv_ty = world.cn({aug_arg, world.cn({aug_ret, pb_ty})});
85 world.DLOG(
"autodiff type: {}", deriv_ty);
91 auto& world = pi->
world();
95 auto ret = pi->
codom();
98 if (!aug_arg)
return nullptr;
100 if (!aug_ret)
return nullptr;
101 return world.pi(aug_arg, aug_ret);
105 auto [arg, ret_pi] = pi->doms<2>();
106 auto ret = ret_pi->as<
Pi>()->dom();
107 world.DLOG(
"compute AD type for pi");
114 auto& world = ty->
world();
118 world.DLOG(
"AutoDiff on type: {} <{}>", ty, ty->
node_name());
120 if (ty == world.type_nat())
return ty;
121 if (
auto arr = ty->isa<
Arr>()) {
122 auto shape = arr->shape();
123 auto body = arr->body();
125 if (!body_ad)
return nullptr;
126 return world.arr(shape, body_ad);
128 if (
auto sig = ty->isa<
Sigma>()) {
130 auto ops =
DefVec(sig->ops(), [&](
const Def* op) { return autodiff_type_fun(op); });
131 world.DLOG(
"ops: {,}", ops);
132 return world.sigma(ops);
136 world.WLOG(
"no-diff type: {}", ty);
143 auto& world = T->
world();
144 world.DLOG(
"zero_def for type {} <{}>", T, T->node_name());
145 if (
auto arr = T->isa<
Arr>()) {
146 auto shape = arr->shape();
147 auto body = arr->body();
148 auto inner_zero = world.app(world.annex<
zero>(), body);
149 auto zero_arr = world.pack(shape, inner_zero);
150 world.DLOG(
"zero_def for array of shape {} with type {}", shape, body);
151 world.DLOG(
"zero_arr: {}", zero_arr);
155 auto zero = world.lit(T, 0)->set(
"zero");
156 world.DLOG(
"zero_def for int is {}",
zero);
158 }
else if (
auto sig = T->isa<
Sigma>()) {
159 auto ops =
DefVec(sig->ops(), [&](
const Def* op) { return world.app(world.annex<zero>(), op); });
160 return world.tuple(ops);
171 auto& world = T->
world();
172 return world.
app(world.app(world.annex<
sum>(), {world.lit_nat(defs.size()), T}), defs);
A (possibly paramterized) Array.
std::string_view node_name() const
static Ref size(Ref def)
Checks if def is a Idx s and returns s or nullptr otherwise.
A dependent function type.
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
This is a thin wrapper for std::span<T, N> with the following additional features:
Ref app(Ref callee, Ref arg)
The automatic differentiation Plugin
const Pi * autodiff_type_fun_pi(const Pi *)
const Def * op_sum(const Def *T, Defs)
const Def * autodiff_type_fun(const Def *)
const Def * zero_def(const Def *T)
const Def * tangent_type_fun(const Def *)
const Def * zero_pullback(const Def *E, const Def *A)
const Def * id_pullback(const Def *)
void register_normalizers(Normalizers &normalizers)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Vector< const Def * > DefVec
absl::flat_hash_map< flags_t, std::function< void(World &, PipelineBuilder &, const Def *)> > Passes
axiom ↦ (pipeline part) × (axiom application) → () The function should inspect Application to const...
void register_pass(Passes &passes, CArgs &&... args)
MIM_EXPORT mim::Plugin mim_get_plugin()
absl::flat_hash_map< flags_t, NormalizeFn > Normalizers
Basic info and registration function pointer to be returned from a specific plugin.