MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
mim::plug::autodiff::Eval Class Reference

This pass is the heart of AD. More...

#include <mim/plug/autodiff/pass/eval.h>

Inheritance diagram for mim::plug::autodiff::Eval:
[legend]

Public Member Functions

 Eval (World &world, flags_t annex)
 
const Defrewrite (const Def *) override
 Detect autodiff calls.
 
const Defderive (const Def *)
 Acts on toplevel autodiff on closed terms:
 
const Defderive_ (const Def *)
 Additionally to the derivation, the pullback is registered and the maps are initialized.
 
const Defaugment (const Def *, Lam *, Lam *)
 Applies to (open) expressions in a functional context.
 
const Defaugment_ (const Def *, Lam *, Lam *)
 Rewrites the given definition in a lambda environment.
 
const Defaugment_var (const Var *, Lam *, Lam *)
 helper functions for augment
 
const Defaugment_lam (Lam *, Lam *, Lam *)
 
const Defaugment_extract (const Extract *, Lam *, Lam *)
 
const Defaugment_app (const App *, Lam *, Lam *)
 
const Defaugment_lit (const Lit *, Lam *, Lam *)
 
const Defaugment_tuple (const Tuple *, Lam *, Lam *)
 
const Defaugment_pack (const Pack *pack, Lam *f, Lam *f_diff)
 
- Public Member Functions inherited from mim::RWPass< Eval, Lam >
 RWPass (World &world, std::string name)
 
 RWPass (World &world, flags_t annex)
 
bool inspect () const override
 Should the PassMan even consider this pass?
 
Lamcurr_mut () const
 
- Public Member Functions inherited from mim::Pass
 Pass (World &world, std::string name)
 
 Pass (World &world, flags_t annex)
 
virtual void init (PassMan *)
 
PassManman ()
 
const PassManman () const
 
size_t index () const
 
virtual const Defrewrite (const Var *var)
 
virtual const Defrewrite (const Proxy *proxy)
 
virtual undo_t analyze (const Def *)
 
virtual undo_t analyze (const Var *)
 
virtual undo_t analyze (const Proxy *)
 
virtual bool fixed_point () const
 
virtual void enter ()
 Invoked just before Pass::rewriteing PassMan::curr_mut's body.
 
virtual void prepare ()
 Invoked once before entering the main rewrite loop.
 
const Proxyproxy (const Def *type, Defs ops, u32 tag=0)
 
const Proxyisa_proxy (const Def *def, u32 tag=0)
 Check whether given def is a Proxy whose Proxy::pass matches this Pass's IPass::index.
 
const Proxyas_proxy (const Def *def, u32 tag=0)
 
- Public Member Functions inherited from mim::Stage
Worldworld ()
 
Driverdriver ()
 
Loglog () const
 
std::string_view name () const
 
flags_t annex () const
 
 Stage (World &world, std::string name)
 
 Stage (World &world, flags_t annex)
 
virtual ~Stage ()=default
 
virtual std::unique_ptr< Stagerecreate ()
 Creates a new instance; needed by a fixed-point PhaseMan.
 
virtual void apply (const App *)
 Invoked if your Stage has additional args.
 
virtual void apply (Stage &)
 Dito, but invoked by Stage::recreate.
 

Additional Inherited Members

static auto create (const Flags2Stages &stages, const Def *def)
 
template<class A, class P>
static void hook (Flags2Stages &stages)
 
- Protected Attributes inherited from mim::Stage
std::string name_
 

Detailed Description

This pass is the heart of AD.

We replace an autodiff fun call with the differentiated function.

Definition at line 10 of file eval.h.

Constructor & Destructor Documentation

◆ Eval()

mim::plug::autodiff::Eval::Eval ( World & world,
flags_t annex )
inline

Member Function Documentation

◆ augment()

const Def * mim::plug::autodiff::Eval::augment ( const Def * def,
Lam * f,
Lam * f_diff )

Applies to (open) expressions in a functional context.

Returns the rewritten expressions and augments the partial and modular pullbacks. The rewrite is identity on the term up to renaming of variables. Otherwise, only pullbacks are added. To do so, some calls (e.g. axms) are replaced by their derivatives. This transformation can be seen as an augmentation with a dual computation that generates the derivatives.

Definition at line 10 of file eval.cpp.

References augment_().

Referenced by augment_app(), augment_extract(), augment_lam(), augment_pack(), augment_tuple(), and derive_().

◆ augment_()

const Def * mim::plug::autodiff::Eval::augment_ ( const Def * def,
Lam * f,
Lam * f_diff )

◆ augment_app()

◆ augment_extract()

◆ augment_lam()

◆ augment_lit()

const Def * mim::plug::autodiff::Eval::augment_lit ( const Lit * lit,
Lam * f,
Lam *  )

Definition at line 10 of file autodiff_rewrite_inner.cpp.

References mim::plug::autodiff::zero_pullback().

Referenced by augment_().

◆ augment_pack()

◆ augment_tuple()

◆ augment_var()

const Def * mim::plug::autodiff::Eval::augment_var ( const Var * var,
Lam * ,
Lam *  )

helper functions for augment

Definition at line 16 of file autodiff_rewrite_inner.cpp.

Referenced by augment_().

◆ derive()

const Def * mim::plug::autodiff::Eval::derive ( const Def * def)

Acts on toplevel autodiff on closed terms:

  • Replaces lambdas, operators with the appropriate derivatives.
  • Creates new lambda, calls associate variables, init maps, calls augment.

Definition at line 16 of file eval.cpp.

References derive_().

Referenced by rewrite().

◆ derive_()

const Def * mim::plug::autodiff::Eval::derive_ ( const Def * def)

◆ rewrite()

const Def * mim::plug::autodiff::Eval::rewrite ( const Def * def)
overridevirtual

Detect autodiff calls.

Reimplemented from mim::Pass.

Definition at line 22 of file eval.cpp.

References derive(), DLOG, and mim::Axm::isa().


The documentation for this class was generated from the following files: