[llvm] [InstCombine] Fold `(x < y) ? -1 : zext(x > y)` and `(x > y) ? 1 : sext(x < y)` to `ucmp/scmp(x, y)` (PR #105272)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 20 12:19:55 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Volodymyr Vasylkun (Poseydon42)

<details>
<summary>Changes</summary>

This patch expands already existing funcionality to include these two additional folds, which are nearly identical to the ones already implemented.

Proofs: https://alive2.llvm.org/ce/z/Xy7s4j

---
Full diff: https://github.com/llvm/llvm-project/pull/105272.diff


4 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+13-4) 
- (modified) llvm/test/Transforms/InstCombine/scmp.ll (+28) 
- (modified) llvm/test/Transforms/InstCombine/select-select.ll (+11-31) 
- (modified) llvm/test/Transforms/InstCombine/ucmp.ll (+30-2) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 1f6d5759883fd0..18ffc209f259e0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3560,7 +3560,9 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
 
 // This function tries to fold the following operations:
 //   (x < y) ? -1 : zext(x != y)
+//   (x < y) ? -1 : zext(x > y)
 //   (x > y) ? 1 : sext(x != y)
+//   (x > y) ? 1 : sext(x < y)
 // Into ucmp/scmp(x, y), where signedness is determined by the signedness
 // of the comparison in the original sequence.
 Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
@@ -3589,16 +3591,23 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
       ICmpInst::isSigned(Pred) ? Intrinsic::scmp : Intrinsic::ucmp;
 
   bool Replace = false;
+  ICmpInst::Predicate ExtendedCmpPredicate;
   // (x < y) ? -1 : zext(x != y)
+  // (x < y) ? -1 : zext(x > y)
   if (ICmpInst::isLT(Pred) && match(TV, m_AllOnes()) &&
-      match(FV, m_ZExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
-                                        m_Specific(RHS)))))
+      match(FV, m_ZExt(m_c_ICmp(ExtendedCmpPredicate, m_Specific(LHS),
+                                m_Specific(RHS)))) &&
+      (ExtendedCmpPredicate == ICmpInst::ICMP_NE ||
+       ICmpInst::getSwappedPredicate(ExtendedCmpPredicate) == Pred))
     Replace = true;
 
   // (x > y) ? 1 : sext(x != y)
+  // (x > y) ? 1 : sext(x < y)
   if (ICmpInst::isGT(Pred) && match(TV, m_One()) &&
-      match(FV, m_SExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
-                                        m_Specific(RHS)))))
+      match(FV, m_SExt(m_c_ICmp(ExtendedCmpPredicate, m_Specific(LHS),
+                                m_Specific(RHS)))) &&
+      (ExtendedCmpPredicate == ICmpInst::ICMP_NE ||
+       ICmpInst::getSwappedPredicate(ExtendedCmpPredicate) == Pred))
     Replace = true;
 
   if (Replace)
diff --git a/llvm/test/Transforms/InstCombine/scmp.ll b/llvm/test/Transforms/InstCombine/scmp.ll
index 7f374c5f9a1d64..e2312140c8c13d 100644
--- a/llvm/test/Transforms/InstCombine/scmp.ll
+++ b/llvm/test/Transforms/InstCombine/scmp.ll
@@ -223,6 +223,20 @@ define i8 @scmp_from_select_lt(i32 %x, i32 %y) {
   ret i8 %r
 }
 
