[Mlir-commits] [mlir] 68433f6 - [mlir][nvvm] Introduce `setmaxregister.sync.aligned` Op (#73780)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 29 06:26:35 PST 2023


Author: Guray Ozen
Date: 2023-11-29T15:26:30+01:00
New Revision: 68433f6b27a2d95d8132d4a636a7e18ed85ca9e6

URL: https://github.com/llvm/llvm-project/commit/68433f6b27a2d95d8132d4a636a7e18ed85ca9e6
DIFF: https://github.com/llvm/llvm-project/commit/68433f6b27a2d95d8132d4a636a7e18ed85ca9e6.diff

LOG: [mlir][nvvm] Introduce `setmaxregister.sync.aligned` Op (#73780)

This PR introduce `setmaxregister.sync.aligned` Op to increase or
decrease the register size.


https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Conversion/NVVMToLLVM/invalid.mlir
    mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 829fb68549307c8..54826f4196993d4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -400,6 +400,28 @@ def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> {
   let assemblyFormat = "attr-dict";
 }
 
+def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>;
+def SetMaxRegisterActionDecrease   : I32EnumAttrCase<"decrease", 1>;
+def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action",
+  [SetMaxRegisterActionDecrease, SetMaxRegisterActionIncrease]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
+
+def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister"> {
+  let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action);
+  let assemblyFormat = "$action $regCount attr-dict";
+  let extraClassDefinition = [{        
+    std::string $cppClass::getPtx() {
+      if(getAction() == NVVM::SetMaxRegisterAction::increase)
+        return std::string("setmaxnreg.inc.sync.aligned.u32 %0;");
+      return std::string("setmaxnreg.dec.sync.aligned.u32 %0;");
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def ShflKindBfly : I32EnumAttrCase<"bfly", 0>;
 def ShflKindUp   : I32EnumAttrCase<"up", 1>;
 def ShflKindDown : I32EnumAttrCase<"down", 2>;

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 63ceebb08e5baa7..ff6b5da78bdfe34 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1007,6 +1007,14 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
   }
 }
 
+LogicalResult NVVM::SetMaxRegisterOp::verify() {
+  if (getRegCount() % 8)
+    return emitOpError("new register size must be multiple of 8");
+  if (getRegCount() < 24 || getRegCount() > 256)
+    return emitOpError("new register size must be in between 24 to 256");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
index 8b9df5fa5980149..1328755f69d8965 100644
--- a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
@@ -132,3 +132,18 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
       -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> 
   return 
 }
+// -----
+
+func.func @set_max_register() {
+  // expected-error @+1 {{new register size must be in between 24 to 256}}
+  nvvm.setmaxregister decrease 8
+  func.return
+}
+
+// -----
+
+func.func @set_max_register() {
+  // expected-error @+1 {{new register size must be multiple of 8}}
+  nvvm.setmaxregister decrease 51
+  func.return
+}

diff  --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5fa907850cedf30..7da4e98c40e54b4 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -611,3 +611,13 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
   nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
   llvm.return
 }
+
+// -----
+
+func.func @set_max_register() {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 $0;", "n"
+  nvvm.setmaxregister increase 232
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 $0;", "n"
+  nvvm.setmaxregister decrease 40
+  func.return
+}


        


More information about the Mlir-commits mailing list