diff options
author | Matthias Springer <springerm@google.com> | 2022-12-07 16:22:07 +0100 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2022-12-07 16:25:10 +0100 |
commit | 9cdf6b641da1a7ba0145b224460c64efd65017e0 (patch) | |
tree | ca1d20c2f9c62636e56612b930b47de8244d0b41 /mlir/lib/Dialect | |
parent | cb3ea52a5ae5178b8cd257bd61c6e05d9a186b4d (diff) | |
download | llvm-9cdf6b641da1a7ba0145b224460c64efd65017e0.tar.gz |
[mlir][tensor] Support parallel_insert_slice in reassociative reshape folder
Differential Revision: https://reviews.llvm.org/D139540
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index b655df3c2cc4..d40e5f33d2a7 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -11,8 +11,6 @@ #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" -#define DEBUG_TYPE "mlir-tensor-split-padding" - using namespace mlir; using namespace mlir::tensor; @@ -51,13 +49,14 @@ struct FoldExpandOfRankReducingExtract }; /// Fold insert_slice(collapse_shape) ops that cancel itself out. -struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> { - using OpRewritePattern<InsertSliceOp>::OpRewritePattern; +template <typename OpTy> +struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> { + using OpRewritePattern<OpTy>::OpRewritePattern; - LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override { auto collapseShapeOp = - insertSliceOp.getSource().getDefiningOp<CollapseShapeOp>(); + insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>(); if (!collapseShapeOp) return failure(); RankedTensorType srcType = collapseShapeOp.getSrcType(); @@ -67,16 +66,16 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> { // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingInsertType = RankedTensorType::get(insertSliceOp.getStaticSizes(), - insertSliceOp.getType().getElementType()); + insertSliceOp.getDestType().getElementType()); if (nonReducingInsertType != srcType) return failure(); SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); - rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( - insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(), - mixedOffsets, mixedSizes, mixedStrides); + rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(), + insertSliceOp.getDest(), mixedOffsets, + mixedSizes, mixedStrides); return success(); } }; @@ -84,6 +83,8 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> { void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { - patterns.add<FoldExpandOfRankReducingExtract, FoldInsertOfRankReducingInsert>( + patterns.add<FoldExpandOfRankReducingExtract, + FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>, + FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>>( patterns.getContext()); } |