diff options
Diffstat (limited to 'mlir')
4 files changed, 52 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index cdeabb743551..c7bc3767b27c 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -143,7 +143,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse", def FuseIntoContainingOp : Op<Transform_Dialect, "structured.fuse_into_containing_op", - [DeclareOpInterfaceMethods<TransformOpInterface>, + [DeclareOpInterfaceMethods<TransformOpInterface, + ["allowsRepeatedHandleOperands"]>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { let summary = "Fuse a producer into a containing operation."; @@ -160,7 +161,7 @@ def FuseIntoContainingOp : producer op handle may be associated with multiple payload ops. This transform fuses producers one-by-one, always picking an unspecified producer that has at least one use inside the containing op among the - producers. + producers. A producer can be listed multiple times in the handle. Note: If a producer has multiple uses inside the containing op, it is currently tiled and/or cloned multiple times into the containing op. @@ -176,8 +177,8 @@ def FuseIntoContainingOp : containing op. I.e., "producers" that are not consumed within the containing op are rejected by this operation. - This operation reads and frees the producer handle. - This operation reads the containing op handle. + This operation consumes the producer handle. + This operation only reads the containing op handle. }]; let arguments = (ins PDL_Operation:$producer_op, diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index afef59990afc..0703ca31f402 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -571,6 +571,11 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, return fusedOp; } +bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() { + // Allow repeated handles since we are fusing everything anyway. + return true; +} + DiagnosedSilenceableFailure transform::FuseIntoContainingOp::apply(transform::TransformResults &results, transform::TransformState &state) { @@ -591,8 +596,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, // Helper function to find the next producer that should be fused. Take any // producer that has a use inside the containing op. - SmallVector<Operation *> remainingProducers(producerOps.begin(), - producerOps.end()); + SetVector<Operation *> remainingProducers(producerOps.begin(), + producerOps.end()); auto getNextProducer = [&]() -> FailureOr<Operation *> { for (const auto &it : enumerate(remainingProducers)) { Operation *producerOp = it.value(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index bad1d74fb473..5685187e853f 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -724,6 +724,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { FULL_LDBG("--handle not consumed -> SKIP\n"); continue; } + if (transform.allowsRepeatedHandleOperands()) { + FULL_LDBG("--op allows repeated handles -> SKIP\n"); + continue; + } FULL_LDBG("--handle is consumed\n"); Type operandType = operand.get().getType(); diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index d6b3ff3181b2..537ee8664df4 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -247,3 +247,39 @@ module { transform.structured.fuse_into_containing_op %0 into %1 } } + +// ----- + +module { + // CHECK-LABEL: func.func @fuse_repeated + func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> { + %c0 = arith.constant 0.0 : f32 + %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32> + + // CHECK: scf.forall + %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) { + %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32> + %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32> + // CHECK: %[[FUSED:.+]] = linalg.fill + // CHECK: elemwise_unary ins(%[[FUSED]] + %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32> + } + } + + return %1 : tensor<2xf32> + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !pdl.operation + %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !pdl.operation + + // Create a new handle that points to `linalg.fill` twice. + %2 = transform.merge_handles %0, %0 : !pdl.operation + + // It shouldn't be a problem to fuse this handle. + transform.structured.fuse_into_containing_op %2 into %1 + } +} |