[Mlir-commits] [mlir] [MLIR][NVVM] Update Wgmma.fence Ops to use intrinsics (PR #120956)
Srinivasa Ravi
llvmlistbot at llvm.org
Mon Dec 23 03:10:03 PST 2024
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/120956
>From 21a06ab2a41fbb4789ca3ca941316f4a6bcf4e0f Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 18 Dec 2024 14:26:49 +0530
Subject: [PATCH] [MLIR][NVVM] Update Wgmma.fence Ops to use intrinsics
This patch updates the WgmmaFenceAlignedOp, WgmmaGroupSyncAlignedOp, and
WgmmaWaitGroupSyncOp Ops in the NVVM Dialect to lower to the
corresponding intrinsics instead of inline-ptx.
The existing test under Conversion/NVVMToLLVM is updated to check for
the new patterns and separate tests are added under Target/LLVMIR to
verify the lowered intrinsics.
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 20 +++++++-------
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 2 +-
.../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 14 +++++-----
mlir/test/Target/LLVMIR/nvvmir.mlir | 26 +++++++++++++++++++
4 files changed, 43 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 530135b912b9e6..a2d2102b59dece 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2130,7 +2130,7 @@ def NVVM_CpAsyncBulkTensorReduceOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
-def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
@@ -2139,12 +2139,12 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence)
}];
let assemblyFormat = "attr-dict";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); }
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_fence_sync_aligned);
}];
}
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
let description = [{
@@ -2152,21 +2152,21 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.a
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group)
}];
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); }
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_commit_group_sync_aligned);
}];
}
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
- let arguments = (ins I32Attr:$group);
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
+ let arguments = (ins I64Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
Signal the completion of a preceding warpgroup operation.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group)
}];
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); }
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_wait_group_sync_aligned, builder.getInt64($group));
}];
}
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index b39f2ee594cd4a..f48fa9976da128 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -771,7 +771,7 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA,
NVGPU_WarpgroupMatrixDescriptor:$descriptorB,
- DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
+ DefaultValuedOptionalAttr<I64Attr, "1">:$waitGroup,
OptionalAttr<UnitAttr>:$transposeA,
OptionalAttr<UnitAttr>:$transposeB,
NVGPU_WarpgroupAccumulator:$matrixC);
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 66b736c18718f3..84ea55ceb5acc2 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -266,19 +266,17 @@ func.func @wgmma_execute() {
nvvm.wgmma.fence.aligned
nvvm.wgmma.commit.group.sync.aligned
nvvm.wgmma.wait.group.sync.aligned 0
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
- // CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
+ // CHECK: nvvm.wgmma.fence.aligned
+ // CHECK: nvvm.wgmma.commit.group.sync.aligned
+ // CHECK: nvvm.wgmma.wait.group.sync.aligned 0
nvvm.wgmma.fence.aligned
nvvm.wgmma.commit.group.sync.aligned
nvvm.wgmma.wait.group.sync.aligned 5
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
- // CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
+ // CHECK: nvvm.wgmma.fence.aligned
+ // CHECK: nvvm.wgmma.commit.group.sync.aligned
+ // CHECK: nvvm.wgmma.wait.group.sync.aligned 5
return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 6a32190694b470..b69d77496351c1 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -714,3 +714,29 @@ llvm.func @nvvm_breakpoint() {
nvvm.breakpoint
llvm.return
}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_fence_aligned
+llvm.func @nvvm_wgmma_fence_aligned() {
+ // CHECK: call void @llvm.nvvm.wgmma.fence.sync.aligned()
+ nvvm.wgmma.fence.aligned
+ llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_commit_group_aligned
+llvm.func @nvvm_wgmma_commit_group_aligned() {
+ // CHECK: call void @llvm.nvvm.wgmma.commit_group.sync.aligned()
+ nvvm.wgmma.commit.group.sync.aligned
+ llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_wait_group_aligned
+llvm.func @nvvm_wgmma_wait_group_aligned() {
+ // CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 0)
+ nvvm.wgmma.wait.group.sync.aligned 0
+ // CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 20)
+ nvvm.wgmma.wait.group.sync.aligned 20
+ llvm.return
+}
More information about the Mlir-commits
mailing list