[Mlir-commits] [mlir] 917e451 - [mlir][arith] cmpi: move constant to the right side

Ivan Butygin llvmlistbot at llvm.org
Fri Jul 22 03:42:18 PDT 2022


Author: Ivan Butygin
Date: 2022-07-22T12:39:17+02:00
New Revision: 917e4519bc2ac6fe490953b82c69f7c7a9511dbd

URL: https://github.com/llvm/llvm-project/commit/917e4519bc2ac6fe490953b82c69f7c7a9511dbd
DIFF: https://github.com/llvm/llvm-project/commit/917e4519bc2ac6fe490953b82c69f7c7a9511dbd.diff

LOG: [mlir][arith] cmpi: move constant to the right side

Convert arith.cmpi to the canonical form with constants on the right side
to simplify further optimizations and open more opportunities for CSE.


Differential Revision: https://reviews.llvm.org/D129929

Added: 
    

Modified: 
    flang/test/Fir/boxproc.fir
    flang/test/Lower/array-character.f90
    flang/test/Lower/host-associated.f90
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir
    mlir/test/Transforms/sccp-structured.mlir

Removed: 
    


################################################################################
diff  --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 128f6e0b252e7..30c293e182505 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -93,7 +93,7 @@ func.func @_QPtest_proc_dummy_other(%arg0: !fir.boxproc<() -> ()>) {
 // CHECK:         %[[VAL_4:.*]] = load { ptr, i64 }, ptr %[[VAL_3]], align 8
 // CHECK:         %[[VAL_5:.*]] = extractvalue { ptr, i64 } %[[VAL_4]], 0
 // CHECK:         %[[VAL_6:.*]] = extractvalue { ptr, i64 } %[[VAL_4]], 1
-// CHECK:         %[[VAL_8:.*]] = icmp slt i64 10, %[[VAL_6]]
+// CHECK:         %[[VAL_8:.*]] = icmp sgt i64 %[[VAL_6]], 10
 // CHECK:         %[[VAL_9:.*]] = select i1 %[[VAL_8]], i64 10, i64 %[[VAL_6]]
 // CHECK:         call void @llvm.memmove.p0.p0.i64(ptr %[[VAL_0]], ptr %[[VAL_5]], i64 %[[VAL_9]], i1 false)
 // CHECK:         %[[VAL_10:.*]] = sub i64 10, %[[VAL_9]]
@@ -129,7 +129,7 @@ func.func @_QPtest_proc_dummy_other(%arg0: !fir.boxproc<() -> ()>) {
 // CHECK:         %[[VAL_27:.*]] = load [1 x i8], ptr %[[VAL_26]], align 1
 // CHECK:         %[[VAL_29:.*]] = getelementptr [1 x i8], ptr %[[VAL_14]], i64 %[[VAL_18]]
 // CHECK:         store [1 x i8] %[[VAL_27]], ptr %[[VAL_29]], align 1
-// CHECK:         %[[VAL_30:.*]] = icmp slt i64 40, %[[VAL_13]]
+// CHECK:         %[[VAL_30:.*]] = icmp sgt i64 %[[VAL_13]], 40
 // CHECK:         %[[VAL_31:.*]] =  select i1 %[[VAL_30]], i64 40, i64 %[[VAL_13]]
 // CHECK:         call void @llvm.memmove.p0.p0.i64(ptr %[[VAL_0]], ptr %[[VAL_14]], i64 %[[VAL_31]], i1 false)
 // CHECK:         %[[VAL_32:.*]] = sub i64 40, %[[VAL_31]]

diff  --git a/flang/test/Lower/array-character.f90 b/flang/test/Lower/array-character.f90
index 476dae43b67ae..82b77fbb6f662 100644
--- a/flang/test/Lower/array-character.f90
+++ b/flang/test/Lower/array-character.f90
@@ -22,7 +22,7 @@ subroutine issue(c1, c2)
   ! CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_7]] : index
   ! CHECK: %[[VAL_17:.*]] = fir.array_coor %[[VAL_11]](%[[VAL_12]]) %[[VAL_16]] typeparams %[[VAL_10]]#1 : (!fir.ref<!fir.array<3x!fir.char<1,?>>>, !fir.shape<1>, index, index) -> !fir.ref<!fir.char<1,?>>
   ! CHECK: %[[VAL_18:.*]] = fir.array_coor %[[VAL_9]](%[[VAL_12]]) %[[VAL_16]] : (!fir.ref<!fir.array<3x!fir.char<1,4>>>, !fir.shape<1>, index) -> !fir.ref<!fir.char<1,4>>
