[Mlir-commits] [mlir] [mlir][math] Add Polynomial Approximation for few ops (PR #90718)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 3 02:36:31 PDT 2024


https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/90718

>From 6d32d4a6c046d4c7c1a963eb206723dd38037de2 Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Tue, 30 Apr 2024 17:14:03 -0700
Subject: [PATCH] [mlir][math] Add Polynomial Approximation for acosh, asinh,
 atanh ops

---
 .../mlir/Dialect/Math/Transforms/Passes.h     |   3 +
 .../Math/Transforms/ExpandPatterns.cpp        |  87 +++++++++++--
 mlir/test/lib/Dialect/Math/TestExpandMath.cpp |   3 +
 .../test-expand-math-approx.mlir              | 119 ++++++++++++++++++
 4 files changed, 200 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a5..24e6d9a8d98e04 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -31,6 +31,9 @@ void populateExpandTanPattern(RewritePatternSet &patterns);
 void populateExpandSinhPattern(RewritePatternSet &patterns);
 void populateExpandCoshPattern(RewritePatternSet &patterns);
 void populateExpandTanhPattern(RewritePatternSet &patterns);
+void populateExpandAsinhPattern(RewritePatternSet &patterns);
+void populateExpandAcoshPattern(RewritePatternSet &patterns);
+void populateExpandAtanhPattern(RewritePatternSet &patterns);
 void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 42629e149e9ffb..5ccf3b6d72a2cc 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
-  Value exp = b.create<math::ExpOp>(operand);
 
-  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value nexp = b.create<arith::DivFOp>(one, exp);
+  Value exp = b.create<math::ExpOp>(operand);
+  Value neg = b.create<arith::NegFOp>(operand);
+  Value nexp = b.create<math::ExpOp>(neg);
   Value sub = b.create<arith::SubFOp>(exp, nexp);
-  Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
-  Value div = b.create<arith::DivFOp>(sub, two);
-  rewriter.replaceOp(op, div);
+  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+  Value res = b.create<arith::MulFOp>(sub, half);
+  rewriter.replaceOp(op, res);
   return success();
 }
 
@@ -89,14 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
-  Value exp = b.create<math::ExpOp>(operand);
 
-  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value nexp = b.create<arith::DivFOp>(one, exp);
+  Value exp = b.create<math::ExpOp>(operand);
+  Value neg = b.create<arith::NegFOp>(operand);
+  Value nexp = b.create<math::ExpOp>(neg);
   Value add = b.create<arith::AddFOp>(exp, nexp);
-  Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
-  Value div = b.create<arith::DivFOp>(add, two);
-  rewriter.replaceOp(op, div);
+  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+  Value res = b.create<arith::MulFOp>(add, half);
+  rewriter.replaceOp(op, res);
   return success();
 }
 
@@ -152,6 +152,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+// asinh(float x) -> log(x + sqrt(x**2 + 1))
+static LogicalResult convertAsinhOp(math::AsinhOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+
+  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+  Value fma = b.create<math::FmaOp>(operand, operand, one);
+  Value sqrt = b.create<math::SqrtOp>(fma);
+  Value add = b.create<arith::AddFOp>(operand, sqrt);
+  Value res = b.create<math::LogOp>(add);
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
+// acosh(float x) -> log(x + sqrt(x**2 - 1))
+static LogicalResult convertAcoshOp(math::AcoshOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+
+  Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
+  Value fma = b.create<math::FmaOp>(operand, operand, negOne);
+  Value sqrt = b.create<math::SqrtOp>(fma);
+  Value add = b.create<arith::AddFOp>(operand, sqrt);
+  Value res = b.create<math::LogOp>(add);
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
+// atanh(float x) -> log((1 + x) / (1 - x)) / 2
+static LogicalResult convertAtanhOp(math::AtanhOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+
+  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+  Value add = b.create<arith::AddFOp>(operand, one);
+  Value neg = b.create<arith::NegFOp>(operand);
+  Value sub = b.create<arith::AddFOp>(neg, one);
+  Value div = b.create<arith::DivFOp>(add, sub);
+  Value log = b.create<math::LogOp>(div);
+  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+  Value res = b.create<arith::MulFOp>(log, half);
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
@@ -584,6 +635,18 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
   patterns.add(convertTanhOp);
 }
 
+void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
+  patterns.add(convertAsinhOp);
+}
+
+void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
+  patterns.add(convertAcoshOp);
+}
+
+void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
+  patterns.add(convertAtanhOp);
+}
+
 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
   patterns.add(convertFmaFOp);
 }
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 97600ad1ebe7a3..da48ccb6e5e08f 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -42,6 +42,9 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandSinhPattern(patterns);
   populateExpandCoshPattern(patterns);
   populateExpandTanhPattern(patterns);
+  populateExpandAsinhPattern(patterns);
+  populateExpandAcoshPattern(patterns);
+  populateExpandAtanhPattern(patterns);
   populateExpandFmaFPattern(patterns);
   populateExpandFloorFPattern(patterns);
   populateExpandCeilFPattern(patterns);
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 340ef30bf59c29..2b72acde6a3bb7 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -717,6 +717,122 @@ func.func @tanh() {
  return
 }
 
+// -------------------------------------------------------------------------- //
+// Asinh.
+// -------------------------------------------------------------------------- //
+
+func.func @asinh_f32(%a : f32) {
+  %r = math.asinh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @asinh_3xf32(%a : vector<3xf32>) {
+  %r = math.asinh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @asinh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @asinh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.881374
+  %cst1 = arith.constant 1.0 : f32
+  call @asinh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.881374
+  %cst2 = arith.constant -1.0 : f32
+  call @asinh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 1.81845
+  %cst3 = arith.constant 3.0 : f32
+  call @asinh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.247466, 0.790169, 1.44364
+  %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+  call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+// -------------------------------------------------------------------------- //
+// Acosh.
+// -------------------------------------------------------------------------- //
+
+func.func @acosh_f32(%a : f32) {
+  %r = math.acosh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @acosh_3xf32(%a : vector<3xf32>) {
+  %r = math.acosh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @acosh() {
+  // CHECK: 0
+  %zero = arith.constant 1.0 : f32
+  call @acosh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 1.31696
+  %cst1 = arith.constant 2.0 : f32
+  call @acosh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: 2.99322
+  %cst2 = arith.constant 10.0 : f32
+  call @acosh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 0.962424, 1.76275, 2.47789
+  %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
+  call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+// -------------------------------------------------------------------------- //
+// Atanh.
+// -------------------------------------------------------------------------- //
+
+func.func @atanh_f32(%a : f32) {
+  %r = math.atanh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @atanh_3xf32(%a : vector<3xf32>) {
+  %r = math.atanh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @atanh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @atanh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.549306
+  %cst1 = arith.constant 0.5 : f32
+  call @atanh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.549306
+  %cst2 = arith.constant -0.5 : f32
+  call @atanh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: inf
+  %cst3 = arith.constant 1.0 : f32
+  call @atanh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.255413, 0.394229, 2.99448
+  %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
+  call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
 func.func @main() {
   call @exp2f() : () -> ()
   call @roundf() : () -> ()
@@ -725,5 +841,8 @@ func.func @main() {
   call @sinh() : () -> ()
   call @cosh() : () -> ()
   call @tanh() : () -> ()
+  call @asinh() : () -> ()
+  call @acosh() : () -> ()
+  call @atanh() : () -> ()
   return
 }



More information about the Mlir-commits mailing list