summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2023-04-12 08:10:24 +0000
committerAlex Zinenko <zinenko@google.com>2023-04-25 22:44:01 +0000
commit135d29c3f50f21823645ed3f772c979c8253cd5d (patch)
tree9c6557663ea60ad3ddf74d756516c1f6379a4988
parentdcfdb963d4f036f02bbe4d8cf3fa55294c49fca7 (diff)
downloadllvm-135d29c3f50f21823645ed3f772c979c8253cd5d.tar.gz
[mlir] reorgnize Linalg TransformOps files. NFC
Mirror the separation between LinalgTransformOps and LinalgMatchOps in headers. Create a separate pair of files for the extension. Depends on D148017 Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D148075
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/DialectExtension.h15
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h48
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h27
-rw-r--r--mlir/include/mlir/InitAllDialects.h2
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp59
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp40
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel8
9 files changed, 131 insertions, 71 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/DialectExtension.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/DialectExtension.h
new file mode 100644
index 000000000000..b89fbc527d56
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/DialectExtension.h
@@ -0,0 +1,15 @@
+//===- DialectExtension.h - Linalg transform dialect extension --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerTransformDialectExtension(DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h
new file mode 100644
index 000000000000..d6bbcf88b79f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h
@@ -0,0 +1,48 @@
+//===- LinalgMatchOps.h - Linalg transform matcher ops ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGMATCHOPS_H
+#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGMATCHOPS_H
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+
+namespace mlir {
+namespace transform {
+
+namespace detail {
+LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op,
+ Value structuredOpHandle);
+} // namespace detail
+
+template <typename OpTy>
+class StructuredOpPredicateOpTrait
+ : public OpTrait::TraitBase<OpTy, StructuredOpPredicateOpTrait> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ static_assert(
+ OpTy::template hasTrait<SingleOpMatcherOpTrait>(),
+ "StructuredOpPredicateOpTrait requires SingleOpMatcherOpTrait");
+
+ return detail::verifyStructuredOpPredicateOpTrait(
+ op, cast<OpTy>(op).getOperandHandle());
+ }
+};
+
+} // namespace transform
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Linalg Matcher Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc"
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGMATCHOPS_H
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 6276c5687808..775377e5e5bc 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -11,7 +11,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
-#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
@@ -56,30 +55,7 @@ DiagnosedSilenceableFailure tileToForallOpImpl(
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
-namespace detail {
-LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op,
- Value structuredOpHandle);
-} // namespace detail
-
-template <typename OpTy>
-class StructuredOpPredicateOpTrait
- : public OpTrait::TraitBase<OpTy, StructuredOpPredicateOpTrait> {
-public:
- static LogicalResult verifyTrait(Operation *op) {
- static_assert(
- OpTy::template hasTrait<SingleOpMatcherOpTrait>(),
- "StructuredOpPredicateOpTrait requires SingleOpMatcherOpTrait");
-
- return detail::verifyStructuredOpPredicateOpTrait(
- op, cast<OpTy>(op).getOperandHandle());
- }
-};
-
} // namespace transform
-
-namespace linalg {
-void registerTransformDialectExtension(DialectRegistry &registry);
-} // namespace linalg
} // namespace mlir
//===----------------------------------------------------------------------===//
@@ -91,7 +67,4 @@ void registerTransformDialectExtension(DialectRegistry &registry);
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h.inc"
-#define GET_OP_CLASSES
-#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h.inc"
-
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 578594eee9df..6f78babc6f33 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -42,7 +42,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
+#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index 079d585d4ea3..01038ed297f1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRLinalgTransformOps
+ DialectExtension.cpp
LinalgMatchOps.cpp
LinalgTransformOps.cpp
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp
new file mode 100644
index 000000000000..6cc296105be7
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp
@@ -0,0 +1,59 @@
+//===- DialectExtension.cpp - Linalg transform dialect extension ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+
+namespace {
+/// Registers new ops and declares PDL as dependent dialect since the
+/// additional ops are using PDL types for operands and results.
+class LinalgTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ LinalgTransformDialectExtension> {
+public:
+ using Base::Base;
+
+ void init() {
+ declareDependentDialect<pdl::PDLDialect>();
+ declareDependentDialect<linalg::LinalgDialect>();
+
+ declareGeneratedDialect<affine::AffineDialect>();
+ declareGeneratedDialect<arith::ArithDialect>();
+ declareGeneratedDialect<scf::SCFDialect>();
+ declareGeneratedDialect<vector::VectorDialect>();
+ declareGeneratedDialect<gpu::GPUDialect>();
+ declareGeneratedDialect<tensor::TensorDialect>();
+
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
+ >();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+void mlir::linalg::registerTransformDialectExtension(
+ DialectRegistry &registry) {
+ registry.addExtensions<LinalgTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index e2d895f12f53..1936a53201f2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/FunctionImplementation.h"
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f113d3af7b44..9c7236e21842 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -3113,46 +3114,7 @@ DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
return diag;
}
-//===----------------------------------------------------------------------===//
-// Transform op registration
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Registers new ops and declares PDL as dependent dialect since the
-/// additional ops are using PDL types for operands and results.
-class LinalgTransformDialectExtension
- : public transform::TransformDialectExtension<
- LinalgTransformDialectExtension> {
-public:
- using Base::Base;
-
- void init() {
- declareDependentDialect<pdl::PDLDialect>();
- declareDependentDialect<LinalgDialect>();
- declareGeneratedDialect<affine::AffineDialect>();
- declareGeneratedDialect<arith::ArithDialect>();
- declareGeneratedDialect<scf::SCFDialect>();
- declareGeneratedDialect<vector::VectorDialect>();
- declareGeneratedDialect<gpu::GPUDialect>();
-
- registerTransformOps<
-#define GET_OP_LIST
-#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
- >();
- registerTransformOps<
-#define GET_OP_LIST
-#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
- >();
- }
-};
-} // namespace
-
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
-
-void mlir::linalg::registerTransformDialectExtension(
- DialectRegistry &registry) {
- registry.addExtensions<LinalgTransformDialectExtension>();
-}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 3190e65fc310..beaf0f160d65 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8787,9 +8787,9 @@ cc_library(
srcs = glob([
"lib/Dialect/Linalg/TransformOps/*.cpp",
]),
- hdrs = [
- "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h",
- ],
+ hdrs = glob([
+ "include/mlir/Dialect/Linalg/TransformOps/*.h",
+ ]),
includes = ["include"],
deps = [
":AffineDialect",
@@ -8807,6 +8807,7 @@ cc_library(
":LinalgTransforms",
":LinalgUtils",
":PDLDialect",
+ ":SCFDialect",
":SCFTransforms",
":Support",
":TensorDialect",
@@ -8815,6 +8816,7 @@ cc_library(
":TransformDialect",
":TransformDialectUtils",
":TransformUtils",
+ ":VectorDialect",
":VectorTransforms",
"//llvm:Support",
],