[Mlir-commits] [mlir] [mlir][complex] Add a numerically-stable lowering for complex.expm1. (PR #115082)
Alexander Belyaev
llvmlistbot at llvm.org
Sat Nov 16 20:49:14 PST 2024
https://github.com/pifon2a updated https://github.com/llvm/llvm-project/pull/115082
>From d0676e9f38eb075c55a07d86177de048adc60f11 Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon at google.com>
Date: Sun, 17 Nov 2024 05:46:21 +0100
Subject: [PATCH] [mlir][complex] Add a numerically-stable lowering for
complex.expm1.
The current conversion to Standard in the MLIR repo is not stable for small
imag(arg).
---
.../ComplexToStandard/ComplexToStandard.cpp | 87 ++++++++++++++++---
.../convert-to-standard.mlir | 83 +++++++++---------
2 files changed, 119 insertions(+), 51 deletions(-)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 6656be830989a4..9282518191274f 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -520,29 +520,94 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
}
};
+Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
+ ArrayRef<double> coefficients,
+ arith::FastMathFlagsAttr fmf) {
+ auto argType = mlir::cast<FloatType>(arg.getType());
+ Value poly =
+ b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
+ for (int i = 1; i < coefficients.size(); ++i) {
+ poly = b.create<math::FmaOp>(
+ poly, arg,
+ b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
+ fmf);
+ }
+ return poly;
+}
+
struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
+ // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
+ // [handle inaccuracies when a and/or b are small]
+ // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
+ // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
LogicalResult
matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = cast<ComplexType>(adaptor.getComplex().getType());
- auto elementType = cast<FloatType>(type.getElementType());
+ auto type = op.getType();
+ auto elemType = mlir::cast<FloatType>(type.getElementType());
+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value real = b.create<complex::ReOp>(adaptor.getComplex());
+ Value imag = b.create<complex::ImOp>(adaptor.getComplex());
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
+ Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
+ Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
- Value real = b.create<complex::ReOp>(elementType, exp);
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
- Value imag = b.create<complex::ImOp>(elementType, exp);
+ Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
+ Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
+
+ Value sinImag = b.create<math::SinOp>(imag, fmf);
+ Value cosm1Imag = emitCosm1(imag, fmf, b);
+ Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
- rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
- imag);
+ Value realResult = b.create<arith::AddFOp>(
+ b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
+
+ Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
+ zero, fmf.getValue());
+ Value imagResult = b.create<arith::SelectOp>(
+ imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
+
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
+ imagResult);
return success();
}
+
+private:
+ Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
+ ImplicitLocOpBuilder &b) const {
+ auto argType = mlir::cast<FloatType>(arg.getType());
+ auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
+ auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
+
+ // Algorithm copied from cephes cosm1.
+ SmallVector<double, 7> kCoeffs{
+ 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
+ 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
+ 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
+ 4.1666666666666666609054E-2,
+ };
+ Value cos = b.create<math::CosOp>(arg, fmf);
+ Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
+
+ Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
+ Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
+ Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
+
+ auto forSmallArg =
+ b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
+ b.create<arith::MulFOp>(negHalf, argPow2, fmf));
+
+ // (pi/4)^2 is approximately 0.61685
+ Value piOver4Pow2 =
+ b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
+ Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
+ piOver4Pow2, fmf.getValue());
+ return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
+ }
};
struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index d7767bda08435f..1e2724e17d765e 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -221,26 +221,52 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
// -----
-// CHECK-LABEL: func.func @complex_expm1(
-// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+// CHECK-LABEL: func.func @complex_expm1(
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
- %expm1 = complex.expm1 %arg: complex<f32>
+ %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
return %expm1 : complex<f32>
}
-// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] : f32
-// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] : f32
-// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] : f32
-// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] : f32
-// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] : f32
-// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] : f32
-// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
-// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
-// CHECK: return %[[RES]] : complex<f32>
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C1_F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[EXPM1:.*]] = math.expm1 %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_6:.*]] = arith.addf %[[EXPM1]], %[[C1_F32]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_7:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_8:.*]] = arith.constant -5.000000e-01 : f32
+// CHECK: %[[VAL_9:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK: %[[VAL_10:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_10]], %[[VAL_9]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_12:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_13:.*]] = arith.mulf %[[VAL_12]], %[[VAL_12]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF0:.*]] = arith.constant 4.73775072E-14 : f32
+// CHECK: %[[COEF1:.*]] = arith.constant -1.14702848E-11 : f32
+// CHECK: %[[FMA0:.*]] = math.fma %[[COEF0]], %[[VAL_12]], %[[COEF1]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF2:.*]] = arith.constant 2.08767537E-9 : f32
+// CHECK: %[[FMA1:.*]] = math.fma %[[FMA0]], %[[VAL_12]], %[[COEF2]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF3:.*]] = arith.constant -2.755732E-7 : f32
+// CHECK: %[[FMA2:.*]] = math.fma %[[FMA1]], %[[VAL_12]], %[[COEF3]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF4:.*]] = arith.constant 2.48015876E-5 : f32
+// CHECK: %[[FMA3:.*]] = math.fma %[[FMA2]], %[[VAL_12]], %[[COEF4]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF5:.*]] = arith.constant -0.00138888892 : f32
+// CHECK: %[[FMA4:.*]] = math.fma %[[FMA3]], %[[VAL_12]], %[[COEF5]] fastmath<nnan,contract> : f32
+// CHECK: %[[COEF6:.*]] = arith.constant 0.0416666679 : f32
+// CHECK: %[[FMA5:.*]] = math.fma %[[FMA4]], %[[VAL_12]], %[[COEF6]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_13]], %[[FMA5]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_8]], %[[VAL_12]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_29:.*]] = arith.addf %[[VAL_27]], %[[VAL_28]] : f32
+// CHECK: %[[VAL_30:.*]] = arith.constant 6.168500e-01 : f32
+// CHECK: %[[VAL_31:.*]] = arith.cmpf oge, %[[VAL_12]], %[[VAL_30]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_31]], %[[VAL_11]], %[[VAL_29]] : f32
+// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_32]], %[[C1_F32]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_34:.*]] = arith.mulf %[[EXPM1]], %[[VAL_33]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_34]], %[[VAL_32]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_36:.*]] = arith.cmpf oeq, %[[IMAG]], %[[C0_F32]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_37:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath<nnan,contract> : f32
+// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[C0_F32]], %[[VAL_37]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[VAL_35]], %[[VAL_38]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
// -----
@@ -882,29 +908,6 @@ func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
// -----
-// CHECK-LABEL: func.func @complex_expm1_with_fmf(
-// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
-func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
- %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
- return %expm1 : complex<f32>
-}
-// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
-// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
-// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
-// CHECK: return %[[RES]] : complex<f32>
-
-// -----
-
// CHECK-LABEL: func @complex_log_with_fmf
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
More information about the Mlir-commits
mailing list