[Mlir-commits] [mlir] 917e451 - [mlir][arith] cmpi: move constant to the right side
Ivan Butygin
llvmlistbot at llvm.org
Fri Jul 22 03:42:18 PDT 2022
Author: Ivan Butygin
Date: 2022-07-22T12:39:17+02:00
New Revision: 917e4519bc2ac6fe490953b82c69f7c7a9511dbd
URL: https://github.com/llvm/llvm-project/commit/917e4519bc2ac6fe490953b82c69f7c7a9511dbd
DIFF: https://github.com/llvm/llvm-project/commit/917e4519bc2ac6fe490953b82c69f7c7a9511dbd.diff
LOG: [mlir][arith] cmpi: move constant to the right side
Convert arith.cmpi to the canonical form with constants on the right side
to simplify further optimizations and open more opportunities for CSE.
Differential Revision: https://reviews.llvm.org/D129929
Added:
Modified:
flang/test/Fir/boxproc.fir
flang/test/Lower/array-character.f90
flang/test/Lower/host-associated.f90
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
mlir/test/Transforms/sccp-structured.mlir
Removed:
################################################################################
diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 128f6e0b252e7..30c293e182505 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -93,7 +93,7 @@ func.func @_QPtest_proc_dummy_other(%arg0: !fir.boxproc<() -> ()>) {
// CHECK: %[[VAL_4:.*]] = load { ptr, i64 }, ptr %[[VAL_3]], align 8
// CHECK: %[[VAL_5:.*]] = extractvalue { ptr, i64 } %[[VAL_4]], 0
// CHECK: %[[VAL_6:.*]] = extractvalue { ptr, i64 } %[[VAL_4]], 1
-// CHECK: %[[VAL_8:.*]] = icmp slt i64 10, %[[VAL_6]]
+// CHECK: %[[VAL_8:.*]] = icmp sgt i64 %[[VAL_6]], 10
// CHECK: %[[VAL_9:.*]] = select i1 %[[VAL_8]], i64 10, i64 %[[VAL_6]]
// CHECK: call void @llvm.memmove.p0.p0.i64(ptr %[[VAL_0]], ptr %[[VAL_5]], i64 %[[VAL_9]], i1 false)
// CHECK: %[[VAL_10:.*]] = sub i64 10, %[[VAL_9]]
@@ -129,7 +129,7 @@ func.func @_QPtest_proc_dummy_other(%arg0: !fir.boxproc<() -> ()>) {
// CHECK: %[[VAL_27:.*]] = load [1 x i8], ptr %[[VAL_26]], align 1
// CHECK: %[[VAL_29:.*]] = getelementptr [1 x i8], ptr %[[VAL_14]], i64 %[[VAL_18]]
// CHECK: store [1 x i8] %[[VAL_27]], ptr %[[VAL_29]], align 1
-// CHECK: %[[VAL_30:.*]] = icmp slt i64 40, %[[VAL_13]]
+// CHECK: %[[VAL_30:.*]] = icmp sgt i64 %[[VAL_13]], 40
// CHECK: %[[VAL_31:.*]] = select i1 %[[VAL_30]], i64 40, i64 %[[VAL_13]]
// CHECK: call void @llvm.memmove.p0.p0.i64(ptr %[[VAL_0]], ptr %[[VAL_14]], i64 %[[VAL_31]], i1 false)
// CHECK: %[[VAL_32:.*]] = sub i64 40, %[[VAL_31]]
diff --git a/flang/test/Lower/array-character.f90 b/flang/test/Lower/array-character.f90
index 476dae43b67ae..82b77fbb6f662 100644
--- a/flang/test/Lower/array-character.f90
+++ b/flang/test/Lower/array-character.f90
@@ -22,7 +22,7 @@ subroutine issue(c1, c2)
! CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_7]] : index
! CHECK: %[[VAL_17:.*]] = fir.array_coor %[[VAL_11]](%[[VAL_12]]) %[[VAL_16]] typeparams %[[VAL_10]]#1 : (!fir.ref<!fir.array<3x!fir.char<1,?>>>, !fir.shape<1>, index, index) -> !fir.ref<!fir.char<1,?>>
! CHECK: %[[VAL_18:.*]] = fir.array_coor %[[VAL_9]](%[[VAL_12]]) %[[VAL_16]] : (!fir.ref<!fir.array<3x!fir.char<1,4>>>, !fir.shape<1>, index) -> !fir.ref<!fir.char<1,4>>
- ! CHECK: %[[VAL_19:.*]] = arith.cmpi slt, %[[VAL_5]], %[[VAL_10]]#1 : index
+ ! CHECK: %[[VAL_19:.*]] = arith.cmpi sgt, %[[VAL_10]]#1, %[[VAL_5]] : index
! CHECK: %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_5]], %[[VAL_10]]#1 : index
! CHECK: %[[VAL_21:.*]] = fir.convert %[[VAL_20]] : (index) -> i64
! CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_18]] : (!fir.ref<!fir.char<1,4>>) -> !fir.ref<i8>
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index aa1f8b022646a..4dc66cb33a806 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -540,7 +540,7 @@ end subroutine test_proc_dummy_other
! CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_9]] : !fir.ref<!fir.boxchar<1>>
! CHECK: %[[VAL_11:.*]]:2 = fir.unboxchar %[[VAL_10]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
! CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.char<1,10>>) -> !fir.ref<!fir.char<1,?>>
-! CHECK: %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_4]], %[[VAL_11]]#1 : index
+! CHECK: %[[VAL_13:.*]] = arith.cmpi sgt, %[[VAL_11]]#1, %[[VAL_4]] : index
! CHECK: %[[VAL_14:.*]] = arith.select %[[VAL_13]], %[[VAL_4]], %[[VAL_11]]#1 : index
! CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (index) -> i64
! CHECK: %[[VAL_16:.*]] = fir.convert %[[VAL_12]] : (!fir.ref<!fir.char<1,?>>) -> !fir.ref<i8>
@@ -607,7 +607,7 @@ end subroutine test_proc_dummy_other
! CHECK: %[[VAL_34:.*]] = arith.subi %[[VAL_25]], %[[VAL_6]] : index
! CHECK: br ^bb1(%[[VAL_33]], %[[VAL_34]] : index, index)
! CHECK: ^bb3:
-! CHECK: %[[VAL_35:.*]] = arith.cmpi slt, %[[VAL_3]], %[[VAL_19]] : index
+! CHECK: %[[VAL_35:.*]] = arith.cmpi sgt, %[[VAL_19]], %[[VAL_3]] : index
! CHECK: %[[VAL_36:.*]] = arith.select %[[VAL_35]], %[[VAL_3]], %[[VAL_19]] : index
! CHECK: %[[VAL_37:.*]] = fir.convert %[[VAL_36]] : (index) -> i64
! CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_9]] : (!fir.ref<!fir.char<1,?>>) -> !fir.ref<i8>
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index f9f4c686a0503..939c969d95cae 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1332,11 +1332,38 @@ OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
}
}
+ // Move constant to the right side.
+ if (operands[0] && !operands[1]) {
+ // Do not use invertPredicate, as it will change eq to ne and vice versa.
+ using Pred = CmpIPredicate;
+ const std::pair<Pred, Pred> invPreds[] = {
+ {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
+ {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
+ {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
+ {Pred::ne, Pred::ne},
+ };
+ Pred origPred = getPredicate();
+ for (auto pred : invPreds) {
+ if (origPred == pred.first) {
+ setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
+ Value lhs = getLhs();
+ Value rhs = getRhs();
+ getLhsMutable().assign(rhs);
+ getRhsMutable().assign(lhs);
+ return getResult();
+ }
+ }
+ llvm_unreachable("unknown cmpi predicate kind");
+ }
+
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
- auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
- if (!lhs || !rhs)
+ if (!lhs)
return {};
+ // We are moving constants to the right side; So if lhs is constant rhs is
+ // guaranteed to be a constant.
+ auto rhs = operands.back().cast<IntegerAttr>();
+
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return BoolAttr::get(getContext(), val);
}
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 2f563eb598d79..0a4d08ae071af 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -127,6 +127,41 @@ func.func @cmpi_equal_vector_operands(%arg0: vector<1x8xi64>)
// -----
+// Test case: Move constant to the right side.
+// CHECK-LABEL: @cmpi_const_right(
+// CHECK-SAME: %[[ARG:.*]]:
+// CHECK: %[[C:.*]] = arith.constant 1 : i64
+// CHECK: %[[R0:.*]] = arith.cmpi eq, %[[ARG]], %[[C]] : i64
+// CHECK: %[[R1:.*]] = arith.cmpi sge, %[[ARG]], %[[C]] : i64
+// CHECK: %[[R2:.*]] = arith.cmpi sle, %[[ARG]], %[[C]] : i64
+// CHECK: %[[R3:.*]] = arith.cmpi uge, %[[ARG]], %[[C]] : i64
+// CHECK: %[[R4:.*]] = arith.cmpi ule, %[[ARG]], %[[C]] : i64
+// CHECK: %[[R5:.*]] = arith.cmpi ne, %[[ARG]], %[[C]] : i64
+// CHECK: %[[R6:.*]] = arith.cmpi sgt, %[[ARG]], %[[C]] : i64
+// CHECK: %[[R7:.*]] = arith.cmpi slt, %[[ARG]], %[[C]] : i64
+// CHECK: %[[R8:.*]] = arith.cmpi ugt, %[[ARG]], %[[C]] : i64
+// CHECK: %[[R9:.*]] = arith.cmpi ult, %[[ARG]], %[[C]] : i64
+// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]],
+// CHECK-SAME: %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]]
+func.func @cmpi_const_right(%arg0: i64)
+ -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
+ %c1 = arith.constant 1 : i64
+ %0 = arith.cmpi eq, %c1, %arg0 : i64
+ %1 = arith.cmpi sle, %c1, %arg0 : i64
+ %2 = arith.cmpi sge, %c1, %arg0 : i64
+ %3 = arith.cmpi ule, %c1, %arg0 : i64
+ %4 = arith.cmpi uge, %c1, %arg0 : i64
+ %5 = arith.cmpi ne, %c1, %arg0 : i64
+ %6 = arith.cmpi slt, %c1, %arg0 : i64
+ %7 = arith.cmpi sgt, %c1, %arg0 : i64
+ %8 = arith.cmpi ult, %c1, %arg0 : i64
+ %9 = arith.cmpi ugt, %c1, %arg0 : i64
+ return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
+ : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
+}
+
+// -----
+
// CHECK-LABEL: @cmpOfExtSI
// CHECK-NEXT: return %arg0
func.func @cmpOfExtSI(%arg0: i1) -> i1 {
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 72bfdd6e580b2..1c63337103482 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -819,10 +819,10 @@ func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1>
-// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index
+// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
-// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index
+// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1>
// CHECK: return %[[T6]] : vector<2x3xi1>
@@ -842,13 +842,13 @@ func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1>
-// CHECK: %[[T1:.*]] = arith.cmpi slt, %[[c0]], %[[B]] : index
+// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index
// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1>
-// CHECK: %[[T4:.*]] = arith.cmpi slt, %[[c0]], %[[A]] : index
+// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1>
-// CHECK: %[[T7:.*]] = arith.cmpi slt, %[[c1]], %[[A]] : index
+// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
// CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1>
// CHECK: return %[[T9]] : vector<2x1x7xi1>
diff --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir
index 32af26d623018..529d41554a473 100644
--- a/mlir/test/Transforms/sccp-structured.mlir
+++ b/mlir/test/Transforms/sccp-structured.mlir
@@ -141,7 +141,7 @@ func.func @loop_region_branch_terminator_op(%arg1 : i32) {
%c2_i32 = arith.constant 2 : i32
%0 = scf.while (%arg2 = %c2_i32) : (i32) -> (i32) {
- %1 = arith.cmpi slt, %arg2, %arg1 : i32
+ %1 = arith.cmpi sgt, %arg1, %arg2 : i32
scf.condition(%1) %arg2 : i32
} do {
^bb0(%arg2: i32):
More information about the Mlir-commits
mailing list