[Mlir-commits] [mlir] [mlir][math] Add Polynomial Approximation for few ops (PR #90718)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 1 03:21:55 PDT 2024


https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/90718

>From 50bf853008d87394d6a9053fabd28f5b0055031a Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Tue, 30 Apr 2024 17:14:03 -0700
Subject: [PATCH] [mlir][math] Add Polynomial Approximation for acosh, asinh,
 atanh, cosh, sinh ops

---
 .../Transforms/PolynomialApproximation.cpp    | 165 +++++++++++++-
 .../math-polynomial-approx.mlir               | 202 ++++++++++++++++++
 2 files changed, 358 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 428c1c37c4e8b5..acccfb64b945df 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -615,10 +615,47 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
   return success();
 }
 
-#define LN2_VALUE                                                              \
-  0.693147180559945309417232121458176568075500134360255254120680009493393621L
-#define LOG2E_VALUE                                                            \
-  1.442695040888963407359924681001892137426645954152985934135449406931109219L
+//----------------------------------------------------------------------------//
+// AtanhOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct AtanhApproximation : public OpRewritePattern<math::AtanhOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::AtanhOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+AtanhApproximation::matchAndRewrite(math::AtanhOp op,
+                                    PatternRewriter &rewriter) const {
+  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  auto operand = op.getOperand();
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // 1/2 * log((1 + x) / (1 - x))
+  Value cstOne = bcast(f32Cst(builder, 1.0));
+  Value add = builder.create<arith::AddFOp>(operand, cstOne);
+  Value neg = builder.create<arith::NegFOp>(operand);
+  Value sub = builder.create<arith::AddFOp>(neg, cstOne);
+  Value div = builder.create<arith::DivFOp>(add, sub);
+  Value log = builder.create<math::LogOp>(div);
+  Value cstTwo = bcast(f32Cst(builder, 2.0));
+  Value res = builder.create<arith::DivFOp>(log, cstTwo);
+  rewriter.replaceOp(op, res);
+
+  return success();
+}
 
 //----------------------------------------------------------------------------//
 // LogOp and Log2Op approximation.
@@ -635,6 +672,11 @@ struct LogApproximationBase : public OpRewritePattern<Op> {
 };
 } // namespace
 
+#define LN2_VALUE                                                              \
+  0.693147180559945309417232121458176568075500134360255254120680009493393621L
+#define LOG2E_VALUE                                                            \
+  1.442695040888963407359924681001892137426645954152985934135449406931109219L
+
 // This approximation comes from Julien Pommier's SSE math library.
 // Link: http://gruntthepeon.free.fr/ssemath
 template <typename Op>
@@ -1316,6 +1358,106 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   return success();
 }
 
+//----------------------------------------------------------------------------//
+// SinhOp and CoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct SinhAndCoshApproximation : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult SinhAndCoshApproximation<isSine, OpTy>::matchAndRewrite(
+    OpTy op, PatternRewriter &rewriter) const {
+  static_assert(
+      llvm::is_one_of<OpTy, math::SinhOp, math::CoshOp>::value,
+      "SinAndCosApproximation pattern expects math::SinhOp or math::CoshOp");
+
+  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  auto operand = op.getOperand();
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // sinh: 1/2 * (exp(x) – exp(-x))
+  // cosh: 1/2 * (exp(x) + exp(-x))
+  Value a = builder.create<math::ExpOp>(operand);
+  Value neg = builder.create<arith::NegFOp>(operand);
+  Value b = builder.create<math::ExpOp>(neg);
+  Value c;
+  if (isSine)
+    c = builder.create<arith::SubFOp>(a, b);
+  else
+    c = builder.create<arith::AddFOp>(a, b);
+  Value cstTwo = bcast(f32Cst(builder, 2.0));
+  Value res = builder.create<arith::DivFOp>(c, cstTwo);
+  rewriter.replaceOp(op, res);
+
+  return success();
+}
+
+//----------------------------------------------------------------------------//
+// AsinhOp and AcoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct AsinhAndAcoshApproximation : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult AsinhAndAcoshApproximation<isSine, OpTy>::matchAndRewrite(
+    OpTy op, PatternRewriter &rewriter) const {
+  static_assert(
+      llvm::is_one_of<OpTy, math::AsinhOp, math::AcoshOp>::value,
+      "SinAndCosApproximation pattern expects math::AsinhOp or math::AcoshOp");
+
+  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  auto operand = op.getOperand();
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // asinh: log(x + sqrt(x**2 + 1))
+  // acosh: log(x + sqrt(x**2 - 1))
+  Value squared = builder.create<arith::MulFOp>(operand, operand);
+  Value cstOne = bcast(f32Cst(builder, 1.0));
+  Value a;
+  if (isSine)
+    a = builder.create<arith::AddFOp>(squared, cstOne);
+  else
+    a = builder.create<arith::SubFOp>(squared, cstOne);
+  Value sqrt = builder.create<math::SqrtOp>(a);
+  Value b = builder.create<arith::AddFOp>(operand, sqrt);
+  Value res = builder.create<math::LogOp>(b);
+  rewriter.replaceOp(op, res);
+
+  return success();
+}
+
 //----------------------------------------------------------------------------//
 // Cbrt approximation.
 //----------------------------------------------------------------------------//
