MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
mim::plug::autodiff::AutoDiffEval Class Reference

This pass is the heart of AD. More...

#include <mim/plug/autodiff/pass/autodiff_eval.h>

Inheritance diagram for mim::plug::autodiff::AutoDiffEval:
[legend]

Public Member Functions

 AutoDiffEval (PassMan &man)
 
const Defrewrite (const Def *) override
 Detect autodiff calls.
 
const Defderive (const Def *)
 Acts on toplevel autodiff on closed terms:
 
const Defderive_ (const Def *)
 Additionally to the derivation, the pullback is registered and the maps are initialized.
 
const Defaugment (const Def *, Lam *, Lam *)
 Applies to (open) expressions in a functional context.
 
const Defaugment_ (const Def *, Lam *, Lam *)
 Rewrites the given definition in a lambda environment.
 
const Defaugment_var (const Var *, Lam *, Lam *)
 helper functions for augment
 
const Defaugment_lam (Lam *, Lam *, Lam *)
 
const Defaugment_extract (const Extract *, Lam *, Lam *)
 
const Defaugment_app (const App *, Lam *, Lam *)
 
const Defaugment_lit (const Lit *, Lam *, Lam *)
 
const Defaugment_tuple (const Tuple *, Lam *, Lam *)
 
const Defaugment_pack (const Pack *pack, Lam *f, Lam *f_diff)
 
- Public Member Functions inherited from mim::RWPass< AutoDiffEval, Lam >
 RWPass (PassMan &man, std::string_view name)
 
bool inspect () const override
 Should the PassMan even consider this pass?
 
Lamcurr_mut () const
 
- Public Member Functions inherited from mim::Pass
 Pass (PassMan &, std::string_view name)
 
virtual ~Pass ()=default
 
Worldworld ()
 
PassManman ()
 
const PassManman () const
 
std::string_view name () const
 
size_t index () const
 
virtual const Defrewrite (const Var *var)
 
virtual const Defrewrite (const Proxy *proxy)
 
virtual undo_t analyze (const Def *)
 
virtual undo_t analyze (const Var *)
 
virtual undo_t analyze (const Proxy *)
 
virtual bool fixed_point () const
 
virtual void enter ()
 Invoked just before Pass::rewriteing PassMan::curr_mut's body.
 
virtual void prepare ()
 Invoked once before entering the main rewrite loop.
 
const Proxyproxy (const Def *type, Defs ops, u32 tag=0)
 
const Proxyisa_proxy (const Def *def, u32 tag=0)
 Check whether given def is a Proxy whose Proxy::pass matches this Pass's IPass::index.
 
const Proxyas_proxy (const Def *def, u32 tag=0)
 

Detailed Description

This pass is the heart of AD.

We replace an autodiff fun call with the differentiated function.

Definition at line 11 of file autodiff_eval.h.

Constructor & Destructor Documentation

◆ AutoDiffEval()

mim::plug::autodiff::AutoDiffEval::AutoDiffEval ( PassMan & man)
inline

Member Function Documentation

◆ augment()

const Def * mim::plug::autodiff::AutoDiffEval::augment ( const Def * def,
Lam * f,
Lam * f_diff )

Applies to (open) expressions in a functional context.

Returns the rewritten expressions and augments the partial and modular pullbacks. The rewrite is identity on the term up to renaming of variables. Otherwise, only pullbacks are added. To do so, some calls (e.g. axioms) are replaced by their derivatives. This transformation can be seen as an augmentation with a dual computation that generates the derivatives.

Definition at line 15 of file autodiff_eval.cpp.

References augment_().

Referenced by augment_app(), augment_extract(), augment_lam(), augment_pack(), augment_tuple(), and derive_().

◆ augment_()

const Def * mim::plug::autodiff::AutoDiffEval::augment_ ( const Def * def,
Lam * f,
Lam * f_diff )

◆ augment_app()

◆ augment_extract()

◆ augment_lam()

◆ augment_lit()

const Def * mim::plug::autodiff::AutoDiffEval::augment_lit ( const Lit * lit,
Lam * f,
Lam *  )

Definition at line 15 of file autodiff_rewrite_inner.cpp.

References mim::plug::autodiff::zero_pullback().

Referenced by augment_().

◆ augment_pack()

◆ augment_tuple()

◆ augment_var()

const Def * mim::plug::autodiff::AutoDiffEval::augment_var ( const Var * var,
Lam * ,
Lam *  )

helper functions for augment

Definition at line 21 of file autodiff_rewrite_inner.cpp.

Referenced by augment_().

◆ derive()

const Def * mim::plug::autodiff::AutoDiffEval::derive ( const Def * def)

Acts on toplevel autodiff on closed terms:

  • Replaces lambdas, operators with the appropriate derivatives.
  • Creates new lambda, calls associate variables, init maps, calls augment.

Definition at line 21 of file autodiff_eval.cpp.

References derive_().

Referenced by rewrite().

◆ derive_()

const Def * mim::plug::autodiff::AutoDiffEval::derive_ ( const Def * def)

◆ rewrite()

const Def * mim::plug::autodiff::AutoDiffEval::rewrite ( const Def * def)
overridevirtual

Detect autodiff calls.

Reimplemented from mim::Pass.

Definition at line 27 of file autodiff_eval.cpp.

References derive(), mim::match(), and mim::Pass::world().


The documentation for this class was generated from the following files: