MimIR 0.1
MimIR is my Intermediate Representation
|
This pass is the heart of AD. More...
#include <mim/plug/autodiff/pass/autodiff_eval.h>
Public Member Functions | |
AutoDiffEval (PassMan &man) | |
Ref | rewrite (Ref) override |
Detect autodiff calls. | |
Ref | derive (Ref) |
Acts on toplevel autodiff on closed terms: | |
Ref | derive_ (Ref) |
Additionally to the derivation, the pullback is registered and the maps are initialized. | |
Ref | augment (Ref, Lam *, Lam *) |
Applies to (open) expressions in a functional context. | |
Ref | augment_ (Ref, Lam *, Lam *) |
Rewrites the given definition in a lambda environment. | |
Ref | augment_var (const Var *, Lam *, Lam *) |
helper functions for augment | |
Ref | augment_lam (Lam *, Lam *, Lam *) |
Ref | augment_extract (const Extract *, Lam *, Lam *) |
Ref | augment_app (const App *, Lam *, Lam *) |
Ref | augment_lit (const Lit *, Lam *, Lam *) |
Ref | augment_tuple (const Tuple *, Lam *, Lam *) |
Ref | augment_pack (const Pack *pack, Lam *f, Lam *f_diff) |
Public Member Functions inherited from mim::RWPass< AutoDiffEval, Lam > | |
RWPass (PassMan &man, std::string_view name) | |
bool | inspect () const override |
Should the PassMan even consider this pass? | |
Lam * | curr_mut () const |
Public Member Functions inherited from mim::Pass | |
Pass (PassMan &, std::string_view name) | |
virtual | ~Pass ()=default |
World & | world () |
PassMan & | man () |
const PassMan & | man () const |
std::string_view | name () const |
size_t | index () const |
virtual Ref | rewrite (const Var *var) |
virtual Ref | rewrite (const Proxy *proxy) |
virtual undo_t | analyze (Ref) |
virtual undo_t | analyze (const Var *) |
virtual undo_t | analyze (const Proxy *) |
virtual bool | fixed_point () const |
virtual void | enter () |
Invoked just before Pass::rewriteing PassMan::curr_mut's body. | |
virtual void | prepare () |
Invoked once before entering the main rewrite loop. | |
const Proxy * | proxy (Ref type, Defs ops, u32 tag=0) |
const Proxy * | isa_proxy (Ref def, u32 tag=0) |
Check whether given def is a Proxy whose Proxy::pass matches this Pass's IPass::index . | |
const Proxy * | as_proxy (Ref def, u32 tag=0) |
This pass is the heart of AD.
We replace an autodiff fun
call with the differentiated function.
Definition at line 10 of file autodiff_eval.h.
|
inline |
Definition at line 12 of file autodiff_eval.h.
Applies to (open) expressions in a functional context.
Returns the rewritten expressions and augments the partial and modular pullbacks. The rewrite is identity on the term up to renaming of variables. Otherwise, only pullbacks are added. To do so, some calls (e.g. axioms) are replaced by their derivatives. This transformation can be seen as an augmentation with a dual computation that generates the derivatives.
Definition at line 15 of file autodiff_eval.cpp.
References augment_().
Referenced by augment_app(), augment_extract(), augment_lam(), augment_pack(), augment_tuple(), and derive_().
Rewrites the given definition in a lambda environment.
Definition at line 313 of file autodiff_rewrite_inner.cpp.
References augment_app(), augment_extract(), augment_lam(), augment_lit(), augment_pack(), augment_tuple(), augment_var(), mim::plug::autodiff::autodiff_type_fun(), mim::World::external(), mim::find_and_replace(), mim::Pass::index(), mim::Def::isa_mut(), mim::Def::node_name(), mim::World::sym(), mim::Def::type(), mim::Def::world(), and mim::Pass::world().
Referenced by augment(), and augment_pack().
Definition at line 205 of file autodiff_rewrite_inner.cpp.
References mim::World::app(), mim::App::arg(), augment(), mim::App::callee(), mim::compose_cn(), mim::World::debug_dump(), mim::Pi::isa_basicblock(), mim::Pi::isa_cn(), mim::World::mut_lam(), mim::plug::direct::op_cps2ds_dep(), mim::Lam::set(), mim::Def::type(), mim::World::type(), mim::Def::var(), mim::Def::world(), and mim::Pass::world().
Referenced by augment_().
Ref mim::plug::autodiff::AutoDiffEval::augment_extract | ( | const Extract * | ext, |
Lam * | f, | ||
Lam * | f_diff ) |
Definition at line 87 of file autodiff_rewrite_inner.cpp.
References augment(), mim::World::call(), mim::World::extract(), mim::Extract::index(), mim::Pass::index(), mim::World::insert(), mim::World::mut_lam(), mim::plug::autodiff::pullback_type(), mim::Def::set(), mim::Lam::set(), mim::Extract::tuple(), mim::Def::type(), mim::World::var(), mim::Def::world(), and mim::Pass::world().
Referenced by augment_().
Definition at line 28 of file autodiff_rewrite_inner.cpp.
References augment(), mim::plug::autodiff::autodiff_type_fun(), mim::Lam::body(), mim::World::call(), mim::Pi::dom(), mim::Lam::filter(), mim::Lam::isa_basicblock(), mim::World::mut_con(), mim::plug::autodiff::pullback_type(), mim::Def::sym(), mim::Lam::type(), mim::Def::var(), mim::Def::world(), mim::Pass::world(), and mim::plug::autodiff::zero_pullback().
Referenced by augment_().
Definition at line 15 of file autodiff_rewrite_inner.cpp.
References mim::plug::autodiff::zero_pullback().
Referenced by augment_().
Definition at line 162 of file autodiff_rewrite_inner.cpp.
References mim::World::annex(), mim::World::app(), mim::Def::arity(), mim::World::arr(), augment(), augment_(), mim::Pack::body(), mim::World::extract(), mim::World::mut_lam(), mim::World::mut_pack(), mim::plug::direct::op_cps2ds_dep(), mim::World::pack(), mim::plug::autodiff::pullback_type(), mim::Lam::set(), mim::Pack::set(), mim::plug::autodiff::tangent_type_fun(), mim::Def::type(), mim::Def::world(), and mim::Pass::world().
Referenced by augment_().
Definition at line 126 of file autodiff_rewrite_inner.cpp.
References augment(), mim::World::mut_lam(), mim::plug::autodiff::op_sum(), mim::Def::projs(), mim::plug::autodiff::pullback_type(), mim::Def::set(), mim::Lam::set(), mim::plug::autodiff::tangent_type_fun(), mim::World::tuple(), mim::Def::type(), mim::World::var(), mim::Def::world(), and mim::Pass::world().
Referenced by augment_().
helper functions for augment
Definition at line 21 of file autodiff_rewrite_inner.cpp.
Referenced by augment_().
Acts on toplevel autodiff on closed terms:
Definition at line 21 of file autodiff_eval.cpp.
References derive_().
Referenced by rewrite().
Additionally to the derivation, the pullback is registered and the maps are initialized.
Definition at line 7 of file autodiff_rewrite_toplevel.cpp.
References mim::Def::as_mut(), augment(), mim::plug::autodiff::autodiff_type_fun_pi(), mim::plug::autodiff::id_pullback(), mim::World::mut_lam(), mim::Def::set(), mim::Lam::set(), mim::World::tuple(), mim::Def::type(), mim::Def::var(), mim::Def::world(), mim::Pass::world(), and mim::plug::autodiff::zero_pullback().
Referenced by derive().
Detect autodiff calls.
Reimplemented from mim::Pass.
Definition at line 27 of file autodiff_eval.cpp.
References derive(), mim::match(), and mim::Pass::world().