[llvm] Add support for "grid_constant" in NVPTXLowerArgs. (PR #96125)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 19 17:34:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Akshay Deodhar (akshayrdeodhar)

<details>
<summary>Changes</summary>

- Adds a helper function for checking whether an argument is a grid_constant.
- Adds support for cvta.param using changes from https://github.com/llvm/llvm-project/pull/95289
- Supports escaped grid_constant pointers conservatively, by casting all uses to the generic address space with cvta.param.

---

Patch is 22.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96125.diff


6 Files Affected:

- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+5) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+1) 
- (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+52-21) 
- (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.cpp (+78-65) 
- (modified) llvm/lib/Target/NVPTX/NVPTXUtilities.h (+1) 
- (added) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+155) 


``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0a9139e0062ba..b7c828566e375 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1596,6 +1596,11 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
                                    [IntrNoMem, IntrSpeculatable, IntrNoCallback],
                                    "llvm.nvvm.ptr.gen.to.param">;
 
+// sm70+, PTX7.7+
+def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
+                                     [llvm_anyptr_ty],
+                                   [IntrNoMem, IntrSpeculatable, IntrNoCallback]>;
+
 // Move intrinsics, used in nvvm internally
 
 def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index a65170e56aa24..3e7f8d63439c8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2475,6 +2475,7 @@ defm cvta_local  : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>
 defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
 defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
 defm cvta_const  : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;
+defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>;
 
 defm cvta_to_local  : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
 defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index cde02c25c4834..1116b8e6313f7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -95,7 +95,9 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
 #include "llvm/InitializePasses.h"
@@ -336,8 +338,9 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
     while (!ValuesToCheck.empty()) {
       Value *V = ValuesToCheck.pop_back_val();
       if (!IsALoadChainInstr(V)) {
-        LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
-                          << "\n");
+        LLVM_DEBUG(dbgs() << "Need a "
+                          << (isParamGridConstant(*Arg) ? "cast " : "copy ")
+                          << "of " << *Arg << " because of " << *V << "\n");
         (void)Arg;
         return false;
       }
@@ -366,27 +369,55 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
     return;
   }
 
-  // Otherwise we have to create a temporary copy.
   const DataLayout &DL = Func->getParent()->getDataLayout();
   unsigned AS = DL.getAllocaAddrSpace();
-  AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
-  // Set the alignment to alignment of the byval parameter. This is because,
-  // later load/stores assume that alignment, and we are going to replace
-  // the use of the byval parameter with this alloca instruction.
-  AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
-                           .value_or(DL.getPrefTypeAlign(StructType)));
-  Arg->replaceAllUsesWith(AllocA);
-
-  Value *ArgInParam = new AddrSpaceCastInst(
-      Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
-      FirstInst);
-  // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
-  // addrspacecast preserves alignment.  Since params are constant, this load is
-  // definitely not volatile.
-  LoadInst *LI =
-      new LoadInst(StructType, ArgInParam, Arg->getName(),
-                   /*isVolatile=*/false, AllocA->getAlign(), FirstInst);
-  new StoreInst(LI, AllocA, FirstInst);
+  if (isParamGridConstant(*Arg)) {
+    // Writes to a grid constant are undefined behaviour. We do not need a
+    // temporary copy. When a pointer might have escaped, conservatively replace
+    // all of its uses (which might include a device function call) with a cast
+    // to the generic address space.
+    // TODO: only cast byval grid constant parameters at use points that need
+    // generic address (e.g., merging parameter pointers with other address
+    // space, or escaping to call-sites, inline-asm, memory), and use the
+    // parameter address space for normal loads.
+    IRBuilder<> IRB(&Func->getEntryBlock().front());
+
+    // Cast argument to param address space
+    AddrSpaceCastInst *CastToParam =
+        cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
+            Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
+
+    // Cast param address to generic address space
+    Value *CvtToGenCall = IRB.CreateIntrinsic(
+        IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
+        CastToParam, nullptr, CastToParam->getName() + ".gen");
+
+    Arg->replaceAllUsesWith(CvtToGenCall);
+
+    // Do not replace Arg in the cast to param space
+    CastToParam->setOperand(0, Arg);
+  } else {
+    // Otherwise we have to create a temporary copy.
+    AllocaInst *AllocA =
+        new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
+    // Set the alignment to alignment of the byval parameter. This is because,
+    // later load/stores assume that alignment, and we are going to replace
+    // the use of the byval parameter with this alloca instruction.
+    AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
+                             .value_or(DL.getPrefTypeAlign(StructType)));
+    Arg->replaceAllUsesWith(AllocA);
+
+    Value *ArgInParam = new AddrSpaceCastInst(
+        Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
+        Arg->getName(), FirstInst);
+    // Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
+    // addrspacecast preserves alignment.  Since params are constant, this load
+    // is definitely not volatile.
+    LoadInst *LI =
+        new LoadInst(StructType, ArgInParam, Arg->getName(),
+                     /*isVolatile=*/false, AllocA->getAlign(), FirstInst);
+    new StoreInst(LI, AllocA, FirstInst);
+  }
 }
 
 void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
