14 auto& world = type->
world();
15 return world.
raw_app(type, callee, arg);
19 auto& world = type->
world();
21 if (ad_ty)
return ad_ty;
22 return world.raw_app(type, callee, arg);
31 auto& world = type->
world();
32 return world.
raw_app(type, callee, arg);
38 auto& world = type->
world();
44 auto T = callee->as<
App>()->arg();
45 auto [a, b] = arg->
projs<2>();
47 world.DLOG(
"add {} {} {}", T, a, b);
58 if (
auto sig = T->isa<
Sigma>()) {
59 world.DLOG(
"add tuple");
60 auto p = sig->num_ops();
61 auto ops =
DefVec(p, [&](
size_t i) {
62 return world.app(world.app(world.annex<
add>(), sig->op(i)), {a->proj(i), b->proj(i)});
64 return world.tuple(ops);
65 }
else if (
auto arr = T->isa<
Arr>()) {
67 world.DLOG(
"add arrays {} {} {}", T, a, b);
68 auto pack = world.mut_pack(T);
69 auto body_type = arr->body();
70 world.DLOG(
"body type {}", body_type);
71 pack->set(world.app(world.app(world.annex<
add>(), body_type),
72 {world.extract(a, pack->var()), world.extract(b, pack->var())}));
73 world.DLOG(
"pack {}", pack);
76 world.DLOG(
"add int");
77 auto width =
Lit::as(world.iinfer(a));
78 world.DLOG(
"width {}", width);
80 world.DLOG(
"int add {} : {}", int_add, world.iinfer(int_add));
82 }
else if (T->isa<
App>()) {
83 assert(0 &&
"not handled");
87 return world.raw_app(type, callee, arg);
91 auto& world = type->
world();
93 auto [count, T] = callee->as<
App>()->args<2>();
95 if (
auto lit = count->isa<
Lit>()) {
96 auto val = lit->get<
nat_t>();
97 world.DLOG(
"val: {}", val);
98 auto args = arg->
projs(val);
99 auto sum = world.app(world.annex<
zero>(), T);
101 if (val >= 1)
sum = args[0];
102 for (
size_t i = 1; i < val; ++i)
sum = world.app(world.app(world.annex<
add>(), T), {sum, args[i]});
107 return world.raw_app(type, callee, arg);
#define MIM_autodiff_NORMALIZER_IMPL
A (possibly paramterized) Array.
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
static Ref size(Ref def)
Checks if def is a Idx s and returns s or nullptr otherwise.
Helper class to retrieve Infer::arg if present.
This is a thin wrapper for std::span<T, N> with the following additional features:
Ref raw_app(Ref type, Ref callee, Ref arg)
The automatic differentiation Plugin
Ref normalize_Tangent(Ref, Ref, Ref arg)
const Def * autodiff_type_fun(const Def *)
Ref normalize_sum(Ref type, Ref callee, Ref arg)
const Def * tangent_type_fun(const Def *)
Ref normalize_ad(Ref type, Ref callee, Ref arg)
Currently this normalizer does nothin.
Ref normalize_AD(Ref type, Ref callee, Ref arg)
Ref normalize_add(Ref type, Ref callee, Ref arg)
Currently resolved the full addition.
Ref normalize_zero(Ref type, Ref callee, Ref arg)
Currently this normalizer does nothing.
Vector< const Def * > DefVec