diff options
author | Alex Zinenko <zinenko@google.com> | 2023-01-04 14:04:53 +0000 |
---|---|---|
committer | Alex Zinenko <zinenko@google.com> | 2023-01-19 10:19:37 +0000 |
commit | 88c5027b93a9f447a8b3ce02e5d74f1c10c14da1 (patch) | |
tree | e46c4fd4f2483ca5dc1fffea34d8d90df932b1a0 /mlir/python | |
parent | 883c117d1a4cce3c19aa521fccaf8f938269fc57 (diff) | |
download | llvm-88c5027b93a9f447a8b3ce02e5d74f1c10c14da1.tar.gz |
[mlir] make multi-size tiling use transform parameters
Use the recently introduced transform dialect parameter mechanism to
perform controllable multi-size tiling with sizes computed at the
transformation time rather than at runtime.
This requires to generalize tile and split structured transform
operations to work with any transform dialect handle types, which is
desirable in itself to avoid unchecked overuse of PDL OperationType.
Reviewed By: shabalin
Differential Revision: https://reviews.llvm.org/D140980
Diffstat (limited to 'mlir/python')
-rw-r--r-- | mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 73 |
1 files changed, 55 insertions, 18 deletions
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 2525ea34c375..f045e5c13c1e 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -5,11 +5,11 @@ try: from ..ir import * from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ..dialects import pdl + from ..dialects import pdl, transform except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, overload IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] @@ -51,13 +51,13 @@ def _get_int_array_attr( def _get_dense_int64_array_attr( values: Sequence[int]) -> DenseI64ArrayAttr: - """Creates a dense integer array from a sequence of integers. + """Creates a dense integer array from a sequence of integers. Expects the thread-local MLIR context to have been set by the context manager. """ - if values is None: - return DenseI64ArrayAttr.get([]) - return DenseI64ArrayAttr.get(values) + if values is None: + return DenseI64ArrayAttr.get([]) + return DenseI64ArrayAttr.get(values) def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, @@ -141,6 +141,7 @@ class MultiTileSizesOp: """Specialization for MultitileSizesOp class.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, dimension: Union[int, IntegerAttr], @@ -149,9 +150,9 @@ class MultiTileSizesOp: loc=None, ip=None): super().__init__( - pdl.OperationType.get(), - pdl.OperationType.get(), - pdl.OperationType.get(), + result_type, + result_type, + result_type, _get_op_result_or_value(target), dimension=_get_int64_attr(dimension), target_size=_get_int64_attr(target_size), @@ -223,11 +224,12 @@ class SplitOp: static_split_point = _get_int64_attr(ShapedType.get_dynamic_size()) dynamic_split_point = _get_op_result_or_value(split_point) - pdl_operation_type = pdl.OperationType.get() + target = _get_op_result_or_value(target) + super().__init__( - pdl_operation_type, - pdl_operation_type, - _get_op_result_or_value(target), + target.type, + target.type, + target, dimension=dimension, static_split_point=static_split_point, dynamic_split_point=dynamic_split_point, @@ -238,7 +240,9 @@ class SplitOp: class TileOp: """Specialization for TileOp class.""" + @overload def __init__(self, + loop_types: Union[Type, List[Type]], target: Union[Operation, Value], *, sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, @@ -246,9 +250,28 @@ class TileOp: interchange: OptionalIntList = None, loc=None, ip=None): - pdl_operation_type = pdl.OperationType.get() - i64_type = IntegerType.get_signless(64) + ... + @overload + def __init__(self, + target: Union[Operation, Value], + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, + Value]], ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None): + ... + + def __init__(self, + loop_types_or_target: Union[Type, List[Type], Operation, Value], + target_or_none: Optional[Union[Operation, Value]] = None, + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, + Value]], ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None): if sizes is None: sizes = [] @@ -267,12 +290,26 @@ class TileOp: num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) + + if isinstance(loop_types_or_target, (Operation, Value)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct TileOp with two targets." + else: + loop_types = ([loop_types_or_target] * num_loops) if isinstance( + loop_types_or_target, Type) else loop_types_or_target + target = target_or_none + + target = _get_op_result_or_value(target) + super().__init__( - pdl_operation_type, [pdl_operation_type] * num_loops, - _get_op_result_or_value(target), + target.type, + loop_types, + target, dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, - interchange=_get_dense_int64_array_attr(interchange) if interchange else None, + interchange=_get_dense_int64_array_attr(interchange) + if interchange else None, loc=loc, ip=ip) |