[Mlir-commits] [mlir] Revert "[mlir][math] Implement alternative decomposition for tanh (#8… (PR #85429)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 15 09:51:40 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: None (srcarroll)

<details>
<summary>Changes</summary>

…5025)"

This reverts commit 58ef9bec071383744fb703ff08df9806f25e4095.

---
Full diff: https://github.com/llvm/llvm-project/pull/85429.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+17-23) 
- (modified) mlir/test/Dialect/Math/expand-math.mlir (+10-9) 


``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 1750171b81a10e..989a3e5536ec66 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -91,40 +91,34 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
 }
 
 /// Expands tanh op into
-/// 1-exp^{-2x} / 1+exp^{-2x}
-/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
-/// We compute a "signs" value which is -1 if input is negative and +1 if input
-/// is positive.  Then multiply the input by this value, guaranteeing that the
-/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
-/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
-/// result by `sign(x)` to retain sign of the real result.
+///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
+///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   auto floatType = op.getOperand().getType();
   Location loc = op.getLoc();
-  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
   Value one = createFloatConst(loc, floatType, 1.0, rewriter);
-  Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
-
-  // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
-  Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
-                                              op.getOperand(), zero);
-  sign = rewriter.create<arith::SIToFPOp>(loc, floatType, sign);
-  sign = rewriter.create<arith::MulFOp>(loc, sign, negTwo);
-  sign = rewriter.create<arith::AddFOp>(loc, sign, one);
+  Value two = createFloatConst(loc, floatType, 2.0, rewriter);
+  Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
 
-  // Normalize input to positive value: y = sign(x) * x
-  Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
-
-  // Decompose on normalized input
-  Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
+  // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
+  Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
   Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
   Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
   Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
   Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
 
-  // Multiply result by sign(x) to retain signs from negative inputs
-  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
+  // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
+  exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
+  dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
+  divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
+  Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
 
+  // tanh(x) = x >= 0 ? positiveRes : negativeRes
+  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
+  Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
+                                                op.getOperand(), zero);
+  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
+                                               negativeRes);
   return success();
 }
 
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 86ee5c8620472b..6ee65b085dad1b 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -7,18 +7,19 @@ func.func @tanh(%arg: f32) -> f32 {
 }
 // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK-DAG: %[[TWO:.+]] = arith.constant -2.000000e+00 : f32
-// CHECK: %[[VAL0:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] : f32
-// CHECK: %[[VAL1:.+]] = arith.sitofp %[[VAL0]] : i1 to f32
-// CHECK: %[[VAL2:.+]] = arith.mulf %[[VAL1]], %[[TWO]] : f32
-// CHECK: %[[SIGN:.+]] = arith.addf %[[VAL2]], %[[ONE]] : f32
-// CHECK: %[[POSX:.+]] = arith.mulf %[[SIGN]], %arg0 : f32
-// CHECK: %[[NEGDOUBLEDX:.+]] = arith.mulf %[[POSX]], %[[TWO]] : f32
+// CHECK-DAG: %[[TWO:.+]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[DOUBLEDX:.+]] = arith.mulf %arg0, %[[TWO]] : f32
+// CHECK: %[[NEGDOUBLEDX:.+]] = arith.negf %[[DOUBLEDX]] : f32
 // CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32
 // CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32
 // CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32
-// CHECK: %[[POSRES:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
-// CHECK: %[[RESULT:.+]] = arith.mulf %[[SIGN]], %[[POSRES]] : f32
+// CHECK: %[[RES1:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
+// CHECK: %[[EXP2:.+]] = math.exp %[[DOUBLEDX]] : f32
+// CHECK: %[[DIVIDEND2:.+]] = arith.subf %[[EXP2]], %[[ONE]] : f32
+// CHECK: %[[DIVISOR2:.+]] = arith.addf %[[EXP2]], %[[ONE]] : f32
+// CHECK: %[[RES2:.+]] = arith.divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32
+// CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32
+// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
 // CHECK: return %[[RESULT]]
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/85429


More information about the Mlir-commits mailing list