summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2022-12-07 16:22:07 +0100
committerMatthias Springer <springerm@google.com>2022-12-07 16:25:10 +0100
commit9cdf6b641da1a7ba0145b224460c64efd65017e0 (patch)
treeca1d20c2f9c62636e56612b930b47de8244d0b41 /mlir/lib/Dialect
parentcb3ea52a5ae5178b8cd257bd61c6e05d9a186b4d (diff)
downloadllvm-9cdf6b641da1a7ba0145b224460c64efd65017e0.tar.gz
[mlir][tensor] Support parallel_insert_slice in reassociative reshape folder
Differential Revision: https://reviews.llvm.org/D139540
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp23
1 files changed, 12 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index b655df3c2cc4..d40e5f33d2a7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -11,8 +11,6 @@
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
-#define DEBUG_TYPE "mlir-tensor-split-padding"
-
using namespace mlir;
using namespace mlir::tensor;
@@ -51,13 +49,14 @@ struct FoldExpandOfRankReducingExtract
};
/// Fold insert_slice(collapse_shape) ops that cancel itself out.
-struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> {
- using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+template <typename OpTy>
+struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+ LogicalResult matchAndRewrite(OpTy insertSliceOp,
PatternRewriter &rewriter) const override {
auto collapseShapeOp =
- insertSliceOp.getSource().getDefiningOp<CollapseShapeOp>();
+ insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
if (!collapseShapeOp)
return failure();
RankedTensorType srcType = collapseShapeOp.getSrcType();
@@ -67,16 +66,16 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> {
// has no rank-reduction anymore are supported at the moment.
RankedTensorType nonReducingInsertType =
RankedTensorType::get(insertSliceOp.getStaticSizes(),
- insertSliceOp.getType().getElementType());
+ insertSliceOp.getDestType().getElementType());
if (nonReducingInsertType != srcType)
return failure();
SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
- rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
- insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(),
- mixedOffsets, mixedSizes, mixedStrides);
+ rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
+ insertSliceOp.getDest(), mixedOffsets,
+ mixedSizes, mixedStrides);
return success();
}
};
@@ -84,6 +83,8 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> {
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
RewritePatternSet &patterns) {
- patterns.add<FoldExpandOfRankReducingExtract, FoldInsertOfRankReducingInsert>(
+ patterns.add<FoldExpandOfRankReducingExtract,
+ FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
+ FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>>(
patterns.getContext());
}