[Mlir-commits] [mlir] [mlir][math] Add conversions for acosh, asinh, atanh (PR #90718)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 6 21:14:24 PDT 2024
https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/90718
>From 92387e0e447ae7b64171bb05d8c209f61eb05921 Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Tue, 30 Apr 2024 17:14:03 -0700
Subject: [PATCH] [mlir][math] Add Polynomial Approximation for acosh, asinh,
atanh ops
---
.../mlir/Dialect/Math/Transforms/Passes.h | 3 +
.../Math/Transforms/ExpandPatterns.cpp | 87 +++++++++++--
mlir/test/lib/Dialect/Math/TestExpandMath.cpp | 3 +
.../test-expand-math-approx.mlir | 119 ++++++++++++++++++
4 files changed, 200 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a..24e6d9a8d98e0 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -31,6 +31,9 @@ void populateExpandTanPattern(RewritePatternSet &patterns);
void populateExpandSinhPattern(RewritePatternSet &patterns);
void populateExpandCoshPattern(RewritePatternSet &patterns);
void populateExpandTanhPattern(RewritePatternSet &patterns);
+void populateExpandAsinhPattern(RewritePatternSet &patterns);
+void populateExpandAcoshPattern(RewritePatternSet &patterns);
+void populateExpandAtanhPattern(RewritePatternSet &patterns);
void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 42629e149e9ff..5ccf3b6d72a2c 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value nexp = b.create<arith::DivFOp>(one, exp);
+ Value exp = b.create<math::ExpOp>(operand);
+ Value neg = b.create<arith::NegFOp>(operand);
+ Value nexp = b.create<math::ExpOp>(neg);
Value sub = b.create<arith::SubFOp>(exp, nexp);
- Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
- Value div = b.create<arith::DivFOp>(sub, two);
- rewriter.replaceOp(op, div);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value res = b.create<arith::MulFOp>(sub, half);
+ rewriter.replaceOp(op, res);
return success();
}
@@ -89,14 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value nexp = b.create<arith::DivFOp>(one, exp);
+ Value exp = b.create<math::ExpOp>(operand);
+ Value neg = b.create<arith::NegFOp>(operand);
+ Value nexp = b.create<math::ExpOp>(neg);
Value add = b.create<arith::AddFOp>(exp, nexp);
- Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
- Value div = b.create<arith::DivFOp>(add, two);
- rewriter.replaceOp(op, div);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value res = b.create<arith::MulFOp>(add, half);
+ rewriter.replaceOp(op, res);
return success();
}
@@ -152,6 +152,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
return success();
}
+// asinh(float x) -> log(x + sqrt(x**2 + 1))
+static LogicalResult convertAsinhOp(math::AsinhOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+ Value fma = b.create<math::FmaOp>(operand, operand, one);
+ Value sqrt = b.create<math::SqrtOp>(fma);
+ Value add = b.create<arith::AddFOp>(operand, sqrt);
+ Value res = b.create<math::LogOp>(add);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
+// acosh(float x) -> log(x + sqrt(x**2 - 1))
+static LogicalResult convertAcoshOp(math::AcoshOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
+ Value fma = b.create<math::FmaOp>(operand, operand, negOne);
+ Value sqrt = b.create<math::SqrtOp>(fma);
+ Value add = b.create<arith::AddFOp>(operand, sqrt);
+ Value res = b.create<math::LogOp>(add);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
+// atanh(float x) -> log((1 + x) / (1 - x)) / 2
+static LogicalResult convertAtanhOp(math::AtanhOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+ Value add = b.create<arith::AddFOp>(operand, one);
+ Value neg = b.create<arith::NegFOp>(operand);
+ Value sub = b.create<arith::AddFOp>(neg, one);
+ Value div = b.create<arith::DivFOp>(add, sub);
+ Value log = b.create<math::LogOp>(div);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value res = b.create<arith::MulFOp>(log, half);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
@@ -584,6 +635,18 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
patterns.add(convertTanhOp);
}
+void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
+ patterns.add(convertAsinhOp);
+}
+
+void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
+ patterns.add(convertAcoshOp);
+}
+
+void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
+ patterns.add(convertAtanhOp);
+}
+
void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
patterns.add(convertFmaFOp);
}
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 97600ad1ebe7a..da48ccb6e5e08 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -42,6 +42,9 @@ void TestExpandMathPass::runOnOperation() {
populateExpandSinhPattern(patterns);
populateExpandCoshPattern(patterns);
populateExpandTanhPattern(patterns);
+ populateExpandAsinhPattern(patterns);
+ populateExpandAcoshPattern(patterns);
+ populateExpandAtanhPattern(patterns);
populateExpandFmaFPattern(patterns);
populateExpandFloorFPattern(patterns);
populateExpandCeilFPattern(patterns);
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 340ef30bf59c2..2b72acde6a3bb 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -717,6 +717,122 @@ func.func @tanh() {
return
}
+// -------------------------------------------------------------------------- //
+// Asinh.
+// -------------------------------------------------------------------------- //
+
+func.func @asinh_f32(%a : f32) {
+ %r = math.asinh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @asinh_3xf32(%a : vector<3xf32>) {
+ %r = math.asinh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @asinh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @asinh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.881374
+ %cst1 = arith.constant 1.0 : f32
+ call @asinh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.881374
+ %cst2 = arith.constant -1.0 : f32
+ call @asinh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 1.81845
+ %cst3 = arith.constant 3.0 : f32
+ call @asinh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.247466, 0.790169, 1.44364
+ %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+ call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Acosh.
+// -------------------------------------------------------------------------- //
+
+func.func @acosh_f32(%a : f32) {
+ %r = math.acosh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @acosh_3xf32(%a : vector<3xf32>) {
+ %r = math.acosh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @acosh() {
+ // CHECK: 0
+ %zero = arith.constant 1.0 : f32
+ call @acosh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 1.31696
+ %cst1 = arith.constant 2.0 : f32
+ call @acosh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: 2.99322
+ %cst2 = arith.constant 10.0 : f32
+ call @acosh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 0.962424, 1.76275, 2.47789
+ %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
+ call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Atanh.
+// -------------------------------------------------------------------------- //
+
+func.func @atanh_f32(%a : f32) {
+ %r = math.atanh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @atanh_3xf32(%a : vector<3xf32>) {
+ %r = math.atanh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @atanh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @atanh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.549306
+ %cst1 = arith.constant 0.5 : f32
+ call @atanh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.549306
+ %cst2 = arith.constant -0.5 : f32
+ call @atanh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: inf
+ %cst3 = arith.constant 1.0 : f32
+ call @atanh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.255413, 0.394229, 2.99448
+ %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
+ call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
func.func @main() {
call @exp2f() : () -> ()
call @roundf() : () -> ()
@@ -725,5 +841,8 @@ func.func @main() {
call @sinh() : () -> ()
call @cosh() : () -> ()
call @tanh() : () -> ()
+ call @asinh() : () -> ()
+ call @acosh() : () -> ()
+ call @atanh() : () -> ()
return
}
More information about the Mlir-commits
mailing list