MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include "mim/axiom.h"
2#include "mim/world.h"
3
6
7namespace mim::plug::autodiff {
8
9/// Currently this normalizer does nothin.
10/// TODO: Maybe we want to handle trivial lookup replacements here.
11Ref normalize_ad(Ref, Ref, Ref) { return {}; }
12
14 auto ad_ty = autodiff_type_fun(arg);
15 if (ad_ty) return ad_ty;
16 return {};
17}
18
20
21/// Currently this normalizer does nothing.
22/// We usually want to keep zeros as long as possible to avoid unnecessary allocations.
23/// A high-level addition with zero can be shortened directly.
24Ref normalize_zero(Ref, Ref, Ref) { return {}; }
25
26/// Currently resolved the full addition.
27/// There is no benefit in keeping additions around longer than necessary.
28Ref normalize_add(Ref type, Ref callee, Ref arg) {
29 auto& world = type->world();
30
31 // TODO: add tuple -> tuple of adds
32 // TODO: add zero -> other
33 // TODO: unify mapping over structure with other aspects like zero
34
35 auto T = callee->as<App>()->arg();
36 auto [a, b] = arg->projs<2>();
37
38 world.DLOG("add {} {} {}", T, a, b);
39
40 if (match<zero>(a)) {
41 world.DLOG("0+b");
42 return b;
43 }
44 if (match<zero>(b)) {
45 world.DLOG("0+a");
46 return a;
47 }
48 // A value level match would be harder as a tuple might in reality be a var or extract
49 if (auto sig = T->isa<Sigma>()) {
50 world.DLOG("add tuple");
51 auto p = sig->num_ops(); // TODO: or num_projs
52 auto ops = DefVec(p, [&](size_t i) {
53 return world.app(world.app(world.annex<add>(), sig->op(i)), {a->proj(i), b->proj(i)});
54 });
55 return world.tuple(ops);
56 } else if (auto arr = T->isa<Arr>()) {
57 // TODO: is this working for non-lit (non-tuple) or do we need a loop?
58 world.DLOG("add arrays {} {} {}", T, a, b);
59 auto pack = world.mut_pack(T);
60 auto body_type = arr->body();
61 world.DLOG("body type {}", body_type);
62 pack->set(world.app(world.app(world.annex<add>(), body_type),
63 {world.extract(a, pack->var()), world.extract(b, pack->var())}));
64 world.DLOG("pack {}", pack);
65 return pack;
66 } else if (Idx::isa(type)) {
67 world.DLOG("add int");
68 auto width = Idx::as_lit(a->type());
69 world.DLOG("width {}", width);
70 auto int_add = world.call(core::wrap::add, 0_n, Defs{a, b});
71 world.DLOG("int add {} : {}", int_add, Idx::isa(int_add->type()));
72 return int_add;
73 } else if (T->isa<App>()) {
74 assert(0 && "not handled");
75 }
76 // TODO: mem stays here (only resolved after direct simplification)
77
78 return {};
79}
80
81Ref normalize_sum(Ref type, Ref callee, Ref arg) {
82 auto& world = type->world();
83
84 auto [count, T] = callee->as<App>()->args<2>();
85
86 if (auto lit = count->isa<Lit>()) {
87 auto val = lit->get<nat_t>();
88 world.DLOG("val: {}", val);
89 auto args = arg->projs(val);
90 auto sum = world.app(world.annex<zero>(), T);
91 // This special case would also be handled by add zero
92 if (val >= 1) sum = args[0];
93 for (size_t i = 1; i < val; ++i) sum = world.app(world.app(world.annex<add>(), T), {sum, args[i]});
94 return sum;
95 }
96 assert(0);
97 return {};
98}
99
101
102} // namespace mim::plug::autodiff
#define MIM_autodiff_NORMALIZER_IMPL
Definition autogen.h:90
A (possibly paramterized) Array.
Definition tuple.h:67
World & world() const
Definition def.cpp:411
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
Definition def.h:361
static Ref isa(Ref def)
Definition def.cpp:552
static nat_t as_lit(Ref def)
Definition def.h:812
Helper class to retrieve Infer::arg if present.
Definition def.h:86
A dependent tuple type.
Definition tuple.h:9
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
The automatic differentiation Plugin
Definition autodiff.h:6
Ref normalize_Tangent(Ref, Ref, Ref arg)
Ref normalize_zero(Ref, Ref, Ref)
Currently this normalizer does nothing.
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:113
Ref normalize_sum(Ref type, Ref callee, Ref arg)
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:57
Ref normalize_AD(Ref, Ref, Ref arg)
Ref normalize_ad(Ref, Ref, Ref)
Currently this normalizer does nothin.
Ref normalize_add(Ref type, Ref callee, Ref arg)
Currently resolved the full addition.
u64 nat_t
Definition types.h:43
Vector< const Def * > DefVec
Definition def.h:62
auto match(Ref def)
Definition axiom.h:112