[Mlir-commits] [mlir] [mlir][math] Add Polynomial Approximation for few ops (PR #90718)

Prashant Kumar llvmlistbot at llvm.org
Wed May 1 04:34:13 PDT 2024


================
@@ -615,10 +615,47 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
   return success();
 }
 
-#define LN2_VALUE                                                              \
-  0.693147180559945309417232121458176568075500134360255254120680009493393621L
-#define LOG2E_VALUE                                                            \
-  1.442695040888963407359924681001892137426645954152985934135449406931109219L
+//----------------------------------------------------------------------------//
+// AtanhOp approximation.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct AtanhApproximation : public OpRewritePattern<math::AtanhOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::AtanhOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+AtanhApproximation::matchAndRewrite(math::AtanhOp op,
+                                    PatternRewriter &rewriter) const {
+  if (!getElementTypeOrSelf(op.getOperand()).isF32())
+    return rewriter.notifyMatchFailure(op, "unsupported operand type");
+
+  auto operand = op.getOperand();
+  VectorShape shape = vectorShape(operand);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  // 1/2 * log((1 + x) / (1 - x))
+  Value cstOne = bcast(f32Cst(builder, 1.0));
+  Value add = builder.create<arith::AddFOp>(operand, cstOne);
+  Value neg = builder.create<arith::NegFOp>(operand);
+  Value sub = builder.create<arith::AddFOp>(neg, cstOne);
+  Value div = builder.create<arith::DivFOp>(add, sub);
----------------
pashu123 wrote:

The only value where this might give Undefined behaviour is 1.0. Could you add the test and check whether it's returning infinity. To guard the behaviour you need to see https://github.com/llvm/llvm-project/blob/803e03fbb7cd97461f349fb6e235592681fc1e6c/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp#L257

https://github.com/llvm/llvm-project/pull/90718


More information about the Mlir-commits mailing list