[Mlir-commits] [mlir] [mlir][math] Add conversions for acosh, asinh, atanh (PR #90718)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 6 08:12:17 PDT 2024
https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/90718
>From 56702629e7b76b110850ed0f5532b841b14837c2 Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Tue, 30 Apr 2024 17:14:03 -0700
Subject: [PATCH] [mlir][math] Add Polynomial Approximation for acosh, asinh,
atanh ops
---
.../mlir/Dialect/Math/Transforms/Passes.h | 4 +-
.../Math/Transforms/ExpandPatterns.cpp | 127 ++++++++++--------
mlir/test/Dialect/Math/expand-math.mlir | 33 -----
mlir/test/lib/Dialect/Math/TestExpandMath.cpp | 4 +-
.../test-expand-math-approx.mlir | 120 +++++++++++++++--
5 files changed, 189 insertions(+), 99 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a5..e86d8c541777e5 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -30,7 +30,9 @@ void populateExpandCtlzPattern(RewritePatternSet &patterns);
void populateExpandTanPattern(RewritePatternSet &patterns);
void populateExpandSinhPattern(RewritePatternSet &patterns);
void populateExpandCoshPattern(RewritePatternSet &patterns);
-void populateExpandTanhPattern(RewritePatternSet &patterns);
+void populateExpandAsinhPattern(RewritePatternSet &patterns);
+void populateExpandAcoshPattern(RewritePatternSet &patterns);
+void populateExpandAtanhPattern(RewritePatternSet &patterns);
void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 42629e149e9ffb..18d3e2bd328918 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value nexp = b.create<arith::DivFOp>(one, exp);
+ Value exp = b.create<math::ExpOp>(operand);
+ Value neg = b.create<arith::NegFOp>(operand);
+ Value nexp = b.create<math::ExpOp>(neg);
Value sub = b.create<arith::SubFOp>(exp, nexp);
- Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
- Value div = b.create<arith::DivFOp>(sub, two);
- rewriter.replaceOp(op, div);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value res = b.create<arith::MulFOp>(sub, half);
+ rewriter.replaceOp(op, res);
return success();
}
@@ -89,54 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value nexp = b.create<arith::DivFOp>(one, exp);
+ Value exp = b.create<math::ExpOp>(operand);
+ Value neg = b.create<arith::NegFOp>(operand);
+ Value nexp = b.create<math::ExpOp>(neg);
Value add = b.create<arith::AddFOp>(exp, nexp);
- Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
- Value div = b.create<arith::DivFOp>(add, two);
- rewriter.replaceOp(op, div);
- return success();
-}
-
-/// Expands tanh op into
-/// 1-exp^{-2x} / 1+exp^{-2x}
-/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
-/// We compute a "signs" value which is -1 if input is negative and +1 if input
-/// is positive. Then multiply the input by this value, guaranteeing that the
-/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
-/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
-/// result by `sign(x)` to retain sign of the real result.
-static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
- auto floatType = op.getOperand().getType();
- Location loc = op.getLoc();
- Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
- Value one = createFloatConst(loc, floatType, 1.0, rewriter);
- Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
-
- // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
- Value isNegative = rewriter.create<arith::CmpFOp>(
- loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
- Value isNegativeFloat =
- rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
- Value isNegativeTimesNegTwo =
- rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
- Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
-
- // Normalize input to positive value: y = sign(x) * x
- Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
-
- // Decompose on normalized input
- Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
- Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
- Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
- Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
- Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
-
- // Multiply result by sign(x) to retain signs from negative inputs
- rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
-
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value res = b.create<arith::MulFOp>(add, half);
+ rewriter.replaceOp(op, res);
return success();
}
@@ -152,6 +112,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
return success();
}
+// asinh(float x) -> log(x + sqrt(x**2 + 1))
+static LogicalResult convertAsinhOp(math::AsinhOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+ Value fma = b.create<math::FmaOp>(operand, operand, one);
+ Value sqrt = b.create<math::SqrtOp>(fma);
+ Value add = b.create<arith::AddFOp>(operand, sqrt);
+ Value res = b.create<math::LogOp>(add);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
+// acosh(float x) -> log(x + sqrt(x**2 - 1))
+static LogicalResult convertAcoshOp(math::AcoshOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
+ Value fma = b.create<math::FmaOp>(operand, operand, negOne);
+ Value sqrt = b.create<math::SqrtOp>(fma);
+ Value add = b.create<arith::AddFOp>(operand, sqrt);
+ Value res = b.create<math::LogOp>(add);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
+// atanh(float x) -> log((1 + x) / (1 - x)) / 2
+static LogicalResult convertAtanhOp(math::AtanhOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+ Value add = b.create<arith::AddFOp>(operand, one);
+ Value neg = b.create<arith::NegFOp>(operand);
+ Value sub = b.create<arith::AddFOp>(neg, one);
+ Value div = b.create<arith::DivFOp>(add, sub);
+ Value log = b.create<math::LogOp>(div);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value res = b.create<arith::MulFOp>(log, half);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
@@ -580,8 +591,16 @@ void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
patterns.add(convertTanOp);
}
-void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
- patterns.add(convertTanhOp);
+void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
+ patterns.add(convertAsinhOp);
+}
+
+void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
+ patterns.add(convertAcoshOp);
+}
+
+void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
+ patterns.add(convertAtanhOp);
}
void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 3d94b55126d097..f57c6e3ddacba6 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -1,38 +1,5 @@
// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
-// CHECK-LABEL: func @tanh
-func.func @tanh(%arg: f32) -> f32 {
- %res = math.tanh %arg : f32
- return %res : f32
-}
-// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 : f32
-// CHECK-DAG: %[[TWO:.+]] = arith.constant -2.000000e+00 : f32
-// CHECK: %[[VAL0:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] : f32
-// CHECK: %[[VAL1:.+]] = arith.uitofp %[[VAL0]] : i1 to f32
-// CHECK: %[[VAL2:.+]] = arith.mulf %[[VAL1]], %[[TWO]] : f32
-// CHECK: %[[SIGN:.+]] = arith.addf %[[VAL2]], %[[ONE]] : f32
-// CHECK: %[[POSX:.+]] = arith.mulf %[[SIGN]], %arg0 : f32
-// CHECK: %[[NEGDOUBLEDX:.+]] = arith.mulf %[[POSX]], %[[TWO]] : f32
-// CHECK: %[[EXP1:.+]] = math.exp %[[NEGDOUBLEDX]] : f32
-// CHECK: %[[DIVIDEND1:.+]] = arith.subf %[[ONE]], %[[EXP1]] : f32
-// CHECK: %[[DIVISOR1:.+]] = arith.addf %[[EXP1]], %[[ONE]] : f32
-// CHECK: %[[POSRES:.+]] = arith.divf %[[DIVIDEND1]], %[[DIVISOR1]] : f32
-// CHECK: %[[RESULT:.+]] = arith.mulf %[[SIGN]], %[[POSRES]] : f32
-// CHECK: return %[[RESULT]]
-
-// -----
-
-
-// CHECK-LABEL: func @vector_tanh
-func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
- // CHECK-NOT: math.tanh
- %res = math.tanh %arg : vector<4xf32>
- return %res : vector<4xf32>
-}
-
-// -----
-
// CHECK-LABEL: func @tan
func.func @tan(%arg: f32) -> f32 {
%res = math.tan %arg : f32
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 97600ad1ebe7a3..87cad41e79bdbc 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -41,7 +41,9 @@ void TestExpandMathPass::runOnOperation() {
populateExpandTanPattern(patterns);
populateExpandSinhPattern(patterns);
populateExpandCoshPattern(patterns);
- populateExpandTanhPattern(patterns);
+ populateExpandAsinhPattern(patterns);
+ populateExpandAcoshPattern(patterns);
+ populateExpandAtanhPattern(patterns);
populateExpandFmaFPattern(patterns);
populateExpandFloorFPattern(patterns);
populateExpandCeilFPattern(patterns);
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 340ef30bf59c29..7f7ad9cc7efc13 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -700,21 +700,119 @@ func.func @cosh() {
}
// -------------------------------------------------------------------------- //
-// Tanh.
+// Asinh.
// -------------------------------------------------------------------------- //
-func.func @tanh_8xf32(%a : vector<8xf32>) {
- %r = math.tanh %a : vector<8xf32>
- vector.print %r : vector<8xf32>
+func.func @asinh_f32(%a : f32) {
+ %r = math.asinh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @asinh_3xf32(%a : vector<3xf32>) {
+ %r = math.asinh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
return
}
-func.func @tanh() {
- // CHECK: -1, -0.761594, -0.291313, 0, 0.291313, 0.761594, 1, 1
- %v3 = arith.constant dense<[0xff800000, -1.0, -0.3, 0.0, 0.3, 1.0, 10.0, 0x7f800000]> : vector<8xf32>
- call @tanh_8xf32(%v3) : (vector<8xf32>) -> ()
+func.func @asinh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @asinh_f32(%zero) : (f32) -> ()
- return
+ // CHECK: 0.881374
+ %cst1 = arith.constant 1.0 : f32
+ call @asinh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.881374
+ %cst2 = arith.constant -1.0 : f32
+ call @asinh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 1.81845
+ %cst3 = arith.constant 3.0 : f32
+ call @asinh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.247466, 0.790169, 1.44364
+ %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+ call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Acosh.
+// -------------------------------------------------------------------------- //
+
+func.func @acosh_f32(%a : f32) {
+ %r = math.acosh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @acosh_3xf32(%a : vector<3xf32>) {
+ %r = math.acosh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @acosh() {
+ // CHECK: 0
+ %zero = arith.constant 1.0 : f32
+ call @acosh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 1.31696
+ %cst1 = arith.constant 2.0 : f32
+ call @acosh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: 2.99322
+ %cst2 = arith.constant 10.0 : f32
+ call @acosh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 0.962424, 1.76275, 2.47789
+ %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
+ call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Atanh.
+// -------------------------------------------------------------------------- //
+
+func.func @atanh_f32(%a : f32) {
+ %r = math.atanh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @atanh_3xf32(%a : vector<3xf32>) {
+ %r = math.atanh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @atanh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @atanh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.549306
+ %cst1 = arith.constant 0.5 : f32
+ call @atanh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.549306
+ %cst2 = arith.constant -0.5 : f32
+ call @atanh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: inf
+ %cst3 = arith.constant 1.0 : f32
+ call @atanh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.255413, 0.394229, 2.99448
+ %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
+ call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
}
func.func @main() {
@@ -724,6 +822,8 @@ func.func @main() {
call @roundeven() : () -> ()
call @sinh() : () -> ()
call @cosh() : () -> ()
- call @tanh() : () -> ()
+ call @asinh() : () -> ()
+ call @acosh() : () -> ()
+ call @atanh() : () -> ()
return
}
More information about the Mlir-commits
mailing list