11using namespace std::literals;
17 partial_pullback[lit] = pb;
22 assert(augmented.count(var));
23 auto aug_var = augmented[var];
24 assert(partial_pullback.count(aug_var));
32 if (augmented.count(lam)) {
37 world().DLOG(
"already augmented {} : {} to {} : {}", lam, lam->
type(), augmented[lam], augmented[lam]->type());
38 return augmented[lam];
43 || lam->
sym().view().find(
"_cont") != std::string::npos) {
52 world().DLOG(
"found an open continuation {} : {}", lam, lam->
type());
53 auto cont_dom = lam->
type()->
dom();
56 world().DLOG(
"augmented domain {}", aug_dom);
57 world().DLOG(
"pb type is {}", pb_ty);
58 auto aug_lam =
world().
mut_con({aug_dom, pb_ty})->set(
"aug_"s + lam->
sym().str());
59 auto aug_var = aug_lam->
var((
nat_t)0);
60 augmented[lam->
var()] = aug_var;
61 augmented[lam] = aug_lam;
62 derived[lam] = aug_lam;
63 auto pb = aug_lam->var(1);
64 partial_pullback[aug_var] = pb;
68 aug_lam->set(lam->
filter(), new_body);
71 partial_pullback[aug_lam] = lam_pb;
72 world().DLOG(
"augmented {} : {}", lam, lam->
type());
73 world().DLOG(
"to {} : {}", aug_lam, aug_lam->type());
74 world().DLOG(
"ppb for lam cont: {}", lam_pb);
78 world().DLOG(
"found a closed function call {} : {}", lam, lam->
type());
82 world().DLOG(
"augmented function is {} : {}", aug_lam, aug_lam->type());
87 auto tuple = ext->
tuple();
90 auto aug_tuple =
augment(tuple, f, f_diff);
94 world().DLOG(
"tuple was: {} : {}", tuple, tuple->type());
95 world().DLOG(
"aug tuple: {} : {}", aug_tuple, aug_tuple->type());
96 if (shadow_pullback.count(aug_tuple)) {
97 auto shadow_tuple_pb = shadow_pullback[aug_tuple];
98 world().DLOG(
"Shadow pullback: {} : {}", shadow_tuple_pb, shadow_tuple_pb->type());
106 assert(partial_pullback.count(aug_tuple));
107 auto tuple_pb = partial_pullback[aug_tuple];
110 world().DLOG(
"Pullback: {} : {}", pb_fun, pb_fun->type());
111 auto pb_tangent = pb_fun->
var(0_s)->
set(
"s");
112 auto tuple_tan =
world().
insert(
world().call<zero>(aug_tuple->type()), aug_index, pb_tangent)->
set(
"tup_s");
113 pb_fun->app(
true, tuple_pb, {tuple_tan, pb_fun->var(1) });
118 partial_pullback[aug_ext] = pb;
125 auto aug_ops = tup->
projs([&](
Ref op) ->
const Def* {
return augment(op, f, f_diff); });
128 auto pbs =
DefVec(
Defs(aug_ops), [&](
Ref op) {
return partial_pullback[op]; });
129 world().DLOG(
"tuple pbs {,}", pbs);
132 shadow_pullback[aug_tup] = shadow_pb;
141 world().DLOG(
"Augmented tuple: {} : {}", aug_tup, aug_tup->type());
142 world().DLOG(
"Tuple Pullback: {} : {}", pb, pb->type());
143 world().DLOG(
"shadow pb: {} : {}", shadow_pb, shadow_pb->type());
145 auto pb_tangent = pb->
var(0_s)->
set(
"tup_s");
148 return world().app(direct::op_cps2ds_dep(pbs[i]), world().extract(pb_tangent, i));
150 pb->app(
true, pb->var(1),
153 partial_pullback[aug_tup] = pb;
159 auto shape = pack->
arity();
160 auto body = pack->
body();
162 auto aug_shape =
augment_(shape, f, f_diff);
163 auto aug_body =
augment(body, f, f_diff);
165 auto aug_pack =
world().
pack(aug_shape, aug_body);
167 assert(partial_pullback[aug_body] &&
"pack pullback should exists");
169 auto body_pb = partial_pullback[aug_body];
170 auto pb_pack =
world().
pack(aug_shape, body_pb);
171 shadow_pullback[aug_pack] = pb_pack;
173 world().DLOG(
"shadow pb of pack: {} : {}", pb_pack, pb_pack->type());
178 world().DLOG(
"pb of pack: {} : {}", pb, pb_type);
188 world().DLOG(
"app pb of pack: {} : {}", app_pb, app_pb->type());
190 auto sumup =
world().
app(
world().annex<sum>(), {aug_shape, f_arg_ty_diff});
191 world().DLOG(
"sumup: {} : {}", sumup, sumup->type());
193 pb->
app(
true, pb->var(1),
world().app(sumup, app_pb));
195 partial_pullback[aug_pack] = pb;
201 auto callee = app->
callee();
202 auto arg = app->
arg();
204 auto aug_arg =
augment(arg, f, f_diff);
205 auto aug_callee =
augment(callee, f, f_diff);
207 world().DLOG(
"augmented argument <{}> {} : {}", aug_arg->unique_name(), aug_arg, aug_arg->type());
208 world().DLOG(
"augmented callee <{}> {} : {}", aug_callee->unique_name(), aug_callee, aug_callee->type());
212 world().DLOG(
"wrapped augmented callee: <{}> {} : {}", aug_callee->unique_name(), aug_callee,
217 if (app->
type()->isa<
Pi>()) {
218 world().DLOG(
"Nested application callee: {} : {}", aug_callee, aug_callee->type());
219 world().DLOG(
"Nested application arg: {} : {}", aug_arg, aug_arg->type());
220 auto aug_app =
world().
app(aug_callee, aug_arg);
221 world().DLOG(
"Nested application result: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
235 world().DLOG(
"continuation {} : {} => {} : {}", callee, callee->type(), aug_callee, aug_callee->type());
237 auto arg_pb = partial_pullback[aug_arg];
238 auto aug_app =
world().
app(aug_callee, {aug_arg, arg_pb});
239 world().DLOG(
"Augmented application: {} : {}", aug_app, aug_app->type());
245 auto aug_app =
world().
app(aug_callee, aug_arg);
246 world().DLOG(
"Augmented application: <{}> {} : {}", aug_app->unique_name(), aug_app, aug_app->type());
248 world().DLOG(
"ds function: {} : {}", aug_app, aug_app->type());
250 auto [aug_res, fun_pb] = aug_app->projs<2>();
253 auto arg_pb = partial_pullback[aug_arg];
257 world().DLOG(
"function pullback: {} : {}", fun_pb, fun_pb->type());
258 world().DLOG(
"argument pullback: {} : {}", arg_pb, arg_pb->type());
260 world().DLOG(
"result pullback: {} : {}", res_pb, res_pb->type());
261 partial_pullback[aug_res] = res_pb;
277 auto g_deriv = aug_callee;
278 world().DLOG(
"g: {} : {}", g, g->type());
279 world().DLOG(
"g': {} : {}", g_deriv, g_deriv->type());
281 auto [real_aug_args, aug_cont] = aug_arg->projs<2>();
282 world().DLOG(
"real_aug_args: {} : {}", real_aug_args, real_aug_args->type());
283 world().DLOG(
"aug_cont: {} : {}", aug_cont, aug_cont->type());
284 auto e_pb = partial_pullback[real_aug_args];
285 world().DLOG(
"e_pb (arg_pb): {} : {}", e_pb, e_pb->type());
288 auto ret_g_deriv_ty = g_deriv->
type()->as<
Pi>()->dom(1);
289 world().DLOG(
"ret_g_deriv_ty: {} ", ret_g_deriv_ty);
290 auto c1_ty = ret_g_deriv_ty->as<
Pi>();
291 world().DLOG(
"c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
294 auto r_pb = c1->
var(1);
295 c1->app(
true, aug_cont, {res,
compose_cn(e_pb, r_pb)});
297 auto aug_app =
world().
app(aug_callee, {real_aug_args, c1});
298 world().DLOG(
"aug_app: {} : {}", aug_app, aug_app->type());
303 assert(
false &&
"should not be reached");
312 world().DLOG(
"Augment def {} : {}", def, def->
type());
315 if (
auto app = def->isa<
App>()) {
316 auto callee = app->callee();
317 auto arg = app->arg();
318 world().DLOG(
"Augment application: app {} with {}", callee, arg);
320 }
else if (
auto ext = def->isa<
Extract>()) {
321 auto tuple = ext->tuple();
322 auto index = ext->index();
323 world().DLOG(
"Augment extract: {} #[{}]", tuple,
index);
325 }
else if (
auto var = def->isa<
Var>()) {
326 world().DLOG(
"Augment variable: {}", var);
329 world().DLOG(
"Augment mut lambda: {}", lam);
331 }
else if (
auto lam = def->isa<
Lam>()) {
332 world().ELOG(
"Augment lambda: {}", lam);
333 assert(
false &&
"can not handle non-mutable lambdas");
334 }
else if (
auto lit = def->isa<
Lit>()) {
335 world().DLOG(
"Augment literal: {}", def);
337 }
else if (
auto tup = def->isa<
Tuple>()) {
338 world().DLOG(
"Augment tuple: {}", def);
340 }
else if (
auto pack = def->isa<
Pack>()) {
342 auto shape = pack->arity();
343 auto body = pack->body();
344 world().DLOG(
"Augment pack: {} : {} with {}", shape, shape->type(), body);
346 }
else if (
auto ax = def->isa<
Axiom>()) {
348 world().DLOG(
"Augment axiom: {} : {}", ax, ax->type());
349 world().DLOG(
"axiom curry: {}", ax->curry());
350 world().DLOG(
"axiom flags: {}", ax->flags());
351 auto diff_name = ax->
sym().str();
354 diff_name =
"internal_diff_" + diff_name;
355 world().DLOG(
"axiom name: {}", ax->sym());
356 world().DLOG(
"axiom function name: {}", diff_name);
360 world().ELOG(
"derivation not found: {}", diff_name);
362 world().ELOG(
"expected: {} : {}", diff_name, expected_type);
363 assert(
false &&
"unhandled axiom");
372 world().ELOG(
"did not expect to augment: {} : {}", def, def->
type());
374 assert(
false &&
"augment not implemented on this def");
std::string_view node_name() const
Def * set(size_t i, Ref)
Successively set from left to right.
T * isa_mut() const
If this is *mut*able, it will cast constness away and perform a dynamic_cast to T.
Ref var(nat_t a, nat_t i)
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Lam * set(Filter filter, Ref body)
static const Lam * isa_basicblock(Ref d)
A (possibly paramterized) Tuple.
A dependent function type.
static const Pi * isa_cn(Ref d)
Is this a continuation - i.e. is the Pi::codom mim::Bottom?
static const Pi * isa_basicblock(Ref d)
Is this a continuation (Pi::isa_cn) that is not Pi::isa_returning?
Helper class to retrieve Infer::arg if present.
Data constructor for a Sigma.
Ref insert(Ref d, Ref i, Ref val)
Ref var(Ref type, Def *mut)
void debug_dump()
Dump in Debug build if World::log::level is Log::Level::Debug.
Sym sym(std::string_view)
Def * external(Sym name)
Lookup by name.
Ref pack(Ref arity, Ref body)
const Def * call(Id id, Args &&... args)
Complete curried call of annexes obeying implicits.
Ref extract(Ref d, Ref i)
Ref app(Ref callee, Ref arg)
const Type * type(Ref level)
Pack * mut_pack(Ref type)
Lam * mut_lam(const Pi *pi)
Ref augment_var(const Var *, Lam *, Lam *)
helper functions for augment
Ref augment_lit(const Lit *, Lam *, Lam *)
Ref augment_tuple(const Tuple *, Lam *, Lam *)
Ref augment_lam(Lam *, Lam *, Lam *)
Ref augment_(Ref, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Ref augment_app(const App *, Lam *, Lam *)
Ref augment_pack(const Pack *pack, Lam *f, Lam *f_diff)
Ref augment_extract(const Extract *, Lam *, Lam *)
The automatic differentiation Plugin
const Def * op_sum(const Def *T, Defs)
const Def * autodiff_type_fun(const Def *)
const Def * tangent_type_fun(const Def *)
const Def * zero_pullback(const Def *E, const Def *A)
const Pi * pullback_type(const Def *E, const Def *A)
computes pb type E* -> A* E - type of the expression (return type for a function) A - type of the arg...
Vector< const Def * > DefVec
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
Ref compose_cn(Ref f, Ref g)
The high level view is: