[llvm] [InstCombine] Handle trunc i1 pattern in eq-of-parts fold (PR #112704)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 25 01:36:09 PST 2024


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/112704

>From dd6eff5d59912138d07e56fc0bab66599e3fa2db Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Wed, 16 Oct 2024 17:23:32 +0200
Subject: [PATCH 1/2] [InstCombine] Handle trunc i1 pattern in eq-of-parts fold

Equality/inequality of the low bit can be represented by
`(trunc (xor x, y) to i1)`, possibly with an extra not. We have
to handle this in the eq-of-parts fold now that we no longer
canonicalize this to a masked icmp.

Proofs: https://alive2.llvm.org/ce/z/qidkzq

Fixes https://github.com/llvm/llvm-project/issues/110919.
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 23 ++++++++++++++-----
 .../InstCombine/InstCombineInternal.h         |  2 +-
 .../Transforms/InstCombine/eq-of-parts.ll     | 11 ++-------
 3 files changed, 20 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index c6f14018a750f5..659b8b86699fa9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1185,14 +1185,25 @@ static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) {
 /// (icmp eq X0, Y0) & (icmp eq X1, Y1) -> icmp eq X01, Y01
 /// (icmp ne X0, Y0) | (icmp ne X1, Y1) -> icmp ne X01, Y01
 /// where X0, X1 and Y0, Y1 are adjacent parts extracted from an integer.
-Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1,
-                                       bool IsAnd) {
+Value *InstCombinerImpl::foldEqOfParts(Value *Cmp0, Value *Cmp1, bool IsAnd) {
   if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse())
     return nullptr;
 
   CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
-  auto GetMatchPart = [&](ICmpInst *Cmp,
+  auto GetMatchPart = [&](Value *CmpV,
                           unsigned OpNo) -> std::optional<IntPart> {
+    Value *X, *Y;
+    // icmp ne (and x, 1), (and y, 1) <=> trunc (xor x, y) to i1
+    // icmp eq (and x, 1), (and y, 1) <=> not (trunc (xor x, y) to i1)
+    if (Pred == CmpInst::ICMP_NE
+            ? match(CmpV, m_Trunc(m_Xor(m_Value(X), m_Value(Y))))
+            : match(CmpV, m_Not(m_Trunc(m_Xor(m_Value(X), m_Value(Y))))))
+      return {{OpNo == 0 ? X : Y, 0, 1}};
+
+    auto *Cmp = dyn_cast<ICmpInst>(CmpV);
+    if (!Cmp)
+      return std::nullopt;
+
     if (Pred == Cmp->getPredicate())
       return matchIntPart(Cmp->getOperand(OpNo));
 
@@ -3404,9 +3415,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
       return X;
   }
 
-  if (Value *X = foldEqOfParts(LHS, RHS, IsAnd))
-    return X;
-
   // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
   // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0)
   // TODO: Remove this and below when foldLogOpOfMaskedICmps can handle undefs.
@@ -3529,6 +3537,9 @@ Value *InstCombinerImpl::foldBooleanAndOr(Value *LHS, Value *RHS,
       if (Value *Res = foldLogicOfFCmps(LHSCmp, RHSCmp, IsAnd, IsLogical))
         return Res;
 
+  if (Value *Res = foldEqOfParts(LHS, RHS, IsAnd))
+    return Res;
+
   return nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 9588930d7658c4..0508ed48fc19c4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -412,7 +412,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                           bool IsAnd, bool IsLogical = false);
   Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Xor);
 
-  Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd);
+  Value *foldEqOfParts(Value *Cmp0, Value *Cmp1, bool IsAnd);
 
   Value *foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, ICmpInst *ICmp2,
                                      bool IsAnd);
diff --git a/llvm/test/Transforms/InstCombine/eq-of-parts.ll b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
index 9494dd6bf8e5b5..d07c2e6a5be521 100644
--- a/llvm/test/Transforms/InstCombine/eq-of-parts.ll
+++ b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
@@ -1441,11 +1441,7 @@ define i1 @ne_optimized_highbits_cmp_todo_overlapping(i32 %x, i32 %y) {
 
 define i1 @and_trunc_i1(i8 %a1, i8 %a2) {
 ; CHECK-LABEL: @and_trunc_i1(
-; CHECK-NEXT:    [[XOR:%.*]] = xor i8 [[A1:%.*]], [[A2:%.*]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[XOR]], 2
-; CHECK-NEXT:    [[LOBIT:%.*]] = trunc i8 [[XOR]] to i1
-; CHECK-NEXT:    [[LOBIT_INV:%.*]] = xor i1 [[LOBIT]], true
-; CHECK-NEXT:    [[AND:%.*]] = and i1 [[CMP]], [[LOBIT_INV]]
+; CHECK-NEXT:    [[AND:%.*]] = icmp eq i8 [[A1:%.*]], [[A2:%.*]]
 ; CHECK-NEXT:    ret i1 [[AND]]
 ;
   %xor = xor i8 %a1, %a2
@@ -1494,10 +1490,7 @@ define i1 @and_trunc_i1_wrong_operands(i8 %a1, i8 %a2, i8 %a3) {
 
 define i1 @or_trunc_i1(i64 %a1, i64 %a2) {
 ; CHECK-LABEL: @or_trunc_i1(
-; CHECK-NEXT:    [[XOR:%.*]] = xor i64 [[A2:%.*]], [[A1:%.*]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i64 [[XOR]], 1
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i64 [[XOR]] to i1
-; CHECK-NEXT:    [[OR:%.*]] = or i1 [[CMP]], [[TRUNC]]
+; CHECK-NEXT:    [[OR:%.*]] = icmp ne i64 [[A2:%.*]], [[A1:%.*]]
 ; CHECK-NEXT:    ret i1 [[OR]]
 ;
   %xor = xor i64 %a2, %a1

>From 7117fbbec6abb882c415c5570c135b98eb9fe8bd Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Mon, 25 Nov 2024 10:35:50 +0100
Subject: [PATCH 2/2] add assert

---
 llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 659b8b86699fa9..b4033fc2a418a1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1192,6 +1192,8 @@ Value *InstCombinerImpl::foldEqOfParts(Value *Cmp0, Value *Cmp1, bool IsAnd) {
   CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
   auto GetMatchPart = [&](Value *CmpV,
                           unsigned OpNo) -> std::optional<IntPart> {
+    assert(CmpV->getType()->isIntOrIntVectorTy(1) && "Must be bool");
+
     Value *X, *Y;
     // icmp ne (and x, 1), (and y, 1) <=> trunc (xor x, y) to i1
     // icmp eq (and x, 1), (and y, 1) <=> not (trunc (xor x, y) to i1)



More information about the llvm-commits mailing list