[Mlir-commits] [mlir] [mlir][nvvm] Introduce `setmaxregister.sync.aligned` Op (PR #73780)
Guray Ozen
llvmlistbot at llvm.org
Wed Nov 29 05:32:16 PST 2023
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/73780
>From 9722374dee7d63086a381671a22de98c0e5007ef Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 29 Nov 2023 11:57:57 +0100
Subject: [PATCH 1/3] [mlir][nvvm] Introduce `setmaxregister.sync.aligned` Op
This PR introduce `setmaxregister.sync.aligned` Op to increase or decrease the register size.
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 22 +++++++++++++++++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 6 +++++
.../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 10 +++++++++
3 files changed, 38 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 829fb68549307c8..cbe7c3919d62043 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -400,6 +400,28 @@ def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> {
let assemblyFormat = "attr-dict";
}
+def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>;
+def SetMaxRegisterActionDecrease : I32EnumAttrCase<"decrease", 1>;
+def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action",
+ [SetMaxRegisterActionDecrease, SetMaxRegisterActionIncrease]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
+
+def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister.sync.aligned"> {
+ let arguments = (ins I32Attr:$count, SetMaxRegisterActionAttr:$action);
+ let assemblyFormat = "$action $count attr-dict";
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ if(getAction() == NVVM::SetMaxRegisterAction::increase)
+ return std::string("setmaxnreg.inc.sync.aligned.u32 %0;");
+ return std::string("setmaxnreg.dec.sync.aligned.u32 %0;");
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def ShflKindBfly : I32EnumAttrCase<"bfly", 0>;
def ShflKindUp : I32EnumAttrCase<"up", 1>;
def ShflKindDown : I32EnumAttrCase<"down", 2>;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 63ceebb08e5baa7..c4fe1bab0ac68ae 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1007,6 +1007,12 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
}
}
+LogicalResult NVVM::SetMaxRegisterOp::verify() {
+ if (getCount() % 8)
+ return emitOpError("new register size must be multiple of 8");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5fa907850cedf30..fe4c33854485cda 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -611,3 +611,13 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
llvm.return
}
+
+// -----
+
+func.func @set_max_register() {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 [$0];", "n"
+ nvvm.setmaxregister.sync.aligned increase 232
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 [$0];", "n"
+ nvvm.setmaxregister.sync.aligned decrease 40
+ func.return
+}
>From 6d65afd47adc269a5fd6a7f04f751be7d9b71e00 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 29 Nov 2023 13:41:35 +0100
Subject: [PATCH 2/3] address comments
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 6 +++---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 4 +++-
mlir/test/Conversion/NVVMToLLVM/invalid.mlir | 7 +++++++
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 8 ++++----
4 files changed, 17 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index cbe7c3919d62043..54826f4196993d4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -409,9 +409,9 @@ def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max reg
}
def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
-def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister.sync.aligned"> {
- let arguments = (ins I32Attr:$count, SetMaxRegisterActionAttr:$action);
- let assemblyFormat = "$action $count attr-dict";
+def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister"> {
+ let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action);
+ let assemblyFormat = "$action $regCount attr-dict";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
if(getAction() == NVVM::SetMaxRegisterAction::increase)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index c4fe1bab0ac68ae..ff6b5da78bdfe34 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1008,8 +1008,10 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
}
LogicalResult NVVM::SetMaxRegisterOp::verify() {
- if (getCount() % 8)
+ if (getRegCount() % 8)
return emitOpError("new register size must be multiple of 8");
+ if (getRegCount() < 24 || getRegCount() > 256)
+ return emitOpError("new register size must be in between 24 to 256");
return success();
}
diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
index 8b9df5fa5980149..3e4d537f0160a86 100644
--- a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
@@ -132,3 +132,10 @@ func.func @wgmma_f32_m32(%descA : i64, %descB : i64) {
-> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
return
}
+// -----
+
+func.func @set_max_register() {
+ // expected-error @+1 {{new register size must be in between 24 to 256}}
+ nvvm.setmaxregister decrease 8
+ func.return
+}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index fe4c33854485cda..7da4e98c40e54b4 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -615,9 +615,9 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
// -----
func.func @set_max_register() {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 [$0];", "n"
- nvvm.setmaxregister.sync.aligned increase 232
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 [$0];", "n"
- nvvm.setmaxregister.sync.aligned decrease 40
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 $0;", "n"
+ nvvm.setmaxregister increase 232
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 $0;", "n"
+ nvvm.setmaxregister decrease 40
func.return
}
>From a7fcc67031d8f37fdcfe054ca26aa3e39138ed40 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 29 Nov 2023 14:32:01 +0100
Subject: [PATCH 3/3] add more test
---
mlir/test/Conversion/NVVMToLLVM/invalid.mlir | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
index 3e4d537f0160a86..1328755f69d8965 100644
--- a/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/invalid.mlir
@@ -139,3 +139,11 @@ func.func @set_max_register() {
nvvm.setmaxregister decrease 8
func.return
}
+
+// -----
+
+func.func @set_max_register() {
+ // expected-error @+1 {{new register size must be multiple of 8}}
+ nvvm.setmaxregister decrease 51
+ func.return
+}
More information about the Mlir-commits
mailing list