MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff.cpp
Go to the documentation of this file.
2
3#include <mim/config.h>
4#include <mim/phase.h>
5
6#include <mim/plug/mem/mem.h>
7
9
10using namespace std::literals;
11using namespace mim;
12using namespace mim::plug;
13
14void reg_stages(Flags2Stages& stages) {
16
18 if (auto zero = Axm::isa<autodiff::zero>(def); zero) {
19 if (auto z = autodiff::zero_def(zero->arg())) return z;
20 }
21 return {};
22 });
23}
24
26 return {"autodiff", [](Normalizers& n) { autodiff::register_normalizers(n); }, reg_stages, nullptr};
27}
28
29namespace mim::plug::autodiff {
30const Def* id_pullback(const Def* A) {
31 auto& world = A->world();
32 auto arg_pb_ty = pullback_type(A, A);
33 auto id_pb = world.mut_lam(arg_pb_ty)->set("id_pb");
34 auto id_pb_scalar = id_pb->var(0_s)->set("s");
35 id_pb->app(true,
36 id_pb->var(1), // can not use ret_var as the result might be higher order
37 id_pb_scalar);
38
39 return id_pb;
40}
41
42const Def* zero_pullback(const Def* E, const Def* A) {
43 auto& world = A->world();
44 auto A_tangent = tangent_type_fun(A);
45 auto pb_ty = pullback_type(E, A);
46 auto pb = world.mut_lam(pb_ty)->set("zero_pb");
47 world.DLOG("zero_pullback for {} resp. {} (-> {})", E, A, A_tangent);
48 pb->app(true, pb->var(1), world.call<zero>(A_tangent));
49 return pb;
50}
51
52// `P` => `P*`
53// TODO: nothing? function => R? Mem => R?
54// TODO: rename to op_tangent_type
55const Def* tangent_type_fun(const Def* ty) { return ty; }
56
57/// computes pb type `E* -> A*`
58/// `E` - type of the expression (return type for a function)
59/// `A` - type of the argument (point of orientation resp. derivative - argument type for partial pullbacks)
60const Pi* pullback_type(const Def* E, const Def* A) {
61 auto& world = E->world();
62 auto tang_arg = tangent_type_fun(A);
63 auto tang_ret = tangent_type_fun(E);
64 auto pb_ty = world.cn({tang_ret, world.cn(tang_arg)});
65 return pb_ty;
66}
67
68namespace {
69// `A,R` => `(A->R)' = A' -> R' * (R* -> A*)`
70const Pi* autodiff_type_fun(const Def* arg, const Def* ret) {
71 auto& world = arg->world();
72 world.DLOG("autodiff type for {} => {}", arg, ret);
75 world.DLOG("augmented types: {} => {}", aug_arg, aug_ret);
76 if (!aug_arg || !aug_ret) return nullptr;
77 // `Q* -> P*`
78 auto pb_ty = pullback_type(ret, arg);
79 world.DLOG("pb type: {}", pb_ty);
80 // `P' -> Q' * (Q* -> P*)`
81
82 auto deriv_ty = world.cn({aug_arg, world.cn({aug_ret, pb_ty})});
83 world.DLOG("autodiff type: {}", deriv_ty);
84 return deriv_ty;
85}
86} // namespace
87
88const Pi* autodiff_type_fun_pi(const Pi* pi) {
89 auto& world = pi->world();
90 if (!Pi::isa_cn(pi)) {
91 // TODO: dependency
92 auto arg = pi->dom();
93 auto ret = pi->codom();
94 if (ret->isa<Pi>()) {
95 auto aug_arg = autodiff_type_fun(arg);
96 if (!aug_arg) return nullptr;
97 auto aug_ret = autodiff_type_fun(pi->codom());
98 if (!aug_ret) return nullptr;
99 return world.pi(aug_arg, aug_ret);
100 }
101 return autodiff_type_fun(arg, ret);
102 }
103 auto [arg, ret_pi] = pi->doms<2>();
104 auto ret = ret_pi->as<Pi>()->dom();
105 world.DLOG("compute AD type for pi");
106 return autodiff_type_fun(arg, ret);
107}
108
109// In general transforms `A` => `A'`.
110// Especially `P->Q` => `P'->Q' * (Q* -> P*)`.
111const Def* autodiff_type_fun(const Def* ty) {
112 auto& world = ty->world();
113 // TODO: handle DS (operators)
114 if (auto pi = ty->isa<Pi>()) return autodiff_type_fun_pi(pi);
115 // Also handles autodiff call from axm declaration => abstract => leave it.
116 world.DLOG("AutoDiff on type: {} <{}>", ty, ty->node_name());
117 if (Idx::isa(ty)) return ty;
118 if (ty == world.type_nat()) return ty;
119 if (auto arr = ty->isa<Arr>()) {
120 auto shape = arr->arity();
121 auto body = arr->body();
122 auto body_ad = autodiff_type_fun(body);
123 if (!body_ad) return nullptr;
124 return world.arr(shape, body_ad);
125 }
126 if (auto sig = ty->isa<Sigma>()) {
127 // TODO: mut sigma
128 auto ops = DefVec(sig->ops(), [&](const Def* op) { return autodiff_type_fun(op); });
129 world.DLOG("ops: {,}", ops);
130 return world.sigma(ops);
131 }
132 // mem
133 if (Axm::isa<mem::M>(ty)) return ty;
134 world.WLOG("no-diff type: {}", ty);
135 return nullptr;
136}
137
138const Def* zero_def(const Def* T) {
139 // TODO: we want: zero mem -> zero mem or bot
140 // zero [A,B,C] -> [zero A, zero B, zero C]
141 auto& world = T->world();
142 world.DLOG("zero_def for type {} <{}>", T, T->node_name());
143 if (auto arr = T->isa<Arr>()) {
144 auto arity = arr->arity();
145 auto body = arr->body();
146 auto inner_zero = world.app(world.annex<zero>(), body);
147 auto zero_arr = world.pack(arity, inner_zero);
148 world.DLOG("zero_def for array of shape {} with type {}", arity, body);
149 world.DLOG("zero_arr: {}", zero_arr);
150 return zero_arr;
151 } else if (Idx::isa(T)) {
152 // TODO: real
153 auto zero = world.lit(T, 0)->set("zero");
154 world.DLOG("zero_def for int is {}", zero);
155 return zero;
156 } else if (auto sig = T->isa<Sigma>()) {
157 auto ops = DefVec(sig->ops(), [&](const Def* op) { return world.app(world.annex<zero>(), op); });
158 return world.tuple(ops);
159 }
160
161 // or return bot
162 // or id => zero T
163 // return world.app(world.annex<zero>(), T);
164 return nullptr;
165}
166
167const Def* op_sum(const Def* T, Defs defs) {
168 // TODO: assert all are of type T
169 auto& world = T->world();
170 return world.app(world.app(world.annex<sum>(), {world.lit_nat(defs.size()), T}), defs);
171}
172
173} // namespace mim::plug::autodiff
void reg_stages(Flags2Stages &stages)
Definition affine.cpp:11
void reg_stages(Flags2Stages &stages)
Definition autodiff.cpp:14
A (possibly paramterized) Array.
Definition tuple.h:117
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
World & world() const noexcept
Definition def.cpp:439
std::string_view node_name() const
Definition def.cpp:460
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:613
A dependent function type.
Definition lam.h:15
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:48
const Def * dom() const
Definition lam.h:36
const Def * codom() const
Definition lam.h:37
A dependent tuple type.
Definition tuple.h:20
static void hook(Flags2Stages &stages)
Definition pass.h:57
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:195
#define MIM_EXPORT
Definition config.h:16
The automatic differentiation Plugin
Definition autodiff.h:6
const Pi * autodiff_type_fun_pi(const Pi *)
Definition autodiff.cpp:88
const Def * op_sum(const Def *T, Defs)
Definition autodiff.cpp:167
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:111
const Def * zero_def(const Def *T)
Definition autodiff.cpp:138
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:55
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:42
const Def * id_pullback(const Def *)
Definition autodiff.cpp:30
void register_normalizers(Normalizers &normalizers)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Definition autodiff.cpp:60
Definition ast.h:14
View< const Def * > Defs
Definition def.h:76
Vector< const Def * > DefVec
Definition def.h:77
absl::flat_hash_map< flags_t, std::function< std::unique_ptr< Stage >(World &)> > Flags2Stages
Maps an an axiom of a Stage to a function that creates one.
Definition plugin.h:22
mim::Plugin mim_get_plugin()
absl::flat_hash_map< flags_t, NormalizeFn > Normalizers
Definition plugin.h:19
#define MIM_REPL(__stages, __annex,...)
Definition phase.h:141
Basic info and registration function pointer to be returned from a specific plugin.
Definition plugin.h:30