[llvm] [InstCombine] Try optimizing with knownbits which determined from Cond (PR #91762)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 17 08:11:24 PDT 2024


================
@@ -1809,6 +1809,198 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
   return nullptr;
 }
 
+// ICmpInst of SelectInst is not included in the calculation of KnownBits
+// so we are missing the opportunity to optimize the Value of the True or
+// False Condition via ICmpInst with KnownBits.
+//
+// Consider:
+//   %or = or i32 %x, %y
+//   %or0 = icmp eq i32 %or, 0
+//   %and = and i32 %x, %y
+//   %cond = select i1 %or0, i32 %and, i32 %or
+//   ret i32 %cond
+//
+// Expect:
+//   %or = or i32 %x, %y
+//   ret i32 %or
+//
+// We could know what bit was enabled for %x, %y by ICmpInst in SelectInst.
+static Instruction *foldSelectICmpBinOp(SelectInst &SI, ICmpInst *ICI,
+                                        Value *CmpLHS, Value *CmpRHS,
+                                        Value *TVal, Value *FVal,
+                                        InstCombinerImpl &IC) {
+  Value *X, *Y;
+  const APInt *C;
+  unsigned CmpLHSOpc;
+  bool IsDisjoint = false;
+  // Specially handling for X^Y==0 transformed to X==Y
+  if (match(TVal, m_c_BitwiseLogic(m_Specific(CmpLHS), m_Specific(CmpRHS)))) {
+    X = CmpLHS;
+    Y = CmpRHS;
+    APInt ZeroVal = APInt::getZero(CmpLHS->getType()->getScalarSizeInBits());
+    C = const_cast<APInt *>(&ZeroVal);
+    CmpLHSOpc = Instruction::Xor;
+  } else if ((match(CmpLHS, m_BinOp(m_Value(X), m_Value(Y))) &&
+              match(CmpRHS, m_APInt(C))) &&
+             (match(TVal, m_c_BinOp(m_Specific(X), m_Value())) ||
+              match(TVal, m_c_BinOp(m_Specific(Y), m_Value())))) {
+    if (auto Inst = dyn_cast<PossiblyDisjointInst>(CmpLHS)) {
+      if (Inst->isDisjoint())
+        IsDisjoint = true;
+      CmpLHSOpc = Instruction::Or;
+    } else
+      CmpLHSOpc = cast<BinaryOperator>(CmpLHS)->getOpcode();
+  } else
+    return nullptr;
+
+  enum SpecialKnownBits {
+    NothingSpecial = 0,
+    NoCommonBits = 1 << 1,
+    AllCommonBits = 1 << 2,
+    AllBitsEnabled = 1 << 3,
+  };
+
+  // We cannot know exactly what bits is known in X Y.
+  // Instead, we just know what relationship exist for.
+  auto isSpecialKnownBitsFor = [&]() -> unsigned {
+    if (CmpLHSOpc == Instruction::And) {
+      if (C->isZero())
+        return NoCommonBits;
+    } else if (CmpLHSOpc == Instruction::Xor) {
+      if (C->isAllOnes())
+        return NoCommonBits | AllBitsEnabled;
+      if (C->isZero())
+        return AllCommonBits;
+    } else if (CmpLHSOpc == Instruction::Or && IsDisjoint) {
+      if (C->isAllOnes())
+        return NoCommonBits | AllBitsEnabled;
+      return NoCommonBits;
+    }
+
+    return NothingSpecial;
+  };
+
+  auto hasOperandAt = [&](Instruction *I, Value *Op) -> int {
+    for (unsigned Idx = 0; Idx < I->getNumOperands(); Idx++) {
+      if (I->getOperand(Idx) == Op)
+        return Idx + 1;
+    }
+    return 0;
+  };
+
+  Type *TValTy = TVal->getType();
+  unsigned BitWidth = TVal->getType()->getScalarSizeInBits();
+  auto TValBop = cast<BinaryOperator>(TVal);
+  unsigned XOrder = hasOperandAt(TValBop, X);
+  unsigned YOrder = hasOperandAt(TValBop, Y);
+  unsigned SKB = isSpecialKnownBitsFor();
+
+  KnownBits Known;
+  if (TValBop->isBitwiseLogicOp()) {
+    // We handle if we know specific knownbits from cond of selectinst.
+    // ex) X&Y==-1 ? X^Y : False
+    if (SKB != SpecialKnownBits::NothingSpecial && XOrder && YOrder) {
+      // No common bits between X, Y
+      if (SKB & SpecialKnownBits::NoCommonBits) {
+        if (SKB & (SpecialKnownBits::AllBitsEnabled)) {
+          // If X op Y == -1, then XOR must be -1
+          if (TValBop->getOpcode() == Instruction::Xor)
+            Known = KnownBits::makeConstant(APInt(BitWidth, -1));
+        }
+        // If Trueval is X&Y then it should be 0.
+        if (TValBop->getOpcode() == Instruction::And)
+          Known = KnownBits::makeConstant(APInt(BitWidth, 0));
+        // X|Y can be replace with X^Y, X^Y can be replace with X|Y
+        // This replacing is meaningful when falseval is same.
+        else if ((match(TVal, m_c_Or(m_Specific(X), m_Specific(Y))) &&
+                  match(FVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) ||
+                 (match(TVal, m_c_Xor(m_Specific(X), m_Specific(Y))) &&
+                  match(FVal, m_c_Or(m_Specific(X), m_Specific(Y)))))
+          return IC.replaceInstUsesWith(SI, FVal);
+        // All common bits between X, Y
+      } else if (SKB & SpecialKnownBits::AllCommonBits) {
+        // We can replace (X&Y) and (X|Y) to X or Y
+        if (TValBop->getOpcode() == Instruction::And ||
+            TValBop->getOpcode() == Instruction::Or)
+          if (TValBop->hasOneUse())
+            return IC.replaceOperand(SI, 1, X);
+      } else if (SKB & SpecialKnownBits::AllBitsEnabled) {
+        // We can replace (X|Y) to -1
+        if (TValBop->getOpcode() == Instruction::Or)
+          Known = KnownBits::makeConstant(APInt(BitWidth, -1));
+      }
+    } else {
+      KnownBits XKnown, YKnown, Temp;
+      KnownBits TValBop0KB, TValBop1KB;
+      // computeKnowBits calculates the KnownBits in the branching condition
+      // that the specified variable passes in the execution flow. however, it
+      // does not contain the SelectInst condition, so there is an optimization
+      // opportunity to update the knownbits obtained by calculating KnownBits
+      // with the SelectInst condition.
+      XKnown = IC.computeKnownBits(X, 0, &SI);
+      IC.computeKnownBitsFromCond(X, ICI, XKnown, 0, &SI, false);
+      YKnown = IC.computeKnownBits(Y, 0, &SI);
+      IC.computeKnownBitsFromCond(Y, ICI, YKnown, 0, &SI, false);
+      CmpInst::Predicate Pred = ICI->getPredicate();
+      if (Pred == ICmpInst::ICMP_EQ) {
+        // Estimate additional KnownBits from the relationship between X and Y
+        if (CmpLHSOpc == Instruction::And) {
+          // The bit that are set to 1 at `~C&Y` must be 0 in X
+          // The bit that are set to 1 at `~C&X` must be 0 in Y
+          XKnown.Zero |= ~*C & YKnown.One;
+          YKnown.Zero |= ~*C & XKnown.One;
+        }
+        if (CmpLHSOpc == Instruction::Or) {
+          // The bit that are set to 0 at `C&Y` must be 1 in X
+          // The bit that are set to 0 at `C&X` must be 1 in Y
+          XKnown.One |= *C & YKnown.Zero;
+          YKnown.One |= *C & XKnown.Zero;
----------------
ParkHanbum wrote:

actually, these discovering way of knownbits is exist in valuetracker before but it removed more recently. but I think it can be used for this commit. so, I'm not have confidence about creating new api for this in ValueTracker. 

but it is beginner's opinion, I'll update code soon.


https://github.com/llvm/llvm-project/pull/91762


More information about the llvm-commits mailing list