-  ! CHECK: %[[VAL_19:.*]] = arith.cmpi slt, %[[VAL_5]], %[[VAL_10]]#1 : index
+  ! CHECK: %[[VAL_19:.*]] = arith.cmpi sgt, %[[VAL_10]]#1, %[[VAL_5]] : index
   ! CHECK: %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_5]], %[[VAL_10]]#1 : index
   ! CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (index) -> i64
   ! CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_18]] : (!fir.ref<!fir.char<1,4>>) -> !fir.ref<i8>

diff  --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index aa1f8b022646a..4dc66cb33a806 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -540,7 +540,7 @@ end subroutine test_proc_dummy_other
 ! CHECK:         %[[VAL_10:.*]] = fir.load %[[VAL_9]] : !fir.ref<!fir.boxchar<1>>
 ! CHECK:         %[[VAL_11:.*]]:2 = fir.unboxchar %[[VAL_10]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
 ! CHECK:         %[[VAL_12:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<!fir.char<1,?>>
-! CHECK:         %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_4]], %[[VAL_11]]#1 : index
+! CHECK:         %[[VAL_13:.*]] = arith.cmpi sgt, %[[VAL_11]]#1, %[[VAL_4]] : index
 ! CHECK:         %[[VAL_14:.*]] = arith.select %[[VAL_13]], %[[VAL_4]], %[[VAL_11]]#1 : index
 ! CHECK:         %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (index) -> i64
 ! CHECK:         %[[VAL_16:.*]] = fir.convert %[[VAL_12]] : (!fir.ref<!fir.char<1,?>>) -> !fir.ref<i8>
@@ -607,7 +607,7 @@ end subroutine test_proc_dummy_other
 ! CHECK:         %[[VAL_34:.*]] = arith.subi %[[VAL_25]], %[[VAL_6]] : index
 ! CHECK:         br ^bb1(%[[VAL_33]], %[[VAL_34]] : index, index)
 ! CHECK:       ^bb3:
-! CHECK:         %[[VAL_35:.*]] = arith.cmpi slt, %[[VAL_3]], %[[VAL_19]] : index
+! CHECK:         %[[VAL_35:.*]] = arith.cmpi sgt, %[[VAL_19]], %[[VAL_3]] : index
 ! CHECK:         %[[VAL_36:.*]] = arith.select %[[VAL_35]], %[[VAL_3]], %[[VAL_19]] : index
 ! CHECK:         %[[VAL_37:.*]] = fir.convert %[[VAL_36]] : (index) -> i64
 ! CHECK:         %[[VAL_38:.*]] = fir.convert %[[VAL_9]] : (!fir.ref<!fir.char<1,?>>) -> !fir.ref<i8>

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index f9f4c686a0503..939c969d95cae 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1332,11 +1332,38 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
     }
   }
 
