summaryrefslogtreecommitdiff
path: root/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h')
-rw-r--r--mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h11
1 files changed, 11 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 325860079b3d..2912c0252872 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -218,6 +218,17 @@ void populateBreakDownVectorBitCastOpPatterns(
void populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Collect patterns to fold tensor.extract_slice -> vector.transfer_read and
+/// vector.transfer_write -> tensor.insert_slice op chains into vector tranfer
+/// read and write ops.
+///
+/// If `controlFn` is not nullptr, the pattern will only apply to ops where
+/// `controlFn` returns true, given the vector transfer read/write op as input.
+void populateVectorTransferTensorSliceTransforms(
+ RewritePatternSet &patterns,
+ std::function<bool(Operation *vectorOp)> controlFn = nullptr,
+ PatternBenefit benefit = 1);
+
/// Collect a set of pattern to unroll vector operations to a smaller shapes.
/// `options` structure controls which operations are unrolled and the target
/// shape.