MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff.cpp
Go to the documentation of this file.
2
3#include <mim/config.h>
4
5#include <mim/pass/pass.h>
7
9#include <mim/plug/mem/mem.h>
10
14
15using namespace std::literals;
16using namespace mim;
17using namespace mim::plug;
18
29
30namespace mim::plug::autodiff {
31
32const Def* id_pullback(const Def* A) {
33 auto& world = A->world();
34 auto arg_pb_ty = pullback_type(A, A);
35 auto id_pb = world.mut_lam(arg_pb_ty)->set("id_pb");
36 auto id_pb_scalar = id_pb->var(0_s)->set("s");
37 id_pb->app(true,
38 id_pb->var(1), // can not use ret_var as the result might be higher order
39 id_pb_scalar);
40
41 return id_pb;
42}
43
44const Def* zero_pullback(const Def* E, const Def* A) {
45 auto& world = A->world();
46 auto A_tangent = tangent_type_fun(A);
47 auto pb_ty = pullback_type(E, A);
48 auto pb = world.mut_lam(pb_ty)->set("zero_pb");
49 world.DLOG("zero_pullback for {} resp. {} (-> {})", E, A, A_tangent);
50 pb->app(true, pb->var(1), world.call<zero>(A_tangent));
51 return pb;
52}
53
54// `P` => `P*`
55// TODO: nothing? function => R? Mem => R?
56// TODO: rename to op_tangent_type
57const Def* tangent_type_fun(const Def* ty) { return ty; }
58
59/// computes pb type `E* -> A*`
60/// `E` - type of the expression (return type for a function)
61/// `A` - type of the argument (point of orientation resp. derivative - argument type for partial pullbacks)
62const Pi* pullback_type(const Def* E, const Def* A) {
63 auto& world = E->world();
64 auto tang_arg = tangent_type_fun(A);
65 auto tang_ret = tangent_type_fun(E);
66 auto pb_ty = world.cn({tang_ret, world.cn(tang_arg)});
67 return pb_ty;
68}
69
70namespace {
71// `A,R` => `(A->R)' = A' -> R' * (R* -> A*)`
72const Pi* autodiff_type_fun(const Def* arg, const Def* ret) {
73 auto& world = arg->world();
74 world.DLOG("autodiff type for {} => {}", arg, ret);
77 world.DLOG("augmented types: {} => {}", aug_arg, aug_ret);
78 if (!aug_arg || !aug_ret) return nullptr;
79 // `Q* -> P*`
80 auto pb_ty = pullback_type(ret, arg);
81 world.DLOG("pb type: {}", pb_ty);
82 // `P' -> Q' * (Q* -> P*)`
83
84 auto deriv_ty = world.cn({aug_arg, world.cn({aug_ret, pb_ty})});
85 world.DLOG("autodiff type: {}", deriv_ty);
86 return deriv_ty;
87}
88} // namespace
89
90const Pi* autodiff_type_fun_pi(const Pi* pi) {
91 auto& world = pi->world();
92 if (!Pi::isa_cn(pi)) {
93 // TODO: dependency
94 auto arg = pi->dom();
95 auto ret = pi->codom();
96 if (ret->isa<Pi>()) {
97 auto aug_arg = autodiff_type_fun(arg);
98 if (!aug_arg) return nullptr;
99 auto aug_ret = autodiff_type_fun(pi->codom());
100 if (!aug_ret) return nullptr;
101 return world.pi(aug_arg, aug_ret);
102 }
103 return autodiff_type_fun(arg, ret);
104 }
105 auto [arg, ret_pi] = pi->doms<2>();
106 auto ret = ret_pi->as<Pi>()->dom();
107 world.DLOG("compute AD type for pi");
108 return autodiff_type_fun(arg, ret);
109}
110
111// In general transforms `A` => `A'`.
112// Especially `P->Q` => `P'->Q' * (Q* -> P*)`.
113const Def* autodiff_type_fun(const Def* ty) {
114 auto& world = ty->world();
115 // TODO: handle DS (operators)
116 if (auto pi = ty->isa<Pi>()) return autodiff_type_fun_pi(pi);
117 // Also handles autodiff call from axiom declaration => abstract => leave it.
118 world.DLOG("AutoDiff on type: {} <{}>", ty, ty->node_name());
119 if (Idx::size(ty)) return ty;
120 if (ty == world.type_nat()) return ty;
121 if (auto arr = ty->isa<Arr>()) {
122 auto shape = arr->shape();
123 auto body = arr->body();
124 auto body_ad = autodiff_type_fun(body);
125 if (!body_ad) return nullptr;
126 return world.arr(shape, body_ad);
127 }
128 if (auto sig = ty->isa<Sigma>()) {
129 // TODO: mut sigma
130 auto ops = DefVec(sig->ops(), [&](const Def* op) { return autodiff_type_fun(op); });
131 world.DLOG("ops: {,}", ops);
132 return world.sigma(ops);
133 }
134 // mem
135 if (match<mem::M>(ty)) return ty;
136 world.WLOG("no-diff type: {}", ty);
137 return nullptr;
138}
139
140const Def* zero_def(const Def* T) {
141 // TODO: we want: zero mem -> zero mem or bot
142 // zero [A,B,C] -> [zero A, zero B, zero C]
143 auto& world = T->world();
144 world.DLOG("zero_def for type {} <{}>", T, T->node_name());
145 if (auto arr = T->isa<Arr>()) {
146 auto shape = arr->shape();
147 auto body = arr->body();
148 auto inner_zero = world.app(world.annex<zero>(), body);
149 auto zero_arr = world.pack(shape, inner_zero);
150 world.DLOG("zero_def for array of shape {} with type {}", shape, body);
151 world.DLOG("zero_arr: {}", zero_arr);
152 return zero_arr;
153 } else if (Idx::size(T)) {
154 // TODO: real
155 auto zero = world.lit(T, 0)->set("zero");
156 world.DLOG("zero_def for int is {}", zero);
157 return zero;
158 } else if (auto sig = T->isa<Sigma>()) {
159 auto ops = DefVec(sig->ops(), [&](const Def* op) { return world.app(world.annex<zero>(), op); });
160 return world.tuple(ops);
161 }
162
163 // or return bot
164 // or id => zero T
165 // return world.app(world.annex<zero>(), T);
166 return nullptr;
167}
168
169const Def* op_sum(const Def* T, Defs defs) {
170 // TODO: assert all are of type T
171 auto& world = T->world();
172 return world.app(world.app(world.annex<sum>(), {world.lit_nat(defs.size()), T}), defs);
173}
174
175} // namespace mim::plug::autodiff
A (possibly paramterized) Array.
Definition tuple.h:66
Base class for all Defs.
Definition def.h:223
std::string_view node_name() const
Definition def.cpp:431
World & world() const
Definition def.cpp:415
static Ref size(Ref def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:556
A dependent function type.
Definition lam.h:11
Ref codom() const
Definition lam.h:40
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:50
Ref dom() const
Definition lam.h:32
A dependent tuple type.
Definition tuple.h:9
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
Ref app(Ref callee, Ref arg)
Definition world.cpp:186
#define MIM_EXPORT
Definition config.h:16
The automatic differentiation Plugin
Definition autodiff.h:6
const Pi * autodiff_type_fun_pi(const Pi *)
Definition autodiff.cpp:90
const Def * op_sum(const Def *T, Defs)
Definition autodiff.cpp:169
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:113
const Def * zero_def(const Def *T)
Definition autodiff.cpp:140
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:57
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:44
const Def * id_pullback(const Def *)
Definition autodiff.cpp:32
void register_normalizers(Normalizers &normalizers)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Definition autodiff.cpp:62
Definition cfg.h:11
Vector< const Def * > DefVec
Definition def.h:62
auto match(Ref def)
Definition axiom.h:112
absl::flat_hash_map< flags_t, std::function< void(World &, PipelineBuilder &, const Def *)> > Passes
axiom ↦ (pipeline part) × (axiom application) → () The function should inspect Application to const...
Definition plugin.h:22
void register_pass(Passes &passes, CArgs &&... args)
MIM_EXPORT mim::Plugin mim_get_plugin()
absl::flat_hash_map< flags_t, NormalizeFn > Normalizers
Definition plugin.h:19
Basic info and registration function pointer to be returned from a specific plugin.
Definition plugin.h:29