[Mlir-commits] [mlir] [MLIR][NVVM] Update Wgmma.fence Ops to use intrinsics (PR #120956)

Srinivasa Ravi llvmlistbot at llvm.org
Mon Dec 23 03:10:03 PST 2024


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/120956

>From 21a06ab2a41fbb4789ca3ca941316f4a6bcf4e0f Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 18 Dec 2024 14:26:49 +0530
Subject: [PATCH] [MLIR][NVVM] Update Wgmma.fence Ops to use intrinsics

This patch updates the WgmmaFenceAlignedOp, WgmmaGroupSyncAlignedOp, and
WgmmaWaitGroupSyncOp Ops in the NVVM Dialect to lower to the
corresponding intrinsics instead of inline-ptx.

The existing test under Conversion/NVVMToLLVM is updated to check for
the new patterns and separate tests are added under Target/LLVMIR to
verify the lowered intrinsics.
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 20 +++++++-------
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   |  2 +-
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   | 14 +++++-----
 mlir/test/Target/LLVMIR/nvvmir.mlir           | 26 +++++++++++++++++++
 4 files changed, 43 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 530135b912b9e6..a2d2102b59dece 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2130,7 +2130,7 @@ def NVVM_CpAsyncBulkTensorReduceOp :
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
   let arguments = (ins);
   let description = [{
     Enforce an ordering of register accesses between warpgroup level matrix 
@@ -2139,12 +2139,12 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence)
   }];
   let assemblyFormat = "attr-dict";
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); }
+  string llvmBuilder = [{
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_fence_sync_aligned);
   }];
 }
 
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
   let description = [{
@@ -2152,21 +2152,21 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.a
     
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group)
   }];
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); }
+  string llvmBuilder = [{
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_commit_group_sync_aligned);
   }];
 }
 
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
-  let arguments = (ins I32Attr:$group);
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
+  let arguments = (ins I64Attr:$group);
   let assemblyFormat = "attr-dict $group";
   let description = [{
     Signal the completion of a preceding warpgroup operation.
     
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group)
   }];
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); }
+  string llvmBuilder = [{
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_wait_group_sync_aligned, builder.getInt64($group));
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index b39f2ee594cd4a..f48fa9976da128 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -771,7 +771,7 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
 
   let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA, 
                        NVGPU_WarpgroupMatrixDescriptor:$descriptorB,                                               
-                       DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
+                       DefaultValuedOptionalAttr<I64Attr, "1">:$waitGroup,
                        OptionalAttr<UnitAttr>:$transposeA,
                        OptionalAttr<UnitAttr>:$transposeB,
                        NVGPU_WarpgroupAccumulator:$matrixC);
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 66b736c18718f3..84ea55ceb5acc2 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -266,19 +266,17 @@ func.func @wgmma_execute() {
   nvvm.wgmma.fence.aligned
   nvvm.wgmma.commit.group.sync.aligned
   nvvm.wgmma.wait.group.sync.aligned 0
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
-  // CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
+  // CHECK: nvvm.wgmma.fence.aligned
+  // CHECK: nvvm.wgmma.commit.group.sync.aligned
+  // CHECK: nvvm.wgmma.wait.group.sync.aligned 0
   
 
   nvvm.wgmma.fence.aligned
   nvvm.wgmma.commit.group.sync.aligned
   nvvm.wgmma.wait.group.sync.aligned 5
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
-  // CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
+  // CHECK: nvvm.wgmma.fence.aligned
+  // CHECK: nvvm.wgmma.commit.group.sync.aligned
+  // CHECK: nvvm.wgmma.wait.group.sync.aligned 5
   return
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 6a32190694b470..b69d77496351c1 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -714,3 +714,29 @@ llvm.func @nvvm_breakpoint() {
   nvvm.breakpoint
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_fence_aligned
+llvm.func @nvvm_wgmma_fence_aligned() {
+  // CHECK: call void @llvm.nvvm.wgmma.fence.sync.aligned()
+  nvvm.wgmma.fence.aligned
+  llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_commit_group_aligned
+llvm.func @nvvm_wgmma_commit_group_aligned() {
+  // CHECK: call void @llvm.nvvm.wgmma.commit_group.sync.aligned()
+  nvvm.wgmma.commit.group.sync.aligned
+  llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_wait_group_aligned
+llvm.func @nvvm_wgmma_wait_group_aligned() {
+  // CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 0)
+  nvvm.wgmma.wait.group.sync.aligned 0
+  // CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 20)
+  nvvm.wgmma.wait.group.sync.aligned 20
+  llvm.return
+}



More information about the Mlir-commits mailing list