[Mlir-commits] [mlir] [mlir][math] Add Polynomial Approximation for acos, asin op (PR #90962)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 3 05:22:05 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Prashant Kumar (pashu123)
<details>
<summary>Changes</summary>
Adds the Polynomial Approximation for math.acos and math.asin op. Also, adds integration tests.
The Approximation has been borrowed from
https://stackoverflow.com/a/42683455
---
Full diff: https://github.com/llvm/llvm-project/pull/90962.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+154-6)
- (modified) mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir (+80)
``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 428c1c37c4e8b5..f4fae68da63b3a 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -821,6 +821,153 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
return success();
}
+//----------------------------------------------------------------------------//
+// Asin approximation.
+//----------------------------------------------------------------------------//
+
+// Approximates asin(x).
+// This approximation is based on the following stackoverflow post:
+// https://stackoverflow.com/a/42683455
+namespace {
+struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::AsinOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+LogicalResult
+AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
+ PatternRewriter &rewriter) const {
+ Value operand = op.getOperand();
+ Type elementType = getElementTypeOrSelf(operand);
+
+ if (!(elementType.isF32() || elementType.isF16()))
+ return rewriter.notifyMatchFailure(op,
+ "only f32 and f16 type is supported.");
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ auto fma = [&](Value a, Value b, Value c) -> Value {
+ return builder.create<math::FmaOp>(a, b, c);
+ };
+
+ auto mul = [&](Value a, Value b) -> Value {
+ return builder.create<arith::MulFOp>(a, b);
+ };
+
+ Value s = mul(operand, operand);
+ Value q = mul(s, s);
+ Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
+ Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
+
+ r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
+ r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
+ r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
+ r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
+ r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
+ t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
+ r = fma(r, s, t);
+ r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
+ t = mul(operand, s);
+ r = fma(r, t, operand);
+
+ rewriter.replaceOp(op, r);
+ return success();
+}
+
+//----------------------------------------------------------------------------//
+// Acos approximation.
+//----------------------------------------------------------------------------//
+
+// Approximates acos(x).
+// This approximation is based on the following stackoverflow post:
+// https://stackoverflow.com/a/42683455
+namespace {
+struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::AcosOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+LogicalResult
+AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
+ PatternRewriter &rewriter) const {
+ Value operand = op.getOperand();
+ Type elementType = getElementTypeOrSelf(operand);
+
+ if (!(elementType.isF32() || elementType.isF16()))
+ return rewriter.notifyMatchFailure(op,
+ "only f32 and f16 type is supported.");
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ auto fma = [&](Value a, Value b, Value c) -> Value {
+ return builder.create<math::FmaOp>(a, b, c);
+ };
+
+ auto mul = [&](Value a, Value b) -> Value {
+ return builder.create<arith::MulFOp>(a, b);
+ };
+
+ Value negOperand = builder.create<arith::NegFOp>(operand);
+ Value zero = bcast(floatCst(builder, 0.0, elementType));
+ Value half = bcast(floatCst(builder, 0.5, elementType));
+ Value negOne = bcast(floatCst(builder, -1.0, elementType));
+ Value selR =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
+ Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
+ Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
+ Value firstPred =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
+
+ Value trueVal =
+ fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
+ bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
+ builder.create<math::AsinOp>(r));
+
+ Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
+ falseVal = builder.create<math::AsinOp>(falseVal);
+ falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
+
+ r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
+
+ // Check whether the operand lies in between [-1.0, 0.0).
+ Value greaterThanNegOne =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
+
+ Value lessThanZero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
+
+ Value betweenNegOneZero =
+ builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
+
+ trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
+ bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
+ builder.create<arith::NegFOp>(r));
+
+ Value finalVal =
+ builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
+
+ rewriter.replaceOp(op, finalVal);
+ return success();
+}
+
//----------------------------------------------------------------------------//
// Erf approximation.
//----------------------------------------------------------------------------//
@@ -1505,12 +1652,13 @@ void mlir::populateMathPolynomialApproximationPatterns(
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
patterns.getContext());
- patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
- LogApproximation, Log2Approximation, Log1pApproximation,
- ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
- SinAndCosApproximation<false, math::CosOp>>(
- patterns.getContext());
+ patterns
+ .add<AtanApproximation, Atan2Approximation, TanhApproximation,
+ LogApproximation, Log2Approximation, Log1pApproximation,
+ ErfPolynomialApproximation, AsinPolynomialApproximation,
+ AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
+ CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
+ SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
if (options.enableAvx2) {
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
patterns.getContext());
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index d3b19be9ecaf8f..370c5baa0adef3 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -461,6 +461,84 @@ func.func @cos() {
return
}
+// -------------------------------------------------------------------------- //
+// Asin.
+// -------------------------------------------------------------------------- //
+func.func @asin_f32(%a : f32) {
+ %r = math.asin %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @asin_3xf32(%a : vector<3xf32>) {
+ %r = math.asin %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @asin() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @asin_f32(%zero) : (f32) -> ()
+
+ // CHECK: -0.597406
+ %cst1 = arith.constant -0.5625 : f32
+ call @asin_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.384397
+ %cst2 = arith.constant -0.375 : f32
+ call @asin_f32(%cst2) : (f32) -> ()
+
+ // CHECK: -0.25268
+ %cst3 = arith.constant -0.25 : f32
+ call @asin_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.25268, 0.384397, 0.597406
+ %vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
+ call @asin_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Acos.
+// -------------------------------------------------------------------------- //
+func.func @acos_f32(%a : f32) {
+ %r = math.acos %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @acos_3xf32(%a : vector<3xf32>) {
+ %r = math.acos %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @acos() {
+ // CHECK: 1.5708
+ %zero = arith.constant 0.0 : f32
+ call @acos_f32(%zero) : (f32) -> ()
+
+ // CHECK: 2.1682
+ %cst1 = arith.constant -0.5625 : f32
+ call @acos_f32(%cst1) : (f32) -> ()
+
+ // CHECK: 1.95519
+ %cst2 = arith.constant -0.375 : f32
+ call @acos_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 1.82348
+ %cst3 = arith.constant -0.25 : f32
+ call @acos_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 1.31812, 1.1864, 0.97339
+ %vec_x = arith.constant dense<[0.25, 0.375, 0.5625]> : vector<3xf32>
+ call @acos_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
// -------------------------------------------------------------------------- //
// Atan.
// -------------------------------------------------------------------------- //
@@ -694,6 +772,8 @@ func.func @main() {
call @expm1(): () -> ()
call @sin(): () -> ()
call @cos(): () -> ()
+ call @asin(): () -> ()
+ call @acos(): () -> ()
call @atan() : () -> ()
call @atan2() : () -> ()
call @cbrt() : () -> ()
``````````
</details>
https://github.com/llvm/llvm-project/pull/90962
More information about the Mlir-commits
mailing list