6using namespace std::literals;
12 partial_pullback[lit] = pb;
17 assert(augmented.count(var));
18 auto aug_var = augmented[var];
19 assert(partial_pullback.count(aug_var));
27 if (augmented.count(lam)) {
32 DLOG(
"already augmented {} : {} to {} : {}", lam, lam->
type(), augmented[lam], augmented[lam]->
type());
33 return augmented[lam];
38 || lam->
sym().view().find(
"_cont") != std::string::npos) {
47 DLOG(
"found an open continuation {} : {}", lam, lam->
type());
48 auto cont_dom = lam->
type()->
dom();
51 DLOG(
"augmented domain {}", aug_dom);
52 DLOG(
"pb type is {}", pb_ty);
53 auto aug_lam =
world().
mut_con({aug_dom, pb_ty})->set(
"aug_"s + lam->
sym().str());
54 auto aug_var = aug_lam->var((
nat_t)0);
55 augmented[lam->
var()] = aug_var;
56 augmented[lam] = aug_lam;
57 derived[lam] = aug_lam;
58 auto pb = aug_lam->var(1);
59 partial_pullback[aug_var] = pb;
63 aug_lam->set(lam->
filter(), new_body);
66 partial_pullback[aug_lam] = lam_pb;
67 DLOG(
"augmented {} : {}", lam, lam->
type());
68 DLOG(
"to {} : {}", aug_lam, aug_lam->type());
69 DLOG(
"ppb for lam cont: {}", lam_pb);
73 DLOG(
"found a closed function call {} : {}", lam, lam->
type());
77 DLOG(
"augmented function is {} : {}", aug_lam, aug_lam->type());
90 DLOG(
"aug tuple: {} : {}", aug_tuple, aug_tuple->type());
91 if (shadow_pullback.count(aug_tuple)) {
92 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
93 DLOG(
"Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
101 assert(partial_pullback.count(aug_tuple));
102 auto tuple_pb = partial_pullback[aug_tuple];
105 DLOG(
"Pullback: {} : {}", pb_fun, pb_fun->type());
106 auto pb_tangent = pb_fun->var(0_s)->set(
"s");
107 auto tuple_tan =
world().
insert(
world().call<zero>(aug_tuple->type()), aug_index, pb_tangent)->
set(
"tup_s");
108 pb_fun->app(
true, tuple_pb, {tuple_tan, pb_fun->var(1) });
113 partial_pullback[aug_ext] = pb;
120 auto aug_ops = tup->
projs([&](
const Def* op) ->
const Def* {
return augment(op, f, f_diff); });
123 auto pbs =
DefVec(
Defs(aug_ops), [&](
const Def* op) {
return partial_pullback[op]; });
124 DLOG(
"tuple pbs {,}", pbs);
127 shadow_pullback[aug_tup] = shadow_pb;
136 DLOG(
"Augmented tuple: {} : {}", aug_tup, aug_tup->type());
137 DLOG(
"Tuple Pullback: {} : {}", pb, pb->type());
138 DLOG(
"shadow pb: {} : {}", shadow_pb, shadow_pb->type());
140 auto pb_tangent = pb->var(0_s)->set(
"tup_s");
143 return world().app(direct::op_cps2ds_dep(pbs[i]), world().extract(pb_tangent, i));
145 pb->app(
true, pb->var(1),
148 partial_pullback[aug_tup] = pb;
154 auto arity = pack->
arity();
155 auto body = pack->
body();
157 auto aug_arity =
augment_(arity, f, f_diff);
158 auto aug_body =
augment(body, f, f_diff);
160 auto aug_pack =
world().
pack(aug_arity, aug_body);
162 assert(partial_pullback[aug_body] &&
"pack pullback should exists");
164 auto body_pb = partial_pullback[aug_body];
165 auto pb_pack =
world().
pack(aug_arity, body_pb);
166 shadow_pullback[aug_pack] = pb_pack;
168 DLOG(
"shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
173 DLOG(
"pb of pack: {} : {}", pb, pb_type);
183 DLOG(
"app pb of pack: {} : {}", app_pb, app_pb->type());
186 DLOG(
"sumup: {} : {}", sumup, sumup->type());
188 pb->app(
true, pb->var(1),
world().app(sumup, app_pb));
190 partial_pullback[aug_pack] = pb;
196 auto callee = app->
callee();
197 auto arg = app->
arg();
199 auto aug_arg =
augment(arg, f, f_diff);
200 auto aug_callee =
augment(callee, f, f_diff);
202 DLOG(
"augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
203 DLOG(
"augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
207 DLOG(
"wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
211 if (app->
type()->isa<
Pi>()) {
212 DLOG(
"Nested application callee: {} : {}", aug_callee, aug_callee->type());
213 DLOG(
"Nested application arg: {} : {}", aug_arg, aug_arg->type());
214 auto aug_app =
world().
app(aug_callee, aug_arg);
215 DLOG(
"Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
229 DLOG(
"continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
231 auto arg_pb = partial_pullback[aug_arg];
232 auto aug_app =
world().
app(aug_callee, {aug_arg, arg_pb});
233 DLOG(
"Augmented application: {} : {}", aug_app, aug_app->type());
239 auto aug_app =
world().
app(aug_callee, aug_arg);
240 DLOG(
"Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
242 DLOG(
"ds function: {} : {}", aug_app, aug_app->type());
244 auto [aug_res, fun_pb] = aug_app->projs<2>();
247 auto arg_pb = partial_pullback[aug_arg];
251 DLOG(
"function pullback: {} : {}", fun_pb, fun_pb->type());
252 DLOG(
"argument pullback: {} : {}", arg_pb, arg_pb->type());
254 DLOG(
"result pullback: {} : {}", res_pb, res_pb->type());
255 partial_pullback[aug_res] = res_pb;
271 auto g_deriv = aug_callee;
272 DLOG(
"g: {} : {}", g, g->type());
273 DLOG(
"g': {} : {}", g_deriv, g_deriv->type());
275 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
276 DLOG(
"real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
277 DLOG(
"aug_cont: {} : {}", aug_cont, aug_cont->type());
278 auto e_pb = partial_pullback[real_aug_args];
279 DLOG(
"e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
282 auto ret_g_deriv_ty = g_deriv->type()->as<
Pi>()->dom(1);
283 DLOG(
"ret_g_deriv_ty: {} ", ret_g_deriv_ty);
284 auto c1_ty = ret_g_deriv_ty->as<
Pi>();
285 DLOG(
"c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
288 auto r_pb = c1->
var(1);
289 c1->app(
true, aug_cont, {res,
compose_cn(e_pb, r_pb)});
291 auto aug_app =
world().
app(aug_callee, {real_aug_args, c1});
292 DLOG(
"aug_app: {} : {}", aug_app, aug_app->type());
297 assert(
false &&
"should not be reached");
306 DLOG(
"Augment def {} : {}", def, def->
type());
309 if (
auto app = def->isa<
App>()) {
310 auto callee = app->callee();
311 auto arg = app->arg();
312 DLOG(
"Augment application: app {} with {}", callee, arg);
314 }
else if (
auto ext = def->isa<
Extract>()) {
315 auto tuple = ext->tuple();
316 auto index = ext->index();
319 }
else if (
auto var = def->isa<
Var>()) {
320 DLOG(
"Augment variable: {}", var);
323 DLOG(
"Augment mut lambda: {}", lam);
325 }
else if (
auto lam = def->isa<
Lam>()) {
326 ELOG(
"Augment lambda: {}", lam);
327 assert(
false &&
"can not handle non-mutable lambdas");
328 }
else if (
auto lit = def->isa<
Lit>()) {
329 DLOG(
"Augment literal: {}", def);
331 }
else if (
auto tup = def->isa<
Tuple>()) {
332 DLOG(
"Augment tuple: {}", def);
334 }
else if (
auto pack = def->isa<
Pack>()) {
336 auto arity = pack->arity();
337 auto body = pack->body();
338 DLOG(
"Augment pack: {} : {} with {}", arity, arity->type(), body);
340 }
else if (
auto ax = def->isa<
Axm>()) {
342 DLOG(
"Augment axm: {} : {}", ax, ax->type());
343 DLOG(
"axm curry: {}", ax->curry());
344 DLOG(
"axm flags: {}", ax->flags());
345 auto diff_name = ax->sym().str();
348 diff_name =
"internal_diff_" + diff_name;
349 DLOG(
"axm name: {}", ax->sym());
350 DLOG(
"axm function name: {}", diff_name);
354 ELOG(
"derivation not found: {}", diff_name);
356 ELOG(
"expected: {} : {}", diff_name, expected_type);
357 assert(
false &&
"unhandled axm");
366 ELOG(
"did not expect to augment: {} : {}", def, def->
type());
368 assert(
false &&
"augment not implemented on this def");
const Def * callee() const
Def * set(size_t i, const Def *)
Successively set from left to right.
std::string_view node_name() const
T * isa_mut() const
If this is mutable, it will cast constness away and perform a dynamic_cast to T.
const Def * var(nat_t a, nat_t i) noexcept
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
const Def * filter() const
Lam * set(Filter filter, const Def *body)
static const Lam * isa_basicblock(const Def *d)
A (possibly paramterized) Tuple.
const Def * arity() const final
Pack * set(const Def *body)
A dependent function type.
static const Pi * isa_cn(const Def *d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
static const Pi * isa_basicblock(const Def *d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Data constructor for a Sigma.
const Def * insert(const Def *d, const Def *i, const Def *val)
const Def * pack(const Def *arity, const Def *body)
const Def * app(const Def *callee, const Def *arg)
const Def * tuple(Defs ops)
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
const Def * extract(const Def *d, const Def *i)
Pack * mut_pack(const Def *type)
Def * external(Sym name)
Lookup by name.
const Def * call(Id id, Args &&... args)
Complete curried call of annexes obeying implicits.
Lam * mut_con(const Def *dom)
Lam * mut_lam(const Pi *pi)
const Def * augment_lit(const Lit *, Lam *, Lam *)
const Def * augment_tuple(const Tuple *, Lam *, Lam *)
const Def * augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
const Def * augment_app(const App *, Lam *, Lam *)
const Def * augment_lam(Lam *, Lam *, Lam *)
const Def * augment_(const Def *, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
const Def * augment_extract(const Extract *, Lam *, Lam *)
const Def * augment_var(const Var *, Lam *, Lam *)
helper functions for augment
#define DLOG(...)
Vaporizes to nothingness in Debug build.
The automatic differentiation Plugin
const Def * op_sum(const Def *T, Defs)
const Def * autodiff_type_fun(const Def *)
const Def * tangent_type_fun(const Def *)
const Def * zero_pullback(const Def *E, const Def *A)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
const Def * op_cps2ds_dep(const Def *k)
Vector< const Def * > DefVec
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
const Def * compose_cn(const Def *f, const Def *g)
The high level view is: