[llvm-branch-commits] [mlir] ee3842f - Revert "Fix complex log1p accuracy with large abs values. (#88260)"

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Apr 10 09:24:29 PDT 2024


Author: Mehdi Amini
Date: 2024-04-10T18:24:25+02:00
New Revision: ee3842f82072f80020ec4449a26b6bc6bb44573b

URL: https://github.com/llvm/llvm-project/commit/ee3842f82072f80020ec4449a26b6bc6bb44573b
DIFF: https://github.com/llvm/llvm-project/commit/ee3842f82072f80020ec4449a26b6bc6bb44573b.diff

LOG: Revert "Fix complex log1p accuracy with large abs values. (#88260)"

This reverts commit 49ef12a08c4c7d7ae4765929e72fe2320a12b08c.

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0aa1de5fa5d9a1..9c3c4d96a301ef 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -570,39 +570,37 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
-    arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
+    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-    Value real = b.create<complex::ReOp>(adaptor.getComplex());
-    Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
 
     Value half = b.create<arith::ConstantOp>(elementType,
                                              b.getFloatAttr(elementType, 0.5));
     Value one = b.create<arith::ConstantOp>(elementType,
                                             b.getFloatAttr(elementType, 1));
-    Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
-    Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
-    Value absImag = b.create<math::AbsFOp>(imag, fmf);
-
-    Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
-    Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
-
-    Value maxAbsOfRealPlusOneAndImagMinusOne = b.create<arith::SelectOp>(
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, realPlusOne, absImag,
-                                fmf),
-        real, b.create<arith::SubFOp>(maxAbs, one, fmf));
-    Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmf);
-    Value logOfMaxAbsOfRealPlusOneAndImag =
-        b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
-    Value logOfSqrtPart = b.create<math::Log1pOp>(
-        b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmf), fmf);
-    Value r = b.create<arith::AddFOp>(
-        b.create<arith::MulFOp>(half, logOfSqrtPart, fmf),
-        logOfMaxAbsOfRealPlusOneAndImag, fmf);
-    Value resultReal = b.create<arith::SelectOp>(
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmf), minAbs,
-        r);
-    Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
+    Value two = b.create<arith::ConstantOp>(elementType,
+                                            b.getFloatAttr(elementType, 2));
+
+    // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
+    // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
+    // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
+    Value sumSq = b.create<arith::MulFOp>(real, real, fmf.getValue());
+    sumSq = b.create<arith::AddFOp>(
+        sumSq, b.create<arith::MulFOp>(real, two, fmf.getValue()),
+        fmf.getValue());
+    sumSq = b.create<arith::AddFOp>(
+        sumSq, b.create<arith::MulFOp>(imag, imag, fmf.getValue()),
+        fmf.getValue());
+    Value logSumSq =
+        b.create<math::Log1pOp>(elementType, sumSq, fmf.getValue());
+    Value resultReal = b.create<arith::MulFOp>(logSumSq, half, fmf.getValue());
+
+    Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf.getValue());
+
+    Value resultImag =
+        b.create<math::Atan2Op>(elementType, imag, realPlusOne, fmf.getValue());
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 43918904a09f40..f5d9499eadda48 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -300,22 +300,15 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32
+// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32
+// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32
+// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32
 // CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32
-// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
-// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
-// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
-// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 : f32
-// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32
-// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
-// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] : f32
-// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] : f32
-// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] : f32
-// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] : f32
-// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] : f32
-// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] : f32
-// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32
 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
@@ -970,22 +963,15 @@ func.func @complex_log1p_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract>  : f32
-// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] fastmath<nnan,contract>  : f32
-// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract>  : f32
-// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract>  : f32
-// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract>  : f32
-// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract>  : f32
-// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 fastmath<nnan,contract>  : f32
-// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32
-// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract>  : f32
-// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] fastmath<nnan,contract> : f32
-// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] fastmath<nnan,contract>  : f32
-// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] fastmath<nnan,contract>  : f32
-// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] fastmath<nnan,contract>  : f32
-// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] fastmath<nnan,contract>  : f32
-// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] fastmath<nnan,contract> : f32
-// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32
+// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>


        


More information about the llvm-branch-commits mailing list