[Mlir-commits] [mlir] 0795715 - [mlir][std] Add SignedCeilDivIOp and SignedFloorDivIOp with std to std lowering triggered by -std-expand-divs option. The new operations support positive/negative nominator/denominator numbers.

Alexandre Eichenberger llvmlistbot at llvm.org
Wed Nov 4 11:18:01 PST 2020


Author: Alexandre Eichenberger
Date: 2020-11-04T14:16:23-05:00
New Revision: 0795715616416382717d5302f33de5bd10cfec96

URL: https://github.com/llvm/llvm-project/commit/0795715616416382717d5302f33de5bd10cfec96
DIFF: https://github.com/llvm/llvm-project/commit/0795715616416382717d5302f33de5bd10cfec96.diff

LOG: [mlir][std] Add SignedCeilDivIOp and SignedFloorDivIOp with std to std lowering triggered by -std-expand-divs option. The new operations support positive/negative nominator/denominator numbers.

Differential Revision: https://reviews.llvm.org/D89726

Signed-off-by: Alexandre Eichenberger <alexe at us.ibm.com>

Added: 
    mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
    mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp
    mlir/test/Dialect/Standard/std-expand-divs.mlir

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/test/IR/core-ops.mlir
    mlir/test/Transforms/canonicalize.mlir
    mlir/test/Transforms/constant-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index e88cb655f63f..793a441a9c7a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2823,6 +2823,63 @@ def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// SignedFloorDivIOp
