17 if (ad_ty)
return ad_ty;
31 auto& world = type->world();
37 auto T = callee->as<
App>()->arg();
38 auto [a, b] = arg->
projs<2>();
40 world.DLOG(
"add {} {} {}", T, a, b);
51 if (
auto sig = T->isa<
Sigma>()) {
52 world.DLOG(
"add tuple");
53 auto p = sig->num_ops();
54 auto ops =
DefVec(p, [&](
size_t i) {
55 return world.app(world.app(world.annex<
add>(), sig->op(i)), {a->proj(i), b->proj(i)});
57 return world.tuple(ops);
58 }
else if (
auto arr = T->isa<
Arr>()) {
60 world.DLOG(
"add arrays {} {} {}", T, a, b);
61 auto pack = world.mut_pack(T);
62 auto body_type = arr->body();
63 world.DLOG(
"body type {}", body_type);
64 pack->set(world.app(world.app(world.annex<
add>(), body_type),
65 {world.extract(a, pack->var()), world.extract(b, pack->var())}));
66 world.DLOG(
"pack {}", pack);
69 world.DLOG(
"add int");
71 world.DLOG(
"width {}", width);
73 world.DLOG(
"int add {} : {}", int_add,
Idx::isa(int_add->type()));
78 }
else if (T->isa<
App>()) {
79 assert(0 &&
"not handled");
86 auto& world = type->world();
88 auto [count, T] = callee->as<
App>()->args<2>();
90 if (
auto lit = count->isa<
Lit>()) {
91 auto val = lit->get<
nat_t>();
92 world.DLOG(
"val: {}", val);
93 auto args = arg->
projs(val);
94 auto sum = world.app(world.annex<
zero>(), T);
96 if (val >= 1)
sum = args[0];
97 for (
size_t i = 1; i < val; ++i)
98 sum = world.app(world.app(world.annex<
add>(), T), {sum, args[i]});
#define MIM_autodiff_NORMALIZER_IMPL
A (possibly paramterized) Array.
static auto isa(const Def *def)
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
static nat_t as_lit(const Def *def)
static const Def * isa(const Def *def)
Checks if def is a Idx s and returns s or nullptr otherwise.
The automatic differentiation Plugin
const Def * normalize_Tangent(const Def *, const Def *, const Def *arg)
const Def * normalize_add(const Def *type, const Def *callee, const Def *arg)
Currently resolved the full addition.
const Def * autodiff_type_fun(const Def *)
const Def * normalize_AD(const Def *, const Def *, const Def *arg)
const Def * normalize_ad(const Def *, const Def *, const Def *)
Currently this normalizer does nothin.
const Def * tangent_type_fun(const Def *)
const Def * normalize_sum(const Def *type, const Def *callee, const Def *arg)
const Def * normalize_zero(const Def *, const Def *, const Def *)
Currently this normalizer does nothing.
Vector< const Def * > DefVec