diff options
author | Alex Zinenko <zinenko@google.com> | 2023-05-15 12:28:21 +0000 |
---|---|---|
committer | Alex Zinenko <zinenko@google.com> | 2023-05-15 14:30:19 +0000 |
commit | 1365ff74cb7d9f15feafdd4fbe2996d2f9e42a5e (patch) | |
tree | 7e73dd27cac7366dae969a3b97953353b0781458 /mlir | |
parent | d421f5226048e4a5d88aab157d0f4d434c43f208 (diff) | |
download | llvm-1365ff74cb7d9f15feafdd4fbe2996d2f9e42a5e.tar.gz |
[mlir] allow repeated payload in structured.fuse_into_containing
Structured fusion proceeds by iteratively finding the next suitable
producer to be fused into the loop. Therefore, it shouldn't matter if
the same producer is listed multiple times (e.g., it is used as multiple
operands). Adjust the implementation of the transform op to support this
case.
Also fix the checking code in the interpreter to actually respect the
TransformOpInterface indication that repeated payload is allowed, it
seems to have been accidentally dropped in one of the refactorings.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D150561
Diffstat (limited to 'mlir')
4 files changed, 52 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index cdeabb743551..c7bc3767b27c 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -143,7 +143,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse", def FuseIntoContainingOp : Op<Transform_Dialect, "structured.fuse_into_containing_op", - [DeclareOpInterfaceMethods<TransformOpInterface>, + [DeclareOpInterfaceMethods<TransformOpInterface, + ["allowsRepeatedHandleOperands"]>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { let summary = "Fuse a producer into a containing operation."; @@ -160,7 +161,7 @@ def FuseIntoContainingOp : producer op handle may be associated with multiple payload ops. This transform fuses producers one-by-one, always picking an unspecified producer that has at least one use inside the containing op among the - producers. + producers. A producer can be listed multiple times in the handle. Note: If a producer has multiple uses inside the containing op, it is currently tiled and/or cloned multiple times into the containing op. @@ -176,8 +177,8 @@ def FuseIntoContainingOp : containing op. I.e., "producers" that are not consumed within the containing op are rejected by this operation. - This operation reads and frees the producer handle. - This operation reads the containing op handle. + This operation consumes the producer handle. + This operation only reads the containing op handle. }]; let arguments = (ins PDL_Operation:$producer_op, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index afef59990afc..0703ca31f402 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -571,6 +571,11 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, return fusedOp; } +bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() { + // Allow repeated handles since we are fusing everything anyway. + return true; +} + DiagnosedSilenceableFailure transform::FuseIntoContainingOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -591,8 +596,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. - SmallVector<Operation *> remainingProducers(producerOps.begin(), - producerOps.end()); + SetVector<Operation *> remainingProducers(producerOps.begin(), + producerOps.end()); auto getNextProducer = [&]() -> FailureOr<Operation *> { for (const auto &it : enumerate(remainingProducers)) { Operation *producerOp = it.value(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index bad1d74fb473..5685187e853f 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -724,6 +724,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { FULL_LDBG("--handle not consumed -> SKIP\n"); continue; } + if (transform.allowsRepeatedHandleOperands()) { + FULL_LDBG("--op allows repeated handles -> SKIP\n"); + continue; + } FULL_LDBG("--handle is consumed\n"); Type operandType = operand.get().getType(); diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index d6b3ff3181b2..537ee8664df4 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -247,3 +247,39 @@ module { transform.structured.fuse_into_containing_op %0 into %1 } } + +// ----- + +module { + // CHECK-LABEL: func.func @fuse_repeated + func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32> + + // CHECK: scf.forall + %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) { + %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32> + %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32> + // CHECK: %[[FUSED:.+]] = linalg.fill + // CHECK: elemwise_unary ins(%[[FUSED]] + %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32> + } + } + + return %1 : tensor<2xf32> + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !pdl.operation + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !pdl.operation + + // Create a new handle that points to `linalg.fill` twice. + %2 = transform.merge_handles %0, %0 : !pdl.operation + + // It shouldn't be a problem to fuse this handle. + transform.structured.fuse_into_containing_op %2 into %1 + } +} |