diff options
author | Tobias Gysi <gysit@google.com> | 2021-09-27 19:20:56 +0000 |
---|---|---|
committer | Tobias Gysi <gysit@google.com> | 2021-09-27 19:21:37 +0000 |
commit | d20d0e145d2fce3a39bd6a76df47455134167cdd (patch) | |
tree | a53dd0d927717df7d230b27f4057834ccd52256d | |
parent | 2a7a768dad3a77571fae8506d84078fe4ce3d105 (diff) | |
download | llvm-d20d0e145d2fce3a39bd6a76df47455134167cdd.tar.gz |
[mlir][linalg] Finer-grained padding control.
Adapt the signature of the PaddingValueComputationFunction callback to either return the padding value or failure to signal padding is not desired.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D110572
4 files changed, 42 insertions, 30 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 8530b89d5ea9..03843bdfaa01 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -438,10 +438,11 @@ private: using TileSizeComputationFunction = std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>; -/// Specify the padding value for an OpOperand. This should be a function of -/// both the operation and the operand type. +/// Callback returning the padding value to use for a given OpOperand or failure +/// for no padding. This should be a function of both the operation and the +/// operand type. using PaddingValueComputationFunction = - std::function<Value(OpBuilder &, OpOperand &)>; + std::function<FailureOr<Value>(OpBuilder &, OpOperand &)>; struct LinalgTilingOptions { /// Computation function that returns the tile sizes for each operation. @@ -504,10 +505,11 @@ struct LinalgTilingOptions { return *this; } - /// Computation function that returns a padding value to use when padding to - /// force static sizes. When `paddingValueComputationFunction` is set, padding - /// operations are introduced, that guarantee the underlying op is statically - /// shaped and can thus be vectorized. + /// Callback returning the padding value to use for a given OpOperand or + /// failure for no padding. Padding operations are introduced if + /// `paddingValueComputationFunction` is set and does not return failure. + /// Padding all operands guarantees the operation is statically shaped and + /// thus can be vectorized. PaddingValueComputationFunction paddingValueComputationFunction = nullptr; LinalgTilingOptions & diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index cef9e5a030ae..1d28451ae05e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -145,16 +145,21 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() { return *this; } -/// Try to compute a static bounding box for `operand`. The padding happens -/// even if the operand already has static shape. `result` is the result of a -/// freshly created PadTensorOp. Return failure if the operand cannot be padded -/// to a static shape. +/// Helper function that tries to pad `opOperand`. Exit early and return success +/// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to +/// pad the operand even if it already has a static shape. Set `result` to the +/// result of the created PadTensorOp or return failure if the operand cannot be +/// padded to a static shape. static LogicalResult padOperandToSmallestStaticBoundingBox( PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, const PaddingValueComputationFunction &paddingFunc, Value &result) { // Can't pad scalars. if (opToPad.getShape(opOperand).empty()) return success(); + // Can't pad if no padding value is known. + FailureOr<Value> paddingValue = paddingFunc(rewriter, *opOperand); + if (failed(paddingValue)) + return success(); auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); // Not a slice op, cannot construct a static bounding box. if (!sliceOp) @@ -173,12 +178,11 @@ static LogicalResult padOperandToSmallestStaticBoundingBox( opToPad, "No constant bounding box can be found for padding"); staticSizes.push_back(indexAttr.getInt()); } - Value pad = paddingFunc(rewriter, *opOperand); auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); result = linalg::PadTensorOp::createPadHighOp( - staticTensorType, opOperand->get(), pad, /*packing=*/true, - opToPad->getLoc(), rewriter); + staticTensorType, opOperand->get(), paddingValue.getValue(), + /*packing=*/true, opToPad->getLoc(), rewriter); return success(); } diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir index fb283814a466..7f30762d48dc 100644 --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3,4" -canonicalize | FileCheck %s -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 tile-sizes=2,3,4" -canonicalize | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE // CHECK-LABEL: func @matmul_tensors( // CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xi8> @@ -97,6 +97,7 @@ func @matmul_partially_padded_tensors( // CHECK: linalg.matmul_i8_i8_i32 ins({{.*}}, {{.*}} : tensor<2x4xi8>, tensor<4x3xi8>) outs({{.*}} : tensor<2x3xi32>) -> tensor<2x3xi32> +// Check only the the input operands are padded. // CHECK-1DIM-TILE: func @matmul_partially_padded_tensors( // CHECK-1DIM-TILE-SAME: %[[TA:[0-9a-z]+]]: tensor<?x8xi8> // CHECK-1DIM-TILE-SAME: %[[TB:[0-9a-z]+]]: tensor<8x?xi8> @@ -111,10 +112,8 @@ func @matmul_partially_padded_tensors( // CHECK-1DIM-TILE: : tensor<?x8xi8> to tensor<2x8xi8> // CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK-1DIM-TILE: : tensor<8x?xi8> to tensor<8x3xi8> -// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] -// CHECK-1DIM-TILE: : tensor<?x?xi32> to tensor<2x3xi32> -// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) -// CHECK-1DIM-TILE: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) +// CHECK-1DIM-TILE: outs(%[[sTC]] : tensor<?x?xi32>) -> tensor<?x?xi32> // Check that the tile-and-pad transformation actually introduces the padding // as requested, even if original operation already operates on static diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 4b4b6acbc917..74ff41c052de 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -108,9 +108,10 @@ struct TestLinalgTransforms llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " "pad_tensor(subtensor)"), llvm::cl::init(false)}; - Option<bool> padTiles{*this, "pad-tiles", - llvm::cl::desc("Pad tiles when test-tile-pattern"), - llvm::cl::init(false)}; + ListOption<int64_t> paddedOperands{ + *this, "padded-operands", + llvm::cl::desc("Operands to pad when test-tile-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption<int64_t> peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern"), @@ -576,7 +577,8 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { } static void applyTilePattern(FuncOp funcOp, std::string loopType, - ArrayRef<int64_t> tileSizes, bool padTiles, + ArrayRef<int64_t> tileSizes, + ArrayRef<int64_t> paddedOperands, ArrayRef<int64_t> peeledLoops, bool scalarizeDynamicDims) { MLIRContext *context = funcOp.getContext(); @@ -597,10 +599,15 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType, } else { linalgTilingOptions.setTileSizes(tileSizes); } - if (padTiles) - linalgTilingOptions.setPaddingValueComputationFunction( - getNeutralOfLinalgOp); - + if (!paddedOperands.empty()) { + auto paddingFunc = [&](OpBuilder &b, + OpOperand &opOperand) -> FailureOr<Value> { + if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0) + return failure(); + return getNeutralOfLinalgOp(b, opOperand); + }; + linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); + } tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>, linalg::LinalgTilingPattern<linalg::GenericOp>>( @@ -734,10 +741,10 @@ void TestLinalgTransforms::runOnFunction() { return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, skipPartial); if (testTilePattern) - return applyTilePattern(getFunction(), loopType, tileSizes, padTiles, + return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, peeledLoops, /*scalarizeDynamicDims=*/false); if (testTileScalarizeDynamicDims) - return applyTilePattern(getFunction(), loopType, tileSizes, padTiles, + return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); if (testHoistPadding) { getFunction().walk([&](linalg::PadTensorOp padTensorOp) { |