[Mlir-commits] [mlir] [MLIR][NVVM]: Update setmaxregister NVVM Op (PR #77594)

Durgadoss R llvmlistbot at llvm.org
Wed Jan 10 04:20:20 PST 2024


https://github.com/durga4github created https://github.com/llvm/llvm-project/pull/77594

This patch updates the setmaxregister NVVM Op to use the
intrinsics instead of inline-ptx.

* The interface remains same (as expected).
* Tests are added to verify the lowered intrinsics in Target/LLVMIR/nvvmir.mlir.

>From 06a3914d1100b4ef9b646388b98d784fe56066f9 Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Wed, 10 Jan 2024 17:39:27 +0530
Subject: [PATCH] [MLIR][NVVM]: Update setmaxregister NVVM Op

...to use the intrinsics instead of inline-ptx.

* The interface remains same (as expected).
* Tests are added to verify the lowered intrinsics
  in Target/LLVMIR/nvvmir.mlir.

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td      | 16 ++++++++--------
 .../test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir |  5 +++--
 mlir/test/Target/LLVMIR/nvvmir.mlir              |  9 +++++++++
 3 files changed, 20 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 3a6c6e5438c6d7..1941c4dece1b86 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -463,17 +463,17 @@ def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max reg
 }
 def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
 
-def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister"> {
+def NVVM_SetMaxRegisterOp : NVVM_Op<"setmaxregister"> {
   let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action);
   let assemblyFormat = "$action $regCount attr-dict";
-  let extraClassDefinition = [{        
-    std::string $cppClass::getPtx() {
-      if(getAction() == NVVM::SetMaxRegisterAction::increase)
-        return std::string("setmaxnreg.inc.sync.aligned.u32 %0;");
-      return std::string("setmaxnreg.dec.sync.aligned.u32 %0;");
-    }
-  }];
   let hasVerifier = 1;
+  string llvmBuilder = [{
+    auto intId = (op.getAction() == NVVM::SetMaxRegisterAction::increase) ?
+      llvm::Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32 :
+      llvm::Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32;
+
+    createIntrinsicCall(builder, intId, builder.getInt32($regCount));
+  }];
 }
 
 def NVVM_FenceMbarrierInitOp : NVVM_PTXBuilder_Op<"fence.mbarrier.init"> {
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 7e08ec6ffcbd89..2ee92e3d9527a6 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -628,9 +628,10 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
 // -----
 
 func.func @set_max_register() {
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 $0;", "n"
+  // CHECK: nvvm.setmaxregister increase 232
   nvvm.setmaxregister increase 232
-  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 $0;", "n"
+
+  // CHECK: nvvm.setmaxregister decrease 40
   nvvm.setmaxregister decrease 40
   func.return
 }
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f83be9dbb2ff30..423b1a133a4ae2 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -369,6 +369,15 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
   llvm.return
 }
 
+// CHECK-LABEL: @llvm_nvvm_setmaxregister
+llvm.func @llvm_nvvm_setmaxregister() {
+  // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
+  nvvm.setmaxregister increase 256
+  // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
+  nvvm.setmaxregister decrease 24
+  llvm.return
+}
+
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})



More information about the Mlir-commits mailing list