[llvm] ad91473 - [InstCombine] Improve eq/ne by parts to handle `ult/ugt` equality pattern.
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Sat Nov 4 17:00:51 PDT 2023
Author: Noah Goldstein
Date: 2023-11-04T19:00:28-05:00
New Revision: ad9147399f196c1c9b6bec373a5ad3afadc51a4a
URL: https://github.com/llvm/llvm-project/commit/ad9147399f196c1c9b6bec373a5ad3afadc51a4a
DIFF: https://github.com/llvm/llvm-project/commit/ad9147399f196c1c9b6bec373a5ad3afadc51a4a.diff
LOG: [InstCombine] Improve eq/ne by parts to handle `ult/ugt` equality pattern.
(icmp eq/ne (lshr x, C), (lshr y, C) gets optimized to `(icmp
ult/uge (xor x, y), (1 << C)`. This can cause the current equal by
parts detection to miss the high-bits as it may get optimized to the
new pattern.
This commit adds support for detecting / combining the ult/ugt
pattern.
Closes #69884
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
llvm/test/Transforms/InstCombine/eq-of-parts.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 070d386b2f18d24..46af9bf5eed003a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1146,13 +1146,40 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1,
return nullptr;
CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
- if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred)
- return nullptr;
+ auto GetMatchPart = [&](ICmpInst *Cmp,
+ unsigned OpNo) -> std::optional<IntPart> {
+ if (Pred == Cmp->getPredicate())
+ return matchIntPart(Cmp->getOperand(OpNo));
+
+ const APInt *C;
+ // (icmp eq (lshr x, C), (lshr y, C)) gets optimized to:
+ // (icmp ult (xor x, y), 1 << C) so also look for that.
+ if (Pred == CmpInst::ICMP_EQ && Cmp->getPredicate() == CmpInst::ICMP_ULT) {
+ if (!match(Cmp->getOperand(1), m_Power2(C)) ||
+ !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())))
+ return std::nullopt;
+ }
+
+ // (icmp ne (lshr x, C), (lshr y, C)) gets optimized to:
+ // (icmp ugt (xor x, y), (1 << C) - 1) so also look for that.
+ else if (Pred == CmpInst::ICMP_NE &&
+ Cmp->getPredicate() == CmpInst::ICMP_UGT) {
+ if (!match(Cmp->getOperand(1), m_LowBitMask(C)) ||
+ !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())))
+ return std::nullopt;
+ } else {
+ return std::nullopt;
+ }
+
+ unsigned From = Pred == CmpInst::ICMP_NE ? C->popcount() : C->countr_zero();
+ Instruction *I = cast<Instruction>(Cmp->getOperand(0));
+ return {{I->getOperand(OpNo), From, C->getBitWidth() - From}};
+ };
- std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
- std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
- std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
- std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
+ std::optional<IntPart> L0 = GetMatchPart(Cmp0, 0);
+ std::optional<IntPart> R0 = GetMatchPart(Cmp0, 1);
+ std::optional<IntPart> L1 = GetMatchPart(Cmp1, 0);
+ std::optional<IntPart> R1 = GetMatchPart(Cmp1, 1);
if (!L0 || !R0 || !L1 || !R1)
return nullptr;
diff --git a/llvm/test/Transforms/InstCombine/eq-of-parts.ll b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
index 5c220bde187d082..57b15ae3b96e66e 100644
--- a/llvm/test/Transforms/InstCombine/eq-of-parts.ll
+++ b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
@@ -1336,12 +1336,7 @@ define i1 @ne_21_wrong_pred2(i32 %x, i32 %y) {
define i1 @eq_optimized_highbits_cmp(i32 %x, i32 %y) {
; CHECK-LABEL: @eq_optimized_highbits_cmp(
-; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 33554432
-; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i25
-; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i25
-; CHECK-NEXT: [[CMP_LO:%.*]] = icmp eq i25 [[TX]], [[TY]]
-; CHECK-NEXT: [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%xor = xor i32 %y, %x
@@ -1393,12 +1388,7 @@ define i1 @eq_optimized_highbits_cmp_fail_not_pow2(i32 %x, i32 %y) {
define i1 @ne_optimized_highbits_cmp(i32 %x, i32 %y) {
; CHECK-LABEL: @ne_optimized_highbits_cmp(
-; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777215
-; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i24
-; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i24
-; CHECK-NEXT: [[CMP_LO:%.*]] = icmp ne i24 [[TX]], [[TY]]
-; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%xor = xor i32 %y, %x
More information about the llvm-commits
mailing list