[Mlir-commits] [mlir] [mlir][math] Reland 58ef9bec071383744fb703ff08df9806f25e4095 (PR #85436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 16 07:06:56 PDT 2024


================
@@ -91,34 +91,40 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
 }
 
 /// Expands tanh op into
-///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
-///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
+/// 1-exp^{-2x} / 1+exp^{-2x}
+/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
+/// We compute a "signs" value which is -1 if input is negative and +1 if input
+/// is positive.  Then multiply the input by this value, guaranteeing that the
+/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
+/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
+/// result by `sign(x)` to retain sign of the real result.
 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   auto floatType = op.getOperand().getType();
   Location loc = op.getLoc();
+  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
   Value one = createFloatConst(loc, floatType, 1.0, rewriter);
-  Value two = createFloatConst(loc, floatType, 2.0, rewriter);
-  Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
+  Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
+
+  // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
+  Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
+                                              op.getOperand(), zero);
+  sign = rewriter.create<arith::UIToFPOp>(loc, floatType, sign);
----------------
srcarroll wrote:

yah i just called all the intermediate results to get to sign the same thing so i didn't have to come up with names, lol. i'll change

https://github.com/llvm/llvm-project/pull/85436


More information about the Mlir-commits mailing list