[Mlir-commits] [mlir] [mlir][math] Reland 58ef9bec071383744fb703ff08df9806f25e4095 (PR #85436)
Kai Sasaki
llvmlistbot at llvm.org
Sat Mar 16 03:45:29 PDT 2024
================
@@ -91,34 +91,40 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
}
/// Expands tanh op into
-/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
-/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
+/// 1-exp^{-2x} / 1+exp^{-2x}
+/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
+/// We compute a "signs" value which is -1 if input is negative and +1 if input
+/// is positive. Then multiply the input by this value, guaranteeing that the
+/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
+/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
+/// result by `sign(x)` to retain sign of the real result.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
auto floatType = op.getOperand().getType();
Location loc = op.getLoc();
+ Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
Value one = createFloatConst(loc, floatType, 1.0, rewriter);
- Value two = createFloatConst(loc, floatType, 2.0, rewriter);
- Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
+ Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
+
+ // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
+ Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
+ op.getOperand(), zero);
+ sign = rewriter.create<arith::UIToFPOp>(loc, floatType, sign);
----------------
Lewuathe wrote:
I think the name of the value should be `signFlag` or `isPos` re-considering the fact we use the value as boolean specifying the value is positive or not. The name `sign` may sound that it can take 1 or -1.
https://github.com/llvm/llvm-project/pull/85436
More information about the Mlir-commits
mailing list