[llvm] [InstCombine] Modify `foldSelectICmpEq` to only handle more useful and simple cases. (PR #121672)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 5 01:42:50 PST 2025


================
@@ -1835,96 +1839,79 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
   if (Pred == ICmpInst::ICMP_NE)
     std::swap(TrueVal, FalseVal);
 
-  if (Instruction *Res =
-          foldSelectWithExtremeEqCond(CmpLHS, CmpRHS, TrueVal, FalseVal))
-    return Res;
+  if (auto *R = foldSelectWithExtremeEqCond(CmpLHS, CmpRHS, TrueVal, FalseVal))
+    return R;
 
-  // Transform (X == C) ? X : Y -> (X == C) ? C : Y
-  // specific handling for Bitwise operation.
-  // x&y -> (x|y) ^ (x^y)  or  (x|y) & ~(x^y)
-  // x|y -> (x&y) | (x^y)  or  (x&y) ^  (x^y)
-  // x^y -> (x|y) ^ (x&y)  or  (x|y) & ~(x&y)
   Value *X, *Y;
-  if (!match(CmpLHS, m_BitwiseLogic(m_Value(X), m_Value(Y))) ||
-      !match(TrueVal, m_c_BitwiseLogic(m_Specific(X), m_Specific(Y))))
-    return nullptr;
-
-  const unsigned AndOps = Instruction::And, OrOps = Instruction::Or,
-                 XorOps = Instruction::Xor, NoOps = 0;
-  enum NotMask { None = 0, NotInner, NotRHS };
-
-  auto matchFalseVal = [&](unsigned OuterOpc, unsigned InnerOpc,
-                           unsigned NotMask) {
-    auto matchInner = m_c_BinOp(InnerOpc, m_Specific(X), m_Specific(Y));
-    if (OuterOpc == NoOps)
-      return match(CmpRHS, m_Zero()) && match(FalseVal, matchInner);
-
-    if (NotMask == NotInner) {
-      return match(FalseVal, m_c_BinOp(OuterOpc, m_NotForbidPoison(matchInner),
-                                       m_Specific(CmpRHS)));
-    } else if (NotMask == NotRHS) {
-      return match(FalseVal, m_c_BinOp(OuterOpc, matchInner,
-                                       m_NotForbidPoison(m_Specific(CmpRHS))));
-    } else {
-      return match(FalseVal,
-                   m_c_BinOp(OuterOpc, matchInner, m_Specific(CmpRHS)));
-    }
-  };
-
-  // (X&Y)==C ? X|Y : X^Y -> (X^Y)|C : X^Y  or (X^Y)^ C : X^Y
-  // (X&Y)==C ? X^Y : X|Y -> (X|Y)^C : X|Y  or (X|Y)&~C : X|Y
-  if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) {
-    if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) {
-      // (X&Y)==C ? X|Y : (X^Y)|C -> (X^Y)|C : (X^Y)|C -> (X^Y)|C
-      // (X&Y)==C ? X|Y : (X^Y)^C -> (X^Y)^C : (X^Y)^C -> (X^Y)^C
-      if (matchFalseVal(OrOps, XorOps, None) ||
-          matchFalseVal(XorOps, XorOps, None))
-        return IC.replaceInstUsesWith(SI, FalseVal);
-    } else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) {
-      // (X&Y)==C ? X^Y : (X|Y)^ C -> (X|Y)^ C : (X|Y)^ C -> (X|Y)^ C
-      // (X&Y)==C ? X^Y : (X|Y)&~C -> (X|Y)&~C : (X|Y)&~C -> (X|Y)&~C
-      if (matchFalseVal(XorOps, OrOps, None) ||
-          matchFalseVal(AndOps, OrOps, NotRHS))
+  if (match(CmpRHS, m_Zero())) {
+    // (X & Y) == 0 ? X |/^/+ Y : X |/^/+ Y -> X |/^/+ Y (false arm)
+    // `(X & Y) == 0` implies no common bits which means:
+	// `X ^ Y == X | Y == X + Y`
+    // https://alive2.llvm.org/ce/z/jjcduh
+    if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) {
+      auto MatchAddOrXor =
+          m_CombineOr(m_c_Add(m_Specific(X), m_Specific(Y)),
+                      m_CombineOr(m_c_Or(m_Specific(X), m_Specific(Y)),
+                                  m_c_Xor(m_Specific(X), m_Specific(Y))));
+      if (match(TrueVal, MatchAddOrXor) && match(FalseVal, MatchAddOrXor))
         return IC.replaceInstUsesWith(SI, FalseVal);
     }
-  }
 
