MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
mim::plug::autodiff::AutoDiffEval Class Reference

This pass is the heart of AD. More...

#include <mim/plug/autodiff/pass/autodiff_eval.h>

Inheritance diagram for mim::plug::autodiff::AutoDiffEval:
[legend]

Public Member Functions

 AutoDiffEval (PassMan &man)
 
Ref rewrite (Ref) override
 Detect autodiff calls.
 
Ref derive (Ref)
 Acts on toplevel autodiff on closed terms:
 
Ref derive_ (Ref)
 Additionally to the derivation, the pullback is registered and the maps are initialized.
 
Ref augment (Ref, Lam *, Lam *)
 Applies to (open) expressions in a functional context.
 
Ref augment_ (Ref, Lam *, Lam *)
 Rewrites the given definition in a lambda environment.
 
Ref augment_var (const Var *, Lam *, Lam *)
 helper functions for augment
 
Ref augment_lam (Lam *, Lam *, Lam *)
 
Ref augment_extract (const Extract *, Lam *, Lam *)
 
Ref augment_app (const App *, Lam *, Lam *)
 
Ref augment_lit (const Lit *, Lam *, Lam *)
 
Ref augment_tuple (const Tuple *, Lam *, Lam *)
 
Ref augment_pack (const Pack *pack, Lam *f, Lam *f_diff)
 
- Public Member Functions inherited from mim::RWPass< AutoDiffEval, Lam >
 RWPass (PassMan &man, std::string_view name)
 
bool inspect () const override
 Should the PassMan even consider this pass?
 
Lamcurr_mut () const
 
- Public Member Functions inherited from mim::Pass
 Pass (PassMan &, std::string_view name)
 
virtual ~Pass ()=default
 
Worldworld ()
 
PassManman ()
 
const PassManman () const
 
std::string_view name () const
 
size_t index () const
 
virtual Ref rewrite (const Var *var)
 
virtual Ref rewrite (const Proxy *proxy)
 
virtual undo_t analyze (Ref)
 
virtual undo_t analyze (const Var *)
 
virtual undo_t analyze (const Proxy *)
 
virtual bool fixed_point () const
 
virtual void enter ()
 Invoked just before Pass::rewriteing PassMan::curr_mut's body.
 
virtual void prepare ()
 Invoked once before entering the main rewrite loop.
 
const Proxyproxy (Ref type, Defs ops, u32 tag=0)
 
const Proxyisa_proxy (Ref def, u32 tag=0)
 Check whether given def is a Proxy whose Proxy::pass matches this Pass's IPass::index.
 
const Proxyas_proxy (Ref def, u32 tag=0)
 

Detailed Description

This pass is the heart of AD.

We replace an autodiff fun call with the differentiated function.

Definition at line 10 of file autodiff_eval.h.

Constructor & Destructor Documentation

◆ AutoDiffEval()

mim::plug::autodiff::AutoDiffEval::AutoDiffEval ( PassMan & man)
inline

Definition at line 12 of file autodiff_eval.h.

Member Function Documentation

◆ augment()

Ref mim::plug::autodiff::AutoDiffEval::augment ( Ref def,
Lam * f,
Lam * f_diff )

Applies to (open) expressions in a functional context.

Returns the rewritten expressions and augments the partial and modular pullbacks. The rewrite is identity on the term up to renaming of variables. Otherwise, only pullbacks are added. To do so, some calls (e.g. axioms) are replaced by their derivatives. This transformation can be seen as an augmentation with a dual computation that generates the derivatives.

Definition at line 15 of file autodiff_eval.cpp.

References augment_().

Referenced by augment_app(), augment_extract(), augment_lam(), augment_pack(), augment_tuple(), and derive_().

◆ augment_()

◆ augment_app()

◆ augment_extract()

◆ augment_lam()

◆ augment_lit()

Ref mim::plug::autodiff::AutoDiffEval::augment_lit ( const Lit * lit,
Lam * f,
Lam *  )

Definition at line 15 of file autodiff_rewrite_inner.cpp.

References mim::plug::autodiff::zero_pullback().

Referenced by augment_().

◆ augment_pack()

◆ augment_tuple()

◆ augment_var()

Ref mim::plug::autodiff::AutoDiffEval::augment_var ( const Var * var,
Lam * ,
Lam *  )

helper functions for augment

Definition at line 21 of file autodiff_rewrite_inner.cpp.

Referenced by augment_().

◆ derive()

Ref mim::plug::autodiff::AutoDiffEval::derive ( Ref def)

Acts on toplevel autodiff on closed terms:

  • Replaces lambdas, operators with the appropriate derivatives.
  • Creates new lambda, calls associate variables, init maps, calls augment.

Definition at line 21 of file autodiff_eval.cpp.

References derive_().

Referenced by rewrite().

◆ derive_()

Ref mim::plug::autodiff::AutoDiffEval::derive_ ( Ref def)

◆ rewrite()

Ref mim::plug::autodiff::AutoDiffEval::rewrite ( Ref def)
overridevirtual

Detect autodiff calls.

Reimplemented from mim::Pass.

Definition at line 27 of file autodiff_eval.cpp.

References derive(), mim::match(), and mim::Pass::world().


The documentation for this class was generated from the following files: