MimIR 0.1
MimIR is my Intermediate Representation
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages Concepts
autodiff_eval.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5#include <mim/lam.h>
6
10#include "mim/plug/mem/mem.h"
11
12namespace mim::plug::autodiff {
13
14// TODO: maybe use template (https://codereview.stackexchange.com/questions/141961/memoization-via-template) to memoize
15const Def* AutoDiffEval::augment(const Def* def, Lam* f, Lam* f_diff) {
16 if (auto i = augmented.find(def); i != augmented.end()) return i->second;
17 augmented[def] = augment_(def, f, f_diff);
18 return augmented[def];
19}
20
21const Def* AutoDiffEval::derive(const Def* def) {
22 if (auto i = derived.find(def); i != derived.end()) return i->second;
23 derived[def] = derive_(def);
24 return derived[def];
25}
26
27const Def* AutoDiffEval::rewrite(const Def* def) {
28 if (auto ad_app = Axm::isa<ad>(def); ad_app) {
29 // callee = autodiff T
30 // arg = function of type T
31 // (or operator)
32 auto arg = ad_app->arg();
33 world().DLOG("found a autodiff::autodiff of {}", arg);
34
35 if (arg->isa<Lam>()) return derive(arg);
36
37 // TODO: handle operators analogous
38
39 assert(0 && "not implemented");
40 return def;
41 }
42
43 return def;
44}
45
46} // namespace mim::plug::autodiff
static auto isa(const Def *def)
Definition axm.h:104
Base class for all Defs.
Definition def.h:197
A function.
Definition lam.h:106
World & world()
Definition pass.h:296
const Def * augment(const Def *, Lam *, Lam *)
Applies to (open) expressions in a functional context.
const Def * derive(const Def *)
Acts on toplevel autodiff on closed terms:
const Def * augment_(const Def *, Lam *, Lam *)
Rewrites the given definition in a lambda environment.
const Def * derive_(const Def *)
Additionally to the derivation, the pullback is registered and the maps are initialized.
const Def * rewrite(const Def *) override
Detect autodiff calls.
The automatic differentiation Plugin
Definition autodiff.h:6