MimIR 0.1
MimIR is my Intermediate Representation
|
This pass is the heart of AD. More...
#include <mim/plug/autodiff/pass/eval.h>
Public Member Functions | |
Eval (World &world, flags_t annex) | |
const Def * | rewrite (const Def *) override |
Detect autodiff calls. | |
const Def * | derive (const Def *) |
Acts on toplevel autodiff on closed terms: | |
const Def * | derive_ (const Def *) |
Additionally to the derivation, the pullback is registered and the maps are initialized. | |
const Def * | augment (const Def *, Lam *, Lam *) |
Applies to (open) expressions in a functional context. | |
const Def * | augment_ (const Def *, Lam *, Lam *) |
Rewrites the given definition in a lambda environment. | |
const Def * | augment_var (const Var *, Lam *, Lam *) |
helper functions for augment | |
const Def * | augment_lam (Lam *, Lam *, Lam *) |
const Def * | augment_extract (const Extract *, Lam *, Lam *) |
const Def * | augment_app (const App *, Lam *, Lam *) |
const Def * | augment_lit (const Lit *, Lam *, Lam *) |
const Def * | augment_tuple (const Tuple *, Lam *, Lam *) |
const Def * | augment_pack (const Pack *pack, Lam *f, Lam *f_diff) |
![]() | |
RWPass (World &world, std::string name) | |
RWPass (World &world, flags_t annex) | |
bool | inspect () const override |
Should the PassMan even consider this pass? | |
Lam * | curr_mut () const |
![]() | |
Pass (World &world, std::string name) | |
Pass (World &world, flags_t annex) | |
virtual void | init (PassMan *) |
PassMan & | man () |
const PassMan & | man () const |
size_t | index () const |
virtual const Def * | rewrite (const Var *var) |
virtual const Def * | rewrite (const Proxy *proxy) |
virtual undo_t | analyze (const Def *) |
virtual undo_t | analyze (const Var *) |
virtual undo_t | analyze (const Proxy *) |
virtual bool | fixed_point () const |
virtual void | enter () |
Invoked just before Pass::rewriteing PassMan::curr_mut's body. | |
virtual void | prepare () |
Invoked once before entering the main rewrite loop. | |
const Proxy * | proxy (const Def *type, Defs ops, u32 tag=0) |
const Proxy * | isa_proxy (const Def *def, u32 tag=0) |
Check whether given def is a Proxy whose Proxy::pass matches this Pass's IPass::index . | |
const Proxy * | as_proxy (const Def *def, u32 tag=0) |
![]() | |
World & | world () |
Driver & | driver () |
Log & | log () const |
std::string_view | name () const |
flags_t | annex () const |
Stage (World &world, std::string name) | |
Stage (World &world, flags_t annex) | |
virtual | ~Stage ()=default |
virtual std::unique_ptr< Stage > | recreate () |
Creates a new instance; needed by a fixed-point PhaseMan. | |
virtual void | apply (const App *) |
Invoked if your Stage has additional args. | |
virtual void | apply (Stage &) |
Dito, but invoked by Stage::recreate. | |
Additional Inherited Members | |
static auto | create (const Flags2Stages &stages, const Def *def) |
template<class A, class P> | |
static void | hook (Flags2Stages &stages) |
![]() | |
std::string | name_ |
This pass is the heart of AD.
We replace an autodiff fun
call with the differentiated function.
Definition at line 12 of file eval.h.
References mim::Stage::annex(), mim::RWPass< Eval, Lam >::RWPass(), and mim::Stage::world().
Applies to (open) expressions in a functional context.
Returns the rewritten expressions and augments the partial and modular pullbacks. The rewrite is identity on the term up to renaming of variables. Otherwise, only pullbacks are added. To do so, some calls (e.g. axms) are replaced by their derivatives. This transformation can be seen as an augmentation with a dual computation that generates the derivatives.
Definition at line 10 of file eval.cpp.
References augment_().
Referenced by augment_app(), augment_extract(), augment_lam(), augment_pack(), augment_tuple(), and derive_().
Rewrites the given definition in a lambda environment.
Definition at line 301 of file autodiff_rewrite_inner.cpp.
References augment_app(), augment_extract(), augment_lam(), augment_lit(), augment_pack(), augment_tuple(), augment_var(), mim::plug::autodiff::autodiff_type_fun(), DLOG, ELOG, mim::World::external(), mim::find_and_replace(), mim::Pass::index(), mim::Def::isa_mut(), mim::Def::node_name(), mim::Def::type(), and mim::Stage::world().
Referenced by augment(), and augment_pack().
Definition at line 195 of file autodiff_rewrite_inner.cpp.
References mim::World::app(), mim::App::arg(), augment(), mim::App::callee(), mim::compose_cn(), mim::World::debug_dump(), DLOG, mim::Pi::isa_basicblock(), mim::Pi::isa_cn(), mim::World::mut_lam(), mim::plug::direct::op_cps2ds_dep(), mim::Lam::set(), mim::Def::type(), mim::Def::var(), and mim::Stage::world().
Referenced by augment_().
const Def * mim::plug::autodiff::Eval::augment_extract | ( | const Extract * | ext, |
Lam * | f, | ||
Lam * | f_diff ) |
Definition at line 81 of file autodiff_rewrite_inner.cpp.
References augment(), DLOG, mim::World::extract(), mim::Extract::index(), mim::Pass::index(), mim::World::insert(), mim::World::mut_lam(), mim::plug::autodiff::pullback_type(), mim::Def::set(), mim::Lam::set(), mim::Extract::tuple(), mim::Def::type(), and mim::Stage::world().
Referenced by augment_().
Definition at line 23 of file autodiff_rewrite_inner.cpp.
References augment(), mim::plug::autodiff::autodiff_type_fun(), mim::Lam::body(), mim::World::call(), DLOG, mim::Pi::dom(), mim::Lam::filter(), mim::Lam::isa_basicblock(), mim::World::mut_con(), mim::plug::autodiff::pullback_type(), mim::Def::sym(), mim::Def::type(), mim::Lam::type(), mim::Def::var(), mim::Stage::world(), and mim::plug::autodiff::zero_pullback().
Referenced by augment_().
Definition at line 10 of file autodiff_rewrite_inner.cpp.
References mim::plug::autodiff::zero_pullback().
Referenced by augment_().
Definition at line 153 of file autodiff_rewrite_inner.cpp.
References mim::Stage::annex(), mim::World::app(), mim::Pack::arity(), augment(), augment_(), mim::Seq::body(), DLOG, mim::World::mut_lam(), mim::World::mut_pack(), mim::plug::direct::op_cps2ds_dep(), mim::World::pack(), mim::plug::autodiff::pullback_type(), mim::Lam::set(), mim::Pack::set(), mim::plug::autodiff::tangent_type_fun(), mim::Def::type(), and mim::Stage::world().
Referenced by augment_().
Definition at line 118 of file autodiff_rewrite_inner.cpp.
References augment(), DLOG, mim::World::mut_lam(), mim::plug::autodiff::op_sum(), mim::Def::projs(), mim::plug::autodiff::pullback_type(), mim::Lam::set(), mim::plug::autodiff::tangent_type_fun(), mim::World::tuple(), mim::Def::type(), and mim::Stage::world().
Referenced by augment_().
helper functions for augment
Definition at line 16 of file autodiff_rewrite_inner.cpp.
Referenced by augment_().
Additionally to the derivation, the pullback is registered and the maps are initialized.
Definition at line 7 of file autodiff_rewrite_toplevel.cpp.
References mim::Def::as_mut(), augment(), mim::plug::autodiff::autodiff_type_fun_pi(), DLOG, mim::plug::autodiff::id_pullback(), mim::World::mut_lam(), mim::Def::set(), mim::Lam::set(), mim::World::tuple(), mim::Def::type(), mim::Lam::type(), mim::Def::var(), mim::Stage::world(), and mim::plug::autodiff::zero_pullback().
Referenced by derive().