[Mlir-commits] [mlir] Fix overflows in complex sqrt lowering. (PR #88480)
Johannes Reifferscheid
llvmlistbot at llvm.org
Fri Apr 12 00:16:30 PDT 2024
https://github.com/jreiffers created https://github.com/llvm/llvm-project/pull/88480
This ports XLA's complex sqrt lowering. The accuracy was tested with its exhaustive_unary_test_complex test.
Note: rsqrt is still broken.
>From 61f46dc7df2d8241a9047ffe1789ecdb83ce859c Mon Sep 17 00:00:00 2001
From: Johannes Reifferscheid <jreiffers at google.com>
Date: Fri, 12 Apr 2024 09:13:33 +0200
Subject: [PATCH] Fix overflows in complex sqrt lowering.
This ports XLA's complex sqrt lowering. The accuracy was tested with its
exhaustive_unary_test_complex test.
Note: rsqrt is still broken.
---
.../ComplexToStandard/ComplexToStandard.cpp | 166 ++++++----
.../convert-to-standard.mlir | 296 +++++++++++-------
.../ComplexToStandard/full-conversion.mlir | 2 +-
3 files changed, 280 insertions(+), 184 deletions(-)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9c82e8105f06e5..0664b053fc9e67 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -27,35 +27,52 @@ using namespace mlir;
namespace {
+// Returns the absolute value or its square root.
+Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
+ ImplicitLocOpBuilder &b, bool returnSqrt = false) {
+ Value one = b.create<arith::ConstantOp>(real.getType(),
+ b.getFloatAttr(real.getType(), 1.0));
+
+ Value absReal = b.create<math::AbsFOp>(real, fmf);
+ Value absImag = b.create<math::AbsFOp>(imag, fmf);
+
+ Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
+ Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
+ Value ratio = b.create<arith::DivFOp>(min, max, fmf);
+ Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
+ Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
+ Value result;
+
+ if (returnSqrt) {
+ Value quarter = b.create<arith::ConstantOp>(
+ real.getType(), b.getFloatAttr(real.getType(), 0.25));
+ // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
+ Value sqrt = b.create<math::SqrtOp>(max, fmf);
+ Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmf);
+ result = b.create<arith::MulFOp>(sqrt, p025, fmf);
+ } else {
+ Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
+ result = b.create<arith::MulFOp>(max, sqrt, fmf);
+ }
+
+ Value isNaN =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
+ return b.create<arith::SelectOp>(isNaN, min, result);
+}
+
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
- Type elementType = op.getType();
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1.0));
-
Value real = b.create<complex::ReOp>(adaptor.getComplex());
Value imag = b.create<complex::ImOp>(adaptor.getComplex());
- Value absReal = b.create<math::AbsFOp>(real, fmf);
- Value absImag = b.create<math::AbsFOp>(imag, fmf);
-
- Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
- Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
- Value ratio = b.create<arith::DivFOp>(min, max, fmf);
- Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
- Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
- Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
- Value result = b.create<arith::MulFOp>(max, sqrt, fmf);
- Value isNaN =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
- rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, min, result);
+ rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
return success();
}
@@ -829,60 +846,71 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
LogicalResult
matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto type = cast<ComplexType>(op.getType());
- Type elementType = type.getElementType();
- Value arg = adaptor.getComplex();
- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
-
- Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
-
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
-
- Value absLhs = b.create<math::AbsFOp>(real, fmf);
- Value absArg = b.create<complex::AbsOp>(elementType, arg, fmf);
- Value addAbs = b.create<arith::AddFOp>(absLhs, absArg, fmf);
+ auto elementType = type.getElementType().cast<FloatType>();
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
+ auto cst = [&](APFloat v) {
+ return b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, v));
+ };
+ const auto &floatSemantics = elementType.getFloatSemantics();
+ Value zero = cst(APFloat::getZero(floatSemantics));
Value half = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 0.5));
- Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half, fmf);
- Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs, fmf);
-
- Value realIsNegative =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
- Value imagIsNegative =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
-
- Value resultReal = sqrtAddAbs;
-
- Value imagDivTwoResultReal = b.create<arith::DivFOp>(
- imag, b.create<arith::AddFOp>(resultReal, resultReal, fmf), fmf);
-
- Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
+ Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
+ Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true);
+ Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
+ Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
+ Value cos = b.create<math::CosOp>(sqrtArg, fmf);
+ Value sin = b.create<math::SinOp>(sqrtArg, fmf);
+ // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
+ // 0 * inf.
+ Value sinIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
+
+ Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
Value resultImag = b.create<arith::SelectOp>(
- realIsNegative,
- b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
- resultReal),
- imagDivTwoResultReal);
-
- resultReal = b.create<arith::SelectOp>(
- realIsNegative,
- b.create<arith::DivFOp>(
- imag, b.create<arith::AddFOp>(resultImag, resultImag, fmf), fmf),
- resultReal);
-
- Value realIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
- Value imagIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
- Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
-
- resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
- resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
+ sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
+ if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
+ arith::FastMathFlags::ninf)) {
+ Value inf = cst(APFloat::getInf(floatSemantics));
+ Value negInf = cst(APFloat::getInf(floatSemantics, true));
+ Value nan = cst(APFloat::getNaN(floatSemantics));
+ Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
+
+ Value absImagIsInf =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
+ Value absImagIsNotInf =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
+ Value realIsInf =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
+ Value realIsNegInf =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
+
+ resultReal = b.create<arith::SelectOp>(
+ b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
+ resultReal);
+ resultReal = b.create<arith::SelectOp>(
+ b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
+
+ Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
+ resultImag = b.create<arith::SelectOp>(
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
+ nan, resultImag);
+ resultImag = b.create<arith::SelectOp>(
+ b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
+ resultImag);
+ }
+
+ Value resultIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
+ resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
+ resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
@@ -1065,7 +1093,7 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
// Case 2:
// 1^(c + d*i) = 1 + 0*i
Value lhsEqOne = builder.create<arith::AndIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one),
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
bEqZero);
Value cutoff2 =
builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
@@ -1073,11 +1101,11 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
// Case 3:
// inf^(c + 0*i) = inf + 0*i, c > 0
Value lhsEqInf = builder.create<arith::AndIOp>(
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf),
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
bEqZero);
Value rhsGt0 = builder.create<arith::AndIOp>(
dEqZero,
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero));
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
Value cutoff3 = builder.create<arith::SelectOp>(
builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
@@ -1085,7 +1113,7 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
// inf^(c + 0*i) = 0 + 0*i, c < 0
Value rhsLt0 = builder.create<arith::AndIOp>(
dEqZero,
- builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero));
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
Value cutoff4 = builder.create<arith::SelectOp>(
builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8d2fb09daa87b6..b22c1acacaea18 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -8,9 +8,9 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
return %abs : f32
}
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
@@ -250,9 +250,9 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg: complex<f32>
return %log : complex<f32>
}
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
@@ -493,9 +493,9 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL2]] : f32
// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG2]] : f32
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
@@ -697,45 +697,95 @@ func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
return %sqrt : complex<f32>
}
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
-// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] : f32
// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
-// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[ABS]] : f32
-// CHECK: %[[CST2:.*]] = arith.constant 5.000000e-01 : f32
-// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] : f32
-// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] : f32
-// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] : f32
-// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] : f32
-// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
-// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
-// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
-// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] : f32
-// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] : f32
-// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
-// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
-// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
-// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
-// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
-// CHECK: return %[[VAR41]] : complex<f32>
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] : f32
+// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] : f32
+// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[SQRTARG]] : f32
+// CHECK: %[[SIN_ZERO:.*]] = arith.cmpf oeq, %[[SIN]], %[[ZERO]] : f32
+// CHECK: %[[RESULT_RE:.*]] = arith.mulf %[[SQRT_ABS]], %[[COS]] : f32
+// CHECK: %[[RESULT_IM:.*]] = arith.mulf %[[SQRT_ABS]], %[[SIN]] : f32
+// CHECK: %[[RESULT_IM2:.*]] = arith.select %[[SIN_ZERO]], %[[ZERO]], %[[RESULT_IM]] : f32
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[NINF:.*]] = arith.constant 0xFF800000 : f32
+// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] : f32
+// CHECK: %[[ABSIMINF:.*]] = arith.cmpf oeq, %[[ABSIM]], %[[INF]] : f32
+// CHECK: %[[ABSIMNOTINF:.*]] = arith.cmpf one, %[[ABSIM]], %[[INF]] : f32
+// CHECK: %[[REINF:.*]] = arith.cmpf oeq, %[[RE]], %[[INF]] : f32
+// CHECK: %[[RENINF:.*]] = arith.cmpf oeq, %[[RE]], %[[NINF]] : f32
+// CHECK: %[[RESULT_RE_ZERO:.*]] = arith.andi %[[RENINF]], %[[ABSIMNOTINF]] : i1
+// CHECK: %[[RESULT_RE2:.*]] = arith.select %[[RESULT_RE_ZERO]], %[[ZERO]], %[[RESULT_RE]] : f32
+// CHECK: %[[RESUL_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[REINF]] : i1
+// CHECK: %[[RESULT_RE3:.*]] = arith.select %[[RESUL_IM_INF]], %[[INF]], %[[RESULT_RE2]] : f32
+// CHECK: %[[INF_IM_SIGN:.*]] = math.copysign %[[INF]], %[[IM]] : f32
+// CHECK: %[[RESULT_IM_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS]], %[[SQRT_ABS]] : f32
+// CHECK: %[[RESULT_IM3:.*]] = arith.select %[[RESULT_IM_NAN]], %[[NAN]], %[[RESULT_IM2]] : f32
+// CHECK: %[[RESULT_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[RENINF]] : i1
+// CHECK: %[[RESULT_IM4:.*]] = arith.select %[[RESULT_IM_INF]], %[[INF_IM_SIGN]], %[[RESULT_IM3]] : f32
+// CHECK: %[[RESULT_ZERO:.*]] = arith.cmpf oeq, %[[SQRT_ABS]], %[[ZERO]] : f32
+// CHECK: %[[RESULT_RE4:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_RE3]] : f32
+// CHECK: %[[RESULT_IM5:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_IM4]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_RE4]], %[[RESULT_IM5]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sqrt_nnan_ninf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_sqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
+ %sqrt = complex.sqrt %arg fastmath<nnan,ninf> : complex<f32>
+ return %sqrt : complex<f32>
+}
+
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] fastmath<nnan,ninf> : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,ninf> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,ninf> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,ninf> : f32
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,ninf> : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,ninf> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,ninf> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[SQRTARG]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SIN_ZERO:.*]] = arith.cmpf oeq, %[[SIN]], %[[ZERO]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RESULT_RE:.*]] = arith.mulf %[[SQRT_ABS]], %[[COS]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RESULT_IM:.*]] = arith.mulf %[[SQRT_ABS]], %[[SIN]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RESULT_IM2:.*]] = arith.select %[[SIN_ZERO]], %[[ZERO]], %[[RESULT_IM]] : f32
+// CHECK: %[[RESULT_ZERO:.*]] = arith.cmpf oeq, %[[SQRT_ABS]], %[[ZERO]] fastmath<nnan,ninf> : f32
+// CHECK: %[[RESULT_RE2:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_RE]] : f32
+// CHECK: %[[RESULT_IM3:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_IM2]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_RE2]], %[[RESULT_IM3]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
// -----
@@ -808,9 +858,9 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
return %abs : f32
}
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
@@ -907,9 +957,9 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg fastmath<nnan,contract> : complex<f32>
return %log : complex<f32>
}
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
@@ -1285,44 +1335,53 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[VAR184:.*]] = complex.im %[[VAR179]] : complex<f32>
// CHECK: %[[VAR185:.*]] = arith.addf %[[VAR183]], %[[VAR184]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR186:.*]] = complex.create %[[VAR182]], %[[VAR185]] : complex<f32>
-// CHECK: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR187:.*]] = complex.re %[[VAR186]] : complex<f32>
-// CHECK: %[[VAR188:.*]] = complex.im %[[VAR186]] : complex<f32>
-// CHECK: %[[VAR189:.*]] = math.absf %[[VAR187]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[VAR186]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[VAR186]] : complex<f32>
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL:.*]] = complex.re %[[VAR186]] : complex<f32>
-// CHECK: %[[IMAG:.*]] = complex.im %[[VAR186]] : complex<f32>
-// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
-// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
-// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[VAR210:.*]] = arith.addf %[[VAR189]], %[[ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[CST_9:.*]] = arith.constant 5.000000e-01 : f32
-// CHECK: %[[VAR211:.*]] = arith.mulf %[[VAR210]], %[[CST_9]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR212:.*]] = math.sqrt %[[VAR211]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR213:.*]] = arith.cmpf olt, %[[VAR187]], %[[CST_6]] : f32
-// CHECK: %[[VAR214:.*]] = arith.cmpf olt, %[[VAR188]], %[[CST_6]] : f32
-// CHECK: %[[VAR215:.*]] = arith.addf %[[VAR212]], %[[VAR212]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR216:.*]] = arith.divf %[[VAR188]], %[[VAR215]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR217:.*]] = arith.negf %[[VAR212]] : f32
-// CHECK: %[[VAR218:.*]] = arith.select %[[VAR214]], %[[VAR217]], %[[VAR212]] : f32
-// CHECK: %[[VAR219:.*]] = arith.select %[[VAR213]], %[[VAR218]], %[[VAR216]] : f32
-// CHECK: %[[VAR220:.*]] = arith.addf %[[VAR219]], %[[VAR219]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR221:.*]] = arith.divf %[[VAR188]], %[[VAR220]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR222:.*]] = arith.select %[[VAR213]], %[[VAR221]], %[[VAR212]] : f32
-// CHECK: %[[VAR223:.*]] = arith.cmpf oeq, %[[VAR187]], %[[CST_6]] : f32
-// CHECK: %[[VAR224:.*]] = arith.cmpf oeq, %[[VAR188]], %[[CST_6]] : f32
-// CHECK: %[[VAR225:.*]] = arith.andi %[[VAR223]], %[[VAR224]] : i1
-// CHECK: %[[VAR226:.*]] = arith.select %[[VAR225]], %[[CST_6]], %[[VAR222]] : f32
-// CHECK: %[[VAR227:.*]] = arith.select %[[VAR225]], %[[CST_6]], %[[VAR219]] : f32
-// CHECK: %[[VAR228:.*]] = complex.create %[[VAR226]], %[[VAR227]] : complex<f32>
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[SQRTARG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_ZERO:.*]] = arith.cmpf oeq, %[[SIN]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE:.*]] = arith.mulf %[[SQRT_ABS]], %[[COS]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM:.*]] = arith.mulf %[[SQRT_ABS]], %[[SIN]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM2:.*]] = arith.select %[[SIN_ZERO]], %[[ZERO]], %[[RESULT_IM]] : f32
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[NINF:.*]] = arith.constant 0xFF800000 : f32
+// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIMINF:.*]] = arith.cmpf oeq, %[[ABSIM]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIMNOTINF:.*]] = arith.cmpf one, %[[ABSIM]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[REINF:.*]] = arith.cmpf oeq, %[[RE]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RENINF:.*]] = arith.cmpf oeq, %[[RE]], %[[NINF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE_ZERO:.*]] = arith.andi %[[RENINF]], %[[ABSIMNOTINF]] : i1
+// CHECK: %[[RESULT_RE2:.*]] = arith.select %[[RESULT_RE_ZERO]], %[[ZERO]], %[[RESULT_RE]] : f32
+// CHECK: %[[RESUL_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[REINF]] : i1
+// CHECK: %[[RESULT_RE3:.*]] = arith.select %[[RESUL_IM_INF]], %[[INF]], %[[RESULT_RE2]] : f32
+// CHECK: %[[INF_IM_SIGN:.*]] = math.copysign %[[INF]], %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS]], %[[SQRT_ABS]] : f32
+// CHECK: %[[RESULT_IM3:.*]] = arith.select %[[RESULT_IM_NAN]], %[[NAN]], %[[RESULT_IM2]] : f32
+// CHECK: %[[RESULT_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[RENINF]] : i1
+// CHECK: %[[RESULT_IM4:.*]] = arith.select %[[RESULT_IM_INF]], %[[INF_IM_SIGN]], %[[RESULT_IM3]] : f32
+// CHECK: %[[RESULT_ZERO:.*]] = arith.cmpf oeq, %[[SQRT_ABS]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE4:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_RE3]] : f32
+// CHECK: %[[RESULT_IM5:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_IM4]] : f32
+// CHECK: %[[VAR228:.*]] = complex.create %[[RESULT_RE4]], %[[RESULT_IM5]] : complex<f32>
// CHECK: %[[CST_10:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CST_11:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[VAR229:.*]] = complex.create %[[CST_10]], %[[CST_11]] : complex<f32>
@@ -1519,9 +1578,9 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[VAR413:.*]] = arith.select %[[VAR412]], %[[VAR408]], %[[VAR402]] : f32
// CHECK: %[[VAR414:.*]] = arith.select %[[VAR412]], %[[VAR409]], %[[VAR403]] : f32
// CHECK: %[[VAR415:.*]] = complex.create %[[VAR413]], %[[VAR414]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[VAR415]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[VAR415]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
@@ -1756,45 +1815,54 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
return %sqrt : complex<f32>
}
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK: %[[RE:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IM:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSRE:.*]] = math.absf %[[RE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABSRE]], %[[ABSIM]] fastmath<nnan,contract> : f32
// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
-// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[CST2:.*]] = arith.constant 5.000000e-01 : f32
-// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR26:.*]] = arith.cmpf olt, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR27:.*]] = arith.cmpf olt, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR28:.*]] = arith.addf %[[VAR25]], %[[VAR25]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR29:.*]] = arith.divf %[[VAR1]], %[[VAR28]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR30:.*]] = arith.negf %[[VAR25]] : f32
-// CHECK: %[[VAR31:.*]] = arith.select %[[VAR27]], %[[VAR30]], %[[VAR25]] : f32
-// CHECK: %[[VAR32:.*]] = arith.select %[[VAR26]], %[[VAR31]], %[[VAR29]] : f32
-// CHECK: %[[VAR33:.*]] = arith.addf %[[VAR32]], %[[VAR32]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR34:.*]] = arith.divf %[[VAR1]], %[[VAR33]] fastmath<nnan,contract> : f32
-// CHECK: %[[VAR35:.*]] = arith.select %[[VAR26]], %[[VAR34]], %[[VAR25]] : f32
-// CHECK: %[[VAR36:.*]] = arith.cmpf oeq, %[[VAR0]], %cst : f32
-// CHECK: %[[VAR37:.*]] = arith.cmpf oeq, %[[VAR1]], %cst : f32
-// CHECK: %[[VAR38:.*]] = arith.andi %[[VAR36]], %[[VAR37]] : i1
-// CHECK: %[[VAR39:.*]] = arith.select %[[VAR38]], %cst, %[[VAR35]] : f32
-// CHECK: %[[VAR40:.*]] = arith.select %[[VAR38]], %cst, %[[VAR32]] : f32
-// CHECK: %[[VAR41:.*]] = complex.create %[[VAR39]], %[[VAR40]] : complex<f32>
-// CHECK: return %[[VAR41]] : complex<f32>
+// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
+// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[SQRTARG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN_ZERO:.*]] = arith.cmpf oeq, %[[SIN]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE:.*]] = arith.mulf %[[SQRT_ABS]], %[[COS]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM:.*]] = arith.mulf %[[SQRT_ABS]], %[[SIN]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM2:.*]] = arith.select %[[SIN_ZERO]], %[[ZERO]], %[[RESULT_IM]] : f32
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[NINF:.*]] = arith.constant 0xFF800000 : f32
+// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK: %[[ABSIM:.*]] = math.absf %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIMINF:.*]] = arith.cmpf oeq, %[[ABSIM]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABSIMNOTINF:.*]] = arith.cmpf one, %[[ABSIM]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[REINF:.*]] = arith.cmpf oeq, %[[RE]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RENINF:.*]] = arith.cmpf oeq, %[[RE]], %[[NINF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE_ZERO:.*]] = arith.andi %[[RENINF]], %[[ABSIMNOTINF]] : i1
+// CHECK: %[[RESULT_RE2:.*]] = arith.select %[[RESULT_RE_ZERO]], %[[ZERO]], %[[RESULT_RE]] : f32
+// CHECK: %[[RESUL_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[REINF]] : i1
+// CHECK: %[[RESULT_RE3:.*]] = arith.select %[[RESUL_IM_INF]], %[[INF]], %[[RESULT_RE2]] : f32
+// CHECK: %[[INF_IM_SIGN:.*]] = math.copysign %[[INF]], %[[IM]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IM_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS]], %[[SQRT_ABS]] : f32
+// CHECK: %[[RESULT_IM3:.*]] = arith.select %[[RESULT_IM_NAN]], %[[NAN]], %[[RESULT_IM2]] : f32
+// CHECK: %[[RESULT_IM_INF:.*]] = arith.ori %[[ABSIMINF]], %[[RENINF]] : i1
+// CHECK: %[[RESULT_IM4:.*]] = arith.select %[[RESULT_IM_INF]], %[[INF_IM_SIGN]], %[[RESULT_IM3]] : f32
+// CHECK: %[[RESULT_ZERO:.*]] = arith.cmpf oeq, %[[SQRT_ABS]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_RE4:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_RE3]] : f32
+// CHECK: %[[RESULT_IM5:.*]] = arith.select %[[RESULT_ZERO]], %[[ZERO]], %[[RESULT_IM4]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_RE4]], %[[RESULT_IM5]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
// -----
@@ -1857,9 +1925,9 @@ func.func @complex_sign_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL2]] fastmath<nnan,contract> : f32
// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG2]] fastmath<nnan,contract> : f32
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
index 2649d004a76acc..110a78631fb95a 100644
--- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir
@@ -6,10 +6,10 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg: complex<f32>
return %abs : f32
}
-// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[ARG]][0] : ![[C_TY]]
// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[ARG]][1] : ![[C_TY]]
+// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
// CHECK: %[[ABS_REAL:.*]] = llvm.intr.fabs(%[[REAL]]) : (f32) -> f32
// CHECK: %[[ABS_IMAG:.*]] = llvm.intr.fabs(%[[IMAG]]) : (f32) -> f32
// CHECK: %[[MAX:.*]] = llvm.intr.maximum(%[[ABS_REAL]], %[[ABS_IMAG]]) : (f32, f32) -> f32
More information about the Mlir-commits
mailing list