[Mlir-commits] [mlir] Fix overflows in complex sqrt lowering. (PR #88480)

Johannes Reifferscheid llvmlistbot at llvm.org
Fri Apr 12 00:16:30 PDT 2024


https://github.com/jreiffers created https://github.com/llvm/llvm-project/pull/88480

This ports XLA's complex sqrt lowering. The accuracy was tested with its exhaustive_unary_test_complex test.

Note: rsqrt is still broken.

>From 61f46dc7df2d8241a9047ffe1789ecdb83ce859c Mon Sep 17 00:00:00 2001
From: Johannes Reifferscheid <jreiffers at google.com>
Date: Fri, 12 Apr 2024 09:13:33 +0200
Subject: [PATCH] Fix overflows in complex sqrt lowering.

This ports XLA's complex sqrt lowering. The accuracy was tested with its
exhaustive_unary_test_complex test.

Note: rsqrt is still broken.
---
 .../ComplexToStandard/ComplexToStandard.cpp   | 166 ++++++----
 .../convert-to-standard.mlir                  | 296 +++++++++++-------
 .../ComplexToStandard/full-conversion.mlir    |   2 +-
 3 files changed, 280 insertions(+), 184 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9c82e8105f06e5..0664b053fc9e67 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -27,35 +27,52 @@ using namespace mlir;
 
 namespace {
 
+// Returns the absolute value or its square root.
+Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
+                 ImplicitLocOpBuilder &b, bool returnSqrt = false) {
+  Value one = b.create<arith::ConstantOp>(real.getType(),
+                                          b.getFloatAttr(real.getType(), 1.0));
+
+  Value absReal = b.create<math::AbsFOp>(real, fmf);
+  Value absImag = b.create<math::AbsFOp>(imag, fmf);
+
+  Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
+  Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
+  Value ratio = b.create<arith::DivFOp>(min, max, fmf);
+  Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
+  Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
+  Value result;
+
+  if (returnSqrt) {
+    Value quarter = b.create<arith::ConstantOp>(
+        real.getType(), b.getFloatAttr(real.getType(), 0.25));
+    // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
+    Value sqrt = b.create<math::SqrtOp>(max, fmf);
+    Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmf);
+    result = b.create<arith::MulFOp>(sqrt, p025, fmf);
+  } else {
+    Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
+    result = b.create<arith::MulFOp>(max, sqrt, fmf);
+  }
+
+  Value isNaN =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
+  return b.create<arith::SelectOp>(isNaN, min, result);
+}
+
 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
 
-    Type elementType = op.getType();
-    Value one = b.create<arith::ConstantOp>(elementType,
-                                            b.getFloatAttr(elementType, 1.0));
-
     Value real = b.create<complex::ReOp>(adaptor.getComplex());
     Value imag = b.create<complex::ImOp>(adaptor.getComplex());
-    Value absReal = b.create<math::AbsFOp>(real, fmf);
-    Value absImag = b.create<math::AbsFOp>(imag, fmf);
-
-    Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
-    Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
-    Value ratio = b.create<arith::DivFOp>(min, max, fmf);
-    Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
-    Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
-    Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
-    Value result = b.create<arith::MulFOp>(max, sqrt, fmf);
-    Value isNaN =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
-    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, min, result);
+    rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
 
     return success();
   }
@@ -829,60 +846,71 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
   LogicalResult
   matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     auto type = cast<ComplexType>(op.getType());
-    Type elementType = type.getElementType();
-    Value arg = adaptor.getComplex();
-    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
-
-    Value zero =
-        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
-
-    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
-    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
-
-    Value absLhs = b.create<math::AbsFOp>(real, fmf);
-    Value absArg = b.create<complex::AbsOp>(elementType, arg, fmf);
-    Value addAbs = b.create<arith::AddFOp>(absLhs, absArg, fmf);
+    auto elementType = type.getElementType().cast<FloatType>();
+    arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
 
+    auto cst = [&](APFloat v) {
+      return b.create<arith::ConstantOp>(elementType,
+                                         b.getFloatAttr(elementType, v));
+    };
+    const auto &floatSemantics = elementType.getFloatSemantics();
+    Value zero = cst(APFloat::getZero(floatSemantics));
     Value half = b.create<arith::ConstantOp>(elementType,
                                              b.getFloatAttr(elementType, 0.5));
