[Mlir-commits] [mlir] [mlir][math] Added `math.sinh` with expansions to `math.exp` (PR #75517)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 14 11:18:32 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Rob Suderman (rsuderman)

<details>
<summary>Changes</summary>

Includes end-to-end tests for the cpu running, folders using `libm` and lowerings to the corresponding `libm` operations.

---
Full diff: https://github.com/llvm/llvm-project/pull/75517.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+21) 
- (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+2) 
- (modified) mlir/lib/Conversion/MathToLibm/MathToLibm.cpp (+1) 
- (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+18) 
- (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+40) 
- (modified) mlir/test/Conversion/MathToLibm/convert-to-libm.mlir (+12) 
- (modified) mlir/test/lib/Dialect/Math/TestExpandMath.cpp (+2) 
- (modified) mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir (+94) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index b9daa91b28a9bd..211cb31d50bdcf 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -375,6 +375,27 @@ def Math_SinOp : Math_FloatUnaryOp<"sin"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// SinhOp
+//===----------------------------------------------------------------------===//
+
+def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
+  let summary = "hyperbolic sine of the specified value";
+  let description = [{
+    The `sinh` operation computes the hyperbolic sine. It takes one operand
+    of floating point type (i.e., scalar, tensor or vector) and returns one
+    result of the same type. It has no standard attributes.
+
+    Example:
+
+    ```mlir
+    // Scalar hyperbolic sine value.
+    %a = math.sinh %b : f64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // CountLeadingZerosOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 817d6e1dae051f..9e6759ef229d6f 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -25,6 +25,8 @@ class RewritePatternSet;
 
 void populateExpandCtlzPattern(RewritePatternSet &patterns);
 void populateExpandTanPattern(RewritePatternSet &patterns);
+void populateExpandSinhPattern(RewritePatternSet &patterns);
+void populateExpandCoshPattern(RewritePatternSet &patterns);
 void populateExpandTanhPattern(RewritePatternSet &patterns);
 void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 6e30c07de4d57e..80eec9b2df7458 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -177,6 +177,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
                                            "roundeven");
   populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
   populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
+  populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
   populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
   populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
   populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 6b8c3a53a422fa..bac46996fce73e 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -180,6 +180,24 @@ OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// SinhOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOpConditional<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+        switch (a.getSizeInBits(a.getSemantics())) {
+        case 64:
+          return APFloat(sinh(a.convertToDouble()));
+        case 32:
+          return APFloat(sinhf(a.convertToFloat()));
+        default:
+          return {};
+        }
+      });
+}
+
 //===----------------------------------------------------------------------===//
 // CountLeadingZerosOp folder
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 9c46a4ca10a8ec..989a3e5536ec66 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -58,6 +58,38 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
   return b.create<math::CopySignOp>(fpFixedConvert, operand);
 }
 
+// sinhf(float x) -> (exp(x) - exp(-x)) / 2
+static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+  Value exp = b.create<math::ExpOp>(operand);
+
+  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+  Value nexp = b.create<arith::DivFOp>(one, exp);
+  Value sub = b.create<arith::SubFOp>(exp, nexp);
+  Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
+  Value div = b.create<arith::DivFOp>(sub, two);
+  rewriter.replaceOp(op, div);
+  return success();
+}
+
+// coshf(float x) -> (exp(x) + exp(-x)) / 2
+static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+  Value exp = b.create<math::ExpOp>(operand);
+
+  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+  Value nexp = b.create<arith::DivFOp>(one, exp);
+  Value add = b.create<arith::AddFOp>(exp, nexp);
+  Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
+  Value div = b.create<arith::DivFOp>(add, two);
+  rewriter.replaceOp(op, div);
+  return success();
+}
+
 /// Expands tanh op into
 ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
 ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
@@ -445,6 +477,14 @@ void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
   patterns.add(convertCtlzOp);
 }
 
+void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
+  patterns.add(convertSinhOp);
+}
+
+void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
+  patterns.add(convertCoshOp);
+}
+
 void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
   patterns.add(convertTanOp);
 }
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index eb9226dee2619d..bfe084b6ca0ab6 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -141,6 +141,18 @@ func.func @cosh_caller(%float: f32, %double: f64) -> (f32, f64)  {
   return %float_result, %double_result : f32, f64
 }
 
+// CHECK-LABEL: func @sinh_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @sinh_caller(%float: f32, %double: f64) -> (f32, f64)  {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @sinhf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.sinh %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @sinh(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.sinh %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
+
 // CHECK-LABEL: func @atan2_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 6dae8213dd41e3..7ce8b5a7cfe9b3 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -39,6 +39,8 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandCtlzPattern(patterns);
   populateExpandExp2FPattern(patterns);
   populateExpandTanPattern(patterns);
+  populateExpandSinhPattern(patterns);
+  populateExpandCoshPattern(patterns);
   populateExpandTanhPattern(patterns);
   populateExpandFmaFPattern(patterns);
   populateExpandFloorFPattern(patterns);
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 3bf474ea47f37f..541a201c94c586 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -591,10 +591,104 @@ func.func @roundeven() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// Sinh.
+// -------------------------------------------------------------------------- //
+
+func.func @sinh_f32(%a : f32) {
+  %r = math.sinh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @sinh_4xf32(%a : vector<4xf32>) {
+  %r = math.sinh %a : vector<4xf32>
+  vector.print %r : vector<4xf32>
+  return
+}
+
+func.func @sinh_8xf32(%a : vector<8xf32>) {
+  %r = math.sinh %a : vector<8xf32>
+  vector.print %r : vector<8xf32>
+  return
+}
+
+func.func @sinh() {
+  // CHECK: 1.60192
+  %f0 = arith.constant 1.25 : f32
+  call @sinh_f32(%f0) : (f32) -> ()
+
+  // CHECK: 0.252612, 0.822317, 1.1752, 1.60192
+  %v1 = arith.constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
+  call @sinh_4xf32(%v1) : (vector<4xf32>) -> ()
+
+  // CHECK: 0.100167, 0.201336, 0.30452, 0.410752, 0.521095, 0.636654, 0.758584, 0.888106
+  %v2 = arith.constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32>
+  call @sinh_8xf32(%v2) : (vector<8xf32>) -> ()
+
+  // CHECK: -0.100167, -0.201336, -0.30452, -0.410752, -0.521095, -0.636654, -0.758584, -0.888106
+  %v3 = arith.constant dense<[-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8]> : vector<8xf32>
+  call @sinh_8xf32(%v3) : (vector<8xf32>) -> ()
+
+  // CHECK: nan
+  %nan = arith.constant 0x7fc00000 : f32
+  call @sinh_f32(%nan) : (f32) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Cosh.
+// -------------------------------------------------------------------------- //
+
+func.func @cosh_f32(%a : f32) {
+  %r = math.cosh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @cosh_4xf32(%a : vector<4xf32>) {
+  %r = math.cosh %a : vector<4xf32>
+  vector.print %r : vector<4xf32>
+  return
+}
+
+func.func @cosh_8xf32(%a : vector<8xf32>) {
+  %r = math.cosh %a : vector<8xf32>
+  vector.print %r : vector<8xf32>
+  return
+}
+
+func.func @cosh() {
+  // CHECK: 1.88842
+  %f0 = arith.constant 1.25 : f32
+  call @cosh_f32(%f0) : (f32) -> ()
+
+  // CHECK: 1.03141, 1.29468, 1.54308, 1.88842
+  %v1 = arith.constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
+  call @cosh_4xf32(%v1) : (vector<4xf32>) -> ()
+
+  // CHECK: 1.005, 1.02007, 1.04534, 1.08107, 1.12763, 1.18547, 1.25517, 1.33743
+  %v2 = arith.constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32>
+  call @cosh_8xf32(%v2) : (vector<8xf32>) -> ()
+
+  // CHECK: 1.005, 1.02007, 1.04534, 1.08107, 1.12763, 1.18547, 1.25517, 1.33743
+  %v3 = arith.constant dense<[-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8]> : vector<8xf32>
+  call @cosh_8xf32(%v3) : (vector<8xf32>) -> ()
+
+  // CHECK: nan
+  %nan = arith.constant 0x7fc00000 : f32
+  call @cosh_f32(%nan) : (f32) -> ()
+
+ return
+}
+
 func.func @main() {
   call @exp2f() : () -> ()
   call @roundf() : () -> ()
   call @powf() : () -> ()
   call @roundeven() : () -> ()
+  call @sinh() : () -> ()
+  call @cosh() : () -> ()
   return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/75517


More information about the Mlir-commits mailing list