[Mlir-commits] [mlir] cc4fb58 - [mlir] support complex type in DenseElementsAttr::get.

Xiang Li llvmlistbot at llvm.org
Mon Feb 13 06:41:30 PST 2023


Author: Xiang Li
Date: 2023-02-13T09:35:19-05:00
New Revision: cc4fb5837647d913e1c1df5ff398be2896ea6d07

URL: https://github.com/llvm/llvm-project/commit/cc4fb5837647d913e1c1df5ff398be2896ea6d07
DIFF: https://github.com/llvm/llvm-project/commit/cc4fb5837647d913e1c1df5ff398be2896ea6d07.diff

LOG: [mlir] support complex type in DenseElementsAttr::get.

Fixes #60662 https://github.com/llvm/llvm-project/issues/60662

Allow ComplexType when create DenseElementsAttr.
Also allow build ConstantOp for integer complex.

Differential Revision: https://reviews.llvm.org/D143848

Added: 
    

Modified: 
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index c71e2e08b6ff..28e121b6026a 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -32,10 +32,16 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
     if (!complexTy || arrAttr.size() != 2)
       return false;
     auto complexEltTy = complexTy.getElementType();
-    auto re = arrAttr[0].dyn_cast<FloatAttr>();
-    auto im = arrAttr[1].dyn_cast<FloatAttr>();
-    return re && im && re.getType() == complexEltTy &&
-           im.getType() == complexEltTy;
+    if (auto fre = arrAttr[0].dyn_cast<FloatAttr>()) {
+      auto im = arrAttr[1].dyn_cast<FloatAttr>();
+      return im && fre.getType() == complexEltTy &&
+             im.getType() == complexEltTy;
+    }
+    if (auto ire = arrAttr[0].dyn_cast<IntegerAttr>()) {
+      auto im = arrAttr[1].dyn_cast<IntegerAttr>();
+      return im && ire.getType() == complexEltTy &&
+             im.getType() == complexEltTy;
+    }
   }
   return false;
 }

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index ff4aa65fc888..b99ec22999fc 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -891,9 +891,44 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
   assert(hasSameElementsOrSplat(type, values));
 
+  Type eltType = type.getElementType();
+
+  // Take care complex type case first.
+  if (auto complexType = eltType.dyn_cast<ComplexType>()) {
+    if (complexType.getElementType().isIntOrIndex()) {
+      SmallVector<std::complex<APInt>> complexValues;
+      complexValues.reserve(values.size());
+      for (Attribute attr : values) {
+        assert(attr.isa<ArrayAttr>() &&
+               "expected ArrayAttr for complex");
+        auto arrayAttr = attr.cast<ArrayAttr>();
+        assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+        auto attr0 = arrayAttr[0];
+        auto attr1 = arrayAttr[1];
+        complexValues.push_back(
+            std::complex<APInt>(attr0.cast<IntegerAttr>().getValue(),
+                                attr1.cast<IntegerAttr>().getValue()));
+      }
+      return DenseElementsAttr::get(type, complexValues);
+    }
+    // Must be float.
+    SmallVector<std::complex<APFloat>> complexValues;
+    complexValues.reserve(values.size());
+    for (Attribute attr : values) {
+      assert(attr.isa<ArrayAttr>() && "expected ArrayAttr for complex");
+      auto arrayAttr = attr.cast<ArrayAttr>();
+      assert(arrayAttr.size() == 2 && "expected 2 element for complex");
+      auto attr0 = arrayAttr[0];
+      auto attr1 = arrayAttr[1];
+      complexValues.push_back(
+          std::complex<APFloat>(attr0.cast<FloatAttr>().getValue(),
+                                attr1.cast<FloatAttr>().getValue()));
+    }
+    return DenseElementsAttr::get(type, complexValues);
+  }
+
   // If the element type is not based on int/float/index, assume it is a string
   // type.
-  Type eltType = type.getElementType();
   if (!eltType.isIntOrIndexOrFloat()) {
     SmallVector<StringRef, 8> stringValues;
     stringValues.reserve(values.size());

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6b7280453f97..2c0b87178b01 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -207,6 +207,34 @@ func.func @extract_from_tensor.from_elements_3d()
 
 // -----
 
+// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
+// CHECK-NEXT:  %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
+// CHECK-NEXT:  return %cst : tensor<3xcomplex<i32>>
+func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
+  %c1 = arith.constant dense<(1, 2)> : tensor<complex<i32>>
+  %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
+  %c2 = arith.constant dense<(3, 2)> : tensor<complex<i32>>
+  %complex2 = tensor.extract %c2[] : tensor<complex<i32>>
+  %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<i32>>
+  return %tensor : tensor<3xcomplex<i32>>
+}
+
+// -----
+
+// CHECK-LABEL:  func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
+// CHECK-NEXT:   %cst = arith.constant dense<[(1.200000e+00,2.300000e+00), (3.200000e+00,2.100000e+00), (1.200000e+00,2.300000e+00)]> : tensor<3xcomplex<f32>>
+// CHECK-NEXT:   return %cst : tensor<3xcomplex<f32>>
+func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
+  %c1 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
+  %complex1 = tensor.extract %c1[] : tensor<complex<f32>>
+  %c2 = arith.constant dense<(3.2, 2.1)> : tensor<complex<f32>>
+  %complex2 = tensor.extract %c2[] : tensor<complex<f32>>
+  %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<f32>>
+  return %tensor : tensor<3xcomplex<f32>>
+}
+
+// -----
+
 // Ensure the optimization doesn't segfault from bad constants
 // CHECK-LABEL: func @extract_negative_from_tensor.from_elements
 func.func @extract_negative_from_tensor.from_elements(%element : index) -> index {


        


More information about the Mlir-commits mailing list