MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
eval.cpp
Go to the documentation of this file.
2
3#include <mim/lam.h>
4
6
7namespace mim::plug::autodiff {
8
9// TODO: maybe use template (https://codereview.stackexchange.com/questions/141961/memoization-via-template) to memoize
10const Def* Eval::augment(const Def* def, Lam* f, Lam* f_diff) {
11 if (auto i = augmented.find(def); i != augmented.end()) return i->second;
12 augmented[def] = augment_(def, f, f_diff);
13 return augmented[def];
14}
15
16const Def* Eval::derive(const Def* def) {
17 if (auto i = derived.find(def); i != derived.end()) return i->second;
18 derived[def] = derive_(def);
19 return derived[def];
20}
21
22const Def* Eval::rewrite(const Def* def) {
23 if (auto ad_app = Axm::isa<ad>(def); ad_app) {
24 // callee = autodiff T
25 // arg = function of type T
26 // (or operator)
27 auto arg = ad_app->arg();
28 DLOG("found a autodiff::autodiff of {}", arg);
29
30 if (arg->isa<Lam>()) return derive(arg);
31
32 // TODO: handle operators analogous
33
34 assert(0 && "not implemented");
35 return def;
36 }
37
38 return def;
39}
40
41} // namespace mim::plug::autodiff
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
A function.
Definition lam.h:111
const Def * derive_(const Def *)
Additionally to the derivation, the pullback is registered and the maps are initialized.
const Def * augment_(const Def *, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Definition eval.cpp:10
const Def * rewrite(const Def *) override
Detect autodiff calls.
Definition eval.cpp:22
const Def * derive(const Def *)
Acts on toplevel autodiff on closed terms:
Definition eval.cpp:16
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
The automatic differentiation Plugin
Definition autodiff.h:6