summaryrefslogtreecommitdiff
path: root/mlir/python
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2022-10-10 14:33:59 +0000
committerAlex Zinenko <zinenko@google.com>2022-10-11 09:55:19 +0000
commit3e1f6d02f755e34a0a12a8dd439fb65f84d6621f (patch)
treeacf2531c582ed1850f71c92a6fab79ce9cdd1e3f /mlir/python
parent6bb997c032f509d4f61461aa461c8431325cbb2a (diff)
downloadllvm-3e1f6d02f755e34a0a12a8dd439fb65f84d6621f.tar.gz
[mlir] add OperationType to the Transform dialect
Add a new OperationType handle type to the Transform dialect. This transform type is parameterized by the name of the payload operation it can point to. It is intended as a constraint on transformations that are only applicable to a specific kind of payload operations. If a transformation is applicable to a small set of operation classes, it can be wrapped into a transform op by using a disjunctive constraint, such as `Type<Or<[Transform_ConcreteOperation<"foo">.predicate, Transform_ConcreteOperation<"bar">.predicate]>>` for its operand without modifying this type. Broader sets of accepted operations should be modeled as specific types. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D135586
Diffstat (limited to 'mlir/python')
-rw-r--r--mlir/python/CMakeLists.txt14
-rw-r--r--mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi26
-rw-r--r--mlir/python/mlir/dialects/_transform_ops_ext.py10
-rw-r--r--mlir/python/mlir/dialects/transform/__init__.py1
4 files changed, 51 insertions, 0 deletions
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index ecff102fe833..0a4c2f803641 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -121,6 +121,7 @@ declare_mlir_dialect_python_bindings(
SOURCES
dialects/_transform_ops_ext.py
dialects/transform/__init__.py
+ _mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform)
declare_mlir_dialect_extension_python_bindings(
@@ -353,6 +354,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
MLIRCAPISparseTensor
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
+ MODULE_NAME _mlirDialectsTransform
+ ADD_TO_PARENT MLIRPythonSources.Dialects.transform
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ SOURCES
+ DialectTransform.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPITransformDialect
+)
+
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
new file mode 100644
index 000000000000..2a29541734a8
--- /dev/null
+++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi
@@ -0,0 +1,26 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional
+
+from mlir.ir import Type, Context
+
+
+class AnyOpType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def get(context: Optional[Context] = None) -> AnyOpType: ...
+
+
+class OperationType(Type):
+ @staticmethod
+ def isinstance(type: Type) -> bool: ...
+
+ @staticmethod
+ def get(operation_name: str, context: Optional[Context] = None) -> OperationType: ...
+
+ @property
+ def operation_name(self) -> str: ...
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index 18cd3adb06f7..5cd57b050012 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -18,6 +18,16 @@ def _get_symbol_ref_attr(value: Union[Attribute, str]):
return FlatSymbolRefAttr.get(value)
+class CastOp:
+
+ def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ loc=loc,
+ ip=ip)
+
+
class GetClosestIsolatedParentOp:
def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index d4d71274c26c..78956c437004 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -18,3 +18,4 @@ class FailurePropagationMode(Enum):
return 2
from .._transform_ops_gen import *
+from ..._mlir_libs._mlirDialectsTransform import *