[Mlir-commits] [mlir] cc4fb58 - [mlir] support complex type in DenseElementsAttr::get.
Xiang Li
llvmlistbot at llvm.org
Mon Feb 13 06:41:30 PST 2023
Author: Xiang Li
Date: 2023-02-13T09:35:19-05:00
New Revision: cc4fb5837647d913e1c1df5ff398be2896ea6d07
URL: https://github.com/llvm/llvm-project/commit/cc4fb5837647d913e1c1df5ff398be2896ea6d07
DIFF: https://github.com/llvm/llvm-project/commit/cc4fb5837647d913e1c1df5ff398be2896ea6d07.diff
LOG: [mlir] support complex type in DenseElementsAttr::get.
Fixes #60662 https://github.com/llvm/llvm-project/issues/60662
Allow ComplexType when create DenseElementsAttr.
Also allow build ConstantOp for integer complex.
Differential Revision: https://reviews.llvm.org/D143848
Added:
Modified:
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index c71e2e08b6ff..28e121b6026a 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -32,10 +32,16 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
if (!complexTy || arrAttr.size() != 2)
return false;
auto complexEltTy = complexTy.getElementType();
- auto re = arrAttr[0].dyn_cast<FloatAttr>();
- auto im = arrAttr[1].dyn_cast<FloatAttr>();
- return re && im && re.getType() == complexEltTy &&
- im.getType() == complexEltTy;
+ if (auto fre = arrAttr[0].dyn_cast<FloatAttr>()) {
+ auto im = arrAttr[1].dyn_cast<FloatAttr>();
+ return im && fre.getType() == complexEltTy &&
+ im.getType() == complexEltTy;
+ }
+ if (auto ire = arrAttr[0].dyn_cast<IntegerAttr>()) {
+ auto im = arrAttr[1].dyn_cast<IntegerAttr>();
+ return im && ire.getType() == complexEltTy &&
+ im.getType() == complexEltTy;
+ }
}
return false;
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index ff4aa65fc888..b99ec22999fc 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -891,9 +891,44 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(hasSameElementsOrSplat(type, values));
+ Type eltType = type.getElementType();
+
+ // Take care complex type case first.
+ if (auto complexType = eltType.dyn_cast<ComplexType>()) {
+ if (complexType.getElementType().isIntOrIndex()) {
+ SmallVector<std::complex<APInt>> complexValues;
+ complexValues.reserve(values.size());
+ for (Attribute attr : values) {
+ assert(attr.isa<ArrayAttr>() &&
+ "expected ArrayAttr for complex");
+ auto arrayAttr = attr.cast<ArrayAttr>();
+ assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+ auto attr0 = arrayAttr[0];
+ auto attr1 = arrayAttr[1];
+ complexValues.push_back(
+ std::complex<APInt>(attr0.cast<IntegerAttr>().getValue(),
+ attr1.cast<IntegerAttr>().getValue()));
+ }
+ return DenseElementsAttr::get(type, complexValues);
+ }
+ // Must be float.
+ SmallVector<std::complex<APFloat>> complexValues;
+ complexValues.reserve(values.size());
+ for (Attribute attr : values) {
+ assert(attr.isa<ArrayAttr>() && "expected ArrayAttr for complex");
+ auto arrayAttr = attr.cast<ArrayAttr>();
+ assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+ auto attr0 = arrayAttr[0];
+ auto attr1 = arrayAttr[1];
+ complexValues.push_back(
+ std::complex<APFloat>(attr0.cast<FloatAttr>().getValue(),
+ attr1.cast<FloatAttr>().getValue()));
+ }
+ return DenseElementsAttr::get(type, complexValues);
+ }
+
// If the element type is not based on int/float/index, assume it is a string
// type.
- Type eltType = type.getElementType();
if (!eltType.isIntOrIndexOrFloat()) {
SmallVector<StringRef, 8> stringValues;
stringValues.reserve(values.size());
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6b7280453f97..2c0b87178b01 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -207,6 +207,34 @@ func.func @extract_from_tensor.from_elements_3d()
// -----
+// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
+// CHECK-NEXT: %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
+// CHECK-NEXT: return %cst : tensor<3xcomplex<i32>>
+func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
+ %c1 = arith.constant dense<(1, 2)> : tensor<complex<i32>>
+ %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
+ %c2 = arith.constant dense<(3, 2)> : tensor<complex<i32>>
+ %complex2 = tensor.extract %c2[] : tensor<complex<i32>>
+ %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<i32>>
+ return %tensor : tensor<3xcomplex<i32>>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
+// CHECK-NEXT: %cst = arith.constant dense<[(1.200000e+00,2.300000e+00), (3.200000e+00,2.100000e+00), (1.200000e+00,2.300000e+00)]> : tensor<3xcomplex<f32>>
+// CHECK-NEXT: return %cst : tensor<3xcomplex<f32>>
+func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
+ %c1 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
+ %complex1 = tensor.extract %c1[] : tensor<complex<f32>>
+ %c2 = arith.constant dense<(3.2, 2.1)> : tensor<complex<f32>>
+ %complex2 = tensor.extract %c2[] : tensor<complex<f32>>
+ %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<f32>>
+ return %tensor : tensor<3xcomplex<f32>>
+}
+
+// -----
+
// Ensure the optimization doesn't segfault from bad constants
// CHECK-LABEL: func @extract_negative_from_tensor.from_elements
func.func @extract_negative_from_tensor.from_elements(%element : index) -> index {
More information about the Mlir-commits
mailing list