-    Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half, fmf);
-    Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs, fmf);
-
-    Value realIsNegative =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
-    Value imagIsNegative =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
-
-    Value resultReal = sqrtAddAbs;
-
-    Value imagDivTwoResultReal = b.create<arith::DivFOp>(
-        imag, b.create<arith::AddFOp>(resultReal, resultReal, fmf), fmf);
-
-    Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
 
+    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+    Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true);
+    Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
+    Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
+    Value cos = b.create<math::CosOp>(sqrtArg, fmf);
+    Value sin = b.create<math::SinOp>(sqrtArg, fmf);
+    // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
+    // 0 * inf.
+    Value sinIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
+
+    Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
     Value resultImag = b.create<arith::SelectOp>(
-        realIsNegative,
-        b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
-                                  resultReal),
-        imagDivTwoResultReal);
-
-    resultReal = b.create<arith::SelectOp>(
-        realIsNegative,
-        b.create<arith::DivFOp>(
-            imag, b.create<arith::AddFOp>(resultImag, resultImag, fmf), fmf),
-        resultReal);
-
-    Value realIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
-    Value imagIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
-    Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
-
-    resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
-    resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
+        sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
+    if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
+                                            arith::FastMathFlags::ninf)) {
+      Value inf = cst(APFloat::getInf(floatSemantics));
+      Value negInf = cst(APFloat::getInf(floatSemantics, true));
+      Value nan = cst(APFloat::getNaN(floatSemantics));
+      Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
+
+      Value absImagIsInf =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
+      Value absImagIsNotInf =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
+      Value realIsInf =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
+      Value realIsNegInf =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
+
+      resultReal = b.create<arith::SelectOp>(
+          b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
+          resultReal);
+      resultReal = b.create<arith::SelectOp>(
+          b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
+
+      Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
+      resultImag = b.create<arith::SelectOp>(
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
+          nan, resultImag);
+      resultImag = b.create<arith::SelectOp>(
+          b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
+          resultImag);
+    }
+
+    Value resultIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
+    resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
+    resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
@@ -1065,7 +1093,7 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
   // Case 2:
   // 1^(c + d*i) = 1 + 0*i
   Value lhsEqOne = builder.create<arith::AndIOp>(
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one),
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
       bEqZero);
   Value cutoff2 =
       builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
@@ -1073,11 +1101,11 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
   // Case 3:
   // inf^(c + 0*i) = inf + 0*i, c > 0
   Value lhsEqInf = builder.create<arith::AndIOp>(
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf),
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
       bEqZero);
   Value rhsGt0 = builder.create<arith::AndIOp>(
       dEqZero,
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero));
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
   Value cutoff3 = builder.create<arith::SelectOp>(
       builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
 
@@ -1085,7 +1113,7 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
   // inf^(c + 0*i) = 0 + 0*i, c < 0
   Value rhsLt0 = builder.create<arith::AndIOp>(
       dEqZero,
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero));
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
   Value cutoff4 = builder.create<arith::SelectOp>(
       builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
 
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8d2fb09daa87b6..b22c1acacaea18 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -8,9 +8,9 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
   return %abs : f32
 }
 
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
@@ -250,9 +250,9 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg: complex<f32>
   return %log : complex<f32>
 }
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
@@ -493,9 +493,9 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL2]] : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG2]] : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
@@ -697,45 +697,95 @@ func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
   return %sqrt : complex<f32>
 }
 
-// CHECK: %[[CST:.*]]  = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
-// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] : f32
 // CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
 // CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
 // CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
-// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[ABS]] : f32
-// CHECK: %[[CST2:.*]]  = arith.constant 5.000000e-01 : f32
-// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] : f32
-// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] : f32
-// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] : f32
-// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] : f32
-// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
-// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
-// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
-// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] : f32
-// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] : f32
-// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
-// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
-// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
-// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
-// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
-// CHECK: return %[[VAR41]] : complex<f32>
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] : f32
+// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] : f32
+// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[SQRTARG]] : f32
+// CHECK: %[[SIN_ZERO:.*]] = arith.cmpf oeq, %[[SIN]], %[[ZERO]] : f32
+// CHECK: %[[RESULT_RE:.*]] = arith.mulf %[[SQRT_ABS]], %[[COS]] : f32
+// CHECK: %[[RESULT_IM:.*]] = arith.mulf %[[SQRT_ABS]], %[[SIN]] : f32
+// CHECK: %[[RESULT_IM2:.*]] = arith.select %[[SIN_ZERO]], %[[ZERO]], %[[RESULT_IM]] : f32
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[NINF:.*]] = arith.constant 0xFF800000 : f32
+// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] : f32
+// CHECK: %[[ABSIMINF:.*]] = arith.cmpf oeq, %[[ABSIM]], %[[INF]] : f32
+// CHECK: %[[ABSIMNOTINF:.*]] = arith.cmpf one, %[[ABSIM]], %[[INF]] : f32
+// CHECK: %[[REINF:.*]] = arith.cmpf oeq, %[[RE]], %[[INF]] : f32
+// CHECK: %[[RENINF:.*]] = arith.cmpf oeq, %[[RE]], %[[NINF]] : f32
+// CHECK: %[[RESULT_RE_ZERO:.*]] = arith.andi %[[RENINF]], %[[ABSIMNOTINF]] : i1
+// CHECK: %[[RESULT_RE2:.*]] = arith.select %[[RESULT_RE_ZERO]], %[[ZERO]], %[[RESULT_RE]] : f32
+// CHECK: %[[RESUL_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[REINF]] : i1
+// CHECK: %[[RESULT_RE3:.*]] = arith.select %[[RESUL_IM_INF]], %[[INF]], %[[RESULT_RE2]] : f32
+// CHECK: %[[INF_IM_SIGN:.*]] = math.copysign %[[INF]], %[[IM]] : f32
+// CHECK: %[[RESULT_IM_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS]], %[[SQRT_ABS]] : f32
+// CHECK: %[[RESULT_IM3:.*]] = arith.select %[[RESULT_IM_NAN]], %[[NAN]], %[[RESULT_IM2]] : f32
+// CHECK: %[[RESULT_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[RENINF]] : i1
+// CHECK: %[[RESULT_IM4:.*]] = arith.select %[[RESULT_IM_INF]], %[[INF_IM_SIGN]], %[[RESULT_IM3]] : f32
+// CHECK: %[[RESULT_ZERO:.*]] = arith.cmpf oeq, %[[SQRT_ABS]], %[[ZERO]] : f32
+// CHECK: %[[RESULT_RE4:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_RE3]] : f32
+// CHECK: %[[RESULT_IM5:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_IM4]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_RE4]], %[[RESULT_IM5]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sqrt_nnan_ninf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
+  %sqrt = complex.sqrt %arg fastmath<nnan,ninf> : complex<f32>
+  return %sqrt : complex<f32>
+}
+
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] fastmath<nnan,ninf> : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,ninf> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,ninf> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,ninf> : f32
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,ninf> : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,ninf> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,ninf> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[SQRTARG]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SIN_ZERO:.*]] = arith.cmpf oeq, %[[SIN]], %[[ZERO]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RESULT_RE:.*]] = arith.mulf %[[SQRT_ABS]], %[[COS]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RESULT_IM:.*]] = arith.mulf %[[SQRT_ABS]], %[[SIN]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RESULT_IM2:.*]] = arith.select %[[SIN_ZERO]], %[[ZERO]], %[[RESULT_IM]] : f32
+// CHECK: %[[RESULT_ZERO:.*]] = arith.cmpf oeq, %[[SQRT_ABS]], %[[ZERO]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RESULT_RE2:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_RE]] : f32
+// CHECK: %[[RESULT_IM3:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_IM2]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_RE2]], %[[RESULT_IM3]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
 
 // -----
 
@@ -808,9 +858,9 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
   return %abs : f32
 }
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
@@ -907,9 +957,9 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
   %log = complex.log %arg fastmath<nnan,contract> : complex<f32>
   return %log : complex<f32>
 }
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
@@ -1285,44 +1335,53 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[VAR184:.*]] = complex.im %[[VAR179]] : complex<f32>
 // CHECK: %[[VAR185:.*]] = arith.addf %[[VAR183]], %[[VAR184]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR186:.*]] = complex.create %[[VAR182]], %[[VAR185]] : complex<f32>
