[Mlir-commits] [mlir] Fix complex abs corner cases. (PR #88373)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 11 03:00:17 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Johannes Reifferscheid (jreiffers)
<details>
<summary>Changes</summary>
The current implementation fails for very small and very large values. For example, (0, -inf) should return inf, but it returns -inf.
This ports the logic used in XLA. Tested with XLA's exhaustive_binary_test_f32_f64.
---
Patch is 39.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88373.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+16-38)
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+134-214)
- (modified) mlir/test/Conversion/ComplexToStandard/full-conversion.mlir (+12-22)
``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index a6fcf6a758c07f..462036e51a1f1c 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -26,7 +26,7 @@ namespace mlir {
using namespace mlir;
namespace {
-// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
+
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
@@ -35,49 +35,27 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
Type elementType = op.getType();
- Value arg = adaptor.getComplex();
-
- Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
Value one = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 1.0));
- Value real = b.create<complex::ReOp>(elementType, arg);
- Value imag = b.create<complex::ImOp>(elementType, arg);
-
- Value realIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
- Value imagIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+ Value real = b.create<complex::ReOp>(adaptor.getComplex());
+ Value imag = b.create<complex::ImOp>(adaptor.getComplex());
+ Value absReal = b.create<math::AbsFOp>(real, fmf);
+ Value absImag = b.create<math::AbsFOp>(imag, fmf);
- // Real > Imag
- Value imagDivReal = b.create<arith::DivFOp>(imag, real, fmf.getValue());
- Value imagSq =
- b.create<arith::MulFOp>(imagDivReal, imagDivReal, fmf.getValue());
- Value imagSqPlusOne = b.create<arith::AddFOp>(imagSq, one, fmf.getValue());
- Value imagSqrt = b.create<math::SqrtOp>(imagSqPlusOne, fmf.getValue());
- Value realAbs = b.create<math::AbsFOp>(real, fmf.getValue());
- Value absImag = b.create<arith::MulFOp>(imagSqrt, realAbs, fmf.getValue());
-
- // Real <= Imag
- Value realDivImag = b.create<arith::DivFOp>(real, imag, fmf.getValue());
- Value realSq =
- b.create<arith::MulFOp>(realDivImag, realDivImag, fmf.getValue());
- Value realSqPlusOne = b.create<arith::AddFOp>(realSq, one, fmf.getValue());
- Value realSqrt = b.create<math::SqrtOp>(realSqPlusOne, fmf.getValue());
- Value imagAbs = b.create<math::AbsFOp>(imag, fmf.getValue());
- Value absReal = b.create<arith::MulFOp>(realSqrt, imagAbs, fmf.getValue());
-
- rewriter.replaceOpWithNewOp<arith::SelectOp>(
- op, realIsZero, imagAbs,
- b.create<arith::SelectOp>(
- imagIsZero, realAbs,
- b.create<arith::SelectOp>(
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, real, imag),
- absImag, absReal)));
+ Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
+ Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
+ Value ratio = b.create<arith::DivFOp>(min, max, fmf);
+ Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
+ Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
+ Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
+ Value result = b.create<arith::MulFOp>(max, sqrt, fmf);
+ Value isNaN =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, min, result);
return success();
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 46dba04a88aa0c..a1de61d10bb226 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -8,29 +8,21 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
return %abs : f32
}
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: return %[[ABS3]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: return %[[ABS]] : f32
// -----
@@ -258,29 +250,21 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg: complex<f32>
return %log : complex<f32>
}
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[RESULT:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RESULT]], %[[RESULT]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[RESULT]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = math.log %[[ABS]] : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
@@ -509,30 +493,22 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[REAL_IS_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
// CHECK: %[[IMAG_IS_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
// CHECK: %[[IS_ZERO:.*]] = arith.andi %[[REAL_IS_ZERO]], %[[IMAG_IS_ZERO]] : i1
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL2]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG2]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG2]], %[[REAL2]] : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL2]] : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL2]], %[[IMAG2]] : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG2]] : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL2]], %[[IMAG2]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[NORM:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[NORM]] : f32
-// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[NORM]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL2]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG2]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[REAL_SIGN:.*]] = arith.divf %[[REAL]], %[[ABS]] : f32
+// CHECK: %[[IMAG_SIGN:.*]] = arith.divf %[[IMAG]], %[[ABS]] : f32
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
@@ -725,29 +701,21 @@ func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[VAR0:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[VAR1:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[VAR2:.*]] = math.absf %[[VAR0]] : f32
-// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[VAR3:.*]] = complex.re %[[ARG]] : complex<f32>
-// CHECK: %[[VAR4:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[VAR5:.*]] = arith.cmpf oeq, %[[VAR3]], %[[CST0]] : f32
-// CHECK: %[[VAR6:.*]] = arith.cmpf oeq, %[[VAR4]], %[[CST0]] : f32
-// CHECK: %[[VAR7:.*]] = arith.divf %[[VAR4]], %[[VAR3]] : f32
-// CHECK: %[[VAR8:.*]] = arith.mulf %[[VAR7]], %[[VAR7]] : f32
-// CHECK: %[[VAR9:.*]] = arith.addf %[[VAR8]], %[[CST1]] : f32
-// CHECK: %[[VAR10:.*]] = math.sqrt %[[VAR9]] : f32
-// CHECK: %[[VAR11:.*]] = math.absf %[[VAR3]] : f32
-// CHECK: %[[VAR12:.*]] = arith.mulf %[[VAR10]], %[[VAR11]] : f32
-// CHECK: %[[VAR13:.*]] = arith.divf %[[VAR3]], %[[VAR4]] : f32
-// CHECK: %[[VAR14:.*]] = arith.mulf %[[VAR13]], %[[VAR13]] : f32
-// CHECK: %[[VAR15:.*]] = arith.addf %[[VAR14]], %[[CST1]] : f32
-// CHECK: %[[VAR16:.*]] = math.sqrt %[[VAR15]] : f32
-// CHECK: %[[VAR17:.*]] = math.absf %[[VAR4]] : f32
-// CHECK: %[[VAR18:.*]] = arith.mulf %[[VAR16]], %[[VAR17]] : f32
-// CHECK: %[[VAR19:.*]] = arith.cmpf ogt, %[[VAR3]], %[[VAR4]] : f32
-// CHECK: %[[VAR20:.*]] = arith.select %[[VAR19]], %[[VAR12]], %[[VAR18]] : f32
-// CHECK: %[[VAR21:.*]] = arith.select %[[VAR6]], %[[VAR11]], %[[VAR20]] : f32
-// CHECK: %[[VAR22:.*]] = arith.select %[[VAR5]], %[[VAR17]], %[[VAR21]] : f32
-// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[VAR22]] : f32
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: %[[VAR23:.*]] = arith.addf %[[VAR2]], %[[ABS]] : f32
// CHECK: %[[CST2:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[VAR24:.*]] = arith.mulf %[[VAR23]], %[[CST2]] : f32
// CHECK: %[[VAR25:.*]] = math.sqrt %[[VAR24]] : f32
@@ -821,29 +789,21 @@ func.func @complex_abs_with_fmf(%arg: complex<f32>) -> f32 {
%abs = complex.abs %arg fastmath<nnan,contract> : complex<f32>
return %abs : f32
}
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_ABS:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_REAL:.*]] = arith.mulf %[[REAL_SQRT]], %[[IMAG_ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_GT_IMAG:.*]] = arith.cmpf ogt, %[[REAL]], %[[IMAG]] : f32
-// CHECK: %[[ABS1:.*]] = arith.select %[[REAL_GT_IMAG]], %[[ABS_IMAG]], %[[ABS_REAL]] : f32
-// CHECK: %[[ABS2:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[REAL_ABS]], %[[ABS1]] : f32
-// CHECK: %[[ABS3:.*]] = arith.select %[[IS_REAL_ZERO]], %[[IMAG_ABS]], %[[ABS2]] : f32
-// CHECK: return %[[ABS3]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[RATIO]], %[[RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ_PLUS_ONE:.*]] = arith.addf %[[RATIO_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[SQRT:.*]] = math.sqrt %[[RATIO_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_OR_NAN:.*]] = arith.mulf %[[MAX]], %[[SQRT]] fastmath<nnan,contract> : f32
+// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[ABS_OR_NAN]], %[[ABS_OR_NAN]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS:.*]] = arith.select %[[IS_NAN]], %[[MIN]], %[[ABS_OR_NAN]] : f32
+// CHECK: return %[[ABS]] : f32
// -----
@@ -928,29 +888,21 @@ func.func @complex_log_with_fmf(%arg: complex<f32>) -> complex<f32> {
%log = complex.log %arg fastmath<nnan,contract> : complex<f32>
return %log : complex<f32>
}
-// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[IS_REAL_ZERO:.*]] = arith.cmpf oeq, %[[REAL]], %[[ZERO]] : f32
-// CHECK: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
-// CHECK: %[[IMAG_DIV_REAL:.*]] = arith.divf %[[IMAG]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ:.*]] = arith.mulf %[[IMAG_DIV_REAL]], %[[IMAG_DIV_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQ_PLUS_ONE:.*]] = arith.addf %[[IMAG_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[IMAG_SQRT:.*]] = math.sqrt %[[IMAG_SQ_PLUS_ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_ABS:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[ABS_IMAG:.*]] = arith.mulf %[[IMAG_SQRT]], %[[REAL_ABS]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_DIV_IMAG:.*]] = arith.divf %[[REAL]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ:.*]] = arith.mulf %[[REAL_DIV_IMAG]], %[[REAL_DIV_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQ_PLUS_ONE:.*]] = arith.addf %[[REAL_SQ]], %[[ONE]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_SQRT:.*]] = math.sqrt %[[REAL_SQ_PLUS_ONE]] fastmath<nnan,con...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/88373
More information about the Mlir-commits
mailing list