[Mlir-commits] [mlir] Fix complex tanh overflows. (PR #88708)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 15 03:18:25 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Johannes Reifferscheid (jreiffers)

<details>
<summary>Changes</summary>

This ports the XLA lowering and was verified using XLA's exhaustive_unary_test_complex test.

---
Full diff: https://github.com/llvm/llvm-project/pull/88708.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+72-18) 
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+85-17) 


``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 3ebee9baff31bd..03e578136e5901 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -978,30 +978,84 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
   LogicalResult
   matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
-    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+    arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
+    const auto &floatSemantics = elementType.getFloatSemantics();
 
-    // The hyperbolic tangent for complex number can be calculated as follows.
-    // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
-    // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
     Value real =
-        rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+        b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
     Value imag =
-        rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
-    Value tanhA = rewriter.create<math::TanhOp>(loc, real, fmf);
-    Value cosB = rewriter.create<math::CosOp>(loc, imag, fmf);
-    Value sinB = rewriter.create<math::SinOp>(loc, imag, fmf);
-    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB, fmf);
-    Value numerator =
-        rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, elementType, rewriter.getFloatAttr(elementType, 1));
-    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB, fmf);
-    Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
-    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator,
-                                                fmf);
+        b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
+
+    auto cst = [&](APFloat v) {
+      return b.create<arith::ConstantOp>(elementType,
+                                         b.getFloatAttr(elementType, v));
+    };
+    Value inf = cst(APFloat::getInf(floatSemantics));
+    Value negOne = b.create<arith::ConstantOp>(
+        elementType, b.getFloatAttr(elementType, -1.0));
+    Value four = b.create<arith::ConstantOp>(elementType,
+                                             b.getFloatAttr(elementType, 4.0));
+    Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
+    Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
+
+    Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
+    Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
+    Value realNum =
+        b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
+
+    Value cosImag = b.create<math::CosOp>(imag, fmf);
+    Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
+    Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
+    Value sinImag = b.create<math::SinOp>(imag, fmf);
+
+    Value imagNum = b.create<arith::MulFOp>(
+        four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
+
+    Value expSumMinusTwo =
+        b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
+    Value denom =
+        b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
+
+    Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
+                                          expSumMinusTwo, inf, fmf);
+    Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
+
+    Value resultReal = b.create<arith::SelectOp>(
+        isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
+    Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
+
+    if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
+                                            arith::FastMathFlags::ninf)) {
+      Value absReal = b.create<math::AbsFOp>(real, fmf);
+      Value zero = b.create<arith::ConstantOp>(
+          elementType, b.getFloatAttr(elementType, 0.0));
+      Value nan = cst(APFloat::getNaN(floatSemantics));
+
+      Value absRealIsInf =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
+      Value imagIsZero =
+          b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
+      Value absRealIsNotInf = b.create<arith::XOrIOp>(
+          absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
+
+      Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
+                                                   imagNum, imagNum, fmf);
+      Value resultRealIsNaN =
+          b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
+      Value resultImagIsZero = b.create<arith::OrIOp>(
+          imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
+
+      resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
+      resultImag =
+          b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
+    }
+
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+                                                   resultImag);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 8b4ea9777f7976..fa1d564d6ad355 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -679,14 +679,42 @@ func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
 }
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] : f32
-// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] : f32
-// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] : f32
-// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] : f32
-// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32
-// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[NEG_ONE:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK: %[[FOUR:.*]] = arith.constant 4.000000e+00 : f32
+// CHECK: %[[TWO_REAL:.*]] = arith.addf %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[NEG_TWO_REAL:.*]] = arith.mulf %[[NEG_ONE]], %[[TWO_REAL]] : f32
+// CHECK: %[[EXPM1:.*]] = math.expm1 %[[TWO_REAL]] : f32
+// CHECK: %[[EXPM1_2:.*]] = math.expm1 %[[NEG_TWO_REAL]] : f32
+// CHECK: %[[REAL_NUM:.*]] = arith.subf %[[EXPM1]], %[[EXPM1_2]] : f32
+// CHECK: %[[COS:.*]] = math.cos %[[IMAG]] : f32
+// CHECK: %[[COS_SQ:.*]] = arith.mulf %[[COS]], %[[COS]] : f32
+// CHECK: %[[FOUR_COS_SQ:.*]] = arith.mulf %[[COS_SQ]], %[[FOUR]] : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[IMAG]] : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[COS]], %[[SIN]] : f32
+// CHECK: %[[IMAG_NUM:.*]] = arith.mulf %[[FOUR]], %[[MUL]] : f32
+// CHECK: %[[ADD:.*]] = arith.addf %[[EXPM1]], %[[EXPM1_2]] : f32
+// CHECK: %[[DENOM:.*]] = arith.addf %[[ADD]], %[[FOUR_COS_SQ]] : f32
+// CHECK: %[[IS_INF:.*]] = arith.cmpf oeq, %[[ADD]], %[[INF]] : f32
+// CHECK: %[[LIMIT:.*]] = math.copysign %[[NEG_ONE]], %[[REAL]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.divf %[[REAL_NUM]], %[[DENOM]] : f32
+// CHECK: %[[RESULT_REAL2:.*]] = arith.select %[[IS_INF]], %[[LIMIT]], %[[RESULT_REAL]] : f32
+// CHECK: %[[RESULT_IMAG:.*]] = arith.divf %[[IMAG_NUM]], %[[DENOM]] : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK: %[[ABS_REAL_INF:.*]] = arith.cmpf oeq, %[[ABS_REAL]], %[[INF]] : f32
+// CHECK: %[[IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK: %true = arith.constant true
+// CHECK: %[[ABS_REAL_NOT_INF:.*]] = arith.xori %[[ABS_REAL_INF]], %true : i1
+// CHECK: %[[IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[IMAG_NUM]], %[[IMAG_NUM]] : f32
+// CHECK: %[[REAL_IS_NAN:.*]] = arith.andi %[[IMAG_IS_NAN]], %[[ABS_REAL_NOT_INF]] : i1
+// CHECK: %[[AND:.*]] = arith.andi %[[ABS_REAL_INF]], %[[IMAG_IS_NAN]] : i1
+// CHECK: %[[IMAG_IS_NAN2:.*]] = arith.ori %[[IMAG_ZERO]], %[[AND]] : i1
+// CHECK: %[[RESULT_REAL3:.*]] = arith.select %[[REAL_IS_NAN]], %[[NAN]], %[[RESULT_REAL2]] : f32
+// CHECK: %[[RESULT_IMAG2:.*]] = arith.select %[[IMAG_IS_NAN2]], %[[ZERO]], %[[RESULT_IMAG]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL3]], %[[RESULT_IMAG2]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
 
 // -----
 
