summaryrefslogtreecommitdiff
path: root/mlir/python
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2023-01-04 14:04:53 +0000
committerAlex Zinenko <zinenko@google.com>2023-01-19 10:19:37 +0000
commit88c5027b93a9f447a8b3ce02e5d74f1c10c14da1 (patch)
treee46c4fd4f2483ca5dc1fffea34d8d90df932b1a0 /mlir/python
parent883c117d1a4cce3c19aa521fccaf8f938269fc57 (diff)
downloadllvm-88c5027b93a9f447a8b3ce02e5d74f1c10c14da1.tar.gz
[mlir] make multi-size tiling use transform parameters
Use the recently introduced transform dialect parameter mechanism to perform controllable multi-size tiling with sizes computed at the transformation time rather than at runtime. This requires to generalize tile and split structured transform operations to work with any transform dialect handle types, which is desirable in itself to avoid unchecked overuse of PDL OperationType. Reviewed By: shabalin Differential Revision: https://reviews.llvm.org/D140980
Diffstat (limited to 'mlir/python')
-rw-r--r--mlir/python/mlir/dialects/_structured_transform_ops_ext.py73
1 files changed, 55 insertions, 18 deletions
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 2525ea34c375..f045e5c13c1e 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -5,11 +5,11 @@
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
- from ..dialects import pdl
+ from ..dialects import pdl, transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import List, Optional, Sequence, Union
+from typing import List, Optional, Sequence, Union, overload
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
@@ -51,13 +51,13 @@ def _get_int_array_attr(
def _get_dense_int64_array_attr(
values: Sequence[int]) -> DenseI64ArrayAttr:
- """Creates a dense integer array from a sequence of integers.
+ """Creates a dense integer array from a sequence of integers.
Expects the thread-local MLIR context to have been set by the context
manager.
"""
- if values is None:
- return DenseI64ArrayAttr.get([])
- return DenseI64ArrayAttr.get(values)
+ if values is None:
+ return DenseI64ArrayAttr.get([])
+ return DenseI64ArrayAttr.get(values)
def _get_int_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
@@ -141,6 +141,7 @@ class MultiTileSizesOp:
"""Specialization for MultitileSizesOp class."""
def __init__(self,
+ result_type: Type,
target: Union[Operation, Value],
*,
dimension: Union[int, IntegerAttr],
@@ -149,9 +150,9 @@ class MultiTileSizesOp:
loc=None,
ip=None):
super().__init__(
- pdl.OperationType.get(),
- pdl.OperationType.get(),
- pdl.OperationType.get(),
+ result_type,
+ result_type,
+ result_type,
_get_op_result_or_value(target),
dimension=_get_int64_attr(dimension),
target_size=_get_int64_attr(target_size),
@@ -223,11 +224,12 @@ class SplitOp:
static_split_point = _get_int64_attr(ShapedType.get_dynamic_size())
dynamic_split_point = _get_op_result_or_value(split_point)
- pdl_operation_type = pdl.OperationType.get()
+ target = _get_op_result_or_value(target)
+
super().__init__(
- pdl_operation_type,
- pdl_operation_type,
- _get_op_result_or_value(target),
+ target.type,
+ target.type,
+ target,
dimension=dimension,
static_split_point=static_split_point,
dynamic_split_point=dynamic_split_point,
@@ -238,7 +240,9 @@ class SplitOp:
class TileOp:
"""Specialization for TileOp class."""
+ @overload
def __init__(self,
+ loop_types: Union[Type, List[Type]],
target: Union[Operation, Value],
*,
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
@@ -246,9 +250,28 @@ class TileOp:
interchange: OptionalIntList = None,
loc=None,
ip=None):
- pdl_operation_type = pdl.OperationType.get()
- i64_type = IntegerType.get_signless(64)
+ ...
+ @overload
+ def __init__(self,
+ target: Union[Operation, Value],
+ *,
+ sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
+ Value]], ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None):
+ ...
+
+ def __init__(self,
+ loop_types_or_target: Union[Type, List[Type], Operation, Value],
+ target_or_none: Optional[Union[Operation, Value]] = None,
+ *,
+ sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
+ Value]], ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None):
if sizes is None:
sizes = []
@@ -267,12 +290,26 @@ class TileOp:
num_loops = sum(
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
+
+ if isinstance(loop_types_or_target, (Operation, Value)):
+ loop_types = [transform.AnyOpType.get()] * num_loops
+ target = loop_types_or_target
+ assert target_or_none is None, "Cannot construct TileOp with two targets."
+ else:
+ loop_types = ([loop_types_or_target] * num_loops) if isinstance(
+ loop_types_or_target, Type) else loop_types_or_target
+ target = target_or_none
+
+ target = _get_op_result_or_value(target)
+
super().__init__(
- pdl_operation_type, [pdl_operation_type] * num_loops,
- _get_op_result_or_value(target),
+ target.type,
+ loop_types,
+ target,
dynamic_sizes=dynamic_sizes,
static_sizes=sizes_attr,
- interchange=_get_dense_int64_array_attr(interchange) if interchange else None,
+ interchange=_get_dense_int64_array_attr(interchange)
+ if interchange else None,
loc=loc,
ip=ip)