11using namespace std::literals;
17 partial_pullback[lit] = pb;
22 assert(augmented.count(var));
23 auto aug_var = augmented[var];
24 assert(partial_pullback.count(aug_var));
33 if (augmented.count(lam)) {
38 world.DLOG(
"already augmented {} : {} to {} : {}", lam, lam->
type(), augmented[lam], augmented[lam]->type());
39 return augmented[lam];
44 || lam->
sym().view().find(
"_cont") != std::string::npos) {
53 world.DLOG(
"found an open continuation {} : {}", lam, lam->
type());
54 auto cont_dom = lam->
type()->
dom();
57 world.DLOG(
"augmented domain {}", aug_dom);
58 world.DLOG(
"pb type is {}", pb_ty);
59 auto aug_lam =
world.
mut_con({aug_dom, pb_ty})->set(
"aug_"s + lam->
sym().str());
60 auto aug_var = aug_lam->
var((
nat_t)0);
61 augmented[lam->
var()] = aug_var;
62 augmented[lam] = aug_lam;
63 derived[lam] = aug_lam;
64 auto pb = aug_lam->var(1);
65 partial_pullback[aug_var] = pb;
69 aug_lam->set(lam->
filter(), new_body);
72 partial_pullback[aug_lam] = lam_pb;
73 world.DLOG(
"augmented {} : {}", lam, lam->
type());
74 world.DLOG(
"to {} : {}", aug_lam, aug_lam->type());
75 world.DLOG(
"ppb for lam cont: {}", lam_pb);
79 world.DLOG(
"found a closed function call {} : {}", lam, lam->
type());
83 world.DLOG(
"augmented function is {} : {}", aug_lam, aug_lam->type());
90 auto tuple = ext->
tuple();
93 auto aug_tuple =
augment(tuple, f, f_diff);
97 world.DLOG(
"tuple was: {} : {}", tuple, tuple->type());
98 world.DLOG(
"aug tuple: {} : {}", aug_tuple, aug_tuple->type());
99 if (shadow_pullback.count(aug_tuple)) {
100 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
101 world.DLOG(
"Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
109 assert(partial_pullback.count(aug_tuple));
110 auto tuple_pb = partial_pullback[aug_tuple];
113 world.DLOG(
"Pullback: {} : {}", pb_fun, pb_fun->type());
114 auto pb_tangent = pb_fun->
var(0_s)->
set(
"s");
116 pb_fun->app(
true, tuple_pb, {tuple_tan, pb_fun->var(1) });
121 partial_pullback[aug_ext] = pb;
130 auto aug_ops = tup->
projs([&](
Ref op) ->
const Def* {
return augment(op, f, f_diff); });
133 auto pbs =
DefVec(
Defs(aug_ops), [&](
Ref op) {
return partial_pullback[op]; });
134 world.DLOG(
"tuple pbs {,}", pbs);
137 shadow_pullback[aug_tup] = shadow_pb;
146 world.DLOG(
"Augmented tuple: {} : {}", aug_tup, aug_tup->type());
147 world.DLOG(
"Tuple Pullback: {} : {}", pb, pb->type());
148 world.DLOG(
"shadow pb: {} : {}", shadow_pb, shadow_pb->type());
150 auto pb_tangent = pb->
var(0_s)->
set(
"tup_s");
153 pbs.size(), [&](
nat_t i) { return world.app(direct::op_cps2ds_dep(pbs[i]), world.extract(pb_tangent, i)); });
154 pb->app(
true, pb->var(1),
157 partial_pullback[aug_tup] = pb;
164 auto shape = pack->
arity();
165 auto body = pack->
body();
167 auto aug_shape =
augment_(shape, f, f_diff);
168 auto aug_body =
augment(body, f, f_diff);
170 auto aug_pack =
world.
pack(aug_shape, aug_body);
172 assert(partial_pullback[aug_body] &&
"pack pullback should exists");
174 auto body_pb = partial_pullback[aug_body];
175 auto pb_pack =
world.
pack(aug_shape, body_pb);
176 shadow_pullback[aug_pack] = pb_pack;
178 world.DLOG(
"shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
183 world.DLOG(
"pb of pack: {} : {}", pb, pb_type);
193 world.DLOG(
"app pb of pack: {} : {}", app_pb, app_pb->type());
196 world.DLOG(
"sumup: {} : {}", sumup, sumup->type());
198 pb->
app(
true, pb->var(1),
world.
app(sumup, app_pb));
200 partial_pullback[aug_pack] = pb;
208 auto callee = app->
callee();
209 auto arg = app->
arg();
211 auto aug_arg =
augment(arg, f, f_diff);
212 auto aug_callee =
augment(callee, f, f_diff);
214 world.DLOG(
"augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
215 world.DLOG(
"augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
219 world.DLOG(
"wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
223 if (app->
type()->isa<
Pi>()) {
224 world.DLOG(
"Nested application callee: {} : {}", aug_callee, aug_callee->type());
225 world.DLOG(
"Nested application arg: {} : {}", aug_arg, aug_arg->type());
226 auto aug_app =
world.
app(aug_callee, aug_arg);
227 world.DLOG(
"Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
241 world.DLOG(
"continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
243 auto arg_pb = partial_pullback[aug_arg];
244 auto aug_app =
world.
app(aug_callee, {aug_arg, arg_pb});
245 world.DLOG(
"Augmented application: {} : {}", aug_app, aug_app->type());
251 auto aug_app =
world.
app(aug_callee, aug_arg);
252 world.DLOG(
"Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
254 world.DLOG(
"ds function: {} : {}", aug_app, aug_app->type());
256 auto [aug_res, fun_pb] = aug_app->projs<2>();
259 auto arg_pb = partial_pullback[aug_arg];
263 world.DLOG(
"function pullback: {} : {}", fun_pb, fun_pb->type());
264 world.DLOG(
"argument pullback: {} : {}", arg_pb, arg_pb->type());
266 world.DLOG(
"result pullback: {} : {}", res_pb, res_pb->type());
267 partial_pullback[aug_res] = res_pb;
283 auto g_deriv = aug_callee;
284 world.DLOG(
"g: {} : {}", g, g->type());
285 world.DLOG(
"g': {} : {}", g_deriv, g_deriv->type());
287 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
288 world.DLOG(
"real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
289 world.DLOG(
"aug_cont: {} : {}", aug_cont, aug_cont->type());
290 auto e_pb = partial_pullback[real_aug_args];
291 world.DLOG(
"e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
294 auto ret_g_deriv_ty = g_deriv->
type()->as<
Pi>()->dom(1);
295 world.DLOG(
"ret_g_deriv_ty: {} ", ret_g_deriv_ty);
296 auto c1_ty = ret_g_deriv_ty->as<
Pi>();
297 world.DLOG(
"c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
300 auto r_pb = c1->
var(1);
301 c1->app(
true, aug_cont, {res,
compose_cn(e_pb, r_pb)});
303 auto aug_app =
world.
app(aug_callee, {real_aug_args, c1});
304 world.DLOG(
"aug_app: {} : {}", aug_app, aug_app->type());
309 assert(
false &&
"should not be reached");
319 world.DLOG(
"Augment def {} : {}", def, def->
type());
322 if (
auto app = def->isa<
App>()) {
323 auto callee = app->callee();
324 auto arg = app->arg();
325 world.DLOG(
"Augment application: app {} with {}", callee, arg);
327 }
else if (
auto ext = def->isa<
Extract>()) {
328 auto tuple = ext->tuple();
329 auto index = ext->index();
330 world.DLOG(
"Augment extract: {} #[{}]", tuple,
index);
332 }
else if (
auto var = def->isa<
Var>()) {
333 world.DLOG(
"Augment variable: {}", var);
336 world.DLOG(
"Augment mut lambda: {}", lam);
338 }
else if (
auto lam = def->isa<
Lam>()) {
339 world.ELOG(
"Augment lambda: {}", lam);
340 assert(
false &&
"can not handle non-mutable lambdas");
341 }
else if (
auto lit = def->isa<
Lit>()) {
342 world.DLOG(
"Augment literal: {}", def);
344 }
else if (
auto tup = def->isa<
Tuple>()) {
345 world.DLOG(
"Augment tuple: {}", def);
347 }
else if (
auto pack = def->isa<
Pack>()) {
349 auto shape = pack->arity();
350 auto body = pack->body();
351 world.DLOG(
"Augment pack: {} : {} with {}", shape, shape->type(), body);
353 }
else if (
auto ax = def->isa<
Axiom>()) {
355 world.DLOG(
"Augment axiom: {} : {}", ax, ax->type());
356 world.DLOG(
"axiom curry: {}", ax->curry());
357 world.DLOG(
"axiom flags: {}", ax->flags());
358 auto diff_name = ax->
sym().str();
361 diff_name =
"internal_diff_" + diff_name;
362 world.DLOG(
"axiom name: {}", ax->sym());
363 world.DLOG(
"axiom function name: {}", diff_name);
367 world.ELOG(
"derivation not found: {}", diff_name);
369 world.ELOG(
"expected: {} : {}", diff_name, expected_type);
370 assert(
false &&
"unhandled axiom");
379 world.ELOG(
"did not expect to augment: {} : {}", def, def->
type());
381 assert(
false &&
"augment not implemented on this def");
const Def * callee() const
Def * set(size_t i, const Def *def)
Successively set from left to right.
std::string_view node_name() const
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Ref var(nat_t a, nat_t i)
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Lam * set(Filter filter, const Def *body)
static const Lam * isa_basicblock(Ref d)
A (possibly paramterized) Tuple.
Pack * set(const Def *body)
A dependent function type.
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
static const Pi * isa_basicblock(Ref d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Helper class to retrieve Infer::arg if present.
Data constructor for a Sigma.
const Def * annex(Id id)
Lookup annex by Axiom::id.
Ref insert(Ref d, Ref i, Ref val)
Ref var(Ref type, Def *mut)
Ref arr(Ref shape, Ref body)
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Ref app(Ref callee, Ref arg)
Sym sym(std::string_view)
Def * external(Sym name)
Lookup by name.
const Def * call(Id id, Args &&... args)
Ref pack(Ref arity, Ref body)
Ref extract(Ref d, Ref i)
const Type * type(Ref level)
Pack * mut_pack(Ref type)
Lam * mut_lam(const Pi *pi)
Ref augment_var(const Var *, Lam *, Lam *)
helper functions for augment
Ref augment_lit(const Lit *, Lam *, Lam *)
Ref augment_tuple(const Tuple *, Lam *, Lam *)
Ref augment_lam(Lam *, Lam *, Lam *)
Ref augment_(Ref, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Ref augment_app(const App *, Lam *, Lam *)
Ref augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
Ref augment_extract(const Extract *, Lam *, Lam *)
The automatic differentiation Plugin
const Def * op_sum(const Def *T, Defs)
const Def * autodiff_type_fun(const Def *)
const Def * tangent_type_fun(const Def *)
const Def * zero_pullback(const Def *E, const Def *A)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
const Def * op_cps2ds_dep(const Def *f)
Vector< const Def * > DefVec
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
const Def * compose_cn(const Def *f, const Def *g)
The high level view is: