# [Mlir-commits] [mlir] 7e2d672 - Add polynomial approximation for trigonometric sine and cosine functions

Ahmed Taei llvmlistbot at llvm.org
Mon Jun 21 13:00:40 PDT 2021

Author: Ahmed S. Taei
Date: 2021-06-21T13:00:33-07:00
New Revision: 7e2d672a672c0559f6e5c417c5ee2514402cf18e

URL: https://github.com/llvm/llvm-project/commit/7e2d672a672c0559f6e5c417c5ee2514402cf18e
DIFF: https://github.com/llvm/llvm-project/commit/7e2d672a672c0559f6e5c417c5ee2514402cf18e.diff

LOG: Add polynomial approximation for trigonometric sine and cosine functions

The approximation relays on range reduced version y \in [0, pi/2]. An input x will have
the property that sin(x) = sin(y), -sin(y), cos(y), -cos(y) depends on which quadrable x
is in, where sin(y) and cos(y) are approximated with 5th degree polynomial (of x^2).
As a result a single pattern can be used to compute approximation for both sine and cosine.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D104582

Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir

Removed:

################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index da58b244d6c67..7844d8c475d2d 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -630,11 +630,146 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
return success();
}

+//----------------------------------------------------------------------------//
+// Sin and Cos approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+#define TWO_OVER_PI                                                            \
+  0.6366197723675813430755350534900574481378385829618257949906693762L
+#define PI_OVER_2                                                              \
+  1.5707963267948966192313216916397514420985846996875529104874722961L
+
+// Approximates sin(x) or cos(x) by finding the best approximation polynomial in
+// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
+// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
+template <bool isSine, typename OpTy>
+LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
+    OpTy op, PatternRewriter &rewriter) const {
+  static_assert(
+      llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
+      "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
+  auto width = vectorWidth(op.operand().getType(), isF32);
+  if (!width.hasValue())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+  };
+  auto mul = [&](Value a, Value b) -> Value {
+    return builder.create<MulFOp>(a, b);
+  };
+  auto sub = [&](Value a, Value b) -> Value {
+    return builder.create<SubFOp>(a, b);
+  };
+  auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
+
+  auto i32Vec = broadcast(builder.getI32Type(), *width);
+  auto fPToSingedInteger = [&](Value a) -> Value {
+    return builder.create<FPToSIOp>(a, i32Vec);
+  };
+
+  auto modulo4 = [&](Value a) -> Value {
+    return builder.create<AndOp>(a, bcast(i32Cst(builder, 3)));
+  };
+
+  auto isEqualTo = [&](Value a, Value b) -> Value {
+    return builder.create<CmpIOp>(CmpIPredicate::eq, a, b);
+  };
+
+  auto isGreaterThan = [&](Value a, Value b) -> Value {
+    return builder.create<CmpIOp>(CmpIPredicate::sgt, a, b);
+  };
+
+  auto select = [&](Value cond, Value t, Value f) -> Value {
+    return builder.create<SelectOp>(cond, t, f);
+  };
+
+  auto fmla = [&](Value a, Value b, Value c) {
+    return builder.create<FmaFOp>(a, b, c);
+  };
+
+  auto bitwiseOr = [&](Value a, Value b) { return builder.create<OrOp>(a, b); };
+
+  Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI));
+  Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2));
+
+  Value x = op.operand();
+
+  Value k = floor(mul(x, twoOverPi));
+
+  Value y = sub(x, mul(k, piOverTwo));
+
+  Value cstOne = bcast(f32Cst(builder, 1.0));
+  Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
+
+  Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
+  Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
+  Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
+  Value cstSC8 =
+      bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
+  Value cstSC10 =
+      bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
+
+  Value cstCC2 = bcast(f32Cst(builder, -0.5f));
+  Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
+  Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
+  Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
+  Value cstCC10 =
+      bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
+
+  Value kMod4 = modulo4(fPToSingedInteger(k));
+
+  Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
+  Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
+  Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
+  Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
+
+  Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
+  Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
+                               : bitwiseOr(kR1, kR2);
+
+  Value y2 = mul(y, y);
+
+  Value base = select(sinuseCos, cstOne, y);
+  Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
+  Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
+  Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
+  Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
+  Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
+
+  Value v1 = fmla(y2, cstC10, cstC8);
+  Value v2 = fmla(y2, v1, cstC6);
+  Value v3 = fmla(y2, v2, cstC4);
+  Value v4 = fmla(y2, v3, cstC2);
+  Value v5 = fmla(y2, v4, cstOne);
+  Value v6 = mul(base, v5);
+
+  Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
+
+  rewriter.replaceOp(op, approximation);
+
+  return success();
+}
+
//----------------------------------------------------------------------------//

