[Mlir-commits] [mlir] 2217888 - [mlir][math] Expand math.ceilf to truncate, compares and increments
Robert Suderman
llvmlistbot at llvm.org
Tue Apr 11 06:53:57 PDT 2023
Author: Balaji V. Iyer
Date: 2023-04-11T13:52:45Z
New Revision: 2217888d2c86f70ced50eba1d68185ccf0fdade3
URL: https://github.com/llvm/llvm-project/commit/2217888d2c86f70ced50eba1d68185ccf0fdade3
DIFF: https://github.com/llvm/llvm-project/commit/2217888d2c86f70ced50eba1d68185ccf0fdade3.diff
LOG: [mlir][math] Expand math.ceilf to truncate, compares and increments
Ceilf are pushed directly to libm. This is problematic for
situations where libm is not available. This patch will break down
a ceilf function to truncate followed by an increment if the
truncated value is smaller than the input value.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D147974
Added:
Modified:
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f93374855a8d2..1b32de2b99683 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -18,6 +18,7 @@ void populateExpandTanPattern(RewritePatternSet &patterns);
void populateExpandTanhPattern(RewritePatternSet &patterns);
void populateExpandFmaFPattern(RewritePatternSet &patterns);
void populateExpandFloorFPattern(RewritePatternSet &patterns);
+void populateExpandCeilFPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 2dab48dfda436..b70ac4e006eac 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -46,6 +46,13 @@ static Value createIntConst(Location loc, Type type, int64_t value,
return b.create<arith::ConstantOp>(loc, attr);
}
+static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
+ Type opType = operand.getType();
+ Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
+ Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+ return fpFixedConvert;
+}
+
/// Expands tanh op into
/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
@@ -112,8 +119,7 @@ static LogicalResult convertFloorOp(math::FloorOp op,
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
- Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
- Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
+ Value fpFixedConvert = createTruncatedFPValue(operand, b);
// Creating constants for later use.
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
@@ -128,6 +134,30 @@ static LogicalResult convertFloorOp(math::FloorOp op,
return success();
}
+// Converts a ceilf() function to the following:
+// ceilf(float x) ->
+// y = (float)(int) x
+// if (x > y) then incr = 1 else incr = 0
+// y = y + incr <= replace this op with the ceilf op.
+static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type opType = operand.getType();
+ Value fpFixedConvert = createTruncatedFPValue(operand, b);
+
+ // Creating constants for later use.
+ Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
+ Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
+
+ Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
+ fpFixedConvert);
+ Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
+
+ Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
+ rewriter.replaceOp(op, ret);
+ return success();
+}
+
// Converts math.ctlz to scf and arith operations. This is done
// by performing a binary search on the bits.
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
@@ -187,6 +217,11 @@ void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
patterns.add(convertFmaFOp);
}
+
+void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
+ patterns.add(convertCeilOp);
+}
+
void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
patterns.add(convertFloorOp);
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index d67193c974051..8ab644931b669 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -148,3 +148,20 @@ func.func @floorf_func(%a: f64) -> f64 {
%ret = math.floor %a : f64
return %ret : f64
}
+
+// -----
+
+// CHECK-LABEL: func @ceilf_func
+// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
+func.func @ceilf_func(%a: f64) -> f64 {
+ // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
+ // CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000
+ // CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
+ // CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
+ // CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[CVTF]]
+ // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
+ // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[CVTF]], [[INCR]]
+ // CHECK-NEXT: return [[ADDF]]
+ %ret = math.ceil %a : f64
+ return %ret : f64
+}
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index e6a44894b4319..c670617f8446f 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -41,6 +41,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandTanhPattern(patterns);
populateExpandFmaFPattern(patterns);
populateExpandFloorFPattern(patterns);
+ populateExpandCeilFPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index 0fff84d3eef1b..130147b01d0a7 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -647,6 +647,43 @@ func.func @floorf() {
return
}
+// -------------------------------------------------------------------------- //
+// ceil.
+// -------------------------------------------------------------------------- //
+func.func @func_ceilf32(%a : f32) {
+ %r = math.ceil %a : f32
+ vector.print %r : f32
+ return
+}
+
+func.func @ceilf() {
+ // CHECK: 4
+ %a = arith.constant 3.8 : f32
+ call @func_ceilf32(%a) : (f32) -> ()
+
+ // CHECK: -3
+ %b = arith.constant -3.8 : f32
+ call @func_ceilf32(%b) : (f32) -> ()
+
+ // CHECK: 0
+ %c = arith.constant 0.0 : f32
+ call @func_ceilf32(%c) : (f32) -> ()
+
+ // CHECK: -4
+ %d = arith.constant -4.2 : f32
+ call @func_ceilf32(%d) : (f32) -> ()
+
+ // CHECK: -495
+ %e = arith.constant -495.0 : f32
+ call @func_ceilf32(%e) : (f32) -> ()
+
+ // CHECK: 495
+ %f = arith.constant 495.0 : f32
+ call @func_ceilf32(%f) : (f32) -> ()
+
+ return
+}
+
func.func @main() {
call @tanh(): () -> ()
call @log(): () -> ()
@@ -661,6 +698,7 @@ func.func @main() {
call @atan2() : () -> ()
call @cbrt() : () -> ()
call @floorf() : () -> ()
+ call @ceilf() : () -> ()
return
}
More information about the Mlir-commits
mailing list