[Mlir-commits] [mlir] [mlir][math] Reland 58ef9bec071383744fb703ff08df9806f25e4095 (PR #85436)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 16 07:14:24 PDT 2024
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/85436
>From 4f5549ea6bf82c92f51af2e010e803f1b9c5fbbb Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 15 Mar 2024 12:15:33 -0500
Subject: [PATCH 1/3] Reland 58ef9bec071383744fb703ff08df9806f25e4095
---
.../Math/Transforms/ExpandPatterns.cpp | 2 +-
mlir/test/Dialect/Math/expand-math.mlir | 2 +-
.../test-expand-math-approx.mlir | 19 +++++++++++++++++++
3 files changed, 21 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 1750171b81a10e..fceafcff8490c3 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -108,7 +108,7 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
op.getOperand(), zero);
- sign = rewriter.create<arith::SIToFPOp>(loc, floatType, sign);
+ sign = rewriter.create<arith::UIToFPOp>(loc, floatType, sign);
sign = rewriter.create<arith::MulFOp>(loc, sign, negTwo);
sign = rewriter.create<arith::AddFOp>(loc, sign, one);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 86ee5c8620472b..6326d3a71874b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -9,7 +9,7 @@ func.func @tanh(%arg: f32) -> f32 {
// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[TWO:.+]] = arith.constant -2.000000e+00 : f32
// CHECK: %[[VAL0:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] : f32
-// CHECK: %[[VAL1:.+]] = arith.sitofp %[[VAL0]] : i1 to f32
+// CHECK: %[[VAL1:.+]] = arith.uitofp %[[VAL0]] : i1 to f32
// CHECK: %[[VAL2:.+]] = arith.mulf %[[VAL1]], %[[TWO]] : f32
// CHECK: %[[SIGN:.+]] = arith.addf %[[VAL2]], %[[ONE]] : f32
// CHECK: %[[POSX:.+]] = arith.mulf %[[SIGN]], %arg0 : f32
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 541a201c94c586..e2229a392bbf76 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -683,6 +683,24 @@ func.func @cosh() {
return
}
+// -------------------------------------------------------------------------- //
+// Tanh.
+// -------------------------------------------------------------------------- //
+
+func.func @tanh_8xf32(%a : vector<8xf32>) {
+ %r = math.tanh %a : vector<8xf32>
+ vector.print %r : vector<8xf32>
+ return
+}
+
+func.func @tanh() {
+ // CHECK: -1, -0.761594, -0.291313, 0, 0.291313, 0.761594, 1, 1
+ %v3 = arith.constant dense<[0xff800000, -1.0, -0.3, 0.0, 0.3, 1.0, 10.0, 0x7f800000]> : vector<8xf32>
+ call @tanh_8xf32(%v3) : (vector<8xf32>) -> ()
+
+ return
+}
+
func.func @main() {
call @exp2f() : () -> ()
call @roundf() : () -> ()
@@ -690,5 +708,6 @@ func.func @main() {
call @roundeven() : () -> ()
call @sinh() : () -> ()
call @cosh() : () -> ()
+ call @tanh() : () -> ()
return
}
>From 1ba29891a49a624e3058413ba0661eaf4a5cf10e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 15 Mar 2024 18:45:26 -0500
Subject: [PATCH 2/3] fix bad merge
---
.../Math/Transforms/ExpandPatterns.cpp | 26 ++++++++-----------
mlir/test/Dialect/Math/expand-math.mlir | 9 ++-----
2 files changed, 13 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 6a4a4f0c33e301..fceafcff8490c3 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -91,14 +91,19 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
}
/// Expands tanh op into
-/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
-/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
+/// 1-exp^{-2x} / 1+exp^{-2x}
+/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
+/// We compute a "signs" value which is -1 if input is negative and +1 if input
+/// is positive. Then multiply the input by this value, guaranteeing that the
+/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
+/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
+/// result by `sign(x)` to retain sign of the real result.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
auto floatType = op.getOperand().getType();
Location loc = op.getLoc();
+ Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
Value one = createFloatConst(loc, floatType, 1.0, rewriter);
- Value two = createFloatConst(loc, floatType, 2.0, rewriter);
- Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
+ Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
@@ -117,18 +122,9 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
- // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
- exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
- dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
- divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
- Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
+ // Multiply result by sign(x) to retain signs from negative inputs
+ rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
- // tanh(x) = x >= 0 ? positiveRes : negativeRes
- Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
- Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
- op.getOperand(), zero);
- rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
- negativeRes);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 1274ab97d3da69..6326d3a71874b4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -17,13 +17,8 @@ func.func @tanh(%arg: f32) -> f32 {
// CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32
// CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32
// CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32
-// CHECK: %[[RES1:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
-// CHECK: %[[EXP2:.+]] = math.exp %[[DOUBLEDX]] : f32
-// CHECK: %[[DIVIDEND2:.+]] = arith.subf %[[EXP2]], %[[ONE]] : f32
-// CHECK: %[[DIVISOR2:.+]] = arith.addf %[[EXP2]], %[[ONE]] : f32
-// CHECK: %[[RES2:.+]] = arith.divf %[[DIVIDEND2]], %[[DIVISOR2]] : f32
-// CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32
-// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
+// CHECK: %[[POSRES:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[SIGN]], %[[POSRES]] : f32
// CHECK: return %[[RESULT]]
// -----
>From 99c1b0d260b9bf6b03fe11856f89dfaaea84c6db Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 16 Mar 2024 09:13:48 -0500
Subject: [PATCH 3/3] Change value names to more appropriate
---
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index fceafcff8490c3..e1ab9c905447b7 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -106,11 +106,13 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
// Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
- Value sign = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
- op.getOperand(), zero);
- sign = rewriter.create<arith::UIToFPOp>(loc, floatType, sign);
- sign = rewriter.create<arith::MulFOp>(loc, sign, negTwo);
- sign = rewriter.create<arith::AddFOp>(loc, sign, one);
+ Value isNegative = rewriter.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
+ Value isNegativeFloat =
+ rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
+ Value isNegativeTimesNegTwo =
+ rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
+ Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
// Normalize input to positive value: y = sign(x) * x
Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
More information about the Mlir-commits
mailing list