[Mlir-commits] [mlir] 7e2d672 - Add polynomial approximation for trigonometric sine and cosine functions
Ahmed Taei
llvmlistbot at llvm.org
Mon Jun 21 13:00:40 PDT 2021
Author: Ahmed S. Taei
Date: 2021-06-21T13:00:33-07:00
New Revision: 7e2d672a672c0559f6e5c417c5ee2514402cf18e
URL: https://github.com/llvm/llvm-project/commit/7e2d672a672c0559f6e5c417c5ee2514402cf18e
DIFF: https://github.com/llvm/llvm-project/commit/7e2d672a672c0559f6e5c417c5ee2514402cf18e.diff
LOG: Add polynomial approximation for trigonometric sine and cosine functions
The approximation relays on range reduced version y \in [0, pi/2]. An input x will have
the property that sin(x) = sin(y), -sin(y), cos(y), -cos(y) depends on which quadrable x
is in, where sin(y) and cos(y) are approximated with 5th degree polynomial (of x^2).
As a result a single pattern can be used to compute approximation for both sine and cosine.
Reviewed By: ezhulenev
Differential Revision: https://reviews.llvm.org/D104582
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index da58b244d6c67..7844d8c475d2d 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -630,11 +630,146 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
return success();
}
+//----------------------------------------------------------------------------//
+// Sin and Cos approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+#define TWO_OVER_PI \
+ 0.6366197723675813430755350534900574481378385829618257949906693762L
+#define PI_OVER_2 \
+ 1.5707963267948966192313216916397514420985846996875529104874722961L
+
+// Approximates sin(x) or cos(x) by finding the best approximation polynomial in
+// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
+// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
+template <bool isSine, typename OpTy>
+LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
+ OpTy op, PatternRewriter &rewriter) const {
+ static_assert(
+ llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
+ "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
+ auto width = vectorWidth(op.operand().getType(), isF32);
+ if (!width.hasValue())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, *width);
+ };
+ auto mul = [&](Value a, Value b) -> Value {
+ return builder.create<MulFOp>(a, b);
+ };
+ auto sub = [&](Value a, Value b) -> Value {
+ return builder.create<SubFOp>(a, b);
+ };
+ auto floor = [&](Value a) { return builder.create<FloorFOp>(a); };
+
+ auto i32Vec = broadcast(builder.getI32Type(), *width);
+ auto fPToSingedInteger = [&](Value a) -> Value {
+ return builder.create<FPToSIOp>(a, i32Vec);
+ };
+
+ auto modulo4 = [&](Value a) -> Value {
+ return builder.create<AndOp>(a, bcast(i32Cst(builder, 3)));
+ };
+
+ auto isEqualTo = [&](Value a, Value b) -> Value {
+ return builder.create<CmpIOp>(CmpIPredicate::eq, a, b);
+ };
+
+ auto isGreaterThan = [&](Value a, Value b) -> Value {
+ return builder.create<CmpIOp>(CmpIPredicate::sgt, a, b);
+ };
+
+ auto select = [&](Value cond, Value t, Value f) -> Value {
+ return builder.create<SelectOp>(cond, t, f);
+ };
+
+ auto fmla = [&](Value a, Value b, Value c) {
+ return builder.create<FmaFOp>(a, b, c);
+ };
+
+ auto bitwiseOr = [&](Value a, Value b) { return builder.create<OrOp>(a, b); };
+
+ Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI));
+ Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2));
+
+ Value x = op.operand();
+
+ Value k = floor(mul(x, twoOverPi));
+
+ Value y = sub(x, mul(k, piOverTwo));
+
+ Value cstOne = bcast(f32Cst(builder, 1.0));
+ Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
+
+ Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
+ Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
+ Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
+ Value cstSC8 =
+ bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
+ Value cstSC10 =
+ bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
+
+ Value cstCC2 = bcast(f32Cst(builder, -0.5f));
+ Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
+ Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
+ Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
+ Value cstCC10 =
+ bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
+
+ Value kMod4 = modulo4(fPToSingedInteger(k));
+
+ Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
+ Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
+ Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
+ Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
+
+ Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
+ Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
+ : bitwiseOr(kR1, kR2);
+
+ Value y2 = mul(y, y);
+
+ Value base = select(sinuseCos, cstOne, y);
+ Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
+ Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
+ Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
+ Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
+ Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
+
+ Value v1 = fmla(y2, cstC10, cstC8);
+ Value v2 = fmla(y2, v1, cstC6);
+ Value v3 = fmla(y2, v2, cstC4);
+ Value v4 = fmla(y2, v3, cstC2);
+ Value v5 = fmla(y2, v4, cstOne);
+ Value v6 = mul(base, v5);
+
+ Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
+
+ rewriter.replaceOp(op, approximation);
+
+ return success();
+}
+
//----------------------------------------------------------------------------//
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns) {
patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
- Log1pApproximation, ExpApproximation, ExpM1Approximation>(
+ Log1pApproximation, ExpApproximation, ExpM1Approximation,
+ SinAndCosApproximation<true, math::SinOp>,
+ SinAndCosApproximation<false, math::CosOp>>(
patterns.getContext());
}
diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
index a3b06e9d7752d..8584d312b6b45 100644
--- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir
@@ -219,6 +219,83 @@ func @expm1() {
return
}
+// -------------------------------------------------------------------------- //
+// Sin.
+// -------------------------------------------------------------------------- //
+func @sin() {
+ // CHECK: 0
+ %0 = constant 0.0 : f32
+ %sin_0 = math.sin %0 : f32
+ vector.print %sin_0 : f32
+
+ // CHECK: 0.707107
+ %pi_over_4 = constant 0.78539816339 : f32
+ %sin_pi_over_4 = math.sin %pi_over_4 : f32
+ vector.print %sin_pi_over_4 : f32
+
+ // CHECK: 1
+ %pi_over_2 = constant 1.57079632679 : f32
+ %sin_pi_over_2 = math.sin %pi_over_2 : f32
+ vector.print %sin_pi_over_2 : f32
+
+
+ // CHECK: 0
+ %pi = constant 3.14159265359 : f32
+ %sin_pi = math.sin %pi : f32
+ vector.print %sin_pi : f32
+
+ // CHECK: -1
+ %pi_3_over_2 = constant 4.71238898038 : f32
+ %sin_pi_3_over_2 = math.sin %pi_3_over_2 : f32
+ vector.print %sin_pi_3_over_2 : f32
+
+ // CHECK: 0, 0.866025, -1
+ %vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32>
+ %sin_vec_x = math.sin %vec_x : vector<3xf32>
+ vector.print %sin_vec_x : vector<3xf32>
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// cos.
+// -------------------------------------------------------------------------- //
+
+func @cos() {
+ // CHECK: 1
+ %0 = constant 0.0 : f32
+ %cos_0 = math.cos %0 : f32
+ vector.print %cos_0 : f32
+
+ // CHECK: 0.707107
+ %pi_over_4 = constant 0.78539816339 : f32
+ %cos_pi_over_4 = math.cos %pi_over_4 : f32
+ vector.print %cos_pi_over_4 : f32
+
+ //// CHECK: 0
+ %pi_over_2 = constant 1.57079632679 : f32
+ %cos_pi_over_2 = math.cos %pi_over_2 : f32
+ vector.print %cos_pi_over_2 : f32
+
+ /// CHECK: -1
+ %pi = constant 3.14159265359 : f32
+ %cos_pi = math.cos %pi : f32
+ vector.print %cos_pi : f32
+
+ // CHECK: 0
+ %pi_3_over_2 = constant 4.71238898038 : f32
+ %cos_pi_3_over_2 = math.cos %pi_3_over_2 : f32
+ vector.print %cos_pi_3_over_2 : f32
+
+ // CHECK: -1, -0.5, 0
+ %vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32>
+ %cos_vec_x = math.cos %vec_x : vector<3xf32>
+ vector.print %cos_vec_x : vector<3xf32>
+
+
+ return
+}
+
func @main() {
call @tanh(): () -> ()
@@ -227,5 +304,7 @@ func @main() {
call @log1p(): () -> ()
call @exp(): () -> ()
call @expm1(): () -> ()
+ call @sin(): () -> ()
+ call @cos(): () -> ()
return
}
More information about the Mlir-commits
mailing list