[Mlir-commits] [mlir] be91157 - [mlir][math] Expand math.round to truncate, compare and increment.

Robert Suderman llvmlistbot at llvm.org
Thu Apr 13 11:02:33 PDT 2023


Author: Balaji V. Iyer
Date: 2023-04-13T18:02:10Z
New Revision: be9115788c7f223dfc6d369455ce84c0e443743b

URL: https://github.com/llvm/llvm-project/commit/be9115788c7f223dfc6d369455ce84c0e443743b
DIFF: https://github.com/llvm/llvm-project/commit/be9115788c7f223dfc6d369455ce84c0e443743b.diff

LOG: [mlir][math] Expand math.round to truncate, compare and increment.

Round functions are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will decompose the
roundf function by adding 0.5 to positive number to input
(subtracting for negative) following by a truncate.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D148026

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/lib/Dialect/Math/TestExpandMath.cpp
    mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 3ac18c3a24184..6cd5b0a409223 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -20,6 +20,7 @@ void populateExpandFmaFPattern(RewritePatternSet &patterns);
 void populateExpandFloorFPattern(RewritePatternSet &patterns);
 void populateExpandCeilFPattern(RewritePatternSet &patterns);
 void populateExpandExp2FPattern(RewritePatternSet &patterns);
+void populateExpandRoundFPattern(RewritePatternSet &patterns);
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
 
 struct MathPolynomialApproximationOptions {

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index e9447dcbd5394..bc35263e12b2d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -174,6 +174,28 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
   return success();
 }
 
+static LogicalResult convertRoundOp(math::RoundOp op,
+                                    PatternRewriter &rewriter) {
+  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+  Value operand = op.getOperand();
+  Type opType = operand.getType();
+
+  // Creating constants for later use.
+  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
+  Value negHalf = createFloatConst(op->getLoc(), opType, -0.5, rewriter);
+
+  Value posCheck =
+      b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, zero);
+  Value incrValue =
+      b.create<arith::SelectOp>(op->getLoc(), posCheck, half, negHalf);
+  Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
+
+  Value fpFixedConvert = createTruncatedFPValue(add, b);
+  rewriter.replaceOp(op, fpFixedConvert);
+  return success();
+}
+
 // Converts math.ctlz to scf and arith operations. This is done
 // by performing a binary search on the bits.
 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -242,6 +264,10 @@ void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
   patterns.add(convertExp2fOp);
 }
 
+void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
+  patterns.add(convertRoundOp);
+}
+
 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
   patterns.add(convertFloorOp);
 }

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 50986969fca57..b3a5668f3235b 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -189,3 +189,21 @@ func.func @exp2f_func_tensor(%a: tensor<1xf32>) -> tensor<1xf32> {
   %ret = math.exp2 %a : tensor<1xf32>
   return %ret : tensor<1xf32>
 }
+
+// -----
+
+// CHECK-LABEL:      func @roundf_func
+// CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
+func.func @roundf_func(%a: f64) -> f64 {
+  // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
+  // CHECK-DAG:   [[CST_0:%.+]] = arith.constant 5.000000e-01
+  // CHECK-DAG:   [[CST_1:%.+]] = arith.constant -5.000000e-01
+  // CHECK-DAG:  [[COMP:%.+]] = arith.cmpf oge, [[ARG0]], [[CST]]
+  // CHECK-DAG:  [[SEL:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST_1]]
+  // CHECK-DAG:  [[ADDF:%.+]] = arith.addf [[ARG0]], [[SEL]]
+  // CHECK-DAG:   [[CVTI:%.+]] = arith.fptosi [[ADDF]]
+  // CHECK-DAG:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
+  // CHECK:   return [[CVTF]]
+  %ret = math.round %a : f64
+  return %ret : f64
+}

diff  --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 29eff9959bebc..5692ecf8d7237 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -43,6 +43,7 @@ void TestExpandMathPass::runOnOperation() {
   populateExpandFmaFPattern(patterns);
   populateExpandFloorFPattern(patterns);
   populateExpandCeilFPattern(patterns);
+  populateExpandRoundFPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 

diff  --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 3fb3b2b719134..f3c7a2c4051b4 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -55,7 +55,54 @@ func.func @exp2f() {
   return
 }
 
+// -------------------------------------------------------------------------- //
+// round.
+// -------------------------------------------------------------------------- //
+func.func @func_roundf(%a : f32) {
+  %r = math.round %a : f32
+  vector.print %r : f32
+  return
+}
+
+func.func @roundf() {
+  // CHECK: 4
+  %a = arith.constant 3.8 : f32
+  call @func_roundf(%a) : (f32) -> ()
+
+  // CHECK: -4
+  %b = arith.constant -3.8 : f32
+  call @func_roundf(%b) : (f32) -> ()
+
+  // CHECK: 0
+  %c = arith.constant 0.0 : f32
+  call @func_roundf(%c) : (f32) -> ()
+
+  // CHECK: -4
+  %d = arith.constant -4.2 : f32
+  call @func_roundf(%d) : (f32) -> ()
+
+  // CHECK: -495
+  %e = arith.constant -495.0 : f32
+  call @func_roundf(%e) : (f32) -> ()
+
+  // CHECK: 495
+  %f = arith.constant 495.0 : f32
+  call @func_roundf(%f) : (f32) -> ()
+
+  // CHECK: 9
+  %g = arith.constant 8.5 : f32
+  call @func_roundf(%g) : (f32) -> ()
+
+  // CHECK: -9
+  %h = arith.constant -8.5 : f32
+  call @func_roundf(%h) : (f32) -> ()
+
+  return
+}
+
+
 func.func @main() {
   call @exp2f() : () -> ()
+  call @roundf() : () -> ()
   return
 }


        


More information about the Mlir-commits mailing list