[Mlir-commits] [mlir] [mlir][gpu] Add gpu.ballot operation to GPU dialect (PR #188647)

Bangtian Liu llvmlistbot at llvm.org
Thu Mar 26 08:13:22 PDT 2026


https://github.com/bangtianliu updated https://github.com/llvm/llvm-project/pull/188647

>From c413203cbc64528b3883eb3f6542a5204c6d470c Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Wed, 25 Mar 2026 17:49:09 -0700
Subject: [PATCH] [mlir][gpu] Add gpu.ballot operation to GPU dialect

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 23 +++++++++++++++++++
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      | 15 +++++++++++-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 18 +++++++++++++++
 .../Conversion/GPUToROCDL/gpu-to-rocdl.mlir   | 10 ++++++++
 .../GPU/broadcast-speculatability.mlir        | 18 +++++++++++++++
 mlir/test/Dialect/GPU/invalid.mlir            |  8 +++++++
 mlir/test/Dialect/GPU/ops.mlir                | 10 ++++++++
 7 files changed, 101 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 36e0875f53b0a..d302c2f3d8fe8 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -3288,4 +3288,27 @@ def GPU_SubgroupBroadcastOp : GPU_Op<"subgroup_broadcast",
   let hasVerifier = 1;
 }
 
+def GPU_BallotOp : GPU_Op<"ballot",
+    [NoMemoryEffect,
+     DeclareOpInterfaceMethods<ConditionallySpeculatable, ["getSpeculatability"]>]>,
+    Arguments<(ins I1:$predicate)>,
+    Results<(outs AnyInteger:$result)> {
+  let summary = "Collects predicate values from all threads in a subgroup.";
+  let description = [{
+    The `ballot` op performs a ballot operation across all threads in a subgroup.
+    Each thread contributes its predicate value as a single bit. The result is an
+    integer where the Nth bit is set iff the Nth thread's predicate is true.
+
+    The result type width must be at least 32 bits (minimum subgroup size).
+
+    Example:
+    ```mlir
+    %0 = gpu.ballot %pred : i32
+    %1 = gpu.ballot %pred : i64
+    ```
+  }];
+  let assemblyFormat = "$predicate attr-dict `:` type($result)";
+  let hasVerifier = 1;
+}
+
 #endif // GPU_OPS
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e08ec138c853a..bb4ea4c3edfbe 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -417,6 +417,19 @@ struct GPUSubgroupBroadcastOpToROCDL
   }
 };
 
+struct GPUBallotOpToROCDL : public ConvertOpToLLVMPattern<gpu::BallotOp> {
+  using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // rocdl.ballot directly supports i32 and i64 result types.
+    rewriter.replaceOpWithNewOp<ROCDL::BallotOp>(op, op.getType(),
+                                                 adaptor.getPredicate());
+    return success();
+  }
+};
+
 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
 
@@ -764,7 +777,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
   patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
 
   patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
-               GPUSubgroupBroadcastOpToROCDL>(converter);
+               GPUSubgroupBroadcastOpToROCDL, GPUBallotOpToROCDL>(converter);
   patterns.add<GPUSubgroupIdOpToROCDL, GPUSubgroupSizeOpToROCDL,
                GPUBarrierOpLowering>(converter, chipset);
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5d409f71847c6..06ea48428e914 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2619,6 +2619,24 @@ OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// GPU_BallotOp
+//===----------------------------------------------------------------------===//
+
+Speculation::Speculatability gpu::BallotOp::getSpeculatability() {
+  // Ballot depends on active lanes, cannot speculate across control flow.
+  return Speculation::NotSpeculatable;
+}
+
+LogicalResult gpu::BallotOp::verify() {
+  auto intType = dyn_cast<IntegerType>(getResult().getType());
+  if (!intType)
+    return emitOpError("result must be an integer type");
+  if (intType.getWidth() < 32)
+    return emitOpError("result type must have at least 32 bits");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GPU KernelMetadataAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 5eaa2d0b4df28..68a5328b8eb77 100755
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -989,4 +989,14 @@ func.func @broadcast_3xi16(%arg0 : vector<3xi16>) -> vector<3xi16> {
   %0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<3xi16>
   func.return %0 : vector<3xi16>
 }
+
+// CHECK-LABEL: func @ballot
+//  CHECK-SAME: (%[[PRED:.*]]: i1)
+func.func @ballot(%pred: i1) -> (i32, i64) {
+  // CHECK: rocdl.ballot %[[PRED]] : i32
+  %0 = gpu.ballot %pred : i32
+  // CHECK: rocdl.ballot %[[PRED]] : i64
+  %1 = gpu.ballot %pred : i64
+  func.return %0, %1 : i32, i64
+}
 }
diff --git a/mlir/test/Dialect/GPU/broadcast-speculatability.mlir b/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
index 3cf4853effee5..7062ce4e6f46f 100644
--- a/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
+++ b/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
 
 func.func private @side_effect(%arg0 : f32, %arg1 : f32)
+func.func private @use_i32(%arg0 : i32)
 
 // CHECK-LABEL: func @broadcast_hoisting
 //  CHECK-SAME: (%[[ARG:.*]]: f32, %[[IDX:.*]]: i32, {{.*}}: index)
@@ -20,3 +21,20 @@ func.func @broadcast_hoisting(%arg0 : f32, %arg1 : i32, %arg2 : index) {
   }
   func.return
 }
+
+// CHECK-LABEL: func @ballot_no_hoisting
+//  CHECK-SAME: (%[[PRED:.*]]: i1, {{.*}}: index)
+func.func @ballot_no_hoisting(%pred: i1, %n: index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // Ballot cannot be speculated across control flow because
+  // it depends on active lanes, which can change.
+  // CHECK: scf.for
+  // CHECK: %[[BALLOT:.*]] = gpu.ballot %[[PRED]] : i32
+  // CHECK: func.call @use_i32(%[[BALLOT]])
+  scf.for %i = %c0 to %n step %c1 {
+    %0 = gpu.ballot %pred : i32
+    func.call @use_i32(%0) : (i32) -> ()
+  }
+  func.return
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index bf862b2c5ae3c..67bb01b13784b 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -1119,3 +1119,11 @@ func.func @warp_execute_wrong_terminator() {
   }
   return
 }
+
+// -----
+
+func.func @ballot_invalid_type(%pred: i1) {
+  // expected-error @+1 {{'gpu.ballot' op result type must have at least 32 bits}}
+  %0 = gpu.ballot %pred : i16
+  return
+}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 1d05268ed4475..a5dad3f931cc1 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -586,3 +586,13 @@ func.func @subgroup_broadcast(%arg0 : f32, %arg1 : i32) -> (f32, f32) {
   %1 = gpu.subgroup_broadcast %arg0, specific_lane %arg1 : f32
   func.return %0, %1 : f32, f32
 }
+
+// CHECK-LABEL: func @ballot
+//  CHECK-SAME: (%[[PRED:.*]]: i1)
+func.func @ballot(%pred: i1) -> (i32, i64) {
+  // CHECK: gpu.ballot %[[PRED]] : i32
+  %0 = gpu.ballot %pred : i32
+  // CHECK: gpu.ballot %[[PRED]] : i64
+  %1 = gpu.ballot %pred : i64
+  func.return %0, %1 : i32, i64
+}



More information about the Mlir-commits mailing list