[Mlir-commits] [mlir] aa165ed - [mlir][math] Added `math.sinh` with expansions to `math.exp` (#75517)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 15 11:35:44 PST 2023
Author: Rob Suderman
Date: 2023-12-15T11:35:40-08:00
New Revision: aa165edca8545b212de084d5b18c3d30347f774a
URL: https://github.com/llvm/llvm-project/commit/aa165edca8545b212de084d5b18c3d30347f774a
DIFF: https://github.com/llvm/llvm-project/commit/aa165edca8545b212de084d5b18c3d30347f774a.diff
LOG: [mlir][math] Added `math.sinh` with expansions to `math.exp` (#75517)
Includes end-to-end tests for the cpu running, folders using `libm` and
lowerings to the corresponding `libm` operations.
Added:
Modified:
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index b9daa91b28a9bd..211cb31d50bdcf 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -375,6 +375,27 @@ def Math_SinOp : Math_FloatUnaryOp<"sin"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// SinhOp
+//===----------------------------------------------------------------------===//
+
+def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
+ let summary = "hyperbolic sine of the specified value";
+ let description = [{
+ The `sinh` operation computes the hyperbolic sine. It takes one operand
+ of floating point type (i.e., scalar, tensor or vector) and returns one
+ result of the same type. It has no standard attributes.
+
+ Example:
+
+ ```mlir
+ // Scalar hyperbolic sine value.
+ %a = math.sinh %b : f64
+ ```
+ }];
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 817d6e1dae051f..9e6759ef229d6f 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -25,6 +25,8 @@ class RewritePatternSet;
void populateExpandCtlzPattern(RewritePatternSet &patterns);
void populateExpandTanPattern(RewritePatternSet &patterns);
+void populateExpandSinhPattern(RewritePatternSet &patterns);
+void populateExpandCoshPattern(RewritePatternSet &patterns);
void populateExpandTanhPattern(RewritePatternSet &patterns);
void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateExpandFloorFPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 6e30c07de4d57e..80eec9b2df7458 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -177,6 +177,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
"roundeven");
populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
+ populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 6b8c3a53a422fa..bac46996fce73e 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -180,6 +180,24 @@ OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// SinhOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
+ return constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+ switch (a.getSizeInBits(a.getSemantics())) {
+ case 64:
+ return APFloat(sinh(a.convertToDouble()));
+ case 32:
+ return APFloat(sinhf(a.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 9c46a4ca10a8ec..989a3e5536ec66 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -58,6 +58,38 @@ static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
return b.create<math::CopySignOp>(fpFixedConvert, operand);
}
+// sinhf(float x) -> (exp(x) - exp(-x)) / 2
+static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+ Value exp = b.create<math::ExpOp>(operand);
+
+ Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+ Value nexp = b.create<arith::DivFOp>(one, exp);
+ Value sub = b.create<arith::SubFOp>(exp, nexp);
+ Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
+ Value div = b.create<arith::DivFOp>(sub, two);
+ rewriter.replaceOp(op, div);
+ return success();
+}
+
+// coshf(float x) -> (exp(x) + exp(-x)) / 2
+static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+ Value exp = b.create<math::ExpOp>(operand);
+
+ Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+ Value nexp = b.create<arith::DivFOp>(one, exp);
+ Value add = b.create<arith::AddFOp>(exp, nexp);
+ Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
+ Value div = b.create<arith::DivFOp>(add, two);
+ rewriter.replaceOp(op, div);
+ return success();
+}
+
/// Expands tanh op into
/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
@@ -445,6 +477,14 @@ void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
patterns.add(convertCtlzOp);
}
+void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
+ patterns.add(convertSinhOp);
+}
+
+void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
+ patterns.add(convertCoshOp);
+}
+
void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
patterns.add(convertTanOp);
}
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index eb9226dee2619d..bfe084b6ca0ab6 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -141,6 +141,18 @@ func.func @cosh_caller(%float: f32, %double: f64) -> (f32, f64) {
return %float_result, %double_result : f32, f64
}
+// CHECK-LABEL: func @sinh_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @sinh_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @sinhf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.sinh %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @sinh(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.sinh %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
+
// CHECK-LABEL: func @atan2_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 6dae8213dd41e3..7ce8b5a7cfe9b3 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -39,6 +39,8 @@ void TestExpandMathPass::runOnOperation() {
populateExpandCtlzPattern(patterns);
populateExpandExp2FPattern(patterns);
populateExpandTanPattern(patterns);
+ populateExpandSinhPattern(patterns);
+ populateExpandCoshPattern(patterns);
populateExpandTanhPattern(patterns);
populateExpandFmaFPattern(patterns);
populateExpandFloorFPattern(patterns);
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 3bf474ea47f37f..541a201c94c586 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -591,10 +591,104 @@ func.func @roundeven() {
return
}
+// -------------------------------------------------------------------------- //
+// Sinh.
+// -------------------------------------------------------------------------- //
+
+func.func @sinh_f32(%a : f32) {
+ %r = math.sinh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @sinh_4xf32(%a : vector<4xf32>) {
+ %r = math.sinh %a : vector<4xf32>
+ vector.print %r : vector<4xf32>
+ return
+}
+
+func.func @sinh_8xf32(%a : vector<8xf32>) {
+ %r = math.sinh %a : vector<8xf32>
+ vector.print %r : vector<8xf32>
+ return
+}
+
+func.func @sinh() {
+ // CHECK: 1.60192
+ %f0 = arith.constant 1.25 : f32
+ call @sinh_f32(%f0) : (f32) -> ()
+
+ // CHECK: 0.252612, 0.822317, 1.1752, 1.60192
+ %v1 = arith.constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
+ call @sinh_4xf32(%v1) : (vector<4xf32>) -> ()
+
+ // CHECK: 0.100167, 0.201336, 0.30452, 0.410752, 0.521095, 0.636654, 0.758584, 0.888106
+ %v2 = arith.constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32>
+ call @sinh_8xf32(%v2) : (vector<8xf32>) -> ()
+
+ // CHECK: -0.100167, -0.201336, -0.30452, -0.410752, -0.521095, -0.636654, -0.758584, -0.888106
+ %v3 = arith.constant dense<[-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8]> : vector<8xf32>
+ call @sinh_8xf32(%v3) : (vector<8xf32>) -> ()
+
+ // CHECK: nan
+ %nan = arith.constant 0x7fc00000 : f32
+ call @sinh_f32(%nan) : (f32) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Cosh.
+// -------------------------------------------------------------------------- //
+
+func.func @cosh_f32(%a : f32) {
+ %r = math.cosh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @cosh_4xf32(%a : vector<4xf32>) {
+ %r = math.cosh %a : vector<4xf32>
+ vector.print %r : vector<4xf32>
+ return
+}
+
+func.func @cosh_8xf32(%a : vector<8xf32>) {
+ %r = math.cosh %a : vector<8xf32>
+ vector.print %r : vector<8xf32>
+ return
+}
+
+func.func @cosh() {
+ // CHECK: 1.88842
+ %f0 = arith.constant 1.25 : f32
+ call @cosh_f32(%f0) : (f32) -> ()
+
+ // CHECK: 1.03141, 1.29468, 1.54308, 1.88842
+ %v1 = arith.constant dense<[0.25, 0.75, 1.0, 1.25]> : vector<4xf32>
+ call @cosh_4xf32(%v1) : (vector<4xf32>) -> ()
+
+ // CHECK: 1.005, 1.02007, 1.04534, 1.08107, 1.12763, 1.18547, 1.25517, 1.33743
+ %v2 = arith.constant dense<[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]> : vector<8xf32>
+ call @cosh_8xf32(%v2) : (vector<8xf32>) -> ()
+
+ // CHECK: 1.005, 1.02007, 1.04534, 1.08107, 1.12763, 1.18547, 1.25517, 1.33743
+ %v3 = arith.constant dense<[-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8]> : vector<8xf32>
+ call @cosh_8xf32(%v3) : (vector<8xf32>) -> ()
+
+ // CHECK: nan
+ %nan = arith.constant 0x7fc00000 : f32
+ call @cosh_f32(%nan) : (f32) -> ()
+
+ return
+}
+
func.func @main() {
call @exp2f() : () -> ()
call @roundf() : () -> ()
call @powf() : () -> ()
call @roundeven() : () -> ()
+ call @sinh() : () -> ()
+ call @cosh() : () -> ()
return
}
More information about the Mlir-commits
mailing list