[llvm] [InstCombine] Fold `(x < y) ? -1 : zext(x != y)` into `u/scmp(x,y)` (PR #101049)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 29 10:57:12 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Volodymyr Vasylkun (Poseydon42)

<details>
<summary>Changes</summary>

This patch adds the aforementioned fold to InstCombine. This pattern is produced after naive implementations of 3-way comparison in high-level languages are transformed into LLVM IR and then optimized.

[Proofs](https://alive2.llvm.org/ce/z/w4QLq_)

This would close #<!-- -->99746 



---
Full diff: https://github.com/llvm/llvm-project/pull/101049.diff


4 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+35) 
- (modified) llvm/test/Transforms/InstCombine/scmp.ll (+96) 
- (modified) llvm/test/Transforms/InstCombine/ucmp.ll (+96) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 64fbcc80e0edf..19d27e277bfda 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -731,6 +731,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   // Helpers of visitSelectInst().
   Instruction *foldSelectOfBools(SelectInst &SI);
+  Instruction *foldSelectToCmp(SelectInst &SI);
   Instruction *foldSelectExtConst(SelectInst &Sel);
   Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI);
   Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index aaf4ece3249a2..a6815f92955b0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3529,6 +3529,38 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
                                 Masked);
 }
 
+// This function tries to fold the following sequence
+//   %lt = icmp ult/slt i32 %x, %y
+//   %ne0 = icmp ne i32 %x, %y
+//   %ne = zext i1 %ne0 to iN
+//   %r = select i1 %lt, iN -1, iN %ne
+// into
+//   %r = call iN @llvm.ucmp/scmp(%x, %y)
+Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
+  if (!isa<ConstantInt>(SI.getTrueValue()) ||
+      !dyn_cast<ConstantInt>(SI.getTrueValue())->isAllOnesValue())
+    return nullptr;
+
+  Value *LHS, *RHS;
+  ICmpInst::Predicate NEPred;
+  if (!match(SI.getFalseValue(),
+             m_ZExt(m_ICmp(NEPred, m_Value(LHS), m_Value(RHS)))) ||
+      NEPred != ICmpInst::ICMP_NE)
+    return nullptr;
+
+  ICmpInst::Predicate LTPred;
+  if (!match(SI.getCondition(),
+             m_ICmp(LTPred, m_Specific(LHS), m_Specific(RHS))) ||
+      !ICmpInst::isLT(LTPred))
+    return nullptr;
+
+  bool IsSigned = ICmpInst::isSigned(LTPred);
+  Instruction *Result = Builder.CreateIntrinsic(
+      SI.getFalseValue()->getType(),
+      IsSigned ? Intrinsic::scmp : Intrinsic::ucmp, {LHS, RHS});
+  return replaceInstUsesWith(SI, Result);
+}
+
 bool InstCombinerImpl::fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
                                         const Instruction *CtxI) const {
   KnownFPClass Known = computeKnownFPClass(MulVal, FMF, fcNegative, CtxI);
@@ -4111,5 +4143,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     }
   }
 
+  if (auto *Instruction = foldSelectToCmp(SI))
+    return Instruction;
+
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/scmp.ll b/llvm/test/Transforms/InstCombine/scmp.ll
index 2523872562cad..5ae7970499fa7 100644
--- a/llvm/test/Transforms/InstCombine/scmp.ll
+++ b/llvm/test/Transforms/InstCombine/scmp.ll
@@ -183,3 +183,99 @@ define i8 @scmp_negated_multiuse(i32 %x, i32 %y) {
   %2 = sub i8 0, %1
   ret i8 %2
 }
+
+; Fold ((x s< y) ? -1 : (x != y)) into scmp(x, y)
+define i8 @scmp_from_select(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %x, %y
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp slt i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}
+
+; Negative test: false value of the select is not `icmp ne x, y`
+define i8 @scmp_from_select_neg1(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg1(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp eq i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    [[LT:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp eq i32 %x, %y
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp slt i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}
+
+define i8 @scmp_from_select_neg2(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg2(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ne i32 [[Y]], [[X]]
+; CHECK-NEXT:    [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    [[LT:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %y, %x
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp slt i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}
+
+; Negative test: true value of select is not -1
+define i8 @scmp_from_select_neg3(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg3(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    [[LT:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[LT]], i8 2, i8 [[NE]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %x, %y
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp slt i32 %x, %y
+  %r = select i1 %lt, i8 2, i8 %ne
+  ret i8 %r
+}
+
+; Negative test: false value of select is sign-extended instead of zero-extended
+define i8 @scmp_from_select_neg4(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg4(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = sext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %x, %y
+  %ne = sext i1 %ne_bool to i8
+  %lt = icmp slt i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}
+
+; Negative test: condition of select is not (x s< y)
+define i8 @scmp_from_select_neg5(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg5(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    [[LT:%.*]] = icmp sgt i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %x, %y
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp sgt i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}
diff --git a/llvm/test/Transforms/InstCombine/ucmp.ll b/llvm/test/Transforms/InstCombine/ucmp.ll
index 7210455094baa..f7ce432885616 100644
--- a/llvm/test/Transforms/InstCombine/ucmp.ll
+++ b/llvm/test/Transforms/InstCombine/ucmp.ll
@@ -183,3 +183,99 @@ define i8 @ucmp_negated_multiuse(i32 %x, i32 %y) {
   %2 = sub i8 0, %1
   ret i8 %2
 }
+
+; Fold ((x u< y) ? -1 : (x != y)) into ucmp(x, y)
+define i8 @ucmp_from_select(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %x, %y
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp ult i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}
+
+; Negative test: false value of the select is not `icmp ne x, y`
+define i8 @ucmp_from_select_neg1(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg1(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp eq i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    [[LT:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp eq i32 %x, %y
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp ult i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}
+
+define i8 @ucmp_from_select_neg2(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg2(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ne i32 [[Y]], [[X]]
+; CHECK-NEXT:    [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    [[LT:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %y, %x
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp ult i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}
+
+; Negative test: true value of select is not -1
+define i8 @ucmp_from_select_neg3(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg3(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    [[LT:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[LT]], i8 2, i8 [[NE]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %x, %y
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp ult i32 %x, %y
+  %r = select i1 %lt, i8 2, i8 %ne
+  ret i8 %r
+}
+
+; Negative test: false value of select is sign-extended instead of zero-extended
+define i8 @ucmp_from_select_neg4(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg4(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = sext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %x, %y
+  %ne = sext i1 %ne_bool to i8
+  %lt = icmp ult i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}
+
+; Negative test: condition of select is not (x s< y)
+define i8 @ucmp_from_select_neg5(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg5(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:    [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT:    [[LT_NOT:%.*]] = icmp ugt i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[LT_NOT]], i8 [[NE]], i8 -1
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %ne_bool = icmp ne i32 %x, %y
+  %ne = zext i1 %ne_bool to i8
+  %lt = icmp ule i32 %x, %y
+  %r = select i1 %lt, i8 -1, i8 %ne
+  ret i8 %r
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/101049


More information about the llvm-commits mailing list