Thorin 1.9.0
The Higher ORder INtermediate representation
Loading...
Searching...
No Matches
autodiff_rewrite_inner.cpp
Go to the documentation of this file.
1#include <algorithm>
2#include <string>
3
4#include "thorin/tuple.h"
5
10
11using namespace std::literals;
12
13namespace thorin::plug::autodiff {
14
16 auto pb = zero_pullback(lit->type(), f->dom(2, 0));
17 partial_pullback[lit] = pb;
18 return lit;
19}
20
22 assert(augmented.count(var));
23 auto aug_var = augmented[var];
24 assert(partial_pullback.count(aug_var));
25 return var;
26}
27
29 auto& world = lam->world();
30 // TODO: we need partial pullbacks for tuples (higher-order / ret-cont application)
31 // also for higher-order args, ret_cont (at another point)
32 // the pullback is not important but formally required by tuple rule
33 if (augmented.count(lam)) {
34 // We already know the function:
35 // * recursion
36 // * higher order arguments
37 // * new encounter of previous function
38 world.DLOG("already augmented {} : {} to {} : {}", lam, lam->type(), augmented[lam], augmented[lam]->type());
39 return augmented[lam];
40 }
41 // TODO: better fix (another pass as analysis?)
42 // TODO: handle open functions
43 if (Lam::isa_basicblock(lam) || lam->sym().view().find("ret") != std::string::npos
44 || lam->sym().view().find("_cont") != std::string::npos) {
45 // A open continuation behaves the same as return:
46 // ```
47 // cont: Cn[X]
48 // cont': Cn[X,Cn[X,A]]
49 // ```
50 // There is dependency on the closed function context.
51 // (All derivatives are with respect to the arguments of a closed function.)
52
53 world.DLOG("found an open continuation {} : {}", lam, lam->type());
54 auto cont_dom = lam->type()->dom(); // not only 0 but all
55 auto pb_ty = pullback_type(cont_dom, f->dom(2, 0));
56 auto aug_dom = autodiff_type_fun(cont_dom);
57 world.DLOG("augmented domain {}", aug_dom);
58 world.DLOG("pb type is {}", pb_ty);
59 auto aug_lam = world.mut_con({aug_dom, pb_ty})->set("aug_"s + lam->sym().str());
60 auto aug_var = aug_lam->var((nat_t)0);
61 augmented[lam->var()] = aug_var;
62 augmented[lam] = aug_lam; // TODO: only one of these two
63 derived[lam] = aug_lam;
64 auto pb = aug_lam->var(1);
65 partial_pullback[aug_var] = pb;
66 // We are still in same closed function.
67 auto new_body = augment(lam->body(), f, f_diff);
68 // TODO we also need to rewrite the filter
69 aug_lam->set(lam->filter(), new_body);
70
71 auto lam_pb = zero_pullback(lam->type(), f->dom(2, 0));
72 partial_pullback[aug_lam] = lam_pb;
73 world.DLOG("augmented {} : {}", lam, lam->type());
74 world.DLOG("to {} : {}", aug_lam, aug_lam->type());
75 world.DLOG("ppb for lam cont: {}", lam_pb);
76
77 return aug_lam;
78 }
79 world.DLOG("found a closed function call {} : {}", lam, lam->type());
80 // Some general function in the program needs to be differentiated.
81 auto aug_lam = world.call<ad>(lam);
82 // TODO: directly more association here? => partly inline op_autodiff
83 world.DLOG("augmented function is {} : {}", aug_lam, aug_lam->type());
84 return aug_lam;
85}
86
88 auto& world = ext->world();
89
90 auto tuple = ext->tuple();
91 auto index = ext->index();
92
93 auto aug_tuple = augment(tuple, f, f_diff);
94 auto aug_index = augment(index, f, f_diff);
95
96 Ref pb;
97 world.DLOG("tuple was: {} : {}", tuple, tuple->type());
98 world.DLOG("aug tuple: {} : {}", aug_tuple, aug_tuple->type());
99 if (shadow_pullback.count(aug_tuple)) {
100 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
101 world.DLOG("Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
102 pb = world.extract(shadow_tuple_pb, aug_index);
103 } else {
104 // ```
105 // e:T, b:B
106 // b = e#i
107 // b* = \lambda (s:B). e* (insert s at i in (zero T))
108 // ```
109 assert(partial_pullback.count(aug_tuple));
110 auto tuple_pb = partial_pullback[aug_tuple];
111 auto pb_ty = pullback_type(ext->type(), f->dom(2, 0));
112 auto pb_fun = world.mut_lam(pb_ty)->set("extract_pb");
113 world.DLOG("Pullback: {} : {}", pb_fun, pb_fun->type());
114 auto pb_tangent = pb_fun->var(0_s)->set("s");
115 auto tuple_tan = world.insert(world.call<zero>(aug_tuple->type()), aug_index, pb_tangent)->set("tup_s");
116 pb_fun->app(true, tuple_pb,
117 {
118 tuple_tan,
119 pb_fun->var(1) // ret_var but make sure to select correct one
120 });
121 pb = pb_fun;
122 }
123
124 auto aug_ext = world.extract(aug_tuple, aug_index);
125 partial_pullback[aug_ext] = pb;
126
127 return aug_ext;
128}
129
130Ref AutoDiffEval::augment_tuple(const Tuple* tup, Lam* f, Lam* f_diff) {
131 auto& world = tup->world();
132
133 // TODO: should use ops instead?
134 auto aug_ops = tup->projs([&](Ref op) -> const Def* { return augment(op, f, f_diff); });
135 auto aug_tup = world.tuple(aug_ops);
136
137 auto pbs = DefVec(Defs(aug_ops), [&](Ref op) { return partial_pullback[op]; });
138 world.DLOG("tuple pbs {,}", pbs);
139 // shadow pb = tuple of pbs
140 auto shadow_pb = world.tuple(pbs);
141 shadow_pullback[aug_tup] = shadow_pb;
142
143 // ```
144 // \lambda (s:[E0,...,Em]).
145 // sum (m,A)
146 // ((cps2ds e0*) (s#0), ..., (cps2ds em*) (s#m))
147 // ```
148 auto pb_ty = pullback_type(tup->type(), f->dom(2, 0));
149 auto pb = world.mut_lam(pb_ty)->set("tup_pb");
150 world.DLOG("Augmented tuple: {} : {}", aug_tup, aug_tup->type());
151 world.DLOG("Tuple Pullback: {} : {}", pb, pb->type());
152 world.DLOG("shadow pb: {} : {}", shadow_pb, shadow_pb->type());
153
154 auto pb_tangent = pb->var(0_s)->set("tup_s");
155
156 auto tangents = DefVec(
157 pbs.size(), [&](nat_t i) { return world.app(direct::op_cps2ds_dep(pbs[i]), world.extract(pb_tangent, i)); });
158 pb->app(true, pb->var(1),
159 // summed up tangents
160 op_sum(tangent_type_fun(f->dom(2, 0)), tangents));
161 partial_pullback[aug_tup] = pb;
162
163 return aug_tup;
164}
165
166Ref AutoDiffEval::augment_pack(const Pack* pack, Lam* f, Lam* f_diff) {
167 auto& world = pack->world();
168 auto shape = pack->arity(); // TODO: arity vs shape
169 auto body = pack->body();
170
171 auto aug_shape = augment_(shape, f, f_diff);
172 auto aug_body = augment(body, f, f_diff);
173
174 auto aug_pack = world.pack(aug_shape, aug_body);
175
176 assert(partial_pullback[aug_body] && "pack pullback should exists");
177 // TODO: or use scale axiom
178 auto body_pb = partial_pullback[aug_body];
179 auto pb_pack = world.pack(aug_shape, body_pb);
180 shadow_pullback[aug_pack] = pb_pack;
181
182 world.DLOG("shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
183
184 auto pb_type = pullback_type(pack->type(), f->dom(2, 0));
185 auto pb = world.mut_lam(pb_type)->set("pack_pb");
186
187 world.DLOG("pb of pack: {} : {}", pb, pb_type);
188
189 auto f_arg_ty_diff = tangent_type_fun(f->dom(2, 0));
190 auto app_pb = world.mut_pack(world.arr(aug_shape, f_arg_ty_diff));
191
192 // TODO: special case for const width (special tuple)
193
194 // <i:n, cps2ds body_pb (s#i)>
195 app_pb->set(world.app(direct::op_cps2ds_dep(body_pb), world.extract(pb->var((nat_t)0), app_pb->var())));
196
197 world.DLOG("app pb of pack: {} : {}", app_pb, app_pb->type());
198
199 auto sumup = world.app(world.annex<sum>(), {aug_shape, f_arg_ty_diff});
200 world.DLOG("sumup: {} : {}", sumup, sumup->type());
201
202 pb->app(true, pb->var(1), world.app(sumup, app_pb));
203
204 partial_pullback[aug_pack] = pb;
205
206 return aug_pack;
207}
208
209Ref AutoDiffEval::augment_app(const App* app, Lam* f, Lam* f_diff) {
210 auto& world = app->world();
211
212 auto callee = app->callee();
213 auto arg = app->arg();
214
215 auto aug_arg = augment(arg, f, f_diff);
216 auto aug_callee = augment(callee, f, f_diff);
217
218 world.DLOG("augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
219 world.DLOG("augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
220 // TODO: move down to if(!is_cont(callee))
221 if (!Pi::isa_cn(callee->type()) && Pi::isa_cn(aug_callee->type())) {
222 aug_callee = direct::op_cps2ds_dep(aug_callee);
223 world.DLOG("wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
224 }
225
226 // nested (inner application)
227 if (app->type()->isa<Pi>()) {
228 world.DLOG("Nested application callee: {} : {}", aug_callee, aug_callee->type());
229 world.DLOG("Nested application arg: {} : {}", aug_arg, aug_arg->type());
230 auto aug_app = world.app(aug_callee, aug_arg);
231 world.DLOG("Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
232 // We do not add a pullback as the pullback is bundled in the cps call or returned by the ds call
233 return aug_app;
234 }
235
236 // continuation (ret, if, ...)
237 if (Pi::isa_basicblock(callee->type())) {
238 // TODO: check if function (not operator)
239 // The original function is an open function (return cont / continuation) of type `Cn[E]`
240 // The augmented function `aug_callee` looks like a function but is not really a function has the type `Cn[E,
241 // Cn[E, Cn[A]]]`
242
243 // ret(e) => ret'(e, e*)
244
245 world.DLOG("continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
246
247 auto arg_pb = partial_pullback[aug_arg];
248 auto aug_app = world.app(aug_callee, {aug_arg, arg_pb});
249 world.DLOG("Augmented application: {} : {}", aug_app, aug_app->type());
250 return aug_app;
251 }
252
253 // ds function
254 if (!Pi::isa_cn(callee->type())) {
255 auto aug_app = world.app(aug_callee, aug_arg);
256 world.DLOG("Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
257
258 world.DLOG("ds function: {} : {}", aug_app, aug_app->type());
259 // The calle is ds function (e.g. operator (or its partial application))
260 auto [aug_res, fun_pb] = aug_app->projs<2>();
261 // We compose `fun_pb` with `argument_pb` to get the result pb
262 // TODO: combine case with cps function case
263 auto arg_pb = partial_pullback[aug_arg];
264 assert(arg_pb);
265 // `fun_pb: out_tan -> arg_tan`
266 // `arg_pb: arg_tan -> fun_tan`
267 world.DLOG("function pullback: {} : {}", fun_pb, fun_pb->type());
268 world.DLOG("argument pullback: {} : {}", arg_pb, arg_pb->type());
269 auto res_pb = compose_cn(arg_pb, fun_pb);
270 world.DLOG("result pullback: {} : {}", res_pb, res_pb->type());
271 partial_pullback[aug_res] = res_pb;
273 return aug_res;
274 }
275
276 // TODO: dest with a function such that f args != g args
277 {
278 // normal function app
279 // ```
280 // g: cn[E, cn X]
281 // g(args,cont)
282 // g': cn[E, cn[X, cn[X, cn E]]]
283 // g'(aug_args, ____)
284 // ```
285 auto g = callee;
286 // At this point g_deriv might still be "autodiff ... g".
287 auto g_deriv = aug_callee;
288 world.DLOG("g: {} : {}", g, g->type());
289 world.DLOG("g': {} : {}", g_deriv, g_deriv->type());
290
291 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
292 world.DLOG("real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
293 world.DLOG("aug_cont: {} : {}", aug_cont, aug_cont->type());
294 auto e_pb = partial_pullback[real_aug_args];
295 world.DLOG("e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
296
297 // TODO: better debug names
298 auto ret_g_deriv_ty = g_deriv->type()->as<Pi>()->dom(1);
299 world.DLOG("ret_g_deriv_ty: {} ", ret_g_deriv_ty);
300 auto c1_ty = ret_g_deriv_ty->as<Pi>();
301 world.DLOG("c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
302 auto c1 = world.mut_lam(c1_ty)->set("c1");
303 auto res = c1->var((nat_t)0);
304 auto r_pb = c1->var(1);
305 c1->app(true, aug_cont, {res, compose_cn(e_pb, r_pb)});
306
307 auto aug_app = world.app(aug_callee, {real_aug_args, c1});
308 world.DLOG("aug_app: {} : {}", aug_app, aug_app->type());
309
310 // The result is * => no pb needed, no composition needed.
311 return aug_app;
312 }
313 assert(false && "should not be reached");
314}
315
316/// Rewrites the given definition in a lambda environment.
318 auto& world = def->world();
319 // We use macros above to avoid recomputation.
320 // TODO: Alternative:
321 // Use class instances to rewrite inside a function and save such values (f, f_diff, f->dom(2, 0)).
322
323 world.DLOG("Augment def {} : {}", def, def->type());
324
325 // Applications are continuations, operators, or full functions
326 if (auto app = def->isa<App>()) {
327 auto callee = app->callee();
328 auto arg = app->arg();
329 world.DLOG("Augment application: app {} with {}", callee, arg);
330 return augment_app(app, f, f_diff);
331 } else if (auto ext = def->isa<Extract>()) {
332 auto tuple = ext->tuple();
333 auto index = ext->index();
334 world.DLOG("Augment extract: {} #[{}]", tuple, index);
335 return augment_extract(ext, f, f_diff);
336 } else if (auto var = def->isa<Var>()) {
337 world.DLOG("Augment variable: {}", var);
338 return augment_var(var, f, f_diff);
339 } else if (auto lam = def->isa_mut<Lam>()) {
340 world.DLOG("Augment mut lambda: {}", lam);
341 return augment_lam(lam, f, f_diff);
342 } else if (auto lam = def->isa<Lam>()) {
343 world.ELOG("Augment lambda: {}", lam);
344 assert(false && "can not handle non-mutable lambdas");
345 } else if (auto lit = def->isa<Lit>()) {
346 world.DLOG("Augment literal: {}", def);
347 return augment_lit(lit, f, f_diff);
348 } else if (auto tup = def->isa<Tuple>()) {
349 world.DLOG("Augment tuple: {}", def);
350 return augment_tuple(tup, f, f_diff);
351 } else if (auto pack = def->isa<Pack>()) {
352 // TODO: handle mut packs (dependencies in the pack) (=> see paper about vectors)
353 auto shape = pack->arity(); // TODO: arity vs shape
354 auto body = pack->body();
355 world.DLOG("Augment pack: {} : {} with {}", shape, shape->type(), body);
356 return augment_pack(pack, f, f_diff);
357 } else if (auto ax = def->isa<Axiom>()) {
358 // TODO: move concrete handling to own function / file / directory (file per plugin)
359 world.DLOG("Augment axiom: {} : {}", ax, ax->type());
360 world.DLOG("axiom curry: {}", ax->curry());
361 world.DLOG("axiom flags: {}", ax->flags());
362 auto diff_name = ax->sym().str();
363 find_and_replace(diff_name, ".", "_");
364 find_and_replace(diff_name, "%", "");
365 diff_name = "internal_diff_" + diff_name;
366 world.DLOG("axiom name: {}", ax->sym());
367 world.DLOG("axiom function name: {}", diff_name);
368
369 auto diff_fun = world.external(world.sym(diff_name));
370 if (!diff_fun) {
371 world.ELOG("derivation not found: {}", diff_name);
372 auto expected_type = autodiff_type_fun(ax->type());
373 world.ELOG("expected: {} : {}", diff_name, expected_type);
374 assert(false && "unhandled axiom");
375 }
376 // TODO: why does this cause a depth error?
377 return diff_fun;
378 }
379
380 // TODO: handle Pi for axiom app
381 // TODO: remaining (lambda, axiom)
382
383 world.ELOG("did not expect to augment: {} : {}", def, def->type());
384 world.ELOG("node: {}", def->node_name());
385 assert(false && "augment not implemented on this def");
386 fe::unreachable();
387}
388
389} // namespace thorin::plug::autodiff
const Def * callee() const
Definition lam.h:206
const Def * arg() const
Definition lam.h:215
Base class for all Defs.
Definition def.h:222
Ref arity() const
Definition def.cpp:494
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Definition def.h:369
Ref var(nat_t a, nat_t i)
Definition def.h:403
const Def * type() const
Yields the raw type of this Def, i.e. maybe nullptr.
Definition def.h:248
std::string_view node_name() const
Definition def.cpp:437
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Definition def.h:449
Sym sym() const
Definition def.h:470
Def * set(size_t i, const Def *def)
Successively set from left to right.
Definition def.cpp:254
World & world() const
Definition def.cpp:421
Extracts from a Sigma or Array-typed Extract::tuple the element at position Extract::index.
Definition tuple.h:118
const Def * index() const
Definition tuple.h:127
const Def * tuple() const
Definition tuple.h:126
A function.
Definition lam.h:97
Ref body() const
Definition lam.h:108
Ref filter() const
Definition lam.h:107
static const Lam * isa_basicblock(Ref d)
Definition lam.h:133
const Pi * type() const
Definition lam.h:109
Lam * set(Filter filter, const Def *body)
Definition lam.h:159
A (possibly paramterized) Tuple.
Definition tuple.h:87
const Def * body() const
Definition tuple.h:97
Pack * set(const Def *body)
Definition tuple.h:104
size_t index() const
Definition pass.h:33
World & world()
Definition pass.h:296
A dependent function type.
Definition lam.h:11
static const Pi * isa_basicblock(Ref d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Definition lam.h:54
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom thorin::Bottom?
Definition lam.h:50
Ref dom() const
Definition lam.h:32
Helper class to retrieve Infer::arg if present.
Definition def.h:87
Data constructor for a Sigma.
Definition tuple.h:39
Ref insert(Ref d, Ref i, Ref val)
Definition world.cpp:340
Lam * mut_con(Ref dom)
Definition world.h:275
Pack * mut_pack(Ref type)
Definition world.h:332
Sym sym(std::string_view)
Definition world.cpp:77
Ref var(Ref type, Def *mut)
Definition world.cpp:153
Ref pack(Ref arity, Ref body)
Definition world.cpp:405
Def * external(Sym name)
Lookup by name.
Definition world.h:164
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Definition dump.cpp:458
const Def * annex(Id id)
Lookup annex by Axiom::id.
Definition world.h:167
Ref extract(Ref d, Ref i)
Definition world.cpp:278
const Def * call(Id id, Args &&... args)
Definition world.h:497
Ref arr(Ref shape, Ref body)
Definition world.cpp:378
Ref app(Ref callee, Ref arg)
Definition world.cpp:183
Ref tuple(Defs ops)
Definition world.cpp:226
Lam * mut_lam(const Pi *pi)
Definition world.h:263
const Type * type(Ref level)
Definition world.cpp:92
Ref augment_extract(const Extract *, Lam *, Lam *)
Ref augment_app(const App *, Lam *, Lam *)
Ref augment_tuple(const Tuple *, Lam *, Lam *)
Ref augment_lit(const Lit *, Lam *, Lam *)
Ref augment_(Ref, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
Ref augment_var(const Var *, Lam *, Lam *)
helper functions for augment
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Ref augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
The automatic differentiation Plugin
Definition autodiff.h:7
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:57
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Definition autodiff.cpp:62
const Def * zero_pullback(const Def *E, const Def *A)
Definition autodiff.cpp:44
const Def * op_sum(const Def *T, Defs)
Definition autodiff.cpp:169
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:113
const Def * op_cps2ds_dep(const Def *f)
Definition direct.h:11
const Def * compose_cn(const Def *f, const Def *g)
The high level view is:
Definition lam.cpp:53
u64 nat_t
Definition types.h:44
View< const Def * > Defs
Definition def.h:62
Vector< const Def * > DefVec
Definition def.h:63
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
Definition util.h:70