MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
2#include <mim/plug/mem/mem.h>
3
4#include "mim/axm.h"
5#include "mim/world.h"
6
8
9namespace mim::plug::autodiff {
10
11/// Currently this normalizer does nothin.
12/// TODO: Maybe we want to handle trivial lookup replacements here.
13const Def* normalize_ad(const Def*, const Def*, const Def*) { return {}; }
14
15const Def* normalize_AD(const Def*, const Def*, const Def* arg) {
16 auto ad_ty = autodiff_type_fun(arg);
17 if (ad_ty) return ad_ty;
18 return {};
19}
20
21const Def* normalize_Tangent(const Def*, const Def*, const Def* arg) { return tangent_type_fun(arg); }
22
23/// Currently this normalizer does nothing.
24/// We usually want to keep zeros as long as possible to avoid unnecessary allocations.
25/// A high-level addition with zero can be shortened directly.
26const Def* normalize_zero(const Def*, const Def*, const Def*) { return {}; }
27
28/// Currently resolved the full addition.
29/// There is no benefit in keeping additions around longer than necessary.
30const Def* normalize_add(const Def* type, const Def* callee, const Def* arg) {
31 auto& world = type->world();
32
33 // TODO: add tuple -> tuple of adds
34 // TODO: add zero -> other
35 // TODO: unify mapping over structure with other aspects like zero
36
37 auto T = callee->as<App>()->arg();
38 auto [a, b] = arg->projs<2>();
39
40 world.DLOG("add {} {} {}", T, a, b);
41
42 if (Axm::isa<zero>(a)) {
43 world.DLOG("0+b");
44 return b;
45 }
46 if (Axm::isa<zero>(b)) {
47 world.DLOG("0+a");
48 return a;
49 }
50 // A value level match would be harder as a tuple might in reality be a var or extract
51 if (auto sig = T->isa<Sigma>()) {
52 world.DLOG("add tuple");
53 auto p = sig->num_ops(); // TODO: or num_projs
54 auto ops = DefVec(p, [&](size_t i) {
55 return world.app(world.app(world.annex<add>(), sig->op(i)), {a->proj(i), b->proj(i)});
56 });
57 return world.tuple(ops);
58 } else if (auto arr = T->isa<Arr>()) {
59 // TODO: is this working for non-lit (non-tuple) or do we need a loop?
60 world.DLOG("add arrays {} {} {}", T, a, b);
61 auto pack = world.mut_pack(T);
62 auto body_type = arr->body();
63 world.DLOG("body type {}", body_type);
64 pack->set(world.app(world.app(world.annex<add>(), body_type),
65 {world.extract(a, pack->var()), world.extract(b, pack->var())}));
66 world.DLOG("pack {}", pack);
67 return pack;
68 } else if (Idx::isa(type)) {
69 world.DLOG("add int");
70 auto width = Idx::as_lit(a->type());
71 world.DLOG("width {}", width);
72 auto int_add = world.call(core::wrap::add, 0_n, Defs{a, b});
73 world.DLOG("int add {} : {}", int_add, Idx::isa(int_add->type()));
74 return int_add;
75 } else if (Axm::isa<mem::M>(type)) {
76 // TODO: mem stays here (only resolved after direct simplification)
77 return {};
78 } else if (T->isa<App>()) {
79 assert(0 && "not handled");
80 }
81
82 return {};
83}
84
85const Def* normalize_sum(const Def* type, const Def* callee, const Def* arg) {
86 auto& world = type->world();
87
88 auto [count, T] = callee->as<App>()->args<2>();
89
90 if (auto lit = count->isa<Lit>()) {
91 auto val = lit->get<nat_t>();
92 world.DLOG("val: {}", val);
93 auto args = arg->projs(val);
94 auto sum = world.app(world.annex<zero>(), T);
95 // This special case would also be handled by add zero
96 if (val >= 1) sum = args[0];
97 for (size_t i = 1; i < val; ++i)
98 sum = world.app(world.app(world.annex<add>(), T), {sum, args[i]});
99 return sum;
100 }
101 assert(0);
102 return {};
103}
104
106
107} // namespace mim::plug::autodiff
#define MIM_autodiff_NORMALIZER_IMPL
Definition autogen.h:76
A (possibly paramterized) Array.
Definition tuple.h:117
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:390
static nat_t as_lit(const Def *def)
Definition def.h:882
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:610
A dependent tuple type.
Definition tuple.h:20
The automatic differentiation Plugin
Definition autodiff.h:6
const Def * normalize_Tangent(const Def *, const Def *, const Def *arg)
const Def * normalize_add(const Def *type, const Def *callee, const Def *arg)
Currently resolved the full addition.
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:111
const Def * normalize_AD(const Def *, const Def *, const Def *arg)
const Def * normalize_ad(const Def *, const Def *, const Def *)
Currently this normalizer does nothin.
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:55
const Def * normalize_sum(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_zero(const Def *, const Def *, const Def *)
Currently this normalizer does nothing.
View< const Def * > Defs
Definition def.h:76
u64 nat_t
Definition types.h:43
Vector< const Def * > DefVec
Definition def.h:77