MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
lower_matrix_highlevel.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5#include <mim/lam.h>
6
11#include "mim/plug/mem/mem.h"
12
13namespace mim::plug::matrix {
14
15namespace {
16
17// clang-format off
18absl::flat_hash_map<flags_t, flags_t> axm_to_impl_map = {
22};
23// clang-format on
24
25std::optional<const Def*> internal_function_of_axm(const Axm* axm, const Def* meta_args, const Def* args) {
26 auto& world = axm->world();
27 if (auto it = axm_to_impl_map.find(axm->flags()); it != axm_to_impl_map.end()) {
28 const Def* spec_fun = world.implicit_app(world.flags2annex().at(it->second), meta_args);
29 auto ds_fun = direct::op_cps2ds_dep(spec_fun);
30 return world.app(ds_fun, args);
31 }
32 return std::nullopt;
33}
34
35} // namespace
36
38 if (auto i = rewritten.find(def); i != rewritten.end()) return i->second;
39 auto new_def = rewrite_(def);
40 rewritten[def] = new_def;
41 return rewritten[def];
42}
43
45 if (auto mat_ax = Axm::isa<matrix::prod>(def)) {
46 auto [mem, M, N] = mat_ax->args<3>();
47 auto [m, k, l, w] = mat_ax->decurry()->args<4>();
48 auto w_lit = Lit::isa(w);
49
50 auto ext_fun = world().externals()[world().sym("extern_matrix_prod")];
51 if (ext_fun && (w_lit && *w_lit == 64)) {
52 auto ds_fun = direct::op_cps2ds_dep(ext_fun);
53 auto fun_app = world().app(ds_fun, {mem, m, k, l, M, N});
54 return fun_app;
55 }
56 }
57
58 if (auto outer_app = def->isa<App>()) {
59 if (auto inner_app = outer_app->callee()->isa<App>()) {
60 if (auto axm = inner_app->callee()->isa<Axm>()) {
61 if (auto internal_function = internal_function_of_axm(axm, inner_app->arg(), outer_app->arg())) {
62 DLOG("lower matrix axm {} in {} : {}", *axm->sym(), def, def->type());
63 DLOG("lower matrix axm using: {} : {}", *internal_function, (*internal_function)->type());
64 return *internal_function;
65 }
66 }
67 }
68 }
69
70 return def;
71}
72
73} // namespace mim::plug::matrix
Definition axm.h:9
static auto isa(const Def *def)
Definition axm.h:107
Base class for all Defs.
Definition def.h:251
const Def * type() const noexcept
Yields the "raw" type of this Def (maybe nullptr).
Definition def.cpp:446
static std::optional< T > isa(const Def *def)
Definition def.h:826
World & world()
Definition pass.h:64
const Def * app(const Def *callee, const Def *arg)
Definition world.cpp:201
Sym sym(std::string_view)
Definition world.cpp:90
const Externals & externals() const
Definition world.h:236
const Def * rewrite(const Def *) override
custom rewrite function memoized version of rewrite_
#define M(S, D)
#define DLOG(...)
Vaporizes to nothingness in Debug build.
Definition log.h:95
const Def * op_cps2ds_dep(const Def *k)
Definition direct.h:16
The matrix Plugin
Definition matrix.h:7
The mem Plugin
Definition mem.h:11
u64 flags_t
Definition types.h:46
@ Axm
Definition def.h:114
static constexpr flags_t Base
Definition plugin.h:117