MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff_rewrite_inner.cpp
Go to the documentation of this file.
1#include <algorithm>
2#include <string>
3
4#include "mim/tuple.h"
5
10
11using namespace std::literals;
12
13namespace mim::plug::autodiff {
14
16 auto pb = zero_pullback(lit->type(), f->dom(2, 0));
17 partial_pullback[lit] = pb;
18 return lit;
19}
20
22 assert(augmented.count(var));
23 auto aug_var = augmented[var];
24 assert(partial_pullback.count(aug_var));
25 return var;
26}
27
29 // TODO: we need partial pullbacks for tuples (higher-order / ret-cont application)
30 // also for higher-order args, ret_cont (at another point)
31 // the pullback is not important but formally required by tuple rule
32 if (augmented.count(lam)) {
33 // We already know the function:
34 // * recursion
35 // * higher order arguments
36 // * new encounter of previous function
37 world().DLOG("already augmented {} : {} to {} : {}", lam, lam->type(), augmented[lam], augmented[lam]->type());
38 return augmented[lam];
39 }
40 // TODO: better fix (another pass as analysis?)
41 // TODO: handle open functions
42 if (Lam::isa_basicblock(lam) || lam->sym().view().find("ret") != std::string::npos
43 || lam->sym().view().find("_cont") != std::string::npos) {
44 // A open continuation behaves the same as return:
45 // ```
46 // cont: Cn[X]
47 // cont': Cn[X,Cn[X,A]]
48 // ```
49 // There is dependency on the closed function context.
50 // (All derivatives are with respect to the arguments of a closed function.)
51
52 world().DLOG("found an open continuation {} : {}", lam, lam->type());
53 auto cont_dom = lam->type()->dom(); // not only 0 but all
54 auto pb_ty = pullback_type(cont_dom, f->dom(2, 0));
55 auto aug_dom = autodiff_type_fun(cont_dom);
56 world().DLOG("augmented domain {}", aug_dom);
57 world().DLOG("pb type is {}", pb_ty);
58 auto aug_lam = world().mut_con({aug_dom, pb_ty})->set("aug_"s + lam->sym().str());
59 auto aug_var = aug_lam->var((nat_t)0);
60 augmented[lam->var()] = aug_var;
61 augmented[lam] = aug_lam; // TODO: only one of these two
62 derived[lam] = aug_lam;
63 auto pb = aug_lam->var(1);
64 partial_pullback[aug_var] = pb;
65 // We are still in same closed function.
66 auto new_body = augment(lam->body(), f, f_diff);
67 // TODO we also need to rewrite the filter
68 aug_lam->set(lam->filter(), new_body);
69
70 auto lam_pb = zero_pullback(lam->type(), f->dom(2, 0));
71 partial_pullback[aug_lam] = lam_pb;
72 world().DLOG("augmented {} : {}", lam, lam->type());
73 world().DLOG("to {} : {}", aug_lam, aug_lam->type());
74 world().DLOG("ppb for lam cont: {}", lam_pb);
75
76 return aug_lam;
77 }
78 world().DLOG("found a closed function call {} : {}", lam, lam->type());
79 // Some general function in the program needs to be differentiated.
80 auto aug_lam = world().call<ad>(lam);
81 // TODO: directly more association here? => partly inline op_autodiff
82 world().DLOG("augmented function is {} : {}", aug_lam, aug_lam->type());
83 return aug_lam;
84}
85
87 auto tuple = ext->tuple();
88 auto index = ext->index();
89
90 auto aug_tuple = augment(tuple, f, f_diff);
91 auto aug_index = augment(index, f, f_diff);
92
93 Ref pb;
94 world().DLOG("tuple was: {} : {}", tuple, tuple->type());
95 world().DLOG("aug tuple: {} : {}", aug_tuple, aug_tuple->type());
96 if (shadow_pullback.count(aug_tuple)) {
97 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
98 world().DLOG("Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
99 pb = world().extract(shadow_tuple_pb, aug_index);
100 } else {
101 // ```
102 // e:T, b:B
103 // b = e#i
104 // b* = \lambda (s:B). e* (insert s at i in (zero T))
105 // ```
106 assert(partial_pullback.count(aug_tuple));
107 auto tuple_pb = partial_pullback[aug_tuple];
108 auto pb_ty = pullback_type(ext->type(), f->dom(2, 0));
109 auto pb_fun = world().mut_lam(pb_ty)->set("extract_pb");
110 world().DLOG("Pullback: {} : {}", pb_fun, pb_fun->type());
111 auto pb_tangent = pb_fun->var(0_s)->set("s");
112 auto tuple_tan = world().insert(world().call<zero>(aug_tuple->type()), aug_index, pb_tangent)->set("tup_s");
113 pb_fun->app(true, tuple_pb, {tuple_tan, pb_fun->var(1) /* ret_var but make sure to select correct one */});
114 pb = pb_fun;
115 }
116
117 auto aug_ext = world().extract(aug_tuple, aug_index);
118 partial_pullback[aug_ext] = pb;
119
120 return aug_ext;
121}
122
123Ref AutoDiffEval::augment_tuple(const Tuple* tup, Lam* f, Lam* f_diff) {
124 // TODO: should use ops instead?
125 auto aug_ops = tup->projs([&](Ref op) -> const Def* { return augment(op, f, f_diff); });
126 auto aug_tup = world().tuple(aug_ops);
127
128 auto pbs = DefVec(Defs(aug_ops), [&](Ref op) { return partial_pullback[op]; });
129 world().DLOG("tuple pbs {,}", pbs);
130 // shadow pb = tuple of pbs
131 auto shadow_pb = world().tuple(pbs);
132 shadow_pullback[aug_tup] = shadow_pb;
133
134 // ```
135 // \lambda (s:[E0,...,Em]).
136 // sum (m,A)
137 // ((cps2ds e0*) (s#0), ..., (cps2ds em*) (s#m))
138 // ```
139 auto pb_ty = pullback_type(tup->type(), f->dom(2, 0));
140 auto pb = world().mut_lam(pb_ty)->set("tup_pb");
141 world().DLOG("Augmented tuple: {} : {}", aug_tup, aug_tup->type());
142 world().DLOG("Tuple Pullback: {} : {}", pb, pb->type());
143 world().DLOG("shadow pb: {} : {}", shadow_pb, shadow_pb->type());
144
145 auto pb_tangent = pb->var(0_s)->set("tup_s");
146
147 auto tangents = DefVec(pbs.size(), [&](nat_t i) {
148 return world().app(direct::op_cps2ds_dep(pbs[i]), world().extract(pb_tangent, i));
149 });
150 pb->app(true, pb->var(1),
151 // summed up tangents
152 op_sum(tangent_type_fun(f->dom(2, 0)), tangents));
153 partial_pullback[aug_tup] = pb;
154
155 return aug_tup;
156}
157
158Ref AutoDiffEval::augment_pack(const Pack* pack, Lam* f, Lam* f_diff) {
159 auto shape = pack->arity(); // TODO: arity vs shape
160 auto body = pack->body();
161
162 auto aug_shape = augment_(shape, f, f_diff);
163 auto aug_body = augment(body, f, f_diff);
164
165 auto aug_pack = world().pack(aug_shape, aug_body);
166
167 assert(partial_pullback[aug_body] && "pack pullback should exists");
168 // TODO: or use scale axiom
169 auto body_pb = partial_pullback[aug_body];
170 auto pb_pack = world().pack(aug_shape, body_pb);
171 shadow_pullback[aug_pack] = pb_pack;
172
173 world().DLOG("shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
174
175 auto pb_type = pullback_type(pack->type(), f->dom(2, 0));
176 auto pb = world().mut_lam(pb_type)->set("pack_pb");
177
178 world().DLOG("pb of pack: {} : {}", pb, pb_type);
179
180 auto f_arg_ty_diff = tangent_type_fun(f->dom(2, 0));
181 auto app_pb = world().mut_pack(world().arr(aug_shape, f_arg_ty_diff));
182
183 // TODO: special case for const width (special tuple)
184
185 // <i:n, cps2ds body_pb (s#i)>
186 app_pb->set(world().app(direct::op_cps2ds_dep(body_pb), world().extract(pb->var((nat_t)0), app_pb->var())));
187
188 world().DLOG("app pb of pack: {} : {}", app_pb, app_pb->type());
189
190 auto sumup = world().app(world().annex<sum>(), {aug_shape, f_arg_ty_diff});
191 world().DLOG("sumup: {} : {}", sumup, sumup->type());
192
193 pb->app(true, pb->var(1), world().app(sumup, app_pb));
194
195 partial_pullback[aug_pack] = pb;
196
197 return aug_pack;
198}
199
200Ref AutoDiffEval::augment_app(const App* app, Lam* f, Lam* f_diff) {
201 auto callee = app->callee();
202 auto arg = app->arg();
203
204 auto aug_arg = augment(arg, f, f_diff);
205 auto aug_callee = augment(callee, f, f_diff);
206
207 world().DLOG("augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
208 world().DLOG("augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
209 // TODO: move down to if(!is_cont(callee))
210 if (!Pi::isa_cn(callee->type()) && Pi::isa_cn(aug_callee->type())) {
211 aug_callee = direct::op_cps2ds_dep(aug_callee);
212 world().DLOG("wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee,
213 aug_callee->type());
214 }
215
216 // nested (inner application)
217 if (app->type()->isa<Pi>()) {
218 world().DLOG("Nested application callee: {} : {}", aug_callee, aug_callee->type());
219 world().DLOG("Nested application arg: {} : {}", aug_arg, aug_arg->type());
220 auto aug_app = world().app(aug_callee, aug_arg);
221 world().DLOG("Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
222 // We do not add a pullback as the pullback is bundled in the cps call or returned by the ds call
223 return aug_app;
224 }
225
226 // continuation (ret, if, ...)
227 if (Pi::isa_basicblock(callee->type())) {
228 // TODO: check if function (not operator)
229 // The original function is an open function (return cont / continuation) of type `Cn[E]`
230 // The augmented function `aug_callee` looks like a function but is not really a function has the type `Cn[E,
231 // Cn[E, Cn[A]]]`
232
233 // ret(e) => ret'(e, e*)
234
235 world().DLOG("continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
236
237 auto arg_pb = partial_pullback[aug_arg];
238 auto aug_app = world().app(aug_callee, {aug_arg, arg_pb});
239 world().DLOG("Augmented application: {} : {}", aug_app, aug_app->type());
240 return aug_app;
241 }
242
243 // ds function
244 if (!Pi::isa_cn(callee->type())) {
245 auto aug_app = world().app(aug_callee, aug_arg);
246 world().DLOG("Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
247
248 world().DLOG("ds function: {} : {}", aug_app, aug_app->type());
249 // The calle is ds function (e.g. operator (or its partial application))
250 auto [aug_res, fun_pb] = aug_app->projs<2>();
251 // We compose `fun_pb` with `argument_pb` to get the result pb
252 // TODO: combine case with cps function case
253 auto arg_pb = partial_pullback[aug_arg];
254 assert(arg_pb);
255 // `fun_pb: out_tan -> arg_tan`
256 // `arg_pb: arg_tan -> fun_tan`
257 world().DLOG("function pullback: {} : {}", fun_pb, fun_pb->type());
258 world().DLOG("argument pullback: {} : {}", arg_pb, arg_pb->type());
259 auto res_pb = compose_cn(arg_pb, fun_pb);
260 world().DLOG("result pullback: {} : {}", res_pb, res_pb->type());
261 partial_pullback[aug_res] = res_pb;
262 world().debug_dump();
263 return aug_res;
264 }
265
266 // TODO: dest with a function such that f args != g args
267 {
268 // normal function app
269 // ```
270 // g: cn[E, cn X]
271 // g(args,cont)
272 // g': cn[E, cn[X, cn[X, cn E]]]
273 // g'(aug_args, ____)
274 // ```
275 auto g = callee;
276 // At this point g_deriv might still be "autodiff ... g".
277 auto g_deriv = aug_callee;
278 world().DLOG("g: {} : {}", g, g->type());
279 world().DLOG("g': {} : {}", g_deriv, g_deriv->type());
280
281 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
282 world().DLOG("real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
283 world().DLOG("aug_cont: {} : {}", aug_cont, aug_cont->type());
284 auto e_pb = partial_pullback[real_aug_args];
285 world().DLOG("e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
286
287 // TODO: better debug names
288 auto ret_g_deriv_ty = g_deriv->type()->as<Pi>()->dom(1);
289 world().DLOG("ret_g_deriv_ty: {} ", ret_g_deriv_ty);
290 auto c1_ty = ret_g_deriv_ty->as<Pi>();
291 world().DLOG("c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
292 auto c1 = world().mut_lam(c1_ty)->set("c1");
293 auto res = c1->var((nat_t)0);
294 auto r_pb = c1->var(1);
295 c1->app(true, aug_cont, {res, compose_cn(e_pb, r_pb)});
296
297 auto aug_app = world().app(aug_callee, {real_aug_args, c1});
298 world().DLOG("aug_app: {} : {}", aug_app, aug_app->type());
299
300 // The result is * => no pb needed, no composition needed.
301 return aug_app;
302 }
303 assert(false && "should not be reached");
304}
305
306/// Rewrites the given definition in a lambda environment.
308 // We use macros above to avoid recomputation.
309 // TODO: Alternative:
310 // Use class instances to rewrite inside a function and save such values (f, f_diff, f->dom(2, 0)).
311
312 world().DLOG("Augment def {} : {}", def, def->type());
313
314 // Applications are continuations, operators, or full functions
315 if (auto app = def->isa<App>()) {
316 auto callee = app->callee();
317 auto arg = app->arg();
318 world().DLOG("Augment application: app {} with {}", callee, arg);
319 return augment_app(app, f, f_diff);
320 } else if (auto ext = def->isa<Extract>()) {
321 auto tuple = ext->tuple();
322 auto index = ext->index();
323 world().DLOG("Augment extract: {} #[{}]", tuple, index);
324 return augment_extract(ext, f, f_diff);
325 } else if (auto var = def->isa<Var>()) {
326 world().DLOG("Augment variable: {}", var);
327 return augment_var(var, f, f_diff);
328 } else if (auto lam = def->isa_mut<Lam>()) {
329 world().DLOG("Augment mut lambda: {}", lam);
330 return augment_lam(lam, f, f_diff);
331 } else if (auto lam = def->isa<Lam>()) {
332 world().ELOG("Augment lambda: {}", lam);
333 assert(false && "can not handle non-mutable lambdas");
334 } else if (auto lit = def->isa<Lit>()) {
335 world().DLOG("Augment literal: {}", def);
336 return augment_lit(lit, f, f_diff);
337 } else if (auto tup = def->isa<Tuple>()) {
338 world().DLOG("Augment tuple: {}", def);
339 return augment_tuple(tup, f, f_diff);
340 } else if (auto pack = def->isa<Pack>()) {
341 // TODO: handle mut packs (dependencies in the pack) (=> see paper about vectors)
342 auto shape = pack->arity(); // TODO: arity vs shape
343 auto body = pack->body();
344 world().DLOG("Augment pack: {} : {} with {}", shape, shape->type(), body);
345 return augment_pack(pack, f, f_diff);
346 } else if (auto ax = def->isa<Axiom>()) {
347 // TODO: move concrete handling to own function / file / directory (file per plugin)
348 world().DLOG("Augment axiom: {} : {}", ax, ax->type());
349 world().DLOG("axiom curry: {}", ax->curry());
350 world().DLOG("axiom flags: {}", ax->flags());
351 auto diff_name = ax->sym().str();
352 find_and_replace(diff_name, ".", "_");
353 find_and_replace(diff_name, "%", "");
354 diff_name = "internal_diff_" + diff_name;
355 world().DLOG("axiom name: {}", ax->sym());
356 world().DLOG("axiom function name: {}", diff_name);
357
358 auto diff_fun = world().external(world().sym(diff_name));
359 if (!diff_fun) {
360 world().ELOG("derivation not found: {}", diff_name);
361 auto expected_type = autodiff_type_fun(ax->type());
362 world().ELOG("expected: {} : {}", diff_name, expected_type);
363 assert(false && "unhandled axiom");
364 }
365 // TODO: why does this cause a depth error?
366 return diff_fun;
367 }
368
369 // TODO: handle Pi for axiom app
370 // TODO: remaining (lambda, axiom)
371
372 world().ELOG("did not expect to augment: {} : {}", def, def->type());
373 world().ELOG("node: {}", def->node_name());
374 assert(false && "augment not implemented on this def");
375 fe::unreachable();
376}
377
378} // namespace mim::plug::autodiff
Ref arg() const
Definition lam.h:222
Ref callee() const
Definition lam.h:213
Base class for all Defs.
Definition def.h:226
Ref type() const
Definition def.h:251
std::string_view node_name() const
Definition def.cpp:427
Def * set(size_t i, Ref)
Successively set from left to right.
Definition def.cpp:256
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:445
Ref var(nat_t a, nat_t i)
Definition def.h:395
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:361
Ref arity() const
Definition def.cpp:477
Sym sym() const
Definition def.h:466
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:154
Ref index() const
Definition tuple.h:165
Ref tuple() const
Definition tuple.h:164
A function.
Definition lam.h:104
Ref filter() const
Definition lam.h:114
Lam * set(Filter filter, Ref body)
Definition lam.h:161
const Pi * type() const
Definition lam.h:116
Ref body() const
Definition lam.h:115
static const Lam * isa_basicblock(Ref d)
Definition lam.h:134
A (possibly paramterized) Tuple.
Definition tuple.h:114
Ref body() const
Definition tuple.h:124
Pack * set(Ref body)
Definition tuple.h:132
World & world()
Definition pass.h:296
size_t index() const
Definition pass.h:33
A dependent function type.
Definition lam.h:11
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
Definition lam.h:44
static const Pi * isa_basicblock(Ref d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Definition lam.h:48
Ref dom() const
Definition lam.h:32
Helper class to retrieve Infer::arg if present.
Definition def.h:86
Data constructor for a Sigma.
Definition tuple.h:50
Ref tuple(Defs ops)
Definition world.cpp:233
Lam * mut_con(Ref dom)
Definition world.h:296
Ref insert(Ref d, Ref i, Ref val)
Definition world.cpp:348
Ref var(Ref type, Def *mut)
Definition world.cpp:156
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Definition dump.cpp:491
Sym sym(std::string_view)
Definition world.cpp:77
Def * external(Sym name)
Lookup by name.
Definition world.h:182
Ref pack(Ref arity, Ref body)
Definition world.cpp:419
const Def * call(Id id, Args &&... args)
Complete curried call of annexes obeying implicits.
Definition world.h:507
Ref extract(Ref d, Ref i)
Definition world.cpp:286
Ref app(Ref callee, Ref arg)
Definition world.cpp:170
const Type * type(Ref level)
Definition world.cpp:94
Pack * mut_pack(Ref type)
Definition world.h:352
Lam * mut_lam(const Pi *pi)
Definition world.h:284
Ref augment_var(const Var *, Lam *, Lam *)
helper functions for augment
Ref augment_lit(const Lit *, Lam *, Lam *)
Ref augment_tuple(const Tuple *, Lam *, Lam *)
Ref augment_(Ref, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Ref augment_app(const App *, Lam *, Lam *)
Ref augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
Ref augment_extract(const Extract *, Lam *, Lam *)
The automatic differentiation Plugin
Definition autodiff.h:6
const Def * op_sum(const Def *T, Defs)
Definition autodiff.cpp:169
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:113
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:57
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:44
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Definition autodiff.cpp:62
Ref op_cps2ds_dep(Ref k)
Definition direct.h:15
View< const Def * > Defs
Definition def.h:61
u64 nat_t
Definition types.h:43
Vector< const Def * > DefVec
Definition def.h:62
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
Definition util.h:70
Ref compose_cn(Ref f, Ref g)
The high level view is:
Definition lam.cpp:53