[Mlir-commits] [mlir] [mlir][complex] Prevent underflow in complex.abs (#79786) (PR #81092)
Kai Sasaki
llvmlistbot at llvm.org
Wed Feb 7 22:33:44 PST 2024
================
@@ -26,29 +26,59 @@ namespace mlir {
using namespace mlir;
namespace {
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
- auto type = op.getType();
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
- Value real =
- rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
- Value imag =
- rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
- Value realSqr =
- rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
- Value imagSqr =
- rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
- Value sqNorm =
- rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
-
- rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
+ Type elementType = op.getType();
+ Value arg = adaptor.getComplex();
+
+ Value zero =
+ b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+ Value one = b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, 1.0));
+
+ Value real = b.create<complex::ReOp>(elementType, arg);
+ Value imag = b.create<complex::ImOp>(elementType, arg);
+
+ Value realIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+ Value imagIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+
+ // Real > Imag
+ Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
+ Value imagSq =
+ b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
+ Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
+ Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
+ Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
+ Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
+
+ // Real <= Imag
+ Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
+ Value realSq =
+ b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
+ Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
+ Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
+ Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
+ Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(
+ op, realIsZero, imagAbs,
----------------
Lewuathe wrote:
Use `imagAbs` if real is zero.
https://github.com/llvm/llvm-project/pull/81092
More information about the Mlir-commits
mailing list