[Mlir-commits] [mlir] 375a5cb - Don't lower log1p(x) to log(1 + x).
Johannes Reifferscheid
llvmlistbot at llvm.org
Mon Aug 15 21:58:10 PDT 2022
Author: Johannes Reifferscheid
Date: 2022-08-16T06:58:00+02:00
New Revision: 375a5cb648835db0b1eacfc921cbb04844b8b3b4
URL: https://github.com/llvm/llvm-project/commit/375a5cb648835db0b1eacfc921cbb04844b8b3b4
DIFF: https://github.com/llvm/llvm-project/commit/375a5cb648835db0b1eacfc921cbb04844b8b3b4.diff
LOG: Don't lower log1p(x) to log(1 + x).
The latter has accuracy issues around 0. The lowering in MathToLLVM is kept for now.
Reviewed By: bkramer
Differential Revision: https://reviews.llvm.org/D131676
Added:
Modified:
mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index 9e7aa1a0f52ac..c07dcfd090d2a 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -15,8 +15,10 @@ template <typename T>
class OperationPass;
/// Populate the given list with patterns that convert from Math to Libm calls.
-void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit);
+/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
+void populateMathToLibmConversionPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit,
+ llvm::Optional<PatternBenefit> log1pBenefit = llvm::None);
/// Create a pass to convert Math operations to libm calls.
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 643806f2b0fa0..064b0db08a413 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -513,11 +513,28 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+
+ Value half = b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, 0.5));
Value one = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 1));
+ Value two = b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, 2));
+
+ // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
+ // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
+ // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
+ Value sumSq = b.create<arith::MulFOp>(real, real);
+ sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(real, two));
+ sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(imag, imag));
+ Value logSumSq = b.create<math::Log1pOp>(elementType, sumSq);
+ Value resultReal = b.create<arith::MulFOp>(logSumSq, half);
+
Value realPlusOne = b.create<arith::AddFOp>(real, one);
- Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
- rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
+
+ Value resultImag = b.create<math::Atan2Op>(elementType, imag, realPlusOne);
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+ resultImag);
return success();
}
};
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 2cd5ca08395aa..43ce675da926f 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -138,8 +138,9 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
return success();
}
-void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit) {
+void mlir::populateMathToLibmConversionPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit,
+ llvm::Optional<PatternBenefit> log1pBenefit) {
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
@@ -168,6 +169,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
"cos", benefit);
patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
"sin", benefit);
+ patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(
+ patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit));
}
namespace {
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index cac758a89b61d..e11187af14b86 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
-// RUN: FileCheck %s
+// RUN: FileCheck %s --dump-input=always
// CHECK-LABEL: func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -262,21 +262,21 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
%log1p = complex.log1p %arg: complex<f32>
return %log1p : complex<f32>
}
+
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32
+// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32
+// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32
+// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32
// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32
-// CHECK: %[[NEW_COMPLEX:.*]] = complex.create %[[REAL_PLUS_ONE]], %[[IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
-// CHECK: %[[IMAG:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
-// CHECK: %[[SQR_REAL:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[SQR_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_NORM:.*]] = arith.addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
-// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
-// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
-// CHECK: %[[REAL2:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
-// CHECK: %[[IMAG2:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
-// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index ced15f571a40f..f67c994a3b78b 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -303,3 +303,15 @@ func.func @tan_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vec
%double_result = math.tan %double : vector<2xf64>
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
}
+
+// CHECK-LABEL: func @log1p_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @log1p_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @log1pf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.log1p %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @log1p(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.log1p %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list