[Mlir-commits] [mlir] 3ac1744 - [mlir][nvvm] Introduce performance tuning directives
Guray Ozen
llvmlistbot at llvm.org
Fri Oct 28 05:02:47 PDT 2022
Author: Guray Ozen
Date: 2022-10-28T14:02:40+02:00
New Revision: 3ac17449cf988bfcde804a4cc532420ed1657595
URL: https://github.com/llvm/llvm-project/commit/3ac17449cf988bfcde804a4cc532420ed1657595
DIFF: https://github.com/llvm/llvm-project/commit/3ac17449cf988bfcde804a4cc532420ed1657595.diff
LOG: [mlir][nvvm] Introduce performance tuning directives
PTX programming models provides some performance tuning directives; see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#performance-tuning-directives
The downstream compiler namely `ptxas` leverages these information for better register allocation or to handle other resource management that improves the performance.
This revision introduce all the kernel based directives to MLIR's NVVM dialect. The list is below
```
maxnreg -> max register per thread in CTA
maxntid -> max threads per CTA
reqntid -> exact number of threads per CTA
minnctapersm -> min CTA per SM
```
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D136931
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/nvvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index eea3516949ac1..25afd7ac4762b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -34,6 +34,29 @@ def NVVM_Dialect : Dialect {
/// Get the name of the attribute used to annotate external kernel
/// functions.
static StringRef getKernelFuncAttrName() { return "nvvm.kernel"; }
+ /// Get the name of the attribute used to annotate max threads required
+ /// per CTA for kernel functions.
+ static StringRef getMaxntidAttrName() { return "nvvm.maxntid"; }
+ /// Get the name of the metadata names for each dimension
+ static StringRef getMaxntidXName() { return "maxntidx"; }
+ static StringRef getMaxntidYName() { return "maxntidy"; }
+ static StringRef getMaxntidZName() { return "maxntidz"; }
+
+ /// Get the name of the attribute used to annotate exact threads required
+ /// per CTA for kernel functions.
+ static StringRef getReqntidAttrName() { return "nvvm.reqntid"; }
+ /// Get the name of the metadata names for each dimension
+ static StringRef getReqntidXName() { return "reqntidx"; }
+ static StringRef getReqntidYName() { return "reqntidy"; }
+ static StringRef getReqntidZName() { return "reqntidz"; }
+
+ /// Get the name of the attribute used to annotate min CTA required
+ /// per SM for kernel functions.
+ static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; }
+
+ /// Get the name of the attribute used to annotate max number of
+ /// registers that can be allocated per thread.
+ static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
}];
let useDefaultAttributePrinterParser = 1;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ae9e95da4cd8d..f80ad40675461 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -16,7 +16,9 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
@@ -27,6 +29,7 @@
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
@@ -672,13 +675,37 @@ void NVVMDialect::initialize() {
LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
+ StringAttr attrName = attr.getName();
// Kernel function attribute should be attached to functions.
- if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) {
+ if (attrName == NVVMDialect::getKernelFuncAttrName()) {
if (!isa<LLVM::LLVMFuncOp>(op)) {
return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
<< "' attribute attached to unexpected op";
}
}
+ // If maxntid and reqntid exist, it must be an array with max 3 dim
+ if (attrName == NVVMDialect::getMaxntidAttrName() ||
+ attrName == NVVMDialect::getReqntidAttrName()) {
+ auto values = attr.getValue().dyn_cast<ArrayAttr>();
+ if (!values || values.empty() || values.size() > 3)
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute must be integer array with maximum 3 index";
+ for (auto val : attr.getValue().cast<ArrayAttr>()) {
+ if (!val.dyn_cast<IntegerAttr>())
+ return op->emitError()
+ << "'" << attrName
+ << "' attribute must be integer array with maximum 3 index";
+ }
+ }
+ // If minctasm and maxnreg exist, it must be an array with max 3 dim
+ if (attrName == NVVMDialect::getMinctasmAttrName() ||
+ attrName == NVVMDialect::getMaxnregAttrName()) {
+ if (!attr.getValue().dyn_cast<IntegerAttr>())
+ return op->emitError()
+ << "'" << attrName << "' attribute must be integer constant";
+ }
+
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index e09260dca28a2..feaf5ca3f563f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -13,7 +13,9 @@
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/IR/IRBuilder.h"
@@ -116,21 +118,59 @@ class NVVMDialectLLVMIRTranslationInterface
LogicalResult
amendOperation(Operation *op, NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
- if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) {
- auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
- if (!func)
- return failure();
+ auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!func)
+ return failure();
+ llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
+ llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
- llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
- llvm::Function *llvmFunc =
- moduleTranslation.lookupFunction(func.getName());
+ auto generateMetadata = [&](int dim, StringRef name) {
llvm::Metadata *llvmMetadata[] = {
+ llvm::ValueAsMetadata::get(llvmFunc),
+ llvm::MDString::get(llvmContext, name),
+ llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
+ llvm::Type::getInt32Ty(llvmContext), dim))};
+ llvm::MDNode *llvmMetadataNode =
+ llvm::MDNode::get(llvmContext, llvmMetadata);
+ moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
+ ->addOperand(llvmMetadataNode);
+ };
+ if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
+ if (!attribute.getValue().dyn_cast<ArrayAttr>())
+ return failure();
+ SmallVector<int64_t> values =
+ extractFromI64ArrayAttr(attribute.getValue());
+ generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
+ if (values.size() > 1)
+ generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
+ if (values.size() > 2)
+ generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
+ } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
+ if (!attribute.getValue().dyn_cast<ArrayAttr>())
+ return failure();
+ SmallVector<int64_t> values =
+ extractFromI64ArrayAttr(attribute.getValue());
+ generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
+ if (values.size() > 1)
+ generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
+ if (values.size() > 2)
+ generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
+ } else if (attribute.getName() ==
+ NVVM::NVVMDialect::getMinctasmAttrName()) {
+ auto value = attribute.getValue().dyn_cast<IntegerAttr>();
+ generateMetadata(value.getInt(), "minctasm");
+ } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
+ auto value = attribute.getValue().dyn_cast<IntegerAttr>();
+ generateMetadata(value.getInt(), "maxnreg");
+ } else if (attribute.getName() ==
+ NVVM::NVVMDialect::getKernelFuncAttrName()) {
+ llvm::Metadata *llvmMetadataKernel[] = {
llvm::ValueAsMetadata::get(llvmFunc),
llvm::MDString::get(llvmContext, "kernel"),
llvm::ValueAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))};
llvm::MDNode *llvmMetadataNode =
- llvm::MDNode::get(llvmContext, llvmMetadata);
+ llvm::MDNode::get(llvmContext, llvmMetadataKernel);
moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
->addOperand(llvmMetadataNode);
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index a90fabe3461d9..fed2a48f9fa8b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s
// CHECK-LABEL: @nvvm_special_regs
llvm.func @nvvm_special_regs() -> i32 {
@@ -349,3 +349,90 @@ llvm.func @kernel_func() attributes {nvvm.kernel} {
// CHECK: !nvvm.annotations =
// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
// CHECK: {ptr @kernel_func, !"kernel", i32 1}
+
+// -----
+
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} {
+ llvm.return
+}
+
+// CHECK: !nvvm.annotations =
+// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"maxntidx", i32 1}
+// CHECK: {ptr @kernel_func, !"maxntidy", i32 23}
+// CHECK: {ptr @kernel_func, !"maxntidz", i32 32}
+// -----
+
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [1,23,32]} {
+ llvm.return
+}
+
+// CHECK: !nvvm.annotations =
+// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"reqntidx", i32 1}
+// CHECK: {ptr @kernel_func, !"reqntidy", i32 23}
+// CHECK: {ptr @kernel_func, !"reqntidz", i32 32}
+// -----
+
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = 16} {
+ llvm.return
+}
+
+// CHECK: !nvvm.annotations =
+// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"minctasm", i32 16}
+// -----
+
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxnreg = 16} {
+ llvm.return
+}
+
+// CHECK: !nvvm.annotations =
+// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"maxnreg", i32 16}
+// -----
+
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32],
+ nvvm.minctasm = 16, nvvm.maxnreg = 32} {
+ llvm.return
+}
+
+// CHECK: !nvvm.annotations =
+// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"maxnreg", i32 32}
+// CHECK: {ptr @kernel_func, !"maxntidx", i32 1}
+// CHECK: {ptr @kernel_func, !"maxntidy", i32 23}
+// CHECK: {ptr @kernel_func, !"maxntidz", i32 32}
+// CHECK: {ptr @kernel_func, !"minctasm", i32 16}
+
+// -----
+// expected-error @below {{'"nvvm.minctasm"' attribute must be integer constant}}
+llvm.func @kernel_func() attributes {nvvm.kernel,
+nvvm.minctasm = "foo"} {
+ llvm.return
+}
+
+
+// -----
+// expected-error @below {{'"nvvm.maxnreg"' attribute must be integer constant}}
+llvm.func @kernel_func() attributes {nvvm.kernel,
+nvvm.maxnreg = "boo"} {
+ llvm.return
+}
+// -----
+// expected-error @below {{'"nvvm.reqntid"' attribute must be integer array with maximum 3 index}}
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [3,4,5,6]} {
+ llvm.return
+}
+
+// -----
+// expected-error @below {{'"nvvm.maxntid"' attribute must be integer array with maximum 3 index}}
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [3,4,5,6]} {
+ llvm.return
+}
+
More information about the Mlir-commits
mailing list