[Mlir-commits] [mlir] [mlir][gpu] Add gpu.ballot operation to GPU dialect (PR #188647)

Bangtian Liu llvmlistbot at llvm.org
Wed Mar 25 17:51:24 PDT 2026


https://github.com/bangtianliu updated https://github.com/llvm/llvm-project/pull/188647

>From 278f7188f30c6d2651f0b545025676027a19c2b1 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Wed, 25 Mar 2026 17:49:09 -0700
Subject: [PATCH] [mlir][gpu] Add gpu.ballot operation to GPU dialect

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 23 +++++++
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 28 ++++++++-
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      | 15 ++++-
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 63 ++++++++++++++++++-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 18 ++++++
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     | 16 +++++
 .../Conversion/GPUToROCDL/gpu-to-rocdl.mlir   | 10 +++
 .../Conversion/GPUToSPIRV/gpu-to-spirv.mlir   | 44 +++++++++++++
 .../GPU/broadcast-speculatability.mlir        | 18 ++++++
 mlir/test/Dialect/GPU/invalid.mlir            |  8 +++
 mlir/test/Dialect/GPU/ops.mlir                | 10 +++
 11 files changed, 249 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 36e0875f53b0a..d302c2f3d8fe8 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -3288,4 +3288,27 @@ def GPU_SubgroupBroadcastOp : GPU_Op<"subgroup_broadcast",
   let hasVerifier = 1;
 }
 
+def GPU_BallotOp : GPU_Op<"ballot",
+    [NoMemoryEffect,
+     DeclareOpInterfaceMethods<ConditionallySpeculatable, ["getSpeculatability"]>]>,
+    Arguments<(ins I1:$predicate)>,
+    Results<(outs AnyInteger:$result)> {
+  let summary = "Collects predicate values from all threads in a subgroup.";
+  let description = [{
+    The `ballot` op performs a ballot operation across all threads in a subgroup.
+    Each thread contributes its predicate value as a single bit. The result is an
+    integer where the Nth bit is set iff the Nth thread's predicate is true.
+
+    The result type width must be at least 32 bits (minimum subgroup size).
+
+    Example:
+    ```mlir
+    %0 = gpu.ballot %pred : i32
+    %1 = gpu.ballot %pred : i64
+    ```
+  }];
+  let assemblyFormat = "$predicate attr-dict `:` type($result)";
+  let hasVerifier = 1;
+}
+
 #endif // GPU_OPS
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 660b24b071b49..5b71c1ef73735 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -236,6 +236,30 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   }
 };
 
+struct GPUBallotOpToNVVM : ConvertOpToLLVMPattern<gpu::BallotOp> {
+  using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto i32Type = rewriter.getI32Type();
+
+    // Full mask: all 32 lanes participate.
+    Value mask = LLVM::ConstantOp::create(rewriter, loc, i32Type, -1);
+    Value result = NVVM::VoteSyncOp::create(rewriter, loc, i32Type, mask,
+                                            adaptor.getPredicate(),
+                                            NVVM::VoteSyncKind::ballot);
+
+    // Zero-extend to i64 if needed.
+    if (op.getType() != i32Type)
+      result = LLVM::ZExtOp::create(rewriter, loc, op.getType(), result);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 /// Lowering of cf.assert into a conditional __assertfail.
 struct AssertOpToAssertfailLowering
     : public ConvertOpToLLVMPattern<cf::AssertOp> {
@@ -504,8 +528,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
   patterns.add<gpu::index_lowering::OpLowering<
       gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
       converter, IndexKind::Grid, IntrType::Dim, benefit);
-  patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
-      converter, benefit);
+  patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUBallotOpToNVVM,
+               GPUReturnOpLowering>(converter, benefit);
 
   patterns.add<GPUDynamicSharedMemoryOpLowering>(
       converter, NVVM::kSharedMemoryAlignmentBit, benefit);
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e08ec138c853a..bb4ea4c3edfbe 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -417,6 +417,19 @@ struct GPUSubgroupBroadcastOpToROCDL
   }
 };
 
