MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff.cpp
Go to the documentation of this file.
2
3#include <mim/config.h>
4
5#include <mim/pass/pass.h>
6
7#include <mim/plug/mem/mem.h>
8
12
13using namespace std::literals;
14using namespace mim;
15using namespace mim::plug;
16
24
26 return {"autodiff", [](Normalizers& n) { autodiff::register_normalizers(n); }, reg_stages, nullptr};
27}
28
29namespace mim::plug::autodiff {
30
31const Def* id_pullback(const Def* A) {
32 auto& world = A->world();
33 auto arg_pb_ty = pullback_type(A, A);
34 auto id_pb = world.mut_lam(arg_pb_ty)->set("id_pb");
35 auto id_pb_scalar = id_pb->var(0_s)->set("s");
36 id_pb->app(true,
37 id_pb->var(1), // can not use ret_var as the result might be higher order
38 id_pb_scalar);
39
40 return id_pb;
41}
42
43const Def* zero_pullback(const Def* E, const Def* A) {
44 auto& world = A->world();
45 auto A_tangent = tangent_type_fun(A);
46 auto pb_ty = pullback_type(E, A);
47 auto pb = world.mut_lam(pb_ty)->set("zero_pb");
48 world.DLOG("zero_pullback for {} resp. {} (-> {})", E, A, A_tangent);
49 pb->app(true, pb->var(1), world.call<zero>(A_tangent));
50 return pb;
51}
52
53// `P` => `P*`
54// TODO: nothing? function => R? Mem => R?
55// TODO: rename to op_tangent_type
56const Def* tangent_type_fun(const Def* ty) { return ty; }
57
58/// computes pb type `E* -> A*`
59/// `E` - type of the expression (return type for a function)
60/// `A` - type of the argument (point of orientation resp. derivative - argument type for partial pullbacks)
61const Pi* pullback_type(const Def* E, const Def* A) {
62 auto& world = E->world();
63 auto tang_arg = tangent_type_fun(A);
64 auto tang_ret = tangent_type_fun(E);
65 auto pb_ty = world.cn({tang_ret, world.cn(tang_arg)});
66 return pb_ty;
67}
68
69namespace {
70// `A,R` => `(A->R)' = A' -> R' * (R* -> A*)`
71const Pi* autodiff_type_fun(const Def* arg, const Def* ret) {
72 auto& world = arg->world();
73 world.DLOG("autodiff type for {} => {}", arg, ret);
76 world.DLOG("augmented types: {} => {}", aug_arg, aug_ret);
77 if (!aug_arg || !aug_ret) return nullptr;
78 // `Q* -> P*`
79 auto pb_ty = pullback_type(ret, arg);
80 world.DLOG("pb type: {}", pb_ty);
81 // `P' -> Q' * (Q* -> P*)`
82
83 auto deriv_ty = world.cn({aug_arg, world.cn({aug_ret, pb_ty})});
84 world.DLOG("autodiff type: {}", deriv_ty);
85 return deriv_ty;
86}
87} // namespace
88
89const Pi* autodiff_type_fun_pi(const Pi* pi) {
90 auto& world = pi->world();
91 if (!Pi::isa_cn(pi)) {
92 // TODO: dependency
93 auto arg = pi->dom();
94 auto ret = pi->codom();
95 if (ret->isa<Pi>()) {
96 auto aug_arg = autodiff_type_fun(arg);
97 if (!aug_arg) return nullptr;
98 auto aug_ret = autodiff_type_fun(pi->codom());
99 if (!aug_ret) return nullptr;
100 return world.pi(aug_arg, aug_ret);
101 }
102 return autodiff_type_fun(arg, ret);
103 }
104 auto [arg, ret_pi] = pi->doms<2>();
105 auto ret = ret_pi->as<Pi>()->dom();
106 world.DLOG("compute AD type for pi");
107 return autodiff_type_fun(arg, ret);
108}
109
110// In general transforms `A` => `A'`.
111// Especially `P->Q` => `P'->Q' * (Q* -> P*)`.
112const Def* autodiff_type_fun(const Def* ty) {
113 auto& world = ty->world();
114 // TODO: handle DS (operators)
115 if (auto pi = ty->isa<Pi>()) return autodiff_type_fun_pi(pi);
116 // Also handles autodiff call from axm declaration => abstract => leave it.
117 world.DLOG("AutoDiff on type: {} <{}>", ty, ty->node_name());
118 if (Idx::isa(ty)) return ty;
119 if (ty == world.type_nat()) return ty;
120 if (auto arr = ty->isa<Arr>()) {
121 auto shape = arr->arity();
122 auto body = arr->body();
123 auto body_ad = autodiff_type_fun(body);
124 if (!body_ad) return nullptr;
125 return world.arr(shape, body_ad);
126 }
127 if (auto sig = ty->isa<Sigma>()) {
128 // TODO: mut sigma
129 auto ops = DefVec(sig->ops(), [&](const Def* op) { return autodiff_type_fun(op); });
130 world.DLOG("ops: {,}", ops);
131 return world.sigma(ops);
132 }
133 // mem
134 if (Axm::isa<mem::M>(ty)) return ty;
135 world.WLOG("no-diff type: {}", ty);
136 return nullptr;
137}
138
139const Def* zero_def(const Def* T) {
140 // TODO: we want: zero mem -> zero mem or bot
141 // zero [A,B,C] -> [zero A, zero B, zero C]
142 auto& world = T->world();
143 world.DLOG("zero_def for type {} <{}>", T, T->node_name());
144 if (auto arr = T->isa<Arr>()) {
145 auto arity = arr->arity();
146 auto body = arr->body();
147 auto inner_zero = world.app(world.annex<zero>(), body);
148 auto zero_arr = world.pack(arity, inner_zero);
149 world.DLOG("zero_def for array of shape {} with type {}", arity, body);
150 world.DLOG("zero_arr: {}", zero_arr);
151 return zero_arr;
152 } else if (Idx::isa(T)) {
153 // TODO: real
154 auto zero = world.lit(T, 0)->set("zero");
155 world.DLOG("zero_def for int is {}", zero);
156 return zero;
157 } else if (auto sig = T->isa<Sigma>()) {
158 auto ops = DefVec(sig->ops(), [&](const Def* op) { return world.app(world.annex<zero>(), op); });
159 return world.tuple(ops);
160 }
161
162 // or return bot
163 // or id => zero T
164 // return world.app(world.annex<zero>(), T);
165 return nullptr;
166}
167
168const Def* op_sum(const Def* T, Defs defs) {
169 // TODO: assert all are of type T
170 auto& world = T->world();
171 return world.app(world.app(world.annex<sum>(), {world.lit_nat(defs.size()), T}), defs);
172}
173
174} // namespace mim::plug::autodiff
void reg_stages(Flags2Phases &, Flags2Passes &passes)
Definition affine.cpp:12
void reg_stages(Flags2Phases &, Flags2Passes &passes)
Definition autodiff.cpp:17
A (possibly paramterized) Array.
Definition tuple.h:117
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
World & world() const noexcept
Definition def.cpp:436
std::string_view node_name() const
Definition def.cpp:454
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:605
static void hook(Flags2Passes &passes, Args &&... args)
Definition pass.h:157
A dependent function type.
Definition lam.h:13
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:46
const Def * dom() const
Definition lam.h:34
const Def * codom() const
Definition lam.h:35
A dependent tuple type.
Definition tuple.h:20
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:196
#define MIM_EXPORT
Definition config.h:16
The automatic differentiation Plugin
Definition autodiff.h:6
const Pi * autodiff_type_fun_pi(const Pi *)
Definition autodiff.cpp:89
const Def * op_sum(const Def *T, Defs)
Definition autodiff.cpp:168
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:112
const Def * zero_def(const Def *T)
Definition autodiff.cpp:139
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:56
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:43
const Def * id_pullback(const Def *)
Definition autodiff.cpp:31
void register_normalizers(Normalizers &normalizers)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Definition autodiff.cpp:61
Definition ast.h:14
View< const Def * > Defs
Definition def.h:76
Vector< const Def * > DefVec
Definition def.h:77
absl::flat_hash_map< flags_t, std::function< void(PassMan &, const Def *)> > Flags2Passes
Definition plugin.h:24
mim::Plugin mim_get_plugin()
absl::flat_hash_map< flags_t, NormalizeFn > Normalizers
Definition plugin.h:20
absl::flat_hash_map< flags_t, std::function< void(PhaseMan &, const Def *)> > Flags2Phases
Maps an an axiom of a Pass/Phaseto a function that appneds a new Pass/Phase to a PhaseMan.
Definition plugin.h:23
Basic info and registration function pointer to be returned from a specific plugin.
Definition plugin.h:32