[Mlir-commits] [mlir] Fix complex abs corner cases. (PR #88373)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 11 03:00:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Johannes Reifferscheid (jreiffers)

<details>
<summary>Changes</summary>

The current implementation fails for very small and very large values. For example, (0, -inf) should return inf, but it returns -inf. 

This ports the logic used in XLA. Tested with XLA's exhaustive_binary_test_f32_f64.

---

Patch is 39.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88373.diff


3 Files Affected:

- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+16-38) 
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+134-214) 
- (modified) mlir/test/Conversion/ComplexToStandard/full-conversion.mlir (+12-22) 


``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index a6fcf6a758c07f..462036e51a1f1c 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -26,7 +26,7 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
-// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
+
 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
 
@@ -35,49 +35,27 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
                   ConversionPatternRewriter &rewriter) const override {
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+    arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
 
     Type elementType = op.getType();
-    Value arg = adaptor.getComplex();
-
-    Value zero =
-        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
     Value one = b.create<arith::ConstantOp>(elementType,
                                             b.getFloatAttr(elementType, 1.0));
 
-    Value real = b.create<complex::ReOp>(elementType, arg);
-    Value imag = b.create<complex::ImOp>(elementType, arg);
-
-    Value realIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
-    Value imagIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+    Value real = b.create<complex::ReOp>(adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+    Value absReal = b.create<math::AbsFOp>(real, fmf);
+    Value absImag = b.create<math::AbsFOp>(imag, fmf);
 
-    // Real > Imag
-    Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
-    Value imagSq =
-        b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
-    Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
-    Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
-    Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
-    Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
-
-    // Real <= Imag
-    Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
-    Value realSq =
-        b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
-    Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
-    Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
-    Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
-    Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
-
-    rewriter.replaceOpWithNewOp<arith::SelectOp>(
-        op, realIsZero, imagAbs,
-        b.create<arith::SelectOp>(
-            imagIsZero, realAbs,
-            b.create<arith::SelectOp>(
-                b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
-                absImag, absReal)));
+    Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
+    Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
+    Value ratio = b.create<arith::DivFOp>(min, max, fmf);
+    Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
+    Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
+    Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
+    Value result = b.create<arith::MulFOp>(max, sqrt, fmf);
+    Value isNaN =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, min, result);
 
     return success();
   }
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 46dba04a88aa0c..a1de61d10bb226 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -8,29 +8,21 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
   return %abs : f32
 }
 
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: return %[[ABS3]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: return %[[ABS]] : f32
 
 // -----
 
@@ -258,29 +250,21 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg: complex<f32>
   return %log : complex<f32>
 }
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[RESULT:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RESULT]], %[[RESULT]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[RESULT]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = math.log %[[ABS]] : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
@@ -509,30 +493,22 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
-// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL2]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG2]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[ABS]] : f32
+// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[ABS]] : f32
 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
 // CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
@@ -725,29 +701,21 @@ func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] : f32
-// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[VAR3:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[VAR4:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[VAR5:.*]] = arith.cmpf oeq, %[[VAR3]], %[[CST0]] : f32
-// CHECK: %[[VAR6:.*]] = arith.cmpf oeq, %[[VAR4]], %[[CST0]] : f32
-// CHECK: %[[VAR7:.*]] = arith.divf %[[VAR4]], %[[VAR3]] : f32
-// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR7]], %[[VAR7]] : f32
-// CHECK: %[[VAR9:.*]] = arith.addf %[[VAR8]], %[[CST1]] : f32
-// CHECK: %[[VAR10:.*]] = math.sqrt %[[VAR9]] : f32
-// CHECK: %[[VAR11:.*]] = math.absf %[[VAR3]] : f32
-// CHECK: %[[VAR12:.*]] = arith.mulf %[[VAR10]], %[[VAR11]] : f32
-// CHECK: %[[VAR13:.*]] = arith.divf %[[VAR3]], %[[VAR4]] : f32
-// CHECK: %[[VAR14:.*]] = arith.mulf %[[VAR13]], %[[VAR13]] : f32
-// CHECK: %[[VAR15:.*]] = arith.addf %[[VAR14]], %[[CST1]] : f32
-// CHECK: %[[VAR16:.*]] = math.sqrt %[[VAR15]] : f32
-// CHECK: %[[VAR17:.*]] = math.absf %[[VAR4]] : f32
-// CHECK: %[[VAR18:.*]] = arith.mulf %[[VAR16]], %[[VAR17]] : f32
-// CHECK: %[[VAR19:.*]] = arith.cmpf ogt, %[[VAR3]], %[[VAR4]] : f32
-// CHECK: %[[VAR20:.*]] = arith.select %[[VAR19]], %[[VAR12]], %[[VAR18]] : f32
-// CHECK: %[[VAR21:.*]] = arith.select %[[VAR6]], %[[VAR11]], %[[VAR20]] : f32
-// CHECK: %[[VAR22:.*]] = arith.select %[[VAR5]], %[[VAR17]], %[[VAR21]] : f32
-// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[VAR22]] : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[ABS]] : f32
 // CHECK: %[[CST2:.*]]  = arith.constant 5.000000e-01 : f32
 // CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] : f32
 // CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] : f32
@@ -821,29 +789,21 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
   return %abs : f32
 }
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: return %[[ABS3]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: return %[[ABS]] : f32
 
 // -----
 
@@ -928,29 +888,21 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg fastmath<nnan,contract> : complex<f32>
   return %log : complex<f32>
 }
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,con...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/88373


More information about the Mlir-commits mailing list