[Mlir-commits] [mlir] [mlir][nvvm] Introduce `setmaxregister.sync.aligned` Op (PR #73780)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 29 02:56:46 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
This PR introduce `setmaxregister.sync.aligned` Op to increase or decrease the register size.
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg
---
Full diff: https://github.com/llvm/llvm-project/pull/73780.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+22)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+9-7)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+10)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 829fb68549307c8..cbe7c3919d62043 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -400,6 +400,28 @@ def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> {
let assemblyFormat = "attr-dict";
}
+def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>;
+def SetMaxRegisterActionDecrease : I32EnumAttrCase<"decrease", 1>;
+def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action",
+ [SetMaxRegisterActionDecrease, SetMaxRegisterActionIncrease]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
+
+def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister.sync.aligned"> {
+ let arguments = (ins I32Attr:$count, SetMaxRegisterActionAttr:$action);
+ let assemblyFormat = "$action $count attr-dict";
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ if(getAction() == NVVM::SetMaxRegisterAction::increase)
+ return std::string("setmaxnreg.inc.sync.aligned.u32 %0;");
+ return std::string("setmaxnreg.dec.sync.aligned.u32 %0;");
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def ShflKindBfly : I32EnumAttrCase<"bfly", 0>;
def ShflKindUp : I32EnumAttrCase<"up", 1>;
def ShflKindDown : I32EnumAttrCase<"down", 2>;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 63ceebb08e5baa7..e2280c398153411 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -213,8 +213,7 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
- p << " : "
- << "(";
+ p << " : " << "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
@@ -954,9 +953,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
- ss << " $" << (regCnt) << ","
- << " $" << (regCnt + 1) << ","
- << " p";
+ ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
if (!outputType.isInteger(32)) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
@@ -964,8 +961,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
if (isF16) {
ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
}
- ss << ";\n"
- << "}\n";
+ ss << ";\n" << "}\n";
ss.flush();
return ptx;
}
@@ -1007,6 +1003,12 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
}
}
+LogicalResult NVVM::SetMaxRegisterOp::verify() {
+ if (getCount() % 8)
+ return emitOpError("new register size must be multiple of 8");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5fa907850cedf30..fe4c33854485cda 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -611,3 +611,13 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
llvm.return
}
+
+// -----
+
+func.func @set_max_register() {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 [$0];", "n"
+ nvvm.setmaxregister.sync.aligned increase 232
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 [$0];", "n"
+ nvvm.setmaxregister.sync.aligned decrease 40
+ func.return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/73780
More information about the Mlir-commits
mailing list