summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h11
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp204
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp237
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir110
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir109
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp22
8 files changed, 384 insertions, 312 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 325860079b3d..2912c0252872 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -218,6 +218,17 @@ void populateBreakDownVectorBitCastOpPatterns(
void populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Collect patterns to fold tensor.extract_slice -> vector.transfer_read and
+/// vector.transfer_write -> tensor.insert_slice op chains into vector tranfer
+/// read and write ops.
+///
+/// If `controlFn` is not nullptr, the pattern will only apply to ops where
+/// `controlFn` returns true, given the vector transfer read/write op as input.
+void populateVectorTransferTensorSliceTransforms(
+ RewritePatternSet &patterns,
+ std::function<bool(Operation *vectorOp)> controlFn = nullptr,
+ PatternBenefit benefit = 1);
+
/// Collect a set of pattern to unroll vector operations to a smaller shapes.
/// `options` structure controls which operations are unrolled and the target
/// shape.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index baabf1ae67fc..4fe9b9fab6a7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -29,6 +29,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
@@ -2880,6 +2881,7 @@ transform::VectorizeOp::applyToOne(Operation *target,
/*benefit=*/2);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
+ vector::populateVectorTransferTensorSliceTransforms(patterns);
patterns.add<CopyVectorizationPattern>(ctx);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 80ea03f6e8d8..1549237f8c9f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3692,108 +3692,7 @@ void TransferReadOp::getEffects(
SideEffects::DefaultResource::get());
}
-/// Returns true if all rank reduced in the given `extractOp` happen in leading
-/// dimensions earlier than last `trailingRank` dimensions.
-static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
- unsigned trailingRank) {
- // If no ranks are reduced at all, it's a degenerated case; always true.
- if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
- return true;
-
- RankedTensorType inferredType = extractOp.inferResultType(
- extractOp.getSourceType(), extractOp.getMixedOffsets(),
- extractOp.getMixedSizes(), extractOp.getMixedStrides());
- return extractOp.getType().getShape().take_back(trailingRank) ==
- inferredType.getShape().take_back(trailingRank);
-}
-
namespace {
-/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
-///
-/// ```
-/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
-/// : tensor<?x?xf32> to tensor<?x?xf32>
-/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
-/// : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %p0 = arith.addi %a, %e : index
-/// %p1 = arith.addi %b, %f : index
-/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
-/// : tensor<?x?xf32>, vector<4x5xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-struct FoldExtractSliceIntoTransferRead
- : public OpRewritePattern<TransferReadOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TransferReadOp xferOp,
- PatternRewriter &rewriter) const override {
- // TODO: support 0-d corner case.
- if (xferOp.getTransferRank() == 0)
- return failure();
- if (xferOp.hasOutOfBoundsDim())
- return failure();
- if (!xferOp.getPermutationMap().isMinorIdentity())
- return failure();
- if (xferOp.getMask())
- return failure();
- auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
- if (!extractOp)
- return failure();
- if (!extractOp.hasUnitStride())
- return failure();
-
- // Bail on illegal rank-reduction: we need to check that the rank-reduced
- // dims are exactly the leading dims. I.e. the following is illegal:
- // ```
- // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
- // tensor<2x1x4xf32> to tensor<2x4xf32>
- // %1 = vector.transfer_read %0[0,0], %cst :
- // tensor<2x4xf32>, vector<2x4xf32>
- // ```
- //
- // Cannot fold into:
- // ```
- // %0 = vector.transfer_read %t[0,0,0], %cst :
- // tensor<2x1x4xf32>, vector<2x4xf32>
- // ```
- // For this, check the trailing `vectorRank` dims of the extract_slice
- // result tensor match the trailing dims of the inferred result tensor.
- if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
- return failure();
-
- int64_t rankReduced =
- extractOp.getSourceType().getRank() - extractOp.getType().getRank();
-
- SmallVector<Value> newIndices;
- // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
- // indices first.
- for (int64_t i = 0; i < rankReduced; ++i) {
- OpFoldResult offset = extractOp.getMixedOffsets()[i];
- newIndices.push_back(getValueOrCreateConstantIndexOp(
- rewriter, extractOp.getLoc(), offset));
- }
- for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
- OpFoldResult offset =
- extractOp.getMixedOffsets()[it.index() + rankReduced];
- newIndices.push_back(rewriter.create<arith::AddIOp>(
- xferOp->getLoc(), it.value(),
- getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
- offset)));
- }
- SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
- rewriter.replaceOpWithNewOp<TransferReadOp>(
- xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
- xferOp.getPadding(), ArrayRef<bool>{inBounds});
-
- return success();
- }
-};
-
/// Store to load forwarding for transfer operations with permuation maps.
/// Even if the permutation maps are different we can still propagate the store
/// into the load if the size of the dimensions read and written match. Then we
@@ -3875,13 +3774,7 @@ struct TransferReadAfterWriteToBroadcast
void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // clang-format off
- results.add <
- // TODO: this is brittle and should be deprecated in favor of a
- // more general pattern that applies on-demand.
- FoldExtractSliceIntoTransferRead,
- TransferReadAfterWriteToBroadcast>(context);
- // clang-format on
+ results.add<TransferReadAfterWriteToBroadcast>(context);
}
//===----------------------------------------------------------------------===//
@@ -4217,93 +4110,6 @@ public:
}
};
-/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
-/// could directly write to the insert_slice's destination. E.g.:
-///
-/// ```
-/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
-/// : vector<4x5xf32>, tensor<4x5xf32>
-/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
-/// : tensor<4x5xf32> into tensor<?x?xf32>
-/// ```
-/// is rewritten to:
-/// ```
-/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
-/// : vector<4x5xf32>, tensor<?x?xf32>
-/// ```
-// TODO: this is brittle and should be deprecated in favor of a more general
-// pattern that applies on-demand.
-struct FoldInsertSliceIntoTransferWrite
- : public OpRewritePattern<tensor::InsertSliceOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
- PatternRewriter &rewriter) const override {
- if (!insertOp.hasUnitStride())
- return failure();
-
- auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
- if (!xferOp)
- return failure();
- // TODO: support 0-d corner case.
- if (xferOp.getTransferRank() == 0)
- return failure();
-
- if (xferOp.hasOutOfBoundsDim())
- return failure();
- if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
- return failure();
- if (xferOp.getMask())
- return failure();
- // Fold only if the TransferWriteOp completely overwrites the `source` with
- // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
- // content is the data of the vector.
- if (!llvm::equal(xferOp.getVectorType().getShape(),
- xferOp.getShapedType().getShape()))
- return failure();
- if (!xferOp.getPermutationMap().isIdentity())
- return failure();
-
- // Bail on illegal rank-reduction: we need to check that the rank-reduced
- // dims are exactly the leading dims. I.e. the following is illegal:
- // ```
- // %0 = vector.transfer_write %v, %t[0,0], %cst :
- // vector<2x4xf32>, tensor<2x4xf32>
- // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
- // tensor<2x4xf32> into tensor<2x1x4xf32>
- // ```
- //
- // Cannot fold into:
- // ```
- // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
- // vector<2x4xf32>, tensor<2x1x4xf32>
- // ```
- // For this, check the trailing `vectorRank` dims of the insert_slice result
- // tensor match the trailing dims of the inferred result tensor.
- int64_t rankReduced =
- insertOp.getType().getRank() - insertOp.getSourceType().getRank();
- int64_t vectorRank = xferOp.getVectorType().getRank();
- RankedTensorType inferredSourceTensorType =
- tensor::ExtractSliceOp::inferResultType(
- insertOp.getType(), insertOp.getMixedOffsets(),
- insertOp.getMixedSizes(), insertOp.getMixedStrides());
- auto actualSourceTensorShape = insertOp.getSourceType().getShape();
- if (rankReduced > 0 &&
- actualSourceTensorShape.take_back(vectorRank) !=
- inferredSourceTensorType.getShape().take_back(vectorRank))
- return failure();
-
- SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
- rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
- SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
- rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
- insertOp.getDest(), indices,
- ArrayRef<bool>{inBounds});
- return success();
- }
-};
-
/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
/// overwritten and inserted into another tensor. After this rewrite, the
@@ -4415,13 +4221,7 @@ public:
void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // clang-format off
- results.add<FoldWaw,
- // TODO: this is brittle and should be deprecated in favor of a
- // more general pattern that applies on-demand.
- FoldInsertSliceIntoTransferWrite,
- SwapExtractSliceOfTransferWrite>(context);
- // clang-format on
+ results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index deba91573e0f..2d269ca3555d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorDropLeadUnitDim.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorTransferOpTransforms.cpp
+ VectorTransferTensorSliceTransforms.cpp
VectorTransferSplitRewritePatterns.cpp
VectorTransforms.cpp
VectorUnroll.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
new file mode 100644
index 000000000000..b3bd2cc85dfe
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferTensorSliceTransforms.cpp
@@ -0,0 +1,237 @@
+//===- VectorTransferTensorSliceTransforms.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+/// Returns true if all rank reduced in the given `extractOp` happen in leading
+/// dimensions earlier than last `trailingRank` dimensions.
+static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
+ unsigned trailingRank) {
+ // If no ranks are reduced at all, it's a degenerated case; always true.
+ if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
+ return true;
+
+ RankedTensorType inferredType = extractOp.inferResultType(
+ extractOp.getSourceType(), extractOp.getMixedOffsets(),
+ extractOp.getMixedSizes(), extractOp.getMixedStrides());
+ return extractOp.getType().getShape().take_back(trailingRank) ==
+ inferredType.getShape().take_back(trailingRank);
+}
+
+namespace {
+/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
+///
+/// ```
+/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
+/// : tensor<?x?xf32> to tensor<?x?xf32>
+/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
+/// : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %p0 = arith.addi %a, %e : index
+/// %p1 = arith.addi %b, %f : index
+/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
+/// : tensor<?x?xf32>, vector<4x5xf32>
+/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
+class FoldExtractSliceIntoTransferRead final
+ : public OpRewritePattern<vector::TransferReadOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ FoldExtractSliceIntoTransferRead(MLIRContext *context,
+ std::function<bool(Operation *op)> controlFn,
+ PatternBenefit benefit)
+ : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
+ PatternRewriter &rewriter) const override {
+ if (controlFn && !controlFn(xferOp))
+ return failure();
+
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
+ if (xferOp.hasOutOfBoundsDim())
+ return failure();
+ if (!xferOp.getPermutationMap().isMinorIdentity())
+ return failure();
+ if (xferOp.getMask())
+ return failure();
+ auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractOp)
+ return failure();
+ if (!extractOp.hasUnitStride())
+ return failure();
+
+ // Bail on illegal rank-reduction: we need to check that the rank-reduced
+ // dims are exactly the leading dims. I.e. the following is illegal:
+ // ```
+ // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
+ // tensor<2x1x4xf32> to tensor<2x4xf32>
+ // %1 = vector.transfer_read %0[0,0], %cst :
+ // tensor<2x4xf32>, vector<2x4xf32>
+ // ```
+ //
+ // Cannot fold into:
+ // ```
+ // %0 = vector.transfer_read %t[0,0,0], %cst :
+ // tensor<2x1x4xf32>, vector<2x4xf32>
+ // ```
+ // For this, check the trailing `vectorRank` dims of the extract_slice
+ // result tensor match the trailing dims of the inferred result tensor.
+ if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
+ return failure();
+
+ int64_t rankReduced =
+ extractOp.getSourceType().getRank() - extractOp.getType().getRank();
+
+ SmallVector<Value> newIndices;
+ // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
+ // indices first.
+ for (int64_t i = 0; i < rankReduced; ++i) {
+ OpFoldResult offset = extractOp.getMixedOffsets()[i];
+ newIndices.push_back(getValueOrCreateConstantIndexOp(
+ rewriter, extractOp.getLoc(), offset));
+ }
+ for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
+ OpFoldResult offset =
+ extractOp.getMixedOffsets()[it.index() + rankReduced];
+ newIndices.push_back(rewriter.create<arith::AddIOp>(
+ xferOp->getLoc(), it.value(),
+ getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
+ offset)));
+ }
+ SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
+ xferOp.getPadding(), ArrayRef<bool>{inBounds});
+
+ return success();
+ }
+
+private:
+ std::function<bool(Operation *)> controlFn;
+};
+
+/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
+/// could directly write to the insert_slice's destination. E.g.:
+///
+/// ```
+/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
+/// : vector<4x5xf32>, tensor<4x5xf32>
+/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
+/// : tensor<4x5xf32> into tensor<?x?xf32>
+/// ```
+/// is rewritten to:
+/// ```
+/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
+/// : vector<4x5xf32>, tensor<?x?xf32>
+/// ```
+// TODO: this is brittle and should be deprecated in favor of a more general
+// pattern that applies on-demand.
+class FoldInsertSliceIntoTransferWrite final
+ : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ FoldInsertSliceIntoTransferWrite(MLIRContext *context,
+ std::function<bool(Operation *op)> controlFn,
+ PatternBenefit benefit)
+ : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+ PatternRewriter &rewriter) const override {
+ if (!insertOp.hasUnitStride())
+ return failure();
+
+ auto xferOp = insertOp.getSource().getDefiningOp<vector::TransferWriteOp>();
+ if (!xferOp)
+ return failure();
+ if (controlFn && !controlFn(xferOp))
+ return failure();
+
+ // TODO: support 0-d corner case.
+ if (xferOp.getTransferRank() == 0)
+ return failure();
+
+ if (xferOp.hasOutOfBoundsDim())
+ return failure();
+ if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
+ return failure();
+ if (xferOp.getMask())
+ return failure();
+ // Fold only if the TransferWriteOp completely overwrites the `source` with
+ // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
+ // content is the data of the vector.
+ if (!llvm::equal(xferOp.getVectorType().getShape(),
+ xferOp.getShapedType().getShape()))
+ return failure();
+ if (!xferOp.getPermutationMap().isIdentity())
+ return failure();
+
+ // Bail on illegal rank-reduction: we need to check that the rank-reduced
+ // dims are exactly the leading dims. I.e. the following is illegal:
+ // ```
+ // %0 = vector.transfer_write %v, %t[0,0], %cst :
+ // vector<2x4xf32>, tensor<2x4xf32>
+ // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
+ // tensor<2x4xf32> into tensor<2x1x4xf32>
+ // ```
+ //
+ // Cannot fold into:
+ // ```
+ // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
+ // vector<2x4xf32>, tensor<2x1x4xf32>
+ // ```
+ // For this, check the trailing `vectorRank` dims of the insert_slice result
+ // tensor match the trailing dims of the inferred result tensor.
+ int64_t rankReduced =
+ insertOp.getType().getRank() - insertOp.getSourceType().getRank();
+ int64_t vectorRank = xferOp.getVectorType().getRank();
+ RankedTensorType inferredSourceTensorType =
+ tensor::ExtractSliceOp::inferResultType(
+ insertOp.getType(), insertOp.getMixedOffsets(),
+ insertOp.getMixedSizes(), insertOp.getMixedStrides());
+ auto actualSourceTensorShape = insertOp.getSourceType().getShape();
+ if (rankReduced > 0 &&
+ actualSourceTensorShape.take_back(vectorRank) !=
+ inferredSourceTensorType.getShape().take_back(vectorRank))
+ return failure();
+
+ SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
+ rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
+ SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ insertOp, xferOp.getVector(), insertOp.getDest(), indices,
+ ArrayRef<bool>{inBounds});
+ return success();
+ }
+
+private:
+ std::function<bool(Operation *)> controlFn;
+};
+
+} // namespace
+
+void vector::populateVectorTransferTensorSliceTransforms(
+ RewritePatternSet &patterns,
+ std::function<bool(Operation *vectorOp)> controlFn,
+ PatternBenefit benefit) {
+ patterns
+ .add<FoldExtractSliceIntoTransferRead, FoldInsertSliceIntoTransferWrite>(
+ patterns.getContext(), controlFn, benefit);
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 88c91ff46a8b..4ce4350f0e4f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1201,116 +1201,6 @@ func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
// -----
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
-// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
-// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
-// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
-// CHECK: return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %cst = arith.constant 0.0 : f32
- %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
- %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
- return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
-// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
-// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
-// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
-// CHECK: return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %cst = arith.constant 0.0 : f32
- %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
- %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
- return %1 : vector<6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
-// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
-// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
-// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index
-// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c3]]
-// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
-// CHECK: return %[[r]]
-func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %cst = arith.constant 0.0 : f32
- %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
- %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
- return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
-// CHECK: extract_slice
-// CHECK: vector.transfer_read
-func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
- %c3 = arith.constant 3 : index
- %c4 = arith.constant 4 : index
- %cst = arith.constant 0.0 : f32
- %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
- %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
- return %1 : vector<5x6xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
-// CHECK: %[[c3:.*]] = arith.constant 3 : index
-// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
-// CHECK: return %[[r]]
-func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
- %c0 = arith.constant 0 : index
- %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
- %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
- return %1 : tensor<?x12xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
-// CHECK: vector.transfer_write
-// CHECK: insert_slice
-func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
- %c0 = arith.constant 0 : index
- %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
- %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
- return %1 : tensor<?x?x12xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
-// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
-// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
-// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
-// CHECK: return %[[r]]
-func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
- %c0 = arith.constant 0 : index
- %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
- %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
- return %1 : tensor<?x?x12xf32>
-}
-
-// -----
-
// CHECK: #[[$MAP:[0-9a-z]+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: func @swap_extract_slice_transfer_write
diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
new file mode 100644
index 000000000000..cc17025fe0f1
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt -split-input-file -test-vector-transfer-tensor-slice-patterns %s | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<5x6xf32>
+// CHECK: return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<10x?xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
+// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
+// CHECK: return %[[r]]
+func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor<?x?xf32> to tensor<10x?xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32>
+ return %1 : vector<6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
+// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[c10:.*]] = arith.constant 10 : index
+// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c3]]
+// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c5]], %[[add]], %[[c10]]], %{{.*}} {in_bounds = [true, true]} : tensor<?x?x?xf32>, vector<5x6xf32>
+// CHECK: return %[[r]]
+func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1, 6] [1, %s2, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
+// CHECK: extract_slice
+// CHECK: vector.transfer_read
+func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x?xf32>, %s1 : index, %s2 : index) -> vector<5x6xf32> {
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor<?x?x?xf32> to tensor<?x12xf32>
+ %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor<?x12xf32>, vector<5x6xf32>
+ return %1 : vector<5x6xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+// CHECK: %[[c3:.*]] = arith.constant 3 : index
+// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x12xf32>
+// CHECK: return %[[r]]
+func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x12xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[3, %s] [5, 6] [1, 1] : tensor<5x6xf32> into tensor<?x12xf32>
+ return %1 : tensor<?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
+// CHECK: vector.transfer_write
+// CHECK: insert_slice
+func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+ return %1 : tensor<?x?x12xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
+// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
+// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK: %[[r:.*]] = vector.transfer_write %[[v]], %[[t1]][%[[c4]], %[[c3]], %[[s]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<?x?x12xf32>
+// CHECK: return %[[r]]
+func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32>, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor<?x?x12xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32>
+ %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
+ return %1 : tensor<?x?x12xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d0c79ab98915..50dfeff635cc 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -679,6 +679,26 @@ struct TestVectorGatherLowering
}
};
+struct TestVectorTransferTensorSlicePatterns
+ : public PassWrapper<TestVectorTransferTensorSlicePatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorTransferTensorSlicePatterns)
+
+ StringRef getArgument() const final {
+ return "test-vector-transfer-tensor-slice-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns that fold vector transfer and tensor slice ops";
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorTransferTensorSliceTransforms(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
} // namespace
namespace mlir {
@@ -713,6 +733,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestCreateVectorBroadcast>();
PassRegistration<TestVectorGatherLowering>();
+
+ PassRegistration<TestVectorTransferTensorSlicePatterns>();
}
} // namespace test
} // namespace mlir