summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2020-07-10 09:49:22 -0400
committerNicolas Vasilache <ntv@google.com>2020-07-10 11:09:27 -0400
commita490d387e6e6085b35a6850581b62db3d2d47009 (patch)
tree083d976b251c99d2e5923e2bdd0d7742a03303ef
parent9fd4b5faacbdfb887389c9ac246efa23be1cd334 (diff)
downloadllvm-a490d387e6e6085b35a6850581b62db3d2d47009.tar.gz
[mlir][Vector] Add ExtractOp folding when fed by a TransposeOp
TransposeOp are often followed by ExtractOp. In certain cases however, it is unnecessary (and even detrimental) to lower a TransposeOp to either a flat transpose (llvm.matrix intrinsics) or to unrolled scalar insert / extract chains. Providing foldings of ExtractOp mitigates some of the unnecessary complexity. Differential revision: https://reviews.llvm.org/D83487
-rw-r--r--mlir/include/mlir/IR/AffineMap.h9
-rw-r--r--mlir/lib/Dialect/Vector/VectorOps.cpp60
-rw-r--r--mlir/lib/IR/AffineMap.cpp23
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir41
4 files changed, 130 insertions, 3 deletions
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index a44723024843..54f81db92a3e 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -170,6 +170,10 @@ public:
/// `(d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)`
AffineMap compose(AffineMap map);
+ /// Applies composition by the dims of `this` to the integer `values` and
+ /// returns the resulting values. `this` must be symbol-less.
+ SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values);
+
/// Returns true if the AffineMap represents a subset (i.e. a projection) of a
/// symbol-less permutation map.
bool isProjectedPermutation();
@@ -180,6 +184,11 @@ public:
/// Returns the map consisting of the `resultPos` subset.
AffineMap getSubMap(ArrayRef<unsigned> resultPos);
+ /// Returns the map consisting of the most major `numResults` results.
+ /// Returns the null AffineMap if `numResults` == 0.
+ /// Returns `*this` if `numResults` >= `this->getNumResults()`.
+ AffineMap getMajorSubMap(unsigned numResults);
+
/// Returns the map consisting of the most minor `numResults` results.
/// Returns the null AffineMap if `numResults` == 0.
/// Returns `*this` if `numResults` >= `this->getNumResults()`.
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 0aae97f24d20..cdf09c4a8f68 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -602,6 +603,63 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
return success();
}
+/// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
+static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
+ auto transposeOp = extractOp.vector().getDefiningOp<TransposeOp>();
+ if (!transposeOp)
+ return failure();
+
+ auto permutation = extractVector<unsigned>(transposeOp.transp());
+ auto extractedPos = extractVector<int64_t>(extractOp.position());
+
+ // If transposition permutation is larger than the ExtractOp, all minor
+ // dimensions must be an identity for folding to occur. If not, individual
+ // elements within the extracted value are transposed and this is not just a
+ // simple folding.
+ unsigned minorRank = permutation.size() - extractedPos.size();
+ MLIRContext *ctx = extractOp.getContext();
+ AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
+ AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
+ if (minorMap && !AffineMap::isMinorIdentity(minorMap))
+ return failure();
+
+ // %1 = transpose %0[x, y, z] : vector<axbxcxf32>
+ // %2 = extract %1[u, v] : vector<..xf32>
+ // may turn into:
+ // %2 = extract %0[w, x] : vector<..xf32>
+ // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
+ // -1 denotes the inverse.
+ permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
+ // The major submap has fewer results but the same number of dims. To compose
+ // cleanly, we need to drop dims to form a "square matrix". This is possible
+ // because:
+ // (a) this is a permutation map and
+ // (b) the minor map has already been checked to be identity.
+ // Therefore, the major map cannot contain dims of position greater or equal
+ // than the number of results.
+ assert(llvm::all_of(permutationMap.getResults(),
+ [&](AffineExpr e) {
+ auto dim = e.dyn_cast<AffineDimExpr>();
+ return dim && dim.getPosition() <
+ permutationMap.getNumResults();
+ }) &&
+ "Unexpected map results depend on higher rank positions");
+ // Project on the first domain dimensions to allow composition.
+ permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
+ permutationMap.getResults(), ctx);
+
+ extractOp.setOperand(transposeOp.vector());
+ // Compose the inverse permutation map with the extractedPos.
+ auto newExtractedPos =
+ inversePermutation(permutationMap).compose(extractedPos);
+ // OpBuilder is only used as a helper to build an I64ArrayAttr.
+ OpBuilder b(extractOp.getContext());
+ extractOp.setAttr(ExtractOp::getPositionAttrName(),
+ b.getI64ArrayAttr(newExtractedPos));
+
+ return success();
+}
+
/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
/// result is always the input to some InsertOp.
static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
@@ -689,6 +747,8 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
+ if (succeeded(foldExtractOpFromTranspose(*this)))
+ return getResult();
if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
return val;
return OpFoldResult();
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index c17df954558b..b09c51a3abbb 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -330,6 +330,21 @@ AffineMap AffineMap::compose(AffineMap map) {
return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
}
+SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) {
+ assert(getNumSymbols() == 0 && "Expected symbol-less map");
+ SmallVector<AffineExpr, 4> exprs;
+ exprs.reserve(values.size());
+ MLIRContext *ctx = getContext();
+ for (auto v : values)
+ exprs.push_back(getAffineConstantExpr(v, ctx));
+ auto resMap = compose(AffineMap::get(0, 0, exprs, ctx));
+ SmallVector<int64_t, 4> res;
+ res.reserve(resMap.getNumResults());
+ for (auto e : resMap.getResults())
+ res.push_back(e.cast<AffineConstantExpr>().getValue());
+ return res;
+}
+
bool AffineMap::isProjectedPermutation() {
if (getNumSymbols() > 0)
return false;
@@ -360,6 +375,14 @@ AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
}
+AffineMap AffineMap::getMajorSubMap(unsigned numResults) {
+ if (numResults == 0)
+ return AffineMap();
+ if (numResults > getNumResults())
+ return *this;
+ return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults)));
+}
+
AffineMap AffineMap::getMinorSubMap(unsigned numResults) {
if (numResults == 0)
return AffineMap();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 94f3f627e777..836a9869248d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -300,13 +300,48 @@ func @insert_extract_transpose_3d_2d(
// CHECK-LABEL: fold_extracts
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
-// CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
-// CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
-// CHECK-NEXT: return
func @fold_extracts(%a : vector<3x4x5x6xf32>) -> (f32, vector<4x5x6xf32>) {
%b = vector.extract %a[0] : vector<3x4x5x6xf32>
%c = vector.extract %b[1, 2] : vector<4x5x6xf32>
+ // CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
%d = vector.extract %c[3] : vector<6xf32>
+
+ // CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
%e = vector.extract %a[0] : vector<3x4x5x6xf32>
+
+ // CHECK-NEXT: return
return %d, %e : f32, vector<4x5x6xf32>
}
+
+// -----
+
+// CHECK-LABEL: fold_extract_transpose
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x6x5x6xf32>
+func @fold_extract_transpose(
+ %a : vector<3x4x5x6xf32>, %b : vector<3x6x5x6xf32>) -> (
+ vector<6xf32>, vector<6xf32>, vector<6xf32>) {
+ // [3] is a proper most minor identity map in transpose.
+ // Permutation is a self inverse and we have.
+ // [0, 2, 1] ^ -1 o [0, 1, 2] = [0, 2, 1] o [0, 1, 2]
+ // = [0, 2, 1]
+ // CHECK-NEXT: vector.extract %[[A]][0, 2, 1] : vector<3x4x5x6xf32>
+ %0 = vector.transpose %a, [0, 2, 1, 3] : vector<3x4x5x6xf32> to vector<3x5x4x6xf32>
+ %1 = vector.extract %0[0, 1, 2] : vector<3x5x4x6xf32>
+
+ // [3] is a proper most minor identity map in transpose.
+ // Permutation is a not self inverse and we have.
+ // [1, 2, 0] ^ -1 o [0, 1, 2] = [2, 0, 1] o [0, 1, 2]
+ // = [2, 0, 1]
+ // CHECK-NEXT: vector.extract %[[A]][2, 0, 1] : vector<3x4x5x6xf32>
+ %2 = vector.transpose %a, [1, 2, 0, 3] : vector<3x4x5x6xf32> to vector<4x5x3x6xf32>
+ %3 = vector.extract %2[0, 1, 2] : vector<4x5x3x6xf32>
+
+ // Not a minor identity map so intra-vector level has been permuted
+ // CHECK-NEXT: vector.transpose %[[B]], [0, 2, 3, 1]
+ // CHECK-NEXT: vector.extract %{{.*}}[0, 1, 2]
+ %4 = vector.transpose %b, [0, 2, 3, 1] : vector<3x6x5x6xf32> to vector<3x5x6x6xf32>
+ %5 = vector.extract %4[0, 1, 2] : vector<3x5x6x6xf32>
+
+ return %1, %3, %5 : vector<6xf32>, vector<6xf32>, vector<6xf32>
+}