summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2023-05-17 08:57:13 -0700
committerLei Zhang <antiagainst@google.com>2023-05-17 09:01:19 -0700
commite000b62a342cac907fd77cfdd070f0b055f0c3c4 (patch)
tree45564fc0513335d152027662e842538b271157a3 /mlir
parent4eab303404d6bb2252b4baf807c5ac87a0fa3125 (diff)
downloadllvm-e000b62a342cac907fd77cfdd070f0b055f0c3c4.tar.gz
[mlir][vector] Separate out vector transfer + tensor slice patterns
These patterns touches the structure generated from tiling so it affects later steps like bufferization and vector hoisting. Instead of putting them in canonicalization, this commit creates separate entry points for them to be called explicitly. This is NFC regarding the functionality and tests of those patterns. It also addresses two TODO items in the codebase. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D150702
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h11
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp204
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp237
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir110
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir109
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp22
8 files changed, 384 insertions, 312 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 325860079b3d..2912c0252872 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -218,6 +218,17 @@ void populateBreakDownVectorBitCastOpPatterns(
void populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Collect patterns to fold tensor.extract_slice -> vector.transfer_read and
+/// vector.transfer_write -> tensor.insert_slice op chains into vector tranfer
+/// read and write ops.
+///
+/// If `controlFn` is not nullptr, the pattern will only apply to ops where
+/// `controlFn` returns true, given the vector transfer read/write op as input.
+void populateVectorTransferTensorSliceTransforms(
+ RewritePatternSet &patterns,
+ std::function<bool(Operation *vectorOp)> controlFn = nullptr,
+ PatternBenefit benefit = 1);
+
/// Collect a set of pattern to unroll vector operations to a smaller shapes.
/// `options` structure controls which operations are unrolled and the target
/// shape.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index baabf1ae67fc..4fe9b9fab6a7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -29,6 +29,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
@@ -2880,6 +2881,7 @@ transform::VectorizeOp::applyToOne(Operation *target,
/*benefit=*/2);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
+ vector::populateVectorTransferTensorSliceTransforms(patterns);
patterns.add<CopyVectorizationPattern>(ctx);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 80ea03f6e8d8..1549237f8c9f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3692,108 +3692,7 @@ void TransferReadOp::getEffects(
SideEffects::DefaultResource::get());
}
-/// Returns true if all rank reduced in the given `extractOp` happen in leading
-/// dimensions earlier than last `trailingRank` dimensions.
-static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
- unsigned trailingRank) {
- // If no ranks are reduced at all, it's a degenerated case; always true.
- if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
- return true;
-
- RankedTensorType inferredType = extractOp.inferResultType(
- extractOp.getSourceType(), extractOp.getMixedOffsets(),
- extractOp.getMixedSizes(), extractOp.getMixedStrides());
- return extractOp.getType().getShape().take_back(trailingRank) ==
- inferredType.getShape().take_back(trailingRank);
-}
-
namespace {
-/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
-///
-/// ```
-/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
-/// : tensor<?x?xf32> to tensor<?x?xf32>
-/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
-/// : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %p0 = arith.addi %a, %e : index
-/// %p1 = arith.addi %b, %f : index
-/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
-/// : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-struct FoldExtractSliceIntoTransferRead
- : public OpRewritePattern<TransferReadOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TransferReadOp xferOp,
- PatternRewriter &rewriter) const override {
- // TODO: support 0-d corner case.
- if (xferOp.getTransferRank() == 0)
- return failure();
- if (xferOp.hasOutOfBoundsDim())
- return failure();
- if (!xferOp.getPermutationMap().isMinorIdentity())
- return failure();
- if (xferOp.getMask())
- return failure();
- auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
- if (!extractOp)
- return failure();
- if (!extractOp.hasUnitStride())
- return failure();
-
- // Bail on illegal rank-reduction: we need to check that the rank-reduced
- // dims are exactly the leading dims. I.e. the following is illegal:
- // ```
- // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
- // tensor<2x1x4xf32> to tensor<2x4xf32>
- // %1 = vector.transfer_read %0[0,0], %cst :
- // tensor<2x4xf32>, vector<2x4xf32>
- // ```
- //
- // Cannot fold into:
- // ```
- // %0 = vector.transfer_read %t[0,0,0], %cst :
- // tensor<2x1x4xf32>, vector<2x4xf32>
- // ```
- // For this, check the trailing `vectorRank` dims of the extract_slice
- // result tensor match the trailing dims of the inferred result tensor.
- if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
- return failure();
-
- int64_t rankReduced =
- extractOp.getSourceType().getRank() - extractOp.getType().getRank();
-
- SmallVector<Value> newIndices;
- // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
- // indices first.
- for (int64_t i = 0; i < rankReduced; ++i) {
- OpFoldResult offset = extractOp.getMixedOffsets()[i];
- newIndices.push_back(getValueOrCreateConstantIndexOp(
- rewriter, extractOp.getLoc(), offset));
- }
- for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
- OpFoldResult offset =
- extractOp.getMixedOffsets()[it.index() + rankReduced];
- newIndices.push_back(rewriter.create<arith::AddIOp>(
- xferOp->getLoc(), it.value(),
- getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
- offset)));
- }
- SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
- rewriter.replaceOpWithNewOp<TransferReadOp>(
- xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
- xferOp.getPadding(), ArrayRef<bool>{inBounds});
-
- return success();
- }
-};
-
/// Store to load forwarding for transfer operations with permuation maps.
/// Even if the permutation maps are different we can still propagate the store
/// into the load if the size of the dimensions read and written match. Then we
@@ -3875,13 +3774,7 @@ struct TransferReadAfterWriteToBroadcast
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // clang-format off
- results.add <
- // TODO: this is brittle and should be deprecated in favor of a
- // more general pattern that applies on-demand.
- FoldExtractSliceIntoTransferRead,
- TransferReadAfterWriteToBroadcast>(context);
- // clang-format on
+ results.add<TransferReadAfterWriteToBroadcast>(context);
}
//===----------------------------------------------------------------------===//
@@ -4217,93 +4110,6 @@ public:
}
};
-/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
-/// could directly write to the insert_slice's destination. E.g.:
-///
-/// ```
-/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
-/// : vector<4x5xf32>, tensor<4x5xf32>
-/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
-/// : tensor<4x5xf32> into tensor<?x?xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
-/// : vector<4x5xf32>, tensor<?x?xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-struct FoldInsertSliceIntoTransferWrite
- : public OpRewritePattern<tensor::InsertSliceOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
- PatternRewriter &rewriter) const override {
- if (!insertOp.hasUnitStride())
- return failure();
-
- auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
- if (!xferOp)
- return failure();
- // TODO: support 0-d corner case.
- if (xferOp.getTransferRank() == 0)
- return failure();
-
- if (xferOp.hasOutOfBoundsDim())
- return failure();
- if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
- return failure();
- if (xferOp.getMask())
- return failure();
- // Fold only if the TransferWriteOp completely overwrites the `source` with
- // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
- // content is the data of the vector.
- if (!llvm::equal(xferOp.getVectorType().getShape(),
- xferOp.getShapedType().getShape()))
- return failure();
- if (!xferOp.getPermutationMap().isIdentity())
- return failure();
-
- // Bail on illegal rank-reduction: we need to check that the rank-reduced
- // dims are exactly the leading dims. I.e. the following is illegal:
- // ```
- // %0 = vector.transfer_write %v, %t[0,0], %cst :
- // vector<2x4xf32>, tensor<2x4xf32>
- // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
- // tensor<2x4xf32> into tensor<2x1x4xf32>
- // ```
- //
- // Cannot fold into:
- // ```
- // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
- // vector<2x4xf32>, tensor<2x1x4xf32>
- // ```
- // For this, check the trailing `vectorRank` dims of the insert_slice result
- // tensor match the trailing dims of the inferred result tensor.
- int64_t rankReduced =
- insertOp.getType().getRank() - insertOp.getSourceType().getRank();
- int64_t vectorRank = xferOp.getVectorType().getRank();
- RankedTensorType inferredSourceTensorType =
- tensor::ExtractSliceOp::inferResultType(
- insertOp.getType(), insertOp.getMixedOffsets(),
- insertOp.getMixedSizes(), insertOp.getMixedStrides());
- auto actualSourceTensorShape = insertOp.getSourceType().getShape();
- if (rankReduced > 0 &&
- actualSourceTensorShape.take_back(vectorRank) !=
- inferredSourceTensorType.getShape().take_back(vectorRank))
- return failure();
-
- SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
- rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
- SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
- rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
- insertOp.getDest(), indices,
- ArrayRef<bool>{inBounds});
- return success();
- }
-};
-
/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
/// overwritten and inserted into another tensor. After this rewrite, the
@@ -4415,13 +4221,7 @@ public:
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // clang-format off
- results.add<FoldWaw,
- // TODO: this is brittle and should be deprecated in favor of a
- // more general pattern that applies on-demand.
- FoldInsertSliceIntoTransferWrite,
- SwapExtractSliceOfTransferWrite>(context);
- // clang-format on
+ results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index deba91573e0f..2d269ca3555d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorDropLeadUnitDim.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorTransferOpTransforms.cpp
+ VectorTransferTensorSliceTransforms.cpp
VectorTransferSplitRewritePatterns.cpp
VectorTransforms.cpp
VectorUnroll.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
new file mode 100644
index 000000000000..b3bd2cc85dfe
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
@@ -0,0 +1,237 @@
+//===- VectorTransferTensorSliceTransforms.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+/// Returns true if all rank reduced in the given `extractOp` happen in leading
+/// dimensions earlier than last `trailingRank` dimensions.
+static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
+ unsigned trailingRank) {
+ // If no ranks are reduced at all, it's a degenerated case; always true.
+ if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
+ return true;
+
+ RankedTensorType inferredType = extractOp.inferResultType(
+ extractOp.getSourceType(), extractOp.getMixedOffsets(),
+ extractOp.getMixedSizes(), extractOp.getMixedStrides());
+ return extractOp.getType().getShape().take_back(trailingRank) ==
+ inferredType.getShape().take_back(trailingRank);
+}
+
+namespace {
+/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
+///
+/// ```
+/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
+/// : tensor<?x?xf32> to tensor<?x?xf32>
+/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
+/// : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %p0 = arith.addi %a, %e : index
+/// %p1 = arith.addi %b, %f : index
+/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
+/// : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
+class FoldExtractSliceIntoTransferRead final
+ : public OpRewritePattern<vector::TransferReadOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ FoldExtractSliceIntoTransferRead(MLIRContext *context,
+ std::function<bool(Operation *op)> controlFn,
+ PatternBenefit benefit)
+ : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
+ PatternRewriter &rewriter) const override {
+ if (controlFn && !controlFn(xferOp))
+ return failure();
+
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
+ if (xferOp.hasOutOfBoundsDim())
+ return failure();
+ if (!xferOp.getPermutationMap().isMinorIdentity())
+ return failure();
+ if (xferOp.getMask())
+ return failure();
+ auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ return failure();
+ if (!extractOp.hasUnitStride())
+ return failure();
+
+ // Bail on illegal rank-reduction: we need to check that the rank-reduced
+ // dims are exactly the leading dims. I.e. the following is illegal:
+ // ```
+ // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
+ // tensor<2x1x4xf32> to tensor<2x4xf32>
+ // %1 = vector.transfer_read %0[0,0], %cst :
+ // tensor<2x4xf32>, vector<2x4xf32>
+ // ```
+ //
+ // Cannot fold into:
+ // ```
+ // %0 = vector.transfer_read %t[0,0,0], %cst :
+ // tensor<2x1x4xf32>, vector<2x4xf32>
+ // ```
+ // For this, check the trailing `vectorRank` dims of the extract_slice
+ // result tensor match the trailing dims of the inferred result tensor.
+ if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
+ return failure();
+
+ int64_t rankReduced =
+ extractOp.getSourceType().getRank() - extractOp.getType().getRank();
+
+ SmallVector<Value> newIndices;
+ // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
+ // indices first.
+ for (int64_t i = 0; i < rankReduced; ++i) {
+ OpFoldResult offset = extractOp.getMixedOffsets()[i];
+ newIndices.push_back(getValueOrCreateConstantIndexOp(
+ rewriter, extractOp.getLoc(), offset));
+ }
+ for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
+ OpFoldResult offset =
+ extractOp.getMixedOffsets()[it.index() + rankReduced];
+ newIndices.push_back(rewriter.create<arith::AddIOp>(
+ xferOp->getLoc(), it.value(),
+ getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
+ offset)));
+ }
+ SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
+ xferOp.getPadding(), ArrayRef<bool>{inBounds});
+
+ return success();
+ }
+
+private:
+ std::function<bool(Operation *)> controlFn;
+};
+
+/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
+/// could directly write to the insert_slice's destination. E.g.:
+///
+/// ```
+/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
+/// : vector<4x5xf32>, tensor<4x5xf32>
+/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
+/// : tensor<4x5xf32> into tensor<?x?xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
+/// : vector<4x5xf32>, tensor<?x?xf32>
+/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
+class FoldInsertSliceIntoTransferWrite final
+ : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ FoldInsertSliceIntoTransferWrite(MLIRContext *context,
+ std::function<bool(Operation *op)> controlFn,
+ PatternBenefit benefit)
+ : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+ PatternRewriter &rewriter) const override {
+ if (!insertOp.hasUnitStride())
+ return failure();
+
+ auto xferOp = insertOp.getSource().getDefiningOp<vector::TransferWriteOp>();
+ if (!xferOp)
+ return failure();
+ if (controlFn && !controlFn(xferOp))
+ return failure();
+
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
+
+ if (xferOp.hasOutOfBoundsDim())
+ return failure();
+ if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
+ return failure();
+ if (xferOp.getMask())
+ return failure();
+ // Fold only if the TransferWriteOp completely overwrites the `source` with
+ // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
+ // content is the data of the vector.
+ if (!llvm::equal(xferOp.getVectorType().getShape(),
+ xferOp.getShapedType().getShape()))
+ return failure();
+ if (!xferOp.getPermutationMap().isIdentity())
+ return failure();
+
+ // Bail on illegal rank-reduction: we need to check that the rank-reduced
+ // dims are exactly the leading dims. I.e. the following is illegal:
+ // ```
+ // %0 = vector.transfer_write %v, %t[0,0], %cst :
+ // vector<2x4xf32>, tensor<2x4xf32>
+ // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
+ // tensor<2x4xf32> into tensor<2x1x4xf32>
+ // ```
+ //
+ // Cannot fold into:
+ // ```
+ // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
+ // vector<2x4xf32>, tensor<2x1x4xf32>
+ // ```
+ // For this, check the trailing `vectorRank` dims of the insert_slice result
+ // tensor match the trailing dims of the inferred result tensor.
+ int64_t rankReduced =
+ insertOp.getType().getRank() - insertOp.getSourceType().getRank();
+ int64_t vectorRank = xferOp.getVectorType().getRank();
+ RankedTensorType inferredSourceTensorType =
+ tensor::ExtractSliceOp::inferResultType(
+ insertOp.getType(), insertOp.getMixedOffsets(),
+ insertOp.getMixedSizes(), insertOp.getMixedStrides());
+ auto actualSourceTensorShape = insertOp.getSourceType().getShape();
+ if (rankReduced > 0 &&
+ actualSourceTensorShape.take_back(vectorRank) !=
+ inferredSourceTensorType.getShape().take_back(vectorRank))
+ return failure();
+
+ SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
+ rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
+ SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ insertOp, xferOp.getVector(), insertOp.getDest(), indices,
+ ArrayRef<bool>{inBounds});
+ return success();
+ }
+
+private:
+ std::function<bool(Operation *)> controlFn;
+};
+
+} // namespace
+
+void vector::populateVectorTransferTensorSliceTransforms(
+ RewritePatternSet &patterns,
+ std::function<bool(Operation *vectorOp)> controlFn,
+ PatternBenefit benefit) {
+ patterns
+ .add<FoldExtractSliceIntoTransferRead, FoldInsertSliceIntoTransferWrite>(
+ patterns.getContext(), controlFn, benefit);
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 88c91ff46a8b..4ce4350f0e4f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1201,116 +1201,6 @@ func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
// -----
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
-// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
-// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
-// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
-// CHECK: return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %cst = arith.constant 0.0 : f32
- %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
- %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
- return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
-// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
-// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
-// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
-// CHECK: return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %cst = arith.constant 0.0 : f32
- %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
- %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
- return %1 : vector<6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
-// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index
-// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c3]]
-// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
-// CHECK: return %[[r]]
-func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %cst = arith.constant 0.0 : f32
- %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
- %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
- return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
-// CHECK: extract_slice
-// CHECK: vector.transfer_read
-func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %cst = arith.constant 0.0 : f32
- %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
- %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
- return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
-// CHECK: %[[c3:.*]] = arith.constant 3 : index
-// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
-// CHECK: return %[[r]]
-func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
- %c0 = arith.constant 0 : index
- %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
- %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
- return %1 : tensor<?x12xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
-// CHECK: vector.transfer_write
-// CHECK: insert_slice
-func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
- %c0 = arith.constant 0 : index
- %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
- %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
- return %1 : tensor<?x?x12xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
-// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
-// CHECK: return %[[r]]
-func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
- %c0 = arith.constant 0 : index
- %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
- %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
- return %1 : tensor<?x?x12xf32>
-}
-
-// -----
-
// CHECK: #[[$MAP:[0-9a-z]+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: func @swap_extract_slice_transfer_write
diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
new file mode 100644
index 000000000000..cc17025fe0f1
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt -split-input-file -test-vector-transfer-tensor-slice-patterns %s | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
+// CHECK: return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
+// CHECK: return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
+ return %1 : vector<6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index
+// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c3]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
+// CHECK: return %[[r]]
+func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
+// CHECK: extract_slice
+// CHECK: vector.transfer_read
+func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+// CHECK: %[[c3:.*]] = arith.constant 3 : index
+// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
+// CHECK: return %[[r]]
+func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
+ return %1 : tensor<?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
+// CHECK: vector.transfer_write
+// CHECK: insert_slice
+func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+ return %1 : tensor<?x?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
+// CHECK: return %[[r]]
+func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+ return %1 : tensor<?x?x12xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d0c79ab98915..50dfeff635cc 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -679,6 +679,26 @@ struct TestVectorGatherLowering
}
};
+struct TestVectorTransferTensorSlicePatterns
+ : public PassWrapper<TestVectorTransferTensorSlicePatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorTransferTensorSlicePatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-transfer-tensor-slice-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns that fold vector transfer and tensor slice ops";
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorTransferTensorSliceTransforms(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
} // namespace
namespace mlir {
@@ -713,6 +733,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestCreateVectorBroadcast>();
PassRegistration<TestVectorGatherLowering>();
+
+ PassRegistration<TestVectorTransferTensorSlicePatterns>();
}
} // namespace test
} // namespace mlir