index 3a536db1c9727..96db2079ed59f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp
@@ -52,29 +52,45 @@ void clearAnnotationCache(const Module *Mod) {
   AC.Cache.erase(Mod);
 }
 
-static void cacheAnnotationFromMD(const MDNode *md, key_val_pair_t &retval) {
+static void readIntVecFromMDNode(const MDNode *MetadataNode,
+                                 std::vector<unsigned> &Vec) {
+  for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
+    ConstantInt *Val =
+        mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
+    Vec.push_back(Val->getZExtValue());
+  }
+}
+
+static void cacheAnnotationFromMD(const MDNode *MetadataNode,
+                                  key_val_pair_t &retval) {
   auto &AC = getAnnotationCache();
   std::lock_guard<sys::Mutex> Guard(AC.Lock);
-  assert(md && "Invalid mdnode for annotation");
-  assert((md->getNumOperands() % 2) == 1 && "Invalid number of operands");
+  assert(MetadataNode && "Invalid mdnode for annotation");
+  assert((MetadataNode->getNumOperands() % 2) == 1 &&
+         "Invalid number of operands");
   // start index = 1, to skip the global variable key
   // increment = 2, to skip the value for each property-value pairs
-  for (unsigned i = 1, e = md->getNumOperands(); i != e; i += 2) {
+  for (unsigned i = 1, e = MetadataNode->getNumOperands(); i != e; i += 2) {
     // property
-    const MDString *prop = dyn_cast<MDString>(md->getOperand(i));
+    const MDString *prop = dyn_cast<MDString>(MetadataNode->getOperand(i));
     assert(prop && "Annotation property not a string");
+    std::string Key = prop->getString().str();
 
     // value
-    ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(md->getOperand(i + 1));
-    assert(Val && "Value operand not a constant int");
-
-    std::string keyname = prop->getString().str();
-    if (retval.find(keyname) != retval.end())
-      retval[keyname].push_back(Val->getZExtValue());
-    else {
-      std::vector<unsigned> tmp;
-      tmp.push_back(Val->getZExtValue());
-      retval[keyname] = tmp;
+    if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
+            MetadataNode->getOperand(i + 1))) {
+      retval[Key].push_back(Val->getZExtValue());
+    } else if (MDNode *VecMd =
+                   dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
+      // assert: there can only exist one unique key value pair of
+      // the form (string key, MDNode node). Operands of such a node
+      // shall always be unsigned ints.
+      if (retval.find(Key) == retval.end()) {
+        readIntVecFromMDNode(VecMd, retval[Key]);
+        continue;
+      }
+    } else {
+      llvm_unreachable("Value operand not a constant int or an mdnode");
     }
   }
 }
