[Mlir-commits] [mlir] [mlir][math] Add conversions for acosh, asinh, atanh (PR #90718)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 6 07:54:24 PDT 2024


https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/90718

>From 511067125c52ff28f13477d5383813e48cd7f344 Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Tue, 30 Apr 2024 17:14:03 -0700
Subject: [PATCH] [mlir][math] Add Polynomial Approximation for acosh, asinh,
 atanh ops

---
 .../mlir/Dialect/Math/Transforms/Passes.h     |   4 +-
 .../Math/Transforms/ExpandPatterns.cpp        | 127 ++++++++++--------
 mlir/test/lib/Dialect/Math/TestExpandMath.cpp |   4 +-
 .../test-expand-math-approx.mlir              | 119 ++++++++++++++++
 4 files changed, 198 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a5..e86d8c541777e5 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -30,7 +30,9 @@ void populateExpandCtlzPattern(RewritePatternSet &patterns);
 void populateExpandTanPattern(RewritePatternSet &patterns);
 void populateExpandSinhPattern(RewritePatternSet &patterns);
 void populateExpandCoshPattern(RewritePatternSet &patterns);
-void populateExpandTanhPattern(RewritePatternSet &patterns);
+void populateExpandAsinhPattern(RewritePatternSet &patterns);
+void populateExpandAcoshPattern(RewritePatternSet &patterns);
+void populateExpandAtanhPattern(RewritePatternSet &patterns);
 void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 42629e149e9ffb..18d3e2bd328918 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
-  Value exp = b.create<math::ExpOp>(operand);
 
-  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value nexp = b.create<arith::DivFOp>(one, exp);
+  Value exp = b.create<math::ExpOp>(operand);
+  Value neg = b.create<arith::NegFOp>(operand);
+  Value nexp = b.create<math::ExpOp>(neg);
   Value sub = b.create<arith::SubFOp>(exp, nexp);
-  Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
-  Value div = b.create<arith::DivFOp>(sub, two);
-  rewriter.replaceOp(op, div);
+  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+  Value res = b.create<arith::MulFOp>(sub, half);
+  rewriter.replaceOp(op, res);
   return success();
 }
 
@@ -89,54 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
-  Value exp = b.create<math::ExpOp>(operand);
 
-  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
-  Value nexp = b.create<arith::DivFOp>(one, exp);
+  Value exp = b.create<math::ExpOp>(operand);
+  Value neg = b.create<arith::NegFOp>(operand);
+  Value nexp = b.create<math::ExpOp>(neg);
   Value add = b.create<arith::AddFOp>(exp, nexp);
-  Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
-  Value div = b.create<arith::DivFOp>(add, two);
-  rewriter.replaceOp(op, div);
-  return success();
-}
-
-/// Expands tanh op into
-/// 1-exp^{-2x} / 1+exp^{-2x}
-/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
-/// We compute a "signs" value which is -1 if input is negative and +1 if input
-/// is positive.  Then multiply the input by this value, guaranteeing that the
-/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
-/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
-/// result by `sign(x)` to retain sign of the real result.
-static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
-  auto floatType = op.getOperand().getType();
-  Location loc = op.getLoc();
-  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
-  Value one = createFloatConst(loc, floatType, 1.0, rewriter);
-  Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
-
-  // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
-  Value isNegative = rewriter.create<arith::CmpFOp>(
-      loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
-  Value isNegativeFloat =
-      rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
-  Value isNegativeTimesNegTwo =
-      rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
-  Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
-
-  // Normalize input to positive value: y = sign(x) * x
-  Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
-
-  // Decompose on normalized input
-  Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
-  Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
-  Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
-  Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
-  Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
-
-  // Multiply result by sign(x) to retain signs from negative inputs
-  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
-
+  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+  Value res = b.create<arith::MulFOp>(add, half);
+  rewriter.replaceOp(op, res);
   return success();
 }
 
