[Mlir-commits] [mlir] 8165eaa - [mlir](arithmetic) Add ceildivui to the arithmetic dialect
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 10 17:49:17 PST 2021
Author: lipracer
Date: 2021-11-11T01:49:14Z
New Revision: 8165eaa8853195e3211cb84b77083c3f10adcb5e
URL: https://github.com/llvm/llvm-project/commit/8165eaa8853195e3211cb84b77083c3f10adcb5e
DIFF: https://github.com/llvm/llvm-project/commit/8165eaa8853195e3211cb84b77083c3f10adcb5e.diff
LOG: [mlir](arithmetic) Add ceildivui to the arithmetic dialect
The specific description is [[ https://llvm.discourse.group/t/adding-unsigned-integer-ceil-and-floor-in-std-dialect/4541 | Adding unsigned integer ceil in Std Dialect ]] .
When we lower ceilDivOp this will generate below code, sometimes we know m and n are unsigned intergal.Here are some redundant judgments about positive and negative.
So we need to add some unsigned operations to simplify the instructions.
```
ceilDiv(n, m)
x = (m > 0) ? -1 : 1
return (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
```
unsigned operations:
```
ceilDivU(n, m)
return n ==0 ? 0 : ((n - 1) / m) + 1
```
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D113363
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
mlir/test/Dialect/Arithmetic/expand-ops.mlir
mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
mlir/test/Transforms/canonicalize.mlir
mlir/test/Transforms/constant-fold.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 4ec51079322ae..2e90455daaa14 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -276,6 +276,30 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// CeilDivUIOp
+//===----------------------------------------------------------------------===//
+
+def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> {
+ let summary = "unsigned ceil integer division operation";
+ let description = [{
+ Unsigned integer division. Rounds towards positive infinity. Treats the
+ leading bit as the most significant, i.e. for `i16` given two's complement
+ representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
+
+ Note: the semantics of division by zero is TBD; do NOT assume any specific
+ behavior.
+
+ Example:
+
+ ```mlir
+ // Scalar unsigned integer division.
+ %a = arith.ceildivui %b, %c : i64
+ ```
+ }];
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// CeilDivSIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 09c742bc82176..19b5bd05ab088 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -306,6 +306,36 @@ static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
return val.sadd_ov(one, overflow);
}
+//===----------------------------------------------------------------------===//
+// CeilDivUIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
+ bool overflowOrDiv0 = false;
+ auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+ if (overflowOrDiv0 || !b) {
+ overflowOrDiv0 = true;
+ return a;
+ }
+ APInt quotient = a.udiv(b);
+ if (!a.urem(b))
+ return quotient;
+ APInt one(a.getBitWidth(), 1, true);
+ return quotient.uadd_ov(one, overflowOrDiv0);
+ });
+ // Fold out ceil division by one. Assumes all tensors of all ones are
+ // splats.
+ if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
+ if (rhs.getValue() == 1)
+ return getLhs();
+ } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
+ if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
+ return getLhs();
+ }
+
+ return overflowOrDiv0 ? Attribute() : result;
+}
+
//===----------------------------------------------------------------------===//
// CeilDivSIOp
//===----------------------------------------------------------------------===//
@@ -342,7 +372,7 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
return zero.ssub_ov(div, overflowOrDiv0);
});
- // Fold out floor division by one. Assumes all tensors of all ones are
+ // Fold out ceil division by one. Assumes all tensors of all ones are
// splats.
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
if (rhs.getValue() == 1)
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index 8cfd1c91838b3..87e41bb1c2e28 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -13,6 +13,30 @@ using namespace mlir;
namespace {
+/// Expands CeilDivUIOp (n, m) into
+/// n == 0 ? 0 : ((n-1) / m) + 1
+struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ Value a = op.lhs();
+ Value b = op.rhs();
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(a.getType(), 0));
+ Value compare =
+ rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(a.getType(), 1));
+ Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
+ Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
+ Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+ Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
+ rewriter.replaceOp(op, {res});
+ return success();
+ }
+};
+
/// Expands CeilDivSIOp (n, m) into
/// 1) x = (m > 0) ? -1 : 1
/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
@@ -132,7 +156,8 @@ struct ArithmeticExpandOpsPass
arith::populateArithmeticExpandOpsPatterns(patterns);
target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>();
- target.addIllegalOp<arith::CeilDivSIOp, arith::FloorDivSIOp>();
+ target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
+ arith::FloorDivSIOp>();
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
@@ -144,8 +169,9 @@ struct ArithmeticExpandOpsPass
void mlir::arith::populateArithmeticExpandOpsPatterns(
RewritePatternSet &patterns) {
- patterns.add<CeilDivSIOpConverter, FloorDivSIOpConverter>(
- patterns.getContext());
+ patterns
+ .add<CeilDivUIOpConverter, CeilDivSIOpConverter, FloorDivSIOpConverter>(
+ patterns.getContext());
}
std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index 25bac1ad8466c..4955b83b80bb8 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -175,7 +175,8 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
StandardOpsDialect>();
- target.addIllegalOp<arith::CeilDivSIOp, arith::FloorDivSIOp>();
+ target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
+ arith::FloorDivSIOp>();
target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
return op.getKind() != AtomicRMWKind::maxf &&
op.getKind() != AtomicRMWKind::minf;
diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
index 23ab1267e0ab6..a1bd39208be37 100644
--- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
@@ -111,3 +111,37 @@ func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : index
}
+
+// -----
+
+// Test ceil divide with unsigned integer
+// CHECK-LABEL: func @ceildivui
+// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
+func @ceildivui(%arg0: i32, %arg1: i32) -> (i32) {
+ %res = arith.ceildivui %arg0, %arg1 : i32
+ return %res : i32
+// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32
+// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : i32
+// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
+// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : i32
+// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : i32
+// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
+// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : i32
+}
+
+// -----
+
+// Test unsigned ceil divide with index
+// CHECK-LABEL: func @ceildivui_index
+// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
+func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
+ %res = arith.ceildivui %arg0, %arg1 : index
+ return %res : index
+// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
+// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : index
+// CHECK: [[ONE:%.+]] = arith.constant 1 : index
+// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : index
+// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : index
+// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index
+// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : index
+}
diff --git a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
index eaadd9908b7ed..d55d39b93dbef 100644
--- a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
+++ b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
@@ -17,6 +17,7 @@ func @entry() {
%c20 = arith.constant 20: i32
%c10 = arith.constant 10: i32
%cmin10 = arith.constant -10: i32
+ %cmax_int = arith.constant 2147483647: i32
%A = memref.alloc() : memref<40xi32>
// print numerator
@@ -64,20 +65,39 @@ func @entry() {
}
call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+ // test with ceildivui(*, 10)
+ affine.for %i = 0 to 40 {
+ %ii = arith.index_cast %i: index to i32
+ %val = arith.ceildivui %ii, %c10 : i32
+ memref.store %val, %A[%i] : memref<40xi32>
+ }
+ call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+ // test with ceildivui(*, -1)
+ affine.for %i = 0 to 40 {
+ %ii = arith.index_cast %i: index to i32
+ %ii30 = arith.subi %ii, %c20 : i32
+ %val = arith.ceildivui %ii30, %cmax_int : i32
+ memref.store %val, %A[%i] : memref<40xi32>
+ }
+ call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
memref.dealloc %A : memref<40xi32>
return
}
// List below is aligned for easy manual check
-// legend: num, ceil(num, 10), floor(num, 10), ceil(num, -10), floor(num, -10)
+// legend: num, signed_ceil(num, 10), floor(num, 10), signed_ceil(num, -10), floor(num, -10), unsigned_ceil(num, 10), unsigned_ceil(num, max_int)
// ( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
// ( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
-// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1,-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
-// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
-// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
+// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
// CHECK:( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
// CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
// CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
+// CHECK:( 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4 )
+// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index f12ff02aff383..2a0c7a34a5e0a 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1028,6 +1028,26 @@ func @tensor_arith.ceildivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// -----
+// CHECK-LABEL: func @arith.ceildivui_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @arith.ceildivui_by_one(%arg0: i32) -> (i32) {
+ %c1 = arith.constant 1 : i32
+ %res = arith.ceildivui %arg0, %c1 : i32
+ // CHECK: return %[[ARG]]
+ return %res : i32
+}
+
+// CHECK-LABEL: func @tensor_arith.ceildivui_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @tensor_arith.ceildivui_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
+ %c1 = arith.constant dense<1> : tensor<4x5xi32>
+ %res = arith.ceildivui %arg0, %c1 : tensor<4x5xi32>
+ // CHECK: return %[[ARG]]
+ return %res : tensor<4x5xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @memref_cast_folding_subview
func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
%0 = memref.cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index edd40bafda5df..5406a8588ce4b 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -478,6 +478,44 @@ func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {
// -----
+// CHECK-LABEL: func @simple_arith.ceildivui
+func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
+ // CHECK-DAG: [[C0:%.+]] = arith.constant 0
+ %z = arith.constant 0 : i32
+ // CHECK-DAG: [[C6:%.+]] = arith.constant 7
+ %0 = arith.constant 7 : i32
+ %1 = arith.constant 2 : i32
+
+ // ceil(7, 2) = 4
+ // CHECK-NEXT: [[C3:%.+]] = arith.constant 4 : i32
+ %2 = arith.ceildivui %0, %1 : i32
+
+ %3 = arith.constant -2 : i32
+
+ // ceil(7, -2) = 0
+ // CHECK-NEXT: [[CM1:%.+]] = arith.constant 1 : i32
+ %4 = arith.ceildivui %0, %3 : i32
+
+ %5 = arith.constant -8 : i32
+
+ // ceil(-8, 2) = 2147483644
+ // CHECK-NEXT: [[CM4:%.+]] = arith.constant 2147483644 : i32
+ %6 = arith.ceildivui %5, %1 : i32
+
+ %7 = arith.constant -15 : i32
+
+ // ceil(-15, -2) = 0
+ // CHECK-NOT: arith.constant 1 : i32
+ %8 = arith.ceildivui %7, %3 : i32
+
+ // CHECK-NEXT: [[XZ:%.+]] = arith.ceildivui [[C6]], [[C0]]
+ %9 = arith.ceildivui %0, %z : i32
+
+ return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32
+}
+
+// -----
+
// CHECK-LABEL: func @simple_arith.remsi
func @simple_arith.remsi(%a : i32) -> (i32, i32, i32) {
%0 = arith.constant 5 : i32
More information about the Mlir-commits
mailing list