[Mlir-commits] [mlir] [mlir][math] Add Polynomial Approximation for acos, asin op (PR #90962)

Prashant Kumar llvmlistbot at llvm.org
Fri May 3 05:21:36 PDT 2024


https://github.com/pashu123 created https://github.com/llvm/llvm-project/pull/90962

Adds the Polynomial Approximation for math.acos and math.asin op. Also, adds integration tests.
The Approximation has been borrowed from
https://stackoverflow.com/a/42683455

>From 0be25e06c6773c34c8d592f07d30a767003f28b1 Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Fri, 26 Apr 2024 18:00:18 +0530
Subject: [PATCH] [mlir][math] Add Polynomial Approximation for acos, asin op

Adds the Polynomial Approximation for math.acos and math.asin op.
Also, adds integration tests.
The Approximation has been borrowed from
https://stackoverflow.com/a/42683455
---
 .../Transforms/PolynomialApproximation.cpp    | 160 +++++++++++++++++-
 .../math-polynomial-approx.mlir               |  80 +++++++++
 2 files changed, 234 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 428c1c37c4e8b5..f4fae68da63b3a 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -821,6 +821,153 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
   return success();
 }
 
+//----------------------------------------------------------------------------//
+// Asin approximation.
+//----------------------------------------------------------------------------//
+
+// Approximates asin(x).
+// This approximation is based on the following stackoverflow post:
+// https://stackoverflow.com/a/42683455
+namespace {
+struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::AsinOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+LogicalResult
+AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
+                                             PatternRewriter &rewriter) const {
+  Value operand = op.getOperand();
+  Type elementType = getElementTypeOrSelf(operand);
+
+  if (!(elementType.isF32() || elementType.isF16()))
+    return rewriter.notifyMatchFailure(op,
+                                       "only f32 and f16 type is supported.");
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  auto fma = [&](Value a, Value b, Value c) -> Value {
+    return builder.create<math::FmaOp>(a, b, c);
+  };
+
+  auto mul = [&](Value a, Value b) -> Value {
+    return builder.create<arith::MulFOp>(a, b);
+  };
+
+  Value s = mul(operand, operand);
+  Value q = mul(s, s);
+  Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
+  Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
+
+  r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
+  r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
+  r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
+  r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
+  r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
+  t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
+  r = fma(r, s, t);
+  r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
+  t = mul(operand, s);
+  r = fma(r, t, operand);
+
+  rewriter.replaceOp(op, r);
+  return success();
+}
+
+//----------------------------------------------------------------------------//
+// Acos approximation.
+//----------------------------------------------------------------------------//
+
+// Approximates acos(x).
+// This approximation is based on the following stackoverflow post:
+// https://stackoverflow.com/a/42683455
+namespace {
+struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::AcosOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+LogicalResult
+AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
+                                             PatternRewriter &rewriter) const {
+  Value operand = op.getOperand();
+  Type elementType = getElementTypeOrSelf(operand);
+
+  if (!(elementType.isF32() || elementType.isF16()))
+    return rewriter.notifyMatchFailure(op,
+                                       "only f32 and f16 type is supported.");
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  auto fma = [&](Value a, Value b, Value c) -> Value {
+    return builder.create<math::FmaOp>(a, b, c);
+  };
+
+  auto mul = [&](Value a, Value b) -> Value {
+    return builder.create<arith::MulFOp>(a, b);
+  };
+
+  Value negOperand = builder.create<arith::NegFOp>(operand);
+  Value zero = bcast(floatCst(builder, 0.0, elementType));
+  Value half = bcast(floatCst(builder, 0.5, elementType));
+  Value negOne = bcast(floatCst(builder, -1.0, elementType));
+  Value selR =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
+  Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
+  Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
+  Value firstPred =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
+
+  Value trueVal =
+      fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
+          bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
+          builder.create<math::AsinOp>(r));
+
+  Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
+  falseVal = builder.create<math::AsinOp>(falseVal);
+  falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
+
+  r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
+
+  // Check whether the operand lies in between [-1.0, 0.0).
+  Value greaterThanNegOne =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
+
+  Value lessThanZero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
+
+  Value betweenNegOneZero =
+      builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
+
+  trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
+                bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
+                builder.create<arith::NegFOp>(r));
+
+  Value finalVal =
+      builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
+
+  rewriter.replaceOp(op, finalVal);
+  return success();
+}
+
 //----------------------------------------------------------------------------//
 // Erf approximation.
 //----------------------------------------------------------------------------//
@@ -1505,12 +1652,13 @@ void mlir::populateMathPolynomialApproximationPatterns(
            ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
           patterns.getContext());
 
-  patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
-               LogApproximation, Log2Approximation, Log1pApproximation,
-               ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-               CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
-               SinAndCosApproximation<false, math::CosOp>>(
-      patterns.getContext());
+  patterns
+      .add<AtanApproximation, Atan2Approximation, TanhApproximation,
+           LogApproximation, Log2Approximation, Log1pApproximation,
+           ErfPolynomialApproximation, AsinPolynomialApproximation,
+           AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
+           CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
+           SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
   if (options.enableAvx2) {
     patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
         patterns.getContext());
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index d3b19be9ecaf8f..370c5baa0adef3 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -461,6 +461,84 @@ func.func @cos() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// Asin.
+// -------------------------------------------------------------------------- //
+func.func @asin_f32(%a : f32) {
+  %r = math.asin %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @asin_3xf32(%a : vector<3xf32>) {
+  %r = math.asin %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @asin() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @asin_f32(%zero) : (f32) -> ()
+
+  // CHECK: -0.597406
+  %cst1 = arith.constant -0.5625 : f32
+  call @asin_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.384397
+  %cst2 = arith.constant -0.375 : f32
+  call @asin_f32(%cst2) : (f32) -> ()
+
+  // CHECK: -0.25268
+  %cst3 = arith.constant -0.25 : f32
+  call @asin_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.25268, 0.384397, 0.597406
+  %vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
+  call @asin_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+// -------------------------------------------------------------------------- //
+// Acos.
+// -------------------------------------------------------------------------- //
+func.func @acos_f32(%a : f32) {
+  %r = math.acos %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @acos_3xf32(%a : vector<3xf32>) {
+  %r = math.acos %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @acos() {
+  // CHECK: 1.5708
+  %zero = arith.constant 0.0 : f32
+  call @acos_f32(%zero) : (f32) -> ()
+
+  // CHECK: 2.1682
+  %cst1 = arith.constant -0.5625 : f32
+  call @acos_f32(%cst1) : (f32) -> ()
+
+  // CHECK: 1.95519
+  %cst2 = arith.constant -0.375 : f32
+  call @acos_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 1.82348
+  %cst3 = arith.constant -0.25 : f32
+  call @acos_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 1.31812, 1.1864, 0.97339
+  %vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
+  call @acos_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
 // -------------------------------------------------------------------------- //
 // Atan.
 // -------------------------------------------------------------------------- //
@@ -694,6 +772,8 @@ func.func @main() {
   call @expm1(): () -> ()
   call @sin(): () -> ()
   call @cos(): () -> ()
+  call @asin(): () -> ()
+  call @acos(): () -> ()
   call @atan() : () -> ()
   call @atan2() : () -> ()
   call @cbrt() : () -> ()



More information about the Mlir-commits mailing list