[llvm] [InstCombine] Fold `(x < y) ? -1 : zext(x != y)` into `u/scmp(x,y)` (PR #101049)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 29 10:57:12 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Volodymyr Vasylkun (Poseydon42)
<details>
<summary>Changes</summary>
This patch adds the aforementioned fold to InstCombine. This pattern is produced after naive implementations of 3-way comparison in high-level languages are transformed into LLVM IR and then optimized.
[Proofs](https://alive2.llvm.org/ce/z/w4QLq_)
This would close #<!-- -->99746
---
Full diff: https://github.com/llvm/llvm-project/pull/101049.diff
4 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+35)
- (modified) llvm/test/Transforms/InstCombine/scmp.ll (+96)
- (modified) llvm/test/Transforms/InstCombine/ucmp.ll (+96)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 64fbcc80e0edf..19d27e277bfda 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -731,6 +731,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
// Helpers of visitSelectInst().
Instruction *foldSelectOfBools(SelectInst &SI);
+ Instruction *foldSelectToCmp(SelectInst &SI);
Instruction *foldSelectExtConst(SelectInst &Sel);
Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI);
Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index aaf4ece3249a2..a6815f92955b0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3529,6 +3529,38 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
Masked);
}
+// This function tries to fold the following sequence
+// %lt = icmp ult/slt i32 %x, %y
+// %ne0 = icmp ne i32 %x, %y
+// %ne = zext i1 %ne0 to iN
+// %r = select i1 %lt, iN -1, iN %ne
+// into
+// %r = call iN @llvm.ucmp/scmp(%x, %y)
+Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
+ if (!isa<ConstantInt>(SI.getTrueValue()) ||
+ !dyn_cast<ConstantInt>(SI.getTrueValue())->isAllOnesValue())
+ return nullptr;
+
+ Value *LHS, *RHS;
+ ICmpInst::Predicate NEPred;
+ if (!match(SI.getFalseValue(),
+ m_ZExt(m_ICmp(NEPred, m_Value(LHS), m_Value(RHS)))) ||
+ NEPred != ICmpInst::ICMP_NE)
+ return nullptr;
+
+ ICmpInst::Predicate LTPred;
+ if (!match(SI.getCondition(),
+ m_ICmp(LTPred, m_Specific(LHS), m_Specific(RHS))) ||
+ !ICmpInst::isLT(LTPred))
+ return nullptr;
+
+ bool IsSigned = ICmpInst::isSigned(LTPred);
+ Instruction *Result = Builder.CreateIntrinsic(
+ SI.getFalseValue()->getType(),
+ IsSigned ? Intrinsic::scmp : Intrinsic::ucmp, {LHS, RHS});
+ return replaceInstUsesWith(SI, Result);
+}
+
bool InstCombinerImpl::fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
const Instruction *CtxI) const {
KnownFPClass Known = computeKnownFPClass(MulVal, FMF, fcNegative, CtxI);
@@ -4111,5 +4143,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
+ if (auto *Instruction = foldSelectToCmp(SI))
+ return Instruction;
+
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/scmp.ll b/llvm/test/Transforms/InstCombine/scmp.ll
index 2523872562cad..5ae7970499fa7 100644
--- a/llvm/test/Transforms/InstCombine/scmp.ll
+++ b/llvm/test/Transforms/InstCombine/scmp.ll
@@ -183,3 +183,99 @@ define i8 @scmp_negated_multiuse(i32 %x, i32 %y) {
%2 = sub i8 0, %1
ret i8 %2
}
+
+; Fold ((x s< y) ? -1 : (x != y)) into scmp(x, y)
+define i8 @scmp_from_select(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %x, %y
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp slt i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
+
+; Negative test: false value of the select is not `icmp ne x, y`
+define i8 @scmp_from_select_neg1(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg1(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp eq i32 [[X]], [[Y]]
+; CHECK-NEXT: [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: [[LT:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp eq i32 %x, %y
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp slt i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
+
+define i8 @scmp_from_select_neg2(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg2(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ne i32 [[Y]], [[X]]
+; CHECK-NEXT: [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: [[LT:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %y, %x
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp slt i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
+
+; Negative test: true value of select is not -1
+define i8 @scmp_from_select_neg3(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg3(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT: [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: [[LT:%.*]] = icmp slt i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[LT]], i8 2, i8 [[NE]]
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %x, %y
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp slt i32 %x, %y
+ %r = select i1 %lt, i8 2, i8 %ne
+ ret i8 %r
+}
+
+; Negative test: false value of select is sign-extended instead of zero-extended
+define i8 @scmp_from_select_neg4(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg4(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = sext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %x, %y
+ %ne = sext i1 %ne_bool to i8
+ %lt = icmp slt i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
+
+; Negative test: condition of select is not (x s< y)
+define i8 @scmp_from_select_neg5(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_neg5(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT: [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: [[LT:%.*]] = icmp sgt i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %x, %y
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp sgt i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
diff --git a/llvm/test/Transforms/InstCombine/ucmp.ll b/llvm/test/Transforms/InstCombine/ucmp.ll
index 7210455094baa..f7ce432885616 100644
--- a/llvm/test/Transforms/InstCombine/ucmp.ll
+++ b/llvm/test/Transforms/InstCombine/ucmp.ll
@@ -183,3 +183,99 @@ define i8 @ucmp_negated_multiuse(i32 %x, i32 %y) {
%2 = sub i8 0, %1
ret i8 %2
}
+
+; Fold ((x u< y) ? -1 : (x != y)) into ucmp(x, y)
+define i8 @ucmp_from_select(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %x, %y
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp ult i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
+
+; Negative test: false value of the select is not `icmp ne x, y`
+define i8 @ucmp_from_select_neg1(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg1(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp eq i32 [[X]], [[Y]]
+; CHECK-NEXT: [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: [[LT:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp eq i32 %x, %y
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp ult i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
+
+define i8 @ucmp_from_select_neg2(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg2(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ne i32 [[Y]], [[X]]
+; CHECK-NEXT: [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: [[LT:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[LT]], i8 -1, i8 [[NE]]
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %y, %x
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp ult i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
+
+; Negative test: true value of select is not -1
+define i8 @ucmp_from_select_neg3(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg3(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT: [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: [[LT:%.*]] = icmp ult i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[LT]], i8 2, i8 [[NE]]
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %x, %y
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp ult i32 %x, %y
+ %r = select i1 %lt, i8 2, i8 %ne
+ ret i8 %r
+}
+
+; Negative test: false value of select is sign-extended instead of zero-extended
+define i8 @ucmp_from_select_neg4(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg4(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = sext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %x, %y
+ %ne = sext i1 %ne_bool to i8
+ %lt = icmp ult i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
+
+; Negative test: condition of select is not (x s< y)
+define i8 @ucmp_from_select_neg5(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_neg5(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ne i32 [[X]], [[Y]]
+; CHECK-NEXT: [[NE:%.*]] = zext i1 [[NE_BOOL]] to i8
+; CHECK-NEXT: [[LT_NOT:%.*]] = icmp ugt i32 [[X]], [[Y]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[LT_NOT]], i8 [[NE]], i8 -1
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %ne_bool = icmp ne i32 %x, %y
+ %ne = zext i1 %ne_bool to i8
+ %lt = icmp ule i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %ne
+ ret i8 %r
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/101049
More information about the llvm-commits
mailing list