MimIR 0.1
MimIR is my Intermediate Representation
|
This pass is the heart of AD. More...
#include <mim/plug/autodiff/pass/autodiff_eval.h>
Public Member Functions | |
AutoDiffEval (PassMan &man) | |
const Def * | rewrite (const Def *) override |
Detect autodiff calls. | |
const Def * | derive (const Def *) |
Acts on toplevel autodiff on closed terms: | |
const Def * | derive_ (const Def *) |
Additionally to the derivation, the pullback is registered and the maps are initialized. | |
const Def * | augment (const Def *, Lam *, Lam *) |
Applies to (open) expressions in a functional context. | |
const Def * | augment_ (const Def *, Lam *, Lam *) |
Rewrites the given definition in a lambda environment. | |
const Def * | augment_var (const Var *, Lam *, Lam *) |
helper functions for augment | |
const Def * | augment_lam (Lam *, Lam *, Lam *) |
const Def * | augment_extract (const Extract *, Lam *, Lam *) |
const Def * | augment_app (const App *, Lam *, Lam *) |
const Def * | augment_lit (const Lit *, Lam *, Lam *) |
const Def * | augment_tuple (const Tuple *, Lam *, Lam *) |
const Def * | augment_pack (const Pack *pack, Lam *f, Lam *f_diff) |
![]() | |
RWPass (PassMan &man, std::string_view name) | |
bool | inspect () const override |
Should the PassMan even consider this pass? | |
Lam * | curr_mut () const |
![]() | |
Pass (PassMan &, std::string_view name) | |
virtual | ~Pass ()=default |
World & | world () |
PassMan & | man () |
const PassMan & | man () const |
std::string_view | name () const |
size_t | index () const |
virtual const Def * | rewrite (const Var *var) |
virtual const Def * | rewrite (const Proxy *proxy) |
virtual undo_t | analyze (const Def *) |
virtual undo_t | analyze (const Var *) |
virtual undo_t | analyze (const Proxy *) |
virtual bool | fixed_point () const |
virtual void | enter () |
Invoked just before Pass::rewriteing PassMan::curr_mut's body. | |
virtual void | prepare () |
Invoked once before entering the main rewrite loop. | |
const Proxy * | proxy (const Def *type, Defs ops, u32 tag=0) |
const Proxy * | isa_proxy (const Def *def, u32 tag=0) |
Check whether given def is a Proxy whose Proxy::pass matches this Pass's IPass::index . | |
const Proxy * | as_proxy (const Def *def, u32 tag=0) |
This pass is the heart of AD.
We replace an autodiff fun
call with the differentiated function.
Definition at line 11 of file autodiff_eval.h.
|
inline |
Definition at line 13 of file autodiff_eval.h.
References mim::Pass::man(), mim::Pass::PassMan, and mim::RWPass< AutoDiffEval, Lam >::RWPass().
Applies to (open) expressions in a functional context.
Returns the rewritten expressions and augments the partial and modular pullbacks. The rewrite is identity on the term up to renaming of variables. Otherwise, only pullbacks are added. To do so, some calls (e.g. axioms) are replaced by their derivatives. This transformation can be seen as an augmentation with a dual computation that generates the derivatives.
Definition at line 15 of file autodiff_eval.cpp.
References augment_().
Referenced by augment_app(), augment_extract(), augment_lam(), augment_pack(), augment_tuple(), and derive_().
Rewrites the given definition in a lambda environment.
Definition at line 307 of file autodiff_rewrite_inner.cpp.
References augment_app(), augment_extract(), augment_lam(), augment_lit(), augment_pack(), augment_tuple(), augment_var(), mim::plug::autodiff::autodiff_type_fun(), mim::World::external(), mim::find_and_replace(), mim::Pass::index(), mim::Def::isa_mut(), mim::Def::node_name(), mim::World::sym(), mim::Def::type(), and mim::Pass::world().
Referenced by augment(), and augment_pack().
const Def * mim::plug::autodiff::AutoDiffEval::augment_app | ( | const App * | app, |
Lam * | f, | ||
Lam * | f_diff ) |
Definition at line 200 of file autodiff_rewrite_inner.cpp.
References mim::World::app(), mim::App::arg(), augment(), mim::App::callee(), mim::compose_cn(), mim::World::debug_dump(), mim::Pi::isa_basicblock(), mim::Pi::isa_cn(), mim::World::mut_lam(), mim::plug::direct::op_cps2ds_dep(), mim::Lam::set(), mim::Def::type(), mim::World::type(), mim::Def::var(), and mim::Pass::world().
Referenced by augment_().
const Def * mim::plug::autodiff::AutoDiffEval::augment_extract | ( | const Extract * | ext, |
Lam * | f, | ||
Lam * | f_diff ) |
Definition at line 86 of file autodiff_rewrite_inner.cpp.
References augment(), mim::World::extract(), mim::Extract::index(), mim::Pass::index(), mim::World::insert(), mim::World::mut_lam(), mim::plug::autodiff::pullback_type(), mim::Def::set(), mim::Lam::set(), mim::Extract::tuple(), mim::Def::type(), mim::World::var(), and mim::Pass::world().
Referenced by augment_().
Definition at line 28 of file autodiff_rewrite_inner.cpp.
References augment(), mim::plug::autodiff::autodiff_type_fun(), mim::Lam::body(), mim::World::call(), mim::Pi::dom(), mim::Lam::filter(), mim::Lam::isa_basicblock(), mim::World::mut_con(), mim::plug::autodiff::pullback_type(), mim::Def::sym(), mim::Def::type(), mim::Lam::type(), mim::Def::var(), mim::Pass::world(), and mim::plug::autodiff::zero_pullback().
Referenced by augment_().
Definition at line 15 of file autodiff_rewrite_inner.cpp.
References mim::plug::autodiff::zero_pullback().
Referenced by augment_().
const Def * mim::plug::autodiff::AutoDiffEval::augment_pack | ( | const Pack * | pack, |
Lam * | f, | ||
Lam * | f_diff ) |
Definition at line 158 of file autodiff_rewrite_inner.cpp.
References mim::World::app(), mim::Def::arity(), augment(), augment_(), mim::Pack::body(), mim::World::mut_lam(), mim::World::mut_pack(), mim::plug::direct::op_cps2ds_dep(), mim::World::pack(), mim::plug::autodiff::pullback_type(), mim::Lam::set(), mim::Pack::set(), mim::plug::autodiff::tangent_type_fun(), mim::Def::type(), and mim::Pass::world().
Referenced by augment_().
const Def * mim::plug::autodiff::AutoDiffEval::augment_tuple | ( | const Tuple * | tup, |
Lam * | f, | ||
Lam * | f_diff ) |
Definition at line 123 of file autodiff_rewrite_inner.cpp.
References augment(), mim::World::mut_lam(), mim::plug::autodiff::op_sum(), mim::Def::projs(), mim::plug::autodiff::pullback_type(), mim::Def::set(), mim::Lam::set(), mim::plug::autodiff::tangent_type_fun(), mim::World::tuple(), mim::Def::type(), mim::World::var(), and mim::Pass::world().
Referenced by augment_().
helper functions for augment
Definition at line 21 of file autodiff_rewrite_inner.cpp.
Referenced by augment_().
Acts on toplevel autodiff on closed terms:
Definition at line 21 of file autodiff_eval.cpp.
References derive_().
Referenced by rewrite().
Additionally to the derivation, the pullback is registered and the maps are initialized.
Definition at line 7 of file autodiff_rewrite_toplevel.cpp.
References mim::Def::as_mut(), augment(), mim::plug::autodiff::autodiff_type_fun_pi(), mim::plug::autodiff::id_pullback(), mim::World::mut_lam(), mim::Def::set(), mim::Lam::set(), mim::World::tuple(), mim::Def::type(), mim::Lam::type(), mim::Def::var(), mim::World::var(), mim::Pass::world(), and mim::plug::autodiff::zero_pullback().
Referenced by derive().
Detect autodiff calls.
Reimplemented from mim::Pass.
Definition at line 27 of file autodiff_eval.cpp.
References derive(), mim::match(), and mim::Pass::world().