[llvm] [InstCombine] Add folds for (add/sub/disjoint_or/icmp C, (ctpop (not x))) (PR #77859)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 14 01:14:32 PST 2024


================
@@ -740,6 +740,87 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ,
   return RetVal;
 }
 
+// If `I` has one Const operand and the other matches `(ctpop (not x))`,
+// replace `(ctpop (not x))` with `(sub nuw nsw BitWidth(x), (ctpop x))`.
+// This is only useful is the new subtract can fold so we only handle the
+// following cases:
+//    1) (add/sub/disjoint_or C, (ctpop (not x))
+//        -> (add/sub/disjoint_or C', (ctpop x))
+//    1) (cmp pred C, (ctpop (not x))
+//        -> (cmp pred C', (ctpop x))
+Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) {
+  unsigned Opc = I->getOpcode();
+  unsigned ConstIdx = 1;
+  switch (Opc) {
+  default:
+    return nullptr;
+    // (ctpop (not x)) <-> (sub nuw nsw BitWidth(x) - (ctpop x))
+    // We can fold the BitWidth(x) with add/sub/icmp as long the other operand
+    // is constant.
+  case Instruction::Sub:
+    ConstIdx = 0;
+    break;
+  case Instruction::Or:
+    if (!match(I, m_DisjointOr(m_Value(), m_Value())))
+      return nullptr;
+    [[fallthrough]];
+  case Instruction::Add:
+  case Instruction::ICmp:
+    break;
+  }
+
+  Value *Op;
+  // Find ctpop.
+  if (!match(I->getOperand(1 - ConstIdx),
+             m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(Op)))))
+    return nullptr;
+
+  Constant *C;
+  // Check other operand is ImmConstant.
+  if (!match(I->getOperand(ConstIdx), m_ImmConstant(C)))
+    return nullptr;
+
+  Type *Ty = Op->getType();
+  Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits());
+  // Need extra check for icmp. Note if this check is it generally means the
+  // icmp will simplify to true/false.
+  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality() &&
+      !ConstantExpr::getICmp(ICmpInst::ICMP_UGT, C, BitWidthC)->isZeroValue())
+    return nullptr;
+
+  // Check we can invert `(not x)` for free.
+  bool Consumes = false;
+  if (!isFreeToInvert(Op, Op->hasOneUse(), Consumes) || !Consumes)
+    return nullptr;
+  Value *NotOp = getFreelyInverted(Op, Op->hasOneUse(), &Builder);
+  assert(NotOp != nullptr &&
+         "Desync between isFreeToInvert and getFreelyInverted");
+
+  Value *CtpopOfNotOp = Builder.CreateIntrinsic(Ty, Intrinsic::ctpop, NotOp);
+
+  Value *R = nullptr;
+
+  // Do the transformation here to avoid potentially introducing an infinite
+  // loop.
+  switch (Opc) {
+  case Instruction::Sub:
+    R = Builder.CreateAdd(CtpopOfNotOp, ConstantExpr::getSub(C, BitWidthC));
+    break;
+  case Instruction::Or:
+  case Instruction::Add:
+    R = Builder.CreateSub(ConstantExpr::getAdd(C, BitWidthC), CtpopOfNotOp);
+    break;
+  case Instruction::ICmp:
+    R = Builder.CreateICmp(cast<ICmpInst>(I)->getSwappedPredicate(),
----------------
goldsteinn wrote:

i just disabled to signed predicates. Shouldn't make a difference as any comparison of `ctpop, C` should be trivial to simplify in unsigned form. 

https://github.com/llvm/llvm-project/pull/77859


More information about the llvm-commits mailing list