MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
normalizers.cpp
Go to the documentation of this file.
1#include <iostream>
2
3#include "mim/axiom.h"
4#include "mim/world.h"
5
8
9namespace mim::plug::autodiff {
10
11/// Currently this normalizer does nothin.
12/// TODO: Maybe we want to handle trivial lookup replacements here.
13Ref normalize_ad(Ref type, Ref callee, Ref arg) {
14 auto& world = type->world();
15 return world.raw_app(type, callee, arg);
16}
17
18Ref normalize_AD(Ref type, Ref callee, Ref arg) {
19 auto& world = type->world();
20 auto ad_ty = autodiff_type_fun(arg);
21 if (ad_ty) return ad_ty;
22 return world.raw_app(type, callee, arg);
23}
24
26
27/// Currently this normalizer does nothing.
28/// We usually want to keep zeros as long as possible to avoid unnecessary allocations.
29/// A high-level addition with zero can be shortened directly.
30Ref normalize_zero(Ref type, Ref callee, Ref arg) {
31 auto& world = type->world();
32 return world.raw_app(type, callee, arg);
33}
34
35/// Currently resolved the full addition.
36/// There is no benefit in keeping additions around longer than necessary.
37Ref normalize_add(Ref type, Ref callee, Ref arg) {
38 auto& world = type->world();
39
40 // TODO: add tuple -> tuple of adds
41 // TODO: add zero -> other
42 // TODO: unify mapping over structure with other aspects like zero
43
44 auto T = callee->as<App>()->arg();
45 auto [a, b] = arg->projs<2>();
46
47 world.DLOG("add {} {} {}", T, a, b);
48
49 if (match<zero>(a)) {
50 world.DLOG("0+b");
51 return b;
52 }
53 if (match<zero>(b)) {
54 world.DLOG("0+a");
55 return a;
56 }
57 // A value level match would be harder as a tuple might in reality be a var or extract
58 if (auto sig = T->isa<Sigma>()) {
59 world.DLOG("add tuple");
60 auto p = sig->num_ops(); // TODO: or num_projs
61 auto ops = DefVec(p, [&](size_t i) {
62 return world.app(world.app(world.annex<add>(), sig->op(i)), {a->proj(i), b->proj(i)});
63 });
64 return world.tuple(ops);
65 } else if (auto arr = T->isa<Arr>()) {
66 // TODO: is this working for non-lit (non-tuple) or do we need a loop?
67 world.DLOG("add arrays {} {} {}", T, a, b);
68 auto pack = world.mut_pack(T);
69 auto body_type = arr->body();
70 world.DLOG("body type {}", body_type);
71 pack->set(world.app(world.app(world.annex<add>(), body_type),
72 {world.extract(a, pack->var()), world.extract(b, pack->var())}));
73 world.DLOG("pack {}", pack);
74 return pack;
75 } else if (Idx::size(type)) {
76 world.DLOG("add int");
77 auto width = Lit::as(world.iinfer(a));
78 world.DLOG("width {}", width);
79 auto int_add = world.call(core::wrap::add, 0_n, Defs{a, b});
80 world.DLOG("int add {} : {}", int_add, world.iinfer(int_add));
81 return int_add;
82 } else if (T->isa<App>()) {
83 assert(0 && "not handled");
84 }
85 // TODO: mem stays here (only resolved after direct simplification)
86
87 return world.raw_app(type, callee, arg);
88}
89
90Ref normalize_sum(Ref type, Ref callee, Ref arg) {
91 auto& world = type->world();
92
93 auto [count, T] = callee->as<App>()->args<2>();
94
95 if (auto lit = count->isa<Lit>()) {
96 auto val = lit->get<nat_t>();
97 world.DLOG("val: {}", val);
98 auto args = arg->projs(val);
99 auto sum = world.app(world.annex<zero>(), T);
100 // This special case would also be handled by add zero
101 if (val >= 1) sum = args[0];
102 for (size_t i = 1; i < val; ++i) sum = world.app(world.app(world.annex<add>(), T), {sum, args[i]});
103 return sum;
104 }
105 assert(0);
106
107 return world.raw_app(type, callee, arg);
108}
109
111
112} // namespace mim::plug::autodiff
#define MIM_autodiff_NORMALIZER_IMPL
Definition autogen.h:90
A (possibly paramterized) Array.
Definition tuple.h:66
World & world() const
Definition def.cpp:415
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Definition def.h:367
static Ref size(Ref def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Definition def.cpp:556
static T as(Ref def)
Definition def.h:768
Helper class to retrieve Infer::arg if present.
Definition def.h:86
A dependent tuple type.
Definition tuple.h:9
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
Ref raw_app(Ref type, Ref callee, Ref arg)
Definition world.cpp:217
The automatic differentiation Plugin
Definition autodiff.h:6
Ref normalize_Tangent(Ref, Ref, Ref arg)
const Def * autodiff_type_fun(const Def *)
Definition autodiff.cpp:113
Ref normalize_sum(Ref type, Ref callee, Ref arg)
const Def * tangent_type_fun(const Def *)
Definition autodiff.cpp:57
Ref normalize_ad(Ref type, Ref callee, Ref arg)
Currently this normalizer does nothin.
Ref normalize_AD(Ref type, Ref callee, Ref arg)
Ref normalize_add(Ref type, Ref callee, Ref arg)
Currently resolved the full addition.
Ref normalize_zero(Ref type, Ref callee, Ref arg)
Currently this normalizer does nothing.
u64 nat_t
Definition types.h:43
Vector< const Def * > DefVec
Definition def.h:62
auto match(Ref def)
Definition axiom.h:112