summaryrefslogtreecommitdiff
path: root/mlir/lib/IR
diff options
context:
space:
mode:
authorXiang Li <python3kgae@outlook.com>2023-02-12 12:19:17 -0500
committerXiang Li <python3kgae@outlook.com>2023-02-13 09:35:19 -0500
commitcc4fb5837647d913e1c1df5ff398be2896ea6d07 (patch)
tree81907e0c4c3e685cc8c09299c56f5cce42767f28 /mlir/lib/IR
parente9eaee9da196265d20dbeaf7920c24ccb33e2d04 (diff)
downloadllvm-cc4fb5837647d913e1c1df5ff398be2896ea6d07.tar.gz
[mlir] support complex type in DenseElementsAttr::get.
Fixes #60662 https://github.com/llvm/llvm-project/issues/60662 Allow ComplexType when create DenseElementsAttr. Also allow build ConstantOp for integer complex. Differential Revision: https://reviews.llvm.org/D143848
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/BuiltinAttributes.cpp37
1 files changed, 36 insertions, 1 deletions
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index ff4aa65fc888..b99ec22999fc 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -891,9 +891,44 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(hasSameElementsOrSplat(type, values));
+ Type eltType = type.getElementType();
+
+ // Take care complex type case first.
+ if (auto complexType = eltType.dyn_cast<ComplexType>()) {
+ if (complexType.getElementType().isIntOrIndex()) {
+ SmallVector<std::complex<APInt>> complexValues;
+ complexValues.reserve(values.size());
+ for (Attribute attr : values) {
+ assert(attr.isa<ArrayAttr>() &&
+ "expected ArrayAttr for complex");
+ auto arrayAttr = attr.cast<ArrayAttr>();
+ assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+ auto attr0 = arrayAttr[0];
+ auto attr1 = arrayAttr[1];
+ complexValues.push_back(
+ std::complex<APInt>(attr0.cast<IntegerAttr>().getValue(),
+ attr1.cast<IntegerAttr>().getValue()));
+ }
+ return DenseElementsAttr::get(type, complexValues);
+ }
+ // Must be float.
+ SmallVector<std::complex<APFloat>> complexValues;
+ complexValues.reserve(values.size());
+ for (Attribute attr : values) {
+ assert(attr.isa<ArrayAttr>() && "expected ArrayAttr for complex");
+ auto arrayAttr = attr.cast<ArrayAttr>();
+ assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+ auto attr0 = arrayAttr[0];
+ auto attr1 = arrayAttr[1];
+ complexValues.push_back(
+ std::complex<APFloat>(attr0.cast<FloatAttr>().getValue(),
+ attr1.cast<FloatAttr>().getValue()));
+ }
+ return DenseElementsAttr::get(type, complexValues);
+ }
+
// If the element type is not based on int/float/index, assume it is a string
// type.
- Type eltType = type.getElementType();
if (!eltType.isIntOrIndexOrFloat()) {
SmallVector<StringRef, 8> stringValues;
stringValues.reserve(values.size());