[llvm] [AMDGPU][InstCombine] Fold ballot intrinsic based on llvm.assume hints (PR #160670)

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 11 04:09:59 PST 2025


================
@@ -1341,6 +1342,71 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
       Call->takeName(&II);
       return IC.replaceInstUsesWith(II, Call);
     }
+
+    // Fold ballot intrinsic based on llvm.assume hint about the result.
+    //
+    // assume(ballot(x) == ballot(true)) -> x = true
+    // assume(ballot(x) == -1)           -> x = true
+    // assume(ballot(x) == 0)            -> x = false
+    if (Arg->getType()->isIntegerTy(1)) {
+      for (auto &AssumeVH : IC.getAssumptionCache().assumptionsFor(&II)) {
+        if (!AssumeVH)
+          continue;
+
+        auto *Assume = cast<AssumeInst>(AssumeVH);
+        Value *Cond = Assume->getArgOperand(0);
+
+        // Check if assume condition is an equality comparison.
+        auto *ICI = dyn_cast<ICmpInst>(Cond);
+        if (!ICI || ICI->getPredicate() != ICmpInst::ICMP_EQ)
+          continue;
+
+        // Extract the ballot and the value being compared against it.
+        Value *LHS = ICI->getOperand(0), *RHS = ICI->getOperand(1);
+        Value *CompareValue = (LHS == &II) ? RHS : (RHS == &II) ? LHS : nullptr;
+        if (!CompareValue)
+          continue;
+
+        // Determine the constant value of the ballot's condition argument.
+        std::optional<bool> InferredCondValue;
+        if (auto *CI = dyn_cast<ConstantInt>(CompareValue)) {
+          // ballot(x) == -1 means all lanes have x = true.
+          if (CI->isMinusOne())
+            InferredCondValue = true;
+          // ballot(x) == 0 means all lanes have x = false.
+          else if (CI->isZero())
+            InferredCondValue = false;
+        } else if (match(CompareValue,
+                         m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One()))) {
+          // ballot(x) == ballot(true) means x = true.
+          InferredCondValue = true;
+        } else if (match(CompareValue,
+                         m_Intrinsic<Intrinsic::amdgcn_ballot>(m_Zero()))) {
----------------
jayfoad wrote:

Don't need to do this since you can rely on InstCombine combining ballot(0) -> 0.

https://github.com/llvm/llvm-project/pull/160670


More information about the llvm-commits mailing list