+struct GPUBallotOpToROCDL : public ConvertOpToLLVMPattern<gpu::BallotOp> {
+  using ConvertOpToLLVMPattern<gpu::BallotOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::BallotOp op, gpu::BallotOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // rocdl.ballot directly supports i32 and i64 result types.
+    rewriter.replaceOpWithNewOp<ROCDL::BallotOp>(op, op.getType(),
+                                                 adaptor.getPredicate());
+    return success();
+  }
+};
+
 struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
 
@@ -764,7 +777,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
   patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
 
   patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL,
-               GPUSubgroupBroadcastOpToROCDL>(converter);
+               GPUSubgroupBroadcastOpToROCDL, GPUBallotOpToROCDL>(converter);
   patterns.add<GPUSubgroupIdOpToROCDL, GPUSubgroupSizeOpToROCDL,
                GPUBarrierOpLowering>(converter, chipset);
 
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index c33a903d03393..ab7a0c7752fbe 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -139,6 +139,16 @@ class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Pattern to convert a gpu.ballot op into a spirv.GroupNonUniformBallot op.
+class GPUBallotConversion final : public OpConversionPattern<gpu::BallotOp> {
+public:
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(gpu::BallotOp ballotOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -504,6 +514,56 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Ballot
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPUBallotConversion::matchAndRewrite(
+    gpu::BallotOp ballotOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = ballotOp.getLoc();
+  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+  auto i32Type = rewriter.getI32Type();
+  auto vec4i32Type = VectorType::get({4}, i32Type);
+
+  // SPIR-V GroupNonUniformBallot returns vector<4xi32>.
+  Value ballot = spirv::GroupNonUniformBallotOp::create(
+      rewriter, loc, vec4i32Type, scope, adaptor.getPredicate());
+
+  // Extract the result based on the requested integer type width.
+  auto intType = cast<IntegerType>(ballotOp.getType());
+  if (intType.getWidth() == 32) {
+    // Extract first element for i32 result.
+    Value result = spirv::CompositeExtractOp::create(rewriter, loc, ballot,
+                                                     ArrayRef<int32_t>{0});
+    rewriter.replaceOp(ballotOp, result);
+  } else if (intType.getWidth() == 64) {
+    // Extract low and high 32-bit parts, combine into i64.
+    Value low = spirv::CompositeExtractOp::create(rewriter, loc, ballot,
+                                                  ArrayRef<int32_t>{0});
+    Value high = spirv::CompositeExtractOp::create(rewriter, loc, ballot,
+                                                   ArrayRef<int32_t>{1});
+    auto i64Type = rewriter.getI64Type();
+    // Zero-extend low and high to i64.
+    Value lowExt = spirv::UConvertOp::create(rewriter, loc, i64Type, low);
+    Value highExt = spirv::UConvertOp::create(rewriter, loc, i64Type, high);
+    // Shift high left by 32 bits.
+    Value shift = spirv::ConstantOp::create(
+        rewriter, loc, i64Type, rewriter.getI64IntegerAttr(32));
+    Value highShifted =
+        spirv::ShiftLeftLogicalOp::create(rewriter, loc, highExt, shift);
+    // Combine with bitwise OR.
+    Value result =
+        spirv::BitwiseOrOp::create(rewriter, loc, lowExt, highShifted);
+    rewriter.replaceOp(ballotOp, result);
+  } else {
+    return rewriter.notifyMatchFailure(
+        ballotOp, "unsupported result type width for ballot");
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Rotate
 //===----------------------------------------------------------------------===//
@@ -831,7 +891,8 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                       RewritePatternSet &patterns) {
   patterns.add<
       GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
-      GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
+      GPUReturnOpConversion, GPUShuffleConversion, GPUBallotConversion,
+      GPURotateConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5d409f71847c6..06ea48428e914 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2619,6 +2619,24 @@ OpFoldResult gpu::SubgroupBroadcastOp::fold(FoldAdaptor /*adaptor*/) {
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// GPU_BallotOp
+//===----------------------------------------------------------------------===//
+
+Speculation::Speculatability gpu::BallotOp::getSpeculatability() {
+  // Ballot depends on active lanes, cannot speculate across control flow.
+  return Speculation::NotSpeculatable;
+}
+
+LogicalResult gpu::BallotOp::verify() {
+  auto intType = dyn_cast<IntegerType>(getResult().getType());
+  if (!intType)
+    return emitOpError("result must be an integer type");
+  if (intType.getWidth() < 32)
+    return emitOpError("result type must have at least 32 bits");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GPU KernelMetadataAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 929794f035b9f..a0973ab98932e 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1197,3 +1197,19 @@ gpu.module @test_module_cluster_block_ops {
   }
 }
 
+// -----
+
+gpu.module @test_module_ballot {
+  // CHECK-LABEL: func @ballot
+  func.func @ballot(%pred: i1) -> (i32, i64) {
+    // CHECK: %[[BALLOT_MASK:.*]] = llvm.mlir.constant(-1 : i32) : i32
+    // CHECK: %[[BALLOT32:.*]] = nvvm.vote.sync ballot %[[BALLOT_MASK]], %{{.*}} -> i32
+    %0 = gpu.ballot %pred : i32
+    // CHECK: %[[BALLOT_MASK2:.*]] = llvm.mlir.constant(-1 : i32) : i32
+    // CHECK: %[[BALLOT:.*]] = nvvm.vote.sync ballot %[[BALLOT_MASK2]], %{{.*}} -> i32
+    // CHECK: %[[BALLOT64:.*]] = llvm.zext %[[BALLOT]] : i32 to i64
+    %1 = gpu.ballot %pred : i64
+    func.return %0, %1 : i32, i64
+  }
+}
+
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 5eaa2d0b4df28..68a5328b8eb77 100755
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -989,4 +989,14 @@ func.func @broadcast_3xi16(%arg0 : vector<3xi16>) -> vector<3xi16> {
   %0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<3xi16>
   func.return %0 : vector<3xi16>
 }
+
+// CHECK-LABEL: func @ballot
+//  CHECK-SAME: (%[[PRED:.*]]: i1)
+func.func @ballot(%pred: i1) -> (i32, i64) {
+  // CHECK: rocdl.ballot %[[PRED]] : i32
+  %0 = gpu.ballot %pred : i32
+  // CHECK: rocdl.ballot %[[PRED]] : i64
+  %1 = gpu.ballot %pred : i64
+  func.return %0, %1 : i32, i64
+}
 }
diff --git a/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
index 7bf6f8419be0d..9991a4ef69984 100644
--- a/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
@@ -128,3 +128,47 @@ module attributes {gpu.container_module} {
     return
   }
 }
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniformBallot], []>, #spirv.resource_limits<subgroup_size = 32>>
+} {
+  gpu.module @kernels {
+    // CHECK-LABEL:  spirv.func @ballot_i32
+    //  CHECK-SAME:  (%[[PRED:.*]]: i1
+    gpu.func @ballot_i32(%pred: i1) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} {
+      // CHECK: %[[BALLOT:.*]] = spirv.GroupNonUniformBallot <Subgroup> %[[PRED]] : vector<4xi32>
+      // CHECK: spirv.CompositeExtract %[[BALLOT]][0 : i32] : vector<4xi32>
+      %0 = gpu.ballot %pred : i32
+      gpu.return
+    }
+  }
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniformBallot, Int64], []>, #spirv.resource_limits<subgroup_size = 64>>
+} {
+  gpu.module @kernels {
+    // CHECK-LABEL:  spirv.func @ballot_i64
+    //  CHECK-SAME:  (%[[PRED:.*]]: i1
+    gpu.func @ballot_i64(%pred: i1) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [64, 1, 1]>} {
+      // CHECK: %[[BALLOT:.*]] = spirv.GroupNonUniformBallot <Subgroup> %[[PRED]] : vector<4xi32>
+      // CHECK: %[[LOW:.*]] = spirv.CompositeExtract %[[BALLOT]][0 : i32] : vector<4xi32>
+      // CHECK: %[[HIGH:.*]] = spirv.CompositeExtract %[[BALLOT]][1 : i32] : vector<4xi32>
+      // CHECK: %[[LOW_EXT:.*]] = spirv.UConvert %[[LOW]] : i32 to i64
+      // CHECK: %[[HIGH_EXT:.*]] = spirv.UConvert %[[HIGH]] : i32 to i64
+      // CHECK: spirv.Constant 32 : i64
+      // CHECK: spirv.ShiftLeftLogical %[[HIGH_EXT]],
+      // CHECK: spirv.BitwiseOr %[[LOW_EXT]],
+      %0 = gpu.ballot %pred : i64
+      gpu.return
+    }
+  }
+}
diff --git a/mlir/test/Dialect/GPU/broadcast-speculatability.mlir b/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
index 3cf4853effee5..7062ce4e6f46f 100644
--- a/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
+++ b/mlir/test/Dialect/GPU/broadcast-speculatability.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
 
 func.func private @side_effect(%arg0 : f32, %arg1 : f32)
