[llvm] [InstCombine] Simplify select if it combinated and/or/xor (PR #73362)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 11:42:46 PST 2024


================
@@ -1672,6 +1672,110 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
   return nullptr;
 }
 
+static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
+                                     InstCombinerImpl &IC) {
+  ICmpInst::Predicate Pred = ICI->getPredicate();
+  if (!ICmpInst::isEquality(Pred))
+    return nullptr;
+
+  Value *TrueVal = SI.getTrueValue();
+  Value *FalseVal = SI.getFalseValue();
+  Value *CmpLHS = ICI->getOperand(0);
+  Value *CmpRHS = ICI->getOperand(1);
+
+  if (Pred == ICmpInst::ICMP_NE) {
+    Pred = ICmpInst::ICMP_EQ;
+    std::swap(TrueVal, FalseVal);
+  }
+
+  if (!isa<BinaryOperator>(CmpLHS) || !isa<BinaryOperator>(TrueVal))
+    return nullptr;
+
+  // Transform (X == C) ? X : Y -> (X == C) ? C : Y
+  // specific handling for and/or/xor bit operation.
+  // https://alive2.llvm.org/ce/z/WW8iRR
+  // x&y -> (x|y) ^ (x^y)
+  // x|y -> (x&y) | (x^y)
+  // x^y -> (x|y) ^ (x&y)
+  Value *X, *Y;
+  Value *AllOnes = Constant::getAllOnesValue(TrueVal->getType());
+  Value *Null = Constant::getNullValue(TrueVal->getType());
+  Instruction *ISI = &cast<Instruction>(SI);
+
+  // https://alive2.llvm.org/ce/z/EzU4sx
+  // (X & Y) == C ? X | Y : X ^ Y  ->  (X & Y) == C ?  C | (X^Y) : X ^ Y
+  // (X & Y) == C ? X ^ Y : X | Y  ->  (X & Y) == C ? ~C & (X|Y) : X | Y
+  // if C == 0, X|Y:X^Y -> 0|X^Y -> X^Y, X^Y:X|Y -> -1&X|Y:X|Y -> X|Y
+  // if C == -1, X|Y:X^Y -> -1:X^Y, X^Y:X|Y -> 0:X|Y
+  // otherwise, X|Y:(X^Y)|C -> (X^Y)|C, X^Y:(X&Y)^C -> (X&Y)^C
+  if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) {
+    if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) {
+      if ((match(FalseVal, m_c_Xor(m_Specific(X), m_Specific(Y))) &&
+           match(CmpRHS, m_Zero())) ||
+          (match(FalseVal, m_c_Or(m_c_Xor(m_Specific(X), m_Specific(Y)),
+                                  m_Specific(CmpRHS)))))
----------------
goldsteinn wrote:

Think a helper function would be useful here and below.

I.e instead of
```
(match(FalseVal, m_c_Xor(m_Specific(X), m_Specific(Y))) &&
           match(CmpRHS, m_Zero())) ||
          (match(FalseVal, m_c_Or(m_c_Xor(m_Specific(X), m_Specific(Y)),
                                  m_Specific(CmpRHS))))
```

you could do:

```
matchFalseVal = [FalseVal, CmpRHS, X, Y](unsigned OuterOpc, unsigned InnerOpc, bool NotRHS) {
   if(match(CmpRHS, m_Zero())
      return match(FalseVal, m_c_BinOp(InnerOpc, m_Specific(X), m_Specific(Y))`
   return match(FalseVal, m_c_BinOp(OuterOpc, m_c_BinOp(InnerOpc, m_Specific(X), m_Specific(Y)), NotRHS ? m_Not(m_Specific(CmpRHS)) : m_Specific(CmpRhs)));
};
```

https://github.com/llvm/llvm-project/pull/73362


More information about the llvm-commits mailing list