@@ -2100,7 +2128,6 @@ func.func @complex_tan_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
-
 // -----
 
 // CHECK-LABEL: func @complex_tanh_with_fmf
@@ -2109,13 +2136,54 @@ func.func @complex_tanh_with_fmf(%arg: complex<f32>) -> complex<f32> {
   %tanh = complex.tanh %arg fastmath<nnan,contract> : complex<f32>
   return %tanh : complex<f32>
 }
+
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] fastmath<nnan,contract> : f32
-// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
-// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] fastmath<nnan,contract> : f32
-// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
+// CHECK: %[[INF:.*]] = arith.constant 0x7F800000 : f32
+// CHECK: %[[NEG_ONE:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK: %[[FOUR:.*]] = arith.constant 4.000000e+00 : f32
+// CHECK: %[[TWO_REAL:.*]] = arith.addf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[NEG_TWO_REAL:.*]] = arith.mulf %[[NEG_ONE]], %[[TWO_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[EXPM1:.*]] = math.expm1 %[[TWO_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[EXPM1_2:.*]] = math.expm1 %[[NEG_TWO_REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_NUM:.*]] = arith.subf %[[EXPM1]], %[[EXPM1_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[COS_SQ:.*]] = arith.mulf %[[COS]], %[[COS]] fastmath<nnan,contract> : f32
+// CHECK: %[[FOUR_COS_SQ:.*]] = arith.mulf %[[COS_SQ]], %[[FOUR]] fastmath<nnan,contract> : f32
+// CHECK: %[[SIN:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[COS]], %[[SIN]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_NUM:.*]] = arith.mulf %[[FOUR]], %[[MUL]] fastmath<nnan,contract> : f32
+// CHECK: %[[ADD:.*]] = arith.addf %[[EXPM1]], %[[EXPM1_2]] fastmath<nnan,contract> : f32
+// CHECK: %[[DENOM:.*]] = arith.addf %[[ADD]], %[[FOUR_COS_SQ]] fastmath<nnan,contract> : f32
+// CHECK: %[[IS_INF:.*]] = arith.cmpf oeq, %[[ADD]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[LIMIT:.*]] = math.copysign %[[NEG_ONE]], %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.divf %[[REAL_NUM]], %[[DENOM]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL2:.*]] = arith.select %[[IS_INF]], %[[LIMIT]], %[[RESULT_REAL]] : f32
+// CHECK: %[[RESULT_IMAG:.*]] = arith.divf %[[IMAG_NUM]], %[[DENOM]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL:.*]] = math.absf %[[REAL]] fastmath<nnan,contract> : f32
+// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK: %[[ABS_REAL_INF:.*]] = arith.cmpf oeq, %[[ABS_REAL]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK: %[[IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] fastmath<nnan,contract> : f32
+// CHECK: %true = arith.constant true
+// CHECK: %[[ABS_REAL_NOT_INF:.*]] = arith.xori %[[ABS_REAL_INF]], %true : i1
+// CHECK: %[[IMAG_IS_NAN:.*]] = arith.cmpf uno, %[[IMAG_NUM]], %[[IMAG_NUM]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_IS_NAN:.*]] = arith.andi %[[IMAG_IS_NAN]], %[[ABS_REAL_NOT_INF]] : i1
+// CHECK: %[[AND:.*]] = arith.andi %[[ABS_REAL_INF]], %[[IMAG_IS_NAN]] : i1
+// CHECK: %[[IMAG_IS_NAN2:.*]] = arith.ori %[[IMAG_ZERO]], %[[AND]] : i1
+// CHECK: %[[RESULT_REAL3:.*]] = arith.select %[[REAL_IS_NAN]], %[[NAN]], %[[RESULT_REAL2]] : f32
+// CHECK: %[[RESULT_IMAG2:.*]] = arith.select %[[IMAG_IS_NAN2]], %[[ZERO]], %[[RESULT_IMAG]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL3]], %[[RESULT_IMAG2]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_tanh_nnan_ninf
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_tanh_nnan_ninf(%arg: complex<f32>) -> complex<f32> {
+  %tanh = complex.tanh %arg fastmath<nnan,ninf> : complex<f32>
+  return %tanh : complex<f32>
+}
+
+// CHECK-COUNT-1: arith.select
+// CHECK-NOT: arith.select

``````````

</details>


https://github.com/llvm/llvm-project/pull/88708


More information about the Mlir-commits mailing list