11using namespace std::literals;
17 partial_pullback[lit] = pb;
22 assert(augmented.count(var));
23 auto aug_var = augmented[var];
24 assert(partial_pullback.count(aug_var));
33 if (augmented.count(lam)) {
38 world.DLOG(
"already augmented {} : {} to {} : {}", lam, lam->
type(), augmented[lam], augmented[lam]->type());
39 return augmented[lam];
44 || lam->
sym().view().find(
"_cont") != std::string::npos) {
53 world.DLOG(
"found an open continuation {} : {}", lam, lam->
type());
54 auto cont_dom = lam->
type()->
dom();
57 world.DLOG(
"augmented domain {}", aug_dom);
58 world.DLOG(
"pb type is {}", pb_ty);
59 auto aug_lam =
world.
mut_con({aug_dom, pb_ty})->set(
"aug_"s + lam->
sym().str());
60 auto aug_var = aug_lam->
var((
nat_t)0);
61 augmented[lam->
var()] = aug_var;
62 augmented[lam] = aug_lam;
63 derived[lam] = aug_lam;
64 auto pb = aug_lam->var(1);
65 partial_pullback[aug_var] = pb;
69 aug_lam->set(lam->
filter(), new_body);
72 partial_pullback[aug_lam] = lam_pb;
73 world.DLOG(
"augmented {} : {}", lam, lam->
type());
74 world.DLOG(
"to {} : {}", aug_lam, aug_lam->type());
75 world.DLOG(
"ppb for lam cont: {}", lam_pb);
79 world.DLOG(
"found a closed function call {} : {}", lam, lam->
type());
83 world.DLOG(
"augmented function is {} : {}", aug_lam, aug_lam->type());
90 auto tuple = ext->
tuple();
93 auto aug_tuple =
augment(tuple, f, f_diff);
97 world.DLOG(
"tuple was: {} : {}", tuple, tuple->type());
98 world.DLOG(
"aug tuple: {} : {}", aug_tuple, aug_tuple->type());
99 if (shadow_pullback.count(aug_tuple)) {
100 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
101 world.DLOG(
"Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
109 assert(partial_pullback.count(aug_tuple));
110 auto tuple_pb = partial_pullback[aug_tuple];
113 world.DLOG(
"Pullback: {} : {}", pb_fun, pb_fun->type());
114 auto pb_tangent = pb_fun->
var(0_s)->
set(
"s");
116 pb_fun->app(
true, tuple_pb,
125 partial_pullback[aug_ext] = pb;
134 auto aug_ops = tup->
projs([&](
Ref op) ->
const Def* {
return augment(op, f, f_diff); });
137 auto pbs =
DefVec(
Defs(aug_ops), [&](
Ref op) {
return partial_pullback[op]; });
138 world.DLOG(
"tuple pbs {,}", pbs);
141 shadow_pullback[aug_tup] = shadow_pb;
150 world.DLOG(
"Augmented tuple: {} : {}", aug_tup, aug_tup->type());
151 world.DLOG(
"Tuple Pullback: {} : {}", pb, pb->type());
152 world.DLOG(
"shadow pb: {} : {}", shadow_pb, shadow_pb->type());
154 auto pb_tangent = pb->
var(0_s)->
set(
"tup_s");
157 pbs.size(), [&](
nat_t i) { return world.app(direct::op_cps2ds_dep(pbs[i]), world.extract(pb_tangent, i)); });
158 pb->app(
true, pb->var(1),
161 partial_pullback[aug_tup] = pb;
168 auto shape = pack->
arity();
169 auto body = pack->
body();
171 auto aug_shape =
augment_(shape, f, f_diff);
172 auto aug_body =
augment(body, f, f_diff);
174 auto aug_pack =
world.
pack(aug_shape, aug_body);
176 assert(partial_pullback[aug_body] &&
"pack pullback should exists");
178 auto body_pb = partial_pullback[aug_body];
179 auto pb_pack =
world.
pack(aug_shape, body_pb);
180 shadow_pullback[aug_pack] = pb_pack;
182 world.DLOG(
"shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
187 world.DLOG(
"pb of pack: {} : {}", pb, pb_type);
197 world.DLOG(
"app pb of pack: {} : {}", app_pb, app_pb->type());
200 world.DLOG(
"sumup: {} : {}", sumup, sumup->type());
202 pb->
app(
true, pb->var(1),
world.
app(sumup, app_pb));
204 partial_pullback[aug_pack] = pb;
212 auto callee = app->
callee();
213 auto arg = app->
arg();
215 auto aug_arg =
augment(arg, f, f_diff);
216 auto aug_callee =
augment(callee, f, f_diff);
218 world.DLOG(
"augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
219 world.DLOG(
"augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
223 world.DLOG(
"wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
227 if (app->
type()->isa<
Pi>()) {
228 world.DLOG(
"Nested application callee: {} : {}", aug_callee, aug_callee->type());
229 world.DLOG(
"Nested application arg: {} : {}", aug_arg, aug_arg->type());
230 auto aug_app =
world.
app(aug_callee, aug_arg);
231 world.DLOG(
"Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
245 world.DLOG(
"continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
247 auto arg_pb = partial_pullback[aug_arg];
248 auto aug_app =
world.
app(aug_callee, {aug_arg, arg_pb});
249 world.DLOG(
"Augmented application: {} : {}", aug_app, aug_app->type());
255 auto aug_app =
world.
app(aug_callee, aug_arg);
256 world.DLOG(
"Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
258 world.DLOG(
"ds function: {} : {}", aug_app, aug_app->type());
260 auto [aug_res, fun_pb] = aug_app->projs<2>();
263 auto arg_pb = partial_pullback[aug_arg];
267 world.DLOG(
"function pullback: {} : {}", fun_pb, fun_pb->type());
268 world.DLOG(
"argument pullback: {} : {}", arg_pb, arg_pb->type());
270 world.DLOG(
"result pullback: {} : {}", res_pb, res_pb->type());
271 partial_pullback[aug_res] = res_pb;
287 auto g_deriv = aug_callee;
288 world.DLOG(
"g: {} : {}", g, g->type());
289 world.DLOG(
"g': {} : {}", g_deriv, g_deriv->type());
291 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
292 world.DLOG(
"real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
293 world.DLOG(
"aug_cont: {} : {}", aug_cont, aug_cont->type());
294 auto e_pb = partial_pullback[real_aug_args];
295 world.DLOG(
"e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
298 auto ret_g_deriv_ty = g_deriv->
type()->as<
Pi>()->dom(1);
299 world.DLOG(
"ret_g_deriv_ty: {} ", ret_g_deriv_ty);
300 auto c1_ty = ret_g_deriv_ty->as<
Pi>();
301 world.DLOG(
"c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
304 auto r_pb = c1->
var(1);
305 c1->app(
true, aug_cont, {res,
compose_cn(e_pb, r_pb)});
307 auto aug_app =
world.
app(aug_callee, {real_aug_args, c1});
308 world.DLOG(
"aug_app: {} : {}", aug_app, aug_app->type());
313 assert(
false &&
"should not be reached");
323 world.DLOG(
"Augment def {} : {}", def, def->
type());
326 if (
auto app = def->isa<
App>()) {
327 auto callee = app->callee();
328 auto arg = app->arg();
329 world.DLOG(
"Augment application: app {} with {}", callee, arg);
331 }
else if (
auto ext = def->isa<
Extract>()) {
332 auto tuple = ext->tuple();
333 auto index = ext->index();
334 world.DLOG(
"Augment extract: {} #[{}]", tuple,
index);
336 }
else if (
auto var = def->isa<
Var>()) {
337 world.DLOG(
"Augment variable: {}", var);
340 world.DLOG(
"Augment mut lambda: {}", lam);
342 }
else if (
auto lam = def->isa<
Lam>()) {
343 world.ELOG(
"Augment lambda: {}", lam);
344 assert(
false &&
"can not handle non-mutable lambdas");
345 }
else if (
auto lit = def->isa<
Lit>()) {
346 world.DLOG(
"Augment literal: {}", def);
348 }
else if (
auto tup = def->isa<
Tuple>()) {
349 world.DLOG(
"Augment tuple: {}", def);
351 }
else if (
auto pack = def->isa<
Pack>()) {
353 auto shape = pack->arity();
354 auto body = pack->body();
355 world.DLOG(
"Augment pack: {} : {} with {}", shape, shape->type(), body);
357 }
else if (
auto ax = def->isa<
Axiom>()) {
359 world.DLOG(
"Augment axiom: {} : {}", ax, ax->type());
360 world.DLOG(
"axiom curry: {}", ax->curry());
361 world.DLOG(
"axiom flags: {}", ax->flags());
362 auto diff_name = ax->
sym().str();
365 diff_name =
"internal_diff_" + diff_name;
366 world.DLOG(
"axiom name: {}", ax->sym());
367 world.DLOG(
"axiom function name: {}", diff_name);
371 world.ELOG(
"derivation not found: {}", diff_name);
373 world.ELOG(
"expected: {} : {}", diff_name, expected_type);
374 assert(
false &&
"unhandled axiom");
383 world.ELOG(
"did not expect to augment: {} : {}", def, def->
type());
385 assert(
false &&
"augment not implemented on this def");
const Def * callee() const
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Ref var(nat_t a, nat_t i)
const Def * type() const
Yields the raw type of this Def, i.e. maybe nullptr.
std::string_view node_name() const
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Def * set(size_t i, const Def *def)
Successively set from left to right.
static const Lam * isa_basicblock(Ref d)
Lam * set(Filter filter, const Def *body)
A (possibly paramterized) Tuple.
Pack * set(const Def *body)
A dependent function type.
static const Pi * isa_basicblock(Ref d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom thorin::Bottom?
Helper class to retrieve Infer::arg if present.
Data constructor for a Sigma.
Ref insert(Ref d, Ref i, Ref val)
Pack * mut_pack(Ref type)
Sym sym(std::string_view)
Ref var(Ref type, Def *mut)
Ref pack(Ref arity, Ref body)
Def * external(Sym name)
Lookup by name.
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
const Def * annex(Id id)
Lookup annex by Axiom::id.
Ref extract(Ref d, Ref i)
const Def * call(Id id, Args &&... args)
Ref arr(Ref shape, Ref body)
Ref app(Ref callee, Ref arg)
Lam * mut_lam(const Pi *pi)
const Type * type(Ref level)
Ref augment_extract(const Extract *, Lam *, Lam *)
Ref augment_app(const App *, Lam *, Lam *)
Ref augment_tuple(const Tuple *, Lam *, Lam *)
Ref augment_lit(const Lit *, Lam *, Lam *)
Ref augment_(Ref, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
Ref augment_lam(Lam *, Lam *, Lam *)
Ref augment_var(const Var *, Lam *, Lam *)
helper functions for augment
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Ref augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
The automatic differentiation Plugin
const Def * tangent_type_fun(const Def *)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
const Def * zero_pullback(const Def *E, const Def *A)
const Def * op_sum(const Def *T, Defs)
const Def * autodiff_type_fun(const Def *)
const Def * op_cps2ds_dep(const Def *f)
const Def * compose_cn(const Def *f, const Def *g)
The high level view is:
Vector< const Def * > DefVec
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.