-// CHECK: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR187:.*]] = complex.re %[[VAR186]] : complex<f32>
-// CHECK: %[[VAR188:.*]] = complex.im %[[VAR186]] : complex<f32>
-// CHECK: %[[VAR189:.*]] = math.absf %[[VAR187]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[VAR186]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[VAR186]] : complex<f32>
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL:.*]] = complex.re %[[VAR186]] : complex<f32>
-// CHECK: %[[IMAG:.*]] = complex.im %[[VAR186]] : complex<f32>
-// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract>  : f32
-// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract>  : f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract>  : f32
-// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract>  : f32
-// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract>  : f32
-// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract>  : f32
-// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract>  : f32
-// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract>  : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract>  : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract>  : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[VAR210:.*]] = arith.addf %[[VAR189]], %[[ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[CST_9:.*]] = arith.constant 5.000000e-01 : f32
-// CHECK: %[[VAR211:.*]] = arith.mulf %[[VAR210]], %[[CST_9]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR212:.*]] = math.sqrt %[[VAR211]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR213:.*]] = arith.cmpf olt, %[[VAR187]], %[[CST_6]] : f32
-// CHECK: %[[VAR214:.*]] = arith.cmpf olt, %[[VAR188]], %[[CST_6]] : f32
-// CHECK: %[[VAR215:.*]] = arith.addf %[[VAR212]], %[[VAR212]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR216:.*]] = arith.divf %[[VAR188]], %[[VAR215]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR217:.*]] = arith.negf %[[VAR212]] : f32
-// CHECK: %[[VAR218:.*]] = arith.select %[[VAR214]], %[[VAR217]], %[[VAR212]] : f32
-// CHECK: %[[VAR219:.*]] = arith.select %[[VAR213]], %[[VAR218]], %[[VAR216]] : f32
-// CHECK: %[[VAR220:.*]] = arith.addf %[[VAR219]], %[[VAR219]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR221:.*]] = arith.divf %[[VAR188]], %[[VAR220]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR222:.*]] = arith.select %[[VAR213]], %[[VAR221]], %[[VAR212]] : f32
-// CHECK: %[[VAR223:.*]] = arith.cmpf oeq, %[[VAR187]], %[[CST_6]] : f32
-// CHECK: %[[VAR224:.*]] = arith.cmpf oeq, %[[VAR188]], %[[CST_6]] : f32
-// CHECK: %[[VAR225:.*]] = arith.andi %[[VAR223]], %[[VAR224]] : i1
-// CHECK: %[[VAR226:.*]] = arith.select %[[VAR225]], %[[CST_6]], %[[VAR222]] : f32
-// CHECK: %[[VAR227:.*]] = arith.select %[[VAR225]], %[[CST_6]], %[[VAR219]] : f32
-// CHECK: %[[VAR228:.*]] = complex.create %[[VAR226]], %[[VAR227]] : complex<f32>
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[SQRTARG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_ZERO:.*]] = arith.cmpf oeq, %[[SIN]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE:.*]] = arith.mulf %[[SQRT_ABS]], %[[COS]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM:.*]] = arith.mulf %[[SQRT_ABS]], %[[SIN]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM2:.*]] = arith.select %[[SIN_ZERO]], %[[ZERO]], %[[RESULT_IM]] : f32
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[NINF:.*]] = arith.constant 0xFF800000 : f32
+// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIMINF:.*]] = arith.cmpf oeq, %[[ABSIM]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIMNOTINF:.*]] = arith.cmpf one, %[[ABSIM]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[REINF:.*]] = arith.cmpf oeq, %[[RE]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RENINF:.*]] = arith.cmpf oeq, %[[RE]], %[[NINF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE_ZERO:.*]] = arith.andi %[[RENINF]], %[[ABSIMNOTINF]] : i1
+// CHECK: %[[RESULT_RE2:.*]] = arith.select %[[RESULT_RE_ZERO]], %[[ZERO]], %[[RESULT_RE]] : f32
+// CHECK: %[[RESUL_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[REINF]] : i1
+// CHECK: %[[RESULT_RE3:.*]] = arith.select %[[RESUL_IM_INF]], %[[INF]], %[[RESULT_RE2]] : f32
+// CHECK: %[[INF_IM_SIGN:.*]] = math.copysign %[[INF]], %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS]], %[[SQRT_ABS]] : f32
+// CHECK: %[[RESULT_IM3:.*]] = arith.select %[[RESULT_IM_NAN]], %[[NAN]], %[[RESULT_IM2]] : f32
+// CHECK: %[[RESULT_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[RENINF]] : i1
+// CHECK: %[[RESULT_IM4:.*]] = arith.select %[[RESULT_IM_INF]], %[[INF_IM_SIGN]], %[[RESULT_IM3]] : f32
+// CHECK: %[[RESULT_ZERO:.*]] = arith.cmpf oeq, %[[SQRT_ABS]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE4:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_RE3]] : f32
+// CHECK: %[[RESULT_IM5:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_IM4]] : f32
+// CHECK: %[[VAR228:.*]] = complex.create %[[RESULT_RE4]], %[[RESULT_IM5]] : complex<f32>
 // CHECK: %[[CST_10:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK: %[[CST_11:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[VAR229:.*]] = complex.create %[[CST_10]], %[[CST_11]] : complex<f32>
