[Mlir-commits] [mlir] [mlir][math] Add Polynomial Approximation for few ops (PR #90718)

Prashant Kumar llvmlistbot at llvm.org
Wed May 1 04:34:13 PDT 2024


================
@@ -1316,6 +1358,106 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
   return success();
 }
 
+//----------------------------------------------------------------------------//
+// SinhOp and CoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct SinhAndCoshApproximation : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult SinhAndCoshApproximation<isSine, OpTy>::matchAndRewrite(
+    OpTy op, PatternRewriter &rewriter) const {
+  static_assert(
+      llvm::is_one_of<OpTy, math::SinhOp, math::CoshOp>::value,
+      "SinAndCosApproximation pattern expects math::SinhOp or math::CoshOp");
+
+  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  auto operand = op.getOperand();
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // sinh: 1/2 * (exp(x) – exp(-x))
+  // cosh: 1/2 * (exp(x) + exp(-x))
+  Value a = builder.create<math::ExpOp>(operand);
+  Value neg = builder.create<arith::NegFOp>(operand);
+  Value b = builder.create<math::ExpOp>(neg);
+  Value c;
+  if (isSine)
+    c = builder.create<arith::SubFOp>(a, b);
+  else
+    c = builder.create<arith::AddFOp>(a, b);
+  Value cstTwo = bcast(f32Cst(builder, 2.0));
+  Value res = builder.create<arith::DivFOp>(c, cstTwo);
+  rewriter.replaceOp(op, res);
+
+  return success();
+}
+
+//----------------------------------------------------------------------------//
+// AsinhOp and AcoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct AsinhAndAcoshApproximation : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult AsinhAndAcoshApproximation<isSine, OpTy>::matchAndRewrite(
+    OpTy op, PatternRewriter &rewriter) const {
+  static_assert(
+      llvm::is_one_of<OpTy, math::AsinhOp, math::AcoshOp>::value,
+      "SinAndCosApproximation pattern expects math::AsinhOp or math::AcoshOp");
+
+  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  auto operand = op.getOperand();
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // asinh: log(x + sqrt(x**2 + 1))
+  // acosh: log(x + sqrt(x**2 - 1))
+  Value squared = builder.create<arith::MulFOp>(operand, operand);
+  Value cstOne = bcast(f32Cst(builder, 1.0));
+  Value a;
+  if (isSine)
----------------
pashu123 wrote:

Can use the same ternary statement.

https://github.com/llvm/llvm-project/pull/90718


More information about the Mlir-commits mailing list