[Mlir-commits] [mlir] [mlir][complex] Prevent underflow in complex.abs (#79786) (PR #81092)

Kai Sasaki llvmlistbot at llvm.org
Wed Feb 7 22:33:44 PST 2024


================
@@ -26,29 +26,59 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    auto type = op.getType();
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
-    Value real =
-        rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
-    Value imag =
-        rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
-    Value realSqr =
-        rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
-    Value imagSqr =
-        rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
-    Value sqNorm =
-        rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
-
-    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
+    Type elementType = op.getType();
+    Value arg = adaptor.getComplex();
+
+    Value zero =
+        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+    Value one = b.create<arith::ConstantOp>(elementType,
+                                            b.getFloatAttr(elementType, 1.0));
+
+    Value real = b.create<complex::ReOp>(elementType, arg);
+    Value imag = b.create<complex::ImOp>(elementType, arg);
+
+    Value realIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+    Value imagIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+
+    // Real > Imag
+    Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
+    Value imagSq =
+        b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
+    Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
+    Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
+    Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
+    Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
+
+    // Real <= Imag
+    Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
+    Value realSq =
+        b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
+    Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
+    Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
+    Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
+    Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(
+        op, realIsZero, imagAbs,
----------------
Lewuathe wrote:

Use `imagAbs` if real is zero.

https://github.com/llvm/llvm-project/pull/81092


More information about the Mlir-commits mailing list