[Mlir-commits] [mlir] [MLIR][NVVM] Update Wgmma.fence Ops to use intrinsics (PR #120956)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 23 01:57:34 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This PR updates the WgmmaFenceAlignedOp, WgmmaGroupSyncAlignedOp, and WgmmaWaitGroupSyncOp Ops in the NVVM Dialect to lower to the corresponding intrinsics instead of inline-ptx.

The existing test under Conversion/NVVMToLLVM is updated to check for the new patterns and separate tests are added under Target/LLVMIR to verify the lowered intrinsics.

---
Full diff: https://github.com/llvm/llvm-project/pull/120956.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+10-10) 
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+6-8) 
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+26) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 530135b912b9e6..a2d2102b59dece 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2130,7 +2130,7 @@ def NVVM_CpAsyncBulkTensorReduceOp :
 // NVVM Wgmma Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
+def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
   let arguments = (ins);
   let description = [{
     Enforce an ordering of register accesses between warpgroup level matrix 
@@ -2139,12 +2139,12 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence)
   }];
   let assemblyFormat = "attr-dict";
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); }
+  string llvmBuilder = [{
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_fence_sync_aligned);
   }];
 }
 
-def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
+def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
   let description = [{
@@ -2152,21 +2152,21 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.a
     
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group)
   }];
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); }
+  string llvmBuilder = [{
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_commit_group_sync_aligned);
   }];
 }
 
-def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
-  let arguments = (ins I32Attr:$group);
+def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
+  let arguments = (ins I64Attr:$group);
   let assemblyFormat = "attr-dict $group";
   let description = [{
     Signal the completion of a preceding warpgroup operation.
     
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group)
   }];
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); }
+  string llvmBuilder = [{
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_wait_group_sync_aligned, builder.getInt64($group));
   }];
 }
 
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 66b736c18718f3..84ea55ceb5acc2 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -266,19 +266,17 @@ func.func @wgmma_execute() {
   nvvm.wgmma.fence.aligned
   nvvm.wgmma.commit.group.sync.aligned
   nvvm.wgmma.wait.group.sync.aligned 0
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
-  // CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
+  // CHECK: nvvm.wgmma.fence.aligned
+  // CHECK: nvvm.wgmma.commit.group.sync.aligned
+  // CHECK: nvvm.wgmma.wait.group.sync.aligned 0
   
 
   nvvm.wgmma.fence.aligned
   nvvm.wgmma.commit.group.sync.aligned
   nvvm.wgmma.wait.group.sync.aligned 5
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
-  // CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
+  // CHECK: nvvm.wgmma.fence.aligned
+  // CHECK: nvvm.wgmma.commit.group.sync.aligned
+  // CHECK: nvvm.wgmma.wait.group.sync.aligned 5
   return
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 6a32190694b470..b69d77496351c1 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -714,3 +714,29 @@ llvm.func @nvvm_breakpoint() {
   nvvm.breakpoint
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_fence_aligned
+llvm.func @nvvm_wgmma_fence_aligned() {
+  // CHECK: call void @llvm.nvvm.wgmma.fence.sync.aligned()
+  nvvm.wgmma.fence.aligned
+  llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_commit_group_aligned
+llvm.func @nvvm_wgmma_commit_group_aligned() {
+  // CHECK: call void @llvm.nvvm.wgmma.commit_group.sync.aligned()
+  nvvm.wgmma.commit.group.sync.aligned
+  llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_wgmma_wait_group_aligned
+llvm.func @nvvm_wgmma_wait_group_aligned() {
+  // CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 0)
+  nvvm.wgmma.wait.group.sync.aligned 0
+  // CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 20)
+  nvvm.wgmma.wait.group.sync.aligned 20
+  llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/120956


More information about the Mlir-commits mailing list