[Mlir-commits] [mlir] [mlir][math] Add Polynomial Approximation for few ops (PR #90718)
Prashant Kumar
llvmlistbot at llvm.org
Wed May 1 04:34:13 PDT 2024
================
@@ -1316,6 +1358,106 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
return success();
}
+//----------------------------------------------------------------------------//
+// SinhOp and CoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct SinhAndCoshApproximation : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult SinhAndCoshApproximation<isSine, OpTy>::matchAndRewrite(
+ OpTy op, PatternRewriter &rewriter) const {
+ static_assert(
+ llvm::is_one_of<OpTy, math::SinhOp, math::CoshOp>::value,
+ "SinAndCosApproximation pattern expects math::SinhOp or math::CoshOp");
+
+ if (!getElementTypeOrSelf(op.getOperand()).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto operand = op.getOperand();
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ // sinh: 1/2 * (exp(x) – exp(-x))
+ // cosh: 1/2 * (exp(x) + exp(-x))
+ Value a = builder.create<math::ExpOp>(operand);
+ Value neg = builder.create<arith::NegFOp>(operand);
+ Value b = builder.create<math::ExpOp>(neg);
+ Value c;
+ if (isSine)
+ c = builder.create<arith::SubFOp>(a, b);
+ else
+ c = builder.create<arith::AddFOp>(a, b);
+ Value cstTwo = bcast(f32Cst(builder, 2.0));
+ Value res = builder.create<arith::DivFOp>(c, cstTwo);
+ rewriter.replaceOp(op, res);
+
+ return success();
+}
+
+//----------------------------------------------------------------------------//
+// AsinhOp and AcoshOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+
+template <bool isSine, typename OpTy>
+struct AsinhAndAcoshApproximation : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+template <bool isSine, typename OpTy>
+LogicalResult AsinhAndAcoshApproximation<isSine, OpTy>::matchAndRewrite(
+ OpTy op, PatternRewriter &rewriter) const {
+ static_assert(
+ llvm::is_one_of<OpTy, math::AsinhOp, math::AcoshOp>::value,
+ "SinAndCosApproximation pattern expects math::AsinhOp or math::AcoshOp");
+
+ if (!getElementTypeOrSelf(op.getOperand()).isF32())
+ return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+ auto operand = op.getOperand();
+ VectorShape shape = vectorShape(operand);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ // asinh: log(x + sqrt(x**2 + 1))
+ // acosh: log(x + sqrt(x**2 - 1))
+ Value squared = builder.create<arith::MulFOp>(operand, operand);
+ Value cstOne = bcast(f32Cst(builder, 1.0));
+ Value a;
+ if (isSine)
----------------
pashu123 wrote:
Can use the same ternary statement.
https://github.com/llvm/llvm-project/pull/90718
More information about the Mlir-commits
mailing list