diff options
author | Jacques Pienaar <jpienaar@google.com> | 2022-12-21 10:10:31 -0800 |
---|---|---|
committer | Jacques Pienaar <jpienaar@google.com> | 2022-12-21 10:10:31 -0800 |
commit | 3781b7905d8d808e5d4e97d597263f8ac48541b8 (patch) | |
tree | 20442b66107661aa1a1db33ec328dc2c9d5ba4dd /mlir/python | |
parent | 2b60ed405b8110b20ab2e383839759ea34003127 (diff) | |
download | llvm-3781b7905d8d808e5d4e97d597263f8ac48541b8.tar.gz |
[mlir][py] Enable building ops with raw inputs
For cases where we can automatically construct the Attribute allow for more
user-friendly input. This is consistent with C++ builder generation as well
choice of which single builder to generate here (most
specialized/user-friendly).
Registration of attribute builders from more pythonic input is all Python side.
The downside is that
* extra checking to see if user provided a custom builder in op builders,
* the ODS attribute name is load bearing
upside is that
* easily change these/register dialect specific ones in downstream projects,
* adding support/changing to different convenience builders are all along with
the rest of the convenience functions in Python (and no additional changes
to tablegen file or recompilation needed);
Allow for both building with Attributes as well as raw inputs. This change
should therefore be backwards compatible as well as allow for avoiding
recreating Attribute where already available.
Differential Revision: https://reviews.llvm.org/D139568
Diffstat (limited to 'mlir/python')
-rw-r--r-- | mlir/python/mlir/ir.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 99e88ff74384..19986917d69b 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,3 +4,44 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug + + +# Convenience decorator for registering user-friendly Attribute builders. +def register_attribute_builder(kind): + def decorator_builder(func): + AttrBuilder.insert(kind, func) + return func + return decorator_builder + + +@register_attribute_builder("BoolAttr") +def _boolAttr(x: bool, context: Context): + return BoolAttr.get(x, context=context) + +@register_attribute_builder("IndexAttr") +def _indexAttr(x: int, context: Context): + return IntegerAttr.get(IndexType.get(context=context), x) + +@register_attribute_builder("I32Attr") +def _i32Attr(x: int, context: Context): + return IntegerAttr.get( + IntegerType.get_signless(32, context=context), x) + +@register_attribute_builder("I64Attr") +def _i64Attr(x: int, context: Context): + return IntegerAttr.get( + IntegerType.get_signless(64, context=context), x) + +@register_attribute_builder("SymbolNameAttr") +def _symbolNameAttr(x: str, context: Context): + return StringAttr.get(x, context=context) + +try: + import numpy as np + @register_attribute_builder("IndexElementsAttr") + def _indexElementsAttr(x: list[int], context: Context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), type=IndexType.get(context=context), + context=context) +except ImportError: + pass |