@@ -153,9 +169,9 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
 
 bool isTexture(const Value &val) {
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned annot;
-    if (findOneNVVMAnnotation(gv, "texture", annot)) {
-      assert((annot == 1) && "Unexpected annotation on a texture symbol");
+    unsigned Annot;
+    if (findOneNVVMAnnotation(gv, "texture", Annot)) {
+      assert((Annot == 1) && "Unexpected annotation on a texture symbol");
       return true;
     }
   }
@@ -164,70 +180,68 @@ bool isTexture(const Value &val) {
 
 bool isSurface(const Value &val) {
   if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned annot;
-    if (findOneNVVMAnnotation(gv, "surface", annot)) {
-      assert((annot == 1) && "Unexpected annotation on a surface symbol");
+    unsigned Annot;
+    if (findOneNVVMAnnotation(gv, "surface", Annot)) {
+      assert((Annot == 1) && "Unexpected annotation on a surface symbol");
       return true;
     }
   }
   return false;
 }
 
-bool isSampler(const Value &val) {
-  const char *AnnotationName = "sampler";
-
-  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned annot;
-    if (findOneNVVMAnnotation(gv, AnnotationName, annot)) {
-      assert((annot == 1) && "Unexpected annotation on a sampler symbol");
-      return true;
-    }
-  }
-  if (const Argument *arg = dyn_cast<Argument>(&val)) {
-    const Function *func = arg->getParent();
-    std::vector<unsigned> annot;
-    if (findAllNVVMAnnotation(func, AnnotationName, annot)) {
-      if (is_contained(annot, arg->getArgNo()))
+static bool argHasNVVMAnnotation(const Value &Val,
+                                 const std::string &Annotation,
+                                 const bool StartArgIndexAtOne = false) {
+  if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
+    const Function *Func = Arg->getParent();
+    std::vector<unsigned> Annot;
+    if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
+      const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
+      if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
         return true;
+      }
     }
   }
   return false;
 }
 
-bool isImageReadOnly(const Value &val) {
-  if (const Argument *arg = dyn_cast<Argument>(&val)) {
-    const Function *func = arg->getParent();
-    std::vector<unsigned> annot;
-    if (findAllNVVMAnnotation(func, "rdoimage", annot)) {
-      if (is_contained(annot, arg->getArgNo()))
-        return true;
+bool isParamGridConstant(const Value &V) {
+  if (const Argument *Arg = dyn_cast<Argument>(&V)) {
+    std::vector<unsigned> Annot;
+    // "grid_constant" counts argument indices starting from 1
+    if (Arg->hasByValAttr() &&
+        argHasNVVMAnnotation(*Arg, "grid_constant", true)) {
+      assert(isKernelFunction(*Arg->getParent()) &&
+             "only kernel arguments can be grid_constant");
+      return true;
     }
   }
   return false;
 }
 
-bool isImageWriteOnly(const Value &val) {
-  if (const Argument *arg = dyn_cast<Argument>(&val)) {
-    const Function *func = arg->getParent();
-    std::vector<unsigned> annot;
-    if (findAllNVVMAnnotation(func, "wroimage", annot)) {
-      if (is_contained(annot, arg->getArgNo()))
-        return true;
+bool isSampler(const Value &val) {
+  const char *AnnotationName = "sampler";
+
+  if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
+    unsigned Annot;
+    if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
+      assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
+      return true;
     }
   }
-  return false;
+  return argHasNVVMAnnotation(val, AnnotationName);
+}
+
+bool isImageReadOnly(const Value &val) {
+  return argHasNVVMAnnotation(val, "rdoimage");
+}
+
+bool isImageWriteOnly(const Value &val) {
+  return argHasNVVMAnnotation(val, "wroimage");
 }
 
 bool isImageReadWrite(const Value &val) {
-  if (const Argument *arg = dyn_cast<Argument>(&val)) {
-    const Function *func = arg->getParent();
-    std::vector<unsigned> annot;
-    if (findAllNVVMAnnotation(func, "rdwrimage", annot)) {
-      if (is_contained(annot, arg->getArgNo()))
-        return true;
-    }
-  }
-  return false;
+  return argHasNVVMAnnotation(val, "rdwrimage");
 }
 
 bool isImage(const Value &val) {
@@ -236,9 +250,9 @@ bool isImage(const Value &val) {
 
 bool isManaged(const Value &val) {
   if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
-    unsigned annot;
-    if (findOneNVVMAnnotation(gv, "managed", annot)) {
-      assert((annot == 1) && "Unexpected annotation on a managed symbol");
+    unsigned Annot;
+    if (findOneNVVMAnnotation(gv, "managed", Annot)) {
+      assert((Annot == 1) && "Unexpected annotation on a managed symbol");
       return true;
     }
   }
@@ -323,8 +337,7 @@ bool getMaxNReg(const Function &F, unsigned &x) {
 
 bool isKernelFunction(const Function &F) {
   unsigned x = 0;
-  bool retval = findOneNVVMAnnotation(&F, "kernel", x);
-  if (!retval) {
+  if (!findOneNVVMAnnotation(&F, "kernel", x)) {
     // There is no NVVM metadata, check the calling convention
     return F.getCallingConv() == CallingConv::PTX_Kernel;
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index e020bc0f02e96..c15ff6cae1f27 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -62,6 +62,7 @@ bool getMaxClusterRank(const Function &, unsigned &);
 bool getMinCTASm(const Function &, unsigned &);
 bool getMaxNReg(const Function &, unsigned &);
 bool isKernelFunction(const Function &);
+bool isParamGridConstant(const Value &);
 
 MaybeAlign getAlign(const Function &, unsigned);
 MaybeAlign getAlign(const CallInst &, unsigned);
diff --git a/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
new file mode 100644
index 0000000000000..46f54e0e6f4d4
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll
@@ -0,0 +1,155 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes OPT
+; RUN: llc < %s -mcpu=sm_70 --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes PTX
+
+define void @grid_const_int(ptr byval(i32) align 4 %input1, i32 %input2, ptr %out, i32 %n) {
+; PTX-LABEL: grid_const_int(
+; PTX-NOT:     ld.u32
+; PTX:         ld.param.{{.*}} [[R2:%.*]], [grid_const_int_param_0];
+; 
+; OPT-LABEL: define void @grid_const_int(
+; OPT-SAME: ptr byval(i32) align 4 [[INPUT1:%.*]], i32 [[INPUT2:%.*]], ptr [[OUT:%.*]], i32 [[N:%.*]]) {
+; OPT-NOT:     alloca
+; OPT:         [[INPUT11:%.*]] = addrspacecast ptr [[INPUT1]] to ptr addrspace(101)
+; OPT:         [[TMP:%.*]] = load i32, ptr addrspace(101) [[INPUT11]], align 4
+;
+  %tmp = load i32, ptr %input1, align 4
+  %add = add i32 %tmp, %input2
+  store i32 %add, ptr %out
+  ret void
+}
+
+%struct.s = type { i32, i32 }
+
+define void @grid_const_struct(ptr byval(%struct.s) align 4 %input, ptr %out){
+; PTX-LABEL: grid_const_struct(
+; PTX:       {
+; PTX-NOT:     ld.u32
+; PTX:         ld.param.{{.*}} [[R1:%.*]], [grid_const_struct_param_0];
+; PTX:         ld.param.{{.*}} [[R2:%.*]], [grid_const_struct_param_0+4];
+;
+; OPT-LABEL: define void @grid_const_struct(
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[OUT:%.*]]) {
+; OPT-NOT:     alloca
+; OPT:         [[INPUT1:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT:         [[GEP13:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 0
+; OPT:         [[GEP22:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr addrspace(101) [[INPUT1]], i32 0, i32 1
+; OPT:         [[TMP1:%.*]] = load i32, ptr addrspace(101) [[GEP13]], align 4
+; OPT:         [[TMP2:%.*]] = load i32, ptr addrspace(101) [[GEP22]], align 4
+;
+  %gep1 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 0
+  %gep2 = getelementptr inbounds %struct.s, ptr %input, i32 0, i32 1
+  %int1 = load i32, ptr %gep1
+  %int2 = load i32, ptr %gep2
+  %add = add i32 %int1, %int2
+  store i32 %add, ptr %out
+  ret void
+}
+
+define void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
+; PTX-LABEL: grid_const_escape(
+; PTX:       {
+; PTX-NOT:     .local
+; PTX:         cvta.param.{{.*}}
+; OPT-LABEL: define void @grid_const_escape(
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]]) {
+; OPT-NOT:     alloca [[STRUCT_S]]
+; OPT:         [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT:         [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
+; OPT:         [[CALL:%.*]] = call i32 @escape(ptr [[INPUT_PARAM_GEN]])
+;
+  %call = call i32 @escape(ptr %input)
+  ret void
+}
+
+define void @multiple_grid_const_escape(ptr byval(%struct.s) align 4 %input, i32 %a, ptr byval(i32) align 4 %b) {
+; PTX-LABEL: multiple_grid_const_escape(
+; PTX:         mov.{{.*}} [[RD1:%.*]], multiple_grid_const_escape_param_0;
+; PTX:         mov.{{.*}} [[RD2:%.*]], multiple_grid_const_escape_param_2;
+; PTX:         mov.{{.*}} [[RD3:%.*]], [[RD2]];
+; PTX:         cvta.param.{{.*}} [[RD4:%.*]], [[RD3]];
+; PTX:         mov.u64 [[RD5:%.*]], [[RD1]];
+; PTX:         cvta.param.{{.*}} [[RD6:%.*]], [[RD5]];
+; PTX:         {
+; PTX:         st.param.b64 [param0+0], [[RD6]];
+; PTX:         st.param.b64 [param2+0], [[RD4]];
+;
+; OPT-LABEL: define void @multiple_grid_const_escape(
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], i32 [[A:%.*]], ptr byval(i32) align 4 [[B:%.*]]) {
+; OPT-NOT:     alloca i32
+; OPT:         [[B_PARAM:%.*]] = addrspacecast ptr [[B]] to ptr addrspace(101)
+; OPT:         [[B_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[B_PARAM]])
+; OPT-NOT:     alloca [[STRUCT_S]]
+; OPT:         [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT:         [[INPUT_PARAM_GEN:%.*]] = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) [[INPUT_PARAM]])
+; OPT:         [[CALL:%.*]] = call i32 @escape3(ptr [[INPUT_PARAM_GEN]], ptr {{.*}}, ptr [[B_PARAM_GEN]])
+;
+  %a.addr = alloca i32, align 4
+  store i32 %a, ptr %a.addr, align 4
+  %call = call i32 @escape3(ptr %input, ptr %a.addr, ptr %b)
+  ret void
+}
+
+define void @grid_const_memory_escape(ptr byval(%struct.s) align 4 %input, ptr %addr) {
+; PTX-LABEL: grid_const_memory_escape(
+; PTX-NOT:     .local
+; PTX:         mov.b64 [[RD1:%.*]], grid_const_memory_escape_param_0;
+; PTX:         cvta.param.u64 [[RD3:%.*]], [[RD2:%.*]];
+; PTX:         st.global.u64 [[[RD4:%.*]]], [[RD3]];
+;
+; OPT-LABEL: define void @grid_const_memory_escape(
+; OPT-SAME: ptr byval([[STRUCT_S:%.*]]) align 4 [[INPUT:%.*]], ptr [[ADDR:%.*]]) {
+; OPT-NOT:     alloca [[STRUCT_S]]
+; OPT:         [[INPUT_PARAM:%.*]] = addrspacecast ptr [[INPUT]] to ptr addrspace(101)
+; OPT:         [[INPUT_PARAM_GEN:%....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/96125


More information about the llvm-commits mailing list