MimIR 0.1
MimIR is my Intermediate Representation
Loading...
Searching...
No Matches
reshape.h
Go to the documentation of this file.
1#pragma once
2
3#include "mim/phase.h"
4
5namespace mim::plug::mem::pass {
6
7/// The general idea of this Pass is to change the shape of signatures of functions.
8/// * Example: `Cn[ [mem, A, B], C , ret]`
9/// * Arg : `Cn[ [mem, [A, B , C]], ret]` (general `Cn[ [mem, args], ret]`)
10/// * Flat : `Cn[ mem, A, B , C , ret]` (general `Cn[mem, ...args, ret]`)
11/// For convenience, we want Arg-style for optimizations.
12/// The invariant is that every closed function has at most one "real" argument and a return-continuation.
13/// If memory is present, the argument is a pair of memory and the remaining arguments.
14/// However, flat style is required for code generation. Especially in the closure conversion.
15///
16/// The concept is to rewrite all signatures of functions with consistent reassociation of arguments.
17/// This change is propagated to (nested) applications.
18// TODO: use RWPhase instead
19class Reshape : public RWPass<Reshape, Lam> {
20public:
21 enum Mode { Flat, Arg };
22
25
26 void apply(Mode);
27 void apply(const App* app) final;
28 void apply(Stage& s) final { apply(static_cast<Reshape&>(s).mode()); }
29
30 Mode mode() const { return mode_; }
31
32 /// Fall-through to `rewrite_def` which falls through to `rewrite_lam`.
33 void enter() override;
34
35private:
36 /// Memoized version of `rewrite_def_`
37 const Def* rewrite_def(const Def* def);
38 /// Replace lambas with reshaped versions, shape application arguments, and replace vars and already rewritten
39 /// lambdas.
40 const Def* rewrite_def_(const Def* def);
41 /// Create a new lambda with the reshaped signature and rewrite its body.
42 /// The old var is associated with a reshaped version of the new var in `old2new_`.
43 Lam* reshape_lam(Lam* def);
44
45 /// Reshapes a type into its flat or arg representation.
46 const Def* reshape_type(const Def* T);
47 /// Reshapes a def into its flat or arg representation.
48 const Def* reshape(const Def* def);
49 // This generalized version of reshape transforms def to match the shape of target.
50 const Def* reshape(const Def* def, const Def* target);
51 /// Reconstructs the target type by taking defs out of the queue.
52 const Def* reshape(DefVec&, const Def* target, const Def* mem);
53
54 /// Keeps track of the replacements.
55 Def2Def old2new_;
56 /// The mode to rewrite all lambas to. Either flat or arg.
57 Mode mode_;
58};
59
60} // namespace mim::plug::mem::pass
Base class for all Defs.
Definition def.h:251
A function.
Definition lam.h:111
RWPass(World &world, std::string name)
Definition pass.h:295
World & world()
Definition pass.h:64
Stage(World &world, std::string name)
Definition pass.h:30
flags_t annex() const
Definition pass.h:68
The World represents the whole program and manages creation of MimIR nodes (Defs).
Definition world.h:36
Reshape(World &world, flags_t annex)
Definition reshape.h:23
void enter() override
Fall-through to rewrite_def which falls through to rewrite_lam.
Definition reshape.cpp:75
void apply(Stage &s) final
Dito, but invoked by Stage::recreate.
Definition reshape.h:28
The mem Plugin
Definition mem.h:11
DefMap< const Def * > Def2Def
Definition def.h:75
Vector< const Def * > DefVec
Definition def.h:77
u64 flags_t
Definition types.h:45