[Mlir-commits] [mlir] [mlir][math] Add Polynomial Approximation for few ops (PR #90718)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 1 03:17:59 PDT 2024
https://github.com/jinchen62 created https://github.com/llvm/llvm-project/pull/90718
- acosh
- asinh
- atanh
- cosh
- sinh
>From 4b95d91428e01e11f4288daff92d75d8c4871dc9 Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Tue, 30 Apr 2024 17:14:03 -0700
Subject: [PATCH] [mlir][math] Add Polynomial Approximation for acosh, asinh,
atanh, cosh, sinh ops
---
.../Transforms/PolynomialApproximation.cpp | 165 +++++++++++++-
.../math-polynomial-approx.mlir | 202 ++++++++++++++++++
2 files changed, 358 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 428c1c37c4e8b5..36bb8c6245e499 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -615,10 +615,47 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
return success();
}
-#define LN2_VALUE \
- 0.693147180559945309417232121458176568075500134360255254120680009493393621L
-#define LOG2E_VALUE \
- 1.442695040888963407359924681001892137426645954152985934135449406931109219L
+//----------------------------------------------------------------------------//
+// AtanhOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct AtanhApproximation : public OpRewritePattern<math::AtanhOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::AtanhOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+AtanhApproximation::matchAndRewrite(math::AtanhOp op,
+ PatternRewriter &rewriter) const {
+ if (!getElementTypeOrSelf(op.getOperand()).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto operand = op.getOperand();
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ // 1/2 * log((1 + x) / (1 - x))
+ Value cstOne = bcast(f32Cst(builder, 1.0));
+ Value add = builder.create<arith::AddFOp>(operand, cstOne);
+ Value neg = builder.create<arith::NegFOp>(operand);
+ Value sub = builder.create<arith::AddFOp>(neg, cstOne);
+ Value div = builder.create<arith::DivFOp>(add, sub);
+ Value log = builder.create<math::LogOp>(div);
+ Value cstTwo = bcast(f32Cst(builder, 2.0));
+ Value res = builder.create<arith::DivFOp>(log, cstTwo);
+ rewriter.replaceOp(op, res);
+
+ return success();
+}
//----------------------------------------------------------------------------//
// LogOp and Log2Op approximation.
@@ -635,6 +672,11 @@ struct LogApproximationBase : public OpRewritePattern<Op> {
};
} // namespace
+#define LN2_VALUE \
+ 0.693147180559945309417232121458176568075500134360255254120680009493393621L
+#define LOG2E_VALUE \
+ 1.442695040888963407359924681001892137426645954152985934135449406931109219L
+
// This approximation comes from Julien Pommier's SSE math library.
// Link: http://gruntthepeon.free.fr/ssemath
template <typename Op>
@@ -1316,6 +1358,106 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
return success();
}
+//----------------------------------------------------------------------------//
+// SinhOp and CoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct SinhAndCoshApproximation : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult SinhAndCoshApproximation<isSine, OpTy>::matchAndRewrite(
+ OpTy op, PatternRewriter &rewriter) const {
+ static_assert(
+ llvm::is_one_of<OpTy, math::SinhOp, math::CoshOp>::value,
+ "SinAndCosApproximation pattern expects math::SinhOp or math::CoshOp");
+
+ if (!getElementTypeOrSelf(op.getOperand()).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto operand = op.getOperand();
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ // sinh: 1/2 * (exp(x) – exp(-x))
+ // cosh: 1/2 * (exp(x) + exp(-x))
+ Value a = builder.create<math::ExpOp>(operand);
+ Value neg = builder.create<arith::NegFOp>(operand);
+ Value b = builder.create<math::ExpOp>(neg);
+ Value c;
+ if (isSine)
+ c = builder.create<arith::SubFOp>(a, b);
+ else
+ c = builder.create<arith::AddFOp>(a, b);
+ Value cstTwo = bcast(f32Cst(builder, 2.0));
+ Value res = builder.create<arith::DivFOp>(c, cstTwo);
+ rewriter.replaceOp(op, res);
+
+ return success();
+}
+
+//----------------------------------------------------------------------------//
+// AsinhOp and AcoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct AsinhAndAcoshApproximation : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult AsinhAndAcoshApproximation<isSine, OpTy>::matchAndRewrite(
+ OpTy op, PatternRewriter &rewriter) const {
+ static_assert(
+ llvm::is_one_of<OpTy, math::AsinhOp, math::AcoshOp>::value,
+ "SinAndCosApproximation pattern expects math::AsinhOp or math::AcoshOp");
+
+ if (!getElementTypeOrSelf(op.getOperand()).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto operand = op.getOperand();
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ // asinh: log(x + sqrt(x**2 + 1))
+ // acosh: log(x + sqrt(x**2 - 1))
+ Value squared = builder.create<arith::MulFOp>(operand, operand);
+ Value cstOne = bcast(f32Cst(builder, 1.0));
+ Value a;
+ if (isSine)
+ a = builder.create<arith::AddFOp>(squared, cstOne);
+ else
+ a = builder.create<arith::SubFOp>(squared, cstOne);
+ Value sqrt = builder.create<math::SqrtOp>(a);
+ Value b = builder.create<arith::AddFOp>(operand, sqrt);
+ Value res = builder.create<math::LogOp>(b);
+ rewriter.replaceOp(op, res);
+
+ return success();
+}
+
//----------------------------------------------------------------------------//
// Cbrt approximation.
//----------------------------------------------------------------------------//
@@ -1505,11 +1647,16 @@ void mlir::populateMathPolynomialApproximationPatterns(
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
patterns.getContext());
- patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
- LogApproximation, Log2Approximation, Log1pApproximation,
- ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
- SinAndCosApproximation<false, math::CosOp>>(
+ patterns.add<AtanApproximation, Atan2Approximation, AtanhApproximation,
+ TanhApproximation, LogApproximation, Log2Approximation,
+ Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
+ ExpM1Approximation, CbrtApproximation,
+ SinAndCosApproximation<true, math::SinOp>,
+ SinAndCosApproximation<false, math::CosOp>,
+ SinhAndCoshApproximation<true, math::SinhOp>,
+ SinhAndCoshApproximation<false, math::CoshOp>,
+ AsinhAndAcoshApproximation<true, math::AsinhOp>,
+ AsinhAndAcoshApproximation<false, math::AcoshOp>>(
patterns.getContext());
if (options.enableAvx2) {
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index d3b19be9ecaf8f..9b73cdf57f5a35 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -568,6 +568,203 @@ func.func @atan2() {
}
+// -------------------------------------------------------------------------- //
+// sinh
+// -------------------------------------------------------------------------- //
+
+func.func @sinh_f32(%a : f32) {
+ %r = math.sinh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @sinh_3xf32(%a : vector<3xf32>) {
+ %r = math.sinh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @sinh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @sinh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.521095
+ %cst1 = arith.constant 0.5 : f32
+ call @sinh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -1.1752
+ %cst2 = arith.constant -1.0 : f32
+ call @sinh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 10.0179
+ %cst3 = arith.constant 3.0 : f32
+ call @sinh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.252612, 0.991007, 3.62686
+ %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+ call @sinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
+// -------------------------------------------------------------------------- //
+// cosh
+// -------------------------------------------------------------------------- //
+
+func.func @cosh_f32(%a : f32) {
+ %r = math.cosh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @cosh_3xf32(%a : vector<3xf32>) {
+ %r = math.cosh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @cosh() {
+ // CHECK: 1
+ %zero = arith.constant 0.0 : f32
+ call @cosh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 1.54308
+ %cst1 = arith.constant 1.0 : f32
+ call @cosh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: 1.54308
+ %cst2 = arith.constant -1.0 : f32
+ call @cosh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 10.0677
+ %cst3 = arith.constant 3.0 : f32
+ call @cosh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 1.03141, 1.40787, 3.7622
+ %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+ call @cosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
+// -------------------------------------------------------------------------- //
+// asinh
+// -------------------------------------------------------------------------- //
+
+func.func @asinh_f32(%a : f32) {
+ %r = math.asinh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @asinh_3xf32(%a : vector<3xf32>) {
+ %r = math.asinh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @asinh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @asinh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.881374
+ %cst1 = arith.constant 1.0 : f32
+ call @asinh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.881374
+ %cst2 = arith.constant -1.0 : f32
+ call @asinh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 1.81845
+ %cst3 = arith.constant 3.0 : f32
+ call @asinh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.247466, 0.790169, 1.44364
+ %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+ call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
+// -------------------------------------------------------------------------- //
+// acosh
+// -------------------------------------------------------------------------- //
+
+func.func @acosh_f32(%a : f32) {
+ %r = math.acosh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @acosh_3xf32(%a : vector<3xf32>) {
+ %r = math.acosh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @acosh() {
+ // CHECK: 0
+ %zero = arith.constant 1.0 : f32
+ call @acosh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 1.31696
+ %cst1 = arith.constant 2.0 : f32
+ call @acosh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: 2.99322
+ %cst2 = arith.constant 10.0 : f32
+ call @acosh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 0.962424, 1.76275, 2.47789
+ %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
+ call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
+// -------------------------------------------------------------------------- //
+// atanh
+// -------------------------------------------------------------------------- //
+
+func.func @atanh_f32(%a : f32) {
+ %r = math.atanh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @atanh_3xf32(%a : vector<3xf32>) {
+ %r = math.atanh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @atanh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @atanh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.549306
+ %cst1 = arith.constant 0.5 : f32
+ call @atanh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.549306
+ %cst2 = arith.constant -0.5 : f32
+ call @atanh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 0.255413, 0.394229, 2.99448
+ %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
+ call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+
// -------------------------------------------------------------------------- //
// Cbrt.
// -------------------------------------------------------------------------- //
@@ -696,6 +893,11 @@ func.func @main() {
call @cos(): () -> ()
call @atan() : () -> ()
call @atan2() : () -> ()
+ call @sinh() : () -> ()
+ call @cosh() : () -> ()
+ call @asinh() : () -> ()
+ call @acosh() : () -> ()
+ call @atanh() : () -> ()
call @cbrt() : () -> ()
call @floorf() : () -> ()
call @ceilf() : () -> ()
More information about the Mlir-commits
mailing list