[Mlir-commits] [mlir] 0795715 - [mlir][std] Add SignedCeilDivIOp and SignedFloorDivIOp with std to std lowering triggered by -std-expand-divs option. The new operations support positive/negative nominator/denominator numbers.
Alexandre Eichenberger
llvmlistbot at llvm.org
Wed Nov 4 11:18:01 PST 2020
Author: Alexandre Eichenberger
Date: 2020-11-04T14:16:23-05:00
New Revision: 0795715616416382717d5302f33de5bd10cfec96
URL: https://github.com/llvm/llvm-project/commit/0795715616416382717d5302f33de5bd10cfec96
DIFF: https://github.com/llvm/llvm-project/commit/0795715616416382717d5302f33de5bd10cfec96.diff
LOG: [mlir][std] Add SignedCeilDivIOp and SignedFloorDivIOp with std to std lowering triggered by -std-expand-divs option. The new operations support positive/negative nominator/denominator numbers.
Differential Revision: https://reviews.llvm.org/D89726
Signed-off-by: Alexandre Eichenberger <alexe at us.ibm.com>
Added:
mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp
mlir/test/Dialect/Standard/std-expand-divs.mlir
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
mlir/test/IR/core-ops.mlir
mlir/test/Transforms/canonicalize.mlir
mlir/test/Transforms/constant-fold.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index e88cb655f63f..793a441a9c7a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2823,6 +2823,63 @@ def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// SignedFloorDivIOp
+//===----------------------------------------------------------------------===//
+
+def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
+ let summary = "signed floor integer division operation";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `floordivi_signed` ssa-use `,` ssa-use `:` type
+ ```
+
+ Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`.
+
+ Note: the semantics of division by zero or signed division overflow (minimum
+ value divided by -1) is TBD; do NOT assume any specific behavior.
+
+ Example:
+
+ ```mlir
+ // Scalar signed integer division.
+ %a = floordivi_signed %b, %c : i64
+
+ ```
+ }];
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// SignedCeilDivIOp
+//===----------------------------------------------------------------------===//
+
+def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
+ let summary = "signed ceil integer division operation";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `ceildivi_signed` ssa-use `,` ssa-use `:` type
+ ```
+
+ Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`.
+
+ Note: the semantics of division by zero or signed division overflow (minimum
+ value divided by -1) is TBD; do NOT assume any specific behavior.
+
+ Example:
+
+ ```mlir
+ // Scalar signed integer division.
+ %a = ceildivi_signed %b, %c : i64
+ ```
+ }];
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// SignedRemIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 76fa79a77b25..d592842520c0 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -41,6 +41,16 @@ std::unique_ptr<Pass> createStdBufferizePass();
/// Creates an instance of func bufferization pass.
std::unique_ptr<Pass> createFuncBufferizePass();
+/// Creates an instance of the StdExpandDivs pass that legalizes Std
+/// dialect Divs to be convertible to StaLLVMndard. For example,
+/// `std.ceildivi_signed` get transformed to a number of std operations,
+/// which can be lowered to LLVM.
+std::unique_ptr<Pass> createStdExpandDivsPass();
+
+/// Collects a set of patterns to rewrite ops within the Std dialect.
+void populateStdExpandDivsRewritePatterns(MLIRContext *context,
+ OwningRewritePatternList &patterns);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index b0b172c33e82..d91e3be8d4cb 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -22,6 +22,11 @@ def StdBufferize : FunctionPass<"std-bufferize"> {
let dependentDialects = ["scf::SCFDialect"];
}
+def StdExpandDivs : FunctionPass<"std-expand-divs"> {
+ let summary = "Legalize div std dialect operations to be convertible to LLVM.";
+ let constructor = "mlir::createStdExpandDivsPass()";
+}
+
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
let summary = "Bufferize func/call/return ops";
let description = [{
diff --git a/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
new file mode 100644
index 000000000000..86bb626d627b
--- /dev/null
+++ b/mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -std-expand-divs -convert-vector-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+func @transfer_read_2d(%A : memref<40xi32>, %base1: index) {
+ %i42 = constant -42: i32
+ %f = vector.transfer_read %A[%base1], %i42
+ {permutation_map = affine_map<(d0) -> (d0)>} :
+ memref<40xi32>, vector<40xi32>
+ vector.print %f: vector<40xi32>
+ return
+}
+
+func @entry() {
+ %c0 = constant 0: index
+ %c20 = constant 20: i32
+ %c10 = constant 10: i32
+ %cmin10 = constant -10: i32
+ %A = alloc() : memref<40xi32>
+
+ // print numerator
+ affine.for %i = 0 to 40 {
+ %ii = index_cast %i: index to i32
+ %ii30 = subi %ii, %c20 : i32
+ store %ii30, %A[%i] : memref<40xi32>
+ }
+ call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+ // test with ceil(*, 10)
+ affine.for %i = 0 to 40 {
+ %ii = index_cast %i: index to i32
+ %ii30 = subi %ii, %c20 : i32
+ %val = ceildivi_signed %ii30, %c10 : i32
+ store %val, %A[%i] : memref<40xi32>
+ }
+ call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+ // test with floor(*, 10)
+ affine.for %i = 0 to 40 {
+ %ii = index_cast %i: index to i32
+ %ii30 = subi %ii, %c20 : i32
+ %val = floordivi_signed %ii30, %c10 : i32
+ store %val, %A[%i] : memref<40xi32>
+ }
+ call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+
+ // test with ceil(*, -10)
+ affine.for %i = 0 to 40 {
+ %ii = index_cast %i: index to i32
+ %ii30 = subi %ii, %c20 : i32
+ %val = ceildivi_signed %ii30, %cmin10 : i32
+ store %val, %A[%i] : memref<40xi32>
+ }
+ call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+ // test with floor(*, -10)
+ affine.for %i = 0 to 40 {
+ %ii = index_cast %i: index to i32
+ %ii30 = subi %ii, %c20 : i32
+ %val = floordivi_signed %ii30, %cmin10 : i32
+ store %val, %A[%i] : memref<40xi32>
+ }
+ call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
+
+ return
+}
+
+// List below is aligned for easy manual check
+// legend: num, ceil(num, 10), floor(num, 10), ceil(num, -10), floor(num, -10)
+// ( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
+// ( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
+// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1,-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
+
+// CHECK:( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
+// CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
+// CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
\ No newline at end of file
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 264f93443a82..9b5875e70793 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2893,6 +2893,113 @@ OpFoldResult SignedDivIOp::fold(ArrayRef<Attribute> operands) {
return overflowOrDiv0 ? Attribute() : result;
}
+//===----------------------------------------------------------------------===//
+// SignedFloorDivIOp
+//===----------------------------------------------------------------------===//
+
+static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
+ // Returns (a-1)/b + 1
+ APInt one(a.getBitWidth(), 1, true); // Signed value 1.
+ APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
+ return val.sadd_ov(one, overflow);
+}
+
+OpFoldResult SignedFloorDivIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
+
+ // Don't fold if it would overflow or if it requires a division by zero.
+ bool overflowOrDiv0 = false;
+ auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+ if (overflowOrDiv0 || !b) {
+ overflowOrDiv0 = true;
+ return a;
+ }
+ unsigned bits = a.getBitWidth();
+ APInt zero = APInt::getNullValue(bits);
+ if (a.sge(zero) && b.sgt(zero)) {
+ // Both positive (or a is zero), return a / b.
+ return a.sdiv_ov(b, overflowOrDiv0);
+ } else if (a.sle(zero) && b.slt(zero)) {
+ // Both negative (or a is zero), return -a / -b.
+ APInt posA = zero.ssub_ov(a, overflowOrDiv0);
+ APInt posB = zero.ssub_ov(b, overflowOrDiv0);
+ return posA.sdiv_ov(posB, overflowOrDiv0);
+ } else if (a.slt(zero) && b.sgt(zero)) {
+ // A is negative, b is positive, return - ceil(-a, b).
+ APInt posA = zero.ssub_ov(a, overflowOrDiv0);
+ APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
+ return zero.ssub_ov(ceil, overflowOrDiv0);
+ } else {
+ // A is positive, b is negative, return - ceil(a, -b).
+ APInt posB = zero.ssub_ov(b, overflowOrDiv0);
+ APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
+ return zero.ssub_ov(ceil, overflowOrDiv0);
+ }
+ });
+
+ // Fold out floor division by one. Assumes all tensors of all ones are
+ // splats.
+ if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
+ if (rhs.getValue() == 1)
+ return lhs();
+ } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
+ if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
+ return lhs();
+ }
+
+ return overflowOrDiv0 ? Attribute() : result;
+}
+
+//===----------------------------------------------------------------------===//
+// SignedCeilDivIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SignedCeilDivIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
+
+ // Don't fold if it would overflow or if it requires a division by zero.
+ bool overflowOrDiv0 = false;
+ auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
+ if (overflowOrDiv0 || !b) {
+ overflowOrDiv0 = true;
+ return a;
+ }
+ unsigned bits = a.getBitWidth();
+ APInt zero = APInt::getNullValue(bits);
+ if (a.sgt(zero) && b.sgt(zero)) {
+ // Both positive, return ceil(a, b).
+ return signedCeilNonnegInputs(a, b, overflowOrDiv0);
+ } else if (a.slt(zero) && b.slt(zero)) {
+ // Both negative, return ceil(-a, -b).
+ APInt posA = zero.ssub_ov(a, overflowOrDiv0);
+ APInt posB = zero.ssub_ov(b, overflowOrDiv0);
+ return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
+ } else if (a.slt(zero) && b.sgt(zero)) {
+ // A is negative, b is positive, return - ( -a / b).
+ APInt posA = zero.ssub_ov(a, overflowOrDiv0);
+ APInt div = posA.sdiv_ov(b, overflowOrDiv0);
+ return zero.ssub_ov(div, overflowOrDiv0);
+ } else {
+ // A is positive (or zero), b is negative, return - (a / -b).
+ APInt posB = zero.ssub_ov(b, overflowOrDiv0);
+ APInt div = a.sdiv_ov(posB, overflowOrDiv0);
+ return zero.ssub_ov(div, overflowOrDiv0);
+ }
+ });
+
+ // Fold out floor division by one. Assumes all tensors of all ones are
+ // splats.
+ if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
+ if (rhs.getValue() == 1)
+ return lhs();
+ } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
+ if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
+ return lhs();
+ }
+
+ return overflowOrDiv0 ? Attribute() : result;
+}
+
//===----------------------------------------------------------------------===//
// SignedRemIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index 1334e7f83d55..7a63f234094f 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
ExpandTanh.cpp
FuncBufferize.cpp
FuncConversions.cpp
+ StdExpandDivs.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp b/mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp
new file mode 100644
index 000000000000..224ecefabf6b
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp
@@ -0,0 +1,155 @@
+//===- StdExpandDivs.cpp - Code to prepare Std for lowring Divs 0to LLVM -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file Std transformations to expand Divs operation to help for the
+// lowering to LLVM. Currently implemented tranformations are Ceil and Floor
+// for Signed Integers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Expands SignedCeilDivIOP (n, m) into
+/// 1) x = (m > 0) ? -1 : 1
+/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
+struct SignedCeilDivIOpConverter : public OpRewritePattern<SignedCeilDivIOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(SignedCeilDivIOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ SignedCeilDivIOp signedCeilDivIOp = cast<SignedCeilDivIOp>(op);
+ Type type = signedCeilDivIOp.getType();
+ Value a = signedCeilDivIOp.lhs();
+ Value b = signedCeilDivIOp.rhs();
+ Value plusOne =
+ rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, 1));
+ Value zero =
+ rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, 0));
+ Value minusOne =
+ rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, -1));
+ // Compute x = (b>0) ? -1 : 1.
+ Value compare = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, b, zero);
+ Value x = rewriter.create<SelectOp>(loc, compare, minusOne, plusOne);
+ // Compute positive res: 1 + ((x+a)/b).
+ Value xPlusA = rewriter.create<AddIOp>(loc, x, a);
+ Value xPlusADivB = rewriter.create<SignedDivIOp>(loc, xPlusA, b);
+ Value posRes = rewriter.create<AddIOp>(loc, plusOne, xPlusADivB);
+ // Compute negative res: - ((-a)/b).
+ Value minusA = rewriter.create<SubIOp>(loc, zero, a);
+ Value minusADivB = rewriter.create<SignedDivIOp>(loc, minusA, b);
+ Value negRes = rewriter.create<SubIOp>(loc, zero, minusADivB);
+ // Result is (a*b>0) ? pos result : neg result.
+ // Note, we want to avoid using a*b because of possible overflow.
+ // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
+ // not particuliarly care if a*b<0 is true or false when b is zero
+ // as this will result in an illegal divide. So `a*b<0` can be reformulated
+ // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
+ // We pick the first expression here.
+ Value aNeg = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, a, zero);
+ Value aPos = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, a, zero);
+ Value bNeg = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, b, zero);
+ Value bPos = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, b, zero);
+ Value firstTerm = rewriter.create<AndOp>(loc, aNeg, bNeg);
+ Value secondTerm = rewriter.create<AndOp>(loc, aPos, bPos);
+ Value compareRes = rewriter.create<OrOp>(loc, firstTerm, secondTerm);
+ Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes);
+ // Perform substitution and return success.
+ rewriter.replaceOp(op, {res});
+ return success();
+ }
+};
+
+/// Expands SignedFloorDivIOP (n, m) into
+/// 1) x = (m<0) ? 1 : -1
+/// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m
+struct SignedFloorDivIOpConverter : public OpRewritePattern<SignedFloorDivIOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(SignedFloorDivIOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ SignedFloorDivIOp signedFloorDivIOp = cast<SignedFloorDivIOp>(op);
+ Type type = signedFloorDivIOp.getType();
+ Value a = signedFloorDivIOp.lhs();
+ Value b = signedFloorDivIOp.rhs();
+ Value plusOne =
+ rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, 1));
+ Value zero =
+ rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, 0));
+ Value minusOne =
+ rewriter.create<ConstantOp>(loc, rewriter.getIntegerAttr(type, -1));
+ // Compute x = (b<0) ? 1 : -1.
+ Value compare = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, b, zero);
+ Value x = rewriter.create<SelectOp>(loc, compare, plusOne, minusOne);
+ // Compute negative res: -1 - ((x-a)/b).
+ Value xMinusA = rewriter.create<SubIOp>(loc, x, a);
+ Value xMinusADivB = rewriter.create<SignedDivIOp>(loc, xMinusA, b);
+ Value negRes = rewriter.create<SubIOp>(loc, minusOne, xMinusADivB);
+ // Compute positive res: a/b.
+ Value posRes = rewriter.create<SignedDivIOp>(loc, a, b);
+ // Result is (a*b<0) ? negative result : positive result.
+ // Note, we want to avoid using a*b because of possible overflow.
+ // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
+ // not particuliarly care if a*b<0 is true or false when b is zero
+ // as this will result in an illegal divide. So `a*b<0` can be reformulated
+ // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
+ // We pick the first expression here.
+ Value aNeg = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, a, zero);
+ Value aPos = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, a, zero);
+ Value bNeg = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, b, zero);
+ Value bPos = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, b, zero);
+ Value firstTerm = rewriter.create<AndOp>(loc, aNeg, bPos);
+ Value secondTerm = rewriter.create<AndOp>(loc, aPos, bNeg);
+ Value compareRes = rewriter.create<OrOp>(loc, firstTerm, secondTerm);
+ Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes);
+ // Perform substitution and return success.
+ rewriter.replaceOp(op, {res});
+ return success();
+ }
+};
+
+} // namespace
+
+namespace {
+struct StdExpandDivs : public StdExpandDivsBase<StdExpandDivs> {
+ void runOnFunction() override;
+};
+} // namespace
+
+void StdExpandDivs::runOnFunction() {
+ MLIRContext &ctx = getContext();
+
+ OwningRewritePatternList patterns;
+ populateStdExpandDivsRewritePatterns(&ctx, patterns);
+
+ ConversionTarget target(getContext());
+ target.addLegalDialect<StandardOpsDialect>();
+ target.addIllegalOp<SignedCeilDivIOp>();
+ target.addIllegalOp<SignedFloorDivIOp>();
+ if (failed(
+ applyPartialConversion(getFunction(), target, std::move(patterns))))
+ signalPassFailure();
+}
+
+void mlir::populateStdExpandDivsRewritePatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
+ context);
+}
+
+std::unique_ptr<Pass> mlir::createStdExpandDivsPass() {
+ return std::make_unique<StdExpandDivs>();
+}
diff --git a/mlir/test/Dialect/Standard/std-expand-divs.mlir b/mlir/test/Dialect/Standard/std-expand-divs.mlir
new file mode 100644
index 000000000000..354a7863f423
--- /dev/null
+++ b/mlir/test/Dialect/Standard/std-expand-divs.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt -std-expand-divs %s -split-input-file | FileCheck %s
+
+// Test floor divide with signed integer
+// CHECK-LABEL: func @floordivi
+// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
+func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
+ %res = floordivi_signed %arg0, %arg1 : i32
+ return %res : i32
+// CHECK: [[ONE:%.+]] = constant 1 : i32
+// CHECK: [[ZERO:%.+]] = constant 0 : i32
+// CHECK: [[MIN1:%.+]] = constant -1 : i32
+// CHECK: [[CMP1:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32
+// CHECK: [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : i32
+// CHECK: [[TRUE1:%.+]] = subi [[X]], [[ARG0]] : i32
+// CHECK: [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[ARG1]] : i32
+// CHECK: [[TRUE3:%.+]] = subi [[MIN1]], [[TRUE2]] : i32
+// CHECK: [[FALSE:%.+]] = divi_signed [[ARG0]], [[ARG1]] : i32
+// CHECK: [[NNEG:%.+]] = cmpi "slt", [[ARG0]], [[ZERO]] : i32
+// CHECK: [[NPOS:%.+]] = cmpi "sgt", [[ARG0]], [[ZERO]] : i32
+// CHECK: [[MNEG:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32
+// CHECK: [[MPOS:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32
+// CHECK: [[TERM1:%.+]] = and [[NNEG]], [[MPOS]] : i1
+// CHECK: [[TERM2:%.+]] = and [[NPOS]], [[MNEG]] : i1
+// CHECK: [[CMP2:%.+]] = or [[TERM1]], [[TERM2]] : i1
+// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
+}
+
+// -----
+
+// Test ceil divide with signed integer
+// CHECK-LABEL: func @ceildivi
+// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
+func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
+ %res = ceildivi_signed %arg0, %arg1 : i32
+ return %res : i32
+
+// CHECK: [[ONE:%.+]] = constant 1 : i32
+// CHECK: [[ZERO:%.+]] = constant 0 : i32
+// CHECK: [[MINONE:%.+]] = constant -1 : i32
+// CHECK: [[CMP1:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32
+// CHECK: [[X:%.+]] = select [[CMP1]], [[MINONE]], [[ONE]] : i32
+// CHECK: [[TRUE1:%.+]] = addi [[X]], [[ARG0]] : i32
+// CHECK: [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[ARG1]] : i32
+// CHECK: [[TRUE3:%.+]] = addi [[ONE]], [[TRUE2]] : i32
+// CHECK: [[FALSE1:%.+]] = subi [[ZERO]], [[ARG0]] : i32
+// CHECK: [[FALSE2:%.+]] = divi_signed [[FALSE1]], [[ARG1]] : i32
+// CHECK: [[FALSE3:%.+]] = subi [[ZERO]], [[FALSE2]] : i32
+// CHECK: [[NNEG:%.+]] = cmpi "slt", [[ARG0]], [[ZERO]] : i32
+// CHECK: [[NPOS:%.+]] = cmpi "sgt", [[ARG0]], [[ZERO]] : i32
+// CHECK: [[MNEG:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32
+// CHECK: [[MPOS:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32
+// CHECK: [[TERM1:%.+]] = and [[NNEG]], [[MNEG]] : i1
+// CHECK: [[TERM2:%.+]] = and [[NPOS]], [[MPOS]] : i1
+// CHECK: [[CMP2:%.+]] = or [[TERM1]], [[TERM2]] : i1
+// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
+}
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index da7394eae784..dfed144fb44e 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -569,6 +569,30 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: %{{[0-9]+}} = floorf %arg0 : tensor<4x4x?xf32>
%166 = floorf %t : tensor<4x4x?xf32>
+ // CHECK: %{{[0-9]+}} = floordivi_signed %arg2, %arg2 : i32
+ %167 = floordivi_signed %i, %i : i32
+
+ // CHECK: %{{[0-9]+}} = floordivi_signed %arg3, %arg3 : index
+ %168 = floordivi_signed %idx, %idx : index
+
+ // CHECK: %{{[0-9]+}} = floordivi_signed %cst_5, %cst_5 : vector<42xi32>
+ %169 = floordivi_signed %vci32, %vci32 : vector<42 x i32>
+
+ // CHECK: %{{[0-9]+}} = floordivi_signed %cst_4, %cst_4 : tensor<42xi32>
+ %170 = floordivi_signed %tci32, %tci32 : tensor<42 x i32>
+
+ // CHECK: %{{[0-9]+}} = ceildivi_signed %arg2, %arg2 : i32
+ %171 = ceildivi_signed %i, %i : i32
+
+ // CHECK: %{{[0-9]+}} = ceildivi_signed %arg3, %arg3 : index
+ %172 = ceildivi_signed %idx, %idx : index
+
+ // CHECK: %{{[0-9]+}} = ceildivi_signed %cst_5, %cst_5 : vector<42xi32>
+ %173 = ceildivi_signed %vci32, %vci32 : vector<42 x i32>
+
+ // CHECK: %{{[0-9]+}} = ceildivi_signed %cst_4, %cst_4 : tensor<42xi32>
+ %174 = ceildivi_signed %tci32, %tci32 : tensor<42 x i32>
+
return
}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index dc7be097b0c0..7b8c45cb409b 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -949,6 +949,46 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// -----
+// CHECK-LABEL: func @floordivi_signed_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @floordivi_signed_by_one(%arg0: i32) -> (i32) {
+ %c1 = constant 1 : i32
+ %res = floordivi_signed %arg0, %c1 : i32
+ // CHECK: return %[[ARG]]
+ return %res : i32
+}
+
+// CHECK-LABEL: func @tensor_floordivi_signed_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @tensor_floordivi_signed_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
+ %c1 = constant dense<1> : tensor<4x5xi32>
+ %res = floordivi_signed %arg0, %c1 : tensor<4x5xi32>
+ // CHECK: return %[[ARG]]
+ return %res : tensor<4x5xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @ceildivi_signed_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @ceildivi_signed_by_one(%arg0: i32) -> (i32) {
+ %c1 = constant 1 : i32
+ %res = ceildivi_signed %arg0, %c1 : i32
+ // CHECK: return %[[ARG]]
+ return %res : i32
+}
+
+// CHECK-LABEL: func @tensor_ceildivi_signed_by_one
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
+func @tensor_ceildivi_signed_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
+ %c1 = constant dense<1> : tensor<4x5xi32>
+ %res = ceildivi_signed %arg0, %c1 : tensor<4x5xi32>
+ // CHECK: return %[[ARG]]
+ return %res : tensor<4x5xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @memref_cast_folding_subview
func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
%0 = memref_cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index c75c89877830..234717863046 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -402,6 +402,82 @@ func @divi_unsigned_splat_tensor() -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi3
// -----
+// CHECK-LABEL: func @simple_floordivi_signed
+func @simple_floordivi_signed() -> (i32, i32, i32, i32, i32) {
+ // CHECK-DAG: [[C0:%.+]] = constant 0
+ %z = constant 0 : i32
+ // CHECK-DAG: [[C6:%.+]] = constant 7
+ %0 = constant 7 : i32
+ %1 = constant 2 : i32
+
+ // floor(7, 2) = 3
+ // CHECK-NEXT: [[C3:%.+]] = constant 3 : i32
+ %2 = floordivi_signed %0, %1 : i32
+
+ %3 = constant -2 : i32
+
+ // floor(7, -2) = -4
+ // CHECK-NEXT: [[CM3:%.+]] = constant -4 : i32
+ %4 = floordivi_signed %0, %3 : i32
+
+ %5 = constant -9 : i32
+
+ // floor(-9, 2) = -5
+ // CHECK-NEXT: [[CM4:%.+]] = constant -5 : i32
+ %6 = floordivi_signed %5, %1 : i32
+
+ %7 = constant -13 : i32
+
+ // floor(-13, -2) = 6
+ // CHECK-NEXT: [[CM5:%.+]] = constant 6 : i32
+ %8 = floordivi_signed %7, %3 : i32
+
+ // CHECK-NEXT: [[XZ:%.+]] = floordivi_signed [[C6]], [[C0]]
+ %9 = floordivi_signed %0, %z : i32
+
+ return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+// CHECK-LABEL: func @simple_ceildivi_signed
+func @simple_ceildivi_signed() -> (i32, i32, i32, i32, i32) {
+ // CHECK-DAG: [[C0:%.+]] = constant 0
+ %z = constant 0 : i32
+ // CHECK-DAG: [[C6:%.+]] = constant 7
+ %0 = constant 7 : i32
+ %1 = constant 2 : i32
+
+ // ceil(7, 2) = 4
+ // CHECK-NEXT: [[C3:%.+]] = constant 4 : i32
+ %2 = ceildivi_signed %0, %1 : i32
+
+ %3 = constant -2 : i32
+
+ // ceil(7, -2) = -3
+ // CHECK-NEXT: [[CM3:%.+]] = constant -3 : i32
+ %4 = ceildivi_signed %0, %3 : i32
+
+ %5 = constant -9 : i32
+
+ // ceil(-9, 2) = -4
+ // CHECK-NEXT: [[CM4:%.+]] = constant -4 : i32
+ %6 = ceildivi_signed %5, %1 : i32
+
+ %7 = constant -15 : i32
+
+ // ceil(-15, -2) = 8
+ // CHECK-NEXT: [[CM5:%.+]] = constant 8 : i32
+ %8 = ceildivi_signed %7, %3 : i32
+
+ // CHECK-NEXT: [[XZ:%.+]] = ceildivi_signed [[C6]], [[C0]]
+ %9 = ceildivi_signed %0, %z : i32
+
+ return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32
+}
+
+// -----
+
// CHECK-LABEL: func @simple_remi_signed
func @simple_remi_signed(%a : i32) -> (i32, i32, i32) {
%0 = constant 5 : i32
More information about the Mlir-commits
mailing list