[Mlir-commits] [mlir] [mlir][complex] Add a numerically-stable lowering for complex.expm1. (PR #115082)
Alexander Belyaev
llvmlistbot at llvm.org
Sat Nov 16 20:48:54 PST 2024
================
@@ -520,29 +520,94 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
}
};
+Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
+ ArrayRef<double> coefficients,
+ arith::FastMathFlagsAttr fmf) {
+ auto argType = mlir::cast<FloatType>(arg.getType());
+ Value poly =
+ b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
+ for (int i = 1; i < coefficients.size(); ++i) {
+ poly = b.create<math::FmaOp>(
+ poly, arg,
+ b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
+ fmf);
+ }
+ return poly;
+}
+
struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
+ // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
+ // [handle inaccuracies when a and/or b are small]
+ // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
+ // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
LogicalResult
matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = cast<ComplexType>(adaptor.getComplex().getType());
- auto elementType = cast<FloatType>(type.getElementType());
+ auto type = op.getType();
+ auto elemType = mlir::cast<FloatType>(type.getElementType());
+
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value real = b.create<complex::ReOp>(adaptor.getComplex());
+ Value imag = b.create<complex::ImOp>(adaptor.getComplex());
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
+ Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
+ Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
- Value real = b.create<complex::ReOp>(elementType, exp);
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
- Value imag = b.create<complex::ImOp>(elementType, exp);
+ Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
+ Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
+
+ Value sinImag = b.create<math::SinOp>(imag, fmf);
+ Value cosm1Imag = emitCosm1(imag, fmf, b);
+ Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
- rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
- imag);
+ Value realResult = b.create<arith::AddFOp>(
+ b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
+
+ Value imageIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
----------------
pifon2a wrote:
done
https://github.com/llvm/llvm-project/pull/115082
More information about the Mlir-commits
mailing list