[Mlir-commits] [mlir] [mlir][gpu] Add gpu.ballot operation to GPU dialect (PR #188647)
Bangtian Liu
llvmlistbot at llvm.org
Wed Mar 25 17:57:26 PDT 2026
https://github.com/bangtianliu updated https://github.com/llvm/llvm-project/pull/188647
>From a3d65882925a3b8dc10792aa03233154c50c6e9f Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Wed, 25 Mar 2026 17:49:09 -0700
Subject: [PATCH] [mlir][gpu] Add gpu.ballot operation to GPU dialect
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 23 +++++++
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 28 ++++++++-
.../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 15 ++++-
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 63 ++++++++++++++++++-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 18 ++++++
.../Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 16 +++++
.../Conversion/GPUToROCDL/gpu-to-rocdl.mlir | 10 +++
.../Conversion/GPUToSPIRV/gpu-to-spirv.mlir | 44 +++++++++++++
.../GPU/broadcast-speculatability.mlir | 18 ++++++
mlir/test/Dialect/GPU/invalid.mlir | 8 +++
mlir/test/Dialect/GPU/ops.mlir | 10 +++
11 files changed, 249 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 36e0875f53b0a..d302c2f3d8fe8 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -3288,4 +3288,27 @@ def GPU_SubgroupBroadcastOp : GPU_Op<"subgroup_broadcast",
let hasVerifier = 1;
}
+def GPU_BallotOp : GPU_Op<"ballot",
+ [NoMemoryEffect,
+ DeclareOpInterfaceMethods<ConditionallySpeculatable, ["getSpeculatability"]>]>,
+ Arguments<(ins I1:$predicate)>,
+ Results<(outs AnyInteger:$result)> {
+ let summary = "Collects predicate values from all threads in a subgroup.";
+ let description = [{
+ The `ballot` op performs a ballot operation across all threads in a subgroup.
+ Each thread contributes its predicate value as a single bit. The result is an
+ integer where the Nth bit is set iff the Nth thread's predicate is true.
+
+ The result type width must be at least 32 bits (minimum subgroup size).
+
+ Example:
+ ```mlir
+ %0 = gpu.ballot %pred : i32
+ %1 = gpu.ballot %pred : i64
+ ```
+ }];
+ let assemblyFormat = "$predicate attr-dict `:` type($result)";
+ let hasVerifier = 1;
+}
+
#endif // GPU_OPS
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 660b24b071b49..5b71c1ef73735 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -236,6 +236,30 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
}
};
+struct GPUBallotOpToNVVM : ConvertOpToLLVMPattern<gpu::BallotOp> {
+ using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto i32Type = rewriter.getI32Type();
+
+ // Full mask: all 32 lanes participate.
+ Value mask = LLVM::ConstantOp::create(rewriter, loc, i32Type, -1);
+ Value result = NVVM::VoteSyncOp::create(rewriter, loc, i32Type, mask,
+ adaptor.getPredicate(),
+ NVVM::VoteSyncKind::ballot);
+
+ // Zero-extend to i64 if needed.
+ if (op.getType() != i32Type)
+ result = LLVM::ZExtOp::create(rewriter, loc, op.getType(), result);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
/// Lowering of cf.assert into a conditional __assertfail.
struct AssertOpToAssertfailLowering
: public ConvertOpToLLVMPattern<cf::AssertOp> {
@@ -504,8 +528,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
patterns.add<gpu::index_lowering::OpLowering<
gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
converter, IndexKind::Grid, IntrType::Dim, benefit);
- patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
- converter, benefit);
+ patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUBallotOpToNVVM,
+ GPUReturnOpLowering>(converter, benefit);
patterns.add<GPUDynamicSharedMemoryOpLowering>(
converter, NVVM::kSharedMemoryAlignmentBit, benefit);
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e08ec138c853a..bb4ea4c3edfbe 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -417,6 +417,19 @@ struct GPUSubgroupBroadcastOpToROCDL
}
};
+struct GPUBallotOpToROCDL : public ConvertOpToLLVMPattern<gpu::BallotOp> {
+ using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // rocdl.ballot directly supports i32 and i64 result types.
+ rewriter.replaceOpWithNewOp<ROCDL::BallotOp>(op, op.getType(),
+ adaptor.getPredicate());
+ return success();
+ }
+};
+
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
@@ -764,7 +777,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
- GPUSubgroupBroadcastOpToROCDL>(converter);
+ GPUSubgroupBroadcastOpToROCDL, GPUBallotOpToROCDL>(converter);
patterns.add<GPUSubgroupIdOpToROCDL, GPUSubgroupSizeOpToROCDL,
GPUBarrierOpLowering>(converter, chipset);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c33a903d03393..f3768decbe695 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -139,6 +139,16 @@ class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Pattern to convert a gpu.ballot op into a spirv.GroupNonUniformBallot op.
+class GPUBallotConversion final : public OpConversionPattern<gpu::BallotOp> {
+public:
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(gpu::BallotOp ballotOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -504,6 +514,56 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// Ballot
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPUBallotConversion::matchAndRewrite(
+ gpu::BallotOp ballotOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = ballotOp.getLoc();
+ auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+ auto i32Type = rewriter.getI32Type();
+ auto vec4i32Type = VectorType::get({4}, i32Type);
+
+ // SPIR-V GroupNonUniformBallot returns vector<4xi32>.
+ Value ballot = spirv::GroupNonUniformBallotOp::create(
+ rewriter, loc, vec4i32Type, scope, adaptor.getPredicate());
+
+ // Extract the result based on the requested integer type width.
+ auto intType = cast<IntegerType>(ballotOp.getType());
+ if (intType.getWidth() == 32) {
+ // Extract first element for i32 result.
+ Value result = spirv::CompositeExtractOp::create(rewriter, loc, ballot,
+ ArrayRef<int32_t>{0});
+ rewriter.replaceOp(ballotOp, result);
+ } else if (intType.getWidth() == 64) {
+ // Extract low and high 32-bit parts, combine into i64.
+ Value low = spirv::CompositeExtractOp::create(rewriter, loc, ballot,
+ ArrayRef<int32_t>{0});
+ Value high = spirv::CompositeExtractOp::create(rewriter, loc, ballot,
+ ArrayRef<int32_t>{1});
+ auto i64Type = rewriter.getI64Type();
+ // Zero-extend low and high to i64.
+ Value lowExt = spirv::UConvertOp::create(rewriter, loc, i64Type, low);
+ Value highExt = spirv::UConvertOp::create(rewriter, loc, i64Type, high);
+ // Shift high left by 32 bits.
+ Value shift = spirv::ConstantOp::create(rewriter, loc, i64Type,
+ rewriter.getI64IntegerAttr(32));
+ Value highShifted =
+ spirv::ShiftLeftLogicalOp::create(rewriter, loc, highExt, shift);
+ // Combine with bitwise OR.
+ Value result =
+ spirv::BitwiseOrOp::create(rewriter, loc, lowExt, highShifted);
+ rewriter.replaceOp(ballotOp, result);
+ } else {
+ return rewriter.notifyMatchFailure(
+ ballotOp, "unsupported result type width for ballot");
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Rotate
//===----------------------------------------------------------------------===//
@@ -831,7 +891,8 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
- GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
+ GPUReturnOpConversion, GPUShuffleConversion, GPUBallotConversion,
+ GPURotateConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5d409f71847c6..06ea48428e914 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2619,6 +2619,24 @@ OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// GPU_BallotOp
+//===----------------------------------------------------------------------===//
+
+Speculation::Speculatability gpu::BallotOp::getSpeculatability() {
+ // Ballot depends on active lanes, cannot speculate across control flow.
+ return Speculation::NotSpeculatable;
+}
+
+LogicalResult gpu::BallotOp::verify() {
+ auto intType = dyn_cast<IntegerType>(getResult().getType());
+ if (!intType)
+ return emitOpError("result must be an integer type");
+ if (intType.getWidth() < 32)
+ return emitOpError("result type must have at least 32 bits");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// GPU KernelMetadataAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 929794f035b9f..a0973ab98932e 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1197,3 +1197,19 @@ gpu.module @test_module_cluster_block_ops {
}
}
+// -----
+
+gpu.module @test_module_ballot {
+ // CHECK-LABEL: func @ballot
+ func.func @ballot(%pred: i1) -> (i32, i64) {
+ // CHECK: %[[BALLOT_MASK:.*]] = llvm.mlir.constant(-1 : i32) : i32
+ // CHECK: %[[BALLOT32:.*]] = nvvm.vote.sync ballot %[[BALLOT_MASK]], %{{.*}} -> i32
+ %0 = gpu.ballot %pred : i32
+ // CHECK: %[[BALLOT_MASK2:.*]] = llvm.mlir.constant(-1 : i32) : i32
+ // CHECK: %[[BALLOT:.*]] = nvvm.vote.sync ballot %[[BALLOT_MASK2]], %{{.*}} -> i32
+ // CHECK: %[[BALLOT64:.*]] = llvm.zext %[[BALLOT]] : i32 to i64
+ %1 = gpu.ballot %pred : i64
+ func.return %0, %1 : i32, i64
+ }
+}
+
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 5eaa2d0b4df28..68a5328b8eb77 100755
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -989,4 +989,14 @@ func.func @broadcast_3xi16(%arg0 : vector<3xi16>) -> vector<3xi16> {
%0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<3xi16>
func.return %0 : vector<3xi16>
}
+
+// CHECK-LABEL: func @ballot
+// CHECK-SAME: (%[[PRED:.*]]: i1)
+func.func @ballot(%pred: i1) -> (i32, i64) {
+ // CHECK: rocdl.ballot %[[PRED]] : i32
+ %0 = gpu.ballot %pred : i32
+ // CHECK: rocdl.ballot %[[PRED]] : i64
+ %1 = gpu.ballot %pred : i64
+ func.return %0, %1 : i32, i64
+}
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
index 7bf6f8419be0d..9991a4ef69984 100644
--- a/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
@@ -128,3 +128,47 @@ module attributes {gpu.container_module} {
return
}
}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<subgroup_size = 32>>
+} {
+ gpu.module @kernels {
+ // CHECK-LABEL: spirv.func @ballot_i32
+ // CHECK-SAME: (%[[PRED:.*]]: i1
+ gpu.func @ballot_i32(%pred: i1) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} {
+ // CHECK: %[[BALLOT:.*]] = spirv.GroupNonUniformBallot <Subgroup> %[[PRED]] : vector<4xi32>
+ // CHECK: spirv.CompositeExtract %[[BALLOT]][0 : i32] : vector<4xi32>
+ %0 = gpu.ballot %pred : i32
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniformBallot, Int64], []>, #spirv.resource_limits<subgroup_size = 64>>
+} {
+ gpu.module @kernels {
+ // CHECK-LABEL: spirv.func @ballot_i64
+ // CHECK-SAME: (%[[PRED:.*]]: i1
+ gpu.func @ballot_i64(%pred: i1) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [64, 1, 1]>} {
+ // CHECK: %[[BALLOT:.*]] = spirv.GroupNonUniformBallot <Subgroup> %[[PRED]] : vector<4xi32>
+ // CHECK: %[[LOW:.*]] = spirv.CompositeExtract %[[BALLOT]][0 : i32] : vector<4xi32>
+ // CHECK: %[[HIGH:.*]] = spirv.CompositeExtract %[[BALLOT]][1 : i32] : vector<4xi32>
+ // CHECK: %[[LOW_EXT:.*]] = spirv.UConvert %[[LOW]] : i32 to i64
+ // CHECK: %[[HIGH_EXT:.*]] = spirv.UConvert %[[HIGH]] : i32 to i64
+ // CHECK: spirv.Constant 32 : i64
+ // CHECK: spirv.ShiftLeftLogical %[[HIGH_EXT]],
+ // CHECK: spirv.BitwiseOr %[[LOW_EXT]],
+ %0 = gpu.ballot %pred : i64
+ gpu.return
+ }
+ }
+}
diff --git a/mlir/test/Dialect/GPU/broadcast-speculatability.mlir b/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
index 3cf4853effee5..7062ce4e6f46f 100644
--- a/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
+++ b/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
func.func private @side_effect(%arg0 : f32, %arg1 : f32)
+func.func private @use_i32(%arg0 : i32)
// CHECK-LABEL: func @broadcast_hoisting
// CHECK-SAME: (%[[ARG:.*]]: f32, %[[IDX:.*]]: i32, {{.*}}: index)
@@ -20,3 +21,20 @@ func.func @broadcast_hoisting(%arg0 : f32, %arg1 : i32, %arg2 : index) {
}
func.return
}
+
+// CHECK-LABEL: func @ballot_no_hoisting
+// CHECK-SAME: (%[[PRED:.*]]: i1, {{.*}}: index)
+func.func @ballot_no_hoisting(%pred: i1, %n: index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // Ballot cannot be speculated across control flow because
+ // it depends on active lanes, which can change.
+ // CHECK: scf.for
+ // CHECK: %[[BALLOT:.*]] = gpu.ballot %[[PRED]] : i32
+ // CHECK: func.call @use_i32(%[[BALLOT]])
+ scf.for %i = %c0 to %n step %c1 {
+ %0 = gpu.ballot %pred : i32
+ func.call @use_i32(%0) : (i32) -> ()
+ }
+ func.return
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index bf862b2c5ae3c..67bb01b13784b 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -1119,3 +1119,11 @@ func.func @warp_execute_wrong_terminator() {
}
return
}
+
+// -----
+
+func.func @ballot_invalid_type(%pred: i1) {
+ // expected-error @+1 {{'gpu.ballot' op result type must have at least 32 bits}}
+ %0 = gpu.ballot %pred : i16
+ return
+}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 1d05268ed4475..a5dad3f931cc1 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -586,3 +586,13 @@ func.func @subgroup_broadcast(%arg0 : f32, %arg1 : i32) -> (f32, f32) {
%1 = gpu.subgroup_broadcast %arg0, specific_lane %arg1 : f32
func.return %0, %1 : f32, f32
}
+
+// CHECK-LABEL: func @ballot
+// CHECK-SAME: (%[[PRED:.*]]: i1)
+func.func @ballot(%pred: i1) -> (i32, i64) {
+ // CHECK: gpu.ballot %[[PRED]] : i32
+ %0 = gpu.ballot %pred : i32
+ // CHECK: gpu.ballot %[[PRED]] : i64
+ %1 = gpu.ballot %pred : i64
+ func.return %0, %1 : i32, i64
+}
More information about the Mlir-commits
mailing list