+//===----------------------------------------------------------------------===//
+
+def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
+  let summary = "signed floor integer division operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `floordivi_signed` ssa-use `,` ssa-use `:` type
+    ```
+
+    Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`.
+
+    Note: the semantics of division by zero or signed division overflow (minimum
+    value divided by -1) is TBD; do NOT assume any specific behavior.
+
+    Example:
+
+    ```mlir
+    // Scalar signed integer division.
+    %a = floordivi_signed %b, %c : i64
+
+    ```
+  }];
+  let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// SignedCeilDivIOp
+//===----------------------------------------------------------------------===//
+
+def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
+  let summary = "signed ceil integer division operation";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `ceildivi_signed` ssa-use `,` ssa-use `:` type
+    ```
+
+    Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`.
+
+    Note: the semantics of division by zero or signed division overflow (minimum
+    value divided by -1) is TBD; do NOT assume any specific behavior.
+
+    Example:
+
+    ```mlir
+    // Scalar signed integer division.
+    %a = ceildivi_signed %b, %c : i64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // SignedRemIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 76fa79a77b25..d592842520c0 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -41,6 +41,16 @@ std::unique_ptr<Pass> createStdBufferizePass();
 /// Creates an instance of func bufferization pass.
 std::unique_ptr<Pass> createFuncBufferizePass();
 
+/// Creates an instance of the StdExpandDivs pass that legalizes Std
+/// dialect Divs to be convertible to StaLLVMndard. For example,
+/// `std.ceildivi_signed` get transformed to a number of std operations,
+/// which can be lowered to LLVM.
+std::unique_ptr<Pass> createStdExpandDivsPass();
+
+/// Collects a set of patterns to rewrite ops within the Std dialect.
+void populateStdExpandDivsRewritePatterns(MLIRContext *context,
+                                        OwningRewritePatternList &patterns);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index b0b172c33e82..d91e3be8d4cb 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -22,6 +22,11 @@ def StdBufferize : FunctionPass<"std-bufferize"> {
   let dependentDialects = ["scf::SCFDialect"];
 }
 
+def StdExpandDivs : FunctionPass<"std-expand-divs"> {
+  let summary = "Legalize div std dialect operations to be convertible to LLVM.";
+  let constructor = "mlir::createStdExpandDivsPass()";
+}
+
 def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
   let summary = "Bufferize func/call/return ops";
   let description = [{

diff  --git a/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
new file mode 100644
index 000000000000..86bb626d627b
--- /dev/null
+++ b/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -std-expand-divs -convert-vector-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @transfer_read_2d(%A : memref<40xi32>, %base1: index) {
+  %i42 = constant -42: i32
+  %f = vector.transfer_read %A[%base1], %i42
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    memref<40xi32>, vector<40xi32>
+  vector.print %f: vector<40xi32>
+  return
+}
+
+func @entry() {
+  %c0 = constant 0: index
+  %c20 = constant 20: i32
+  %c10 = constant 10: i32
+  %cmin10 = constant -10: i32
+  %A = alloc() : memref<40xi32>
+
+  // print numerator
+  affine.for %i = 0 to 40  {
+    %ii = index_cast %i: index to i32
+    %ii30 = subi %ii, %c20 : i32
+    store %ii30, %A[%i] : memref<40xi32>
+  }
+  call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+  // test with ceil(*, 10)
+  affine.for %i = 0 to 40  {
+    %ii = index_cast %i: index to i32
+    %ii30 = subi %ii, %c20 : i32
+    %val = ceildivi_signed %ii30, %c10 : i32
+    store %val, %A[%i] : memref<40xi32>
+  }
+  call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+    // test with floor(*, 10)
+  affine.for %i = 0 to 40  {
+    %ii = index_cast %i: index to i32
+    %ii30 = subi %ii, %c20 : i32
+    %val = floordivi_signed %ii30, %c10 : i32
+    store %val, %A[%i] : memref<40xi32>
+  }
+  call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+
+  // test with ceil(*, -10)
+  affine.for %i = 0 to 40  {
+    %ii = index_cast %i: index to i32
+    %ii30 = subi %ii, %c20 : i32
+    %val = ceildivi_signed %ii30, %cmin10 : i32
+    store %val, %A[%i] : memref<40xi32>
+  }
+  call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+  // test with floor(*, -10)
+  affine.for %i = 0 to 40  {
+    %ii = index_cast %i: index to i32
+    %ii30 = subi %ii, %c20 : i32
+    %val = floordivi_signed %ii30, %cmin10 : i32
+    store %val, %A[%i] : memref<40xi32>
+  }
+  call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+  return
+}
+
+// List below is aligned for easy manual check
+// legend: num, ceil(num, 10), floor(num, 10), ceil(num, -10), floor(num, -10)
+//  ( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
+//  (  -2,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  0,  0,  0,  0,  0,  0,  0,  0,  0, 0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2 )
+//  (  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -2,  -1, -1,  -1,-1, -1, -1, -1, -1, -1, -1, 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1 )
+//  (   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   1,  1,   1, 1,  1,  1,  1,  1,  1,  1, 0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+//  (   2,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,  0,   0, 0,  0,  0,  0,  0,  0,  0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
+
+// CHECK:( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
+// CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
+// CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
\ No newline at end of file

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 264f93443a82..9b5875e70793 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2893,6 +2893,113 @@ OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
   return overflowOrDiv0 ? Attribute() : result;
 }
 
+//===----------------------------------------------------------------------===//
+// SignedFloorDivIOp
+//===----------------------------------------------------------------------===//
+
+static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
+  // Returns (a-1)/b + 1
+  APInt one(a.getBitWidth(), 1, true); // Signed value 1.
+  APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
+  return val.sadd_ov(one, overflow);
+}
+
+OpFoldResult SignedFloorDivIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
+  // Don't fold if it would overflow or if it requires a division by zero.
+  bool overflowOrDiv0 = false;
+  auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+    if (overflowOrDiv0 || !b) {
+      overflowOrDiv0 = true;
+      return a;
+    }
+    unsigned bits = a.getBitWidth();
+    APInt zero = APInt::getNullValue(bits);
+    if (a.sge(zero) && b.sgt(zero)) {
+      // Both positive (or a is zero), return a / b.
+      return a.sdiv_ov(b, overflowOrDiv0);
+    } else if (a.sle(zero) && b.slt(zero)) {
+      // Both negative (or a is zero), return -a / -b.
+      APInt posA = zero.ssub_ov(a, overflowOrDiv0);
+      APInt posB = zero.ssub_ov(b, overflowOrDiv0);
+      return posA.sdiv_ov(posB, overflowOrDiv0);
+    } else if (a.slt(zero) && b.sgt(zero)) {
+      // A is negative, b is positive, return - ceil(-a, b).
+      APInt posA = zero.ssub_ov(a, overflowOrDiv0);
+      APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
+      return zero.ssub_ov(ceil, overflowOrDiv0);
+    } else {
+      // A is positive, b is negative, return - ceil(a, -b).
+      APInt posB = zero.ssub_ov(b, overflowOrDiv0);
+      APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
+      return zero.ssub_ov(ceil, overflowOrDiv0);
+    }
+  });
+
+  // Fold out floor division by one. Assumes all tensors of all ones are
+  // splats.
+  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
+    if (rhs.getValue() == 1)
+      return lhs();
+  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
+    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
+      return lhs();
+  }
+
+  return overflowOrDiv0 ? Attribute() : result;
+}
+
+//===----------------------------------------------------------------------===//
+// SignedCeilDivIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SignedCeilDivIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "binary operation takes two operands");
+
+  // Don't fold if it would overflow or if it requires a division by zero.
+  bool overflowOrDiv0 = false;
+  auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+    if (overflowOrDiv0 || !b) {
+      overflowOrDiv0 = true;
+      return a;
+    }
+    unsigned bits = a.getBitWidth();
+    APInt zero = APInt::getNullValue(bits);
+    if (a.sgt(zero) && b.sgt(zero)) {
+      // Both positive, return ceil(a, b).
+      return signedCeilNonnegInputs(a, b, overflowOrDiv0);
+    } else if (a.slt(zero) && b.slt(zero)) {
+      // Both negative, return ceil(-a, -b).
+      APInt posA = zero.ssub_ov(a, overflowOrDiv0);
+      APInt posB = zero.ssub_ov(b, overflowOrDiv0);
+      return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
+    } else if (a.slt(zero) && b.sgt(zero)) {
+      // A is negative, b is positive, return - ( -a / b).
+      APInt posA = zero.ssub_ov(a, overflowOrDiv0);
+      APInt div = posA.sdiv_ov(b, overflowOrDiv0);
+      return zero.ssub_ov(div, overflowOrDiv0);
+    } else {
+      // A is positive (or zero), b is negative, return - (a / -b).
+      APInt posB = zero.ssub_ov(b, overflowOrDiv0);
+      APInt div = a.sdiv_ov(posB, overflowOrDiv0);
+      return zero.ssub_ov(div, overflowOrDiv0);
+    }
+  });
+
+  // Fold out floor division by one. Assumes all tensors of all ones are
+  // splats.
+  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
+    if (rhs.getValue() == 1)
+      return lhs();
+  } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
+    if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
+      return lhs();
+  }
+
+  return overflowOrDiv0 ? Attribute() : result;
+}
+
 //===----------------------------------------------------------------------===//
 // SignedRemIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index 1334e7f83d55..7a63f234094f 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
   ExpandTanh.cpp
   FuncBufferize.cpp
   FuncConversions.cpp
+  StdExpandDivs.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp b/mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp
new file mode 100644
index 000000000000..224ecefabf6b
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp
@@ -0,0 +1,155 @@
+//===- StdExpandDivs.cpp - Code to prepare Std for lowring Divs 0to LLVM  -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file Std transformations to expand Divs operation to help for the
+// lowering to LLVM. Currently implemented tranformations are Ceil and Floor
+// for Signed Integers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Expands SignedCeilDivIOP (n, m) into
+///   1) x = (m > 0) ? -1 : 1
+///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
+struct SignedCeilDivIOpConverter : public OpRewritePattern<SignedCeilDivIOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(SignedCeilDivIOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    SignedCeilDivIOp signedCeilDivIOp = cast<SignedCeilDivIOp>(op);
+    Type type = signedCeilDivIOp.getType();
+    Value a = signedCeilDivIOp.lhs();
+    Value b = signedCeilDivIOp.rhs();
+    Value plusOne =
+        rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, 1));
+    Value zero =
+        rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, 0));
+    Value minusOne =
+        rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, -1));
+    // Compute x = (b>0) ? -1 : 1.
+    Value compare = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, b, zero);
+    Value x = rewriter.create<SelectOp>(loc, compare, minusOne, plusOne);
+    // Compute positive res: 1 + ((x+a)/b).
+    Value xPlusA = rewriter.create<AddIOp>(loc, x, a);
+    Value xPlusADivB = rewriter.create<SignedDivIOp>(loc, xPlusA, b);
+    Value posRes = rewriter.create<AddIOp>(loc, plusOne, xPlusADivB);
+    // Compute negative res: - ((-a)/b).
+    Value minusA = rewriter.create<SubIOp>(loc, zero, a);
+    Value minusADivB = rewriter.create<SignedDivIOp>(loc, minusA, b);
+    Value negRes = rewriter.create<SubIOp>(loc, zero, minusADivB);
+    // Result is (a*b>0) ? pos result : neg result.
+    // Note, we want to avoid using a*b because of possible overflow.
+    // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
+    // not particuliarly care if a*b<0 is true or false when b is zero
+    // as this will result in an illegal divide. So `a*b<0` can be reformulated
+    // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
+    // We pick the first expression here.
+    Value aNeg = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, a, zero);
+    Value aPos = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, a, zero);
+    Value bNeg = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, b, zero);
+    Value bPos = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, b, zero);
+    Value firstTerm = rewriter.create<AndOp>(loc, aNeg, bNeg);
+    Value secondTerm = rewriter.create<AndOp>(loc, aPos, bPos);
+    Value compareRes = rewriter.create<OrOp>(loc, firstTerm, secondTerm);
+    Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes);
+    // Perform substitution and return success.
+    rewriter.replaceOp(op, {res});
+    return success();
+  }
+};
+
+/// Expands SignedFloorDivIOP (n, m) into
+///   1)  x = (m<0) ? 1 : -1
+///   2)  return (n*m<0) ? - ((-n+x) / m) -1 : n / m
+struct SignedFloorDivIOpConverter : public OpRewritePattern<SignedFloorDivIOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(SignedFloorDivIOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    SignedFloorDivIOp signedFloorDivIOp = cast<SignedFloorDivIOp>(op);
+    Type type = signedFloorDivIOp.getType();
+    Value a = signedFloorDivIOp.lhs();
+    Value b = signedFloorDivIOp.rhs();
+    Value plusOne =
+        rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, 1));
+    Value zero =
+        rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, 0));
+    Value minusOne =
+        rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, -1));
+    // Compute x = (b<0) ? 1 : -1.
+    Value compare = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, b, zero);
+    Value x = rewriter.create<SelectOp>(loc, compare, plusOne, minusOne);
+    // Compute negative res: -1 - ((x-a)/b).
+    Value xMinusA = rewriter.create<SubIOp>(loc, x, a);
+    Value xMinusADivB = rewriter.create<SignedDivIOp>(loc, xMinusA, b);
+    Value negRes = rewriter.create<SubIOp>(loc, minusOne, xMinusADivB);
+    // Compute positive res: a/b.
+    Value posRes = rewriter.create<SignedDivIOp>(loc, a, b);
+    // Result is (a*b<0) ? negative result : positive result.
+    // Note, we want to avoid using a*b because of possible overflow.
+    // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
+    // not particuliarly care if a*b<0 is true or false when b is zero
+    // as this will result in an illegal divide. So `a*b<0` can be reformulated
+    // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
+    // We pick the first expression here.
+    Value aNeg = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, a, zero);
+    Value aPos = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, a, zero);
+    Value bNeg = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, b, zero);
+    Value bPos = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, b, zero);
+    Value firstTerm = rewriter.create<AndOp>(loc, aNeg, bPos);
+    Value secondTerm = rewriter.create<AndOp>(loc, aPos, bNeg);
+    Value compareRes = rewriter.create<OrOp>(loc, firstTerm, secondTerm);
+    Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes);
+    // Perform substitution and return success.
+    rewriter.replaceOp(op, {res});
+    return success();
+  }
+};
+
+} // namespace
+
+namespace {
+struct StdExpandDivs : public StdExpandDivsBase<StdExpandDivs> {
+  void runOnFunction() override;
+};
+} // namespace
+
+void StdExpandDivs::runOnFunction() {
+  MLIRContext &ctx = getContext();
+
+  OwningRewritePatternList patterns;
+  populateStdExpandDivsRewritePatterns(&ctx, patterns);
+
+  ConversionTarget target(getContext());
+  target.addLegalDialect<StandardOpsDialect>();
+  target.addIllegalOp<SignedCeilDivIOp>();
+  target.addIllegalOp<SignedFloorDivIOp>();
+  if (failed(
+          applyPartialConversion(getFunction(), target, std::move(patterns))))
+    signalPassFailure();
+}
+
+void mlir::populateStdExpandDivsRewritePatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns) {
+  patterns.insert<SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
+      context);
+}
+
+std::unique_ptr<Pass> mlir::createStdExpandDivsPass() {
+  return std::make_unique<StdExpandDivs>();
+}

diff  --git a/mlir/test/Dialect/Standard/std-expand-divs.mlir b/mlir/test/Dialect/Standard/std-expand-divs.mlir
new file mode 100644
index 000000000000..354a7863f423
--- /dev/null
+++ b/mlir/test/Dialect/Standard/std-expand-divs.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt -std-expand-divs %s -split-input-file | FileCheck %s
+
+// Test floor divide with signed integer
+// CHECK-LABEL:       func @floordivi
+// CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
+func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
+  %res = floordivi_signed %arg0, %arg1 : i32
+  return %res : i32
+// CHECK:           [[ONE:%.+]] = constant 1 : i32
+// CHECK:           [[ZERO:%.+]] = constant 0 : i32
+// CHECK:           [[MIN1:%.+]] = constant -1 : i32
+// CHECK:           [[CMP1:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32
+// CHECK:           [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : i32
+// CHECK:           [[TRUE1:%.+]] = subi [[X]], [[ARG0]] : i32
+// CHECK:           [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[ARG1]] : i32
+// CHECK:           [[TRUE3:%.+]] = subi [[MIN1]], [[TRUE2]] : i32
+// CHECK:           [[FALSE:%.+]] = divi_signed [[ARG0]], [[ARG1]] : i32
+// CHECK:           [[NNEG:%.+]] = cmpi "slt", [[ARG0]], [[ZERO]] : i32
+// CHECK:           [[NPOS:%.+]] = cmpi "sgt", [[ARG0]], [[ZERO]] : i32
+// CHECK:           [[MNEG:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32
+// CHECK:           [[MPOS:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32
+// CHECK:           [[TERM1:%.+]] = and [[NNEG]], [[MPOS]] : i1
+// CHECK:           [[TERM2:%.+]] = and [[NPOS]], [[MNEG]] : i1
+// CHECK:           [[CMP2:%.+]] = or [[TERM1]], [[TERM2]] : i1
+// CHECK:           [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
+}
+
+// -----
+
+// Test ceil divide with signed integer
+// CHECK-LABEL:       func @ceildivi
+// CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
+func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
+  %res = ceildivi_signed %arg0, %arg1 : i32
+  return %res : i32
+
+// CHECK:           [[ONE:%.+]] = constant 1 : i32
+// CHECK:           [[ZERO:%.+]] = constant 0 : i32
+// CHECK:           [[MINONE:%.+]] = constant -1 : i32
+// CHECK:           [[CMP1:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32
+// CHECK:           [[X:%.+]] = select [[CMP1]], [[MINONE]], [[ONE]] : i32
+// CHECK:           [[TRUE1:%.+]] = addi [[X]], [[ARG0]] : i32
+// CHECK:           [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[ARG1]] : i32
+// CHECK:           [[TRUE3:%.+]] = addi [[ONE]], [[TRUE2]] : i32
+// CHECK:           [[FALSE1:%.+]] = subi [[ZERO]], [[ARG0]] : i32
+// CHECK:           [[FALSE2:%.+]] = divi_signed [[FALSE1]], [[ARG1]] : i32
+// CHECK:           [[FALSE3:%.+]] = subi [[ZERO]], [[FALSE2]] : i32
+// CHECK:           [[NNEG:%.+]] = cmpi "slt", [[ARG0]], [[ZERO]] : i32
+// CHECK:           [[NPOS:%.+]] = cmpi "sgt", [[ARG0]], [[ZERO]] : i32
+// CHECK:           [[MNEG:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32
+// CHECK:           [[MPOS:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32
+// CHECK:           [[TERM1:%.+]] = and [[NNEG]], [[MNEG]] : i1
+// CHECK:           [[TERM2:%.+]] = and [[NPOS]], [[MPOS]] : i1
+// CHECK:           [[CMP2:%.+]] = or [[TERM1]], [[TERM2]] : i1
+// CHECK:           [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
+}

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index da7394eae784..dfed144fb44e 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -569,6 +569,30 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: %{{[0-9]+}} = floorf %arg0 : tensor<4x4x?xf32>
   %166 = floorf %t : tensor<4x4x?xf32>
 
+  // CHECK: %{{[0-9]+}} = floordivi_signed %arg2, %arg2 : i32
+  %167 = floordivi_signed %i, %i : i32
+
+  // CHECK: %{{[0-9]+}} = floordivi_signed %arg3, %arg3 : index
+  %168 = floordivi_signed %idx, %idx : index
+
+  // CHECK: %{{[0-9]+}} = floordivi_signed %cst_5, %cst_5 : vector<42xi32>
+  %169 = floordivi_signed %vci32, %vci32 : vector<42 x i32>
+
+  // CHECK: %{{[0-9]+}} = floordivi_signed %cst_4, %cst_4 : tensor<42xi32>
+  %170 = floordivi_signed %tci32, %tci32 : tensor<42 x i32>
+
+  // CHECK: %{{[0-9]+}} = ceildivi_signed %arg2, %arg2 : i32
+  %171 = ceildivi_signed %i, %i : i32
+
+  // CHECK: %{{[0-9]+}} = ceildivi_signed %arg3, %arg3 : index
+  %172 = ceildivi_signed %idx, %idx : index
+
+  // CHECK: %{{[0-9]+}} = ceildivi_signed %cst_5, %cst_5 : vector<42xi32>
+  %173 = ceildivi_signed %vci32, %vci32 : vector<42 x i32>
+
+  // CHECK: %{{[0-9]+}} = ceildivi_signed %cst_4, %cst_4 : tensor<42xi32>
+  %174 = ceildivi_signed %tci32, %tci32 : tensor<42 x i32>
+
   return
 }
 

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index dc7be097b0c0..7b8c45cb409b 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -949,6 +949,46 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
 
 // -----
 
+// CHECK-LABEL: func @floordivi_signed_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @floordivi_signed_by_one(%arg0: i32) -> (i32) {
+  %c1 = constant 1 : i32
+  %res = floordivi_signed %arg0, %c1 : i32
+  // CHECK: return %[[ARG]]
+  return %res : i32
+}
+
+// CHECK-LABEL: func @tensor_floordivi_signed_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @tensor_floordivi_signed_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
+  %c1 = constant dense<1> : tensor<4x5xi32>
+  %res = floordivi_signed %arg0, %c1 : tensor<4x5xi32>
+  // CHECK: return %[[ARG]]
+  return %res : tensor<4x5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @ceildivi_signed_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @ceildivi_signed_by_one(%arg0: i32) -> (i32) {
+  %c1 = constant 1 : i32
+  %res = ceildivi_signed %arg0, %c1 : i32
+  // CHECK: return %[[ARG]]
+  return %res : i32
+}
+
+// CHECK-LABEL: func @tensor_ceildivi_signed_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @tensor_ceildivi_signed_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
+  %c1 = constant dense<1> : tensor<4x5xi32>
+  %res = ceildivi_signed %arg0, %c1 : tensor<4x5xi32>
+  // CHECK: return %[[ARG]]
+  return %res : tensor<4x5xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @memref_cast_folding_subview
 func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
   %0 = memref_cast %arg0 : memref<4x5xf32> to memref<?x?xf32>

diff  --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index c75c89877830..234717863046 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -402,6 +402,82 @@ func @divi_unsigned_splat_tensor() -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi3
 
 // -----
 
+// CHECK-LABEL: func @simple_floordivi_signed
+func @simple_floordivi_signed() -> (i32, i32, i32, i32, i32) {
+  // CHECK-DAG: [[C0:%.+]] = constant 0
+  %z = constant 0 : i32
+  // CHECK-DAG: [[C6:%.+]] = constant 7
+  %0 = constant 7 : i32
+  %1 = constant 2 : i32
+
+  // floor(7, 2) = 3
+  // CHECK-NEXT: [[C3:%.+]] = constant 3 : i32
+  %2 = floordivi_signed %0, %1 : i32
+
+  %3 = constant -2 : i32
+
+  // floor(7, -2) = -4
+  // CHECK-NEXT: [[CM3:%.+]] = constant -4 : i32
+  %4 = floordivi_signed %0, %3 : i32
+
+  %5 = constant -9 : i32
+
+  // floor(-9, 2) = -5
+  // CHECK-NEXT: [[CM4:%.+]] = constant -5 : i32
+  %6 = floordivi_signed %5, %1 : i32
+
+  %7 = constant -13 : i32
+
+  // floor(-13, -2) = 6
+  // CHECK-NEXT: [[CM5:%.+]] = constant 6 : i32
+  %8 = floordivi_signed %7, %3 : i32
+
+  // CHECK-NEXT: [[XZ:%.+]] = floordivi_signed [[C6]], [[C0]]
+  %9 = floordivi_signed %0, %z : i32
+
+  return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+// CHECK-LABEL: func @simple_ceildivi_signed
+func @simple_ceildivi_signed() -> (i32, i32, i32, i32, i32) {
+  // CHECK-DAG: [[C0:%.+]] = constant 0
+  %z = constant 0 : i32
+  // CHECK-DAG: [[C6:%.+]] = constant 7
+  %0 = constant 7 : i32
+  %1 = constant 2 : i32
+
+  // ceil(7, 2) = 4
+  // CHECK-NEXT: [[C3:%.+]] = constant 4 : i32
+  %2 = ceildivi_signed %0, %1 : i32
+
+  %3 = constant -2 : i32
+
+  // ceil(7, -2) = -3
+  // CHECK-NEXT: [[CM3:%.+]] = constant -3 : i32
+  %4 = ceildivi_signed %0, %3 : i32
+
+  %5 = constant -9 : i32
+
+  // ceil(-9, 2) = -4
+  // CHECK-NEXT: [[CM4:%.+]] = constant -4 : i32
+  %6 = ceildivi_signed %5, %1 : i32
+
+  %7 = constant -15 : i32
+
+  // ceil(-15, -2) = 8
+  // CHECK-NEXT: [[CM5:%.+]] = constant 8 : i32
+  %8 = ceildivi_signed %7, %3 : i32
+
+  // CHECK-NEXT: [[XZ:%.+]] = ceildivi_signed [[C6]], [[C0]]
+  %9 = ceildivi_signed %0, %z : i32
+
+  return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32
+}
+
+// -----
+
 // CHECK-LABEL: func @simple_remi_signed
 func @simple_remi_signed(%a : i32) -> (i32, i32, i32) {
   %0 = constant 5 : i32


        


More information about the Mlir-commits mailing list