[llvm] [InstCombine] Improve eq/ne by parts to handle ult/ugt equality pattern (PR #69884)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 27 12:52:32 PDT 2023


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/69884

>From 352ceb719f062e97f81d77dc71ea00b3281932b9 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Sun, 22 Oct 2023 01:47:44 -0500
Subject: [PATCH 1/2] [InstCombine] Add tests for new eq/ne patterns combining
 eq/ne by parts; NFC

---
 .../Transforms/InstCombine/eq-of-parts.ll     | 133 ++++++++++++++++++
 1 file changed, 133 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/eq-of-parts.ll b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
index dbf671aaaa86b40..5c220bde187d082 100644
--- a/llvm/test/Transforms/InstCombine/eq-of-parts.ll
+++ b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
@@ -1333,3 +1333,136 @@ define i1 @ne_21_wrong_pred2(i32 %x, i32 %y) {
   %c.210 = or i1 %c.2, %c.1
   ret i1 %c.210
 }
+
+define i1 @eq_optimized_highbits_cmp(i32 %x, i32 %y) {
+; CHECK-LABEL: @eq_optimized_highbits_cmp(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 33554432
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i25
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i25
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp eq i25 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ult i32 %xor, 33554432
+  %tx = trunc i32 %x to i25
+  %ty = trunc i32 %y to i25
+  %cmp_lo = icmp eq i25 %tx, %ty
+  %r = and i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @eq_optimized_highbits_cmp_todo_overlapping(i32 %x, i32 %y) {
+; CHECK-LABEL: @eq_optimized_highbits_cmp_todo_overlapping(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 16777216
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i25
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i25
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp eq i25 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ult i32 %xor, 16777216
+  %tx = trunc i32 %x to i25
+  %ty = trunc i32 %y to i25
+  %cmp_lo = icmp eq i25 %tx, %ty
+  %r = and i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @eq_optimized_highbits_cmp_fail_not_pow2(i32 %x, i32 %y) {
+; CHECK-LABEL: @eq_optimized_highbits_cmp_fail_not_pow2(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 16777215
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i24
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i24
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp eq i24 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ult i32 %xor, 16777215
+  %tx = trunc i32 %x to i24
+  %ty = trunc i32 %y to i24
+  %cmp_lo = icmp eq i24 %tx, %ty
+  %r = and i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777215
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i24
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i24
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp ne i24 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ugt i32 %xor, 16777215
+  %tx = trunc i32 %x to i24
+  %ty = trunc i32 %y to i24
+  %cmp_lo = icmp ne i24 %tx, %ty
+  %r = or i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp_fail_not_mask(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp_fail_not_mask(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777216
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i24
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i24
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp ne i24 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ugt i32 %xor, 16777216
+  %tx = trunc i32 %x to i24
+  %ty = trunc i32 %y to i24
+  %cmp_lo = icmp ne i24 %tx, %ty
+  %r = or i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp_fail_no_combined_int(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp_fail_no_combined_int(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777215
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i23
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i23
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp ne i23 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ugt i32 %xor, 16777215
+  %tx = trunc i32 %x to i23
+  %ty = trunc i32 %y to i23
+  %cmp_lo = icmp ne i23 %tx, %ty
+  %r = or i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp_todo_overlapping(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp_todo_overlapping(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 8388607
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i24
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i24
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp ne i24 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ugt i32 %xor, 8388607
+  %tx = trunc i32 %x to i24
+  %ty = trunc i32 %y to i24
+  %cmp_lo = icmp ne i24 %tx, %ty
+  %r = or i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}

>From 6a2540373028b77a5e732f3e938d798f7269616d Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Sun, 22 Oct 2023 01:19:53 -0500
Subject: [PATCH 2/2] [InstCombine] Improve eq/ne by parts to handle `ult/ugt`
 equality pattern.

(icmp eq/ne (lshr x, C), (lshr y, C) gets optimized to `(icmp
ult/uge (xor x, y), (1 << C)`. This can cause the current equal by
parts detection to miss the high-bits as it may get optimized to the
new pattern.

This commit adds support for detecting / combining the ult/ugt
pattern.
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 46 +++++++++++++++----
 .../Transforms/InstCombine/eq-of-parts.ll     | 14 +-----
 2 files changed, 39 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 070d386b2f18d24..b99da410acf156a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1131,9 +1131,9 @@ static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) {
   Value *V = P.From;
   if (P.StartBit)
     V = Builder.CreateLShr(V, P.StartBit);
-  Type *TruncTy = V->getType()->getWithNewBitWidth(P.NumBits);
-  if (TruncTy != V->getType())
-    V = Builder.CreateTrunc(V, TruncTy);
+  Type *OutTy = V->getType()->getWithNewBitWidth(P.NumBits);
+  if (OutTy != V->getType())
+    V = Builder.CreateTrunc(V, OutTy);
   return V;
 }
 
@@ -1146,13 +1146,41 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1,
     return nullptr;
 
   CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
-  if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred)
-    return nullptr;
+  auto GetMatchPart = [&](ICmpInst *Cmp,
+                          unsigned OpNo) -> std::optional<IntPart> {
+    if (Pred == Cmp->getPredicate())
+      return matchIntPart(Cmp->getOperand(OpNo));
+
+    const APInt *C;
+    // (icmp eq (lshr x, C), (lshr y, C)) gets optimized to:
+    // (icmp ult (xor x, y), 1 << C) so also look for that.
+    if (Pred == CmpInst::ICMP_EQ && Cmp->getPredicate() == CmpInst::ICMP_ULT) {
+      if (!match(Cmp->getOperand(1), m_Power2(C)) ||
+          !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())))
+        return std::nullopt;
+    }
+
+    // (icmp ne (lshr x, C), (lshr y, C)) gets optimized to:
+    // (icmp ugt (xor x, y), (1 << C) - 1) so also look for that.
+    else if (Pred == CmpInst::ICMP_NE &&
+             Cmp->getPredicate() == CmpInst::ICMP_UGT) {
+      if (!match(Cmp->getOperand(1), m_LowBitMask(C)) ||
+          !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())))
+        return std::nullopt;
+    } else {
+      return std::nullopt;
+    }
+
+    unsigned From = Pred == CmpInst::ICMP_NE ? C->popcount() : C->countr_zero();
+    Instruction *I = cast<Instruction>(Cmp->getOperand(0));
+    return {{I->getOperand(OpNo), From,
+             Cmp->getOperand(0)->getType()->getScalarSizeInBits() - From}};
+  };
 
-  std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
-  std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
-  std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
-  std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
+  std::optional<IntPart> L0 = GetMatchPart(Cmp0, 0);
+  std::optional<IntPart> R0 = GetMatchPart(Cmp0, 1);
+  std::optional<IntPart> L1 = GetMatchPart(Cmp1, 0);
+  std::optional<IntPart> R1 = GetMatchPart(Cmp1, 1);
   if (!L0 || !R0 || !L1 || !R1)
     return nullptr;
 
diff --git a/llvm/test/Transforms/InstCombine/eq-of-parts.ll b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
index 5c220bde187d082..57b15ae3b96e66e 100644
--- a/llvm/test/Transforms/InstCombine/eq-of-parts.ll
+++ b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
@@ -1336,12 +1336,7 @@ define i1 @ne_21_wrong_pred2(i32 %x, i32 %y) {
 
 define i1 @eq_optimized_highbits_cmp(i32 %x, i32 %y) {
 ; CHECK-LABEL: @eq_optimized_highbits_cmp(
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 33554432
-; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i25
-; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i25
-; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp eq i25 [[TX]], [[TY]]
-; CHECK-NEXT:    [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %xor = xor i32 %y, %x
@@ -1393,12 +1388,7 @@ define i1 @eq_optimized_highbits_cmp_fail_not_pow2(i32 %x, i32 %y) {
 
 define i1 @ne_optimized_highbits_cmp(i32 %x, i32 %y) {
 ; CHECK-LABEL: @ne_optimized_highbits_cmp(
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777215
-; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i24
-; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i24
-; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp ne i24 [[TX]], [[TY]]
-; CHECK-NEXT:    [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ne i32 [[Y:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %xor = xor i32 %y, %x



More information about the llvm-commits mailing list