[Mlir-commits] [mlir] 375a5cb - Don't lower log1p(x) to log(1 + x).

Johannes Reifferscheid llvmlistbot at llvm.org
Mon Aug 15 21:58:10 PDT 2022


Author: Johannes Reifferscheid
Date: 2022-08-16T06:58:00+02:00
New Revision: 375a5cb648835db0b1eacfc921cbb04844b8b3b4

URL: https://github.com/llvm/llvm-project/commit/375a5cb648835db0b1eacfc921cbb04844b8b3b4
DIFF: https://github.com/llvm/llvm-project/commit/375a5cb648835db0b1eacfc921cbb04844b8b3b4.diff

LOG: Don't lower log1p(x) to log(1 + x).

The latter has accuracy issues around 0. The lowering in MathToLLVM is kept for now.

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D131676

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index 9e7aa1a0f52ac..c07dcfd090d2a 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -15,8 +15,10 @@ template <typename T>
 class OperationPass;
 
 /// Populate the given list with patterns that convert from Math to Libm calls.
-void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit);
+/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
+void populateMathToLibmConversionPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit,
+    llvm::Optional<PatternBenefit> log1pBenefit = llvm::None);
 
 /// Create a pass to convert Math operations to libm calls.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();

diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 643806f2b0fa0..064b0db08a413 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -513,11 +513,28 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
 
     Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
     Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+
+    Value half = b.create<arith::ConstantOp>(elementType,
+                                             b.getFloatAttr(elementType, 0.5));
     Value one = b.create<arith::ConstantOp>(elementType,
                                             b.getFloatAttr(elementType, 1));
+    Value two = b.create<arith::ConstantOp>(elementType,
+                                            b.getFloatAttr(elementType, 2));
+
+    // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
+    // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
+    // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
+    Value sumSq = b.create<arith::MulFOp>(real, real);
+    sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(real, two));
+    sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(imag, imag));
+    Value logSumSq = b.create<math::Log1pOp>(elementType, sumSq);
+    Value resultReal = b.create<arith::MulFOp>(logSumSq, half);
+
     Value realPlusOne = b.create<arith::AddFOp>(real, one);
-    Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
-    rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
+
+    Value resultImag = b.create<math::Atan2Op>(elementType, imag, realPlusOne);
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+                                                   resultImag);
     return success();
   }
 };

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 2cd5ca08395aa..43ce675da926f 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -138,8 +138,9 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
   return success();
 }
 
-void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
-                                                PatternBenefit benefit) {
+void mlir::populateMathToLibmConversionPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit,
+    llvm::Optional<PatternBenefit> log1pBenefit) {
   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
                VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
                VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
@@ -168,6 +169,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
                                                 "cos", benefit);
   patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
                                                 "sin", benefit);
+  patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(
+      patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit));
 }
 
 namespace {

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index cac758a89b61d..e11187af14b86 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
-// RUN: FileCheck %s
+// RUN: FileCheck %s --dump-input=always
 
 // CHECK-LABEL: func @complex_abs
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -262,21 +262,21 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
   %log1p = complex.log1p %arg: complex<f32>
   return %log1p : complex<f32>
 }
+
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32 
+// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32
+// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32 
+// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32 
+// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32 
+// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32 
+// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32
 // CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32
-// CHECK: %[[NEW_COMPLEX:.*]] = complex.create %[[REAL_PLUS_ONE]], %[[IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
-// CHECK: %[[IMAG:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
-// CHECK: %[[REAL2:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
-// CHECK: %[[IMAG2:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
-// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32 
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index ced15f571a40f..f67c994a3b78b 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -303,3 +303,15 @@ func.func @tan_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vec
   %double_result = math.tan %double : vector<2xf64>
   return %float_result, %double_result : vector<2xf32>, vector<2xf64>
 }
+
+// CHECK-LABEL: func @log1p_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @log1p_caller(%float: f32, %double: f64) -> (f32, f64)  {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @log1pf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.log1p %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @log1p(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.log1p %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list