+func.func private @use_i32(%arg0 : i32)
 
 // CHECK-LABEL: func @broadcast_hoisting
 //  CHECK-SAME: (%[[ARG:.*]]: f32, %[[IDX:.*]]: i32, {{.*}}: index)
@@ -20,3 +21,20 @@ func.func @broadcast_hoisting(%arg0 : f32, %arg1 : i32, %arg2 : index) {
   }
   func.return
 }
+
+// CHECK-LABEL: func @ballot_no_hoisting
+//  CHECK-SAME: (%[[PRED:.*]]: i1, {{.*}}: index)
+func.func @ballot_no_hoisting(%pred: i1, %n: index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // Ballot cannot be speculated across control flow because
+  // it depends on active lanes, which can change.
+  // CHECK: scf.for
+  // CHECK: %[[BALLOT:.*]] = gpu.ballot %[[PRED]] : i32
+  // CHECK: func.call @use_i32(%[[BALLOT]])
+  scf.for %i = %c0 to %n step %c1 {
+    %0 = gpu.ballot %pred : i32
+    func.call @use_i32(%0) : (i32) -> ()
+  }
+  func.return
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index bf862b2c5ae3c..67bb01b13784b 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -1119,3 +1119,11 @@ func.func @warp_execute_wrong_terminator() {
   }
   return
 }
+
+// -----
+
+func.func @ballot_invalid_type(%pred: i1) {
+  // expected-error @+1 {{'gpu.ballot' op result type must have at least 32 bits}}
+  %0 = gpu.ballot %pred : i16
+  return
+}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 1d05268ed4475..a5dad3f931cc1 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -586,3 +586,13 @@ func.func @subgroup_broadcast(%arg0 : f32, %arg1 : i32) -> (f32, f32) {
   %1 = gpu.subgroup_broadcast %arg0, specific_lane %arg1 : f32
   func.return %0, %1 : f32, f32
 }
+
+// CHECK-LABEL: func @ballot
+//  CHECK-SAME: (%[[PRED:.*]]: i1)
+func.func @ballot(%pred: i1) -> (i32, i64) {
+  // CHECK: gpu.ballot %[[PRED]] : i32
+  %0 = gpu.ballot %pred : i32
+  // CHECK: gpu.ballot %[[PRED]] : i64
+  %1 = gpu.ballot %pred : i64
+  func.return %0, %1 : i32, i64
+}



More information about the Mlir-commits mailing list