@@ -1519,9 +1578,9 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[VAR413:.*]] = arith.select %[[VAR412]], %[[VAR408]], %[[VAR402]] : f32
 // CHECK: %[[VAR414:.*]] = arith.select %[[VAR412]], %[[VAR409]], %[[VAR403]] : f32
 // CHECK: %[[VAR415:.*]] = complex.create %[[VAR413]], %[[VAR414]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL:.*]] = complex.re %[[VAR415]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[VAR415]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
@@ -1756,45 +1815,54 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
   return %sqrt : complex<f32>
 }
 
-// CHECK: %[[CST:.*]]  = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,contract> : f32
 // CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
 // CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
 // CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[CST2:.*]]  = arith.constant 5.000000e-01 : f32
-// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
-// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
-// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
-// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
-// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
-// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
-// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
-// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
-// CHECK: return %[[VAR41]] : complex<f32>
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[SQRTARG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_ZERO:.*]] = arith.cmpf oeq, %[[SIN]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE:.*]] = arith.mulf %[[SQRT_ABS]], %[[COS]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM:.*]] = arith.mulf %[[SQRT_ABS]], %[[SIN]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM2:.*]] = arith.select %[[SIN_ZERO]], %[[ZERO]], %[[RESULT_IM]] : f32
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[NINF:.*]] = arith.constant 0xFF800000 : f32
+// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIMINF:.*]] = arith.cmpf oeq, %[[ABSIM]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIMNOTINF:.*]] = arith.cmpf one, %[[ABSIM]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[REINF:.*]] = arith.cmpf oeq, %[[RE]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RENINF:.*]] = arith.cmpf oeq, %[[RE]], %[[NINF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE_ZERO:.*]] = arith.andi %[[RENINF]], %[[ABSIMNOTINF]] : i1
+// CHECK: %[[RESULT_RE2:.*]] = arith.select %[[RESULT_RE_ZERO]], %[[ZERO]], %[[RESULT_RE]] : f32
+// CHECK: %[[RESUL_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[REINF]] : i1
+// CHECK: %[[RESULT_RE3:.*]] = arith.select %[[RESUL_IM_INF]], %[[INF]], %[[RESULT_RE2]] : f32
+// CHECK: %[[INF_IM_SIGN:.*]] = math.copysign %[[INF]], %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS]], %[[SQRT_ABS]] : f32
+// CHECK: %[[RESULT_IM3:.*]] = arith.select %[[RESULT_IM_NAN]], %[[NAN]], %[[RESULT_IM2]] : f32
+// CHECK: %[[RESULT_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[RENINF]] : i1
+// CHECK: %[[RESULT_IM4:.*]] = arith.select %[[RESULT_IM_INF]], %[[INF_IM_SIGN]], %[[RESULT_IM3]] : f32
+// CHECK: %[[RESULT_ZERO:.*]] = arith.cmpf oeq, %[[SQRT_ABS]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE4:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_RE3]] : f32
+// CHECK: %[[RESULT_IM5:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_IM4]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_RE4]], %[[RESULT_IM5]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
 
 // -----
 
@@ -1857,9 +1925,9 @@ func.func @complex_sign_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
 // CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
 // CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL2]] fastmath<nnan,contract> : f32
 // CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG2]] fastmath<nnan,contract> : f32
 // CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index 2649d004a76acc..110a78631fb95a 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -6,10 +6,10 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
   %abs = complex.abs %arg: complex<f32>
   return %abs : f32
 }
-// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
 // CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
 // CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
 
+// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
 // CHECK: %[[ABS_REAL:.*]] = llvm.intr.fabs(%[[REAL]]) : (f32) -> f32
 // CHECK: %[[ABS_IMAG:.*]] = llvm.intr.fabs(%[[IMAG]]) : (f32) -> f32
 // CHECK: %[[MAX:.*]] = llvm.intr.maximum(%[[ABS_REAL]], %[[ABS_IMAG]]) : (f32, f32) -> f32



More information about the Mlir-commits mailing list