[Mlir-commits] [mlir] [MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics (PR #78900)

Durgadoss R llvmlistbot at llvm.org
Sun Jan 21 05:38:18 PST 2024


https://github.com/durga4github created https://github.com/llvm/llvm-project/pull/78900

This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics.
* Doc updated for the commit_group Op.
* Tests are added to verify the lowering to the intrinsics.

While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test.

>From b3dff722e3714a51524158b27c84f0d0c430a5cb Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Sun, 21 Jan 2024 18:56:34 +0530
Subject: [PATCH] [MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics

This patch updates the cp.async.bulk.{commit/wait}_group
Ops to use NVVM intrinsics.
* Doc updated for the commit_group Op.
* Tests are added to verify the lowering to the intrinsics.

While we are there, fix the FileCheck directive on the
'nvvm.setmaxregister' test.

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 30 +++++++++++--------
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   | 18 +++++------
 mlir/test/Target/LLVMIR/nvvmir.mlir           | 24 +++++++++++++--
 3 files changed, 47 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f986..3916896382163ea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
-def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
+def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
+  let description = [{
+    This Op commits all prior initiated but uncommitted cp.async.bulk
+    instructions into a cp.async.bulk-group.
+
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
+  }];
+
+  string llvmBuilder = [{
+    createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
   }];
 }
 
-def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
+def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
   Arguments<(ins 
     ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group, 
-    OptionalAttr<UnitAttr>:$read)> 
-{
+    OptionalAttr<UnitAttr>:$read)> {
   let assemblyFormat = "$group attr-dict";
   let description = [{
     Op waits for completion of the most recent bulk async-groups.
@@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
   }];
   
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() { 
-      auto ptx = std::string("cp.async.bulk.wait_group");
-      if(getRead()) ptx += ".read";
-      ptx += " %0;"; return ptx; }
+  string llvmBuilder = [{
+    auto intId = op.getRead() ?
+      llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
+      llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
+    createIntrinsicCall(builder, intId, builder.getInt32($group));
   }];
 }
 
-
 def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9487bdf3bd218a..40131af6826487a 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -638,23 +638,19 @@ func.func @set_max_register() {
 
 // -----
 
-func.func @cp_bulk_commit() {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
+func.func @cp_async_bulk_commit() {
+  // CHECK: nvvm.cp.async.bulk.commit.group
   nvvm.cp.async.bulk.commit.group
   func.return
 }
 
 // -----
 
-func.func @cp_bulk_wait_group() {
-  // CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
-  // CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
-  // CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
-  // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
+func.func @cp_async_bulk_wait_group() {
+  // CHECK: nvvm.cp.async.bulk.wait_group 1
+  // CHECK: nvvm.cp.async.bulk.wait_group 0
+  // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
+  // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
   nvvm.cp.async.bulk.wait_group 1
   nvvm.cp.async.bulk.wait_group 0
   nvvm.cp.async.bulk.wait_group 5 {read}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f68..49f9426daabc21b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
 
 // CHECK-LABEL: @llvm_nvvm_setmaxregister
 llvm.func @llvm_nvvm_setmaxregister() {
-  // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
+  // CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
   nvvm.setmaxregister increase 256
-  // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
+  // CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
   nvvm.setmaxregister decrease 24
   llvm.return
 }
 
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
+llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
+  nvvm.cp.async.bulk.commit.group
+  llvm.return
+}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
+llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
+  nvvm.cp.async.bulk.wait_group 0
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
+  nvvm.cp.async.bulk.wait_group 3
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
+  nvvm.cp.async.bulk.wait_group 0 {read}
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
+  nvvm.cp.async.bulk.wait_group 3 {read}
+  llvm.return
+}
+
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})



More information about the Mlir-commits mailing list