[llvm] [InstCombine] Fold Xor with or disjoint (PR #105992)

Amr Hesham via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 2 10:29:44 PDT 2024


https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/105992

>From 77b75e66c1abdd43009c239a04c9c01e00737e26 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Mon, 2 Sep 2024 19:23:18 +0200
Subject: [PATCH 1/2] [InstCombine] Add pre-commit tests

---
 llvm/test/Transforms/InstCombine/xor.ll | 176 ++++++++++++++++++++++++
 1 file changed, 176 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/xor.ll b/llvm/test/Transforms/InstCombine/xor.ll
index ea7f7382ee7c8e..1a6f5ae40c463c 100644
--- a/llvm/test/Transforms/InstCombine/xor.ll
+++ b/llvm/test/Transforms/InstCombine/xor.ll
@@ -1485,3 +1485,179 @@ define i4 @PR96857_xor_without_noundef(i4  %val0, i4  %val1, i4 %val2) {
   %val7 = xor i4 %val4, %val6
   ret i4 %val7
 }
+
+define i32 @or_disjoint_with_xor(i32 %a, i32 %b) {
+; CHECK-LABEL: @or_disjoint_with_xor(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = xor i32 [[A:%.*]], -1
+; CHECK-NEXT:    [[XOR:%.*]] = and i32 [[B:%.*]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[XOR]]
+;
+entry:
+  %or = or disjoint i32 %a, %b
+  %xor = xor i32 %or, %a
+  ret i32 %xor
+}
+
+define i32 @xor_with_or_disjoint(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: @xor_with_or_disjoint(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = xor i32 [[A:%.*]], -1
+; CHECK-NEXT:    [[XOR:%.*]] = and i32 [[B:%.*]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[XOR]]
+;
+entry:
+  %or = or disjoint i32 %a, %b
+  %xor = xor i32 %a, %or
+  ret i32 %xor
+}
+
+define <2 x i32> @or_disjoint_with_xor_vec(<2 x i32> %a, < 2 x i32> %b, <2 x i32> %c) {
+; CHECK-LABEL: @or_disjoint_with_xor_vec(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = xor <2 x i32> [[A:%.*]], <i32 -1, i32 -1>
+; CHECK-NEXT:    [[XOR:%.*]] = and <2 x i32> [[B:%.*]], [[TMP0]]
+; CHECK-NEXT:    ret <2 x i32> [[XOR]]
+;
+entry:
+  %or = or disjoint <2 x i32> %a, %b
+  %xor = xor <2 x i32> %or, %a
+  ret <2 x i32> %xor
+}
+
+define <2 x i32> @xor_with_or_disjoint_vec(<2 x i32> %a, < 2 x i32> %b, <2 x i32> %c) {
+; CHECK-LABEL: @xor_with_or_disjoint_vec(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = xor <2 x i32> [[A:%.*]], <i32 -1, i32 -1>
+; CHECK-NEXT:    [[XOR:%.*]] = and <2 x i32> [[B:%.*]], [[TMP0]]
+; CHECK-NEXT:    ret <2 x i32> [[XOR]]
+;
+entry:
+  %or = or disjoint <2 x i32> %a, %b
+  %xor = xor <2 x i32> %a, %or
+  ret <2 x i32> %xor
+}
+
+define i32 @select_or_disjoint_xor(i32 %a, i1 %c) {
+; CHECK-LABEL: @select_or_disjoint_xor(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], i32 0, i32 4
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[A:%.*]], 4
+; CHECK-NEXT:    [[OR:%.*]] = or disjoint i32 [[S]], [[SHL]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[OR]], 4
+; CHECK-NEXT:    ret i32 [[XOR]]
+;
+entry:
+  %s = select i1 %c, i32 0, i32 4
+  %shl = shl i32 %a, 4
+  %or = or disjoint i32 %s, %shl
+  %xor = xor i32 %or, 4
+  ret i32 %xor
+}
+
+define <2 x i32> @select_or_disjoint_xor_vec(<2 x i32> %a, i1 %c) {
+; CHECK-LABEL: @select_or_disjoint_xor_vec(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], <2 x i32> zeroinitializer, <2 x i32> <i32 4, i32 4>
+; CHECK-NEXT:    [[SHL:%.*]] = shl <2 x i32> [[A:%.*]], <i32 4, i32 4>
+; CHECK-NEXT:    [[OR:%.*]] = or disjoint <2 x i32> [[S]], [[SHL]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor <2 x i32> [[OR]], <i32 4, i32 4>
+; CHECK-NEXT:    ret <2 x i32> [[XOR]]
+;
+entry:
+  %s = select i1 %c, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 4, i32 4>
+  %shl = shl <2 x i32> %a, <i32 4, i32 4>
+  %or = or <2 x i32> %s, %shl
+  %xor = xor <2 x i32> %or, <i32 4, i32 4>
+  ret <2 x i32> %xor
+}
+
+define i32 @select_or_disjoint_or(i32 %a, i1 %c) {
+; CHECK-LABEL: @select_or_disjoint_or(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], i32 0, i32 4
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 [[A:%.*]], 4
+; CHECK-NEXT:    [[OR:%.*]] = or disjoint i32 [[S]], [[SHL]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[OR]], 4
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %s = select i1 %c, i32 0, i32 4
+  %shl = shl i32 %a, 4
+  %or = or disjoint i32 %s, %shl
+  %add = add i32 %or, 4
+  ret i32 %add
+}
+
+define <2 x i32> @select_or_disjoint_or_vec(<2 x i32> %a, i1 %c) {
+; CHECK-LABEL: @select_or_disjoint_or_vec(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[S:%.*]] = select i1 [[C:%.*]], <2 x i32> zeroinitializer, <2 x i32> <i32 4, i32 4>
+; CHECK-NEXT:    [[SHL:%.*]] = shl <2 x i32> [[A:%.*]], <i32 4, i32 4>
+; CHECK-NEXT:    [[OR:%.*]] = or disjoint <2 x i32> [[S]], [[SHL]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw <2 x i32> [[OR]], <i32 4, i32 4>
+; CHECK-NEXT:    ret <2 x i32> [[ADD]]
+;
+entry:
+  %s = select i1 %c, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 4, i32 4>
+  %shl = shl <2 x i32> %a, <i32 4, i32 4>
+  %or = or <2 x i32> %s, %shl
+  %add = add <2 x i32> %or, <i32 4, i32 4>
+  ret <2 x i32> %add
+}
+
+define i32 @or_multi_use_disjoint_with_xor(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: @or_multi_use_disjoint_with_xor(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[OR:%.*]] = or disjoint i32 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[OR]], [[C:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[OR]], [[XOR]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+entry:
+  %or = or disjoint i32 %a, %b
+  %xor = xor i32 %or, %c
+  %add = add i32 %or, %xor
+  ret i32 %add
+}
+
+define <2 x i32> @or_multi_use_disjoint_with_xor_vec(<2 x i32> %a, <2 x i32> %b, <2 x i32> %c) {
+; CHECK-LABEL: @or_multi_use_disjoint_with_xor_vec(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[OR:%.*]] = or disjoint <2 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor <2 x i32> [[OR]], [[C:%.*]]
+; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[OR]], [[XOR]]
+; CHECK-NEXT:    ret <2 x i32> [[ADD]]
+;
+entry:
+  %or = or disjoint <2 x i32> %a, %b
+  %xor = xor <2 x i32> %or, %c
+  %add = add <2 x i32> %or, %xor
+  ret <2 x i32> %add
+}
+
+define i32 @add_with_or(i32 %a, i32 %b, i32 %c) {
+; CHECK-LABEL: @add_with_or(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = or i32 [[ADD]], [[C:%.*]]
+; CHECK-NEXT:    ret i32 [[OR]]
+;
+entry:
+  %add = add i32 %a, %b
+  %or = or i32 %add, %c
+  ret i32 %or
+}
+
+define <2 x i32> @add_with_or_vec(<2 x i32> %a, <2 x i32> %b, <2 x i32> %c) {
+; CHECK-LABEL: @add_with_or_vec(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[OR:%.*]] = or <2 x i32> [[ADD]], [[C:%.*]]
+; CHECK-NEXT:    ret <2 x i32> [[OR]]
+;
+entry:
+  %add = add <2 x i32> %a, %b
+  %or = or <2 x i32> %add, %c
+  ret <2 x i32> %or
+}

>From f89cd57b46832fbee576d1189c718275cc0ecf7c Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Mon, 2 Sep 2024 19:29:00 +0200
Subject: [PATCH 2/2] [InstCombine] Fold Xor with or disjoint

Implement a missing optimization to fold (A | B) ^ C to (A ^ C) ^ B
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 114 ++++++++++--------
 llvm/test/Transforms/InstCombine/xor.ll       |  16 +--
 2 files changed, 70 insertions(+), 60 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 4f557532f9f783..f58ffee1d2105d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -101,16 +101,16 @@ Value *InstCombinerImpl::insertRangeTest(Value *V, const APInt &Lo,
 ///    (icmp eq (A & B), A) equals (icmp ne (A & B), 0)
 ///    (icmp ne (A & B), A) equals (icmp eq (A & B), 0)
 enum MaskedICmpType {
-  AMask_AllOnes           =     1,
-  AMask_NotAllOnes        =     2,
-  BMask_AllOnes           =     4,
-  BMask_NotAllOnes        =     8,
-  Mask_AllZeros           =    16,
-  Mask_NotAllZeros        =    32,
-  AMask_Mixed             =    64,
-  AMask_NotMixed          =   128,
-  BMask_Mixed             =   256,
-  BMask_NotMixed          =   512
+  AMask_AllOnes = 1,
+  AMask_NotAllOnes = 2,
+  BMask_AllOnes = 4,
+  BMask_NotAllOnes = 8,
+  Mask_AllZeros = 16,
+  Mask_NotAllZeros = 32,
+  AMask_Mixed = 64,
+  AMask_NotMixed = 128,
+  BMask_Mixed = 256,
+  BMask_NotMixed = 512
 };
 
 /// Return the set of patterns (from MaskedICmpType) that (icmp SCC (A & B), C)
@@ -172,15 +172,16 @@ static unsigned conjugateICmpMask(unsigned Mask) {
             << 1;
 
   NewMask |= (Mask & (AMask_NotAllOnes | BMask_NotAllOnes | Mask_NotAllZeros |
-                      AMask_NotMixed | BMask_NotMixed))
-             >> 1;
+                      AMask_NotMixed | BMask_NotMixed)) >>
+             1;
 
   return NewMask;
 }
 
 // Adapts the external decomposeBitTestICmp for local use.
-static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
-                                 Value *&X, Value *&Y, Value *&Z) {
+static bool decomposeBitTestICmp(Value *LHS, Value *RHS,
+                                 CmpInst::Predicate &Pred, Value *&X, Value *&Y,
+                                 Value *&Z) {
   APInt Mask;
   if (!llvm::decomposeBitTestICmp(LHS, RHS, Pred, X, Mask))
     return false;
@@ -519,9 +520,9 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
   if (Mask == 0) {
     // Even if the two sides don't share a common pattern, check if folding can
     // still happen.
-    if (Value *V = foldLogOpOfMaskedICmpsAsymmetric(
-            LHS, RHS, IsAnd, A, B, C, D, E, PredL, PredR, LHSMask, RHSMask,
-            Builder))
+    if (Value *V = foldLogOpOfMaskedICmpsAsymmetric(LHS, RHS, IsAnd, A, B, C, D,
+                                                    E, PredL, PredR, LHSMask,
+                                                    RHSMask, Builder))
       return V;
     return nullptr;
   }
@@ -680,16 +681,16 @@ Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,
   if (!RangeStart)
     return nullptr;
 
-  ICmpInst::Predicate Pred0 = (Inverted ? Cmp0->getInversePredicate() :
-                               Cmp0->getPredicate());
+  ICmpInst::Predicate Pred0 =
+      (Inverted ? Cmp0->getInversePredicate() : Cmp0->getPredicate());
 
   // Accept x > -1 or x >= 0 (after potentially inverting the predicate).
   if (!((Pred0 == ICmpInst::ICMP_SGT && RangeStart->isMinusOne()) ||
         (Pred0 == ICmpInst::ICMP_SGE && RangeStart->isZero())))
     return nullptr;
 
-  ICmpInst::Predicate Pred1 = (Inverted ? Cmp1->getInversePredicate() :
-                               Cmp1->getPredicate());
+  ICmpInst::Predicate Pred1 =
+      (Inverted ? Cmp1->getInversePredicate() : Cmp1->getPredicate());
 
   Value *Input = Cmp0->getOperand(0);
   Value *RangeEnd;
@@ -707,9 +708,14 @@ Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,
   // Check the upper range comparison, e.g. x < n
   ICmpInst::Predicate NewPred;
   switch (Pred1) {
-    case ICmpInst::ICMP_SLT: NewPred = ICmpInst::ICMP_ULT; break;
-    case ICmpInst::ICMP_SLE: NewPred = ICmpInst::ICMP_ULE; break;
-    default: return nullptr;
+  case ICmpInst::ICMP_SLT:
+    NewPred = ICmpInst::ICMP_ULT;
+    break;
+  case ICmpInst::ICMP_SLE:
+    NewPred = ICmpInst::ICMP_ULE;
+    break;
+  default:
+    return nullptr;
   }
 
   // This simplification is only valid if the upper range is not negative.
@@ -785,8 +791,7 @@ Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS,
     if (L2 == R1)
       std::swap(L1, L2);
 
-    if (L1 == R1 &&
-        isKnownToBeAPowerOfTwo(L2, false, 0, CxtI) &&
+    if (L1 == R1 && isKnownToBeAPowerOfTwo(L2, false, 0, CxtI) &&
         isKnownToBeAPowerOfTwo(R2, false, 0, CxtI)) {
       // If this is a logical and/or, then we must prevent propagation of a
       // poison value from the RHS by inserting freeze.
@@ -1636,8 +1641,8 @@ static Instruction *reassociateFCmps(BinaryOperator &BO,
 
   // Match inner binop and the predicate for combining 2 NAN checks into 1.
   Value *BO10, *BO11;
-  FCmpInst::Predicate NanPred = Opcode == Instruction::And ? FCmpInst::FCMP_ORD
-                                                           : FCmpInst::FCMP_UNO;
+  FCmpInst::Predicate NanPred =
+      Opcode == Instruction::And ? FCmpInst::FCMP_ORD : FCmpInst::FCMP_UNO;
   if (!match(Op0, m_SpecificFCmp(NanPred, m_Value(X), m_AnyZeroFP())) ||
       !match(Op1, m_BinOp(Opcode, m_Value(BO10), m_Value(BO11))))
     return nullptr;
@@ -1666,8 +1671,7 @@ static Instruction *reassociateFCmps(BinaryOperator &BO,
 /// Match variations of De Morgan's Laws:
 /// (~A & ~B) == (~(A | B))
 /// (~A | ~B) == (~(A & B))
-static Instruction *matchDeMorgansLaws(BinaryOperator &I,
-                                       InstCombiner &IC) {
+static Instruction *matchDeMorgansLaws(BinaryOperator &I, InstCombiner &IC) {
   const Instruction::BinaryOps Opcode = I.getOpcode();
   assert((Opcode == Instruction::And || Opcode == Instruction::Or) &&
          "Trying to match De Morgan's Laws with something other than and/or");
@@ -1841,10 +1845,10 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) {
   Value *Cast1Src = Cast1->getOperand(0);
 
   // fold logic(cast(A), cast(B)) -> cast(logic(A, B))
-  if ((Cast0->hasOneUse() || Cast1->hasOneUse()) &&
-      shouldOptimizeCast(Cast0) && shouldOptimizeCast(Cast1)) {
-    Value *NewOp = Builder.CreateBinOp(LogicOpc, Cast0Src, Cast1Src,
-                                       I.getName());
+  if ((Cast0->hasOneUse() || Cast1->hasOneUse()) && shouldOptimizeCast(Cast0) &&
+      shouldOptimizeCast(Cast1)) {
+    Value *NewOp =
+        Builder.CreateBinOp(LogicOpc, Cast0Src, Cast1Src, I.getName());
     return CastInst::Create(CastOpcode, NewOp, DestTy);
   }
 
@@ -2530,7 +2534,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
       int Log2ShiftC = ShiftC->exactLogBase2();
       int Log2C = C->exactLogBase2();
       bool IsShiftLeft =
-         cast<BinaryOperator>(Op0)->getOpcode() == Instruction::Shl;
+          cast<BinaryOperator>(Op0)->getOpcode() == Instruction::Shl;
       int BitNum = IsShiftLeft ? Log2C - Log2ShiftC : Log2ShiftC - Log2C;
       assert(BitNum >= 0 && "Expected demanded bits to handle impossible mask");
       Value *Cmp = Builder.CreateICmpEQ(X, ConstantInt::get(Ty, BitNum));
@@ -3475,8 +3479,8 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
       }
     } else {
       if ((TrueIfSignedL && !TrueIfSignedR &&
-            match(LHS0, m_And(m_Value(X), m_Value(Y))) &&
-            match(RHS0, m_c_Or(m_Specific(X), m_Specific(Y)))) ||
+           match(LHS0, m_And(m_Value(X), m_Value(Y))) &&
+           match(RHS0, m_c_Or(m_Specific(X), m_Specific(Y)))) ||
           (!TrueIfSignedL && TrueIfSignedR &&
            match(LHS0, m_Or(m_Value(X), m_Value(Y))) &&
            match(RHS0, m_c_And(m_Specific(X), m_Specific(Y))))) {
@@ -4163,8 +4167,8 @@ Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS,
         isSignBitCheck(PredL, *LC, TrueIfSignedL) &&
         isSignBitCheck(PredR, *RC, TrueIfSignedR)) {
       Value *XorLR = Builder.CreateXor(LHS0, RHS0);
-      return TrueIfSignedL == TrueIfSignedR ? Builder.CreateIsNeg(XorLR) :
-                                              Builder.CreateIsNotNeg(XorLR);
+      return TrueIfSignedL == TrueIfSignedR ? Builder.CreateIsNeg(XorLR)
+                                            : Builder.CreateIsNotNeg(XorLR);
     }
 
     // Fold (icmp pred1 X, C1) ^ (icmp pred2 X, C2)
@@ -4343,8 +4347,8 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor,
   Type *Ty = Xor.getType();
   Value *A;
   const APInt *ShAmt;
-  if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) &&
-      Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 &&
+  if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && Op1->hasNUses(2) &&
+      *ShAmt == Ty->getScalarSizeInBits() - 1 &&
       match(Op0, m_OneUse(m_c_Add(m_Specific(A), m_Specific(Op1))))) {
     // Op1 = ashr i32 A, 31   ; smear the sign bit
     // xor (add A, Op1), Op1  ; add -1 and flip bits if negative
@@ -4580,7 +4584,8 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
 
   // Move a 'not' ahead of casts of a bool to enable logic reduction:
   // not (bitcast (sext i1 X)) --> bitcast (sext (not i1 X))
-  if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) && X->getType()->isIntOrIntVectorTy(1)) {
+  if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) &&
+      X->getType()->isIntOrIntVectorTy(1)) {
     Type *SextTy = cast<BitCastOperator>(NotOp)->getSrcTy();
     Value *NotX = Builder.CreateNot(X);
     Value *Sext = Builder.CreateSExt(NotX, SextTy);
@@ -4693,7 +4698,21 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
   // calls in there are unnecessary as SimplifyDemandedInstructionBits should
   // have already taken care of those cases.
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-  Value *M;
+  Value *X, *Y, *M;
+
+  // (A | B) ^ C -> (A ^ C) ^ B
+  // C ^ (A | B) -> B ^ (A ^ C)
+  if (match(&I, m_c_Xor(m_OneUse(m_c_DisjointOr(m_Value(X), m_Value(Y))),
+                        m_Value(M)))) {
+    if (Value *XorAC = simplifyBinOp(Instruction::Xor, X, M, SQ)) {
+      return BinaryOperator::CreateXor(XorAC, Y);
+    }
+
+    if (Value *XorBC = simplifyBinOp(Instruction::Xor, Y, M, SQ)) {
+      return BinaryOperator::CreateXor(XorBC, X);
+    }
+  }
+
   if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()),
                         m_c_And(m_Deferred(M), m_Value())))) {
     if (isGuaranteedNotToBeUndef(M))
@@ -4705,7 +4724,6 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
   if (Instruction *Xor = visitMaskedMerge(I, Builder))
     return Xor;
 
-  Value *X, *Y;
   Constant *C1;
   if (match(Op1, m_Constant(C1))) {
     Constant *C2;
@@ -4870,14 +4888,14 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
   // (A ^ B) ^ (A | C) --> (~A & C) ^ B -- There are 4 commuted variants.
   if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(A), m_Value(B))),
                         m_OneUse(m_c_Or(m_Deferred(A), m_Value(C))))))
-      return BinaryOperator::CreateXor(
-          Builder.CreateAnd(Builder.CreateNot(A), C), B);
+    return BinaryOperator::CreateXor(Builder.CreateAnd(Builder.CreateNot(A), C),
+                                     B);
 
   // (A ^ B) ^ (B | C) --> (~B & C) ^ A -- There are 4 commuted variants.
   if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(A), m_Value(B))),
                         m_OneUse(m_c_Or(m_Deferred(B), m_Value(C))))))
