[llvm-branch-commits] [mlir] de88bd7 - Revert "Fix rsqrt inaccuracies. (#88691)"

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Apr 15 02:48:15 PDT 2024


Author: Johannes Reifferscheid
Date: 2024-04-15T11:48:12+02:00
New Revision: de88bd7e8925f5df51547e20f6fbd1ef006386ad

URL: https://github.com/llvm/llvm-project/commit/de88bd7e8925f5df51547e20f6fbd1ef006386ad
DIFF: https://github.com/llvm/llvm-project/commit/de88bd7e8925f5df51547e20f6fbd1ef006386ad.diff

LOG: Revert "Fix rsqrt inaccuracies. (#88691)"

This reverts commit 8ddaf750746d7f9b5f7e878870b086edc0f55326.

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 3ebee9baff31bd..49eb575212ffc1 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -27,11 +27,9 @@ using namespace mlir;
 
 namespace {
 
-enum class AbsFn { abs, sqrt, rsqrt };
-
-// Returns the absolute value, its square root or its reciprocal square root.
+// Returns the absolute value or its square root.
 Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
-                 ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
+                 ImplicitLocOpBuilder &b, bool returnSqrt = false) {
   Value one = b.create<arith::ConstantOp>(real.getType(),
                                           b.getFloatAttr(real.getType(), 1.0));
 
@@ -45,13 +43,7 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
   Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
   Value result;
 
-  if (fn == AbsFn::rsqrt) {
-    ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
-    min = b.create<math::RsqrtOp>(min, fmf);
-    max = b.create<math::RsqrtOp>(max, fmf);
-  }
-
-  if (fn == AbsFn::sqrt) {
+  if (returnSqrt) {
     Value quarter = b.create<arith::ConstantOp>(
         real.getType(), b.getFloatAttr(real.getType(), 0.25));
     // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
@@ -871,7 +863,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
 
     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
-    Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
+    Value absSqrt = computeAbs(real, imag, fmf, b, /*returnSqrt=*/true);
     Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
     Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
     Value cos = b.create<math::CosOp>(sqrtArg, fmf);
@@ -1155,74 +1147,18 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
   LogicalResult
   matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
 
-    arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
-
-    auto cst = [&](APFloat v) {
-      return b.create<arith::ConstantOp>(elementType,
-                                         b.getFloatAttr(elementType, v));
-    };
-    const auto &floatSemantics = elementType.getFloatSemantics();
-    Value zero = cst(APFloat::getZero(floatSemantics));
-    Value inf = cst(APFloat::getInf(floatSemantics));
-    Value negHalf = b.create<arith::ConstantOp>(
-        elementType, b.getFloatAttr(elementType, -0.5));
-    Value nan = cst(APFloat::getNaN(floatSemantics));
-
-    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
-    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
-    Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
-    Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
-    Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
-    Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
-    Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
-
-    Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
-    Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
-
-    if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
-                                            arith::FastMathFlags::ninf)) {
-      Value negOne = b.create<arith::ConstantOp>(
-          elementType, b.getFloatAttr(elementType, -1));
-
-      Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
-      Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
-      Value negImagSignedZero =
-          b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
+    Value c = builder.create<arith::ConstantOp>(
+        elementType, builder.getFloatAttr(elementType, -0.5));
+    Value d = builder.create<arith::ConstantOp>(
+        elementType, builder.getFloatAttr(elementType, 0));
 
-      Value absReal = b.create<math::AbsFOp>(real, fmf);
-      Value absImag = b.create<math::AbsFOp>(imag, fmf);
-
-      Value absImagIsInf =
-          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
-      Value realIsNan =
-          b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
-      Value realIsInf =
-          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
-      Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
-
-      Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
-
-      resultReal =
-          b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
-      resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
-                                             resultImag);
-    }
-
-    Value isRealZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
-    Value isImagZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
-    Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
-
-    resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
-    resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
-
-    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
-                                                   resultImag);
+    rewriter.replaceOp(op,
+                       {powOpConversionImpl(builder, type, adaptor.getComplex(),
+                                            c, d, op.getFastmath())});
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8b4ea9777f7976..e0e7cdadd317d2 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -837,21 +837,6 @@ func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
   return %rsqrt : complex<f32>
 }
 
-// CHECK-COUNT-5: arith.select
-// CHECK-NOT: arith.select
-
-// -----
-
-// CHECK-LABEL: func @complex_rsqrt_nnan_ninf
-// CHECK-SAME: %[[ARG:.*]]: complex<f32>
-func.func @complex_rsqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
-  %sqrt = complex.rsqrt %arg fastmath<nnan,ninf> : complex<f32>
-  return %sqrt : complex<f32>
-}
-
-// CHECK-COUNT-3: arith.select
-// CHECK-NOT: arith.select
-
 // -----
 
 // CHECK-LABEL:   func.func @complex_angle
@@ -2118,4 +2103,4 @@ func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
-// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
\ No newline at end of file


        


More information about the llvm-branch-commits mailing list