[Mlir-commits] [mlir] [mlir][complex] Allow integer element types in `complex.constant` ops (PR #74564)

Matthias Springer llvmlistbot at llvm.org
Tue Dec 5 22:11:53 PST 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/74564

The op used to support only float element types. This was inconsistent with `ConstantOp::isBuildableWith`, which allows integer element types. The complex type allows any float/integer element type.

Note: The other complex dialect ops do not support non-float element types yet. The purpose of this change to fix `Tensor/canonicalize.mlir`, which is currently failing when verifying the IR after each pattern application (#74270).

```
within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: error: 'complex.constant' op result #0 must be complex type with floating-point elements, but got 'complex<i32>'
  %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
              ^
within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: note: see current operation: %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32>
"func.func"() <{function_type = () -> tensor<3xcomplex<i32>>, sym_name = "extract_from_elements_complex_i"}> ({
  %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32>
  %1 = "arith.constant"() <{value = dense<(3,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>>
  %2 = "arith.constant"() <{value = dense<(1,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>>
  %3 = "tensor.extract"(%1) : (tensor<complex<i32>>) -> complex<i32>
  %4 = "tensor.from_elements"(%0, %3, %0) : (complex<i32>, complex<i32>, complex<i32>) -> tensor<3xcomplex<i32>>
  "func.return"(%4) : (tensor<3xcomplex<i32>>) -> ()
}) : () -> ()
```

>From b1904a55a910dbe1e9f97302489431fbb5fad691 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 6 Dec 2023 15:07:36 +0900
Subject: [PATCH] [mlir][complex] Allow integer element types in
 `complex.constant` ops

The op used to support only float element types. This was inconsistent with `ConstantOp::isBuildableWith`, which allows integer element types. The complex type allows any float/integer element type.

Note: The other complex dialect ops do not support non-float element types yet. The purpose of this change to fix `Tensor/canonicalize.mlir`, which is currently failing when verifying the IR after each pattern application (#74270).

```
within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: error: 'complex.constant' op result #0 must be complex type with floating-point elements, but got 'complex<i32>'
  %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
              ^
within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: note: see current operation: %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32>
"func.func"() <{function_type = () -> tensor<3xcomplex<i32>>, sym_name = "extract_from_elements_complex_i"}> ({
  %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32>
  %1 = "arith.constant"() <{value = dense<(3,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>>
  %2 = "arith.constant"() <{value = dense<(1,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>>
  %3 = "tensor.extract"(%1) : (tensor<complex<i32>>) -> complex<i32>
  %4 = "tensor.from_elements"(%0, %3, %0) : (complex<i32>, complex<i32>, complex<i32>) -> tensor<3xcomplex<i32>>
  "func.return"(%4) : (tensor<3xcomplex<i32>>) -> ()
}) : () -> ()
```
---
 mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td |  2 +-
 mlir/lib/Dialect/Complex/IR/ComplexOps.cpp         | 10 ++++++----
 mlir/test/Dialect/Complex/ops.mlir                 |  3 +++
 3 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index ada6c14b5b713..e19d714cadf8a 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -145,7 +145,7 @@ def ConstantOp : Complex_Op<"constant", [
   }];
 
   let arguments = (ins ArrayAttr:$value);
-  let results = (outs Complex<AnyFloat>:$complex);
+  let results = (outs AnyComplex:$complex);
 
   let assemblyFormat = "$value attr-dict `:` type($complex)";
   let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..0557de65ff43c 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -58,10 +58,12 @@ LogicalResult ConstantOp::verify() {
   }
 
   auto complexEltTy = getType().getElementType();
-  auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
-  auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
-  if (!re || !im)
-    return emitOpError("requires attribute's elements to be float attributes");
+  if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
+      !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
+    return emitOpError(
+        "requires attribute's elements to be float or integer attributes");
+  auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
+  auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
   if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
     return emitOpError()
            << "requires attribute's element types (" << re.getType() << ", "
diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 1050ad0dcd530..96f17b2898c83 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -11,6 +11,9 @@ func.func @ops(%f: f32) {
   // CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
   %cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>
 
+  // CHECK: complex.constant [true, false] : complex<i1>
+  %cst_i1 = complex.constant [1 : i1, 0 : i1] : complex<i1>
+
   // CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
   %complex = complex.create %f, %f : complex<f32>
 



More information about the Mlir-commits mailing list