[Mlir-commits] [mlir] Reland 58ef9bec071383744fb703ff08df9806f25e4095 (PR #85436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 15 16:45:40 PDT 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/85436

>From 4f5549ea6bf82c92f51af2e010e803f1b9c5fbbb Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 15 Mar 2024 12:15:33 -0500
Subject: [PATCH 1/2] Reland 58ef9bec071383744fb703ff08df9806f25e4095

---
 .../Math/Transforms/ExpandPatterns.cpp        |  2 +-
 mlir/test/Dialect/Math/expand-math.mlir       |  2 +-
 .../test-expand-math-approx.mlir              | 19 +++++++++++++++++++
 3 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 1750171b81a10e..fceafcff8490c3 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -108,7 +108,7 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
   Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
                                               op.getOperand(), zero);
-  sign = rewriter.create<arith::SIToFPOp>(loc, floatType, sign);
+  sign = rewriter.create<arith::UIToFPOp>(loc, floatType, sign);
   sign = rewriter.create<arith::MulFOp>(loc, sign, negTwo);
   sign = rewriter.create<arith::AddFOp>(loc, sign, one);
 
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 86ee5c8620472b..6326d3a71874b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -9,7 +9,7 @@ func.func @tanh(%arg: f32) -> f32 {
 // CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK-DAG: %[[TWO:.+]] = arith.constant -2.000000e+00 : f32
 // CHECK: %[[VAL0:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] : f32
-// CHECK: %[[VAL1:.+]] = arith.sitofp %[[VAL0]] : i1 to f32
+// CHECK: %[[VAL1:.+]] = arith.uitofp %[[VAL0]] : i1 to f32
 // CHECK: %[[VAL2:.+]] = arith.mulf %[[VAL1]], %[[TWO]] : f32
 // CHECK: %[[SIGN:.+]] = arith.addf %[[VAL2]], %[[ONE]] : f32
 // CHECK: %[[POSX:.+]] = arith.mulf %[[SIGN]], %arg0 : f32
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 541a201c94c586..e2229a392bbf76 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -683,6 +683,24 @@ func.func @cosh() {
  return
 }
 
+// -------------------------------------------------------------------------- //
+// Tanh.
+// -------------------------------------------------------------------------- //
+
+func.func @tanh_8xf32(%a : vector<8xf32>) {
+  %r = math.tanh %a : vector<8xf32>
+  vector.print %r : vector<8xf32>
+  return
+}
+
+func.func @tanh() {
+  // CHECK: -1, -0.761594, -0.291313, 0, 0.291313, 0.761594, 1, 1
+  %v3 = arith.constant dense<[0xff800000, -1.0, -0.3, 0.0, 0.3, 1.0, 10.0, 0x7f800000]> : vector<8xf32>
+  call @tanh_8xf32(%v3) : (vector<8xf32>) -> ()
+
+ return
+}
+
 func.func @main() {
   call @exp2f() : () -> ()
   call @roundf() : () -> ()
@@ -690,5 +708,6 @@ func.func @main() {
   call @roundeven() : () -> ()
   call @sinh() : () -> ()
   call @cosh() : () -> ()
+  call @tanh() : () -> ()
   return
 }

>From 1ba29891a49a624e3058413ba0661eaf4a5cf10e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 15 Mar 2024 18:45:26 -0500
Subject: [PATCH 2/2] fix bad merge

---
 .../Math/Transforms/ExpandPatterns.cpp        | 26 ++++++++-----------
 mlir/test/Dialect/Math/expand-math.mlir       |  9 ++-----
 2 files changed, 13 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 6a4a4f0c33e301..fceafcff8490c3 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -91,14 +91,19 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
 }
 
 /// Expands tanh op into
-///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
-///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
+/// 1-exp^{-2x} / 1+exp^{-2x}
+/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
+/// We compute a "signs" value which is -1 if input is negative and +1 if input
+/// is positive.  Then multiply the input by this value, guaranteeing that the
+/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
+/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
+/// result by `sign(x)` to retain sign of the real result.
 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   auto floatType = op.getOperand().getType();
   Location loc = op.getLoc();
+  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
   Value one = createFloatConst(loc, floatType, 1.0, rewriter);
-  Value two = createFloatConst(loc, floatType, 2.0, rewriter);
-  Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
+  Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
 
   // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
   Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
@@ -117,18 +122,9 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
   Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
 
-  // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
-  exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
-  dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
-  divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
-  Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
+  // Multiply result by sign(x) to retain signs from negative inputs
+  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
 
-  // tanh(x) = x >= 0 ? positiveRes : negativeRes
-  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
-  Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
-                                                op.getOperand(), zero);
-  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
-                                               negativeRes);
   return success();
 }
 
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1274ab97d3da69..6326d3a71874b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -17,13 +17,8 @@ func.func @tanh(%arg: f32) -> f32 {
 // CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32
 // CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32
 // CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32
-// CHECK: %[[RES1:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
-// CHECK: %[[EXP2:.+]] = math.exp %[[DOUBLEDX]] : f32
-// CHECK: %[[DIVIDEND2:.+]] = arith.subf %[[EXP2]], %[[ONE]] : f32
-// CHECK: %[[DIVISOR2:.+]] = arith.addf %[[EXP2]], %[[ONE]] : f32
-// CHECK: %[[RES2:.+]] = arith.divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32
-// CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32
-// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
+// CHECK: %[[POSRES:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[SIGN]], %[[POSRES]] : f32
 // CHECK: return %[[RESULT]]
 
 // -----



More information about the Mlir-commits mailing list