[Mlir-commits] [mlir] 8165eaa - [mlir](arithmetic) Add ceildivui to the arithmetic dialect

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 10 17:49:17 PST 2021


Author: lipracer
Date: 2021-11-11T01:49:14Z
New Revision: 8165eaa8853195e3211cb84b77083c3f10adcb5e

URL: https://github.com/llvm/llvm-project/commit/8165eaa8853195e3211cb84b77083c3f10adcb5e
DIFF: https://github.com/llvm/llvm-project/commit/8165eaa8853195e3211cb84b77083c3f10adcb5e.diff

LOG: [mlir](arithmetic) Add ceildivui to the arithmetic dialect

The specific description is [[ https://llvm.discourse.group/t/adding-unsigned-integer-ceil-and-floor-in-std-dialect/4541 | Adding unsigned integer ceil in Std Dialect ]] .

When we lower ceilDivOp this will generate below code, sometimes we know m and n are unsigned intergal.Here are some redundant judgments about positive and negative.
So we need to add some unsigned operations to simplify the instructions.
```
ceilDiv(n, m)
  x = (m > 0) ? -1 : 1
  return (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
```
unsigned operations:
```
ceilDivU(n, m)
  return n ==0 ?  0 :  ((n - 1) / m) + 1
```

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D113363

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
    mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Arithmetic/expand-ops.mlir
    mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
    mlir/test/Transforms/canonicalize.mlir
    mlir/test/Transforms/constant-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 4ec51079322ae..2e90455daaa14 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -276,6 +276,30 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// CeilDivUIOp
+//===----------------------------------------------------------------------===//
+
+def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> {
+  let summary = "unsigned ceil integer division operation";
+  let description = [{
+    Unsigned integer division. Rounds towards positive infinity. Treats the 
+    leading bit as the most significant, i.e. for `i16` given two's complement
+    representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
+
+    Note: the semantics of division by zero is TBD; do NOT assume any specific
+    behavior.
+
+    Example:
+
+    ```mlir
+    // Scalar unsigned integer division.
+    %a = arith.ceildivui %b, %c : i64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // CeilDivSIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 09c742bc82176..19b5bd05ab088 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -306,6 +306,36 @@ static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
   return val.sadd_ov(one, overflow);
 }
 
+//===----------------------------------------------------------------------===//
+// CeilDivUIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
+  bool overflowOrDiv0 = false;
+  auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+    if (overflowOrDiv0 || !b) {
+      overflowOrDiv0 = true;
+      return a;
+    }
+    APInt quotient = a.udiv(b);
+    if (!a.urem(b))
+      return quotient;
+    APInt one(a.getBitWidth(), 1, true);
+    return quotient.uadd_ov(one, overflowOrDiv0);
+  });
+  // Fold out ceil division by one. Assumes all tensors of all ones are
+  // splats.
+  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
+    if (rhs.getValue() == 1)
+      return getLhs();
+  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
+    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
+      return getLhs();
+  }
+
+  return overflowOrDiv0 ? Attribute() : result;
+}
+
 //===----------------------------------------------------------------------===//
 // CeilDivSIOp
 //===----------------------------------------------------------------------===//
@@ -342,7 +372,7 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
     return zero.ssub_ov(div, overflowOrDiv0);
   });
 
-  // Fold out floor division by one. Assumes all tensors of all ones are
+  // Fold out ceil division by one. Assumes all tensors of all ones are
   // splats.
   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
     if (rhs.getValue() == 1)

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index 8cfd1c91838b3..87e41bb1c2e28 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -13,6 +13,30 @@ using namespace mlir;
 
 namespace {
 
+/// Expands CeilDivUIOp (n, m) into
+///  n == 0 ? 0 : ((n-1) / m) + 1
+struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    Value a = op.lhs();
+    Value b = op.rhs();
+    Value zero = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIntegerAttr(a.getType(), 0));
+    Value compare =
+        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
+    Value one = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIntegerAttr(a.getType(), 1));
+    Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
+    Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
+    Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+    Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
+    rewriter.replaceOp(op, {res});
+    return success();
+  }
+};
+
 /// Expands CeilDivSIOp (n, m) into
 ///   1) x = (m > 0) ? -1 : 1
 ///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
@@ -132,7 +156,8 @@ struct ArithmeticExpandOpsPass
     arith::populateArithmeticExpandOpsPatterns(patterns);
 
     target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>();
-    target.addIllegalOp<arith::CeilDivSIOp, arith::FloorDivSIOp>();
+    target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
+                        arith::FloorDivSIOp>();
 
     if (failed(
             applyPartialConversion(getFunction(), target, std::move(patterns))))
@@ -144,8 +169,9 @@ struct ArithmeticExpandOpsPass
 
 void mlir::arith::populateArithmeticExpandOpsPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CeilDivSIOpConverter, FloorDivSIOpConverter>(
-      patterns.getContext());
+  patterns
+      .add<CeilDivUIOpConverter, CeilDivSIOpConverter, FloorDivSIOpConverter>(
+          patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index 25bac1ad8466c..4955b83b80bb8 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -175,7 +175,8 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
 
     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
                            StandardOpsDialect>();
-    target.addIllegalOp<arith::CeilDivSIOp, arith::FloorDivSIOp>();
+    target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
+                        arith::FloorDivSIOp>();
     target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
       return op.getKind() != AtomicRMWKind::maxf &&
              op.getKind() != AtomicRMWKind::minf;

diff  --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
index 23ab1267e0ab6..a1bd39208be37 100644
--- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
@@ -111,3 +111,37 @@ func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
 // CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
 // CHECK:           [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : index
 }
+
+// -----
+
+// Test ceil divide with unsigned integer
+// CHECK-LABEL:       func @ceildivui
+// CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
+func @ceildivui(%arg0: i32, %arg1: i32) -> (i32) {
+  %res = arith.ceildivui %arg0, %arg1 : i32
+  return %res : i32
+// CHECK:           [[ZERO:%.+]] = arith.constant 0 : i32
+// CHECK:           [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : i32
+// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
+// CHECK:           [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : i32
+// CHECK:           [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : i32
+// CHECK:           [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
+// CHECK:           [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : i32
+}
+
+// -----
+
+// Test unsigned ceil divide with index
+// CHECK-LABEL:       func @ceildivui_index
+// CHECK-SAME:     ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
+func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
+  %res = arith.ceildivui %arg0, %arg1 : index
+  return %res : index
+// CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
+// CHECK:           [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : index
+// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
+// CHECK:           [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : index
+// CHECK:           [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : index
+// CHECK:           [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index
+// CHECK:           [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : index
+}

diff  --git a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
index eaadd9908b7ed..d55d39b93dbef 100644
--- a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
+++ b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
@@ -17,6 +17,7 @@ func @entry() {
   %c20 = arith.constant 20: i32
   %c10 = arith.constant 10: i32
   %cmin10 = arith.constant -10: i32
+  %cmax_int = arith.constant 2147483647: i32
   %A = memref.alloc() : memref<40xi32>
 
   // print numerator
@@ -64,20 +65,39 @@ func @entry() {
   }
   call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
 
+  // test with ceildivui(*, 10)
+  affine.for %i = 0 to 40  {
+    %ii = arith.index_cast %i: index to i32
+    %val = arith.ceildivui %ii, %c10 : i32
+    memref.store %val, %A[%i] : memref<40xi32>
+  }
+  call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+  // test with ceildivui(*, -1)
+  affine.for %i = 0 to 40  {
+    %ii = arith.index_cast %i: index to i32
+    %ii30 = arith.subi %ii, %c20 : i32
+    %val = arith.ceildivui %ii30, %cmax_int : i32
+    memref.store %val, %A[%i] : memref<40xi32>
+  }
+  call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
   memref.dealloc %A : memref<40xi32>
   return
 }
 
 // List below is aligned for easy manual check
-// legend: num, ceil(num, 10), floor(num, 10), ceil(num, -10), floor(num, -10)
+// legend: num, signed_ceil(num, 10), floor(num, 10), signed_ceil(num, -10), floor(num, -10), unsigned_ceil(num, 10), unsigned_ceil(num, max_int)
 //  ( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
 //  (  -2,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  0,  0,  0,  0,  0,  0,  0,  0,  0, 0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2 )
-//  (  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -1, -1,  -1,-1, -1, -1, -1, -1, -1, -1, 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1 )
-//  (   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   1,  1,   1, 1,  1,  1,  1,  1,  1,  1, 0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
-//  (   2,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,  0,   0, 0,  0,  0,  0,  0,  0,  0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
+//  (  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1 )
+//  (   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   1,  1,  1,  1,  1,  1,  1,  1,  1,  1, 0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+//  (   2,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,  0,  0,  0,  0,  0,  0,  0,  0,  0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
 
 // CHECK:( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
 // CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
 // CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
 // CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
 // CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
+// CHECK:( 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4 )
+// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index f12ff02aff383..2a0c7a34a5e0a 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1028,6 +1028,26 @@ func @tensor_arith.ceildivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
 
 // -----
 
+// CHECK-LABEL: func @arith.ceildivui_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @arith.ceildivui_by_one(%arg0: i32) -> (i32) {
+  %c1 = arith.constant 1 : i32
+  %res = arith.ceildivui %arg0, %c1 : i32
+  // CHECK: return %[[ARG]]
+  return %res : i32
+}
+
+// CHECK-LABEL: func @tensor_arith.ceildivui_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @tensor_arith.ceildivui_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
+  %c1 = arith.constant dense<1> : tensor<4x5xi32>
+  %res = arith.ceildivui %arg0, %c1 : tensor<4x5xi32>
+  // CHECK: return %[[ARG]]
+  return %res : tensor<4x5xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @memref_cast_folding_subview
 func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
   %0 = memref.cast %arg0 : memref<4x5xf32> to memref<?x?xf32>

diff  --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index edd40bafda5df..5406a8588ce4b 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -478,6 +478,44 @@ func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {
 
 // -----
 
+// CHECK-LABEL: func @simple_arith.ceildivui
+func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
+  // CHECK-DAG: [[C0:%.+]] = arith.constant 0
+  %z = arith.constant 0 : i32
+  // CHECK-DAG: [[C6:%.+]] = arith.constant 7
+  %0 = arith.constant 7 : i32
+  %1 = arith.constant 2 : i32
+
+  // ceil(7, 2) = 4
+  // CHECK-NEXT: [[C3:%.+]] = arith.constant 4 : i32
+  %2 = arith.ceildivui %0, %1 : i32
+
+  %3 = arith.constant -2 : i32
+
+  // ceil(7, -2) = 0
+  // CHECK-NEXT: [[CM1:%.+]] = arith.constant 1 : i32
+  %4 = arith.ceildivui %0, %3 : i32
+
+  %5 = arith.constant -8 : i32
+
+  // ceil(-8, 2) = 2147483644
+  // CHECK-NEXT: [[CM4:%.+]] = arith.constant 2147483644 : i32
+  %6 = arith.ceildivui %5, %1 : i32
+
+  %7 = arith.constant -15 : i32
+
+  // ceil(-15, -2) = 0
+  // CHECK-NOT: arith.constant 1 : i32
+  %8 = arith.ceildivui %7, %3 : i32
+
+  // CHECK-NEXT: [[XZ:%.+]] = arith.ceildivui [[C6]], [[C0]]
+  %9 = arith.ceildivui %0, %z : i32
+
+  return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32
+}
+
+// -----
+
 // CHECK-LABEL: func @simple_arith.remsi
 func @simple_arith.remsi(%a : i32) -> (i32, i32, i32) {
   %0 = arith.constant 5 : i32


        


More information about the Mlir-commits mailing list