[Mlir-commits] [mlir] [MLIR][NVVM] Update the elect.sync Op to use intrinsics (PR #113757)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Oct 26 01:22:05 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Durgadoss R (durga4github)
<details>
<summary>Changes</summary>
Recently, we added an intrinsic for the elect.sync PTX instruction (PR 104780).
This patch updates the corresponding Op in NVVM Dialect to lower
to the intrinsic instead of inline-ptx.
The existing test under Conversion/ is migrated to check for the new pattern.
A separate test is added to verify the lowered intrinsic under the Target/
directory.
---
Full diff: https://github.com/llvm/llvm-project/pull/113757.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+18-15)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+1-7)
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+9)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 5806295cedb198..7cb4b5c346ad97 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -783,24 +783,27 @@ def NVVM_SyncWarpOp :
let assemblyFormat = "$mask attr-dict `:` type($mask)";
}
-
-def NVVM_ElectSyncOp : NVVM_Op<"elect.sync",
- [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>
+def NVVM_ElectSyncOp : NVVM_Op<"elect.sync">
{
+ let summary = "Elect one leader thread";
+ let description = [{
+ The `elect.sync` instruction elects one predicated active leader
+ thread from among a set of threads specified in membermask.
+ The membermask is set to `0xFFFFFFFF` for the current version
+ of this Op. The predicate result is set to `True` for the
+ leader thread, and `False` for all other threads.
+
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync)
+ }];
+
let results = (outs I1:$pred);
let assemblyFormat = "attr-dict `->` type(results)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- return std::string(
- "{ \n"
- ".reg .u32 rx; \n"
- ".reg .pred px; \n"
- " mov.pred %0, 0; \n"
- " elect.sync rx | px, 0xFFFFFFFF;\n"
- "@px mov.pred %0, 1; \n"
- "}\n"
- );
- }
+ string llvmBuilder = [{
+ auto *resultTuple = createIntrinsicCall(builder,
+ llvm::Intrinsic::nvvm_elect_sync, {builder.getInt32(0xFFFFFFFF)});
+ // Extract the second value into $pred
+ $pred = builder.CreateExtractValue(resultTuple, 1);
}];
}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 375e2951a037cd..66b736c18718f3 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -579,13 +579,7 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
// -----
func.func @elect_one_leader_sync() {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{
- // CHECK-SAME: .reg .u32 rx;
- // CHECK-SAME: .reg .pred px;
- // CHECK-SAME: mov.pred $0, 0;
- // CHECK-SAME: elect.sync rx | px, 0xFFFFFFFF;
- // CHECK-SAME: @px mov.pred $0, 1;
- // CHECK-SAME: "=b" : () -> i1
+ // CHECK: %[[RES:.*]] = nvvm.elect.sync -> i1
%cnd = nvvm.elect.sync -> i1
return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 0471e5faf84578..75ce958b43fd34 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -259,6 +259,15 @@ llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
llvm.return %3 : i32
}
+// CHECK-LABEL: @nvvm_elect_sync
+llvm.func @nvvm_elect_sync() -> i1 {
+ // CHECK: %[[RES:.*]] = call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1)
+ // CHECK-NEXT: %[[PRED:.*]] = extractvalue { i32, i1 } %[[RES]], 1
+ // CHECK-NEXT: ret i1 %[[PRED]]
+ %0 = nvvm.elect.sync -> i1
+ llvm.return %0 : i1
+}
+
// CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32
llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/113757
More information about the Mlir-commits
mailing list