[Mlir-commits] [clang] [llvm] [mlir] [NVPTX] Auto-upgrade nvvm.grid_constant to param attribute (PR #155489)
Alex MacLean
llvmlistbot at llvm.org
Tue Aug 26 13:14:20 PDT 2025
https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/155489
Upgrade the !"grid_constant" !nvvm.annotation to a "nvvm.grid_constant" attribute. This attribute is much simpler for front-ends to apply and faster and simpler to query.
>From 178af6d5a6ae46a1db9969fab050b7240efaf1a1 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Sun, 24 Aug 2025 05:13:27 +0000
Subject: [PATCH] [NVPTX] Auto-upgrade nvvm.grid_constant to param attribute
---
clang/lib/CodeGen/Targets/NVPTX.cpp | 42 +------
clang/test/CodeGenCUDA/grid-constant.cu | 16 +--
llvm/docs/NVPTXUsage.rst | 59 ++++------
llvm/lib/IR/AutoUpgrade.cpp | 10 ++
llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 32 +-----
.../CodeGen/NVPTX/lower-args-gridconstant.ll | 104 ++++++------------
.../CodeGen/NVPTX/upgrade-nvvm-annotations.ll | 13 ++-
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 45 +-------
mlir/test/Target/LLVMIR/nvvmir.mlir | 10 +-
9 files changed, 91 insertions(+), 240 deletions(-)
diff --git a/clang/lib/CodeGen/Targets/NVPTX.cpp b/clang/lib/CodeGen/Targets/NVPTX.cpp
index e874617796f86..78790daa1874a 100644
--- a/clang/lib/CodeGen/Targets/NVPTX.cpp
+++ b/clang/lib/CodeGen/Targets/NVPTX.cpp
@@ -87,10 +87,6 @@ class NVPTXTargetCodeGenInfo : public TargetCodeGenInfo {
static void addNVVMMetadata(llvm::GlobalValue *GV, StringRef Name,
int Operand);
- static void
- addGridConstantNVVMMetadata(llvm::GlobalValue *GV,
- const SmallVectorImpl<int> &GridConstantArgs);
-
private:
static void emitBuiltinSurfTexDeviceCopy(CodeGenFunction &CGF, LValue Dst,
LValue Src) {
@@ -266,27 +262,24 @@ void NVPTXTargetCodeGenInfo::setTargetAttributes(
// By default, all functions are device functions
if (FD->hasAttr<DeviceKernelAttr>() || FD->hasAttr<CUDAGlobalAttr>()) {
// OpenCL/CUDA kernel functions get kernel metadata
- // Create !{<func-ref>, metadata !"kernel", i32 1} node
// And kernel functions are not subject to inlining
F->addFnAttr(llvm::Attribute::NoInline);
if (FD->hasAttr<CUDAGlobalAttr>()) {
- SmallVector<int, 10> GCI;
+ F->setCallingConv(llvm::CallingConv::PTX_Kernel);
+
for (auto IV : llvm::enumerate(FD->parameters()))
if (IV.value()->hasAttr<CUDAGridConstantAttr>())
- // For some reason arg indices are 1-based in NVVM
- GCI.push_back(IV.index() + 1);
- // Create !{<func-ref>, metadata !"kernel", i32 1} node
- F->setCallingConv(llvm::CallingConv::PTX_Kernel);
- addGridConstantNVVMMetadata(F, GCI);
+ F->addParamAttr(
+ IV.index(),
+ llvm::Attribute::get(F->getContext(), "nvvm.grid_constant"));
}
if (CUDALaunchBoundsAttr *Attr = FD->getAttr<CUDALaunchBoundsAttr>())
M.handleCUDALaunchBoundsAttr(F, Attr);
}
}
// Attach kernel metadata directly if compiling for NVPTX.
- if (FD->hasAttr<DeviceKernelAttr>()) {
+ if (FD->hasAttr<DeviceKernelAttr>())
F->setCallingConv(llvm::CallingConv::PTX_Kernel);
- }
}
void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
@@ -306,29 +299,6 @@ void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
}
-void NVPTXTargetCodeGenInfo::addGridConstantNVVMMetadata(
- llvm::GlobalValue *GV, const SmallVectorImpl<int> &GridConstantArgs) {
-
- llvm::Module *M = GV->getParent();
- llvm::LLVMContext &Ctx = M->getContext();
-
- // Get "nvvm.annotations" metadata node
- llvm::NamedMDNode *MD = M->getOrInsertNamedMetadata("nvvm.annotations");
-
- SmallVector<llvm::Metadata *, 5> MDVals = {llvm::ConstantAsMetadata::get(GV)};
- if (!GridConstantArgs.empty()) {
- SmallVector<llvm::Metadata *, 10> GCM;
- for (int I : GridConstantArgs)
- GCM.push_back(llvm::ConstantAsMetadata::get(
- llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), I)));
- MDVals.append({llvm::MDString::get(Ctx, "grid_constant"),
- llvm::MDNode::get(Ctx, GCM)});
- }
-
- // Append metadata to nvvm.annotations
- MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
-}
-
bool NVPTXTargetCodeGenInfo::shouldEmitStaticExternCAliases() const {
return false;
}
diff --git a/clang/test/CodeGenCUDA/grid-constant.cu b/clang/test/CodeGenCUDA/grid-constant.cu
index e7000cab3cda5..120b854e56746 100644
--- a/clang/test/CodeGenCUDA/grid-constant.cu
+++ b/clang/test/CodeGenCUDA/grid-constant.cu
@@ -19,13 +19,9 @@ void foo() {
tkernel_const<S><<<1,1>>>({});
tkernel<const S><<<1,1>>>(1, {});
}
-//.
-//.
-// CHECK: [[META0:![0-9]+]] = !{ptr @_Z6kernel1Sii, !"grid_constant", [[META1:![0-9]+]]}
-// CHECK: [[META1]] = !{i32 1, i32 3}
-// CHECK: [[META2:![0-9]+]] = !{ptr @_Z13tkernel_constIK1SEvT_, !"grid_constant", [[META3:![0-9]+]]}
-// CHECK: [[META3]] = !{i32 1}
-// CHECK: [[META4:![0-9]+]] = !{ptr @_Z13tkernel_constI1SEvT_, !"grid_constant", [[META3]]}
-// CHECK: [[META5:![0-9]+]] = !{ptr @_Z7tkernelIK1SEviT_, !"grid_constant", [[META6:![0-9]+]]}
-// CHECK: [[META6]] = !{i32 2}
-//.
+
+// CHECK: define dso_local ptx_kernel void @_Z6kernel1Sii(ptr noundef byval(%struct.S) align 1 "nvvm.grid_constant" %gc_arg1, i32 noundef %arg2, i32 noundef "nvvm.grid_constant" %gc_arg3)
+// CHECK: define ptx_kernel void @_Z13tkernel_constIK1SEvT_(ptr noundef byval(%struct.S) align 1 "nvvm.grid_constant" %arg)
+// CHECK: define ptx_kernel void @_Z13tkernel_constI1SEvT_(ptr noundef byval(%struct.S) align 1 "nvvm.grid_constant" %arg)
+// CHECK: define ptx_kernel void @_Z7tkernelIK1SEviT_(i32 noundef %dummy, ptr noundef byval(%struct.S) align 1 "nvvm.grid_constant" %arg)
+
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 629bf2ea5afb4..4c8c605edfdd6 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -57,6 +57,19 @@ not.
When compiled, the PTX kernel functions are callable by host-side code.
+
+Parameter Attributes
+--------------------
+
+``"nvvm.grid_constant"``
+ This attribute may be attached to a ``byval`` parameter of a kernel function
+ to indicate that the parameter should be lowered as a direct reference to
+ the grid-constant memory of the parameter, as opposed to a copy of the
+ parameter in local memory. Writing to a grid-constant parameter is
+ undefined behavior. Unlike a normal ``byval`` parameter, the address of a
+ grid-constant parameter is not unique to a given function invocation but
+ instead is shared by all kernels in the grid.
+
.. _nvptx_fnattrs:
Function Attributes
@@ -2289,9 +2302,9 @@ The Kernel
; Intrinsic to read X component of thread ID
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() readnone nounwind
- define void @kernel(ptr addrspace(1) %A,
- ptr addrspace(1) %B,
- ptr addrspace(1) %C) {
+ define ptx_kernel void @kernel(ptr addrspace(1) %A,
+ ptr addrspace(1) %B,
+ ptr addrspace(1) %C) {
entry:
; What is my ID?
%id = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() readnone nounwind
@@ -2314,9 +2327,6 @@ The Kernel
ret void
}
- !nvvm.annotations = !{!0}
- !0 = !{ptr @kernel, !"kernel", i32 1}
-
We can use the LLVM ``llc`` tool to directly run the NVPTX code generator:
@@ -2442,34 +2452,6 @@ and non-generic address spaces.
See :ref:`address_spaces` and :ref:`nvptx_intrinsics` for more information.
-Kernel Metadata
-^^^^^^^^^^^^^^^
-
-In PTX, a function can be either a `kernel` function (callable from the host
-program), or a `device` function (callable only from GPU code). You can think
-of `kernel` functions as entry-points in the GPU program. To mark an LLVM IR
-function as a `kernel` function, we make use of special LLVM metadata. The
-NVPTX back-end will look for a named metadata node called
-``nvvm.annotations``. This named metadata must contain a list of metadata that
-describe the IR. For our purposes, we need to declare a metadata node that
-assigns the "kernel" attribute to the LLVM IR function that should be emitted
-as a PTX `kernel` function. These metadata nodes take the form:
-
-.. code-block:: text
-
- !{<function ref>, metadata !"kernel", i32 1}
-
-For the previous example, we have:
-
-.. code-block:: llvm
-
- !nvvm.annotations = !{!0}
- !0 = !{ptr @kernel, !"kernel", i32 1}
-
-Here, we have a single metadata declaration in ``nvvm.annotations``. This
-metadata annotates our ``@kernel`` function with the ``kernel`` attribute.
-
-
Running the Kernel
------------------
@@ -2669,9 +2651,9 @@ Libdevice provides an ``__nv_powf`` function that we will use.
; libdevice function
declare float @__nv_powf(float, float)
- define void @kernel(ptr addrspace(1) %A,
- ptr addrspace(1) %B,
- ptr addrspace(1) %C) {
+ define ptx_kernel void @kernel(ptr addrspace(1) %A,
+ ptr addrspace(1) %B,
+ ptr addrspace(1) %C) {
entry:
; What is my ID?
%id = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x() readnone nounwind
@@ -2694,9 +2676,6 @@ Libdevice provides an ``__nv_powf`` function that we will use.
ret void
}
- !nvvm.annotations = !{!0}
- !0 = !{ptr @kernel, !"kernel", i32 1}
-
To compile this kernel, we perform the following steps:
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index e200f3626e69d..7ea9c6dff13b8 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -5381,6 +5381,16 @@ bool static upgradeSingleNVVMAnnotation(GlobalValue *GV, StringRef K,
upgradeNVVMFnVectorAttr("nvvm.cluster_dim", K[0], GV, V);
return true;
}
+ if (K == "grid_constant") {
+ const auto Attr = Attribute::get(GV->getContext(), "nvvm.grid_constant");
+ for (const auto &Op : cast<MDNode>(V)->operands()) {
+ // For some reason, the index is 1-based in the metadata. Good thing we're
+ // able to auto-upgrade it!
+ const auto Index = mdconst::extract<ConstantInt>(Op)->getZExtValue() - 1;
+ cast<Function>(GV)->addParamAttr(Index, Attr);
+ }
+ return true;
+ }
return false;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 274b04fdd30b5..8e97b422218f7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -55,15 +55,6 @@ void clearAnnotationCache(const Module *Mod) {
AC.Cache.erase(Mod);
}
-static void readIntVecFromMDNode(const MDNode *MetadataNode,
- std::vector<unsigned> &Vec) {
- for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
- ConstantInt *Val =
- mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
- Vec.push_back(Val->getZExtValue());
- }
-}
-
static void cacheAnnotationFromMD(const MDNode *MetadataNode,
key_val_pair_t &retval) {
auto &AC = getAnnotationCache();
@@ -83,19 +74,8 @@ static void cacheAnnotationFromMD(const MDNode *MetadataNode,
if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
MetadataNode->getOperand(i + 1))) {
retval[Key].push_back(Val->getZExtValue());
- } else if (MDNode *VecMd =
- dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
- // note: only "grid_constant" annotations support vector MDNodes.
- // assert: there can only exist one unique key value pair of
- // the form (string key, MDNode node). Operands of such a node
- // shall always be unsigned ints.
- auto [It, Inserted] = retval.try_emplace(Key);
- if (Inserted) {
- readIntVecFromMDNode(VecMd, It->second);
- continue;
- }
} else {
- llvm_unreachable("Value operand not a constant int or an mdnode");
+ llvm_unreachable("Value operand not a constant int");
}
}
}
@@ -179,16 +159,13 @@ static bool globalHasNVVMAnnotation(const Value &V, const std::string &Prop) {
}
static bool argHasNVVMAnnotation(const Value &Val,
- const std::string &Annotation,
- const bool StartArgIndexAtOne = false) {
+ const std::string &Annotation) {
if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
const Function *Func = Arg->getParent();
std::vector<unsigned> Annot;
if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
- const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
- if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
+ if (is_contained(Annot, Arg->getArgNo()))
return true;
- }
}
}
return false;
@@ -250,8 +227,7 @@ bool isParamGridConstant(const Argument &Arg) {
}
// "grid_constant" counts argument indices starting from 1
- if (argHasNVVMAnnotation(Arg, "grid_constant",
- /*StartArgIndexAtOne*/ true))
+ if (Arg.hasAttribute("nvvm.grid_constant"))
return true;
return false;
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
index 8adde4ceefbf4..01ab47145940c 100644
--- a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -49,14 +49,14 @@ define dso_local noundef i32 @non_kernel_function(ptr nocapture noundef readonly
; PTX-NEXT: st.param.b32 [func_retval0], %r10;
; PTX-NEXT: ret;
entry:
- %a. = select i1 %b, ptr %a, ptr addrspacecast (ptr addrspace(1) @gi to ptr), !dbg !17
- %idx.ext = sext i32 %c to i64, !dbg !18
- %add.ptr = getelementptr inbounds i8, ptr %a., i64 %idx.ext, !dbg !18
- %0 = load i32, ptr %add.ptr, align 1, !dbg !19
- ret i32 %0, !dbg !23
+ %a. = select i1 %b, ptr %a, ptr addrspacecast (ptr addrspace(1) @gi to ptr)
+ %idx.ext = sext i32 %c to i64
+ %add.ptr = getelementptr inbounds i8, ptr %a., i64 %idx.ext
+ %0 = load i32, ptr %add.ptr, align 1
+ ret i32 %0
}
-define ptx_kernel void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %input2, ptr %out, i32 %n) {
+define ptx_kernel void @grid_const_int(ptr byval(i32) align 4 "nvvm.grid_constant" %input1, i32 %input2, ptr %out, i32 %n) {
; PTX-LABEL: grid_const_int(
; PTX: {
; PTX-NEXT: .reg .b32 %r<4>;
@@ -71,7 +71,7 @@ define ptx_kernel void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %inpu
; PTX-NEXT: st.global.b32 [%rd2], %r3;
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_int(
-; OPT-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[INPUT11:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
; OPT-NEXT: [[TMP:%.*]] = load i32, ptr addrspace(101) [[INPUT11]], align 4
; OPT-NEXT: [[ADD:%.*]] = add i32 [[TMP]], [[INPUT2]]
@@ -85,7 +85,7 @@ define ptx_kernel void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %inpu
%struct.s = type { i32, i32 }
-define ptx_kernel void @grid_const_struct(ptr byval(%struct.s) align 4 %input, ptr %out){
+define ptx_kernel void @grid_const_struct(ptr byval(%struct.s) align 4 "nvvm.grid_constant" %input, ptr %out){
; PTX-LABEL: grid_const_struct(
; PTX: {
; PTX-NEXT: .reg .b32 %r<4>;
@@ -100,7 +100,7 @@ define ptx_kernel void @grid_const_struct(ptr byval(%struct.s) align 4 %input, p
; PTX-NEXT: st.global.b32 [%rd2], %r3;
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_struct(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[OUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[OUT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[INPUT1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
; OPT-NEXT: [[GEP13:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 0
; OPT-NEXT: [[GEP22:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 1
@@ -118,7 +118,7 @@ define ptx_kernel void @grid_const_struct(ptr byval(%struct.s) align 4 %input, p
ret void
}
-define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
+define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 "nvvm.grid_constant" %input) {
; PTX-LABEL: grid_const_escape(
; PTX: {
; PTX-NEXT: .reg .b64 %rd<4>;
@@ -136,7 +136,7 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
; PTX-NEXT: } // callseq 0
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_escape(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
; OPT-NEXT: [[INPUT_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
; OPT-NEXT: [[CALL:%.*]] = call i32 @escape(ptr [[INPUT_PARAM_GEN]])
@@ -145,7 +145,7 @@ define ptx_kernel void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
ret void
}
-define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4 %input, i32 %a, ptr byval(i32) align 4 %b) {
+define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4 "nvvm.grid_constant" %input, i32 %a, ptr byval(i32) align 4 "nvvm.grid_constant" %b) {
; PTX-LABEL: multiple_grid_const_escape(
; PTX: {
; PTX-NEXT: .local .align 4 .b8 __local_depot4[4];
@@ -179,7 +179,7 @@ define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4
; PTX-NEXT: } // callseq 1
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @multiple_grid_const_escape(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], i32 [[A:%.*]], ptr byval(i32) align 4 [[B:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], i32 [[A:%.*]], ptr byval(i32) align 4 "nvvm.grid_constant" [[B:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[B]])
; OPT-NEXT: [[B_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
; OPT-NEXT: [[TMP2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
@@ -194,7 +194,7 @@ define ptx_kernel void @multiple_grid_const_escape(ptr byval(%struct.s) align 4
ret void
}
-define ptx_kernel void @grid_const_memory_escape(ptr byval(%struct.s) align 4 %input, ptr %addr) {
+define ptx_kernel void @grid_const_memory_escape(ptr byval(%struct.s) align 4 "nvvm.grid_constant" %input, ptr %addr) {
; PTX-LABEL: grid_const_memory_escape(
; PTX: {
; PTX-NEXT: .reg .b64 %rd<5>;
@@ -207,7 +207,7 @@ define ptx_kernel void @grid_const_memory_escape(ptr byval(%struct.s) align 4 %i
; PTX-NEXT: st.global.b64 [%rd3], %rd4;
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_memory_escape(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[ADDR:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[ADDR:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
; OPT-NEXT: [[INPUT1:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
; OPT-NEXT: store ptr [[INPUT1]], ptr [[ADDR]], align 8
@@ -216,7 +216,7 @@ define ptx_kernel void @grid_const_memory_escape(ptr byval(%struct.s) align 4 %i
ret void
}
-define ptx_kernel void @grid_const_inlineasm_escape(ptr byval(%struct.s) align 4 %input, ptr %result) {
+define ptx_kernel void @grid_const_inlineasm_escape(ptr byval(%struct.s) align 4 "nvvm.grid_constant" %input, ptr %result) {
; PTX-LABEL: grid_const_inlineasm_escape(
; PTX: {
; PTX-NEXT: .reg .b64 %rd<7>;
@@ -234,7 +234,7 @@ define ptx_kernel void @grid_const_inlineasm_escape(ptr byval(%struct.s) align 4
; PTX-NEXT: ret;
; PTX-NOT .local
; OPT-LABEL: define ptx_kernel void @grid_const_inlineasm_escape(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[RESULT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT:%.*]], ptr [[RESULT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
; OPT-NEXT: [[INPUT1:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
; OPT-NEXT: [[TMPPTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT1]], i32 0, i32 0
@@ -249,7 +249,7 @@ define ptx_kernel void @grid_const_inlineasm_escape(ptr byval(%struct.s) align 4
ret void
}
-define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) %input, ptr %output) {
+define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) "nvvm.grid_constant" %input, ptr %output) {
; PTX-LABEL: grid_const_partial_escape(
; PTX: {
; PTX-NEXT: .reg .b32 %r<3>;
@@ -273,7 +273,7 @@ define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) %input, ptr %ou
; PTX-NEXT: } // callseq 2
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_partial_escape(
-; OPT-SAME: ptr byval(i32) [[INPUT:%.*]], ptr [[OUTPUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval(i32) "nvvm.grid_constant" [[INPUT:%.*]], ptr [[OUTPUT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[TMP1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
; OPT-NEXT: [[INPUT1_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
; OPT-NEXT: [[VAL1:%.*]] = load i32, ptr [[INPUT1_GEN]], align 4
@@ -288,7 +288,7 @@ define ptx_kernel void @grid_const_partial_escape(ptr byval(i32) %input, ptr %ou
ret void
}
-define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input, ptr %output) {
+define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) "nvvm.grid_constant" %input, ptr %output) {
; PTX-LABEL: grid_const_partial_escapemem(
; PTX: {
; PTX-NEXT: .reg .b32 %r<4>;
@@ -314,7 +314,7 @@ define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input,
; PTX-NEXT: st.param.b32 [func_retval0], %r3;
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel i32 @grid_const_partial_escapemem(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) [[INPUT:%.*]], ptr [[OUTPUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) "nvvm.grid_constant" [[INPUT:%.*]], ptr [[OUTPUT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[TMP1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
; OPT-NEXT: [[INPUT1:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
; OPT-NEXT: [[PTR1:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[INPUT1]], i32 0, i32 0
@@ -335,7 +335,7 @@ define ptx_kernel i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input,
ret i32 %add
}
-define ptx_kernel void @grid_const_phi(ptr byval(%struct.s) align 4 %input1, ptr %inout) {
+define ptx_kernel void @grid_const_phi(ptr byval(%struct.s) align 4 "nvvm.grid_constant" %input1, ptr %inout) {
; PTX-LABEL: grid_const_phi(
; PTX: {
; PTX-NEXT: .reg .pred %p<2>;
@@ -356,7 +356,7 @@ define ptx_kernel void @grid_const_phi(ptr byval(%struct.s) align 4 %input1, ptr
; PTX-NEXT: st.global.b32 [%rd1], %r2;
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_phi(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT1:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[TMP1:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
; OPT-NEXT: [[INPUT1_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
; OPT-NEXT: [[VAL:%.*]] = load i32, ptr [[INOUT]], align 4
@@ -391,7 +391,7 @@ merge:
}
; NOTE: %input2 is *not* grid_constant
-define ptx_kernel void @grid_const_phi_ngc(ptr byval(%struct.s) align 4 %input1, ptr byval(%struct.s) %input2, ptr %inout) {
+define ptx_kernel void @grid_const_phi_ngc(ptr byval(%struct.s) align 4 "nvvm.grid_constant" %input1, ptr byval(%struct.s) %input2, ptr %inout) {
; PTX-LABEL: grid_const_phi_ngc(
; PTX: {
; PTX-NEXT: .reg .pred %p<2>;
@@ -413,7 +413,7 @@ define ptx_kernel void @grid_const_phi_ngc(ptr byval(%struct.s) align 4 %input1,
; PTX-NEXT: st.global.b32 [%rd1], %r2;
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_phi_ngc(
-; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT1:%.*]], ptr byval([[STRUCT_S]]) [[INPUT2:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr byval([[STRUCT_S]]) [[INPUT2:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[TMP1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
; OPT-NEXT: [[INPUT2_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
; OPT-NEXT: [[TMP2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
@@ -449,7 +449,7 @@ merge:
}
; NOTE: %input2 is *not* grid_constant
-define ptx_kernel void @grid_const_select(ptr byval(i32) align 4 %input1, ptr byval(i32) %input2, ptr %inout) {
+define ptx_kernel void @grid_const_select(ptr byval(i32) align 4 "nvvm.grid_constant" %input1, ptr byval(i32) %input2, ptr %inout) {
; PTX-LABEL: grid_const_select(
; PTX: {
; PTX-NEXT: .reg .pred %p<2>;
@@ -468,7 +468,7 @@ define ptx_kernel void @grid_const_select(ptr byval(i32) align 4 %input1, ptr by
; PTX-NEXT: st.global.b32 [%rd3], %r2;
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel void @grid_const_select(
-; OPT-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], ptr byval(i32) [[INPUT2:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval(i32) align 4 "nvvm.grid_constant" [[INPUT1:%.*]], ptr byval(i32) [[INPUT2:%.*]], ptr [[INOUT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[TMP1:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT2]])
; OPT-NEXT: [[INPUT2_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[TMP1]] to ptr
; OPT-NEXT: [[TMP2:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT1]])
@@ -487,7 +487,7 @@ define ptx_kernel void @grid_const_select(ptr byval(i32) align 4 %input1, ptr by
ret void
}
-define ptx_kernel i32 @grid_const_ptrtoint(ptr byval(i32) %input) {
+define ptx_kernel i32 @grid_const_ptrtoint(ptr byval(i32) "nvvm.grid_constant" %input) {
; PTX-LABEL: grid_const_ptrtoint(
; PTX: {
; PTX-NEXT: .reg .b32 %r<4>;
@@ -502,7 +502,7 @@ define ptx_kernel i32 @grid_const_ptrtoint(ptr byval(i32) %input) {
; PTX-NEXT: st.param.b32 [func_retval0], %r3;
; PTX-NEXT: ret;
; OPT-LABEL: define ptx_kernel i32 @grid_const_ptrtoint(
-; OPT-SAME: ptr byval(i32) align 4 [[INPUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval(i32) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[INPUT2:%.*]] = call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
; OPT-NEXT: [[INPUT3:%.*]] = load i32, ptr addrspace(101) [[INPUT2]], align 4
; OPT-NEXT: [[INPUT1:%.*]] = addrspacecast ptr addrspace(101) [[INPUT2]] to ptr
@@ -517,9 +517,9 @@ define ptx_kernel i32 @grid_const_ptrtoint(ptr byval(i32) %input) {
declare void @device_func(ptr byval(i32) align 4)
-define ptx_kernel void @test_forward_byval_arg(ptr byval(i32) align 4 %input) {
+define ptx_kernel void @test_forward_byval_arg(ptr byval(i32) align 4 "nvvm.grid_constant" %input) {
; OPT-LABEL: define ptx_kernel void @test_forward_byval_arg(
-; OPT-SAME: ptr byval(i32) align 4 [[INPUT:%.*]]) #[[ATTR0]] {
+; OPT-SAME: ptr byval(i32) align 4 "nvvm.grid_constant" [[INPUT:%.*]]) #[[ATTR0]] {
; OPT-NEXT: [[INPUT_PARAM:%.*]] = call align 4 ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr [[INPUT]])
; OPT-NEXT: [[INPUT_PARAM_GEN:%.*]] = addrspacecast ptr addrspace(101) [[INPUT_PARAM]] to ptr
; OPT-NEXT: call void @device_func(ptr byval(i32) align 4 [[INPUT_PARAM_GEN]])
@@ -545,45 +545,3 @@ define ptx_kernel void @test_forward_byval_arg(ptr byval(i32) align 4 %input) {
declare dso_local void @dummy() local_unnamed_addr
declare dso_local ptr @escape(ptr) local_unnamed_addr
declare dso_local ptr @escape3(ptr, ptr, ptr) local_unnamed_addr
-
-!nvvm.annotations = !{!0, !1, !2, !3, !4, !5, !6, !7, !8, !9, !10, !11, !12, !13, !14, !15, !16, !17, !18, !19, !20, !21, !22, !23, !24}
-
-!0 = !{ptr @grid_const_int, !"grid_constant", !1}
-!1 = !{i32 1}
-
-!2 = !{ptr @grid_const_struct, !"grid_constant", !3}
-!3 = !{i32 1}
-
-!4 = !{ptr @grid_const_escape, !"grid_constant", !5}
-!5 = !{i32 1}
-
-!6 = !{ptr @multiple_grid_const_escape, !"grid_constant", !7}
-!7 = !{i32 1, i32 3}
-
-!8 = !{ptr @grid_const_memory_escape, !"grid_constant", !9}
-!9 = !{i32 1}
-
-!10 = !{ptr @grid_const_inlineasm_escape, !"grid_constant", !11}
-!11 = !{i32 1}
-
-!12 = !{ptr @grid_const_partial_escape, !"grid_constant", !13}
-!13 = !{i32 1}
-
-!14 = !{ptr @grid_const_partial_escapemem, !"grid_constant", !15}
-!15 = !{i32 1}
-
-!16 = !{ptr @grid_const_phi, !"grid_constant", !17}
-!17 = !{i32 1}
-
-!18 = !{ptr @grid_const_phi_ngc, !"grid_constant", !19}
-!19 = !{i32 1}
-
-!20 = !{ptr @grid_const_select, !"grid_constant", !21}
-!21 = !{i32 1}
-
-!22 = !{ptr @grid_const_ptrtoint, !"grid_constant", !23}
-!23 = !{i32 1}
-
-!24 = !{ptr @test_forward_byval_arg, !"grid_constant", !25}
-!25 = !{i32 1}
-
diff --git a/llvm/test/CodeGen/NVPTX/upgrade-nvvm-annotations.ll b/llvm/test/CodeGen/NVPTX/upgrade-nvvm-annotations.ll
index 84c7a124a6f3e..80fd47f85795c 100644
--- a/llvm/test/CodeGen/NVPTX/upgrade-nvvm-annotations.ll
+++ b/llvm/test/CodeGen/NVPTX/upgrade-nvvm-annotations.ll
@@ -96,7 +96,15 @@ define void @test_cluster_dim() {
ret void
}
-!nvvm.annotations = !{!0, !1, !2, !3, !4, !5, !6, !7, !8, !9, !10, !11, !12}
+define void @test_grid_constant(ptr byval(i32) %input1, i32 %input2, ptr byval(i32) %input3) {
+; CHECK-LABEL: define void @test_grid_constant(
+; CHECK-SAME: ptr byval(i32) "nvvm.grid_constant" [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr byval(i32) "nvvm.grid_constant" [[INPUT3:%.*]]) {
+; CHECK-NEXT: ret void
+;
+ ret void
+}
+
+!nvvm.annotations = !{!0, !1, !2, !3, !4, !5, !6, !7, !8, !9, !10, !11, !12, !13}
!0 = !{ptr @test_align, !"align", i32 u0x00000008, !"align", i32 u0x00010008, !"align", i32 u0x00020010}
!1 = !{null, !"align", i32 u0x00000008, !"align", i32 u0x00010008, !"align", i32 u0x00020008}
@@ -111,7 +119,8 @@ define void @test_cluster_dim() {
!10 = !{ptr @test_maxntid_4, !"maxntidz", i32 100}
!11 = !{ptr @test_reqntid, !"reqntidx", i32 31, !"reqntidy", i32 32, !"reqntidz", i32 33}
!12 = !{ptr @test_cluster_dim, !"cluster_dim_x", i32 101, !"cluster_dim_y", i32 102, !"cluster_dim_z", i32 103}
-
+!13 = !{ptr @test_grid_constant, !"grid_constant", !14}
+!14 = !{i32 1, i32 3}
;.
; CHECK: attributes #[[ATTR0]] = { "nvvm.maxclusterrank"="2" }
; CHECK: attributes #[[ATTR1]] = { "nvvm.maxclusterrank"="3" }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index a20701ce75bc0..7f69af14df338 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -483,51 +483,10 @@ class NVVMDialectLLVMIRTranslationInterface
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(funcOp.getName());
- llvm::NamedMDNode *nvvmAnnotations =
- moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
- llvm::MDNode *gridConstantMetaData = nullptr;
-
- // Check if a 'grid_constant' metadata node exists for the given function
- for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
- if (opnd->getNumOperands() == 3 &&
- opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
- opnd->getOperand(1) ==
- llvm::MDString::get(llvmContext, "grid_constant")) {
- gridConstantMetaData = opnd;
- break;
- }
- }
-
- // 'grid_constant' is a function-level meta data node with a list of
- // integers, where each integer n denotes that the nth parameter has the
- // grid_constant annotation (numbering from 1). This requires aggregating
- // the indices of the individual parameters that have this attribute.
- llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
- if (gridConstantMetaData == nullptr) {
- // Create a new 'grid_constant' metadata node
- SmallVector<llvm::Metadata *> gridConstMetadata = {
- llvm::ValueAsMetadata::getConstant(
- llvm::ConstantInt::get(i32, argIdx + 1))};
- llvm::Metadata *llvmMetadata[] = {
- llvm::ValueAsMetadata::get(llvmFunc),
- llvm::MDString::get(llvmContext, "grid_constant"),
- llvm::MDNode::get(llvmContext, gridConstMetadata)};
- llvm::MDNode *llvmMetadataNode =
- llvm::MDNode::get(llvmContext, llvmMetadata);
- nvvmAnnotations->addOperand(llvmMetadataNode);
- } else {
- // Append argIdx + 1 to the 'grid_constant' argument list
- if (auto argList =
- dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
- llvm::TempMDTuple clonedArgList = argList->clone();
- clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
- llvm::ConstantInt::get(i32, argIdx + 1))));
- gridConstantMetaData->replaceOperandWith(
- 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
- }
- }
+ llvmFunc->addParamAttr(
+ argIdx, llvm::Attribute::get(llvmContext, "nvvm.grid_constant"));
}
return success();
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index fa7dd1daf96ed..62aeb071c5786 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -705,19 +705,13 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.blocksareclusters,
// CHECK: define ptx_kernel void @kernel_func() #[[ATTR0:[0-9]+]]
// CHECK: attributes #[[ATTR0]] = { "nvvm.blocksareclusters" "nvvm.cluster_dim"="3,5,7" "nvvm.reqntid"="1,23,32" }
// -----
-// CHECK: define ptx_kernel void @kernel_func
-// CHECK: !nvvm.annotations =
-// CHECK: !{{.*}} = !{ptr @kernel_func, !"grid_constant", ![[ID:[[:alnum:]]+]]}
-// CHECK: ![[ID]] = !{i32 1}
+// CHECK: define ptx_kernel void @kernel_func(ptr byval(i32) "nvvm.grid_constant" %0)
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}
// -----
-// CHECK: define ptx_kernel void @kernel_func
-// CHECK: !nvvm.annotations =
-// CHECK: !{{.*}} = !{ptr @kernel_func, !"grid_constant", ![[ID:[[:alnum:]]+]]}
-// CHECK: ![[ID]] = !{i32 1, i32 3}
+// CHECK: define ptx_kernel void @kernel_func(ptr byval(i32) "nvvm.grid_constant" %0, float %1, ptr byval(float) "nvvm.grid_constant" %2)
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
llvm.return
}
More information about the Mlir-commits
mailing list