15 if (ad_ty)
return ad_ty;
29 auto& world = type->
world();
35 auto T = callee->as<
App>()->arg();
36 auto [a, b] = arg->
projs<2>();
38 world.DLOG(
"add {} {} {}", T, a, b);
49 if (
auto sig = T->isa<
Sigma>()) {
50 world.DLOG(
"add tuple");
51 auto p = sig->num_ops();
52 auto ops =
DefVec(p, [&](
size_t i) {
53 return world.app(world.app(world.annex<
add>(), sig->op(i)), {a->proj(i), b->proj(i)});
55 return world.tuple(ops);
56 }
else if (
auto arr = T->isa<
Arr>()) {
58 world.DLOG(
"add arrays {} {} {}", T, a, b);
59 auto pack = world.mut_pack(T);
60 auto body_type = arr->body();
61 world.DLOG(
"body type {}", body_type);
62 pack->set(world.app(world.app(world.annex<
add>(), body_type),
63 {world.extract(a, pack->var()), world.extract(b, pack->var())}));
64 world.DLOG(
"pack {}", pack);
67 world.DLOG(
"add int");
69 world.DLOG(
"width {}", width);
71 world.DLOG(
"int add {} : {}", int_add,
Idx::isa(int_add->type()));
73 }
else if (T->isa<
App>()) {
74 assert(0 &&
"not handled");
82 auto& world = type->
world();
84 auto [count, T] = callee->as<
App>()->args<2>();
86 if (
auto lit = count->isa<
Lit>()) {
87 auto val = lit->get<
nat_t>();
88 world.DLOG(
"val: {}", val);
89 auto args = arg->
projs(val);
90 auto sum = world.app(world.annex<
zero>(), T);
92 if (val >= 1)
sum = args[0];
93 for (
size_t i = 1; i < val; ++i)
sum = world.app(world.app(world.annex<
add>(), T), {sum, args[i]});
#define MIM_autodiff_NORMALIZER_IMPL
A (possibly paramterized) Array.
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == std::dynamic_extent) or std::array (other...
static nat_t as_lit(Ref def)
Helper class to retrieve Infer::arg if present.
This is a thin wrapper for std::span<T, N> with the following additional features:
The automatic differentiation Plugin
Ref normalize_Tangent(Ref, Ref, Ref arg)
Ref normalize_zero(Ref, Ref, Ref)
Currently this normalizer does nothing.
const Def * autodiff_type_fun(const Def *)
Ref normalize_sum(Ref type, Ref callee, Ref arg)
const Def * tangent_type_fun(const Def *)
Ref normalize_AD(Ref, Ref, Ref arg)
Ref normalize_ad(Ref, Ref, Ref)
Currently this normalizer does nothin.
Ref normalize_add(Ref type, Ref callee, Ref arg)
Currently resolved the full addition.
Vector< const Def * > DefVec