diff options
author | MaheshRavishankar <ravishankarm@google.com> | 2021-03-29 09:18:43 -0700 |
---|---|---|
committer | MaheshRavishankar <ravishankarm@google.com> | 2021-03-29 09:19:36 -0700 |
commit | f0a2fe7f79d79c757fca5bd1498a014f2f98bb72 (patch) | |
tree | 8ab3c9472b15409fa9520d86ee6fff2806268006 | |
parent | e8515ca8478f96f7d2eddadc4d310ac29bb04abe (diff) | |
download | llvm-f0a2fe7f79d79c757fca5bd1498a014f2f98bb72.tar.gz |
[mlir][Linalg] Rewrite SubTensors that take a slice out of a unit-extend dimension.
Subtensor operations that are taking a slice out of a tensor that is
unit-extent along a dimension can be rewritten to drop that dimension.
Differential Revision: https://reviews.llvm.org/D99226
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 89 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir | 82 |
2 files changed, 152 insertions, 19 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 2d3e16fab960..d5f08056d551 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -171,8 +171,6 @@ LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>( namespace { /// Pattern to fold unit-trip count loops in GenericOps. -// TODO: Generalize this to indexed-generic as well by modifying the region args -// as well. template <typename GenericOpTy> struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> { using OpRewritePattern<GenericOpTy>::OpRewritePattern; @@ -375,9 +373,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> { return success(); } }; -} // namespace -namespace { /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for /// example: /// @@ -428,12 +424,12 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> { parentSrcType.getRank() == dstType.getRank()) return failure(); - // Check if the result tensor_reshape after folding the reshapeOp and - // parentReshapeOp are combined. - // If the final tensor_reshape is folding, the parentReshapeOp is - // introducing unit-dims, and the reshapeOp does an actual reshape. - // If the final tensor_reshape op is expanding, the reshapeOp is - // introducing unit-dims, and the parentReshapeOp does an actual reshape. + // Check if the result tensor_reshape is folding or expanding after folding + // the reshapeOp and parentReshapeOp are combined. If the final + // tensor_reshape is folding, the parentReshapeOp is introducing unit-dims, + // and the reshapeOp does an actual reshape. If the final tensor_reshape op + // is expanding, the reshapeOp is introducing unit-dims, and the + // parentReshapeOp does an actual reshape. bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank(); ArrayRef<int64_t> expandedShape = isFoldingPattern ? parentSrcType.getShape() : dstType.getShape(); @@ -485,6 +481,77 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> { return success(); } }; + +/// Pattern to fold subtensors that are just taking a slice of unit-dimension +/// tensor. For example +/// +/// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1] +/// : tensor<1x?x1xf32> to tensor<1x?x1xf32> +/// +/// can be replaced with +/// +/// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +/// : tensor<1x?x1xf32> into tensor<?xf32> +/// %1 = subtensor %0[%o1] [%s1] [1] : tensor<?xf32> to tensor<?xf32> +/// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] +/// : tensor<?xf32> into tensor<1x?x1xf32> +/// +/// The additional tensor_reshapes will hopefully get canonicalized away with +/// other reshapes that drop unit dimensions. Three condiitions to fold a +/// dimension +/// - The offset must be 0 +/// - The size must be 1 +/// - The dimension of the source type must be 1. +struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> { + using OpRewritePattern<SubTensorOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(SubTensorOp subTensorOp, + PatternRewriter &rewriter) const override { + SmallVector<OpFoldResult> mixedOffsets = subTensorOp.getMixedOffsets(); + SmallVector<OpFoldResult> mixedSizes = subTensorOp.getMixedSizes(); + SmallVector<OpFoldResult> mixedStrides = subTensorOp.getMixedStrides(); + auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) { + auto attr = valueOrAttr.dyn_cast<Attribute>(); + return attr && attr.cast<IntegerAttr>().getInt() == val; + }; + + if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) { + return !hasValue(valueOrAttr, 1); + })) + return failure(); + + // Find the expanded unit dimensions. + SmallVector<ReassociationIndices> reassociation; + SmallVector<OpFoldResult> newOffsets, newSizes; + ArrayRef<int64_t> sourceShape = subTensorOp.getSourceType().getShape(); + ReassociationIndices curr; + for (int64_t dim : llvm::seq<int64_t>(0, mixedOffsets.size())) { + curr.push_back(dim); + if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) && + hasValue(mixedSizes[dim], 1)) { + continue; + } + newOffsets.push_back(mixedOffsets[dim]); + newSizes.push_back(mixedSizes[dim]); + reassociation.emplace_back(ReassociationIndices{}); + std::swap(reassociation.back(), curr); + } + if (newOffsets.size() == mixedOffsets.size()) + return failure(); + reassociation.back().append(curr.begin(), curr.end()); + SmallVector<OpFoldResult> newStrides(newOffsets.size(), + rewriter.getI64IntegerAttr(1)); + Location loc = subTensorOp->getLoc(); + auto srcReshape = rewriter.create<TensorReshapeOp>( + loc, subTensorOp.source(), reassociation); + auto newSubTensorOp = rewriter.create<SubTensorOp>( + loc, srcReshape, newOffsets, newSizes, newStrides); + rewriter.replaceOpWithNewOp<TensorReshapeOp>( + subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation); + return success(); + } +}; + } // namespace /// Patterns that are used to canonicalize the use of unit-extent dims for @@ -493,7 +560,7 @@ void mlir::populateLinalgFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>, - ReplaceUnitExtentTensors<GenericOp>, + FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors<GenericOp>, ReplaceUnitExtentTensors<IndexedGenericOp>>(context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); patterns.add<FoldReshapeOpWithUnitExtent>(context); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index cb5d1089eb85..2a6711018988 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -55,12 +55,12 @@ func @drop_one_trip_loops_indexed_generic outs(%shape: tensor<?x1x?x1x?xi32>) { ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : i32, %arg7 : i32) : - %1 = addi %arg1, %arg2 : index - %2 = addi %1, %arg3 : index - %3 = addi %2, %arg4 : index - %4 = addi %3, %arg5 : index - %5 = index_cast %4 : index to i32 - %6 = addi %5, %arg6 : i32 + %1 = addi %arg1, %arg2 : index + %2 = addi %1, %arg3 : index + %3 = addi %2, %arg4 : index + %4 = addi %3, %arg5 : index + %5 = index_cast %4 : index to i32 + %6 = addi %5, %arg6 : i32 linalg.yield %6 : i32 } -> tensor<?x1x?x1x?xi32> return %0 : tensor<?x1x?x1x?xi32> @@ -120,8 +120,8 @@ func @drop_all_loops_indexed_generic outs(%arg0 : tensor<1x1xi32>) { ^bb0(%arg1 : index, %arg2 : index, %arg3: i32, %arg4: i32) : %1 = addi %arg1, %arg2 : index - %2 = index_cast %1 : index to i32 - %3 = addi %2, %arg3 : i32 + %2 = index_cast %1 : index to i32 + %3 = addi %2, %arg3 : i32 linalg.yield %3 : i32 } -> tensor<1x1xi32> return %0 : tensor<1x1xi32> @@ -390,3 +390,69 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32> // CHECK-SAME: outs(%[[FILL]] : tensor<f32>) // CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32> // CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32> + + +// ----- + +func @fold_subtensor( + %arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) + -> tensor<1x?x?x1x?x1x1xf32> { + %0 = subtensor %arg0[0, %arg1, %arg2, 0, %arg3, 0, 0] + [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] : + tensor<1x?x?x1x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32> + return %0 : tensor<1x?x?x1x?x1x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)> +// CHECK: func @fold_subtensor +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x?x1x?x1x1xf32> +// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-z0-9]+]]: index +// CHECK: %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]] +// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[ARG3]]] +// CHECK-SAME: [%[[ARG4]], %[[ARG5]], %[[ARG6]]] +// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK: return %[[RESULT_RESHAPE]] + +// ----- + +func @no_fold_subtensor( + %arg0 : tensor<1x?x?x?x?x1x1xf32>, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) + -> tensor<1x?x?x1x?x1x1xf32> { + %0 = subtensor %arg0[%arg1, 0, %arg2, 0, 0, %arg3, 0] + [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] : + tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32> + return %0 : tensor<1x?x?x1x?x1x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6)> +// CHECK: func @no_fold_subtensor +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x?x?x?x1x1xf32> +// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-z0-9]+]]: index +// CHECK: %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]] +// CHECK: %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]] +// CHECK-SAME: [%[[ARG1]], 0, %[[ARG2]], 0, 0, %[[ARG3]]] +// CHECK-SAME: [1, %[[ARG4]], %[[ARG5]], 1, %[[ARG6]], 1] +// CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] +// CHECK-SAME: [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]] +// CHECK: return %[[RESULT_RESHAPE]] |