[llvm] users/goldsteinn/ctpop of not (PR #77859)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 11 17:07:31 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

<details>
<summary>Changes</summary>

- Add tests for folding `(add/sub/disjoint_or/icmp C, (ctpop (not x)))`; NFC
- Add folds for `(add/sub/disjoint_or/icmp C, (ctpop (not x)))`


---
Full diff: https://github.com/llvm/llvm-project/pull/77859.diff


6 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+6) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+3) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+3) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+4) 
- (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+77) 
- (added) llvm/test/Transforms/InstCombine/fold-ctpop-of-not.ll (+166) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index c7e6f32c5406a6..8a00b75a1f7404 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1683,6 +1683,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     }
   }
 
+  if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
+    return R;
+
   // TODO(jingyue): Consider willNotOverflowSignedAdd and
   // willNotOverflowUnsignedAdd to reduce the number of invocations of
   // computeKnownBits.
@@ -2445,6 +2448,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
     }
   }
 
+  if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
+    return R;
+
   if (Instruction *R = foldSubOfMinMax(I, Builder))
     return R;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 0620752e321394..de06fb8badf817 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3398,6 +3398,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (Instruction *R = foldBinOpShiftWithShift(I))
     return R;
 
+  if (Instruction *R = tryFoldInstWithCtpopWithNot(&I))
+    return R;
+
   Value *X, *Y;
   const APInt *CV;
   if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7c1aff445524de..8c0fd662255130 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1323,6 +1323,9 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
       return replaceInstUsesWith(Cmp, NewPhi);
     }
 
+  if (Instruction *R = tryFoldInstWithCtpopWithNot(&Cmp))
+    return R;
+
   return nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 21c61bd990184d..c24b6e3a5b33c0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -505,6 +505,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS,
                                         Value *RHS);
 
+  // If `I` has operand `(ctpop (not x))`, fold `I` with `(sub nuw nsw
+  // BitWidth(x), (ctpop x))`.
+  Instruction *tryFoldInstWithCtpopWithNot(Instruction *I);
+
   // (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C))
   //    -> (logic_shift (Binop1 (Binop2 X, inv_logic_shift(C1, C)), Y), C)
   // (Binop1 (Binop2 (logic_shift X, Amt), Mask), (logic_shift Y, Amt))
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 7f2018b3a19958..732ab7ad8b3223 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -740,6 +740,83 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ,
   return RetVal;
 }
 
+// If `I` has one Const operand and the other matches `(ctpop (not x))`,
+// replace `(ctpop (not x))` with `(sub nuw nsw BitWidth(x), (ctpop x))`.
+// This is only useful is the new subtract can fold so we only handle the
+// following cases:
+//    1) (add/sub/disjoint_or C, (ctpop (not x))
+//        -> (add/sub/disjoint_or C', (ctpop x))
+//    1) (cmp pred C, (ctpop (not x))
+//        -> (cmp pred C', (ctpop x))
+Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) {
+  unsigned Opc = I->getOpcode();
+  unsigned ConstIdx = 1;
+  switch (Opc) {
+  default:
+    return nullptr;
+    // (ctpop (not x)) <-> (sub nuw nsw BitWidth(x) - (ctpop x))
+    // We can fold the BitWidth(x) with add/sub/icmp as long the other operand
+    // is constant.
+  case Instruction::Sub:
+    ConstIdx = 0;
+    break;
+  case Instruction::Or:
+    if (!match(I, m_DisjointOr(m_Value(), m_Value())))
+      return nullptr;
+    [[fallthrough]];
+  case Instruction::Add:
+  case Instruction::ICmp:
+    break;
+  }
+  // Find ctpop.
+  auto *Ctpop = dyn_cast<IntrinsicInst>(I->getOperand(1 - ConstIdx));
+  if (Ctpop == nullptr)
+    return nullptr;
+  if (Ctpop->getIntrinsicID() != Intrinsic::ctpop)
+    return nullptr;
+  Constant *C;
+  // Check other operand is ImmConstant.
+  if (!match(I->getOperand(ConstIdx), m_ImmConstant(C)))
+    return nullptr;
+
+  Type *Ty = Ctpop->getType();
+  Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits());
+  // Need extra check for icmp. Note if this check is it generally means the
+  // icmp will simplify to true/false.
+  if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality() &&
+      !ConstantExpr::getICmp(ICmpInst::ICMP_UGT, C, BitWidthC)->isZeroValue())
+    return nullptr;
+
+  Value *Op = Ctpop->getArgOperand(0);
+  // Check we can invert `(not x)` for free.
+  Value *NotOp = getFreelyInverted(Op, Op->hasOneUse(), &Builder);
+  if (NotOp == nullptr)
+    return nullptr;
+  Value *CtpopOfNotOp = Builder.CreateIntrinsic(Ty, Intrinsic::ctpop, NotOp);
+
+  Value *R = nullptr;
+
+  // Do the transformation here to avoid potentially introducing an infinite
+  // loop.
+  switch (Opc) {
+  case Instruction::Sub:
+    R = Builder.CreateAdd(CtpopOfNotOp, ConstantExpr::getSub(C, BitWidthC));
+    break;
+  case Instruction::Or:
+  case Instruction::Add:
+    R = Builder.CreateSub(ConstantExpr::getAdd(C, BitWidthC), CtpopOfNotOp);
+    break;
+  case Instruction::ICmp:
+    R = Builder.CreateICmp(cast<ICmpInst>(I)->getSwappedPredicate(),
+                           CtpopOfNotOp, ConstantExpr::getSub(BitWidthC, C));
+    break;
+  default:
+    llvm_unreachable("Unhandled Opcode");
+  }
+  assert(R != nullptr);
+  return replaceInstUsesWith(*I, R);
+}
+
 // (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C))
 //   IFF
 //    1) the logic_shifts match
