1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
|
//===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the AffineExpr visitor class.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_AFFINEEXPRVISITOR_H
#define MLIR_IR_AFFINEEXPRVISITOR_H
#include "mlir/IR/AffineExpr.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
/// Base class for AffineExpr visitors/walkers.
///
/// AffineExpr visitors are used when you want to perform different actions
/// for different kinds of AffineExprs without having to use lots of casts
/// and a big switch instruction.
///
/// To define your own visitor, inherit from this class, specifying your
/// new type for the 'SubClass' template parameter, and "override" visitXXX
/// functions in your class. This class is defined in terms of statically
/// resolved overloading, not virtual functions.
///
/// For example, here is a visitor that counts the number of for AffineDimExprs
/// in an AffineExpr.
///
/// /// Declare the class. Note that we derive from AffineExprVisitor
/// /// instantiated with our new subclasses_ type.
///
/// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
/// unsigned numDimExprs;
/// DimExprCounter() : numDimExprs(0) {}
/// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
/// };
///
/// And this class would be used like this:
/// DimExprCounter dec;
/// dec.visit(affineExpr);
/// numDimExprs = dec.numDimExprs;
///
/// AffineExprVisitor provides visit methods for the following binary affine
/// op expressions:
/// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
/// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
/// AffineBinaryCeilDivOpExpr. Note that default implementations of these
/// methods will call the general AffineBinaryOpExpr method.
///
/// In addition, visit methods are provided for the following affine
// expressions: AffineConstantExpr, AffineDimExpr, and
// AffineSymbolExpr.
///
/// Note that if you don't implement visitXXX for some affine expression type,
/// the visitXXX method for Instruction superclass will be invoked.
///
/// Note that this class is specifically designed as a template to avoid
/// virtual function call overhead. Defining and using a AffineExprVisitor is
/// just as efficient as having your own switch instruction over the instruction
/// opcode.
template <typename SubClass, typename RetTy = void>
class AffineExprVisitor {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the AffineExprVisitor
// that you use to visit affine expressions...
public:
// Function to walk an AffineExpr (in post order).
RetTy walkPostOrder(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
walkOperandsPostOrder(binOpExpr);
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
expr.cast<AffineConstantExpr>());
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
expr.cast<AffineDimExpr>());
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
expr.cast<AffineSymbolExpr>());
}
}
// Function to visit an AffineExpr.
RetTy visit(AffineExpr expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
return static_cast<SubClass *>(this)->visitConstantExpr(
expr.cast<AffineConstantExpr>());
case AffineExprKind::DimId:
return static_cast<SubClass *>(this)->visitDimExpr(
expr.cast<AffineDimExpr>());
case AffineExprKind::SymbolId:
return static_cast<SubClass *>(this)->visitSymbolExpr(
expr.cast<AffineSymbolExpr>());
}
llvm_unreachable("Unknown AffineExpr");
}
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
//
// Default visit methods. Note that the default op-specific binary op visit
// methods call the general visitAffineBinaryOpExpr visit method.
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
RetTy visitAddExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitMulExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitModExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
private:
// Walk the operands - each operand is itself walked in post order.
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
walkPostOrder(expr.getLHS());
walkPostOrder(expr.getRHS());
}
};
// This class is used to flatten a pure affine expression (AffineExpr,
// which is in a tree form) into a sum of products (w.r.t constants) when
// possible, and in that process simplifying the expression. For a modulo,
// floordiv, or a ceildiv expression, an additional identifier, called a local
// identifier, is introduced to rewrite the expression as a sum of product
// affine expression. Each local identifier is always and by construction a
// floordiv of a pure add/mul affine function of dimensional, symbolic, and
// other local identifiers, in a non-mutually recursive way. Hence, every local
// identifier can ultimately always be recovered as an affine function of
// dimensional and symbolic identifiers (involving floordiv's); note however
// that by AffineExpr construction, some floordiv combinations are converted to
// mod's. The result of the flattening is a flattened expression and a set of
// constraints involving just the local variables.
//
// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local
// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
//
// The simplification performed includes the accumulation of contributions for
// each dimensional and symbolic identifier together, the simplification of
// floordiv/ceildiv/mod expressions and other simplifications that in turn
// happen as a result. A simplification that this flattening naturally performs
// is of simplifying the numerator and denominator of floordiv/ceildiv, and
// folding a modulo expression to a zero, if possible. Three examples are below:
//
// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
// (d0 - d0 mod 4 + 4) mod 4 simplified to 0
// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
//
// The way the flattening works for the second example is as follows: d0 % 4 is
// replaced by d0 - 4*q with q being introduced: the expression then simplifies
// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
// zero. Note that an affine expression may not always be expressible purely as
// a sum of products involving just the original dimensional and symbolic
// identifiers due to the presence of modulo/floordiv/ceildiv expressions that
// may not be eliminated after simplification; in such cases, the final
// expression can be reconstructed by replacing the local identifiers with their
// corresponding explicit form stored in 'localExprs' (note that each of the
// explicit forms itself would have been simplified).
//
// The expression walk method here performs a linear time post order walk that
// performs the above simplifications through visit methods, with partial
// results being stored in 'operandExprStack'. When a parent expr is visited,
// the flattened expressions corresponding to its two operands would already be
// on the stack - the parent expression looks at the two flattened expressions
// and combines the two. It pops off the operand expressions and pushes the
// combined result (although this is done in-place on its LHS operand expr).
// When the walk is completed, the flattened form of the top-level expression
// would be left on the stack.
//
// A flattener can be repeatedly used for multiple affine expressions that bind
// to the same operands, for example, for all result expressions of an
// AffineMap or AffineValueMap. In such cases, using it for multiple expressions
// is more efficient than creating a new flattener for each expression since
// common identical div and mod expressions appearing across different
// expressions are mapped to the same local identifier (same column position in
// 'localVarCst').
class SimpleAffineExprFlattener
: public AffineExprVisitor<SimpleAffineExprFlattener> {
public:
// Flattend expression layout: [dims, symbols, locals, constant]
// Stack that holds the LHS and RHS operands while visiting a binary op expr.
// In future, consider adding a prepass to determine how big the SmallVector's
// will be, and linearize this to std::vector<int64_t> to prevent
// SmallVector moves on re-allocation.
std::vector<SmallVector<int64_t, 8>> operandExprStack;
unsigned numDims;
unsigned numSymbols;
// Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
unsigned numLocals;
// AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
// which new identifiers were introduced; if the latter do not get canceled
// out, these expressions can be readily used to reconstruct the AffineExpr
// (tree) form. Note that these expressions themselves would have been
// simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
// will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
// ceildiv 2 would be the local expression stored for q.
SmallVector<AffineExpr, 4> localExprs;
SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
virtual ~SimpleAffineExprFlattener() = default;
// Visitor method overrides.
void visitMulExpr(AffineBinaryOpExpr expr);
void visitAddExpr(AffineBinaryOpExpr expr);
void visitDimExpr(AffineDimExpr expr);
void visitSymbolExpr(AffineSymbolExpr expr);
void visitConstantExpr(AffineConstantExpr expr);
void visitCeilDivExpr(AffineBinaryOpExpr expr);
void visitFloorDivExpr(AffineBinaryOpExpr expr);
//
// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
//
// A mod expression "expr mod c" is thus flattened by introducing a new local
// variable q (= expr floordiv c), such that expr mod c is replaced with
// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
void visitModExpr(AffineBinaryOpExpr expr);
protected:
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
// The local identifier added is always a floordiv of a pure add/mul affine
// function of other identifiers, coefficients of which are specified in
// dividend and with respect to a positive constant divisor. localExpr is the
// simplified tree expression (AffineExpr) corresponding to the quantifier.
virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
AffineExpr localExpr);
/// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
/// expr) when the rhs is a symbolic expression. The local identifier added
/// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
/// function of other identifiers, coefficients of which are specified in the
/// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
/// symbolic rhs expression. `localExpr` is the simplified tree expression
/// (AffineExpr) corresponding to the quantifier.
virtual void addLocalIdSemiAffine(AffineExpr localExpr);
private:
/// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
/// representing the affine expression corresponding to the quantifier
/// introduced as the local variable corresponding to `expr`. If the
/// quantifier is already present, we put the coefficient in the proper index
/// of `result`, otherwise we add a new local variable and put the coefficient
/// there.
void addLocalVariableSemiAffine(AffineExpr expr,
SmallVectorImpl<int64_t> &result,
unsigned long resultSize);
// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
// A floordiv is thus flattened by introducing a new local variable q, and
// replacing that expression with 'q' while adding the constraints
// c * q <= expr <= c * q + c - 1 to localVarCst (done by
// FlatAffineConstraints::addLocalFloorDiv).
//
// A ceildiv is similarly flattened:
// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
int findLocalId(AffineExpr localExpr);
inline unsigned getNumCols() const {
return numDims + numSymbols + numLocals + 1;
}
inline unsigned getConstantIndex() const { return getNumCols() - 1; }
inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
inline unsigned getSymbolStartIndex() const { return numDims; }
inline unsigned getDimStartIndex() const { return 0; }
};
} // namespace mlir
#endif // MLIR_IR_AFFINEEXPRVISITOR_H
|