void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns) {
-               Log1pApproximation, ExpApproximation, ExpM1Approximation>(
+               Log1pApproximation, ExpApproximation, ExpM1Approximation,
+               SinAndCosApproximation<true, math::SinOp>,
+               SinAndCosApproximation<false, math::CosOp>>(
patterns.getContext());
}

diff  --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
index a3b06e9d7752d..8584d312b6b45 100644
--- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
@@ -219,6 +219,83 @@ func @expm1() {

return
}
+// -------------------------------------------------------------------------- //
+// Sin.
+// -------------------------------------------------------------------------- //
+func @sin() {
+  // CHECK: 0
+  %0 = constant 0.0 : f32
+  %sin_0 = math.sin %0 : f32
+  vector.print %sin_0 : f32
+
+  // CHECK: 0.707107
+  %pi_over_4 = constant 0.78539816339 : f32
+  %sin_pi_over_4 = math.sin %pi_over_4 : f32
+  vector.print %sin_pi_over_4 : f32
+
+  // CHECK: 1
+  %pi_over_2 = constant 1.57079632679 : f32
+  %sin_pi_over_2 = math.sin %pi_over_2 : f32
+  vector.print %sin_pi_over_2 : f32
+
+
+  // CHECK: 0
+  %pi = constant 3.14159265359 : f32
+  %sin_pi = math.sin %pi : f32
+  vector.print %sin_pi : f32
+
+  // CHECK: -1
+  %pi_3_over_2 = constant 4.71238898038 : f32
+  %sin_pi_3_over_2 = math.sin %pi_3_over_2 : f32
+  vector.print %sin_pi_3_over_2 : f32
+
+  // CHECK: 0, 0.866025, -1
+  %vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32>
+  %sin_vec_x = math.sin %vec_x : vector<3xf32>
+  vector.print %sin_vec_x : vector<3xf32>
+
+  return
+}
+
+// -------------------------------------------------------------------------- //
+// cos.
+// -------------------------------------------------------------------------- //
+
+func @cos() {
+  // CHECK: 1
+  %0 = constant 0.0 : f32
+  %cos_0 = math.cos %0 : f32
+  vector.print %cos_0 : f32
+
+  // CHECK: 0.707107
+  %pi_over_4 = constant 0.78539816339 : f32
+  %cos_pi_over_4 = math.cos %pi_over_4 : f32
+  vector.print %cos_pi_over_4 : f32
+
+  //// CHECK: 0
+  %pi_over_2 = constant 1.57079632679 : f32
+  %cos_pi_over_2 = math.cos %pi_over_2 : f32
+  vector.print %cos_pi_over_2 : f32
+
+  /// CHECK: -1
+  %pi = constant 3.14159265359 : f32
+  %cos_pi = math.cos %pi : f32
+  vector.print %cos_pi : f32
+
+  // CHECK: 0
+  %pi_3_over_2 = constant 4.71238898038 : f32
+  %cos_pi_3_over_2 = math.cos %pi_3_over_2 : f32
+  vector.print %cos_pi_3_over_2 : f32
+
+  // CHECK: -1, -0.5, 0
+  %vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32>
+  %cos_vec_x = math.cos %vec_x : vector<3xf32>
+  vector.print %cos_vec_x : vector<3xf32>
+
+
+  return
+}
+

func @main() {
call @tanh(): () -> ()
@@ -227,5 +304,7 @@ func @main() {
call @log1p(): () -> ()
call @exp(): () -> ()
call @expm1(): () -> ()
+  call @sin(): () -> ()
+  call @cos(): () -> ()
return
}