[Mlir-commits] [mlir] [mlir][complex] Add a numerically-stable lowering for complex.expm1. (PR #115082)

Alexander Belyaev llvmlistbot at llvm.org
Sun Nov 17 08:25:09 PST 2024


https://github.com/pifon2a updated https://github.com/llvm/llvm-project/pull/115082

>From ec60b65085236e5831357230e357e786898e88b7 Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon at google.com>
Date: Sun, 17 Nov 2024 17:24:13 +0100
Subject: [PATCH] [mlir][complex] Add a numerically-stable lowering for
 complex.expm1.

The current conversion to Standard in the MLIR repo is not stable for small
imag(arg).
---
 .../ComplexToStandard/ComplexToStandard.cpp   | 87 ++++++++++++++++---
 .../convert-to-standard.mlir                  | 85 +++++++++---------
 2 files changed, 120 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 6656be830989a4..9282518191274f 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -520,29 +520,94 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
   }
 };
 
+Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
+                         ArrayRef<double> coefficients,
+                         arith::FastMathFlagsAttr fmf) {
+  auto argType = mlir::cast<FloatType>(arg.getType());
+  Value poly =
+      b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
+  for (int i = 1; i < coefficients.size(); ++i) {
+    poly = b.create<math::FmaOp>(
+        poly, arg,
+        b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
+        fmf);
+  }
+  return poly;
+}
+
 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
   using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
 
+  // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
+  //            [handle inaccuracies when a and/or b are small]
+  //            = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
+  //            = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
   LogicalResult
   matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto type = cast<ComplexType>(adaptor.getComplex().getType());
-    auto elementType = cast<FloatType>(type.getElementType());
+    auto type = op.getType();
+    auto elemType = mlir::cast<FloatType>(type.getElementType());
+
     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value real = b.create<complex::ReOp>(adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(adaptor.getComplex());
 
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
+    Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
+    Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
 
-    Value real = b.create<complex::ReOp>(elementType, exp);
-    Value one = b.create<arith::ConstantOp>(elementType,
-                                            b.getFloatAttr(elementType, 1));
-    Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
-    Value imag = b.create<complex::ImOp>(elementType, exp);
+    Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
+    Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
+
+    Value sinImag = b.create<math::SinOp>(imag, fmf);
+    Value cosm1Imag = emitCosm1(imag, fmf, b);
+    Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
 
-    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
-                                                   imag);
+    Value realResult = b.create<arith::AddFOp>(
+        b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
+
+    Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
+                                               zero, fmf.getValue());
+    Value imagResult = b.create<arith::SelectOp>(
+        imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
+
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
+                                                   imagResult);
     return success();
   }
+
+private:
+  Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
+                  ImplicitLocOpBuilder &b) const {
+    auto argType = mlir::cast<FloatType>(arg.getType());
+    auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
+    auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
+
+    // Algorithm copied from cephes cosm1.
+    SmallVector<double, 7> kCoeffs{
+        4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
+        2.0876754287081521758361E-9,  -2.7557319214999787979814E-7,
+        2.4801587301570552304991E-5,  -1.3888888888888872993737E-3,
+        4.1666666666666666609054E-2,
+    };
+    Value cos = b.create<math::CosOp>(arg, fmf);
+    Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
+
+    Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
+    Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
+    Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
+
+    auto forSmallArg =
+        b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
+                                b.create<arith::MulFOp>(negHalf, argPow2, fmf));
+
+    // (pi/4)^2 is approximately 0.61685
+    Value piOver4Pow2 =
+        b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
+    Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
+                                         piOver4Pow2, fmf.getValue());
+    return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
+  }
 };
 
 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index d7767bda08435f..3d73292e6b8868 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -221,26 +221,52 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
 
 // -----
 
-// CHECK-LABEL:   func.func @complex_expm1(
-// CHECK-SAME:                             %[[ARG:.*]]: complex<f32>) -> complex<f32> {
+// CHECK-LABEL: func.func @complex_expm1(
+// CHECK-SAME:    %[[ARG:.*]]: complex<f32>) -> complex<f32> {
 func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
-  %expm1 = complex.expm1 %arg: complex<f32>
+  %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
   return %expm1 : complex<f32>
 }
-// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] : f32
-// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] : f32
-// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] : f32
-// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] : f32
-// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] : f32
-// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] : f32
-// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
-// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
-// CHECK: return %[[RES]] : complex<f32>
+// CHECK:  %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK:  %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG:  %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:  %[[C1_F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK:  %[[EXPM1:.*]] = math.expm1 %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_6:.*]] = arith.addf %[[EXPM1]], %[[C1_F32]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_7:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_8:.*]] = arith.constant -5.000000e-01 : f32
+// CHECK:  %[[VAL_9:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK:  %[[VAL_10:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_11:.*]] = arith.addf %[[VAL_10]], %[[VAL_9]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_12:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_13:.*]] = arith.mulf %[[VAL_12]], %[[VAL_12]] fastmath<nnan,contract> : f32
+// CHECK-DAG:  %[[COEF0:.*]] = arith.constant 4.73775072E-14 : f32
+// CHECK-DAG:  %[[COEF1:.*]] = arith.constant -1.14702848E-11 : f32
+// CHECK:  %[[FMA0:.*]] = math.fma %[[COEF0]], %[[VAL_12]], %[[COEF1]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF2:.*]] = arith.constant 2.08767537E-9 : f32
+// CHECK:  %[[FMA1:.*]] = math.fma %[[FMA0]], %[[VAL_12]], %[[COEF2]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF3:.*]] = arith.constant -2.755732E-7 : f32
+// CHECK:  %[[FMA2:.*]] = math.fma %[[FMA1]], %[[VAL_12]], %[[COEF3]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF4:.*]] = arith.constant 2.48015876E-5 : f32
+// CHECK:  %[[FMA3:.*]] = math.fma %[[FMA2]], %[[VAL_12]], %[[COEF4]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF5:.*]] = arith.constant -0.00138888892 : f32
+// CHECK:  %[[FMA4:.*]] = math.fma %[[FMA3]], %[[VAL_12]], %[[COEF5]] fastmath<nnan,contract> : f32
+// CHECK:  %[[COEF6:.*]] = arith.constant 0.0416666679 : f32
+// CHECK:  %[[FMA5:.*]] = math.fma %[[FMA4]], %[[VAL_12]], %[[COEF6]] fastmath<nnan,contract> : f32
+// CHECK-DAG:  %[[VAL_27:.*]] = arith.mulf %[[VAL_13]], %[[FMA5]] fastmath<nnan,contract> : f32
+// CHECK-DAG:  %[[VAL_28:.*]] = arith.mulf %[[VAL_8]], %[[VAL_12]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_29:.*]] = arith.addf %[[VAL_27]], %[[VAL_28]] : f32
+// CHECK:  %[[VAL_30:.*]] = arith.constant 6.168500e-01 : f32
+// CHECK:  %[[VAL_31:.*]] = arith.cmpf oge, %[[VAL_12]], %[[VAL_30]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_32:.*]] = arith.select %[[VAL_31]], %[[VAL_11]], %[[VAL_29]] : f32
+// CHECK:  %[[VAL_33:.*]] = arith.addf %[[VAL_32]], %[[C1_F32]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_34:.*]] = arith.mulf %[[EXPM1]], %[[VAL_33]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_35:.*]] = arith.addf %[[VAL_34]], %[[VAL_32]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_36:.*]] = arith.cmpf oeq, %[[IMAG]], %[[C0_F32]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_37:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath<nnan,contract> : f32
+// CHECK:  %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[C0_F32]], %[[VAL_37]] : f32
+// CHECK:  %[[RESULT:.*]] = complex.create %[[VAL_35]], %[[VAL_38]] : complex<f32>
+// CHECK:  return %[[RESULT]] : complex<f32>
 
 // -----
 
@@ -882,29 +908,6 @@ func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
 
 // -----
 
-// CHECK-LABEL:   func.func @complex_expm1_with_fmf(
-// CHECK-SAME:                             %[[ARG:.*]]: complex<f32>) -> complex<f32> {
-func.func @complex_expm1_with_fmf(%arg: complex<f32>) -> complex<f32> {
-  %expm1 = complex.expm1 %arg fastmath<nnan,contract> : complex<f32>
-  return %expm1 : complex<f32>
-}
-// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath<nnan,contract> : f32
-// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath<nnan,contract> : f32
-// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex<f32>
-// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex<f32>
-// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
-// CHECK: return %[[RES]] : complex<f32>
-
-// -----
-
 // CHECK-LABEL: func @complex_log_with_fmf
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
@@ -2020,4 +2023,4 @@ func.func @complex_angle_with_fmf(%arg: complex<f32>) -> f32 {
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: return %[[RESULT]] : f32
\ No newline at end of file
+// CHECK: return %[[RESULT]] : f32



More information about the Mlir-commits mailing list