MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff_rewrite_inner.cpp
Go to the documentation of this file.
1#include <algorithm>
2#include <string>
3
4#include "mim/tuple.h"
5
10
11using namespace std::literals;
12
13namespace mim::plug::autodiff {
14
16 auto pb = zero_pullback(lit->type(), f->dom(2, 0));
17 partial_pullback[lit] = pb;
18 return lit;
19}
20
22 assert(augmented.count(var));
23 auto aug_var = augmented[var];
24 assert(partial_pullback.count(aug_var));
25 return var;
26}
27
29 auto& world = lam->world();
30 // TODO: we need partial pullbacks for tuples (higher-order / ret-cont application)
31 // also for higher-order args, ret_cont (at another point)
32 // the pullback is not important but formally required by tuple rule
33 if (augmented.count(lam)) {
34 // We already know the function:
35 // * recursion
36 // * higher order arguments
37 // * new encounter of previous function
38 world.DLOG("already augmented {} : {} to {} : {}", lam, lam->type(), augmented[lam], augmented[lam]->type());
39 return augmented[lam];
40 }
41 // TODO: better fix (another pass as analysis?)
42 // TODO: handle open functions
43 if (Lam::isa_basicblock(lam) || lam->sym().view().find("ret") != std::string::npos
44 || lam->sym().view().find("_cont") != std::string::npos) {
45 // A open continuation behaves the same as return:
46 // ```
47 // cont: Cn[X]
48 // cont': Cn[X,Cn[X,A]]
49 // ```
50 // There is dependency on the closed function context.
51 // (All derivatives are with respect to the arguments of a closed function.)
52
53 world.DLOG("found an open continuation {} : {}", lam, lam->type());
54 auto cont_dom = lam->type()->dom(); // not only 0 but all
55 auto pb_ty = pullback_type(cont_dom, f->dom(2, 0));
56 auto aug_dom = autodiff_type_fun(cont_dom);
57 world.DLOG("augmented domain {}", aug_dom);
58 world.DLOG("pb type is {}", pb_ty);
59 auto aug_lam = world.mut_con({aug_dom, pb_ty})->set("aug_"s + lam->sym().str());
60 auto aug_var = aug_lam->var((nat_t)0);
61 augmented[lam->var()] = aug_var;
62 augmented[lam] = aug_lam; // TODO: only one of these two
63 derived[lam] = aug_lam;
64 auto pb = aug_lam->var(1);
65 partial_pullback[aug_var] = pb;
66 // We are still in same closed function.
67 auto new_body = augment(lam->body(), f, f_diff);
68 // TODO we also need to rewrite the filter
69 aug_lam->set(lam->filter(), new_body);
70
71 auto lam_pb = zero_pullback(lam->type(), f->dom(2, 0));
72 partial_pullback[aug_lam] = lam_pb;
73 world.DLOG("augmented {} : {}", lam, lam->type());
74 world.DLOG("to {} : {}", aug_lam, aug_lam->type());
75 world.DLOG("ppb for lam cont: {}", lam_pb);
76
77 return aug_lam;
78 }
79 world.DLOG("found a closed function call {} : {}", lam, lam->type());
80 // Some general function in the program needs to be differentiated.
81 auto aug_lam = world.call<ad>(lam);
82 // TODO: directly more association here? => partly inline op_autodiff
83 world.DLOG("augmented function is {} : {}", aug_lam, aug_lam->type());
84 return aug_lam;
85}
86
88 auto& world = ext->world();
89
90 auto tuple = ext->tuple();
91 auto index = ext->index();
92
93 auto aug_tuple = augment(tuple, f, f_diff);
94 auto aug_index = augment(index, f, f_diff);
95
96 Ref pb;
97 world.DLOG("tuple was: {} : {}", tuple, tuple->type());
98 world.DLOG("aug tuple: {} : {}", aug_tuple, aug_tuple->type());
99 if (shadow_pullback.count(aug_tuple)) {
100 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
101 world.DLOG("Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
102 pb = world.extract(shadow_tuple_pb, aug_index);
103 } else {
104 // ```
105 // e:T, b:B
106 // b = e#i
107 // b* = \lambda (s:B). e* (insert s at i in (zero T))
108 // ```
109 assert(partial_pullback.count(aug_tuple));
110 auto tuple_pb = partial_pullback[aug_tuple];
111 auto pb_ty = pullback_type(ext->type(), f->dom(2, 0));
112 auto pb_fun = world.mut_lam(pb_ty)->set("extract_pb");
113 world.DLOG("Pullback: {} : {}", pb_fun, pb_fun->type());
114 auto pb_tangent = pb_fun->var(0_s)->set("s");
115 auto tuple_tan = world.insert(world.call<zero>(aug_tuple->type()), aug_index, pb_tangent)->set("tup_s");
116 pb_fun->app(true, tuple_pb, {tuple_tan, pb_fun->var(1) /* ret_var but make sure to select correct one */});
117 pb = pb_fun;
118 }
119
120 auto aug_ext = world.extract(aug_tuple, aug_index);
121 partial_pullback[aug_ext] = pb;
122
123 return aug_ext;
124}
125
126Ref AutoDiffEval::augment_tuple(const Tuple* tup, Lam* f, Lam* f_diff) {
127 auto& world = tup->world();
128
129 // TODO: should use ops instead?
130 auto aug_ops = tup->projs([&](Ref op) -> const Def* { return augment(op, f, f_diff); });
131 auto aug_tup = world.tuple(aug_ops);
132
133 auto pbs = DefVec(Defs(aug_ops), [&](Ref op) { return partial_pullback[op]; });
134 world.DLOG("tuple pbs {,}", pbs);
135 // shadow pb = tuple of pbs
136 auto shadow_pb = world.tuple(pbs);
137 shadow_pullback[aug_tup] = shadow_pb;
138
139 // ```
140 // \lambda (s:[E0,...,Em]).
141 // sum (m,A)
142 // ((cps2ds e0*) (s#0), ..., (cps2ds em*) (s#m))
143 // ```
144 auto pb_ty = pullback_type(tup->type(), f->dom(2, 0));
145 auto pb = world.mut_lam(pb_ty)->set("tup_pb");
146 world.DLOG("Augmented tuple: {} : {}", aug_tup, aug_tup->type());
147 world.DLOG("Tuple Pullback: {} : {}", pb, pb->type());
148 world.DLOG("shadow pb: {} : {}", shadow_pb, shadow_pb->type());
149
150 auto pb_tangent = pb->var(0_s)->set("tup_s");
151
152 auto tangents = DefVec(
153 pbs.size(), [&](nat_t i) { return world.app(direct::op_cps2ds_dep(pbs[i]), world.extract(pb_tangent, i)); });
154 pb->app(true, pb->var(1),
155 // summed up tangents
156 op_sum(tangent_type_fun(f->dom(2, 0)), tangents));
157 partial_pullback[aug_tup] = pb;
158
159 return aug_tup;
160}
161
162Ref AutoDiffEval::augment_pack(const Pack* pack, Lam* f, Lam* f_diff) {
163 auto& world = pack->world();
164 auto shape = pack->arity(); // TODO: arity vs shape
165 auto body = pack->body();
166
167 auto aug_shape = augment_(shape, f, f_diff);
168 auto aug_body = augment(body, f, f_diff);
169
170 auto aug_pack = world.pack(aug_shape, aug_body);
171
172 assert(partial_pullback[aug_body] && "pack pullback should exists");
173 // TODO: or use scale axiom
174 auto body_pb = partial_pullback[aug_body];
175 auto pb_pack = world.pack(aug_shape, body_pb);
176 shadow_pullback[aug_pack] = pb_pack;
177
178 world.DLOG("shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
179
180 auto pb_type = pullback_type(pack->type(), f->dom(2, 0));
181 auto pb = world.mut_lam(pb_type)->set("pack_pb");
182
183 world.DLOG("pb of pack: {} : {}", pb, pb_type);
184
185 auto f_arg_ty_diff = tangent_type_fun(f->dom(2, 0));
186 auto app_pb = world.mut_pack(world.arr(aug_shape, f_arg_ty_diff));
187
188 // TODO: special case for const width (special tuple)
189
190 // <i:n, cps2ds body_pb (s#i)>
191 app_pb->set(world.app(direct::op_cps2ds_dep(body_pb), world.extract(pb->var((nat_t)0), app_pb->var())));
192
193 world.DLOG("app pb of pack: {} : {}", app_pb, app_pb->type());
194
195 auto sumup = world.app(world.annex<sum>(), {aug_shape, f_arg_ty_diff});
196 world.DLOG("sumup: {} : {}", sumup, sumup->type());
197
198 pb->app(true, pb->var(1), world.app(sumup, app_pb));
199
200 partial_pullback[aug_pack] = pb;
201
202 return aug_pack;
203}
204
205Ref AutoDiffEval::augment_app(const App* app, Lam* f, Lam* f_diff) {
206 auto& world = app->world();
207
208 auto callee = app->callee();
209 auto arg = app->arg();
210
211 auto aug_arg = augment(arg, f, f_diff);
212 auto aug_callee = augment(callee, f, f_diff);
213
214 world.DLOG("augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
215 world.DLOG("augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
216 // TODO: move down to if(!is_cont(callee))
217 if (!Pi::isa_cn(callee->type()) && Pi::isa_cn(aug_callee->type())) {
218 aug_callee = direct::op_cps2ds_dep(aug_callee);
219 world.DLOG("wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
220 }
221
222 // nested (inner application)
223 if (app->type()->isa<Pi>()) {
224 world.DLOG("Nested application callee: {} : {}", aug_callee, aug_callee->type());
225 world.DLOG("Nested application arg: {} : {}", aug_arg, aug_arg->type());
226 auto aug_app = world.app(aug_callee, aug_arg);
227 world.DLOG("Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
228 // We do not add a pullback as the pullback is bundled in the cps call or returned by the ds call
229 return aug_app;
230 }
231
232 // continuation (ret, if, ...)
233 if (Pi::isa_basicblock(callee->type())) {
234 // TODO: check if function (not operator)
235 // The original function is an open function (return cont / continuation) of type `Cn[E]`
236 // The augmented function `aug_callee` looks like a function but is not really a function has the type `Cn[E,
237 // Cn[E, Cn[A]]]`
238
239 // ret(e) => ret'(e, e*)
240
241 world.DLOG("continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
242
243 auto arg_pb = partial_pullback[aug_arg];
244 auto aug_app = world.app(aug_callee, {aug_arg, arg_pb});
245 world.DLOG("Augmented application: {} : {}", aug_app, aug_app->type());
246 return aug_app;
247 }
248
249 // ds function
250 if (!Pi::isa_cn(callee->type())) {
251 auto aug_app = world.app(aug_callee, aug_arg);
252 world.DLOG("Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
253
254 world.DLOG("ds function: {} : {}", aug_app, aug_app->type());
255 // The calle is ds function (e.g. operator (or its partial application))
256 auto [aug_res, fun_pb] = aug_app->projs<2>();
257 // We compose `fun_pb` with `argument_pb` to get the result pb
258 // TODO: combine case with cps function case
259 auto arg_pb = partial_pullback[aug_arg];
260 assert(arg_pb);
261 // `fun_pb: out_tan -> arg_tan`
262 // `arg_pb: arg_tan -> fun_tan`
263 world.DLOG("function pullback: {} : {}", fun_pb, fun_pb->type());
264 world.DLOG("argument pullback: {} : {}", arg_pb, arg_pb->type());
265 auto res_pb = compose_cn(arg_pb, fun_pb);
266 world.DLOG("result pullback: {} : {}", res_pb, res_pb->type());
267 partial_pullback[aug_res] = res_pb;
269 return aug_res;
270 }
271
272 // TODO: dest with a function such that f args != g args
273 {
274 // normal function app
275 // ```
276 // g: cn[E, cn X]
277 // g(args,cont)
278 // g': cn[E, cn[X, cn[X, cn E]]]
279 // g'(aug_args, ____)
280 // ```
281 auto g = callee;
282 // At this point g_deriv might still be "autodiff ... g".
283 auto g_deriv = aug_callee;
284 world.DLOG("g: {} : {}", g, g->type());
285 world.DLOG("g': {} : {}", g_deriv, g_deriv->type());
286
287 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
288 world.DLOG("real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
289 world.DLOG("aug_cont: {} : {}", aug_cont, aug_cont->type());
290 auto e_pb = partial_pullback[real_aug_args];
291 world.DLOG("e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
292
293 // TODO: better debug names
294 auto ret_g_deriv_ty = g_deriv->type()->as<Pi>()->dom(1);
295 world.DLOG("ret_g_deriv_ty: {} ", ret_g_deriv_ty);
296 auto c1_ty = ret_g_deriv_ty->as<Pi>();
297 world.DLOG("c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
298 auto c1 = world.mut_lam(c1_ty)->set("c1");
299 auto res = c1->var((nat_t)0);
300 auto r_pb = c1->var(1);
301 c1->app(true, aug_cont, {res, compose_cn(e_pb, r_pb)});
302
303 auto aug_app = world.app(aug_callee, {real_aug_args, c1});
304 world.DLOG("aug_app: {} : {}", aug_app, aug_app->type());
305
306 // The result is * => no pb needed, no composition needed.
307 return aug_app;
308 }
309 assert(false && "should not be reached");
310}
311
312/// Rewrites the given definition in a lambda environment.
314 auto& world = def->world();
315 // We use macros above to avoid recomputation.
316 // TODO: Alternative:
317 // Use class instances to rewrite inside a function and save such values (f, f_diff, f->dom(2, 0)).
318
319 world.DLOG("Augment def {} : {}", def, def->type());
320
321 // Applications are continuations, operators, or full functions
322 if (auto app = def->isa<App>()) {
323 auto callee = app->callee();
324 auto arg = app->arg();
325 world.DLOG("Augment application: app {} with {}", callee, arg);
326 return augment_app(app, f, f_diff);
327 } else if (auto ext = def->isa<Extract>()) {
328 auto tuple = ext->tuple();
329 auto index = ext->index();
330 world.DLOG("Augment extract: {} #[{}]", tuple, index);
331 return augment_extract(ext, f, f_diff);
332 } else if (auto var = def->isa<Var>()) {
333 world.DLOG("Augment variable: {}", var);
334 return augment_var(var, f, f_diff);
335 } else if (auto lam = def->isa_mut<Lam>()) {
336 world.DLOG("Augment mut lambda: {}", lam);
337 return augment_lam(lam, f, f_diff);
338 } else if (auto lam = def->isa<Lam>()) {
339 world.ELOG("Augment lambda: {}", lam);
340 assert(false && "can not handle non-mutable lambdas");
341 } else if (auto lit = def->isa<Lit>()) {
342 world.DLOG("Augment literal: {}", def);
343 return augment_lit(lit, f, f_diff);
344 } else if (auto tup = def->isa<Tuple>()) {
345 world.DLOG("Augment tuple: {}", def);
346 return augment_tuple(tup, f, f_diff);
347 } else if (auto pack = def->isa<Pack>()) {
348 // TODO: handle mut packs (dependencies in the pack) (=> see paper about vectors)
349 auto shape = pack->arity(); // TODO: arity vs shape
350 auto body = pack->body();
351 world.DLOG("Augment pack: {} : {} with {}", shape, shape->type(), body);
352 return augment_pack(pack, f, f_diff);
353 } else if (auto ax = def->isa<Axiom>()) {
354 // TODO: move concrete handling to own function / file / directory (file per plugin)
355 world.DLOG("Augment axiom: {} : {}", ax, ax->type());
356 world.DLOG("axiom curry: {}", ax->curry());
357 world.DLOG("axiom flags: {}", ax->flags());
358 auto diff_name = ax->sym().str();
359 find_and_replace(diff_name, ".", "_");
360 find_and_replace(diff_name, "%", "");
361 diff_name = "internal_diff_" + diff_name;
362 world.DLOG("axiom name: {}", ax->sym());
363 world.DLOG("axiom function name: {}", diff_name);
364
365 auto diff_fun = world.external(world.sym(diff_name));
366 if (!diff_fun) {
367 world.ELOG("derivation not found: {}", diff_name);
368 auto expected_type = autodiff_type_fun(ax->type());
369 world.ELOG("expected: {} : {}", diff_name, expected_type);
370 assert(false && "unhandled axiom");
371 }
372 // TODO: why does this cause a depth error?
373 return diff_fun;
374 }
375
376 // TODO: handle Pi for axiom app
377 // TODO: remaining (lambda, axiom)
378
379 world.ELOG("did not expect to augment: {} : {}", def, def->type());
380 world.ELOG("node: {}", def->node_name());
381 assert(false && "augment not implemented on this def");
382 fe::unreachable();
383}
384
385} // namespace mim::plug::autodiff
const Def * callee() const
Definition lam.h:218
const Def * arg() const
Definition lam.h:227
Base class for all Defs.
Definition def.h:223
Def * set(size_t i, const Def *def)
Successively set from left to right.
Definition def.cpp:246
std::string_view node_name() const
Definition def.cpp:431
World & world() const
Definition def.cpp:415
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:447
Ref var(nat_t a, nat_t i)
Definition def.h:401
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Definition def.h:367
const Def * type() const
Definition def.h:248
Ref arity() const
Definition def.cpp:481
Sym sym() const
Definition def.h:468
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:152
const Def * tuple() const
Definition tuple.h:162
const Def * index() const
Definition tuple.h:163
A function.
Definition lam.h:103
Ref filter() const
Definition lam.h:113
Lam * set(Filter filter, const Def *body)
Definition lam.h:166
const Pi * type() const
Definition lam.h:115
Ref body() const
Definition lam.h:114
static const Lam * isa_basicblock(Ref d)
Definition lam.h:139
A (possibly paramterized) Tuple.
Definition tuple.h:112
const Def * body() const
Definition tuple.h:122
Pack * set(const Def *body)
Definition tuple.h:130
World & world()
Definition pass.h:296
size_t index() const
Definition pass.h:33
A dependent function type.
Definition lam.h:11
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:50
static const Pi * isa_basicblock(Ref d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Definition lam.h:54
Ref dom() const
Definition lam.h:32
Helper class to retrieve Infer::arg if present.
Definition def.h:86
Data constructor for a Sigma.
Definition tuple.h:49
Ref tuple(Defs ops)
Definition world.cpp:238
Lam * mut_con(Ref dom)
Definition world.h:296
const Def * annex(Id id)
Lookup annex by Axiom::id.
Definition world.h:185
Ref insert(Ref d, Ref i, Ref val)
Definition world.cpp:352
Ref var(Ref type, Def *mut)
Definition world.cpp:156
Ref arr(Ref shape, Ref body)
Definition world.cpp:394
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Definition dump.cpp:489
Ref app(Ref callee, Ref arg)
Definition world.cpp:186
Sym sym(std::string_view)
Definition world.cpp:77
Def * external(Sym name)
Lookup by name.
Definition world.h:182
const Def * call(Id id, Args &&... args)
Definition world.h:518
Ref pack(Ref arity, Ref body)
Definition world.cpp:421
Ref extract(Ref d, Ref i)
Definition world.cpp:290
const Type * type(Ref level)
Definition world.cpp:94
Pack * mut_pack(Ref type)
Definition world.h:353
Lam * mut_lam(const Pi *pi)
Definition world.h:284
Ref augment_var(const Var *, Lam *, Lam *)
helper functions for augment
Ref augment_lit(const Lit *, Lam *, Lam *)
Ref augment_tuple(const Tuple *, Lam *, Lam *)
Ref augment_(Ref, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Ref augment_app(const App *, Lam *, Lam *)
Ref augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
Ref augment_extract(const Extract *, Lam *, Lam *)
The automatic differentiation Plugin
Definition autodiff.h:6
const Def * op_sum(const Def *T, Defs)
Definition autodiff.cpp:169
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:113
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:57
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:44
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Definition autodiff.cpp:62
const Def * op_cps2ds_dep(const Def *f)
Definition direct.h:11
View< const Def * > Defs
Definition def.h:61
u64 nat_t
Definition types.h:43
Vector< const Def * > DefVec
Definition def.h:62
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
Definition util.h:70
const Def * compose_cn(const Def *f, const Def *g)
The high level view is:
Definition lam.cpp:53