diff --git a/llvm/test/Transforms/InstCombine/fold-ctpop-of-not.ll b/llvm/test/Transforms/InstCombine/fold-ctpop-of-not.ll
new file mode 100644
index 00000000000000..9fa3bb66bb7f10
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fold-ctpop-of-not.ll
@@ -0,0 +1,166 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+declare i8 @llvm.ctpop.i8(i8)
+declare <2 x i8> @llvm.ctpop.v2i8(<2 x i8>)
+
+define i8 @fold_sub_c_ctpop(i8 %x) {
+; CHECK-LABEL: @fold_sub_c_ctpop(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0:![0-9]+]]
+; CHECK-NEXT:    [[R:%.*]] = add nuw nsw i8 [[TMP1]], 4
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = sub i8 12, %cnt
+  ret i8 %r
+}
+
+define i8 @fold_sub_var_ctpop_fail(i8 %x, i8 %y) {
+; CHECK-LABEL: @fold_sub_var_ctpop_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor i8 [[X:%.*]], -1
+; CHECK-NEXT:    [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = sub i8 [[Y:%.*]], [[CNT]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = sub i8 %y, %cnt
+  ret i8 %r
+}
+
+define <2 x i8> @fold_sub_ctpop_c(<2 x i8> %x) {
+; CHECK-LABEL: @fold_sub_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw nsw <2 x i8> <i8 -55, i8 -56>, [[TMP1]]
+; CHECK-NEXT:    ret <2 x i8> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = sub <2 x i8> %cnt, <i8 63, i8 64>
+  ret <2 x i8> %r
+}
+
+define i8 @fold_add_ctpop_c(i8 %x) {
+; CHECK-LABEL: @fold_add_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw nsw i8 71, [[TMP1]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = add i8 %cnt, 63
+  ret i8 %r
+}
+
+define i8 @fold_distjoint_or_ctpop_c(i8 %x) {
+; CHECK-LABEL: @fold_distjoint_or_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = sub nuw nsw i8 72, [[TMP1]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = or i8 %cnt, 64
+  ret i8 %r
+}
+
+define i8 @fold_or_ctpop_c_fail(i8 %x) {
+; CHECK-LABEL: @fold_or_ctpop_c_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor i8 [[X:%.*]], -1
+; CHECK-NEXT:    [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = or i8 [[CNT]], 65
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = or i8 %cnt, 65
+  ret i8 %r
+}
+
+define i8 @fold_add_ctpop_var_fail(i8 %x, i8 %y) {
+; CHECK-LABEL: @fold_add_ctpop_var_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor i8 [[X:%.*]], -1
+; CHECK-NEXT:    [[CNT:%.*]] = call i8 @llvm.ctpop.i8(i8 [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = add i8 [[CNT]], [[Y:%.*]]
+; CHECK-NEXT:    ret i8 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = add i8 %cnt, %y
+  ret i8 %r
+}
+
+define i1 @fold_cmp_eq_ctpop_c(i8 %x) {
+; CHECK-LABEL: @fold_cmp_eq_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[TMP1]], 6
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = icmp eq i8 %cnt, 2
+  ret i1 %r
+}
+
+define <2 x i1> @fold_cmp_ne_ctpop_c(<2 x i8> %x) {
+; CHECK-LABEL: @fold_cmp_ne_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ne <2 x i8> [[TMP1]], <i8 -36, i8 5>
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = icmp ne <2 x i8> %cnt, <i8 44, i8 3>
+  ret <2 x i1> %r
+}
+
+define <2 x i1> @fold_cmp_ne_ctpop_var_fail(<2 x i8> %x, <2 x i8> %y) {
+; CHECK-LABEL: @fold_cmp_ne_ctpop_var_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -1, i8 -1>
+; CHECK-NEXT:    [[CNT:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ne <2 x i8> [[CNT]], [[Y:%.*]]
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = icmp ne <2 x i8> %cnt, %y
+  ret <2 x i1> %r
+}
+
+define i1 @fold_cmp_ult_ctpop_c(i8 %x) {
+; CHECK-LABEL: @fold_cmp_ult_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call i8 @llvm.ctpop.i8(i8 [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ugt i8 [[TMP1]], 3
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %nx = xor i8 %x, -1
+  %cnt = call i8 @llvm.ctpop.i8(i8 %nx)
+  %r = icmp ult i8 %cnt, 5
+  ret i1 %r
+}
+
+define <2 x i1> @fold_cmp_ugt_ctpop_c(<2 x i8> %x) {
+; CHECK-LABEL: @fold_cmp_ugt_ctpop_c(
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[X:%.*]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ult <2 x i8> [[TMP1]], <i8 0, i8 2>
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = icmp ugt <2 x i8> %cnt, <i8 8, i8 6>
+  ret <2 x i1> %r
+}
+
+define <2 x i1> @fold_cmp_ugt_ctpop_c_out_of_range_fail(<2 x i8> %x) {
+; CHECK-LABEL: @fold_cmp_ugt_ctpop_c_out_of_range_fail(
+; CHECK-NEXT:    [[NX:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -1, i8 -1>
+; CHECK-NEXT:    [[CNT:%.*]] = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> [[NX]]), !range [[RNG0]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ugt <2 x i8> [[CNT]], <i8 2, i8 10>
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %nx = xor <2 x i8> %x, <i8 -1, i8 -1>
+  %cnt = call <2 x i8> @llvm.ctpop.v2i8(<2 x i8> %nx)
+  %r = icmp ugt <2 x i8> %cnt, <i8 2, i8 10>
+  ret <2 x i1> %r
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/77859


More information about the llvm-commits mailing list