MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff_rewrite_inner.cpp
Go to the documentation of this file.
2
5
6using namespace std::literals;
7
8namespace mim::plug::autodiff {
9
10const Def* Eval::augment_lit(const Lit* lit, Lam* f, Lam*) {
11 auto pb = zero_pullback(lit->type(), f->dom(2, 0));
12 partial_pullback[lit] = pb;
13 return lit;
14}
15
16const Def* Eval::augment_var(const Var* var, Lam*, Lam*) {
17 assert(augmented.count(var));
18 auto aug_var = augmented[var];
19 assert(partial_pullback.count(aug_var));
20 return var;
21}
22
23const Def* Eval::augment_lam(Lam* lam, Lam* f, Lam* f_diff) {
24 // TODO: we need partial pullbacks for tuples (higher-order / ret-cont application)
25 // also for higher-order args, ret_cont (at another point)
26 // the pullback is not important but formally required by tuple rule
27 if (augmented.count(lam)) {
28 // We already know the function:
29 // * recursion
30 // * higher order arguments
31 // * new encounter of previous function
32 DLOG("already augmented {} : {} to {} : {}", lam, lam->type(), augmented[lam], augmented[lam]->type());
33 return augmented[lam];
34 }
35 // TODO: better fix (another pass as analysis?)
36 // TODO: handle open functions
37 if (Lam::isa_basicblock(lam) || lam->sym().view().find("ret") != std::string::npos
38 || lam->sym().view().find("_cont") != std::string::npos) {
39 // A open continuation behaves the same as return:
40 // ```
41 // cont: Cn[X]
42 // cont': Cn[X,Cn[X,A]]
43 // ```
44 // There is dependency on the closed function context.
45 // (All derivatives are with respect to the arguments of a closed function.)
46
47 DLOG("found an open continuation {} : {}", lam, lam->type());
48 auto cont_dom = lam->type()->dom(); // not only 0 but all
49 auto pb_ty = pullback_type(cont_dom, f->dom(2, 0));
50 auto aug_dom = autodiff_type_fun(cont_dom);
51 DLOG("augmented domain {}", aug_dom);
52 DLOG("pb type is {}", pb_ty);
53 auto aug_lam = world().mut_con({aug_dom, pb_ty})->set("aug_"s + lam->sym().str());
54 auto aug_var = aug_lam->var((nat_t)0);
55 augmented[lam->var()] = aug_var;
56 augmented[lam] = aug_lam; // TODO: only one of these two
57 derived[lam] = aug_lam;
58 auto pb = aug_lam->var(1);
59 partial_pullback[aug_var] = pb;
60 // We are still in same closed function.
61 auto new_body = augment(lam->body(), f, f_diff);
62 // TODO we also need to rewrite the filter
63 aug_lam->set(lam->filter(), new_body);
64
65 auto lam_pb = zero_pullback(lam->type(), f->dom(2, 0));
66 partial_pullback[aug_lam] = lam_pb;
67 DLOG("augmented {} : {}", lam, lam->type());
68 DLOG("to {} : {}", aug_lam, aug_lam->type());
69 DLOG("ppb for lam cont: {}", lam_pb);
70
71 return aug_lam;
72 }
73 DLOG("found a closed function call {} : {}", lam, lam->type());
74 // Some general function in the program needs to be differentiated.
75 auto aug_lam = world().call<ad>(lam);
76 // TODO: directly more association here? => partly inline op_autodiff
77 DLOG("augmented function is {} : {}", aug_lam, aug_lam->type());
78 return aug_lam;
79}
80
81const Def* Eval::augment_extract(const Extract* ext, Lam* f, Lam* f_diff) {
82 auto tuple = ext->tuple();
83 auto index = ext->index();
84
85 auto aug_tuple = augment(tuple, f, f_diff);
86 auto aug_index = augment(index, f, f_diff);
87
88 const Def* pb;
89 DLOG("tuple was: {} : {}", tuple, tuple->type());
90 DLOG("aug tuple: {} : {}", aug_tuple, aug_tuple->type());
91 if (shadow_pullback.count(aug_tuple)) {
92 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
93 DLOG("Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
94 pb = world().extract(shadow_tuple_pb, aug_index);
95 } else {
96 // ```
97 // e:T, b:B
98 // b = e#i
99 // b* = \lambda (s:B). e* (insert s at i in (zero T))
100 // ```
101 assert(partial_pullback.count(aug_tuple));
102 auto tuple_pb = partial_pullback[aug_tuple];
103 auto pb_ty = pullback_type(ext->type(), f->dom(2, 0));
104 auto pb_fun = world().mut_lam(pb_ty)->set("extract_pb");
105 DLOG("Pullback: {} : {}", pb_fun, pb_fun->type());
106 auto pb_tangent = pb_fun->var(0_s)->set("s");
107 auto tuple_tan = world().insert(world().call<zero>(aug_tuple->type()), aug_index, pb_tangent)->set("tup_s");
108 pb_fun->app(true, tuple_pb, {tuple_tan, pb_fun->var(1) /* ret_var but make sure to select correct one */});
109 pb = pb_fun;
110 }
111
112 auto aug_ext = world().extract(aug_tuple, aug_index);
113 partial_pullback[aug_ext] = pb;
114
115 return aug_ext;
116}
117
118const Def* Eval::augment_tuple(const Tuple* tup, Lam* f, Lam* f_diff) {
119 // TODO: should use ops instead?
120 auto aug_ops = tup->projs([&](const Def* op) -> const Def* { return augment(op, f, f_diff); });
121 auto aug_tup = world().tuple(aug_ops);
122
123 auto pbs = DefVec(Defs(aug_ops), [&](const Def* op) { return partial_pullback[op]; });
124 DLOG("tuple pbs {,}", pbs);
125 // shadow pb = tuple of pbs
126 auto shadow_pb = world().tuple(pbs);
127 shadow_pullback[aug_tup] = shadow_pb;
128
129 // ```
130 // \lambda (s:[E0,...,Em]).
131 // sum (m,A)
132 // ((cps2ds e0*) (s#0), ..., (cps2ds em*) (s#m))
133 // ```
134 auto pb_ty = pullback_type(tup->type(), f->dom(2, 0));
135 auto pb = world().mut_lam(pb_ty)->set("tup_pb");
136 DLOG("Augmented tuple: {} : {}", aug_tup, aug_tup->type());
137 DLOG("Tuple Pullback: {} : {}", pb, pb->type());
138 DLOG("shadow pb: {} : {}", shadow_pb, shadow_pb->type());
139
140 auto pb_tangent = pb->var(0_s)->set("tup_s");
141
142 auto tangents = DefVec(pbs.size(), [&](nat_t i) {
143 return world().app(direct::op_cps2ds_dep(pbs[i]), world().extract(pb_tangent, i));
144 });
145 pb->app(true, pb->var(1),
146 // summed up tangents
147 op_sum(tangent_type_fun(f->dom(2, 0)), tangents));
148 partial_pullback[aug_tup] = pb;
149
150 return aug_tup;
151}
152
153const Def* Eval::augment_pack(const Pack* pack, Lam* f, Lam* f_diff) {
154 auto arity = pack->arity(); // TODO: arity vs shape
155 auto body = pack->body();
156
157 auto aug_arity = augment_(arity, f, f_diff);
158 auto aug_body = augment(body, f, f_diff);
159
160 auto aug_pack = world().pack(aug_arity, aug_body);
161
162 assert(partial_pullback[aug_body] && "pack pullback should exists");
163 // TODO: or use scale axm
164 auto body_pb = partial_pullback[aug_body];
165 auto pb_pack = world().pack(aug_arity, body_pb);
166 shadow_pullback[aug_pack] = pb_pack;
167
168 DLOG("shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
169
170 auto pb_type = pullback_type(pack->type(), f->dom(2, 0));
171 auto pb = world().mut_lam(pb_type)->set("pack_pb");
172
173 DLOG("pb of pack: {} : {}", pb, pb_type);
174
175 auto f_arg_ty_diff = tangent_type_fun(f->dom(2, 0));
176 auto app_pb = world().mut_pack(world().arr(aug_arity, f_arg_ty_diff));
177
178 // TODO: special case for const width (special tuple)
179
180 // <i:n, cps2ds body_pb (s#i)>
181 app_pb->set(world().app(direct::op_cps2ds_dep(body_pb), world().extract(pb->var((nat_t)0), app_pb->var())));
182
183 DLOG("app pb of pack: {} : {}", app_pb, app_pb->type());
184
185 auto sumup = world().app(world().annex<sum>(), {aug_arity, f_arg_ty_diff});
186 DLOG("sumup: {} : {}", sumup, sumup->type());
187
188 pb->app(true, pb->var(1), world().app(sumup, app_pb));
189
190 partial_pullback[aug_pack] = pb;
191
192 return aug_pack;
193}
194
195const Def* Eval::augment_app(const App* app, Lam* f, Lam* f_diff) {
196 auto callee = app->callee();
197 auto arg = app->arg();
198
199 auto aug_arg = augment(arg, f, f_diff);
200 auto aug_callee = augment(callee, f, f_diff);
201
202 DLOG("augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
203 DLOG("augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
204 // TODO: move down to if(!is_cont(callee))
205 if (!Pi::isa_cn(callee->type()) && Pi::isa_cn(aug_callee->type())) {
206 aug_callee = direct::op_cps2ds_dep(aug_callee);
207 DLOG("wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
208 }
209
210 // nested (inner application)
211 if (app->type()->isa<Pi>()) {
212 DLOG("Nested application callee: {} : {}", aug_callee, aug_callee->type());
213 DLOG("Nested application arg: {} : {}", aug_arg, aug_arg->type());
214 auto aug_app = world().app(aug_callee, aug_arg);
215 DLOG("Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
216 // We do not add a pullback as the pullback is bundled in the cps call or returned by the ds call
217 return aug_app;
218 }
219
220 // continuation (ret, if, ...)
221 if (Pi::isa_basicblock(callee->type())) {
222 // TODO: check if function (not operator)
223 // The original function is an open function (return cont / continuation) of type `Cn[E]`
224 // The augmented function `aug_callee` looks like a function but is not really a function has the type `Cn[E,
225 // Cn[E, Cn[A]]]`
226
227 // ret(e) => ret'(e, e*)
228
229 DLOG("continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
230
231 auto arg_pb = partial_pullback[aug_arg];
232 auto aug_app = world().app(aug_callee, {aug_arg, arg_pb});
233 DLOG("Augmented application: {} : {}", aug_app, aug_app->type());
234 return aug_app;
235 }
236
237 // ds function
238 if (!Pi::isa_cn(callee->type())) {
239 auto aug_app = world().app(aug_callee, aug_arg);
240 DLOG("Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
241
242 DLOG("ds function: {} : {}", aug_app, aug_app->type());
243 // The calle is ds function (e.g. operator (or its partial application))
244 auto [aug_res, fun_pb] = aug_app->projs<2>();
245 // We compose `fun_pb` with `argument_pb` to get the result pb
246 // TODO: combine case with cps function case
247 auto arg_pb = partial_pullback[aug_arg];
248 assert(arg_pb);
249 // `fun_pb: out_tan -> arg_tan`
250 // `arg_pb: arg_tan -> fun_tan`
251 DLOG("function pullback: {} : {}", fun_pb, fun_pb->type());
252 DLOG("argument pullback: {} : {}", arg_pb, arg_pb->type());
253 auto res_pb = compose_cn(arg_pb, fun_pb);
254 DLOG("result pullback: {} : {}", res_pb, res_pb->type());
255 partial_pullback[aug_res] = res_pb;
256 world().debug_dump();
257 return aug_res;
258 }
259
260 // TODO: dest with a function such that f args != g args
261 {
262 // normal function app
263 // ```
264 // g: cn[E, cn X]
265 // g(args,cont)
266 // g': cn[E, cn[X, cn[X, cn E]]]
267 // g'(aug_args, ____)
268 // ```
269 auto g = callee;
270 // At this point g_deriv might still be "autodiff ... g".
271 auto g_deriv = aug_callee;
272 DLOG("g: {} : {}", g, g->type());
273 DLOG("g': {} : {}", g_deriv, g_deriv->type());
274
275 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
276 DLOG("real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
277 DLOG("aug_cont: {} : {}", aug_cont, aug_cont->type());
278 auto e_pb = partial_pullback[real_aug_args];
279 DLOG("e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
280
281 // TODO: better debug names
282 auto ret_g_deriv_ty = g_deriv->type()->as<Pi>()->dom(1);
283 DLOG("ret_g_deriv_ty: {} ", ret_g_deriv_ty);
284 auto c1_ty = ret_g_deriv_ty->as<Pi>();
285 DLOG("c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
286 auto c1 = world().mut_lam(c1_ty)->set("c1");
287 auto res = c1->var((nat_t)0);
288 auto r_pb = c1->var(1);
289 c1->app(true, aug_cont, {res, compose_cn(e_pb, r_pb)});
290
291 auto aug_app = world().app(aug_callee, {real_aug_args, c1});
292 DLOG("aug_app: {} : {}", aug_app, aug_app->type());
293
294 // The result is * => no pb needed, no composition needed.
295 return aug_app;
296 }
297 assert(false && "should not be reached");
298}
299
300/// Rewrites the given definition in a lambda environment.
301const Def* Eval::augment_(const Def* def, Lam* f, Lam* f_diff) {
302 // We use macros above to avoid recomputation.
303 // TODO: Alternative:
304 // Use class instances to rewrite inside a function and save such values (f, f_diff, f->dom(2, 0)).
305
306 DLOG("Augment def {} : {}", def, def->type());
307
308 // Applications are continuations, operators, or full functions
309 if (auto app = def->isa<App>()) {
310 auto callee = app->callee();
311 auto arg = app->arg();
312 DLOG("Augment application: app {} with {}", callee, arg);
313 return augment_app(app, f, f_diff);
314 } else if (auto ext = def->isa<Extract>()) {
315 auto tuple = ext->tuple();
316 auto index = ext->index();
317 DLOG("Augment extract: {} #[{}]", tuple, index);
318 return augment_extract(ext, f, f_diff);
319 } else if (auto var = def->isa<Var>()) {
320 DLOG("Augment variable: {}", var);
321 return augment_var(var, f, f_diff);
322 } else if (auto lam = def->isa_mut<Lam>()) {
323 DLOG("Augment mut lambda: {}", lam);
324 return augment_lam(lam, f, f_diff);
325 } else if (auto lam = def->isa<Lam>()) {
326 ELOG("Augment lambda: {}", lam);
327 assert(false && "can not handle non-mutable lambdas");
328 } else if (auto lit = def->isa<Lit>()) {
329 DLOG("Augment literal: {}", def);
330 return augment_lit(lit, f, f_diff);
331 } else if (auto tup = def->isa<Tuple>()) {
332 DLOG("Augment tuple: {}", def);
333 return augment_tuple(tup, f, f_diff);
334 } else if (auto pack = def->isa<Pack>()) {
335 // TODO: handle mut packs (dependencies in the pack) (=> see paper about vectors)
336 auto arity = pack->arity(); // TODO: arity vs shape
337 auto body = pack->body();
338 DLOG("Augment pack: {} : {} with {}", arity, arity->type(), body);
339 return augment_pack(pack, f, f_diff);
340 } else if (auto ax = def->isa<Axm>()) {
341 // TODO: move concrete handling to own function / file / directory (file per plugin)
342 DLOG("Augment axm: {} : {}", ax, ax->type());
343 DLOG("axm curry: {}", ax->curry());
344 DLOG("axm flags: {}", ax->flags());
345 auto diff_name = ax->sym().str();
346 find_and_replace(diff_name, ".", "_");
347 find_and_replace(diff_name, "%", "");
348 diff_name = "internal_diff_" + diff_name;
349 DLOG("axm name: {}", ax->sym());
350 DLOG("axm function name: {}", diff_name);
351
352 auto diff_fun = world().external(world().sym(diff_name));
353 if (!diff_fun) {
354 ELOG("derivation not found: {}", diff_name);
355 auto expected_type = autodiff_type_fun(ax->type());
356 ELOG("expected: {} : {}", diff_name, expected_type);
357 assert(false && "unhandled axm");
358 }
359 // TODO: why does this cause a depth error?
360 return diff_fun;
361 }
362
363 // TODO: handle Pi for axm app
364 // TODO: remaining (lambda, axm)
365
366 ELOG("did not expect to augment: {} : {}", def, def->type());
367 ELOG("node: {}", def->node_name());
368 assert(false && "augment not implemented on this def");
369 fe::unreachable();
370}
371
372} // namespace mim::plug::autodiff
const Def * callee() const
Definition lam.h:277
const Def * arg() const
Definition lam.h:286
Definition axm.h:9
Base class for all Defs.
Definition def.h:251
Def * set(size_t i, const Def *)
Successively set from left to right.
Definition def.cpp:266
std::string_view node_name() const
Definition def.cpp:454
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:482
const Def * var(nat_t a, nat_t i) noexcept
Definition def.h:429
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:390
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.h:295
Sym sym() const
Definition def.h:504
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:206
const Def * tuple() const
Definition tuple.h:216
const Def * index() const
Definition tuple.h:217
A function.
Definition lam.h:111
const Def * filter() const
Definition lam.h:123
Lam * set(Filter filter, const Def *body)
Definition lam.cpp:29
const Pi * type() const
Definition lam.h:131
static const Lam * isa_basicblock(const Def *d)
Definition lam.h:143
const Def * body() const
Definition lam.h:124
A (possibly paramterized) Tuple.
Definition tuple.h:166
const Def * arity() const final
Definition tuple.cpp:46
Pack * set(const Def *body)
Definition tuple.h:183
size_t index() const
Definition pass.h:99
A dependent function type.
Definition lam.h:15
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:48
static const Pi * isa_basicblock(const Def *d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Definition lam.h:52
const Def * dom() const
Definition lam.h:36
const Def * body() const
Definition tuple.h:91
World & world()
Definition pass.h:64
flags_t annex() const
Definition pass.h:68
Data constructor for a Sigma.
Definition tuple.h:68
const Def * insert(const Def *d, const Def *i, const Def *val)
Definition world.cpp:425
const Def * pack(const Def *arity, const Def *body)
Definition world.h:379
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:196
const Def * tuple(Defs ops)
Definition world.cpp:288
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Definition dump.cpp:493
const Def * extract(const Def *d, const Def *i)
Definition world.cpp:346
Pack * mut_pack(const Def *type)
Definition world.h:377
Def * external(Sym name)
Lookup by name.
Definition world.h:208
const Def * call(Id id, Args &&... args)
Complete curried call of annexes obeying implicits.
Definition world.h:546
Lam * mut_con(const Def *dom)
Definition world.h:309
Lam * mut_lam(const Pi *pi)
Definition world.h:297
const Def * augment_lit(const Lit *, Lam *, Lam *)
const Def * augment_tuple(const Tuple *, Lam *, Lam *)
const Def * augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
const Def * augment_app(const App *, Lam *, Lam *)
const Def * augment_lam(Lam *, Lam *, Lam *)
const Def * augment_(const Def *, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Definition eval.cpp:10
const Def * augment_extract(const Extract *, Lam *, Lam *)
const Def * augment_var(const Var *, Lam *, Lam *)
helper functions for augment
#define ELOG(...)
Definition log.h:89
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
The automatic differentiation Plugin
Definition autodiff.h:6
const Def * op_sum(const Def *T, Defs)
Definition autodiff.cpp:167
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:111
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:55
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:42
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Definition autodiff.cpp:60
const Def * op_cps2ds_dep(const Def *k)
Definition direct.h:15
The tuple Plugin
View< const Def * > Defs
Definition def.h:76
u64 nat_t
Definition types.h:43
Vector< const Def * > DefVec
Definition def.h:77
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
Definition util.h:74
const Def * compose_cn(const Def *f, const Def *g)
The high level view is:
Definition lam.cpp:62