[Mlir-commits] [mlir] 516884f - [MLIR] Fix FloorDivSIOpConverter that was failing for index type after the arithmetic op refactor
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 21 14:42:33 PDT 2021
Author: Mogball
Date: 2021-10-21T21:42:30Z
New Revision: 516884f58b46a60d0aa499e19d792c39f2478aa4
URL: https://github.com/llvm/llvm-project/commit/516884f58b46a60d0aa499e19d792c39f2478aa4
DIFF: https://github.com/llvm/llvm-project/commit/516884f58b46a60d0aa499e19d792c39f2478aa4.diff
LOG: [MLIR] Fix FloorDivSIOpConverter that was failing for index type after the arithmetic op refactor
ConstantOp should be used instead of ConstantIntOp to be able to support index type.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D112191
Added:
Modified:
mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
mlir/test/Dialect/Arithmetic/expand-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index 03fac9e5fb56a..af6de99ce5ce9 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -81,9 +81,12 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
Type type = signedFloorDivIOp.getType();
Value a = signedFloorDivIOp.lhs();
Value b = signedFloorDivIOp.rhs();
- Value plusOne = rewriter.create<arith::ConstantIntOp>(loc, 1, type);
- Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, type);
- Value minusOne = rewriter.create<arith::ConstantIntOp>(loc, -1, type);
+ Value plusOne = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(type, 1));
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(type, 0));
+ Value minusOne = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(type, -1));
// Compute x = (b<0) ? 1 : -1.
Value compare =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
index 160cf3acccce6..23ab1267e0ab6 100644
--- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
@@ -30,6 +30,36 @@ func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
// -----
+// Test ceil divide with index type
+// CHECK-LABEL: func @ceildivi_index
+// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
+func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
+ %res = arith.ceildivsi %arg0, %arg1 : index
+ return %res : index
+
+// CHECK: [[ONE:%.+]] = arith.constant 1 : index
+// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
+// CHECK: [[MINONE:%.+]] = arith.constant -1 : index
+// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
+// CHECK: [[X:%.+]] = select [[CMP1]], [[MINONE]], [[ONE]] : index
+// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
+// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
+// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
+// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
+// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
+// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
+// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
+// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
+// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
+// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
+// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
+// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
+// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
+// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
+}
+
+// -----
+
// Test floor divide with signed integer
// CHECK-LABEL: func @floordivi
// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
@@ -54,3 +84,30 @@ func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
}
+
+// -----
+
+// Test floor divide with index type
+// CHECK-LABEL: func @floordivi_index
+// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
+func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
+ %res = arith.floordivsi %arg0, %arg1 : index
+ return %res : index
+// CHECK: [[ONE:%.+]] = arith.constant 1 : index
+// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
+// CHECK: [[MIN1:%.+]] = arith.constant -1 : index
+// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
+// CHECK: [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : index
+// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index
+// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
+// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index
+// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
+// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
+// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
+// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
+// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
+// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
+// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
+// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
+// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : index
+}
More information about the Mlir-commits
mailing list