[Mlir-commits] [mlir] af9eb1e - [mlir][math] Expand math.floorf to truncate, compares and increments

Robert Suderman llvmlistbot at llvm.org
Mon Apr 10 14:09:20 PDT 2023


Author: Balaji V. Iyer
Date: 2023-04-10T21:04:27Z
New Revision: af9eb1e3845f699a15aef73b6398fc62499fda01

URL: https://github.com/llvm/llvm-project/commit/af9eb1e3845f699a15aef73b6398fc62499fda01
DIFF: https://github.com/llvm/llvm-project/commit/af9eb1e3845f699a15aef73b6398fc62499fda01.diff

LOG: [mlir][math] Expand math.floorf to truncate, compares and increments

Floorf are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will break down
a floorf function to truncate followed by an increment for negative
values, if necessary.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D147966

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/lib/Dialect/Math/TestExpandMath.cpp
    mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 597618033cd0a..f93374855a8d2 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -17,7 +17,7 @@ void populateExpandCtlzPattern(RewritePatternSet &patterns);
 void populateExpandTanPattern(RewritePatternSet &patterns);
 void populateExpandTanhPattern(RewritePatternSet &patterns);
 void populateExpandFmaFPattern(RewritePatternSet &patterns);
-
+void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index f3e807e13102b..2dab48dfda436 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -102,6 +102,32 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+// Converts a floorf() function to the following:
+// floorf(float x) ->
+//     y = (float)(int) x
+//     if (x < 0) then incr = -1 else incr = 0
+//     y = y + incr    <= replace this op with the floorf op.
+static LogicalResult convertFloorOp(math::FloorOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+  Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
+  Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+
+  // Creating constants for later use.
+  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
+
+  Value negCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
+  Value incrValue =
+      b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
+  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+  rewriter.replaceOp(op, ret);
+  return success();
+}
+
 // Converts math.ctlz to scf and arith operations. This is done
 // by performing a binary search on the bits.
 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -161,3 +187,6 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
   patterns.add(convertFmaFOp);
 }
+void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertFloorOp);
+}

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index cc6c401d0c356..d67193c974051 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -131,3 +131,20 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
   %ret = math.fma %a, %b, %c : f64
   return %ret : f64
 }
+
+// -----
+
+// CHECK-LABEL:     func @floorf_func
+// CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
+func.func @floorf_func(%a: f64) -> f64 {
+  // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
+  // CHECK-DAG:   [[CST_0:%.+]] = arith.constant -1.000
+  // CHECK-NEXT:   [[CVTI:%.+]] = arith.fptosi [[ARG0]]
+  // CHECK-NEXT:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
+  // CHECK-NEXT:   [[COMP:%.+]] = arith.cmpf olt, [[ARG0]], [[CST]]
+  // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
+  // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]]
+  // CHECK-NEXT:   return [[ADDF]]
+  %ret = math.floor %a : f64
+  return %ret : f64
+}

diff  --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 12bc3afe7ef75..e6a44894b4319 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -40,6 +40,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandTanPattern(patterns);
   populateExpandTanhPattern(patterns);
   populateExpandFmaFPattern(patterns);
+  populateExpandFloorFPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 

diff  --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index 665d3280c0c46..0fff84d3eef1b 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -610,6 +610,43 @@ func.func @cbrt() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// floor.
+// -------------------------------------------------------------------------- //
+func.func @func_floorf32(%a : f32) {
+  %r = math.floor %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @floorf() {
+  // CHECK: 3
+  %a = arith.constant 3.8 : f32
+  call @func_floorf32(%a) : (f32) -> ()
+
+  // CHECK: -4
+  %b = arith.constant -3.8 : f32
+  call @func_floorf32(%b) : (f32) -> ()
+
+  // CHECK: 0
+  %c = arith.constant 0.0 : f32
+  call @func_floorf32(%c) : (f32) -> ()
+
+  // CHECK: -5
+  %d = arith.constant -4.2 : f32
+  call @func_floorf32(%d) : (f32) -> ()
+
+  // CHECK: -2
+  %e = arith.constant -2.0 : f32
+  call @func_floorf32(%e) : (f32) -> ()
+
+  // CHECK: 2
+  %f = arith.constant 2.0 : f32
+  call @func_floorf32(%f) : (f32) -> ()
+
+  return
+}
+
 func.func @main() {
   call @tanh(): () -> ()
   call @log(): () -> ()
@@ -623,6 +660,7 @@ func.func @main() {
   call @atan() : () -> ()
   call @atan2() : () -> ()
   call @cbrt() : () -> ()
+  call @floorf() : () -> ()
   return
 }
 


        


More information about the Mlir-commits mailing list