[Mlir-commits] [clang] [clang-tools-extra] [llvm] [flang] [mlir] [lldb] [mlir][complex] Prevent underflow in complex.abs (PR #76316)

Kai Sasaki llvmlistbot at llvm.org
Thu Jan 25 16:56:01 PST 2024


https://github.com/Lewuathe updated https://github.com/llvm/llvm-project/pull/76316

>From a5810363e546da073543cb2d62cceb956c46b2e6 Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Fri, 15 Dec 2023 15:53:54 +0900
Subject: [PATCH 1/2] [mlir][complex] Prevent underflow in complex.abs

---
 .../ComplexToStandard/ComplexToStandard.cpp   |  56 ++++++---
 .../convert-to-standard.mlir                  | 115 ++++++++++++++----
 .../ComplexToStandard/full-conversion.mlir    |  25 +++-
 .../Dialect/Complex/CPU/correctness.mlir      |  38 ++++++
 4 files changed, 194 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index bf753c7062f366..7c1db57b55f996 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -26,29 +26,57 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    auto type = op.getType();
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
-    Value real =
-        rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
-    Value imag =
-        rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
-    Value realSqr =
-        rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
-    Value imagSqr =
-        rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
-    Value sqNorm =
-        rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
-
-    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
+    Type elementType = op.getType();
+    Value arg = adaptor.getComplex();
+
+    Value zero =
+        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+    Value one = b.create<arith::ConstantOp>(elementType,
+                                            b.getFloatAttr(elementType, 1.0));
+
+    Value real = b.create<complex::ReOp>(elementType, arg);
+    Value imag = b.create<complex::ImOp>(elementType, arg);
+
+    Value realIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+    Value imagIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+
+    // Real > Imag
+    Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
+    Value imagSq =
+        b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
+    Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
+    Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
+    Value absImag = b.create<arith::MulFOp>(imagSqrt, real, fmf.getValue());
+
+    // Real <= Imag
+    Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
+    Value realSq =
+        b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
+    Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
+    Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
+    Value absReal = b.create<arith::MulFOp>(realSqrt, imag, fmf.getValue());
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(
+        op, realIsZero, imag,
+        b.create<arith::SelectOp>(
+            imagIsZero, real,
+            b.create<arith::SelectOp>(
+                b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
+                absImag, absReal)));
+
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 3af28150fd5c3f..1028c9aae92c05 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -7,13 +7,28 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg: complex<f32>
   return %abs : f32
 }
+
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
 
 // -----
 
@@ -241,12 +256,26 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg: complex<f32>
   return %log : complex<f32>
 }
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -469,12 +498,26 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL2]], %[[REAL2]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG2]], %[[IMAG2]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL2]] : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG2]] : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL2]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG2]], %[[ABS2]] : f32
 // CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
 // CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
@@ -716,13 +759,27 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
   return %abs : f32
 }
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[REAL_SQ:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[REAL_SQ]], %[[IMAG_SQ]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: return %[[NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
+// CHECK: return %[[ABS3]] : f32
 
 // -----
 
@@ -807,12 +864,26 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg fastmath<nnan,contract> : complex<f32>
   return %log : complex<f32>
 }
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
+// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL]], %[[ABS1]] : f32
+// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG]], %[[ABS2]] : f32
 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] fastmath<nnan,contract> : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index 9983dd46f09433..d710dc8e1adeb7 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -6,12 +6,29 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg: complex<f32>
   return %abs : f32
 }
+// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
 // CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
 // CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
-// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]]  : f32
-// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]]  : f32
-// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]]  : f32
-// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32
+// CHECK: %[[REAL_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[REAL]], %[[ZERO]] : f32
+// CHECK: %[[IMAG_IS_ZERO:.*]] = llvm.fcmp "oeq" %[[IMAG]], %[[ZERO]] : f32
+
+// CHECK: %[[IMAG_DIV_REAL:.*]] = llvm.fdiv %[[IMAG]], %[[REAL]] : f32
+// CHECK: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]]  : f32
+// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = llvm.fadd %[[IMAG_SQ]], %[[ONE]] : f32
+// CHECK: %[[IMAG_SQRT:.*]] = llvm.intr.sqrt(%[[IMAG_SQ_PLUS_ONE]]) : (f32) -> f32
+// CHECK: %[[ABS_IMAG:.*]] = llvm.fmul %[[IMAG_SQRT]], %[[REAL]] : f32
+
+// CHECK: %[[REAL_DIV_IMAG:.*]] = llvm.fdiv %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
+// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = llvm.fadd %[[REAL_SQ]], %[[ONE]]  : f32
+// CHECK: %[[REAL_SQRT:.*]] = llvm.intr.sqrt(%[[REAL_SQ_PLUS_ONE]])  : (f32) -> f32
+// CHECK: %[[ABS_REAL:.*]] = llvm.fmul %[[REAL_SQRT]], %[[IMAG]]  : f32
+
+// CHECK: %[[REAL_GT_IMAG:.*]] = llvm.fcmp "ogt" %[[REAL]], %[[IMAG]] : f32
+// CHECK: %[[ABS1:.*]] = llvm.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : i1, f32
+// CHECK: %[[ABS2:.*]] = llvm.select %[[IMAG_IS_ZERO]], %[[REAL]], %[[ABS1]] : i1, f32
+// CHECK: %[[NORM:.*]] = llvm.select %[[REAL_IS_ZERO]], %[[IMAG]], %[[ABS2]] : i1, f32
 // CHECK: llvm.return %[[NORM]] : f32
 
 // CHECK-LABEL: llvm.func @complex_eq
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index 349b92a7aefa2e..b7849945b3cf49 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -106,6 +106,27 @@ func.func @angle(%arg: complex<f32>) -> f32 {
   func.return %angle : f32
 }
 
+func.func @test_element_f64(%input: tensor<?xcomplex<f64>>,
+                      %func: (complex<f64>) -> f64) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %size = tensor.dim %input, %c0: tensor<?xcomplex<f64>>
+
+  scf.for %i = %c0 to %size step %c1 {
+    %elem = tensor.extract %input[%i]: tensor<?xcomplex<f64>>
+
+    %val = func.call_indirect %func(%elem) : (complex<f64>) -> f64
+    vector.print %val : f64
+    scf.yield
+  }
+  func.return
+}
+
+func.func @abs(%arg: complex<f64>) -> f64 {
+  %abs = complex.abs %arg : complex<f64>
+  func.return %abs : f64
+}
+
 func.func @entry() {
   // complex.sqrt test
   %sqrt_test = arith.constant dense<[
@@ -300,5 +321,22 @@ func.func @entry() {
   call @test_element(%angle_test_cast, %angle_func)
     : (tensor<?xcomplex<f32>>, (complex<f32>) -> f32) -> ()
 
+  // complex.abs test
+  %abs_test = arith.constant dense<[
+    (1.0, 1.0),
+    // CHECK:  1.414
+    (1.0e300, 1.0e300),
+    // CHECK-NEXT:  1.41421e+300
+    (1.0e-300, 1.0e-300)
+    // CHECK-NEXT:  1.41421e-300
+  ]> : tensor<3xcomplex<f64>>
+  %abs_test_cast = tensor.cast %abs_test
+    :  tensor<3xcomplex<f64>> to tensor<?xcomplex<f64>>
+
+  %abs_func = func.constant @abs : (complex<f64>) -> f64
+
+  call @test_element_f64(%abs_test_cast, %abs_func)
+    : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
+
   func.return
 }

>From c4e8c8b2c67933896f2f42a86fc994811011b6ad Mon Sep 17 00:00:00 2001
From: Kai Sasaki <lewuathe at gmail.com>
Date: Wed, 24 Jan 2024 09:22:50 +0900
Subject: [PATCH 2/2] [mlir][complex] Add test case to check zero element
 handling

---
 .../Integration/Dialect/Complex/CPU/correctness.mlir | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index b7849945b3cf49..c8327e94def8ab 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -327,11 +327,17 @@ func.func @entry() {
     // CHECK:  1.414
     (1.0e300, 1.0e300),
     // CHECK-NEXT:  1.41421e+300
-    (1.0e-300, 1.0e-300)
+    (1.0e-300, 1.0e-300),
     // CHECK-NEXT:  1.41421e-300
-  ]> : tensor<3xcomplex<f64>>
+    (5.0, 0.0),
+    // CHECK-NEXT:  5
+    (0.0, 6.0),
+    // CHECK-NEXT:  6
+    (7.0, 8.0)
+    // CHECK-NEXT:  10.6301
+  ]> : tensor<6xcomplex<f64>>
   %abs_test_cast = tensor.cast %abs_test
-    :  tensor<3xcomplex<f64>> to tensor<?xcomplex<f64>>
+    :  tensor<6xcomplex<f64>> to tensor<?xcomplex<f64>>
 
   %abs_func = func.constant @abs : (complex<f64>) -> f64
 



More information about the Mlir-commits mailing list