@@ -152,6 +112,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+// asinh(float x) -> log(x + sqrt(x**2 + 1))
+static LogicalResult convertAsinhOp(math::AsinhOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+
+  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+  Value fma = b.create<math::FmaOp>(operand, operand, one);
+  Value sqrt = b.create<math::SqrtOp>(fma);
+  Value add = b.create<arith::AddFOp>(operand, sqrt);
+  Value res = b.create<math::LogOp>(add);
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
+// acosh(float x) -> log(x + sqrt(x**2 - 1))
+static LogicalResult convertAcoshOp(math::AcoshOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+
+  Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
+  Value fma = b.create<math::FmaOp>(operand, operand, negOne);
+  Value sqrt = b.create<math::SqrtOp>(fma);
+  Value add = b.create<arith::AddFOp>(operand, sqrt);
+  Value res = b.create<math::LogOp>(add);
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
+// atanh(float x) -> log((1 + x) / (1 - x)) / 2
+static LogicalResult convertAtanhOp(math::AtanhOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+
+  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+  Value add = b.create<arith::AddFOp>(operand, one);
+  Value neg = b.create<arith::NegFOp>(operand);
+  Value sub = b.create<arith::AddFOp>(neg, one);
+  Value div = b.create<arith::DivFOp>(add, sub);
+  Value log = b.create<math::LogOp>(div);
+  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+  Value res = b.create<arith::MulFOp>(log, half);
+  rewriter.replaceOp(op, res);
+  return success();
+}
+
 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operandA = op.getOperand(0);
@@ -580,8 +591,16 @@ void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
   patterns.add(convertTanOp);
 }
 
-void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
-  patterns.add(convertTanhOp);
+void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
+  patterns.add(convertAsinhOp);
+}
+
+void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
+  patterns.add(convertAcoshOp);
+}
+
+void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
+  patterns.add(convertAtanhOp);
 }
 
 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 97600ad1ebe7a3..87cad41e79bdbc 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -41,7 +41,9 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandTanPattern(patterns);
   populateExpandSinhPattern(patterns);
   populateExpandCoshPattern(patterns);
-  populateExpandTanhPattern(patterns);
+  populateExpandAsinhPattern(patterns);
+  populateExpandAcoshPattern(patterns);
+  populateExpandAtanhPattern(patterns);
   populateExpandFmaFPattern(patterns);
   populateExpandFloorFPattern(patterns);
   populateExpandCeilFPattern(patterns);
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 340ef30bf59c29..2b72acde6a3bb7 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -717,6 +717,122 @@ func.func @tanh() {
  return
 }
 
+// -------------------------------------------------------------------------- //
+// Asinh.
+// -------------------------------------------------------------------------- //
+
+func.func @asinh_f32(%a : f32) {
+  %r = math.asinh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @asinh_3xf32(%a : vector<3xf32>) {
+  %r = math.asinh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @asinh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @asinh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.881374
+  %cst1 = arith.constant 1.0 : f32
+  call @asinh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.881374
+  %cst2 = arith.constant -1.0 : f32
+  call @asinh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 1.81845
+  %cst3 = arith.constant 3.0 : f32
+  call @asinh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.247466, 0.790169, 1.44364
+  %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+  call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+// -------------------------------------------------------------------------- //
+// Acosh.
+// -------------------------------------------------------------------------- //
+
+func.func @acosh_f32(%a : f32) {
+  %r = math.acosh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @acosh_3xf32(%a : vector<3xf32>) {
+  %r = math.acosh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @acosh() {
+  // CHECK: 0
+  %zero = arith.constant 1.0 : f32
+  call @acosh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 1.31696
+  %cst1 = arith.constant 2.0 : f32
+  call @acosh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: 2.99322
+  %cst2 = arith.constant 10.0 : f32
+  call @acosh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: 0.962424, 1.76275, 2.47789
+  %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
+  call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
+// -------------------------------------------------------------------------- //
+// Atanh.
+// -------------------------------------------------------------------------- //
+
+func.func @atanh_f32(%a : f32) {
+  %r = math.atanh %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @atanh_3xf32(%a : vector<3xf32>) {
+  %r = math.atanh %a : vector<3xf32>
+  vector.print %r : vector<3xf32>
+  return
+}
+
+func.func @atanh() {
+  // CHECK: 0
+  %zero = arith.constant 0.0 : f32
+  call @atanh_f32(%zero) : (f32) -> ()
+
+  // CHECK: 0.549306
+  %cst1 = arith.constant 0.5 : f32
+  call @atanh_f32(%cst1) : (f32) -> ()
+
+  // CHECK: -0.549306
+  %cst2 = arith.constant -0.5 : f32
+  call @atanh_f32(%cst2) : (f32) -> ()
+
+  // CHECK: inf
+  %cst3 = arith.constant 1.0 : f32
+  call @atanh_f32(%cst3) : (f32) -> ()
+
+  // CHECK: 0.255413, 0.394229, 2.99448
+  %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
+  call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+  return
+}
+
 func.func @main() {
   call @exp2f() : () -> ()
   call @roundf() : () -> ()
@@ -725,5 +841,8 @@ func.func @main() {
   call @sinh() : () -> ()
   call @cosh() : () -> ()
   call @tanh() : () -> ()
+  call @asinh() : () -> ()
+  call @acosh() : () -> ()
+  call @atanh() : () -> ()
   return
 }



More information about the Mlir-commits mailing list