diff options
author | Xiang Li <python3kgae@outlook.com> | 2023-02-12 12:19:17 -0500 |
---|---|---|
committer | Xiang Li <python3kgae@outlook.com> | 2023-02-13 09:35:19 -0500 |
commit | cc4fb5837647d913e1c1df5ff398be2896ea6d07 (patch) | |
tree | 81907e0c4c3e685cc8c09299c56f5cce42767f28 /mlir/lib/IR | |
parent | e9eaee9da196265d20dbeaf7920c24ccb33e2d04 (diff) | |
download | llvm-cc4fb5837647d913e1c1df5ff398be2896ea6d07.tar.gz |
[mlir] support complex type in DenseElementsAttr::get.
Fixes #60662 https://github.com/llvm/llvm-project/issues/60662
Allow ComplexType when create DenseElementsAttr.
Also allow build ConstantOp for integer complex.
Differential Revision: https://reviews.llvm.org/D143848
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r-- | mlir/lib/IR/BuiltinAttributes.cpp | 37 |
1 files changed, 36 insertions, 1 deletions
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index ff4aa65fc888..b99ec22999fc 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -891,9 +891,44 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<Attribute> values) { assert(hasSameElementsOrSplat(type, values)); + Type eltType = type.getElementType(); + + // Take care complex type case first. + if (auto complexType = eltType.dyn_cast<ComplexType>()) { + if (complexType.getElementType().isIntOrIndex()) { + SmallVector<std::complex<APInt>> complexValues; + complexValues.reserve(values.size()); + for (Attribute attr : values) { + assert(attr.isa<ArrayAttr>() && + "expected ArrayAttr for complex"); + auto arrayAttr = attr.cast<ArrayAttr>(); + assert(arrayAttr.size() == 2 && "expected 2 element for complex"); + auto attr0 = arrayAttr[0]; + auto attr1 = arrayAttr[1]; + complexValues.push_back( + std::complex<APInt>(attr0.cast<IntegerAttr>().getValue(), + attr1.cast<IntegerAttr>().getValue())); + } + return DenseElementsAttr::get(type, complexValues); + } + // Must be float. + SmallVector<std::complex<APFloat>> complexValues; + complexValues.reserve(values.size()); + for (Attribute attr : values) { + assert(attr.isa<ArrayAttr>() && "expected ArrayAttr for complex"); + auto arrayAttr = attr.cast<ArrayAttr>(); + assert(arrayAttr.size() == 2 && "expected 2 element for complex"); + auto attr0 = arrayAttr[0]; + auto attr1 = arrayAttr[1]; + complexValues.push_back( + std::complex<APFloat>(attr0.cast<FloatAttr>().getValue(), + attr1.cast<FloatAttr>().getValue())); + } + return DenseElementsAttr::get(type, complexValues); + } + // If the element type is not based on int/float/index, assume it is a string // type. - Type eltType = type.getElementType(); if (!eltType.isIntOrIndexOrFloat()) { SmallVector<StringRef, 8> stringValues; stringValues.reserve(values.size()); |