summaryrefslogtreecommitdiff
path: root/mlir/python
diff options
context:
space:
mode:
authorJacques Pienaar <jpienaar@google.com>2022-12-21 16:22:39 -0800
committerJacques Pienaar <jpienaar@google.com>2022-12-21 16:22:39 -0800
commitb57acb9a405c289069345a498ebfc1d1b9b110de (patch)
tree384faebf8c4f2c0e4ef4ce2e25ffbbabf5bf7eea /mlir/python
parent02f4cfa33d3d970b8ad3ddeda73f59785ab19984 (diff)
downloadllvm-b57acb9a405c289069345a498ebfc1d1b9b110de.tar.gz
Revert "Revert "[mlir][py] Enable building ops with raw inputs""
Fix Python 3.6.9 issue encountered due to type checking here. Will add back in follow up. This reverts commit 1f47fee2948ef48781084afe0426171d000d7997.
Diffstat (limited to 'mlir/python')
-rw-r--r--mlir/python/mlir/ir.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 99e88ff74384..82468e8b76b4 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,3 +4,44 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
+
+
+# Convenience decorator for registering user-friendly Attribute builders.
+def register_attribute_builder(kind):
+ def decorator_builder(func):
+ AttrBuilder.insert(kind, func)
+ return func
+ return decorator_builder
+
+
+@register_attribute_builder("BoolAttr")
+def _boolAttr(x, context):
+ return BoolAttr.get(x, context=context)
+
+@register_attribute_builder("IndexAttr")
+def _indexAttr(x, context):
+ return IntegerAttr.get(IndexType.get(context=context), x)
+
+@register_attribute_builder("I32Attr")
+def _i32Attr(x, context):
+ return IntegerAttr.get(
+ IntegerType.get_signless(32, context=context), x)
+
+@register_attribute_builder("I64Attr")
+def _i64Attr(x, context):
+ return IntegerAttr.get(
+ IntegerType.get_signless(64, context=context), x)
+
+@register_attribute_builder("SymbolNameAttr")
+def _symbolNameAttr(x, context):
+ return StringAttr.get(x, context=context)
+
+try:
+ import numpy as np
+ @register_attribute_builder("IndexElementsAttr")
+ def _indexElementsAttr(x, context):
+ return DenseElementsAttr.get(
+ np.array(x, dtype=np.int64), type=IndexType.get(context=context),
+ context=context)
+except ImportError:
+ pass