+; Fold (x s< y) ? -1 : zext(x s> y) into scmp(x, y)
+define i8 @scmp_from_select_lt_and_gt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_lt_and_gt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %gt_bool = icmp sgt i32 %x, %y
+  %gt = zext i1 %gt_bool to i8
+  %lt = icmp slt i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %gt
+  ret i8 %r
+}
+
 ; Vector version
 define <4 x i8> @scmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: define <4 x i8> @scmp_from_select_vec_lt(
@@ -264,3 +278,17 @@ define i8 @scmp_from_select_ge(i32 %x, i32 %y) {
   %r = select i1 %ge, i8 %ne, i8 -1
   ret i8 %r
 }
+
+; Fold (x s> y) ? 1 : sext(x s< y)
+define i8 @scmp_from_select_gt_and_lt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_gt_and_lt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lt_bool = icmp slt i32 %x, %y
+  %lt = sext i1 %lt_bool to i8
+  %gt = icmp sgt i32 %x, %y
+  %r = select i1 %gt, i8 1, i8 %lt
+  ret i8 %r
+}
diff --git a/llvm/test/Transforms/InstCombine/select-select.ll b/llvm/test/Transforms/InstCombine/select-select.ll
index 5460ba1bc55838..1feae5ab504dcf 100644
--- a/llvm/test/Transforms/InstCombine/select-select.ll
+++ b/llvm/test/Transforms/InstCombine/select-select.ll
@@ -18,9 +18,9 @@ define float @foo1(float %a) {
 
 define float @foo2(float %a) {
 ; CHECK-LABEL: @foo2(
-; CHECK-NEXT:    [[B:%.*]] = fcmp ule float [[C:%.*]], 0.000000e+00
-; CHECK-NEXT:    [[D:%.*]] = fcmp olt float [[C]], 1.000000e+00
-; CHECK-NEXT:    [[E:%.*]] = select i1 [[D]], float [[C]], float 1.000000e+00
+; CHECK-NEXT:    [[B:%.*]] = fcmp ule float [[A:%.*]], 0.000000e+00
+; CHECK-NEXT:    [[TMP1:%.*]] = fcmp olt float [[A]], 1.000000e+00
+; CHECK-NEXT:    [[E:%.*]] = select i1 [[TMP1]], float [[A]], float 1.000000e+00
 ; CHECK-NEXT:    [[F:%.*]] = select i1 [[B]], float 0.000000e+00, float [[E]]
 ; CHECK-NEXT:    ret float [[F]]
 ;
@@ -330,10 +330,7 @@ define i8 @strong_order_cmp_eq_ugt(i32 %a, i32 %b) {
 
 define i8 @strong_order_cmp_slt_sgt(i32 %a, i32 %b) {
 ; CHECK-LABEL: @strong_order_cmp_slt_sgt(
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
+; CHECK-NEXT:    [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    ret i8 [[SEL_GT]]
 ;
   %cmp.lt = icmp slt i32 %a, %b
@@ -345,10 +342,7 @@ define i8 @strong_order_cmp_slt_sgt(i32 %a, i32 %b) {
 
 define i8 @strong_order_cmp_ult_ugt(i32 %a, i32 %b) {
 ; CHECK-LABEL: @strong_order_cmp_ult_ugt(
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
+; CHECK-NEXT:    [[SEL_GT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    ret i8 [[SEL_GT]]
 ;
   %cmp.lt = icmp ult i32 %a, %b
@@ -360,10 +354,7 @@ define i8 @strong_order_cmp_ult_ugt(i32 %a, i32 %b) {
 
 define i8 @strong_order_cmp_sgt_slt(i32 %a, i32 %b) {
 ; CHECK-LABEL: @strong_order_cmp_sgt_slt(
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp sgt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp slt i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    ret i8 [[SEL_LT]]
 ;
   %cmp.gt = icmp sgt i32 %a, %b
@@ -375,10 +366,7 @@ define i8 @strong_order_cmp_sgt_slt(i32 %a, i32 %b) {
 
 define i8 @strong_order_cmp_ugt_ult(i32 %a, i32 %b) {
 ; CHECK-LABEL: @strong_order_cmp_ugt_ult(
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    ret i8 [[SEL_LT]]
 ;
   %cmp.gt = icmp ugt i32 %a, %b
@@ -460,8 +448,7 @@ define i8 @strong_order_cmp_ugt_ult_zext_not_oneuse(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
 ; CHECK-NEXT:    call void @use8(i8 [[ZEXT]])
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
 ; CHECK-NEXT:    ret i8 [[SEL_LT]]
 ;
   %cmp.gt = icmp ugt i32 %a, %b
@@ -477,8 +464,7 @@ define i8 @strong_order_cmp_slt_sgt_sext_not_oneuse(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
 ; CHECK-NEXT:    call void @use8(i8 [[SEXT]])
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
+; CHECK-NEXT:    [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A]], i32 [[B]])
 ; CHECK-NEXT:    ret i8 [[SEL_GT]]
 ;
   %cmp.lt = icmp slt i32 %a, %b
@@ -491,10 +477,7 @@ define i8 @strong_order_cmp_slt_sgt_sext_not_oneuse(i32 %a, i32 %b) {
 
 define <2 x i8> @strong_order_cmp_ugt_ult_vector(<2 x i32> %a, <2 x i32> %b) {
 ; CHECK-LABEL: @strong_order_cmp_ugt_ult_vector(
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i1> [[CMP_GT]] to <2 x i8>
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult <2 x i32> [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select <2 x i1> [[CMP_LT]], <2 x i8> <i8 -1, i8 -1>, <2 x i8> [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
 ; CHECK-NEXT:    ret <2 x i8> [[SEL_LT]]
 ;
   %cmp.gt = icmp ugt <2 x i32> %a, %b
@@ -506,10 +489,7 @@ define <2 x i8> @strong_order_cmp_ugt_ult_vector(<2 x i32> %a, <2 x i32> %b) {
 
 define <2 x i8> @strong_order_cmp_ugt_ult_vector_poison(<2 x i32> %a, <2 x i32> %b) {
 ; CHECK-LABEL: @strong_order_cmp_ugt_ult_vector_poison(
-; CHECK-NEXT:    [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i1> [[CMP_GT]] to <2 x i8>
-; CHECK-NEXT:    [[CMP_LT:%.*]] = icmp ult <2 x i32> [[A]], [[B]]
-; CHECK-NEXT:    [[SEL_LT:%.*]] = select <2 x i1> [[CMP_LT]], <2 x i8> <i8 poison, i8 -1>, <2 x i8> [[ZEXT]]
+; CHECK-NEXT:    [[SEL_LT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
 ; CHECK-NEXT:    ret <2 x i8> [[SEL_LT]]
 ;
   %cmp.gt = icmp ugt <2 x i32> %a, %b
diff --git a/llvm/test/Transforms/InstCombine/ucmp.ll b/llvm/test/Transforms/InstCombine/ucmp.ll
index ad8a57825253b0..13755f13bb0a11 100644
--- a/llvm/test/Transforms/InstCombine/ucmp.ll
+++ b/llvm/test/Transforms/InstCombine/ucmp.ll
@@ -222,6 +222,20 @@ define i8 @ucmp_from_select_lt(i32 %x, i32 %y) {
   ret i8 %r
 }
 
+; Fold (x u< y) ? -1 : zext(x u> y) into ucmp(x, y)
+define i8 @ucmp_from_select_lt_and_gt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_lt_and_gt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %gt_bool = icmp ugt i32 %x, %y
+  %gt = zext i1 %gt_bool to i8
+  %lt = icmp ult i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %gt
+  ret i8 %r
+}
+
 ; Vector version
 define <4 x i8> @ucmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: define <4 x i8> @ucmp_from_select_vec_lt(
@@ -349,13 +363,13 @@ define i8 @ucmp_from_select_le_neg1(i32 %x, i32 %y) {
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
 ; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ult i32 [[X]], [[Y]]
 ; CHECK-NEXT:    [[NE:%.*]] = sext i1 [[NE_BOOL]] to i8
-; CHECK-NEXT:    [[LE_NOT:%.*]] = icmp ugt i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[LE_NOT:%.*]] = icmp ult i32 [[X]], [[Y]]
 ; CHECK-NEXT:    [[R:%.*]] = select i1 [[LE_NOT]], i8 1, i8 [[NE]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %ne_bool = icmp ult i32 %x, %y
   %ne = sext i1 %ne_bool to i8
-  %le = icmp ule i32 %x, %y
+  %le = icmp uge i32 %x, %y
   %r = select i1 %le, i8 %ne, i8 1
   ret i8 %r
 }
@@ -513,3 +527,17 @@ define i8 @ucmp_from_select_ge_neg4(i32 %x, i32 %y) {
   %r = select i1 %ge, i8 %ne, i8 3
   ret i8 %r
 }
+
+; Fold (x > y) ? 1 : sext(x < y)
+define i8 @ucmp_from_select_gt_and_lt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_gt_and_lt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %lt_bool = icmp ult i32 %x, %y
+  %lt = sext i1 %lt_bool to i8
+  %gt = icmp ugt i32 %x, %y
+  %r = select i1 %gt, i8 1, i8 %lt
+  ret i8 %r
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/105272


More information about the llvm-commits mailing list