diff options
author | Matthias Springer <springerm@google.com> | 2022-07-25 16:13:01 +0200 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2022-07-25 16:14:35 +0200 |
commit | a299539adeede887b39e4a913e98047651656592 (patch) | |
tree | 6ff00a043b3ae3135101a8923c001fdee6c1de96 /mlir | |
parent | 07aa8fc8db6b4b8581e0ba8ef4a66274023c0b59 (diff) | |
download | llvm-a299539adeede887b39e4a913e98047651656592.tar.gz |
[mlir][linalg] Expand test case for tile-and-fuse with transform dialect
Reverse the order of the payload ops. fuse_into_containing_op should still work.
Differential Revision: https://reviews.llvm.org/D130355
Diffstat (limited to 'mlir')
3 files changed, 83 insertions, 2 deletions
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir index a2547d5973a5..1109950916ed 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -canonicalize | FileCheck %s // This is a simple tile-and-fuse example with a single fusion group. @@ -22,7 +22,7 @@ module { {__producer__} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32> - %7 = linalg.generic + %7 = linalg.generic {__root__, indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, @@ -56,3 +56,64 @@ module { } } } + +// ----- + +// Inverse the order of the payload ops passed to the tile_to_foreach_thread_op +// op. Fusion should still work. + +module { + // CHECK: func @foo + // CHECK: scf.foreach_thread {{.*}} { + // CHECK: linalg.fill + // CHECK: linalg.matmul + // CHECK: linalg.generic + // CHECK: } + func.func @foo(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?xf32>, + %D: tensor<?x?xf32>, %sz0: index, %sz1: index) + -> tensor<?x?xf32> + { + %cst = arith.constant 0.000000e+00 : f32 + %5 = linalg.fill + {__producer__} + ins(%cst : f32) + outs(%D : tensor<?x?xf32>) -> tensor<?x?xf32> + %6 = linalg.matmul + {__producer__} + ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) + outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32> + %7 = linalg.generic + {__root__, + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } + ins(%C, %6 : tensor<?xf32>, tensor<?x?xf32>) + outs(%D : tensor<?x?xf32>) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): + %16 = arith.maxf %arg3, %cst : f32 + %17 = arith.cmpf ogt, %arg2, %cst : f32 + %18 = arith.select %17, %cst, %16 : f32 + linalg.yield %18 : f32 + } -> tensor<?x?xf32> + return %7 : tensor<?x?xf32> + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + // Find the root and all producers. + %root = transform.structured.match attribute{"__root__"} in %arg1 + %producers = transform.structured.match attribute{"__producer__"} in %arg1 + %reversed_producers = transform.test_reverse_payload_ops %producers + + // Tile the root. + %foreach_thread_op, %tiled_op = transform.structured.tile_to_foreach_thread_op %root num_threads [10, 20] + + // Fuse all producers. + transform.structured.fuse_into_containing_op %reversed_producers into %foreach_thread_op + } + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 3893508ff19f..e7ad8afa8592 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -198,6 +198,16 @@ DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( state.removeExtension<TestTransformStateExtension>(); return DiagnosedSilenceableFailure::success(); } + +DiagnosedSilenceableFailure +mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); + auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); + results.set(getResult().cast<OpResult>(), reversedOps); + return DiagnosedSilenceableFailure::success(); +} + DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index 3d05c6f4681d..4cc6b415ec60 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -101,6 +101,16 @@ def TestRemoveTestExtensionOp let cppNamespace = "::mlir::test"; } +def TestReversePayloadOpsOp + : Op<Transform_Dialect, "test_reverse_payload_ops", + [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, + DeclareOpInterfaceMethods<TransformOpInterface>]> { + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$result); + let assemblyFormat = "$target attr-dict"; + let cppNamespace = "::mlir::test"; +} + def TestTransformOpWithRegions : Op<Transform_Dialect, "test_transform_op_with_regions", [DeclareOpInterfaceMethods<TransformOpInterface>, |