[Mlir-commits] [mlir] [MLIR][NVVM] Add an explicit mask operand to elect.sync (PR #145509)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 24 06:13:44 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

<details>
<summary>Changes</summary>

This patch adds a mask operand to elect.sync explicitly.
When provided, this overrides the default value of 0xffffffff.

---
Full diff: https://github.com/llvm/llvm-project/pull/145509.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+8-6) 
- (added) mlir/test/Target/LLVMIR/nvvm/elect.mlir (+20) 
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (-9) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 418931b931265..6895e946b8a45 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -965,19 +965,21 @@ def NVVM_ElectSyncOp : NVVM_Op<"elect.sync">
   let summary = "Elect one leader thread";
   let description = [{
     The `elect.sync` instruction elects one predicated active leader
-    thread from among a set of threads specified in membermask.
-    The membermask is set to `0xFFFFFFFF` for the current version
-    of this Op. The predicate result is set to `True` for the
-    leader thread, and `False` for all other threads.
+    thread from among a set of threads specified in the `membermask`.
+    When the `membermask` is not provided explicitly, a default value
+    of `0xFFFFFFFF` is used. The predicate result is set to `True` for
+    the leader thread, and `False` for all other threads.
 
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync)
   }];
 
+  let arguments = (ins Optional<I32>:$membermask);
   let results = (outs I1:$pred);
-  let assemblyFormat = "attr-dict `->` type(results)";  
+  let assemblyFormat = "($membermask^)? attr-dict `->` type(results)";
   string llvmBuilder = [{
     auto *resultTuple = createIntrinsicCall(builder,
-        llvm::Intrinsic::nvvm_elect_sync, {builder.getInt32(0xFFFFFFFF)});
+        llvm::Intrinsic::nvvm_elect_sync,
+        {$membermask ? $membermask : builder.getInt32(0xFFFFFFFF)});
     // Extract the second value into $pred
     $pred = builder.CreateExtractValue(resultTuple, 1);
   }];
diff --git a/mlir/test/Target/LLVMIR/nvvm/elect.mlir b/mlir/test/Target/LLVMIR/nvvm/elect.mlir
new file mode 100644
index 0000000000000..3c5cac4b650bb
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/elect.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @test_nvvm_elect_sync
+llvm.func @test_nvvm_elect_sync() -> i1 {
+  // CHECK: %[[RES:.*]] = call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1)
+  // CHECK-NEXT: %[[PRED:.*]] = extractvalue { i32, i1 } %[[RES]], 1
+  // CHECK-NEXT: ret i1 %[[PRED]]
+  %0 = nvvm.elect.sync -> i1
+  llvm.return %0 : i1
+}
+
+// CHECK-LABEL: @test_nvvm_elect_sync_mask
+llvm.func @test_nvvm_elect_sync_mask(%mask : i32) -> i1 {
+  // CHECK: %[[RES:.*]] = call { i32, i1 } @llvm.nvvm.elect.sync(i32 %0)
+  // CHECK-NEXT: %[[PRED:.*]] = extractvalue { i32, i1 } %[[RES]], 1
+  // CHECK-NEXT: ret i1 %[[PRED]]
+  %0 = nvvm.elect.sync %mask -> i1
+  llvm.return %0 : i1
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 660d0a22dce9c..f86a04186f512 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -265,15 +265,6 @@ llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
   llvm.return %3 : i32
 }
 
-// CHECK-LABEL: @nvvm_elect_sync
-llvm.func @nvvm_elect_sync() -> i1 {
-  // CHECK: %[[RES:.*]] = call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1)
-  // CHECK-NEXT: %[[PRED:.*]] = extractvalue { i32, i1 } %[[RES]], 1
-  // CHECK-NEXT: ret i1 %[[PRED]]
-  %0 = nvvm.elect.sync -> i1
-  llvm.return %0 : i1
-}
-
 // CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32
 llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                     %b0 : vector<2xf16>, %b1 : vector<2xf16>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/145509


More information about the Mlir-commits mailing list