[Mlir-commits] [mlir] 4da9651 - [mlir][math] Expand math.exp2 to use math.exp.
Robert Suderman
llvmlistbot at llvm.org
Thu Apr 13 09:08:19 PDT 2023
Author: Balaji V. Iyer
Date: 2023-04-13T16:06:04Z
New Revision: 4da96515ea8552cdf14c6aa6310d2a91fbe74641
URL: https://github.com/llvm/llvm-project/commit/4da96515ea8552cdf14c6aa6310d2a91fbe74641
DIFF: https://github.com/llvm/llvm-project/commit/4da96515ea8552cdf14c6aa6310d2a91fbe74641.diff
LOG: [mlir][math] Expand math.exp2 to use math.exp.
Exp2 functions are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will expand the exp2
function to use exp2 with the input multiplied by ln2 (natural log).
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D148064
Added:
mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
Modified:
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 1b32de2b99683..3ac18c3a24184 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -19,6 +19,7 @@ void populateExpandTanhPattern(RewritePatternSet &patterns);
void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateExpandFloorFPattern(RewritePatternSet &patterns);
void populateExpandCeilFPattern(RewritePatternSet &patterns);
+void populateExpandExp2FPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index b70ac4e006eac..e9447dcbd5394 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -158,6 +158,22 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
return success();
}
+// exp2f(float x) -> exp(x * ln(2))
+// Proof: Let's say 2^x = y
+// ln(2^x) = ln(y)
+// x * ln(2) = ln(y) => e ^(x*ln(2)) = y
+static LogicalResult convertExp2fOp(math::Exp2Op op,
+ PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+ Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
+ Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
+ Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
+ rewriter.replaceOp(op, exp);
+ return success();
+}
+
// Converts math.ctlz to scf and arith operations. This is done
// by performing a binary search on the bits.
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -222,6 +238,10 @@ void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
patterns.add(convertCeilOp);
}
+void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
+ patterns.add(convertExp2fOp);
+}
+
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
patterns.add(convertFloorOp);
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 8ab644931b669..50986969fca57 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -165,3 +165,27 @@ func.func @ceilf_func(%a: f64) -> f64 {
%ret = math.ceil %a : f64
return %ret : f64
}
+
+// -----
+
+// CHECK-LABEL: func @exp2f_func
+// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
+func.func @exp2f_func(%a: f64) -> f64 {
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.69314718055994529
+ // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
+ // CHECK: [[EXP:%.+]] = math.exp [[MULF]]
+ // CHECK: return [[EXP]]
+ %ret = math.exp2 %a : f64
+ return %ret : f64
+}
+
+// CHECK-LABEL: func @exp2f_func_tensor
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<1xf32>) -> tensor<1xf32>
+func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
+ // CHECK-DAG: [[CST:%.+]] = arith.constant dense<0.693147182>
+ // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
+ // CHECK: [[EXP:%.+]] = math.exp [[MULF]]
+ // CHECK: return [[EXP]]
+ %ret = math.exp2 %a : tensor<1xf32>
+ return %ret : tensor<1xf32>
+}
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index c670617f8446f..29eff9959bebc 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -37,6 +37,7 @@ struct TestExpandMathPass
void TestExpandMathPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateExpandCtlzPattern(patterns);
+ populateExpandExp2FPattern(patterns);
populateExpandTanPattern(patterns);
populateExpandTanhPattern(patterns);
populateExpandFmaFPattern(patterns);
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
new file mode 100644
index 0000000000000..3fb3b2b719134
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math,convert-arith-to-llvm),convert-vector-to-llvm,func.func(convert-math-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner \
+// RUN: -e main -entry-point-result=void -O0 \
+// RUN: -shared-libs=%mlir_c_runner_utils \
+// RUN: -shared-libs=%mlir_runner_utils \
+// RUN: | FileCheck %s
+
+// -------------------------------------------------------------------------- //
+// exp2f.
+// -------------------------------------------------------------------------- //
+func.func @func_exp2f(%a : f64) {
+ %r = math.exp2 %a : f64
+ vector.print %r : f64
+ return
+}
+
+func.func @exp2f() {
+ // CHECK: 2
+ %a = arith.constant 1.0 : f64
+ call @func_exp2f(%a) : (f64) -> ()
+
+ // CHECK: 4
+ %b = arith.constant 2.0 : f64
+ call @func_exp2f(%b) : (f64) -> ()
+
+ // CHECK: 5.65685
+ %c = arith.constant 2.5 : f64
+ call @func_exp2f(%c) : (f64) -> ()
+
+ // CHECK: 0.29730
+ %d = arith.constant -1.75 : f64
+ call @func_exp2f(%d) : (f64) -> ()
+
+ // CHECK: 1.09581
+ %e = arith.constant 0.132 : f64
+ call @func_exp2f(%e) : (f64) -> ()
+
+ // CHECK: inf
+ %f1 = arith.constant 0.00 : f64
+ %f2 = arith.constant 1.00 : f64
+ %f = arith.divf %f2, %f1 : f64
+ call @func_exp2f(%f) : (f64) -> ()
+
+ // CHECK: inf
+ %g = arith.constant 5038939.0 : f64
+ call @func_exp2f(%g) : (f64) -> ()
+
+ // CHECK: 0
+ %neg_inf = arith.constant 0xff80000000000000 : f64
+ call @func_exp2f(%neg_inf) : (f64) -> ()
+
+ // CHECK: inf
+ %i = arith.constant 0x7fc0000000000000 : f64
+ call @func_exp2f(%i) : (f64) -> ()
+ return
+}
+
+func.func @main() {
+ call @exp2f() : () -> ()
+ return
+}
More information about the Mlir-commits
mailing list