summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2023-05-15 12:28:21 +0000
committerAlex Zinenko <zinenko@google.com>2023-05-15 14:30:19 +0000
commit1365ff74cb7d9f15feafdd4fbe2996d2f9e42a5e (patch)
tree7e73dd27cac7366dae969a3b97953353b0781458 /mlir
parentd421f5226048e4a5d88aab157d0f4d434c43f208 (diff)
downloadllvm-1365ff74cb7d9f15feafdd4fbe2996d2f9e42a5e.tar.gz
[mlir] allow repeated payload in structured.fuse_into_containing
Structured fusion proceeds by iteratively finding the next suitable producer to be fused into the loop. Therefore, it shouldn't matter if the same producer is listed multiple times (e.g., it is used as multiple operands). Adjust the implementation of the transform op to support this case. Also fix the checking code in the interpreter to actually respect the TransformOpInterface indication that repeated payload is allowed, it seems to have been accidentally dropped in one of the refactorings. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D150561
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td9
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp9
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp4
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir36
4 files changed, 52 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index cdeabb743551..c7bc3767b27c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -143,7 +143,8 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
def FuseIntoContainingOp :
Op<Transform_Dialect, "structured.fuse_into_containing_op",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
+ [DeclareOpInterfaceMethods<TransformOpInterface,
+ ["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Fuse a producer into a containing operation.";
@@ -160,7 +161,7 @@ def FuseIntoContainingOp :
producer op handle may be associated with multiple payload ops. This
transform fuses producers one-by-one, always picking an unspecified producer
that has at least one use inside the containing op among the
- producers.
+ producers. A producer can be listed multiple times in the handle.
Note: If a producer has multiple uses inside the containing op, it is
currently tiled and/or cloned multiple times into the containing op.
@@ -176,8 +177,8 @@ def FuseIntoContainingOp :
containing op. I.e., "producers" that are not consumed within the containing
op are rejected by this operation.
- This operation reads and frees the producer handle.
- This operation reads the containing op handle.
+ This operation consumes the producer handle.
+ This operation only reads the containing op handle.
}];
let arguments = (ins PDL_Operation:$producer_op,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index afef59990afc..0703ca31f402 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -571,6 +571,11 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
return fusedOp;
}
+bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
+ // Allow repeated handles since we are fusing everything anyway.
+ return true;
+}
+
DiagnosedSilenceableFailure
transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
@@ -591,8 +596,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
// Helper function to find the next producer that should be fused. Take any
// producer that has a use inside the containing op.
- SmallVector<Operation *> remainingProducers(producerOps.begin(),
- producerOps.end());
+ SetVector<Operation *> remainingProducers(producerOps.begin(),
+ producerOps.end());
auto getNextProducer = [&]() -> FailureOr<Operation *> {
for (const auto &it : enumerate(remainingProducers)) {
Operation *producerOp = it.value();
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index bad1d74fb473..5685187e853f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -724,6 +724,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("--handle not consumed -> SKIP\n");
continue;
}
+ if (transform.allowsRepeatedHandleOperands()) {
+ FULL_LDBG("--op allows repeated handles -> SKIP\n");
+ continue;
+ }
FULL_LDBG("--handle is consumed\n");
Type operandType = operand.get().getType();
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index d6b3ff3181b2..537ee8664df4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -247,3 +247,39 @@ module {
transform.structured.fuse_into_containing_op %0 into %1
}
}
+
+// -----
+
+module {
+ // CHECK-LABEL: func.func @fuse_repeated
+ func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32>
+
+ // CHECK: scf.forall
+ %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) {
+ %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+ %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32>
+ // CHECK: %[[FUSED:.+]] = linalg.fill
+ // CHECK: elemwise_unary ins(%[[FUSED]]
+ %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32>
+ }
+ }
+
+ return %1 : tensor<2xf32>
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !pdl.operation
+
+ // Create a new handle that points to `linalg.fill` twice.
+ %2 = transform.merge_handles %0, %0 : !pdl.operation
+
+ // It shouldn't be a problem to fuse this handle.
+ transform.structured.fuse_into_containing_op %2 into %1
+ }
+}