summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTobias Gysi <gysit@google.com>2021-09-27 19:20:56 +0000
committerTobias Gysi <gysit@google.com>2021-09-27 19:21:37 +0000
commitd20d0e145d2fce3a39bd6a76df47455134167cdd (patch)
treea53dd0d927717df7d230b27f4057834ccd52256d
parent2a7a768dad3a77571fae8506d84078fe4ce3d105 (diff)
downloadllvm-d20d0e145d2fce3a39bd6a76df47455134167cdd.tar.gz
[mlir][linalg] Finer-grained padding control.
Adapt the signature of the PaddingValueComputationFunction callback to either return the padding value or failure to signal padding is not desired. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110572
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h16
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp18
-rw-r--r--mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir11
-rw-r--r--mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp27
4 files changed, 42 insertions, 30 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8530b89d5ea9..03843bdfaa01 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -438,10 +438,11 @@ private:
using TileSizeComputationFunction =
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
-/// Specify the padding value for an OpOperand. This should be a function of
-/// both the operation and the operand type.
+/// Callback returning the padding value to use for a given OpOperand or failure
+/// for no padding. This should be a function of both the operation and the
+/// operand type.
using PaddingValueComputationFunction =
- std::function<Value(OpBuilder &, OpOperand &)>;
+ std::function<FailureOr<Value>(OpBuilder &, OpOperand &)>;
struct LinalgTilingOptions {
/// Computation function that returns the tile sizes for each operation.
@@ -504,10 +505,11 @@ struct LinalgTilingOptions {
return *this;
}
- /// Computation function that returns a padding value to use when padding to
- /// force static sizes. When `paddingValueComputationFunction` is set, padding
- /// operations are introduced, that guarantee the underlying op is statically
- /// shaped and can thus be vectorized.
+ /// Callback returning the padding value to use for a given OpOperand or
+ /// failure for no padding. Padding operations are introduced if
+ /// `paddingValueComputationFunction` is set and does not return failure.
+ /// Padding all operands guarantees the operation is statically shaped and
+ /// thus can be vectorized.
PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
LinalgTilingOptions &
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index cef9e5a030ae..1d28451ae05e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -145,16 +145,21 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
return *this;
}
-/// Try to compute a static bounding box for `operand`. The padding happens
-/// even if the operand already has static shape. `result` is the result of a
-/// freshly created PadTensorOp. Return failure if the operand cannot be padded
-/// to a static shape.
+/// Helper function that tries to pad `opOperand`. Exit early and return success
+/// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to
+/// pad the operand even if it already has a static shape. Set `result` to the
+/// result of the created PadTensorOp or return failure if the operand cannot be
+/// padded to a static shape.
static LogicalResult padOperandToSmallestStaticBoundingBox(
PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
const PaddingValueComputationFunction &paddingFunc, Value &result) {
// Can't pad scalars.
if (opToPad.getShape(opOperand).empty())
return success();
+ // Can't pad if no padding value is known.
+ FailureOr<Value> paddingValue = paddingFunc(rewriter, *opOperand);
+ if (failed(paddingValue))
+ return success();
auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
// Not a slice op, cannot construct a static bounding box.
if (!sliceOp)
@@ -173,12 +178,11 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
opToPad, "No constant bounding box can be found for padding");
staticSizes.push_back(indexAttr.getInt());
}
- Value pad = paddingFunc(rewriter, *opOperand);
auto staticTensorType = RankedTensorType::get(
staticSizes, getElementTypeOrSelf(opOperand->get()));
result = linalg::PadTensorOp::createPadHighOp(
- staticTensorType, opOperand->get(), pad, /*packing=*/true,
- opToPad->getLoc(), rewriter);
+ staticTensorType, opOperand->get(), paddingValue.getValue(),
+ /*packing=*/true, opToPad->getLoc(), rewriter);
return success();
}
diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
index fb283814a466..7f30762d48dc 100644
--- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3,4" -canonicalize | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 tile-sizes=2,3,4" -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
// CHECK-LABEL: func @matmul_tensors(
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
@@ -97,6 +97,7 @@ func @matmul_partially_padded_tensors(
// CHECK: linalg.matmul_i8_i8_i32 ins({{.*}}, {{.*}} : tensor<2x4xi8>, tensor<4x3xi8>) outs({{.*}} : tensor<2x3xi32>) -> tensor<2x3xi32>
+// Check only the the input operands are padded.
// CHECK-1DIM-TILE: func @matmul_partially_padded_tensors(
// CHECK-1DIM-TILE-SAME: %[[TA:[0-9a-z]+]]: tensor<?x8xi8>
// CHECK-1DIM-TILE-SAME: %[[TB:[0-9a-z]+]]: tensor<8x?xi8>
@@ -111,10 +112,8 @@ func @matmul_partially_padded_tensors(
// CHECK-1DIM-TILE: : tensor<?x8xi8> to tensor<2x8xi8>
// CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
// CHECK-1DIM-TILE: : tensor<8x?xi8> to tensor<8x3xi8>
-// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
-// CHECK-1DIM-TILE: : tensor<?x?xi32> to tensor<2x3xi32>
-// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
-// CHECK-1DIM-TILE: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32>
+// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
+// CHECK-1DIM-TILE: outs(%[[sTC]] : tensor<?x?xi32>) -> tensor<?x?xi32>
// Check that the tile-and-pad transformation actually introduces the padding
// as requested, even if original operation already operates on static
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4b4b6acbc917..74ff41c052de 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -108,9 +108,10 @@ struct TestLinalgTransforms
llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
"pad_tensor(subtensor)"),
llvm::cl::init(false)};
- Option<bool> padTiles{*this, "pad-tiles",
- llvm::cl::desc("Pad tiles when test-tile-pattern"),
- llvm::cl::init(false)};
+ ListOption<int64_t> paddedOperands{
+ *this, "padded-operands",
+ llvm::cl::desc("Operands to pad when test-tile-pattern"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
ListOption<int64_t> peeledLoops{
*this, "peeled-loops",
llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
@@ -576,7 +577,8 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
}
static void applyTilePattern(FuncOp funcOp, std::string loopType,
- ArrayRef<int64_t> tileSizes, bool padTiles,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> paddedOperands,
ArrayRef<int64_t> peeledLoops,
bool scalarizeDynamicDims) {
MLIRContext *context = funcOp.getContext();
@@ -597,10 +599,15 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType,
} else {
linalgTilingOptions.setTileSizes(tileSizes);
}
- if (padTiles)
- linalgTilingOptions.setPaddingValueComputationFunction(
- getNeutralOfLinalgOp);
-
+ if (!paddedOperands.empty()) {
+ auto paddingFunc = [&](OpBuilder &b,
+ OpOperand &opOperand) -> FailureOr<Value> {
+ if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0)
+ return failure();
+ return getNeutralOfLinalgOp(b, opOperand);
+ };
+ linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc);
+ }
tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>,
linalg::LinalgTilingPattern<linalg::GenericOp>>(
@@ -734,10 +741,10 @@ void TestLinalgTransforms::runOnFunction() {
return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling,
skipPartial);
if (testTilePattern)
- return applyTilePattern(getFunction(), loopType, tileSizes, padTiles,
+ return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
peeledLoops, /*scalarizeDynamicDims=*/false);
if (testTileScalarizeDynamicDims)
- return applyTilePattern(getFunction(), loopType, tileSizes, padTiles,
+ return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands,
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
if (testHoistPadding) {
getFunction().walk([&](linalg::PadTensorOp padTensorOp) {