[llvm] [NVPTX] Cleanup/Refactoring in NVPTX AsmPrinter and RegisterInfo (NFC) (PR #126800)

Justin Fargnoli via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 11 14:06:33 PST 2025


================
@@ -1464,161 +1415,143 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
   O << "(\n";
 
-  for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
-    Type *Ty = I->getType();
+  for (const Argument &Arg : F->args()) {
+    Type *Ty = Arg.getType();
+    const std::string ParamSym = TLI->getParamName(F, Arg.getArgNo());
 
-    if (!first)
+    if (!IsFirst)
       O << ",\n";
 
-    first = false;
+    IsFirst = false;
 
     // Handle image/sampler parameters
-    if (isKernelFunc) {
-      if (isSampler(*I) || isImage(*I)) {
-        std::string ParamSym;
-        raw_string_ostream ParamStr(ParamSym);
-        ParamStr << F->getName() << "_param_" << paramIndex;
-        ParamStr.flush();
-        bool EmitImagePtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
-        if (isImage(*I)) {
-          if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
-            if (EmitImagePtr)
-              O << "\t.param .u64 .ptr .surfref ";
-            else
-              O << "\t.param .surfref ";
-            O << TLI->getParamName(F, paramIndex);
-          }
-          else { // Default image is read_only
-            if (EmitImagePtr)
-              O << "\t.param .u64 .ptr .texref ";
-            else
-              O << "\t.param .texref ";
-            O << TLI->getParamName(F, paramIndex);
-          }
-        } else {
-          if (EmitImagePtr)
-            O << "\t.param .u64 .ptr .samplerref ";
-          else
-            O << "\t.param .samplerref ";
-          O << TLI->getParamName(F, paramIndex);
-        }
+    if (IsKernelFunc) {
+      const bool IsSampler = isSampler(Arg);
+      const bool IsTexture = !IsSampler && isImageReadOnly(Arg);
+      const bool IsSurface = !IsSampler && !IsTexture &&
+                             (isImageReadWrite(Arg) || isImageWriteOnly(Arg));
+      if (IsSampler || IsTexture || IsSurface) {
+        const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
+        O << "\t.param ";
+        if (EmitImgPtr)
+          O << ".u64 .ptr ";
+
+        if (IsSampler)
+          O << ".samplerref ";
+        else if (IsTexture)
+          O << ".texref ";
+        else // IsSurface
+          O << ".samplerref ";
+        O << ParamSym;
         continue;
       }
     }
 
-    auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
-                                    paramIndex](Type *Ty) -> Align {
+    auto GetOptimalAlignForParam = [TLI, &DL, F, &Arg](Type *Ty) -> Align {
       if (MaybeAlign StackAlign =
-              getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
+              getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex))
         return StackAlign.value();
 
       Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
-      MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
+      MaybeAlign ParamAlign =
+          Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign();
       return std::max(TypeAlign, ParamAlign.valueOrOne());
     };
 
-    if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
-      if (ShouldPassAsArray(Ty)) {
-        // Just print .param .align <a> .b8 .param[size];
-        // <a>  = optimal alignment for the element type; always multiple of
-        //        PAL.getParamAlignment
-        // size = typeallocsize of element type
-        Align OptimalAlign = getOptimalAlignForParam(Ty);
+    if (Arg.hasByValAttr()) {
+      // param has byVal attribute.
+      Type *ETy = Arg.getParamByValType();
+      assert(ETy && "Param should have byval type");
+
+      // Print .param .align <a> .b8 .param[size];
+      // <a>  = optimal alignment for the element type; always multiple of
+      //        PAL.getParamAlignment
+      // size = typeallocsize of element type
+      const Align OptimalAlign =
+          IsKernelFunc ? GetOptimalAlignForParam(ETy)
+                       : TLI->getFunctionByValParamAlign(
+                             F, ETy, Arg.getParamAlign().valueOrOne(), DL);
+
+      O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
+        << "[" << DL.getTypeAllocSize(ETy) << "]";
+      continue;
+    }
 
-        O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
-        O << TLI->getParamName(F, paramIndex);
-        O << "[" << DL.getTypeAllocSize(Ty) << "]";
+    if (ShouldPassAsArray(Ty)) {
+      // Just print .param .align <a> .b8 .param[size];
+      // <a>  = optimal alignment for the element type; always multiple of
+      //        PAL.getParamAlignment
+      // size = typeallocsize of element type
+      Align OptimalAlign = GetOptimalAlignForParam(Ty);
 
-        continue;
-      }
-      // Just a scalar
-      auto *PTy = dyn_cast<PointerType>(Ty);
-      unsigned PTySizeInBits = 0;
-      if (PTy) {
-        PTySizeInBits =
-            TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
-        assert(PTySizeInBits && "Invalid pointer size");
-      }
+      O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym
+        << "[" << DL.getTypeAllocSize(Ty) << "]";
 
-      if (isKernelFunc) {
-        if (PTy) {
-          O << "\t.param .u" << PTySizeInBits << " .ptr";
-
-          switch (PTy->getAddressSpace()) {
-          default:
-            break;
-          case ADDRESS_SPACE_GLOBAL:
-            O << " .global";
-            break;
-          case ADDRESS_SPACE_SHARED:
-            O << " .shared";
-            break;
-          case ADDRESS_SPACE_CONST:
-            O << " .const";
-            break;
-          case ADDRESS_SPACE_LOCAL:
-            O << " .local";
-            break;
-          }
+      continue;
+    }
+    // Just a scalar
+    auto *PTy = dyn_cast<PointerType>(Ty);
+    unsigned PTySizeInBits = 0;
+    if (PTy) {
+      PTySizeInBits =
+          TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
+      assert(PTySizeInBits && "Invalid pointer size");
+    }
 
-          O << " .align " << I->getParamAlign().valueOrOne().value();
-          O << " " << TLI->getParamName(F, paramIndex);
-          continue;
+    if (IsKernelFunc) {
+      if (PTy) {
----------------
justinfargnoli wrote:

Opportunity for improvement: Make `if (PTy)` the outer condition and combine it with the above conditional `if (PTy)`. 

https://github.com/llvm/llvm-project/pull/126800


More information about the llvm-commits mailing list