[Mlir-commits] [mlir] [mlir][complex] Prevent underflow in complex.abs (PR #76316)

Matthias Springer llvmlistbot at llvm.org
Tue Jan 23 00:20:40 PST 2024


================
@@ -26,29 +26,57 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    auto type = op.getType();
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
 
-    Value real =
-        rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
-    Value imag =
-        rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
-    Value realSqr =
-        rewriter.create<arith::MulFOp>(loc, real, real, fmf.getValue());
-    Value imagSqr =
-        rewriter.create<arith::MulFOp>(loc, imag, imag, fmf.getValue());
-    Value sqNorm =
-        rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr, fmf.getValue());
-
-    rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
+    Type elementType = op.getType();
+    Value arg = adaptor.getComplex();
+
+    Value zero =
+        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+    Value one = b.create<arith::ConstantOp>(elementType,
+                                            b.getFloatAttr(elementType, 1.0));
+
+    Value real = b.create<complex::ReOp>(elementType, arg);
+    Value imag = b.create<complex::ImOp>(elementType, arg);
+
+    Value realIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+    Value imagIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+
+    // Real > Imag
+    Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
----------------
matthias-springer wrote:

I'm wondering what's going to happen here if `real` is zero. Should the computation be wrapped in `scf.if` instead of using an `arith.select` below?


https://github.com/llvm/llvm-project/pull/76316


More information about the Mlir-commits mailing list