13using namespace std::literals;
32 auto& world = A->
world();
34 auto id_pb = world.mut_lam(arg_pb_ty)->set(
"id_pb");
35 auto id_pb_scalar = id_pb->var(0_s)->set(
"s");
44 auto& world = A->
world();
47 auto pb = world.mut_lam(pb_ty)->set(
"zero_pb");
48 world.DLOG(
"zero_pullback for {} resp. {} (-> {})", E, A, A_tangent);
49 pb->app(
true, pb->var(1), world.call<
zero>(A_tangent));
62 auto& world = E->
world();
65 auto pb_ty = world.cn({tang_ret, world.cn(tang_arg)});
72 auto& world = arg->
world();
73 world.DLOG(
"autodiff type for {} => {}", arg, ret);
76 world.DLOG(
"augmented types: {} => {}", aug_arg, aug_ret);
77 if (!aug_arg || !aug_ret)
return nullptr;
80 world.DLOG(
"pb type: {}", pb_ty);
83 auto deriv_ty = world.cn({aug_arg, world.cn({aug_ret, pb_ty})});
84 world.DLOG(
"autodiff type: {}", deriv_ty);
90 auto& world = pi->
world();
94 auto ret = pi->
codom();
97 if (!aug_arg)
return nullptr;
99 if (!aug_ret)
return nullptr;
100 return world.pi(aug_arg, aug_ret);
104 auto [arg, ret_pi] = pi->doms<2>();
105 auto ret = ret_pi->as<
Pi>()->dom();
106 world.DLOG(
"compute AD type for pi");
113 auto& world = ty->
world();
117 world.DLOG(
"AutoDiff on type: {} <{}>", ty, ty->
node_name());
119 if (ty == world.type_nat())
return ty;
120 if (
auto arr = ty->isa<
Arr>()) {
121 auto shape = arr->arity();
122 auto body = arr->body();
124 if (!body_ad)
return nullptr;
125 return world.arr(shape, body_ad);
127 if (
auto sig = ty->isa<
Sigma>()) {
129 auto ops =
DefVec(sig->ops(), [&](
const Def* op) { return autodiff_type_fun(op); });
130 world.DLOG(
"ops: {,}", ops);
131 return world.sigma(ops);
135 world.WLOG(
"no-diff type: {}", ty);
142 auto& world = T->
world();
143 world.DLOG(
"zero_def for type {} <{}>", T, T->
node_name());
144 if (
auto arr = T->isa<
Arr>()) {
145 auto arity = arr->arity();
146 auto body = arr->body();
147 auto inner_zero = world.app(world.annex<
zero>(), body);
148 auto zero_arr = world.pack(arity, inner_zero);
149 world.DLOG(
"zero_def for array of shape {} with type {}", arity, body);
150 world.DLOG(
"zero_arr: {}", zero_arr);
154 auto zero = world.lit(T, 0)->set(
"zero");
155 world.DLOG(
"zero_def for int is {}",
zero);
157 }
else if (
auto sig = T->isa<
Sigma>()) {
158 auto ops =
DefVec(sig->ops(), [&](
const Def* op) { return world.app(world.annex<zero>(), op); });
159 return world.tuple(ops);
170 auto& world = T->
world();
171 return world.
app(world.app(world.annex<
sum>(), {world.lit_nat(defs.size()), T}), defs);
void reg_stages(Flags2Phases &, Flags2Passes &passes)
void reg_stages(Flags2Phases &, Flags2Passes &passes)
A (possibly paramterized) Array.
static auto isa(const Def *def)
World & world() const noexcept
std::string_view node_name() const
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
static void hook(Flags2Passes &passes, Args &&... args)
A dependent function type.
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
const Def * codom() const
const Def * app(const Def *callee, const Def *arg)
The automatic differentiation Plugin
const Pi * autodiff_type_fun_pi(const Pi *)
const Def * op_sum(const Def *T, Defs)
const Def * autodiff_type_fun(const Def *)
const Def * zero_def(const Def *T)
const Def * tangent_type_fun(const Def *)
const Def * zero_pullback(const Def *E, const Def *A)
const Def * id_pullback(const Def *)
void register_normalizers(Normalizers &normalizers)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Vector< const Def * > DefVec
absl::flat_hash_map< flags_t, std::function< void(PassMan &, const Def *)> > Flags2Passes
mim::Plugin mim_get_plugin()
absl::flat_hash_map< flags_t, NormalizeFn > Normalizers
absl::flat_hash_map< flags_t, std::function< void(PhaseMan &, const Def *)> > Flags2Phases
Maps an an axiom of a Pass/Phaseto a function that appneds a new Pass/Phase to a PhaseMan.
Basic info and registration function pointer to be returned from a specific plugin.