[Mlir-commits] [mlir] Fix complex abs with nnan/ninf. (PR #95080)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 11 00:34:29 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Johannes Reifferscheid (jreiffers)

<details>
<summary>Changes</summary>

The current logic tests for inf/inf and 0/0 inputs using a NaN check. This doesn't work with all fastmath flags. With nnan and ninf, we can just check for a 0 maximum. With only nnan, we have to check for both cases separately.

---
Full diff: https://github.com/llvm/llvm-project/pull/95080.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+35-1) 
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+59-22) 


``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index d8150aeb828a5..cba82f97792f2 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -43,7 +43,6 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
   Value ratio = b.create<arith::DivFOp>(min, max, fmf);
   Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
   Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
-  Value result;
 
   if (fn == AbsFn::rsqrt) {
     ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
@@ -51,6 +50,7 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
     max = b.create<math::RsqrtOp>(max, fmf);
   }
 
+  Value result;
   if (fn == AbsFn::sqrt) {
     Value quarter = b.create<arith::ConstantOp>(
         real.getType(), b.getFloatAttr(real.getType(), 0.25));
@@ -63,6 +63,40 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
     result = b.create<arith::MulFOp>(max, sqrt, fmf);
   }
 
+  if (arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
+                                         arith::FastMathFlags::ninf)) {
+    // We only need to handle the 0/0 case here.
+    Value zero = b.create<arith::ConstantOp>(
+        real.getType(), b.getFloatAttr(real.getType(), 0.0));
+    Value maxIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, max, zero);
+    return b.create<arith::SelectOp>(maxIsZero, min, result);
+  }
+
+  if (arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan)) {
+    Value zero = b.create<arith::ConstantOp>(
+        real.getType(), b.getFloatAttr(real.getType(), 0.0));
+    Value inf = b.create<arith::ConstantOp>(
+        real.getType(),
+        b.getFloatAttr(
+            real.getType(),
+            APFloat::getInf(
+                cast<FloatType>(real.getType()).getFloatSemantics())));
+    Value maxIsInf =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, max, inf, fmf);
+    Value minIsInf =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, min, inf, fmf);
+    // We need to handle inf/inf and 0/0 specially. The former is inf, the
+    // latter is 0. Both produce poison in the division.
+    Value resultIsInf = b.create<arith::AndIOp>(maxIsInf, minIsInf);
+    Value resultIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, max, zero);
+    result = b.create<arith::SelectOp>(resultIsInf, inf, result);
+    result = b.create<arith::SelectOp>(resultIsZero, zero, result);
+    return result;
+  }
+
+  // This handles both inf/inf and 0/0.
   Value isNaN =
       b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
   return b.create<arith::SelectOp>(isNaN, min, result);
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 6dafe29e2e5f6..ccc85a29c03f1 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
-// RUN: FileCheck %s --dump-input=always
+// RUN: FileCheck %s
 
 // CHECK-LABEL: func @complex_abs
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -709,9 +709,10 @@ func.func @complex_sqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
 // CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,ninf> : f32
 // CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,ninf> : f32
-// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,ninf> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,ninf> : f32
-// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[SQRT_ABS_OR_POISON:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,ninf> : f32
+// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[IS_POISON:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_POISON]], %[[MIN]], %[[SQRT_ABS_OR_POISON]] : f32
 // CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,ninf> : f32
 // CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,ninf> : f32
 // CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,ninf> : f32
@@ -823,9 +824,15 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
 // CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
 // CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO]], %[[ABS_OR_INF]] : f32
 // CHECK: return %[[ABS]] : f32
 
 // -----
@@ -922,9 +929,15 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
 // CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO]], %[[ABS_OR_INF]] : f32
 // CHECK: %[[RESULT_REAL:.*]] = math.log %[[ABS]] fastmath<nnan,contract> : f32
 // CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -1304,9 +1317,15 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
 // CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
 // CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[SQRT_ABS_OR_POISON:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[SQRT_ABS_OR_POISON]] : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_2]], %[[ABS_OR_INF]] : f32
 // CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
 // CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
 // CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
@@ -1543,9 +1562,15 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
 // CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
 // CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO_3:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_3]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_3]], %[[ABS_OR_INF]] : f32
 // CHECK: %[[VAR436:.*]] = math.log %[[ABS]] fastmath<nnan,contract> : f32
 // CHECK: %[[VAR437:.*]] = complex.re %[[VAR415]] : complex<f32>
 // CHECK: %[[VAR438:.*]] = complex.im %[[VAR415]] : complex<f32>
@@ -1784,9 +1809,15 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
 // CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
 // CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[SQRT_ABS_OR_POISON:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[SQRT_ABS_OR_POISON]] : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_2]], %[[SQRT_ABS_OR_INF]] : f32
 // CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
 // CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
 // CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
@@ -1890,9 +1921,15 @@ func.func @complex_sign_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
 // CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_2]], %[[ABS_OR_INF]] : f32
 // CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[ABS]] fastmath<nnan,contract> : f32
 // CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[ABS]] fastmath<nnan,contract> : f32
 // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/95080


More information about the Mlir-commits mailing list