MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
autodiff_eval.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5#include <mim/lam.h>
6
10#include "mim/plug/mem/mem.h"
11
12namespace mim::plug::autodiff {
13
14// TODO: maybe use template (https://codereview.stackexchange.com/questions/141961/memoization-via-template) to memoize
15Ref AutoDiffEval::augment(Ref def, Lam* f, Lam* f_diff) {
16 if (auto i = augmented.find(def); i != augmented.end()) return i->second;
17 augmented[def] = augment_(def, f, f_diff);
18 return augmented[def];
19}
20
22 if (auto i = derived.find(def); i != derived.end()) return i->second;
23 derived[def] = derive_(def);
24 return derived[def];
25}
26
28 if (auto ad_app = match<ad>(def); ad_app) {
29 // callee = autodiff T
30 // arg = function of type T
31 // (or operator)
32 auto arg = ad_app->arg();
33 world().DLOG("found a autodiff::autodiff of {}", arg);
34
35 if (arg->isa<Lam>()) return derive(arg);
36
37 // TODO: handle operators analogous
38
39 assert(0 && "not implemented");
40 return def;
41 }
42
43 return def;
44}
45
46} // namespace mim::plug::autodiff
A function.
Definition lam.h:103
World & world()
Definition pass.h:296
Helper class to retrieve Infer::arg if present.
Definition def.h:86
Ref derive_(Ref)
Additionally to the derivation, the pullback is registered and the maps are initialized.
Ref augment_(Ref, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
Ref augment(Ref, Lam *, Lam *)
Applies to (open) expressions in a functional context.
Ref rewrite(Ref) override
Detect autodiff calls.
Ref derive(Ref)
Acts on toplevel autodiff on closed terms:
The automatic differentiation Plugin
Definition autodiff.h:6
auto match(Ref def)
Definition axiom.h:112