[llvm] [InstCombine] Optimize AMDGPU ballot + assume uniformity patterns (PR #160670)
    Teja Alaghari via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Thu Sep 25 02:27:42 PDT 2025
    
    
  
https://github.com/TejaX-Alaghari created https://github.com/llvm/llvm-project/pull/160670
## Summary
This PR implements an InstCombine optimization that recognizes when AMDGPU ballot intrinsics are used with assumptions about uniformity, specifically the pattern `assume(ballot(cmp) == -1)`.
## Problem
In AMDGPU code, developers often use ballot intrinsics to test uniformity of conditions across a wavefront:
```llvm
%cmp = icmp eq i32 %x, 0
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
%uniform = icmp eq i64 %ballot, -1
call void @llvm.assume(i1 %uniform)
br i1 %cmp, label %then, label %else
```
When `ballot(cmp) == -1`, we know that cmp evaluates to true on all active lanes, making it uniform. However, existing optimizations didn't recognize this pattern, leaving expensive ballot calls and potentially divergent control flow.
## Solution
This optimization adds pattern matching in InstCombine's `visitCallInst` for `llvm.assume` intrinsics to detect:
- `assume(icmp eq (ballot(cmp), -1))` patterns
- Both `i32` and `i64` ballot variants (`@llvm.amdgcn.ballot.i32` and `@llvm.amdgcn.ballot.i64`)
When detected, it:
1. Replaces the ballot condition `cmp` with `ConstantInt::getTrue()`
2. Simplifies the assumption to always true
3. Enables subsequent passes (like SimplifyCFG) to eliminate dead branches
## Example Transformation
**Before:**
```llvm
%cmp = icmp eq i32 %x, 0
%ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
%uniform = icmp eq i64 %ballot, -1
call void @llvm.assume(i1 %uniform)
```
**After InstCombine:**
```llvm
br i1 true, label %then, label %else
```
**After SimplifyCFG:**
```llvm
; Direct branch to %then, %else eliminated
```
>From 3679d0dd4bbfe5753917c9d5a63b639ccbecfda5 Mon Sep 17 00:00:00 2001
From: TejaX-Alaghari <Teja.Alaghari at amd.com>
Date: Thu, 25 Sep 2025 14:27:36 +0530
Subject: [PATCH] [InstCombine] Optimize AMDGPU ballot + assume uniformity
 patterns
When we encounter assume(ballot(cmp) == -1), we know that cmp is uniform
across all lanes and evaluates to true. This optimization recognizes this
pattern and replaces the condition with a constant true, allowing
subsequent passes to eliminate dead code and optimize control flow.
The optimization handles both i32 and i64 ballot intrinsics and only
applies when the ballot result is compared against -1 (all lanes active).
This is a conservative approach that ensures correctness while enabling
significant optimizations for uniform control flow patterns.
---
 .../InstCombine/InstCombineCalls.cpp          |  33 ++++++
 .../amdgpu-assume-ballot-uniform.ll           | 108 ++++++++++++++++++
 2 files changed, 141 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/amdgpu-assume-ballot-uniform.ll
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 6ad493772d170..c23a4e3dfbaf3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3519,6 +3519,39 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       }
     }
 
+    // Optimize AMDGPU ballot uniformity assumptions:
+    // assume(icmp eq (ballot(cmp), -1)) implies that cmp is uniform and true
+    // This allows us to optimize away the ballot and replace cmp with true
+    Value *BallotInst;
+    if (match(IIOperand, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(BallotInst),
+                                        m_AllOnes()))) {
+      // Check if this is an AMDGPU ballot intrinsic
+      if (auto *BallotCall = dyn_cast<IntrinsicInst>(BallotInst)) {
+        if (BallotCall->getIntrinsicID() == Intrinsic::amdgcn_ballot) {
+          Value *BallotCondition = BallotCall->getArgOperand(0);
+
+          // If ballot(cmp) == -1, then cmp is uniform across all lanes and
+          // evaluates to true We can safely replace BallotCondition with true
+          // since ballot == -1 implies all lanes are true
+          if (BallotCondition->getType()->isIntOrIntVectorTy(1) &&
+              !isa<Constant>(BallotCondition)) {
+
+            // Add the condition to the worklist for further optimization
+            Worklist.pushValue(BallotCondition);
+
+            // Replace BallotCondition with true
+            BallotCondition->replaceAllUsesWith(
+                ConstantInt::getTrue(BallotCondition->getType()));
+
+            // The assumption is now always true, so we can simplify it
+            replaceUse(II->getOperandUse(0),
+                       ConstantInt::getTrue(II->getContext()));
+            return II;
+          }
+        }
+      }
+    }
+
     // If there is a dominating assume with the same condition as this one,
     // then this one is redundant, and should be removed.
     KnownBits Known(1);
