Thorin 1.9.0
The Higher ORder INtermediate representation
Loading...
Searching...
No Matches
autodiff_eval.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5#include <thorin/lam.h>
6
10#include "thorin/plug/mem/mem.h"
11
12namespace thorin::plug::autodiff {
13
14// TODO: maybe use template (https://codereview.stackexchange.com/questions/141961/memoization-via-template) to memoize
15Ref AutoDiffEval::augment(Ref def, Lam* f, Lam* f_diff) {
16 if (auto i = augmented.find(def); i != augmented.end()) return i->second;
17 augmented[def] = augment_(def, f, f_diff);
18 return augmented[def];
19}
20
22 if (auto i = derived.find(def); i != derived.end()) return i->second;
23 derived[def] = derive_(def);
24 return derived[def];
25}
26
28 if (auto ad_app = match<ad>(def); ad_app) {
29 // callee = autodiff T
30 // arg = function of type T
31 // (or operator)
32 auto arg = ad_app->arg();
33 world().DLOG("found a autodiff::autodiff of {}", arg);
34
35 if (arg->isa<Lam>()) return derive(arg);
36
37 // TODO: handle operators analogous
38
39 assert(0 && "not implemented");
40 return def;
41 }
42
43 return def;
44}
45
46} // namespace thorin::plug::autodiff
A function.
Definition lam.h:97
World & world()
Definition pass.h:296
Helper class to retrieve Infer::arg if present.
Definition def.h:87
Ref derive_(Ref)
Additionally to the derivation, the pullback is registered and the maps are initialized.
Ref derive(Ref)
Acts on toplevel autodiff on closed terms:
Ref augment_(Ref, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
Ref rewrite(Ref) override
Detect autodiff calls.
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
The automatic differentiation Plugin
Definition autodiff.h:7