@@ -1505,11 +1647,16 @@ void mlir::populateMathPolynomialApproximationPatterns(
            ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
           patterns.getContext());
 
-  patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
-               LogApproximation, Log2Approximation, Log1pApproximation,
-               ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-               CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
-               SinAndCosApproximation<false, math::CosOp>>(
+  patterns.add<AtanApproximation, Atan2Approximation, AtanhApproximation,
+               TanhApproximation, LogApproximation, Log2Approximation,
+               Log1pApproximation, ErfPolynomialApproximation, ExpApproximation,
+               ExpM1Approximation, CbrtApproximation,
+               SinAndCosApproximation<true, math::SinOp>,
+               SinAndCosApproximation<false, math::CosOp>,
+               SinhAndCoshApproximation<true, math::SinhOp>,
+               SinhAndCoshApproximation<false, math::CoshOp>,
+               AsinhAndAcoshApproximation<true, math::AsinhOp>,
+               AsinhAndAcoshApproximation<false, math::AcoshOp>>(
       patterns.getContext());
   if (options.enableAvx2) {
     patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index d3b19be9ecaf8f..9b73cdf57f5a35 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -568,6 +568,203 @@ func.func @atan2() {
 }
 
 
+// -------------------------------------------------------------------------- //
+// sinh
+// -------------------------------------------------------------------------- //
+
+func.func @sinh_f32(%a : f32) {
+  %r = math.sinh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @sinh_3xf32(%a : vector<3xf32>) {
+  %r = math.sinh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @sinh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @sinh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.521095
+  %cst1 = arith.constant 0.5 : f32
+  call @sinh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -1.1752
+  %cst2 = arith.constant -1.0 : f32
+  call @sinh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 10.0179
+  %cst3 = arith.constant 3.0 : f32
+  call @sinh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.252612, 0.991007, 3.62686
+  %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+  call @sinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
+// -------------------------------------------------------------------------- //
+// cosh
+// -------------------------------------------------------------------------- //
+
+func.func @cosh_f32(%a : f32) {
+  %r = math.cosh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @cosh_3xf32(%a : vector<3xf32>) {
+  %r = math.cosh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @cosh() {
+  // CHECK: 1
+  %zero = arith.constant 0.0 : f32
+  call @cosh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 1.54308
+  %cst1 = arith.constant 1.0 : f32
+  call @cosh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: 1.54308
+  %cst2 = arith.constant -1.0 : f32
+  call @cosh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 10.0677
+  %cst3 = arith.constant 3.0 : f32
+  call @cosh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 1.03141, 1.40787, 3.7622
+  %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+  call @cosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
+// -------------------------------------------------------------------------- //
+// asinh
+// -------------------------------------------------------------------------- //
+
+func.func @asinh_f32(%a : f32) {
+  %r = math.asinh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @asinh_3xf32(%a : vector<3xf32>) {
+  %r = math.asinh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @asinh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @asinh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.881374
+  %cst1 = arith.constant 1.0 : f32
+  call @asinh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.881374
+  %cst2 = arith.constant -1.0 : f32
+  call @asinh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 1.81845
+  %cst3 = arith.constant 3.0 : f32
+  call @asinh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.247466, 0.790169, 1.44364
+  %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+  call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
+// -------------------------------------------------------------------------- //
+// acosh
+// -------------------------------------------------------------------------- //
+
+func.func @acosh_f32(%a : f32) {
+  %r = math.acosh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @acosh_3xf32(%a : vector<3xf32>) {
+  %r = math.acosh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @acosh() {
+  // CHECK: 0
+  %zero = arith.constant 1.0 : f32
+  call @acosh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 1.31696
+  %cst1 = arith.constant 2.0 : f32
+  call @acosh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: 2.99322
+  %cst2 = arith.constant 10.0 : f32
+  call @acosh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 0.962424, 1.76275, 2.47789
+  %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
+  call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
+// -------------------------------------------------------------------------- //
+// atanh
+// -------------------------------------------------------------------------- //
+
+func.func @atanh_f32(%a : f32) {
+  %r = math.atanh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @atanh_3xf32(%a : vector<3xf32>) {
+  %r = math.atanh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @atanh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @atanh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.549306
+  %cst1 = arith.constant 0.5 : f32
+  call @atanh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.549306
+  %cst2 = arith.constant -0.5 : f32
+  call @atanh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 0.255413, 0.394229, 2.99448
+  %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
+  call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+
 // -------------------------------------------------------------------------- //
 // Cbrt.
 // -------------------------------------------------------------------------- //
@@ -696,6 +893,11 @@ func.func @main() {
   call @cos(): () -> ()
   call @atan() : () -> ()
   call @atan2() : () -> ()
+  call @sinh() : () -> ()
+  call @cosh() : () -> ()
+  call @asinh() : () -> ()
+  call @acosh() : () -> ()
+  call @atanh() : () -> ()
   call @cbrt() : () -> ()
   call @floorf() : () -> ()
   call @ceilf() : () -> ()



More information about the Mlir-commits mailing list