[llvm] [AMDGPU][InstCombine] Fold ballot intrinsic based on llvm.assume hints (PR #160670)
Jay Foad via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 11 04:09:59 PST 2025
================
@@ -1341,6 +1342,71 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
Call->takeName(&II);
return IC.replaceInstUsesWith(II, Call);
}
+
+ // Fold ballot intrinsic based on llvm.assume hint about the result.
+ //
+ // assume(ballot(x) == ballot(true)) -> x = true
+ // assume(ballot(x) == -1) -> x = true
+ // assume(ballot(x) == 0) -> x = false
+ if (Arg->getType()->isIntegerTy(1)) {
+ for (auto &AssumeVH : IC.getAssumptionCache().assumptionsFor(&II)) {
+ if (!AssumeVH)
+ continue;
+
+ auto *Assume = cast<AssumeInst>(AssumeVH);
+ Value *Cond = Assume->getArgOperand(0);
+
+ // Check if assume condition is an equality comparison.
+ auto *ICI = dyn_cast<ICmpInst>(Cond);
+ if (!ICI || ICI->getPredicate() != ICmpInst::ICMP_EQ)
+ continue;
+
+ // Extract the ballot and the value being compared against it.
+ Value *LHS = ICI->getOperand(0), *RHS = ICI->getOperand(1);
+ Value *CompareValue = (LHS == &II) ? RHS : (RHS == &II) ? LHS : nullptr;
+ if (!CompareValue)
+ continue;
+
+ // Determine the constant value of the ballot's condition argument.
+ std::optional<bool> InferredCondValue;
+ if (auto *CI = dyn_cast<ConstantInt>(CompareValue)) {
+ // ballot(x) == -1 means all lanes have x = true.
+ if (CI->isMinusOne())
+ InferredCondValue = true;
+ // ballot(x) == 0 means all lanes have x = false.
+ else if (CI->isZero())
+ InferredCondValue = false;
+ } else if (match(CompareValue,
+ m_Intrinsic<Intrinsic::amdgcn_ballot>(m_One()))) {
+ // ballot(x) == ballot(true) means x = true.
+ InferredCondValue = true;
+ } else if (match(CompareValue,
+ m_Intrinsic<Intrinsic::amdgcn_ballot>(m_Zero()))) {
+ // ballot(x) == ballot(false) means x = false.
+ InferredCondValue = false;
+ }
+
+ if (!InferredCondValue)
+ continue;
+
+ Constant *ReplacementValue =
+ ConstantInt::getBool(Arg->getContext(), *InferredCondValue);
+
+ // Replace dominated uses of the condition argument.
+ bool Changed = false;
+ Arg->replaceUsesWithIf(ReplacementValue, [&](Use &U) {
+ Instruction *UserInst = dyn_cast<Instruction>(U.getUser());
+ bool Dominates =
+ UserInst && IC.getDominatorTree().dominates(Assume, U);
----------------
jayfoad wrote:
Don't quite understand why you need to check UserInst here. And if you do need to check it, do you also need to check that it belongs to the current function?
https://github.com/llvm/llvm-project/pull/160670
More information about the llvm-commits
mailing list