[llvm] [AMDGPU] Split struct kernel arguments (PR #133786)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 23 10:07:11 PST 2026


================
@@ -42,8 +51,494 @@ static cl::opt<bool>
                          cl::desc("Enable preload kernel arguments to SGPRs"),
                          cl::init(true));
 
+static cl::opt<bool> EnableKernargLayoutChange(
+    "amdgpu-kernarg-layout-change",
+    cl::desc("Allow changing kernel argument segment layout when splitting "
+             "byref structs (remove unused fields, reorder for packing). "
+             "When disabled (default), all struct fields are preserved in "
+             "their original order."),
+    cl::init(false));
+
 namespace {
 
+//===----------------------------------------------------------------------===//
+// Kernel Argument Splitting Logic
+//
+// The following functions handle splitting of byref struct kernel arguments
+// into scalar arguments. This enables preloading of struct fields that would
+// otherwise not be preloadable due to the byref attribute.
+//===----------------------------------------------------------------------===//
+
+// Attribute name for tracking original argument index and offset
+static constexpr StringRef OriginalArgAttr = "amdgpu-original-arg";
+
+// Prefix for backup declaration of original kernel (used for metadata
+// generation)
+static constexpr StringRef OriginalKernelPrefix = "__amdgpu_orig_kernel_";
+
+// Attribute to store the name of the backup declaration
+static constexpr StringRef OriginalKernelAttr = "amdgpu-original-kernel";
+
+static bool parseOriginalArgAttribute(StringRef S, unsigned &RootIdx,
+                                      uint64_t &BaseOff) {
+  auto Parts = S.split(':');
+  if (Parts.second.empty())
+    return false;
+  if (Parts.first.getAsInteger(10, RootIdx))
+    return false;
+  if (Parts.second.getAsInteger(10, BaseOff))
+    return false;
+  return true;
+}
+
+/// Traverses all users of an argument to check if it's suitable for
+/// splitting. A suitable argument is only used by a chain of
+/// GEPs that terminate in LoadInsts.
+static bool
+areArgUsersValidForSplit(Argument &Arg, SmallVectorImpl<LoadInst *> &Loads,
+                         SmallVectorImpl<GetElementPtrInst *> &GEPs) {
+  SmallVector<User *, 16> Worklist(Arg.user_begin(), Arg.user_end());
+  SetVector<User *> Visited;
+
+  while (!Worklist.empty()) {
+    User *U = Worklist.pop_back_val();
+    if (!Visited.insert(U))
+      continue;
+
+    if (auto *LI = dyn_cast<LoadInst>(U)) {
+      Loads.push_back(LI);
+    } else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      GEPs.push_back(GEP);
+      for (User *GEPUser : GEP->users()) {
+        Worklist.push_back(GEPUser);
+      }
+    } else
+      return false;
+  }
+
+  const DataLayout &DL = Arg.getParent()->getParent()->getDataLayout();
+  for (const LoadInst *LI : Loads) {
+    APInt Offset(DL.getPointerSizeInBits(), 0);
+    const Value *Base =
+        LI->getPointerOperand()->stripAndAccumulateConstantOffsets(
+            DL, Offset, /*AllowNonInbounds=*/false);
+    if (Base != &Arg)
+      return false;
+  }
+
+  return true;
+}
+
+/// Information about a struct field to be flattened into a scalar argument.
+struct FieldInfo {
+  Type *Ty;
+  uint64_t Offset;
+  LoadInst *Load; // nullptr if field is unused
+};
+
+/// Recursively collect all leaf (scalar) fields from a type with their offsets.
+/// This flattens nested structs and arrays into individual scalar fields.
+static void collectLeafFields(Type *Ty, const DataLayout &DL,
+                              uint64_t BaseOffset,
+                              SmallVectorImpl<FieldInfo> &Fields) {
+  if (auto *STy = dyn_cast<StructType>(Ty)) {
+    const StructLayout *SL = DL.getStructLayout(STy);
+    for (unsigned I = 0; I < STy->getNumElements(); ++I) {
+      Type *ElemTy = STy->getElementType(I);
+      uint64_t ElemOffset = BaseOffset + SL->getElementOffset(I);
+      collectLeafFields(ElemTy, DL, ElemOffset, Fields);
+    }
+  } else if (auto *ATy = dyn_cast<ArrayType>(Ty)) {
+    Type *ElemTy = ATy->getElementType();
+    uint64_t ElemSize = DL.getTypeAllocSize(ElemTy);
+    for (uint64_t I = 0; I < ATy->getNumElements(); ++I) {
+      collectLeafFields(ElemTy, DL, BaseOffset + I * ElemSize, Fields);
+    }
+  } else {
+    // Leaf type (scalar, vector, pointer, etc.)
+    Fields.push_back({Ty, BaseOffset, nullptr});
+  }
+}
+
+/// Check if split arguments can be preloaded into SGPRs.
+/// This calculates the new arg layout size after splitting and checks if it
+/// fits in available user SGPRs.
+static bool canPreloadSplitArgs(
+    Function &F, const GCNSubtarget &ST,
+    const DenseMap<Argument *, SmallVector<FieldInfo, 8>> &ArgToFieldsMap) {
+  GCNUserSGPRUsageInfo UserSGPRInfo(F, ST);
+  unsigned NumFreeUserSGPRs = UserSGPRInfo.getNumFreeUserSGPRs();
+  uint64_t AvailableBytes = NumFreeUserSGPRs * 4;
+
+  const DataLayout &DL = F.getParent()->getDataLayout();
+  uint64_t NewArgOffset = 0;
+
+  // Calculate the new arg layout size after splitting
+  for (Argument &Arg : F.args()) {
+    auto It = ArgToFieldsMap.find(&Arg);
+    if (It != ArgToFieldsMap.end()) {
+      // This arg will be split - add sizes of replacement scalar args
+      for (const FieldInfo &FI : It->second) {
+        Align ABITypeAlign = DL.getABITypeAlign(FI.Ty);
+        uint64_t AllocSize = DL.getTypeAllocSize(FI.Ty);
+        NewArgOffset = alignTo(NewArgOffset, ABITypeAlign) + AllocSize;
+      }
+    } else {
+      // This arg is not split - keep original size
+      Type *ArgTy = Arg.getType();
+      if (Arg.hasByRefAttr())
+        ArgTy = Arg.getParamByRefType();
+      Align ABITypeAlign = DL.getABITypeAlign(ArgTy);
+      uint64_t AllocSize = DL.getTypeAllocSize(ArgTy);
+      NewArgOffset = alignTo(NewArgOffset, ABITypeAlign) + AllocSize;
+    }
+  }
+
+  return NewArgOffset <= AvailableBytes;
+}
+
+/// Try to split byref struct kernel arguments into scalar arguments.
+/// Returns the new function with split arguments, or nullptr if no split
+/// was performed. If a new function is returned, the original function F
+/// has been erased and should not be used.
+///
+/// When EnableKernargLayoutChange is false (default), ALL struct fields are
+/// preserved in their original order (recursively flattening nested structs),
+/// maintaining the kernel argument segment layout. Unused fields become dead
+/// arguments.
+///
+/// When EnableKernargLayoutChange is true, only used fields are kept and
+/// the layout may change.
+static Function *trySplitKernelArguments(Function &F, const GCNSubtarget &ST) {
+  if (F.isDeclaration() || F.getCallingConv() != CallingConv::AMDGPU_KERNEL ||
+      F.arg_empty())
+    return nullptr;
+
+  if (!ST.hasKernargPreload())
+    return nullptr;
+
+  const DataLayout &DL = F.getParent()->getDataLayout();
+
+  SmallVector<std::tuple<unsigned, unsigned, uint64_t>, 8> NewArgMappings;
+  DenseMap<Argument *, SmallVector<LoadInst *, 8>> ArgToLoadsMap;
+  DenseMap<Argument *, SmallVector<GetElementPtrInst *, 8>> ArgToGEPsMap;
+  // Maps struct arg to field info (type, offset, associated load if any)
+  DenseMap<Argument *, SmallVector<FieldInfo, 8>> ArgToFieldsMap;
+  SmallVector<Argument *, 8> StructArgs;
+  SmallVector<Type *, 8> NewArgTypes;
+
+  unsigned OriginalArgIndex = 0;
+  unsigned NewArgIndex = 0;
+  auto HandlePassthroughArg = [&](Argument &Arg) {
+    NewArgTypes.push_back(Arg.getType());
+    if (!Arg.hasAttribute(OriginalArgAttr) && NewArgIndex != OriginalArgIndex) {
+      NewArgMappings.emplace_back(NewArgIndex, OriginalArgIndex, 0);
+    }
+    ++NewArgIndex;
+    ++OriginalArgIndex;
+  };
+
+  for (Argument &Arg : F.args()) {
+    PointerType *PT = dyn_cast<PointerType>(Arg.getType());
+    if (!PT || !Arg.hasByRefAttr()) {
+      HandlePassthroughArg(Arg);
+      continue;
+    }
+
+    StructType *STy = dyn_cast<StructType>(Arg.getParamByRefType());
+    if (!STy) {
+      HandlePassthroughArg(Arg);
+      continue;
+    }
+
+    // Collect loads from this struct argument
+    SmallVector<LoadInst *, 8> Loads;
+    SmallVector<GetElementPtrInst *, 8> GEPs;
+
+    // Check if all users are valid for splitting (GEPs + loads)
+    bool HasValidUsers =
+        Arg.use_empty() || areArgUsersValidForSplit(Arg, Loads, GEPs);
+    if (!HasValidUsers) {
+      HandlePassthroughArg(Arg);
+      continue;
+    }
+
+    // Helper to get load offset. Returns std::nullopt if offset can't be
+    // computed (e.g., variable-index GEP).
+    auto GetLoadOffset = [&](LoadInst *LI) -> std::optional<uint64_t> {
+      Value *Ptr = LI->getPointerOperand();
+      // Direct load from argument (offset 0)
+      if (Ptr == &Arg)
+        return 0;
+      if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
+        APInt OffsetAPInt(DL.getPointerSizeInBits(), 0);
+        if (GEP->accumulateConstantOffset(DL, OffsetAPInt))
+          return OffsetAPInt.getZExtValue();
+      }
+      return std::nullopt;
+    };
+
+    unsigned RootIdx = OriginalArgIndex;
+    uint64_t BaseOffset = 0;
+
+    if (Arg.hasAttribute(OriginalArgAttr)) {
+      Attribute Attr = F.getAttributeAtIndex(OriginalArgIndex, OriginalArgAttr);
+      (void)parseOriginalArgAttribute(Attr.getValueAsString(), RootIdx,
+                                      BaseOffset);
+    }
+
+    // Build a map from offset to load for matching. Skip splitting if any
+    // load has a variable-index GEP (can't compute constant offset).
+    DenseMap<uint64_t, LoadInst *> OffsetToLoad;
+    bool HasVariableIndexLoad = false;
+    for (LoadInst *LI : Loads) {
+      auto Off = GetLoadOffset(LI);
+      if (!Off) {
+        HasVariableIndexLoad = true;
+        break;
+      }
+      OffsetToLoad[*Off] = LI;
+    }
+
+    if (HasVariableIndexLoad) {
+      LLVM_DEBUG(dbgs() << "Skipping split for " << F.getName()
+                        << ": load with variable-index GEP\n");
+      HandlePassthroughArg(Arg);
+      continue;
+    }
+
+    StructArgs.push_back(&Arg);
+    ArgToLoadsMap[&Arg] = Loads;
+    ArgToGEPsMap[&Arg] = GEPs;
+
+    SmallVector<FieldInfo, 8> Fields;
+
+    if (EnableKernargLayoutChange) {
+      // Layout change allowed: only keep used fields, sorted by offset
+      llvm::sort(Loads, [&](LoadInst *A, LoadInst *B) {
+        return *GetLoadOffset(A) < *GetLoadOffset(B);
+      });
+
+      for (LoadInst *LI : Loads) {
+        uint64_t LocalOff = *GetLoadOffset(LI);
----------------
shiltian wrote:

There is no point of calling this again. Cache it in the first run?

https://github.com/llvm/llvm-project/pull/133786


More information about the llvm-commits mailing list