[Mlir-commits] [mlir] 516884f - [MLIR] Fix FloorDivSIOpConverter that was failing for index type after the arithmetic op refactor

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 21 14:42:33 PDT 2021


Author: Mogball
Date: 2021-10-21T21:42:30Z
New Revision: 516884f58b46a60d0aa499e19d792c39f2478aa4

URL: https://github.com/llvm/llvm-project/commit/516884f58b46a60d0aa499e19d792c39f2478aa4
DIFF: https://github.com/llvm/llvm-project/commit/516884f58b46a60d0aa499e19d792c39f2478aa4.diff

LOG: [MLIR] Fix FloorDivSIOpConverter that was failing for index type after the arithmetic op refactor

ConstantOp should be used instead of ConstantIntOp to be able to support index type.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D112191

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Arithmetic/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index 03fac9e5fb56a..af6de99ce5ce9 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -81,9 +81,12 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
     Type type = signedFloorDivIOp.getType();
     Value a = signedFloorDivIOp.lhs();
     Value b = signedFloorDivIOp.rhs();
-    Value plusOne = rewriter.create<arith::ConstantIntOp>(loc, 1, type);
-    Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, type);
-    Value minusOne = rewriter.create<arith::ConstantIntOp>(loc, -1, type);
+    Value plusOne = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIntegerAttr(type, 1));
+    Value zero = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIntegerAttr(type, 0));
+    Value minusOne = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIntegerAttr(type, -1));
     // Compute x = (b<0) ? 1 : -1.
     Value compare =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);

diff  --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
index 160cf3acccce6..23ab1267e0ab6 100644
--- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
@@ -30,6 +30,36 @@ func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
 
 // -----
 
+// Test ceil divide with index type
+// CHECK-LABEL:       func @ceildivi_index
+// CHECK-SAME:     ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
+func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
+  %res = arith.ceildivsi %arg0, %arg1 : index
+  return %res : index
+
+// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
+// CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
+// CHECK:           [[MINONE:%.+]] = arith.constant -1 : index
+// CHECK:           [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
+// CHECK:           [[X:%.+]] = select [[CMP1]], [[MINONE]], [[ONE]] : index
+// CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
+// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
+// CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
+// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
+// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
+// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
+// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
+// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
+// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
+// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
+// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
+// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
+// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
+// CHECK:           [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
+}
+
+// -----
+
 // Test floor divide with signed integer
 // CHECK-LABEL:       func @floordivi
 // CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
@@ -54,3 +84,30 @@ func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
 // CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
 // CHECK:           [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
 }
+
+// -----
+
+// Test floor divide with index type
+// CHECK-LABEL:       func @floordivi_index
+// CHECK-SAME:     ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
+func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
+  %res = arith.floordivsi %arg0, %arg1 : index
+  return %res : index
+// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
+// CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
+// CHECK:           [[MIN1:%.+]] = arith.constant -1 : index
+// CHECK:           [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
+// CHECK:           [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : index
+// CHECK:           [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index
+// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
+// CHECK:           [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index
+// CHECK:           [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
+// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
+// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
+// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
+// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
+// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
+// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
+// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
+// CHECK:           [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : index
+}


        


More information about the Mlir-commits mailing list