diff options
author | Jacques Pienaar <jpienaar@google.com> | 2022-12-21 16:22:39 -0800 |
---|---|---|
committer | Jacques Pienaar <jpienaar@google.com> | 2022-12-21 16:22:39 -0800 |
commit | b57acb9a405c289069345a498ebfc1d1b9b110de (patch) | |
tree | 384faebf8c4f2c0e4ef4ce2e25ffbbabf5bf7eea /mlir/python | |
parent | 02f4cfa33d3d970b8ad3ddeda73f59785ab19984 (diff) | |
download | llvm-b57acb9a405c289069345a498ebfc1d1b9b110de.tar.gz |
Revert "Revert "[mlir][py] Enable building ops with raw inputs""
Fix Python 3.6.9 issue encountered due to type checking here. Will
add back in follow up.
This reverts commit 1f47fee2948ef48781084afe0426171d000d7997.
Diffstat (limited to 'mlir/python')
-rw-r--r-- | mlir/python/mlir/ir.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 99e88ff74384..82468e8b76b4 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,3 +4,44 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug + + +# Convenience decorator for registering user-friendly Attribute builders. +def register_attribute_builder(kind): + def decorator_builder(func): + AttrBuilder.insert(kind, func) + return func + return decorator_builder + + +@register_attribute_builder("BoolAttr") +def _boolAttr(x, context): + return BoolAttr.get(x, context=context) + +@register_attribute_builder("IndexAttr") +def _indexAttr(x, context): + return IntegerAttr.get(IndexType.get(context=context), x) + +@register_attribute_builder("I32Attr") +def _i32Attr(x, context): + return IntegerAttr.get( + IntegerType.get_signless(32, context=context), x) + +@register_attribute_builder("I64Attr") +def _i64Attr(x, context): + return IntegerAttr.get( + IntegerType.get_signless(64, context=context), x) + +@register_attribute_builder("SymbolNameAttr") +def _symbolNameAttr(x, context): + return StringAttr.get(x, context=context) + +try: + import numpy as np + @register_attribute_builder("IndexElementsAttr") + def _indexElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), type=IndexType.get(context=context), + context=context) +except ImportError: + pass |