[Mlir-commits] [mlir] [mlir][complex] Add a numerically-stable lowering for complex.expm1. (PR #115082)

Alexander Belyaev llvmlistbot at llvm.org
Sat Nov 16 20:48:54 PST 2024


================
@@ -520,29 +520,94 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
   }
 };
 
+Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
+                         ArrayRef<double> coefficients,
+                         arith::FastMathFlagsAttr fmf) {
+  auto argType = mlir::cast<FloatType>(arg.getType());
+  Value poly =
+      b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
+  for (int i = 1; i < coefficients.size(); ++i) {
+    poly = b.create<math::FmaOp>(
+        poly, arg,
+        b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
+        fmf);
+  }
+  return poly;
+}
+
 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
   using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
 
+  // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
+  //            [handle inaccuracies when a and/or b are small]
+  //            = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
+  //            = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
   LogicalResult
   matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto type = cast<ComplexType>(adaptor.getComplex().getType());
-    auto elementType = cast<FloatType>(type.getElementType());
+    auto type = op.getType();
+    auto elemType = mlir::cast<FloatType>(type.getElementType());
+
     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value real = b.create<complex::ReOp>(adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(adaptor.getComplex());
 
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-    Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
+    Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
+    Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
 
-    Value real = b.create<complex::ReOp>(elementType, exp);
-    Value one = b.create<arith::ConstantOp>(elementType,
-                                            b.getFloatAttr(elementType, 1));
-    Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
-    Value imag = b.create<complex::ImOp>(elementType, exp);
+    Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
+    Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
+
+    Value sinImag = b.create<math::SinOp>(imag, fmf);
+    Value cosm1Imag = emitCosm1(imag, fmf, b);
+    Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
 
-    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
-                                                   imag);
+    Value realResult = b.create<arith::AddFOp>(
+        b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
+
+    Value imageIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
----------------
pifon2a wrote:

done

https://github.com/llvm/llvm-project/pull/115082


More information about the Mlir-commits mailing list