[Mlir-commits] [mlir] 2d4e856 - [mlir][math] Expand math.powf to exp, log and multiply

Robert Suderman llvmlistbot at llvm.org
Fri Apr 14 07:07:49 PDT 2023


Author: Balaji V. Iyer
Date: 2023-04-14T14:04:19Z
New Revision: 2d4e8567097eae48bff6ed2b0b1d7056ede15456

URL: https://github.com/llvm/llvm-project/commit/2d4e8567097eae48bff6ed2b0b1d7056ede15456
DIFF: https://github.com/llvm/llvm-project/commit/2d4e8567097eae48bff6ed2b0b1d7056ede15456.diff

LOG: [mlir][math] Expand math.powf to exp, log and multiply

Powf functions are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will decompose the
powf function into log of exponent multiplied by log of base and raise
it to the exp.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D148164

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/lib/Dialect/Math/TestExpandMath.cpp
    mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 6cd5b0a409223..245a11747d5c8 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -20,6 +20,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
 void populateExpandExp2FPattern(RewritePatternSet &patterns);
+void populateExpandPowFPattern(RewritePatternSet &patterns);
 void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bc35263e12b2d..a37340d312f51 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -157,6 +157,19 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   rewriter.replaceOp(op, ret);
   return success();
 }
+// Converts  Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
+static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operandA = op.getOperand(0);
+  Value operandB = op.getOperand(1);
+  Type opType = operandA.getType();
+
+  Value logA = b.create<math::LogOp>(opType, operandA);
+  Value mult = b.create<arith::MulFOp>(opType, logA, operandB);
+  Value expResult = b.create<math::ExpOp>(opType, mult);
+  rewriter.replaceOp(op, expResult);
+  return success();
+}
 
 // exp2f(float x) -> exp(x * ln(2))
 //   Proof: Let's say 2^x = y
@@ -264,6 +277,10 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
   patterns.add(convertExp2fOp);
 }
 
+void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertPowfOp);
+}
+
 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
   patterns.add(convertRoundOp);
 }

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index b3a5668f3235b..382278c060c8e 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -207,3 +207,16 @@ func.func @roundf_func(%a: f64) -> f64 {
   %ret = math.round %a : f64
   return %ret : f64
 }
+
+// -----
+
+// CHECK-LABEL:   func @powf_func
+// CHECK-SAME:    ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
+func.func @powf_func(%a: f64, %b: f64) ->f64 {
+  // CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]]
+  // CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]]
+  // CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
+  // CHECK: return [[EXPR]]
+  %ret = math.powf %a, %b : f64
+  return %ret : f64
+}

diff  --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 5692ecf8d7237..c9b3357c9b508 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -43,6 +43,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandFmaFPattern(patterns);
   populateExpandFloorFPattern(patterns);
   populateExpandCeilFPattern(patterns);
+  populateExpandPowFPattern(patterns);
   populateExpandRoundFPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }

diff  --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index f3c7a2c4051b4..b72f9ba8fd258 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -100,9 +100,68 @@ func.func @roundf() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// pow.
+// -------------------------------------------------------------------------- //
+func.func @func_powff64(%a : f64, %b : f64) {
+  %r = math.powf %a, %b : f64
+  vector.print %r : f64
+  return
+}
+
+func.func @powf() {
+  // CHECK: 16
+  %a   = arith.constant 4.0 : f64
+  %a_p = arith.constant 2.0 : f64
+  call @func_powff64(%a, %a_p) : (f64, f64) -> ()
+
+  // CHECK: -nan
+  %b   = arith.constant -3.0 : f64
+  %b_p = arith.constant 3.0 : f64
+  call @func_powff64(%b, %b_p) : (f64, f64) -> ()
+
+  // CHECK: 2.343
+  %c   = arith.constant 2.343 : f64
+  %c_p = arith.constant 1.000 : f64
+  call @func_powff64(%c, %c_p) : (f64, f64) -> ()
+
+  // CHECK: 0.176171
+  %d   = arith.constant 4.25 : f64
+  %d_p = arith.constant -1.2  : f64
+  call @func_powff64(%d, %d_p) : (f64, f64) -> ()
+
+  // CHECK: 1
+  %e   = arith.constant 4.385 : f64
+  %e_p = arith.constant 0.00 : f64
+  call @func_powff64(%e, %e_p) : (f64, f64) -> ()
+
+  // CHECK: 6.62637
+  %f    = arith.constant 4.835 : f64
+  %f_p  = arith.constant 1.2 : f64
+  call @func_powff64(%f, %f_p) : (f64, f64) -> ()
+
+  // CHECK: -nan
+  %g    = arith.constant 0xff80000000000000 : f64
+  call @func_powff64(%g, %g) : (f64, f64) -> ()
+
+  // CHECK: nan
+  %h = arith.constant 0x7fffffffffffffff : f64
+  call @func_powff64(%h, %h) : (f64, f64) -> ()
+
+  // CHECK: nan
+  %i = arith.constant 1.0 : f64
+  call @func_powff64(%i, %h) : (f64, f64) -> ()
+
+  // CHECK: inf
+  %j   = arith.constant 29385.0 : f64
+  %j_p = arith.constant 23598.0 : f64
+  call @func_powff64(%j, %j_p) : (f64, f64) -> () 
+  return
+}
 
 func.func @main() {
   call @exp2f() : () -> ()
   call @roundf() : () -> ()
+  call @powf() : () -> ()
   return
 }


        


More information about the Mlir-commits mailing list