Thorin 1.9.0
The Higher ORder INtermediate representation
Loading...
Searching...
No Matches
lower_matrix_highlevel.cpp
Go to the documentation of this file.
2
3#include <iostream>
4
5#include <thorin/lam.h>
6
11#include "thorin/plug/mem/mem.h"
12
13namespace thorin::plug::matrix {
14
15namespace {
16
17std::optional<Ref> internal_function_of_axiom(const Axiom* axiom, Ref meta_args, Ref args) {
18 auto& world = axiom->world();
19 auto name = axiom->sym().str();
20 find_and_replace(name, ".", "_");
21 find_and_replace(name, "%", "");
22 name = INTERNAL_PREFIX + name;
23
24 auto replacement = world.external(world.sym(name));
25 if (replacement) {
26 auto spec_fun = world.app(replacement, meta_args);
27 auto ds_fun = direct::op_cps2ds_dep(spec_fun);
28 return world.app(ds_fun, args);
29 }
30 return std::nullopt;
31}
32
33} // namespace
34
36 if (auto i = rewritten.find(def); i != rewritten.end()) return i->second;
37 auto new_def = rewrite_(def);
38 rewritten[def] = new_def;
39 return rewritten[def];
40}
41
43 if (auto mat_ax = match<matrix::prod>(def)) {
44 auto [mem, M, N] = mat_ax->args<3>();
45 auto [m, k, l, w] = mat_ax->decurry()->args<4>();
46 auto w_lit = Lit::isa(w);
47
48 auto ext_fun = world().external(world().sym("extern_matrix_prod"));
49 if (ext_fun && (w_lit && *w_lit == 64)) {
50 auto ds_fun = direct::op_cps2ds_dep(ext_fun);
51 auto fun_app = world().app(ds_fun, {mem, m, k, l, M, N});
52 return fun_app;
53 }
54 }
55
56 if (auto outer_app = def->isa<App>()) {
57 if (auto inner_app = outer_app->callee()->isa<App>()) {
58 if (auto axiom = inner_app->callee()->isa<Axiom>()) {
59 if (auto internal_function = internal_function_of_axiom(axiom, inner_app->arg(), outer_app->arg())) {
60 world().DLOG("lower matrix axiom {} in {} : {}", *axiom->sym(), def, def->type());
61 world().DLOG("lower matrix axiom using: {} : {}", *internal_function, (*internal_function)->type());
62 return *internal_function;
63 }
64 }
65 }
66 }
67
68 return def;
69}
70
71} // namespace thorin::plug::matrix
const Def * type() const
Yields the raw type of this Def, i.e. maybe nullptr.
Definition def.h:248
static std::optional< T > isa(Ref def)
Definition def.h:726
World & world()
Definition pass.h:296
Helper class to retrieve Infer::arg if present.
Definition def.h:87
Def * external(Sym name)
Lookup by name.
Definition world.h:164
Ref app(Ref callee, Ref arg)
Definition world.cpp:183
Ref rewrite(Ref) override
custom rewrite function memoized version of rewrite_
#define M(S, D)
#define INTERNAL_PREFIX
Definition matrix.h:11
const Def * op_cps2ds_dep(const Def *f)
Definition direct.h:11
The matrix Plugin
Definition matrix.h:9
The mem Plugin
Definition mem.h:11
void find_and_replace(std::string &str, std::string_view what, std::string_view repl)
Replaces all occurrences of what with repl.
Definition util.h:70