[Mlir-commits] [mlir] [mlir][math] Add conversions for acosh, asinh, atanh (PR #90718)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 3 20:20:59 PDT 2024
https://github.com/jinchen62 updated https://github.com/llvm/llvm-project/pull/90718
>From 1d3a1a8a14e96baddcf1aa402928bd4baaa02f5e Mon Sep 17 00:00:00 2001
From: jinchen62 <jinchenye62 at gmail.com>
Date: Tue, 30 Apr 2024 17:14:03 -0700
Subject: [PATCH] [mlir][math] Add Polynomial Approximation for acosh, asinh,
atanh ops
---
.../mlir/Dialect/Math/Transforms/Passes.h | 4 +
.../Math/Transforms/ExpandPatterns.cpp | 106 ++++++++++++++--
mlir/test/lib/Dialect/Math/TestExpandMath.cpp | 14 +--
.../test-expand-math-approx.mlir | 119 ++++++++++++++++++
4 files changed, 218 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index e2c513047c77a5..e9719e7949c28a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -31,6 +31,9 @@ void populateExpandTanPattern(RewritePatternSet &patterns);
void populateExpandSinhPattern(RewritePatternSet &patterns);
void populateExpandCoshPattern(RewritePatternSet &patterns);
void populateExpandTanhPattern(RewritePatternSet &patterns);
+void populateExpandAsinhPattern(RewritePatternSet &patterns);
+void populateExpandAcoshPattern(RewritePatternSet &patterns);
+void populateExpandAtanhPattern(RewritePatternSet &patterns);
void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
@@ -39,6 +42,7 @@ void populateExpandPowFPattern(RewritePatternSet &patterns);
void populateExpandFPowIPattern(RewritePatternSet &patterns);
void populateExpandRoundFPattern(RewritePatternSet &patterns);
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
+void populateMathExpandPatterns(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 42629e149e9ffb..4fdb10a021b09f 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -73,14 +73,14 @@ static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value nexp = b.create<arith::DivFOp>(one, exp);
+ Value exp = b.create<math::ExpOp>(operand);
+ Value neg = b.create<arith::NegFOp>(operand);
+ Value nexp = b.create<math::ExpOp>(neg);
Value sub = b.create<arith::SubFOp>(exp, nexp);
- Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
- Value div = b.create<arith::DivFOp>(sub, two);
- rewriter.replaceOp(op, div);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value res = b.create<arith::MulFOp>(sub, half);
+ rewriter.replaceOp(op, res);
return success();
}
@@ -89,14 +89,14 @@ static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
- Value exp = b.create<math::ExpOp>(operand);
- Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
- Value nexp = b.create<arith::DivFOp>(one, exp);
+ Value exp = b.create<math::ExpOp>(operand);
+ Value neg = b.create<arith::NegFOp>(operand);
+ Value nexp = b.create<math::ExpOp>(neg);
Value add = b.create<arith::AddFOp>(exp, nexp);
- Value two = createFloatConst(op->getLoc(), opType, 2.0, rewriter);
- Value div = b.create<arith::DivFOp>(add, two);
- rewriter.replaceOp(op, div);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value res = b.create<arith::MulFOp>(add, half);
+ rewriter.replaceOp(op, res);
return success();
}
@@ -152,6 +152,57 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
return success();
}
+// asinh(float x) -> log(x + sqrt(x**2 + 1))
+static LogicalResult convertAsinhOp(math::AsinhOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+ Value fma = b.create<math::FmaOp>(operand, operand, one);
+ Value sqrt = b.create<math::SqrtOp>(fma);
+ Value add = b.create<arith::AddFOp>(operand, sqrt);
+ Value res = b.create<math::LogOp>(add);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
+// acosh(float x) -> log(x + sqrt(x**2 - 1))
+static LogicalResult convertAcoshOp(math::AcoshOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
+ Value fma = b.create<math::FmaOp>(operand, operand, negOne);
+ Value sqrt = b.create<math::SqrtOp>(fma);
+ Value add = b.create<arith::AddFOp>(operand, sqrt);
+ Value res = b.create<math::LogOp>(add);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
+// atanh(float x) -> log((1 + x) / (1 - x)) / 2
+static LogicalResult convertAtanhOp(math::AtanhOp op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+
+ Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
+ Value add = b.create<arith::AddFOp>(operand, one);
+ Value neg = b.create<arith::NegFOp>(operand);
+ Value sub = b.create<arith::AddFOp>(neg, one);
+ Value div = b.create<arith::DivFOp>(add, sub);
+ Value log = b.create<math::LogOp>(div);
+ Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+ Value res = b.create<arith::MulFOp>(log, half);
+ rewriter.replaceOp(op, res);
+ return success();
+}
+
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operandA = op.getOperand(0);
@@ -584,6 +635,18 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
patterns.add(convertTanhOp);
}
+void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
+ patterns.add(convertAsinhOp);
+}
+
+void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
+ patterns.add(convertAcoshOp);
+}
+
+void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
+ patterns.add(convertAtanhOp);
+}
+
void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
patterns.add(convertFmaFOp);
}
@@ -615,3 +678,22 @@ void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
patterns.add(convertRoundEvenOp);
}
+
+void mlir::populateMathExpandPatterns(RewritePatternSet &patterns) {
+ patterns.add(convertCtlzOp);
+ patterns.add(convertSinhOp);
+ patterns.add(convertCoshOp);
+ patterns.add(convertTanOp);
+ patterns.add(convertTanhOp);
+ patterns.add(convertAsinhOp);
+ patterns.add(convertAcoshOp);
+ patterns.add(convertAtanhOp);
+ patterns.add(convertFmaFOp);
+ patterns.add(convertCeilOp);
+ patterns.add(convertExp2fOp);
+ patterns.add(convertPowfOp);
+ patterns.add(convertFPowIOp);
+ patterns.add(convertRoundOp);
+ patterns.add(convertFloorOp);
+ patterns.add(convertRoundEvenOp);
+}
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 97600ad1ebe7a3..c90794d5c99a05 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -36,19 +36,7 @@ struct TestExpandMathPass
void TestExpandMathPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
- populateExpandCtlzPattern(patterns);
- populateExpandExp2FPattern(patterns);
- populateExpandTanPattern(patterns);
- populateExpandSinhPattern(patterns);
- populateExpandCoshPattern(patterns);
- populateExpandTanhPattern(patterns);
- populateExpandFmaFPattern(patterns);
- populateExpandFloorFPattern(patterns);
- populateExpandCeilFPattern(patterns);
- populateExpandPowFPattern(patterns);
- populateExpandFPowIPattern(patterns);
- populateExpandRoundFPattern(patterns);
- populateExpandRoundEvenPattern(patterns);
+ populateMathExpandPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 340ef30bf59c29..2b72acde6a3bb7 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -717,6 +717,122 @@ func.func @tanh() {
return
}
+// -------------------------------------------------------------------------- //
+// Asinh.
+// -------------------------------------------------------------------------- //
+
+func.func @asinh_f32(%a : f32) {
+ %r = math.asinh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @asinh_3xf32(%a : vector<3xf32>) {
+ %r = math.asinh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @asinh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @asinh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.881374
+ %cst1 = arith.constant 1.0 : f32
+ call @asinh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.881374
+ %cst2 = arith.constant -1.0 : f32
+ call @asinh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 1.81845
+ %cst3 = arith.constant 3.0 : f32
+ call @asinh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.247466, 0.790169, 1.44364
+ %vec_x = arith.constant dense<[0.25, 0.875, 2.0]> : vector<3xf32>
+ call @asinh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Acosh.
+// -------------------------------------------------------------------------- //
+
+func.func @acosh_f32(%a : f32) {
+ %r = math.acosh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @acosh_3xf32(%a : vector<3xf32>) {
+ %r = math.acosh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @acosh() {
+ // CHECK: 0
+ %zero = arith.constant 1.0 : f32
+ call @acosh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 1.31696
+ %cst1 = arith.constant 2.0 : f32
+ call @acosh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: 2.99322
+ %cst2 = arith.constant 10.0 : f32
+ call @acosh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: 0.962424, 1.76275, 2.47789
+ %vec_x = arith.constant dense<[1.5, 3.0, 6.0]> : vector<3xf32>
+ call @acosh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
+// -------------------------------------------------------------------------- //
+// Atanh.
+// -------------------------------------------------------------------------- //
+
+func.func @atanh_f32(%a : f32) {
+ %r = math.atanh %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @atanh_3xf32(%a : vector<3xf32>) {
+ %r = math.atanh %a : vector<3xf32>
+ vector.print %r : vector<3xf32>
+ return
+}
+
+func.func @atanh() {
+ // CHECK: 0
+ %zero = arith.constant 0.0 : f32
+ call @atanh_f32(%zero) : (f32) -> ()
+
+ // CHECK: 0.549306
+ %cst1 = arith.constant 0.5 : f32
+ call @atanh_f32(%cst1) : (f32) -> ()
+
+ // CHECK: -0.549306
+ %cst2 = arith.constant -0.5 : f32
+ call @atanh_f32(%cst2) : (f32) -> ()
+
+ // CHECK: inf
+ %cst3 = arith.constant 1.0 : f32
+ call @atanh_f32(%cst3) : (f32) -> ()
+
+ // CHECK: 0.255413, 0.394229, 2.99448
+ %vec_x = arith.constant dense<[0.25, 0.375, 0.995]> : vector<3xf32>
+ call @atanh_3xf32(%vec_x) : (vector<3xf32>) -> ()
+
+ return
+}
+
func.func @main() {
call @exp2f() : () -> ()
call @roundf() : () -> ()
@@ -725,5 +841,8 @@ func.func @main() {
call @sinh() : () -> ()
call @cosh() : () -> ()
call @tanh() : () -> ()
+ call @asinh() : () -> ()
+ call @acosh() : () -> ()
+ call @atanh() : () -> ()
return
}
More information about the Mlir-commits
mailing list