+  // Move constant to the right side.
+  if (operands[0] && !operands[1]) {
+    // Do not use invertPredicate, as it will change eq to ne and vice versa.
+    using Pred = CmpIPredicate;
+    const std::pair<Pred, Pred> invPreds[] = {
+        {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
+        {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
+        {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
+        {Pred::ne, Pred::ne},
+    };
+    Pred origPred = getPredicate();
+    for (auto pred : invPreds) {
+      if (origPred == pred.first) {
+        setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
+        Value lhs = getLhs();
+        Value rhs = getRhs();
+        getLhsMutable().assign(rhs);
+        getRhsMutable().assign(lhs);
+        return getResult();
+      }
+    }
+    llvm_unreachable("unknown cmpi predicate kind");
+  }
+
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
-  auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
-  if (!lhs || !rhs)
+  if (!lhs)
     return {};
 
+  // We are moving constants to the right side; So if lhs is constant rhs is
+  // guaranteed to be a constant.
+  auto rhs = operands.back().cast<IntegerAttr>();
+
   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
   return BoolAttr::get(getContext(), val);
 }

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 2f563eb598d79..0a4d08ae071af 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -127,6 +127,41 @@ func.func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>)
 
 // -----
 
+// Test case: Move constant to the right side.
+// CHECK-LABEL: @cmpi_const_right(
+//  CHECK-SAME: %[[ARG:.*]]:
+//       CHECK:   %[[C:.*]] = arith.constant 1 : i64
+//       CHECK:   %[[R0:.*]] = arith.cmpi eq, %[[ARG]], %[[C]] : i64
+//       CHECK:   %[[R1:.*]] = arith.cmpi sge, %[[ARG]], %[[C]] : i64
+//       CHECK:   %[[R2:.*]] = arith.cmpi sle, %[[ARG]], %[[C]] : i64
+//       CHECK:   %[[R3:.*]] = arith.cmpi uge, %[[ARG]], %[[C]] : i64
+//       CHECK:   %[[R4:.*]] = arith.cmpi ule, %[[ARG]], %[[C]] : i64
+//       CHECK:   %[[R5:.*]] = arith.cmpi ne, %[[ARG]], %[[C]] : i64
+//       CHECK:   %[[R6:.*]] = arith.cmpi sgt, %[[ARG]], %[[C]] : i64
+//       CHECK:   %[[R7:.*]] = arith.cmpi slt, %[[ARG]], %[[C]] : i64
+//       CHECK:   %[[R8:.*]] = arith.cmpi ugt, %[[ARG]], %[[C]] : i64
+//       CHECK:   %[[R9:.*]] = arith.cmpi ult, %[[ARG]], %[[C]] : i64
+//       CHECK:   return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]],
+//  CHECK-SAME:          %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]]
+func.func @cmpi_const_right(%arg0: i64)
+    -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
+  %c1 = arith.constant 1 : i64
+  %0 = arith.cmpi eq, %c1, %arg0 : i64
+  %1 = arith.cmpi sle, %c1, %arg0 : i64
+  %2 = arith.cmpi sge, %c1, %arg0 : i64
+  %3 = arith.cmpi ule, %c1, %arg0 : i64
+  %4 = arith.cmpi uge, %c1, %arg0 : i64
+  %5 = arith.cmpi ne, %c1, %arg0 : i64
+  %6 = arith.cmpi slt, %c1, %arg0 : i64
+  %7 = arith.cmpi sgt, %c1, %arg0 : i64
+  %8 = arith.cmpi ult, %c1, %arg0 : i64
+  %9 = arith.cmpi ugt, %c1, %arg0 : i64
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
+      : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
+}
+
+// -----
+
 // CHECK-LABEL: @cmpOfExtSI
 //  CHECK-NEXT:   return %arg0
 func.func @cmpOfExtSI(%arg0: i1) -> i1 {

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 72bfdd6e580b2..1c63337103482 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -819,10 +819,10 @@ func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
 // CHECK:      %[[c0:.*]] = arith.constant 0 : index
 // CHECK:      %[[c1:.*]] = arith.constant 1 : index
 // CHECK:      %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1>
-// CHECK:      %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index
+// CHECK:      %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
 // CHECK:      %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
-// CHECK:      %[[T4:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index
+// CHECK:      %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
 // CHECK:      %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1>
 // CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1>
 // CHECK:      return %[[T6]] : vector<2x3xi1>
@@ -842,13 +842,13 @@ func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
 // CHECK-DAG:  %[[c0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[c1:.*]] = arith.constant 1 : index
 // CHECK:      %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1>
-// CHECK:      %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[B]] : index
+// CHECK:      %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index
 // CHECK:      %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1>
 // CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1>
-// CHECK:      %[[T4:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index
+// CHECK:      %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
 // CHECK:      %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1>
 // CHECK:      %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1>
-// CHECK:      %[[T7:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index
+// CHECK:      %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
 // CHECK:      %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1>
 // CHECK:      %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1>
 // CHECK:      return %[[T9]] : vector<2x1x7xi1>

diff  --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir
index 32af26d623018..529d41554a473 100644
--- a/mlir/test/Transforms/sccp-structured.mlir
+++ b/mlir/test/Transforms/sccp-structured.mlir
@@ -141,7 +141,7 @@ func.func @loop_region_branch_terminator_op(%arg1 : i32) {
 
   %c2_i32 = arith.constant 2 : i32
    %0 = scf.while (%arg2 = %c2_i32) : (i32) -> (i32) {
-    %1 = arith.cmpi slt, %arg2, %arg1 : i32
+    %1 = arith.cmpi sgt, %arg1, %arg2 : i32
     scf.condition(%1) %arg2 : i32
   } do {
   ^bb0(%arg2: i32):


        


More information about the Mlir-commits mailing list