-  // (X|Y)==C ? X&Y : X^Y -> (X^Y)^C : X^Y  or  ~(X^Y)&C : X^Y
-  // (X|Y)==C ? X^Y : X&Y -> (X&Y)^C : X&Y  or  ~(X&Y)&C : X&Y
-  if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y)))) {
-    if (match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y)))) {
-      // (X|Y)==C ? X&Y: (X^Y)^C -> (X^Y)^C: (X^Y)^C ->  (X^Y)^C
-      // (X|Y)==C ? X&Y:~(X^Y)&C ->~(X^Y)&C:~(X^Y)&C -> ~(X^Y)&C
-      if (matchFalseVal(XorOps, XorOps, None) ||
-          matchFalseVal(AndOps, XorOps, NotInner))
-        return IC.replaceInstUsesWith(SI, FalseVal);
-    } else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) {
-      // (X|Y)==C ? X^Y : (X&Y)^C ->  (X&Y)^C : (X&Y)^C ->  (X&Y)^C
-      // (X|Y)==C ? X^Y :~(X&Y)&C -> ~(X&Y)&C :~(X&Y)&C -> ~(X&Y)&C
-      if (matchFalseVal(XorOps, AndOps, None) ||
-          matchFalseVal(AndOps, AndOps, NotInner))
-        return IC.replaceInstUsesWith(SI, FalseVal);
-    }
-  }
+    // (X | Y) == 0 ? X Op0 Y : X Op1 Y -> X Op1 Y
+    // For any `Op0` and `Op1` that are zero when `X` and `Y` are zero.
+	// https://alive2.llvm.org/ce/z/azHzBW
+    if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
+        (match(TrueVal, m_c_BinOp(m_Specific(X), m_Specific(Y))) ||
+		 // In true arm we can also accept just `0`.
+         match(TrueVal, m_Zero())) &&
+        match(FalseVal, m_c_BinOp(m_Specific(X), m_Specific(Y)))) {
+      auto IsOpcZeroWithZeros = [](Value *V) {
+        auto *I = dyn_cast<Instruction>(V);
+        if (!I)
+          return false;
+        switch (I->getOpcode()) {
+        case Instruction::And:
+        case Instruction::Or:
+        case Instruction::Xor:
+        case Instruction::Mul:
+        case Instruction::Add:
+        case Instruction::Sub:
+        case Instruction::Shl:
+        case Instruction::AShr:
+        case Instruction::LShr:
+          return true;
+        default:
+          return false;
+        }
+      };
 
-  // (X^Y)==C ? X&Y : X|Y -> (X|Y)^C : X|Y  or (X|Y)&~C : X|Y
-  // (X^Y)==C ? X|Y : X&Y -> (X&Y)|C : X&Y  or (X&Y)^ C : X&Y
-  if (match(CmpLHS, m_Xor(m_Value(X), m_Value(Y)))) {
-    if ((match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y))))) {
-      // (X^Y)==C ? X&Y : (X|Y)^C -> (X|Y)^C
-      // (X^Y)==C ? X&Y : (X|Y)&~C -> (X|Y)&~C
-      if (matchFalseVal(XorOps, OrOps, None) ||
-          matchFalseVal(AndOps, OrOps, NotRHS))
-        return IC.replaceInstUsesWith(SI, FalseVal);
-    } else if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) {
-      // (X^Y)==C ? (X|Y) : (X&Y)|C -> (X&Y)|C
-      // (X^Y)==C ? (X|Y) : (X&Y)^C -> (X&Y)^C
-      if (matchFalseVal(OrOps, AndOps, None) ||
-          matchFalseVal(XorOps, AndOps, None))
+      if ((match(TrueVal, m_Zero()) || IsOpcZeroWithZeros(TrueVal)) &&
+          IsOpcZeroWithZeros(FalseVal))
         return IC.replaceInstUsesWith(SI, FalseVal);
     }
   }
+  // (X == Y) ? X | Y : X & Y
+  // (X == Y) ? X & Y : X | Y
+  // If `X == Y` then `X == Y == X | Y == X & Y`.
+  // NB: `X == Y` is canonicalization of `(X ^ Y) == 0`.
+  // https://alive2.llvm.org/ce/z/SJskbz
+  X = CmpLHS;
+  Y = CmpRHS;
+  auto MatchOrAnd = m_CombineOr(m_c_Or(m_Specific(X), m_Specific(Y)),
+                                m_c_And(m_Specific(X), m_Specific(Y)));
+  if (match(FalseVal, MatchOrAnd) &&
+      // In the true arm we can also just match `X` or `Y`.
+      (match(TrueVal, MatchOrAnd) || match(TrueVal, m_Specific(X)) ||
+       match(TrueVal, m_Specific(Y)))) {
+    // Can't preserve `or disjoint` here so rebuild.
+    auto *BO = dyn_cast<BinaryOperator>(FalseVal);
+    if (!BO)
+      return nullptr;
 
+    return IC.replaceInstUsesWith(
+        SI, IC.Builder.CreateBinOp(BO->getOpcode(), BO->getOperand(0),
+                                   BO->getOperand(1)));
+  }
----------------
nikic wrote:

Can we handle this case by changing https://github.com/llvm/llvm-project/blob/3321c2d72ab7757dbdd38bdd99a76d89293dac8a/llvm/lib/Analysis/InstructionSimplify.cpp#L4596-L4612 to compare both simplified values instead? Currently we simplify one and compare against the original other. If we compared both simplified values, I think we should be able to handle this pattern without any specialized code.

https://github.com/llvm/llvm-project/pull/121672


More information about the llvm-commits mailing list