[Mlir-commits] [mlir] [mlir][math] Reland 58ef9bec071383744fb703ff08df9806f25e4095 (PR #85436)

Kai Sasaki llvmlistbot at llvm.org
Sat Mar 16 03:45:28 PDT 2024


================
@@ -91,34 +91,40 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
 }
 
 /// Expands tanh op into
-///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
-///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
+/// 1-exp^{-2x} / 1+exp^{-2x}
+/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
+/// We compute a "signs" value which is -1 if input is negative and +1 if input
+/// is positive.  Then multiply the input by this value, guaranteeing that the
+/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
+/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
+/// result by `sign(x)` to retain sign of the real result.
 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   auto floatType = op.getOperand().getType();
   Location loc = op.getLoc();
+  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
   Value one = createFloatConst(loc, floatType, 1.0, rewriter);
-  Value two = createFloatConst(loc, floatType, 2.0, rewriter);
-  Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
+  Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
+
+  // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
+  Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
----------------
Lewuathe wrote:

The result of `arith.cmpf` is `i1` type. If we cast the value 1 in `i1` type to the signed float value, it will be -1 because of the sign extension. If we use `arith.uitofp` for casting, the value 1 will be kept 1. (and 0 remains 0). 

https://github.com/llvm/llvm-project/pull/85436


More information about the Mlir-commits mailing list