[Mlir-commits] [mlir] Fix complex abs with nnan/ninf. (PR #95080)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 11 00:34:29 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Johannes Reifferscheid (jreiffers)
<details>
<summary>Changes</summary>
The current logic tests for inf/inf and 0/0 inputs using a NaN check. This doesn't work with all fastmath flags. With nnan and ninf, we can just check for a 0 maximum. With only nnan, we have to check for both cases separately.
---
Full diff: https://github.com/llvm/llvm-project/pull/95080.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+35-1)
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+59-22)
``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index d8150aeb828a5..cba82f97792f2 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -43,7 +43,6 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
Value ratio = b.create<arith::DivFOp>(min, max, fmf);
Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
- Value result;
if (fn == AbsFn::rsqrt) {
ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
@@ -51,6 +50,7 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
max = b.create<math::RsqrtOp>(max, fmf);
}
+ Value result;
if (fn == AbsFn::sqrt) {
Value quarter = b.create<arith::ConstantOp>(
real.getType(), b.getFloatAttr(real.getType(), 0.25));
@@ -63,6 +63,40 @@ Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
result = b.create<arith::MulFOp>(max, sqrt, fmf);
}
+ if (arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
+ arith::FastMathFlags::ninf)) {
+ // We only need to handle the 0/0 case here.
+ Value zero = b.create<arith::ConstantOp>(
+ real.getType(), b.getFloatAttr(real.getType(), 0.0));
+ Value maxIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, max, zero);
+ return b.create<arith::SelectOp>(maxIsZero, min, result);
+ }
+
+ if (arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan)) {
+ Value zero = b.create<arith::ConstantOp>(
+ real.getType(), b.getFloatAttr(real.getType(), 0.0));
+ Value inf = b.create<arith::ConstantOp>(
+ real.getType(),
+ b.getFloatAttr(
+ real.getType(),
+ APFloat::getInf(
+ cast<FloatType>(real.getType()).getFloatSemantics())));
+ Value maxIsInf =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, max, inf, fmf);
+ Value minIsInf =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, min, inf, fmf);
+ // We need to handle inf/inf and 0/0 specially. The former is inf, the
+ // latter is 0. Both produce poison in the division.
+ Value resultIsInf = b.create<arith::AndIOp>(maxIsInf, minIsInf);
+ Value resultIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, max, zero);
+ result = b.create<arith::SelectOp>(resultIsInf, inf, result);
+ result = b.create<arith::SelectOp>(resultIsZero, zero, result);
+ return result;
+ }
+
+ // This handles both inf/inf and 0/0.
Value isNaN =
b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
return b.create<arith::SelectOp>(isNaN, min, result);
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 6dafe29e2e5f6..ccc85a29c03f1 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
-// RUN: FileCheck %s --dump-input=always
+// RUN: FileCheck %s
// CHECK-LABEL: func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -709,9 +709,10 @@ func.func @complex_sqrt_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,ninf> : f32
// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,ninf> : f32
-// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,ninf> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,ninf> : f32
-// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[SQRT_ABS_OR_POISON:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,ninf> : f32
+// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[IS_POISON:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,ninf> : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_POISON]], %[[MIN]], %[[SQRT_ABS_OR_POISON]] : f32
// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,ninf> : f32
// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,ninf> : f32
// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,ninf> : f32
@@ -823,9 +824,15 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO]], %[[ABS_OR_INF]] : f32
// CHECK: return %[[ABS]] : f32
// -----
@@ -922,9 +929,15 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO]], %[[ABS_OR_INF]] : f32
// CHECK: %[[RESULT_REAL:.*]] = math.log %[[ABS]] fastmath<nnan,contract> : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
@@ -1304,9 +1317,15 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[SQRT_ABS_OR_POISON:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[SQRT_ABS_OR_POISON]] : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_2]], %[[ABS_OR_INF]] : f32
// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
@@ -1543,9 +1562,15 @@ func.func @complex_atan2_with_fmf(%lhs: complex<f32>,
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO_3:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_3]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_3]], %[[ABS_OR_INF]] : f32
// CHECK: %[[VAR436:.*]] = math.log %[[ABS]] fastmath<nnan,contract> : f32
// CHECK: %[[VAR437:.*]] = complex.re %[[VAR415]] : complex<f32>
// CHECK: %[[VAR438:.*]] = complex.im %[[VAR415]] : complex<f32>
@@ -1784,9 +1809,15 @@ func.func @complex_sqrt_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[QUARTER:.*]] = arith.constant 2.500000e-01 : f32
// CHECK: %[[SQRT_MAX:.*]] = math.sqrt %[[MAX]] fastmath<nnan,contract> : f32
// CHECK: %[[POW:.*]] = math.powf %[[RATIO_SQ_PLUS_ONE]], %[[QUARTER]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT_ABS_OR_NAN:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[SQRT_ABS_OR_NAN]], %[[SQRT_ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[SQRT_ABS_OR_NAN]] : f32
+// CHECK: %[[SQRT_ABS_OR_POISON:.*]] = arith.mulf %[[SQRT_MAX]], %[[POW]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT_ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[SQRT_ABS_OR_POISON]] : f32
+// CHECK: %[[SQRT_ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_2]], %[[SQRT_ABS_OR_INF]] : f32
// CHECK: %[[ARGARG:.*]] = math.atan2 %[[IM]], %[[RE]] fastmath<nnan,contract> : f32
// CHECK: %[[SQRTARG:.*]] = arith.mulf %[[ARGARG]], %[[HALF]] fastmath<nnan,contract> : f32
// CHECK: %[[COS:.*]] = math.cos %[[SQRTARG]] fastmath<nnan,contract> : f32
@@ -1890,9 +1921,15 @@ func.func @complex_sign_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
-// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS_OR_POISON:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO_2:.*]] = arith.constant 0.000000e+00
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000
+// CHECK: %[[MAX_IS_INF:.*]] = arith.cmpf oeq, %[[MAX]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN_IS_INF:.*]] = arith.cmpf oeq, %[[MIN]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_IS_INF:.*]] = arith.andi %[[MAX_IS_INF]], %[[MIN_IS_INF]]
+// CHECK: %[[RESULT_IS_ZERO:.*]] = arith.cmpf oeq, %[[MAX]], %[[ZERO_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_INF:.*]] = arith.select %[[RESULT_IS_INF]], %[[INF]], %[[ABS_OR_POISON]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[RESULT_IS_ZERO]], %[[ZERO_2]], %[[ABS_OR_INF]] : f32
// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[ABS]] fastmath<nnan,contract> : f32
// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[ABS]] fastmath<nnan,contract> : f32
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/95080
More information about the Mlir-commits
mailing list