diff --git a/llvm/test/Transforms/InstCombine/amdgpu-assume-ballot-uniform.ll b/llvm/test/Transforms/InstCombine/amdgpu-assume-ballot-uniform.ll
new file mode 100644
index 0000000000000..3bf3b317b0771
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/amdgpu-assume-ballot-uniform.ll
@@ -0,0 +1,108 @@
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+; Test case for optimizing AMDGPU ballot + assume patterns
+; When we assume that ballot(cmp) == -1, we know that cmp is uniform
+; This allows us to optimize away the ballot and directly branch
+
+define void @test_assume_ballot_uniform(i32 %x) {
+; CHECK-LABEL: @test_assume_ballot_uniform(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
+; CHECK:       foo:
+; CHECK-NEXT:    ret void
+; CHECK:       bar:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %cmp = icmp eq i32 %x, 0
+  %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
+  %all = icmp eq i64 %ballot, -1
+  call void @llvm.assume(i1 %all)
+  br i1 %cmp, label %foo, label %bar
+
+foo:
+  ret void
+
+bar:
+  ret void
+}
+
+; Test case with partial optimization - only ballot removal without branch optimization
+define void @test_assume_ballot_partial(i32 %x) {
+; CHECK-LABEL: @test_assume_ballot_partial(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
+; CHECK:       foo:
+; CHECK-NEXT:    ret void
+; CHECK:       bar:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %cmp = icmp eq i32 %x, 0
+  %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
+  %all = icmp eq i64 %ballot, -1
+  call void @llvm.assume(i1 %all)
+  br i1 %cmp, label %foo, label %bar
+
+foo:
+  ret void
+
+bar:
+  ret void
+}
+
+; Negative test - ballot not compared to -1
+define void @test_assume_ballot_not_uniform(i32 %x) {
+; CHECK-LABEL: @test_assume_ballot_not_uniform(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[X:%.*]], 0
+; CHECK-NEXT:    [[BALLOT:%.*]] = call i64 @llvm.amdgcn.ballot.i64(i1 [[CMP]])
+; CHECK-NEXT:    [[SOME:%.*]] = icmp ne i64 [[BALLOT]], 0
+; CHECK-NEXT:    call void @llvm.assume(i1 [[SOME]])
+; CHECK-NEXT:    br i1 [[CMP]], label [[FOO:%.*]], label [[BAR:%.*]]
+; CHECK:       foo:
+; CHECK-NEXT:    ret void
+; CHECK:       bar:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %cmp = icmp eq i32 %x, 0
+  %ballot = call i64 @llvm.amdgcn.ballot.i64(i1 %cmp)
+  %some = icmp ne i64 %ballot, 0
+  call void @llvm.assume(i1 %some)
+  br i1 %cmp, label %foo, label %bar
+
+foo:
+  ret void
+
+bar:
+  ret void
+}
+
+; Test with 32-bit ballot
+define void @test_assume_ballot_uniform_i32(i32 %x) {
+; CHECK-LABEL: @test_assume_ballot_uniform_i32(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 true, label [[FOO:%.*]], label [[BAR:%.*]]
+; CHECK:       foo:
+; CHECK-NEXT:    ret void
+; CHECK:       bar:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %cmp = icmp eq i32 %x, 0
+  %ballot = call i32 @llvm.amdgcn.ballot.i32(i1 %cmp)
+  %all = icmp eq i32 %ballot, -1  
+  call void @llvm.assume(i1 %all)
+  br i1 %cmp, label %foo, label %bar
+
+foo:
+  ret void
+
+bar:
+  ret void
+}
+
+declare i64 @llvm.amdgcn.ballot.i64(i1)
+declare i32 @llvm.amdgcn.ballot.i32(i1)
+declare void @llvm.assume(i1)
    
    
More information about the llvm-commits
mailing list