-      return BinaryOperator::CreateXor(
-          Builder.CreateAnd(Builder.CreateNot(B), C), A);
+    return BinaryOperator::CreateXor(Builder.CreateAnd(Builder.CreateNot(B), C),
+                                     A);
 
   // (A & B) ^ (A ^ B) -> (A | B)
   if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
diff --git a/llvm/test/Transforms/InstCombine/xor.ll b/llvm/test/Transforms/InstCombine/xor.ll
index 1a6f5ae40c463c..9146fb6cac8102 100644
--- a/llvm/test/Transforms/InstCombine/xor.ll
+++ b/llvm/test/Transforms/InstCombine/xor.ll
@@ -1489,9 +1489,7 @@ define i4 @PR96857_xor_without_noundef(i4  %val0, i4  %val1, i4 %val2) {
 define i32 @or_disjoint_with_xor(i32 %a, i32 %b) {
 ; CHECK-LABEL: @or_disjoint_with_xor(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = xor i32 [[A:%.*]], -1
-; CHECK-NEXT:    [[XOR:%.*]] = and i32 [[B:%.*]], [[TMP0]]
-; CHECK-NEXT:    ret i32 [[XOR]]
+; CHECK-NEXT:    ret i32 [[B:%.*]]
 ;
 entry:
   %or = or disjoint i32 %a, %b
@@ -1502,9 +1500,7 @@ entry:
 define i32 @xor_with_or_disjoint(i32 %a, i32 %b, i32 %c) {
 ; CHECK-LABEL: @xor_with_or_disjoint(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = xor i32 [[A:%.*]], -1
-; CHECK-NEXT:    [[XOR:%.*]] = and i32 [[B:%.*]], [[TMP0]]
-; CHECK-NEXT:    ret i32 [[XOR]]
+; CHECK-NEXT:    ret i32 [[B:%.*]]
 ;
 entry:
   %or = or disjoint i32 %a, %b
@@ -1515,9 +1511,7 @@ entry:
 define <2 x i32> @or_disjoint_with_xor_vec(<2 x i32> %a, < 2 x i32> %b, <2 x i32> %c) {
 ; CHECK-LABEL: @or_disjoint_with_xor_vec(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = xor <2 x i32> [[A:%.*]], <i32 -1, i32 -1>
-; CHECK-NEXT:    [[XOR:%.*]] = and <2 x i32> [[B:%.*]], [[TMP0]]
-; CHECK-NEXT:    ret <2 x i32> [[XOR]]
+; CHECK-NEXT:    ret <2 x i32> [[B:%.*]]
 ;
 entry:
   %or = or disjoint <2 x i32> %a, %b
@@ -1528,9 +1522,7 @@ entry:
 define <2 x i32> @xor_with_or_disjoint_vec(<2 x i32> %a, < 2 x i32> %b, <2 x i32> %c) {
 ; CHECK-LABEL: @xor_with_or_disjoint_vec(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = xor <2 x i32> [[A:%.*]], <i32 -1, i32 -1>
-; CHECK-NEXT:    [[XOR:%.*]] = and <2 x i32> [[B:%.*]], [[TMP0]]
-; CHECK-NEXT:    ret <2 x i32> [[XOR]]
+; CHECK-NEXT:    ret <2 x i32> [[B:%.*]]
 ;
 entry:
   %or = or disjoint <2 x i32> %a, %b



More information about the llvm-commits mailing list