summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td9
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp9
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp4
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir36
4 files changed, 52 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index cdeabb743551..c7bc3767b27c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -143,7 +143,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
def FuseIntoContainingOp :
Op<Transform_Dialect, "structured.fuse_into_containing_op",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
+ [DeclareOpInterfaceMethods<TransformOpInterface,
+ ["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Fuse a producer into a containing operation.";
@@ -160,7 +161,7 @@ def FuseIntoContainingOp :
producer op handle may be associated with multiple payload ops. This
transform fuses producers one-by-one, always picking an unspecified producer
that has at least one use inside the containing op among the
- producers.
+ producers. A producer can be listed multiple times in the handle.
Note: If a producer has multiple uses inside the containing op, it is
currently tiled and/or cloned multiple times into the containing op.
@@ -176,8 +177,8 @@ def FuseIntoContainingOp :
containing op. I.e., "producers" that are not consumed within the containing
op are rejected by this operation.
- This operation reads and frees the producer handle.
- This operation reads the containing op handle.
+ This operation consumes the producer handle.
+ This operation only reads the containing op handle.
}];
let arguments = (ins PDL_Operation:$producer_op,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index afef59990afc..0703ca31f402 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -571,6 +571,11 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
return fusedOp;
}
+bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
+ // Allow repeated handles since we are fusing everything anyway.
+ return true;
+}
+
DiagnosedSilenceableFailure
transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
@@ -591,8 +596,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
// Helper function to find the next producer that should be fused. Take any
// producer that has a use inside the containing op.
- SmallVector<Operation *> remainingProducers(producerOps.begin(),
- producerOps.end());
+ SetVector<Operation *> remainingProducers(producerOps.begin(),
+ producerOps.end());
auto getNextProducer = [&]() -> FailureOr<Operation *> {
for (const auto &it : enumerate(remainingProducers)) {
Operation *producerOp = it.value();
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index bad1d74fb473..5685187e853f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -724,6 +724,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("--handle not consumed -> SKIP\n");
continue;
}
+ if (transform.allowsRepeatedHandleOperands()) {
+ FULL_LDBG("--op allows repeated handles -> SKIP\n");
+ continue;
+ }
FULL_LDBG("--handle is consumed\n");
Type operandType = operand.get().getType();
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index d6b3ff3181b2..537ee8664df4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -247,3 +247,39 @@ module {
transform.structured.fuse_into_containing_op %0 into %1
}
}
+
+// -----
+
+module {
+ // CHECK-LABEL: func.func @fuse_repeated
+ func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32>
+
+ // CHECK: scf.forall
+ %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) {
+ %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+ %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+ // CHECK: %[[FUSED:.+]] = linalg.fill
+ // CHECK: elemwise_unary ins(%[[FUSED]]
+ %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32>
+ }
+ }
+
+ return %1 : tensor<2xf32>
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+
+ // Create a new handle that points to `linalg.fill` twice.
+ %2 = transform.merge_handles %0, %0 : !pdl.operation
+
+ // It shouldn't be a problem to fuse this handle.
+ transform.structured.fuse_into_containing_op %2 into %1
+ }
+}