summaryrefslogtreecommitdiff
path: root/mlir/python
diff options
context:
space:
mode:
authorJacques Pienaar <jpienaar@google.com>2022-12-21 10:10:31 -0800
committerJacques Pienaar <jpienaar@google.com>2022-12-21 10:10:31 -0800
commit3781b7905d8d808e5d4e97d597263f8ac48541b8 (patch)
tree20442b66107661aa1a1db33ec328dc2c9d5ba4dd /mlir/python
parent2b60ed405b8110b20ab2e383839759ea34003127 (diff)
downloadllvm-3781b7905d8d808e5d4e97d597263f8ac48541b8.tar.gz
[mlir][py] Enable building ops with raw inputs
For cases where we can automatically construct the Attribute allow for more user-friendly input. This is consistent with C++ builder generation as well choice of which single builder to generate here (most specialized/user-friendly). Registration of attribute builders from more pythonic input is all Python side. The downside is that * extra checking to see if user provided a custom builder in op builders, * the ODS attribute name is load bearing upside is that * easily change these/register dialect specific ones in downstream projects, * adding support/changing to different convenience builders are all along with the rest of the convenience functions in Python (and no additional changes to tablegen file or recompilation needed); Allow for both building with Attributes as well as raw inputs. This change should therefore be backwards compatible as well as allow for avoiding recreating Attribute where already available. Differential Revision: https://reviews.llvm.org/D139568
Diffstat (limited to 'mlir/python')
-rw-r--r--mlir/python/mlir/ir.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 99e88ff74384..19986917d69b 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,3 +4,44 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
+
+
+# Convenience decorator for registering user-friendly Attribute builders.
+def register_attribute_builder(kind):
+ def decorator_builder(func):
+ AttrBuilder.insert(kind, func)
+ return func
+ return decorator_builder
+
+
+@register_attribute_builder("BoolAttr")
+def _boolAttr(x: bool, context: Context):
+ return BoolAttr.get(x, context=context)
+
+@register_attribute_builder("IndexAttr")
+def _indexAttr(x: int, context: Context):
+ return IntegerAttr.get(IndexType.get(context=context), x)
+
+@register_attribute_builder("I32Attr")
+def _i32Attr(x: int, context: Context):
+ return IntegerAttr.get(
+ IntegerType.get_signless(32, context=context), x)
+
+@register_attribute_builder("I64Attr")
+def _i64Attr(x: int, context: Context):
+ return IntegerAttr.get(
+ IntegerType.get_signless(64, context=context), x)
+
+@register_attribute_builder("SymbolNameAttr")
+def _symbolNameAttr(x: str, context: Context):
+ return StringAttr.get(x, context=context)
+
+try:
+ import numpy as np
+ @register_attribute_builder("IndexElementsAttr")
+ def _indexElementsAttr(x: list[int], context: Context):
+ return DenseElementsAttr.get(
+ np.array(x, dtype=np.int64), type=IndexType.get(context=context),
+ context=context)
+except ImportError:
+ pass