Thorin 1.9.0
The Higher ORder INtermediate representation
Loading...
Searching...
No Matches
lower_matrix_lowlevel.cpp
Go to the documentation of this file.
2
3#include <cassert>
4
5#include <iostream>
6
7#include <thorin/lam.h>
8
9#include "thorin/axiom.h"
10#include "thorin/def.h"
11
18#include "thorin/plug/mem/mem.h"
19
20namespace thorin::plug::matrix {
21
22namespace {
23
24Ref op_lea_tuple(Ref arr, Ref tuple) {
25 auto& world = arr->world();
26 world.DLOG("op_lea_tuple arr {} : {}", arr, arr->type());
27 auto n = tuple->num_projs();
28 auto element = arr;
29 for (size_t i = 0; i < n; ++i) element = mem::op_lea(element, tuple->proj(n, i));
30 return element;
31}
32
33Ref op_pack_tuple(u64 n, Ref tuple, Ref val) {
34 auto& world = val->world();
35 // TODO: find out why num_projs is wrong
36 auto element = val;
37 for (int i = n - 1; i >= 0; i--) {
38 auto dim = tuple->proj(n, i);
39 element = world.pack(dim, element);
40 }
41 world.DLOG("op_pack_tuple: {} -> {}", val, element);
42 world.DLOG(" for tuple: {} : {}", tuple, tuple->type());
43 return element;
44}
45
46Ref arr_ty_of_matrix_ty(Ref S, Ref T) {
47 auto& world = S->world();
48 auto n = S->num_projs();
49 auto arr_ty = T;
50 for (int i = n - 1; i >= 0; i--) {
51 auto dim = S->proj(n, i);
52 arr_ty = world.arr(dim, arr_ty);
53 }
54 return arr_ty;
55}
56
57} // namespace
58
60 assert(!match<matrix::map_reduce>(def) && "map_reduce should have been lowered to for loops by now");
61 assert(!match<matrix::shape>(def) && "high level operations should have been lowered to for loops by now");
62 assert(!match<matrix::prod>(def) && "high level operations should have been lowered to for loops by now");
63 assert(!match<matrix::transpose>(def) && "high level operations should have been lowered to for loops by now");
64 assert(!match<matrix::sum>(def) && "high level operations should have been lowered to for loops by now");
65
66 // TODO: generalize arg rewrite
67 if (auto mat_ax = match<matrix::Mat>(def)) {
68 auto [_, S, T] = mat_ax->args<3>();
69 S = rewrite(S);
70 T = rewrite(T);
71 auto arr_ty = arr_ty_of_matrix_ty(S, T);
72
73 auto addr_space = world().lit_nat_0();
74 auto ptr_ty = world().call<mem::Ptr>(Defs{arr_ty, addr_space});
75
76 return ptr_ty;
77 } else if (auto init_ax = match<matrix::init>(def)) {
78 world().DLOG("init {} : {}", def, def->type());
79 auto [_, S, T, mem] = init_ax->args<4>();
80 world().DLOG(" S T mem {} {} {}", S, T, mem);
81 S = rewrite(S);
82 T = rewrite(T);
83 mem = rewrite(mem);
84 world().DLOG(" S T mem {} {} {}", S, T, mem);
85 auto arr_ty = arr_ty_of_matrix_ty(S, T);
86 auto [mem2, ptr_mat] = mem::op_alloc(arr_ty, mem)->projs<2>();
87 auto res = world().tuple({mem2, ptr_mat});
88 world().DLOG(" res {} : {}", res, res->type());
89 return res;
90 } else if (auto read_ax = match<matrix::read>(def)) {
91 auto [mem, mat, idx] = read_ax->args<3>();
92 world().DLOG("read_ax: {}", read_ax);
93 world().DLOG(" mem: {} : {}", mem, mem->type());
94 world().DLOG(" mat: {} : {}", mat, mat->type());
95 world().DLOG(" idx: {} : {}", idx, idx->type());
96 mem = rewrite(mem);
97 mat = rewrite(mat);
98 idx = rewrite(idx);
99 world().DLOG("rewritten read");
100 world().DLOG(" mem: {} : {}", mem, mem->type());
101 world().DLOG(" mat: {} : {}", mat, mat->type());
102 world().DLOG(" idx: {} : {}", idx, idx->type());
103 // TODO: check if mat is already converted
104 auto ptr_mat = mat;
105 auto element_ptr = op_lea_tuple(ptr_mat, idx);
106 auto [mem2, val] = world().call<mem::load>(Defs{mem, element_ptr})->projs<2>();
107 return world().tuple({mem2, val});
108 } else if (auto insert_ax = match<matrix::insert>(def)) {
109 auto [mem, mat, idx, val] = insert_ax->args<4>();
110 world().DLOG("insert_ax: {}", insert_ax);
111 world().DLOG(" mem: {} : {}", mem, mem->type());
112 world().DLOG(" mat: {} : {}", mat, mat->type());
113 world().DLOG(" idx: {} : {}", idx, idx->type());
114 world().DLOG(" val: {} : {}", val, val->type());
115 mem = rewrite(mem);
116 mat = rewrite(mat);
117 idx = rewrite(idx);
118 val = rewrite(val);
119 world().DLOG("rewritten insert");
120 world().DLOG(" mem: {} : {}", mem, mem->type());
121 world().DLOG(" mat: {} : {}", mat, mat->type());
122 world().DLOG(" idx: {} : {}", idx, idx->type());
123 world().DLOG(" val: {} : {}", val, val->type());
124 auto ptr_mat = mat;
125 auto element_ptr = op_lea_tuple(ptr_mat, idx);
126 auto mem2 = world().call<mem::store>(Defs{mem, element_ptr, val});
127 return world().tuple({mem2, ptr_mat});
128 } else if (auto const_ax = match<matrix::constMat>(def)) {
129 auto [mem, val] = const_ax->args<2>();
130 mem = rewrite(mem);
131 val = rewrite(val);
132 auto [n_def, S, T] = const_ax->callee()->as<App>()->args<3>();
133 S = rewrite(S);
134 T = rewrite(T);
135 auto arr_ty = arr_ty_of_matrix_ty(S, T);
136 auto [mem2, ptr_mat] = mem::op_alloc(arr_ty, mem)->projs<2>();
137
138 // store initial value
139 auto n = n_def->as<Lit>()->get<u64>();
140 auto initial = op_pack_tuple(n, S, val);
141
142 auto mem3 = world().call<mem::store>(Defs{mem2, ptr_mat, initial});
143
144 return world().tuple({mem3, ptr_mat});
145 }
146
147 // ignore unapplied axioms to avoid spurious type replacements
148 if (def->isa<Axiom>()) return def;
149
150 return Rewriter::rewrite_imm(def); // continue recursive rewriting with everything else
151}
152
153} // namespace thorin::plug::matrix
auto projs(F f) const
Splits this Def via Def::projections into an Array (if A == -1_n) or std::array (otherwise).
Definition def.h:369
const Def * type() const
Yields the raw type of this Def, i.e. maybe nullptr.
Definition def.h:248
World & world() const
Definition def.cpp:421
World & world()
Definition phase.h:59
Helper class to retrieve Infer::arg if present.
Definition def.h:87
virtual Ref rewrite(Ref)
Definition rewrite.cpp:9
virtual Ref rewrite_imm(Ref)
Definition rewrite.cpp:16
This is a thin wrapper for std::span<T, N> with the following additional features:
Definition span.h:28
Ref pack(Ref arity, Ref body)
Definition world.cpp:405
const Def * call(Id id, Args &&... args)
Definition world.h:497
const Lit * lit_nat_0()
Definition world.h:368
Ref tuple(Defs ops)
Definition world.cpp:226
The matrix Plugin
Definition matrix.h:9
The mem Plugin
Definition mem.h:11
Ref op_alloc(Ref type, Ref mem)
Definition mem.h:136
Ref op_lea(Ref ptr, Ref index)
Definition mem.h:111
uint64_t u64
Definition types.h:35