[Mlir-commits] [mlir] 4da9651 - [mlir][math] Expand math.exp2 to use math.exp.

Robert Suderman llvmlistbot at llvm.org
Thu Apr 13 09:08:19 PDT 2023


Author: Balaji V. Iyer
Date: 2023-04-13T16:06:04Z
New Revision: 4da96515ea8552cdf14c6aa6310d2a91fbe74641

URL: https://github.com/llvm/llvm-project/commit/4da96515ea8552cdf14c6aa6310d2a91fbe74641
DIFF: https://github.com/llvm/llvm-project/commit/4da96515ea8552cdf14c6aa6310d2a91fbe74641.diff

LOG: [mlir][math] Expand math.exp2 to use math.exp.

Exp2 functions are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will expand the exp2
function to use exp2 with the input multiplied by ln2 (natural log).

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D148064

Added: 
    mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Modified: 
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 1b32de2b99683..3ac18c3a24184 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -19,6 +19,7 @@ void populateExpandTanhPattern(RewritePatternSet &patterns);
 void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
+void populateExpandExp2FPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index b70ac4e006eac..e9447dcbd5394 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -158,6 +158,22 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+// exp2f(float x) -> exp(x * ln(2))
+//   Proof: Let's say 2^x = y
+//   ln(2^x) = ln(y)
+//   x * ln(2) = ln(y) => e ^(x*ln(2)) = y
+static LogicalResult convertExp2fOp(math::Exp2Op op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+  Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
+  Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
+  Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
+  rewriter.replaceOp(op, exp);
+  return success();
+}
+
 // Converts math.ctlz to scf and arith operations. This is done
 // by performing a binary search on the bits.
 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -222,6 +238,10 @@ void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
   patterns.add(convertCeilOp);
 }
 
+void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
+  patterns.add(convertExp2fOp);
+}
+
 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
   patterns.add(convertFloorOp);
 }

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 8ab644931b669..50986969fca57 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -165,3 +165,27 @@ func.func @ceilf_func(%a: f64) -> f64 {
   %ret = math.ceil %a : f64
   return %ret : f64
 }
+
+// -----
+
+// CHECK-LABEL:     func @exp2f_func
+// CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
+func.func @exp2f_func(%a: f64) -> f64 {
+  // CHECK-DAG:     [[CST:%.+]]  = arith.constant 0.69314718055994529
+  // CHECK:         [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
+  // CHECK:         [[EXP:%.+]]  = math.exp [[MULF]]
+  // CHECK:         return [[EXP]]
+  %ret = math.exp2 %a : f64
+  return %ret : f64
+}
+
+// CHECK-LABEL:     func @exp2f_func_tensor
+// CHECK-SAME:      ([[ARG0:%.+]]: tensor<1xf32>) -> tensor<1xf32>
+func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
+  // CHECK-DAG:     [[CST:%.+]]  = arith.constant dense<0.693147182>
+  // CHECK:         [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
+  // CHECK:         [[EXP:%.+]]  = math.exp [[MULF]]
+  // CHECK:         return [[EXP]]
+  %ret = math.exp2 %a : tensor<1xf32>
+  return %ret : tensor<1xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index c670617f8446f..29eff9959bebc 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -37,6 +37,7 @@ struct TestExpandMathPass
 void TestExpandMathPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   populateExpandCtlzPattern(patterns);
+  populateExpandExp2FPattern(patterns);
   populateExpandTanPattern(patterns);
   populateExpandTanhPattern(patterns);
   populateExpandFmaFPattern(patterns);

diff  --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
new file mode 100644
index 0000000000000..3fb3b2b719134
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -0,0 +1,61 @@
+// RUN:   mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math,convert-arith-to-llvm),convert-vector-to-llvm,func.func(convert-math-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
+// RUN: | mlir-cpu-runner                                                      \
+// RUN:     -e main -entry-point-result=void -O0                               \
+// RUN:     -shared-libs=%mlir_c_runner_utils  \
+// RUN:     -shared-libs=%mlir_runner_utils    \
+// RUN: | FileCheck %s
+
+// -------------------------------------------------------------------------- //
+// exp2f.
+// -------------------------------------------------------------------------- //
+func.func @func_exp2f(%a : f64) {
+  %r = math.exp2 %a : f64
+  vector.print %r : f64
+  return
+}
+
+func.func @exp2f() {
+  // CHECK: 2
+  %a = arith.constant 1.0 : f64
+  call @func_exp2f(%a) : (f64) -> ()  
+
+  // CHECK: 4
+  %b = arith.constant 2.0 : f64
+  call @func_exp2f(%b) : (f64) -> ()
+
+  // CHECK: 5.65685
+  %c = arith.constant 2.5 : f64
+  call @func_exp2f(%c) : (f64) -> ()
+
+  // CHECK: 0.29730
+  %d = arith.constant -1.75 : f64
+  call @func_exp2f(%d) : (f64) -> ()
+
+  // CHECK: 1.09581
+  %e = arith.constant 0.132 : f64
+  call @func_exp2f(%e) : (f64) -> ()
+
+  // CHECK: inf
+  %f1 = arith.constant 0.00 : f64
+  %f2 = arith.constant 1.00 : f64
+  %f = arith.divf %f2, %f1 : f64
+  call @func_exp2f(%f) : (f64) -> ()
+
+  // CHECK: inf
+  %g = arith.constant 5038939.0 : f64
+  call @func_exp2f(%g) : (f64) -> ()
+
+  // CHECK: 0
+  %neg_inf = arith.constant 0xff80000000000000 : f64
+  call @func_exp2f(%neg_inf) : (f64) -> ()
+
+  // CHECK: inf
+  %i = arith.constant 0x7fc0000000000000 : f64
+  call @func_exp2f(%i) : (f64) -> ()
+  return
+}
+
+func.func @main() {
+  call @exp2f() : () -> ()
+  return
+}


        


More information about the Mlir-commits mailing list