diff options
author | Alex Zinenko <zinenko@google.com> | 2022-10-10 14:33:59 +0000 |
---|---|---|
committer | Alex Zinenko <zinenko@google.com> | 2022-10-11 09:55:19 +0000 |
commit | 3e1f6d02f755e34a0a12a8dd439fb65f84d6621f (patch) | |
tree | acf2531c582ed1850f71c92a6fab79ce9cdd1e3f /mlir/python | |
parent | 6bb997c032f509d4f61461aa461c8431325cbb2a (diff) | |
download | llvm-3e1f6d02f755e34a0a12a8dd439fb65f84d6621f.tar.gz |
[mlir] add OperationType to the Transform dialect
Add a new OperationType handle type to the Transform dialect. This
transform type is parameterized by the name of the payload operation it
can point to. It is intended as a constraint on transformations that are
only applicable to a specific kind of payload operations. If a
transformation is applicable to a small set of operation classes, it can
be wrapped into a transform op by using a disjunctive constraint, such
as `Type<Or<[Transform_ConcreteOperation<"foo">.predicate,
Transform_ConcreteOperation<"bar">.predicate]>>` for its operand without
modifying this type. Broader sets of accepted operations should be
modeled as specific types.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D135586
Diffstat (limited to 'mlir/python')
-rw-r--r-- | mlir/python/CMakeLists.txt | 14 | ||||
-rw-r--r-- | mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi | 26 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/_transform_ops_ext.py | 10 | ||||
-rw-r--r-- | mlir/python/mlir/dialects/transform/__init__.py | 1 |
4 files changed, 51 insertions, 0 deletions
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index ecff102fe833..0a4c2f803641 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -121,6 +121,7 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/_transform_ops_ext.py dialects/transform/__init__.py + _mlir_libs/_mlir/dialects/transform/__init__.pyi DIALECT_NAME transform) declare_mlir_dialect_extension_python_bindings( @@ -353,6 +354,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind MLIRCAPISparseTensor ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind + MODULE_NAME _mlirDialectsTransform + ADD_TO_PARENT MLIRPythonSources.Dialects.transform + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectTransform.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPITransformDialect +) + declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi new file mode 100644 index 000000000000..2a29541734a8 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi @@ -0,0 +1,26 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +from mlir.ir import Type, Context + + +class AnyOpType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> AnyOpType: ... + + +class OperationType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(operation_name: str, context: Optional[Context] = None) -> OperationType: ... + + @property + def operation_name(self) -> str: ... diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 18cd3adb06f7..5cd57b050012 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -18,6 +18,16 @@ def _get_symbol_ref_attr(value: Union[Attribute, str]): return FlatSymbolRefAttr.get(value) +class CastOp: + + def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + result_type, + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + class GetClosestIsolatedParentOp: def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index d4d71274c26c..78956c437004 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -18,3 +18,4 @@ class FailurePropagationMode(Enum): return 2 from .._transform_ops_gen import * +from ..._mlir_libs._mlirDialectsTransform import * |