[llvm] [InstCombine] Try optimizing with knownbits which determined from Cond (PR #91762)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 3 14:19:57 PDT 2024


================
@@ -1790,6 +1790,158 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
   return nullptr;
 }
 
+// ICmpInst of SelectInst is not included in the calculation of KnownBits
+// so we are missing the opportunity to optimize the Value of the True or
+// False Condition via ICmpInst with KnownBits.
+//
+// Consider:
+//   %or = or i32 %x, %y
+//   %or0 = icmp eq i32 %or, 0
+//   %and = and i32 %x, %y
+//   %cond = select i1 %or0, i32 %and, i32 %or
+//   ret i32 %cond
+//
+// Expect:
+//   %or = or i32 %x, %y
+//   ret i32 %or
+//
+// We could know what bit was enabled for %x, %y by ICmpInst in SelectInst.
+static Instruction *foldSelectICmpBinOp(SelectInst &SI, ICmpInst *ICI,
+                                        Value *CmpLHS, Value *CmpRHS,
+                                        Value *TVal, Value *FVal,
+                                        InstCombinerImpl &IC) {
+  Value *X, *Y;
+  const APInt *C;
+
+  if (!((match(CmpLHS, m_BinOp(m_Value(X), m_Value(Y))) &&
+         match(CmpRHS, m_APInt(C))) &&
+        (match(TVal, m_c_BinOp(m_Specific(X), m_Value())) ||
+         match(TVal, m_c_BinOp(m_Specific(Y), m_Value())))))
+    return nullptr;
+
+  enum SpecialKnownBits {
+    NothingSpecial = 0,
+    NoCommonBits = 1 << 1,
+    AllCommonBits = 1 << 2,
+    AllBitsEnabled = 1 << 3,
+  };
+
+  // We cannot know exactly what bits is known in X Y.
+  // Instead, we just know what relationship exist for.
+  auto isSpecialKnownBitsFor = [&](const Instruction *CmpLHS,
+                                   const APInt *CmpRHS) -> unsigned {
+    unsigned Opc = CmpLHS->getOpcode();
+    if (Opc == Instruction::And) {
+      if (CmpRHS->isZero())
+        return NoCommonBits;
+    } else if (Opc == Instruction::Xor) {
+      if (CmpRHS->isAllOnes())
+        return NoCommonBits | AllBitsEnabled;
+      if (CmpRHS->isZero())
+        return AllCommonBits;
+    }
+
+    return NothingSpecial;
+  };
+
+  auto hasOperandAt = [&](Instruction *I, Value *Op) -> int {
+    for (unsigned Idx = 0; Idx < I->getNumOperands(); Idx++) {
+      if (I->getOperand(Idx) == Op)
+        return Idx + 1;
+    }
+    return 0;
+  };
+
+  Type *TValTy = TVal->getType();
+  unsigned BitWidth = TVal->getType()->getScalarSizeInBits();
+  auto TValBop = cast<BinaryOperator>(TVal);
+  auto CmpLHSBop = cast<BinaryOperator>(CmpLHS);
+  unsigned XOrder = hasOperandAt(TValBop, X);
+  unsigned YOrder = hasOperandAt(TValBop, Y);
+  unsigned SKB = isSpecialKnownBitsFor(CmpLHSBop, C);
+
+  KnownBits Known;
+  if (TValBop->isBitwiseLogicOp()) {
+    if (SKB != SpecialKnownBits::NothingSpecial && XOrder && YOrder) {
+      if (SKB & SpecialKnownBits::NoCommonBits) {
+        if (SKB & (SpecialKnownBits::AllBitsEnabled)) {
+          if (TValBop->getOpcode() == Instruction::Xor)
+            Known = KnownBits::makeConstant(APInt(BitWidth, -1));
+        }
+        if (TValBop->getOpcode() == Instruction::And)
+          Known = KnownBits::makeConstant(APInt(BitWidth, 0));
+        else if ((match(TVal, m_c_Or(m_Specific(X), m_Specific(Y))) &&
+                  match(FVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) ||
+                 (match(TVal, m_c_Xor(m_Specific(X), m_Specific(Y))) &&
+                  match(FVal, m_c_Or(m_Specific(X), m_Specific(Y)))))
+          return IC.replaceInstUsesWith(SI, FVal);
+      } else if (SKB & SpecialKnownBits::AllCommonBits) {
+        if (TValBop->getOpcode() == Instruction::And ||
+            TValBop->getOpcode() == Instruction::Or)
+          if (TValBop->hasOneUse())
+            return IC.replaceOperand(SI, 1, X);
+      } else if (SKB & SpecialKnownBits::AllBitsEnabled) {
+        if (TValBop->getOpcode() == Instruction::Or)
+          Known = KnownBits::makeConstant(APInt(BitWidth, -1));
+      }
+    } else {
+      KnownBits XKnown, YKnown, Temp;
+      KnownBits TValBop0KB, TValBop1KB;
+      XKnown = IC.computeKnownBits(X, 0, &SI);
+      IC.computeKnownBitsFromCond(X, ICI, XKnown, 0, &SI, false);
+      YKnown = IC.computeKnownBits(Y, 0, &SI);
+      IC.computeKnownBitsFromCond(Y, ICI, YKnown, 0, &SI, false);
+
+      // Estimate additional KnownBits from the relationship between X and Y
+      CmpInst::Predicate Pred = ICI->getPredicate();
+      if (Pred == ICmpInst::ICMP_EQ) {
+        if (CmpLHSBop->getOpcode() == Instruction::And) {
+          XKnown.Zero |= ~*C & YKnown.One;
+          YKnown.Zero |= ~*C & XKnown.One;
+        }
+        if (CmpLHSBop->getOpcode() == Instruction::Or) {
+          XKnown.One |= *C & YKnown.Zero;
+          YKnown.One |= *C & XKnown.Zero;
+        }
+        if (CmpLHSBop->getOpcode() == Instruction::Xor) {
+          XKnown.One |= *C & YKnown.Zero;
+          XKnown.Zero |= *C & YKnown.One;
+          YKnown.One |= *C & XKnown.Zero;
+          YKnown.Zero |= *C & XKnown.One;
+          XKnown.Zero |= ~*C & YKnown.Zero;
+          XKnown.One |= ~*C & YKnown.One;
+          YKnown.Zero |= ~*C & XKnown.Zero;
+          YKnown.One |= ~*C & XKnown.One;
+        }
+      }
+
+      auto getTValBopKB = [&](unsigned OpNum) -> KnownBits {
+        unsigned Order = OpNum + 1;
+        if (Order == XOrder)
+          return XKnown;
+        else if (Order == YOrder)
+          return YKnown;
+
+        Value *V = TValBop->getOperand(OpNum);
+        KnownBits Known = IC.computeKnownBits(V, 0, &SI);
+        return Known;
+      };
+      TValBop0KB = getTValBopKB(0);
+      TValBop1KB = getTValBopKB(1);
+      Known = analyzeKnownBitsFromAndXorOr(
+          cast<Operator>(TValBop), TValBop0KB, TValBop1KB, 0,
+          IC.getSimplifyQuery().getWithInstruction(&SI));
+    }
+  }
----------------
ParkHanbum wrote:

I'll update soon!

https://github.com/llvm/llvm-project/pull/91762


More information about the llvm-commits mailing list