MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include "mim/axiom.h"
2#include "mim/world.h"
3
6
7namespace mim::plug::autodiff {
8
9/// Currently this normalizer does nothin.
10/// TODO: Maybe we want to handle trivial lookup replacements here.
11const Def* normalize_ad(const Def*, const Def*, const Def*) { return {}; }
12
13const Def* normalize_AD(const Def*, const Def*, const Def* arg) {
14 auto ad_ty = autodiff_type_fun(arg);
15 if (ad_ty) return ad_ty;
16 return {};
17}
18
19const Def* normalize_Tangent(const Def*, const Def*, const Def* arg) { return tangent_type_fun(arg); }
20
21/// Currently this normalizer does nothing.
22/// We usually want to keep zeros as long as possible to avoid unnecessary allocations.
23/// A high-level addition with zero can be shortened directly.
24const Def* normalize_zero(const Def*, const Def*, const Def*) { return {}; }
25
26/// Currently resolved the full addition.
27/// There is no benefit in keeping additions around longer than necessary.
28const Def* normalize_add(const Def* type, const Def* callee, const Def* arg) {
29 auto& world = type->world();
30
31 // TODO: add tuple -> tuple of adds
32 // TODO: add zero -> other
33 // TODO: unify mapping over structure with other aspects like zero
34
35 auto T = callee->as<App>()->arg();
36 auto [a, b] = arg->projs<2>();
37
38 world.DLOG("add {} {} {}", T, a, b);
39
40 if (match<zero>(a)) {
41 world.DLOG("0+b");
42 return b;
43 }
44 if (match<zero>(b)) {
45 world.DLOG("0+a");
46 return a;
47 }
48 // A value level match would be harder as a tuple might in reality be a var or extract
49 if (auto sig = T->isa<Sigma>()) {
50 world.DLOG("add tuple");
51 auto p = sig->num_ops(); // TODO: or num_projs
52 auto ops = DefVec(p, [&](size_t i) {
53 return world.app(world.app(world.annex<add>(), sig->op(i)), {a->proj(i), b->proj(i)});
54 });
55 return world.tuple(ops);
56 } else if (auto arr = T->isa<Arr>()) {
57 // TODO: is this working for non-lit (non-tuple) or do we need a loop?
58 world.DLOG("add arrays {} {} {}", T, a, b);
59 auto pack = world.mut_pack(T);
60 auto body_type = arr->body();
61 world.DLOG("body type {}", body_type);
62 pack->set(world.app(world.app(world.annex<add>(), body_type),
63 {world.extract(a, pack->var()), world.extract(b, pack->var())}));
64 world.DLOG("pack {}", pack);
65 return pack;
66 } else if (Idx::isa(type)) {
67 world.DLOG("add int");
68 auto width = Idx::as_lit(a->type());
69 world.DLOG("width {}", width);
70 auto int_add = world.call(core::wrap::add, 0_n, Defs{a, b});
71 world.DLOG("int add {} : {}", int_add, Idx::isa(int_add->type()));
72 return int_add;
73 } else if (T->isa<App>()) {
74 assert(0 && "not handled");
75 }
76 // TODO: mem stays here (only resolved after direct simplification)
77
78 return {};
79}
80
81const Def* normalize_sum(const Def* type, const Def* callee, const Def* arg) {
82 auto& world = type->world();
83
84 auto [count, T] = callee->as<App>()->args<2>();
85
86 if (auto lit = count->isa<Lit>()) {
87 auto val = lit->get<nat_t>();
88 world.DLOG("val: {}", val);
89 auto args = arg->projs(val);
90 auto sum = world.app(world.annex<zero>(), T);
91 // This special case would also be handled by add zero
92 if (val >= 1) sum = args[0];
93 for (size_t i = 1; i < val; ++i) sum = world.app(world.app(world.annex<add>(), T), {sum, args[i]});
94 return sum;
95 }
96 assert(0);
97 return {};
98}
99
101
102} // namespace mim::plug::autodiff
#define MIM_autodiff_NORMALIZER_IMPL
Definition autogen.h:90
A (possibly paramterized) Array.
Definition tuple.h:68
Base class for all Defs.
Definition def.h:198
World & world() const noexcept
Definition def.cpp:413
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:345
static nat_t as_lit(const Def *def)
Definition def.h:781
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:555
A dependent tuple type.
Definition tuple.h:9
The automatic differentiation Plugin
Definition autodiff.h:6
const Def * normalize_Tangent(const Def *, const Def *, const Def *arg)
const Def * normalize_add(const Def *type, const Def *callee, const Def *arg)
Currently resolved the full addition.
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:113
const Def * normalize_AD(const Def *, const Def *, const Def *arg)
const Def * normalize_ad(const Def *, const Def *, const Def *)
Currently this normalizer does nothin.
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:57
const Def * normalize_sum(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_zero(const Def *, const Def *, const Def *)
Currently this normalizer does nothing.
View< const Def * > Defs
Definition def.h:49
u64 nat_t
Definition types.h:43
Vector< const Def * > DefVec
Definition def.h:50
auto match(const Def *def)
Definition axiom.h:112