[llvm] [AMDGPU] Split struct kernel arguments (PR #133786)
Shilei Tian via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 23 10:07:11 PST 2026
================
@@ -42,8 +51,494 @@ static cl::opt<bool>
cl::desc("Enable preload kernel arguments to SGPRs"),
cl::init(true));
+static cl::opt<bool> EnableKernargLayoutChange(
+ "amdgpu-kernarg-layout-change",
+ cl::desc("Allow changing kernel argument segment layout when splitting "
+ "byref structs (remove unused fields, reorder for packing). "
+ "When disabled (default), all struct fields are preserved in "
+ "their original order."),
+ cl::init(false));
+
namespace {
+//===----------------------------------------------------------------------===//
+// Kernel Argument Splitting Logic
+//
+// The following functions handle splitting of byref struct kernel arguments
+// into scalar arguments. This enables preloading of struct fields that would
+// otherwise not be preloadable due to the byref attribute.
+//===----------------------------------------------------------------------===//
+
+// Attribute name for tracking original argument index and offset
+static constexpr StringRef OriginalArgAttr = "amdgpu-original-arg";
+
+// Prefix for backup declaration of original kernel (used for metadata
+// generation)
+static constexpr StringRef OriginalKernelPrefix = "__amdgpu_orig_kernel_";
+
+// Attribute to store the name of the backup declaration
+static constexpr StringRef OriginalKernelAttr = "amdgpu-original-kernel";
+
+static bool parseOriginalArgAttribute(StringRef S, unsigned &RootIdx,
+ uint64_t &BaseOff) {
+ auto Parts = S.split(':');
+ if (Parts.second.empty())
+ return false;
+ if (Parts.first.getAsInteger(10, RootIdx))
+ return false;
+ if (Parts.second.getAsInteger(10, BaseOff))
+ return false;
+ return true;
+}
+
+/// Traverses all users of an argument to check if it's suitable for
+/// splitting. A suitable argument is only used by a chain of
+/// GEPs that terminate in LoadInsts.
+static bool
+areArgUsersValidForSplit(Argument &Arg, SmallVectorImpl<LoadInst *> &Loads,
+ SmallVectorImpl<GetElementPtrInst *> &GEPs) {
+ SmallVector<User *, 16> Worklist(Arg.user_begin(), Arg.user_end());
+ SetVector<User *> Visited;
+
+ while (!Worklist.empty()) {
+ User *U = Worklist.pop_back_val();
+ if (!Visited.insert(U))
+ continue;
+
+ if (auto *LI = dyn_cast<LoadInst>(U)) {
+ Loads.push_back(LI);
+ } else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
+ GEPs.push_back(GEP);
+ for (User *GEPUser : GEP->users()) {
+ Worklist.push_back(GEPUser);
+ }
+ } else
+ return false;
+ }
+
+ const DataLayout &DL = Arg.getParent()->getParent()->getDataLayout();
+ for (const LoadInst *LI : Loads) {
+ APInt Offset(DL.getPointerSizeInBits(), 0);
+ const Value *Base =
+ LI->getPointerOperand()->stripAndAccumulateConstantOffsets(
+ DL, Offset, /*AllowNonInbounds=*/false);
+ if (Base != &Arg)
+ return false;
+ }
+
+ return true;
+}
+
+/// Information about a struct field to be flattened into a scalar argument.
+struct FieldInfo {
+ Type *Ty;
+ uint64_t Offset;
+ LoadInst *Load; // nullptr if field is unused
+};
+
+/// Recursively collect all leaf (scalar) fields from a type with their offsets.
+/// This flattens nested structs and arrays into individual scalar fields.
+static void collectLeafFields(Type *Ty, const DataLayout &DL,
+ uint64_t BaseOffset,
+ SmallVectorImpl<FieldInfo> &Fields) {
+ if (auto *STy = dyn_cast<StructType>(Ty)) {
+ const StructLayout *SL = DL.getStructLayout(STy);
+ for (unsigned I = 0; I < STy->getNumElements(); ++I) {
+ Type *ElemTy = STy->getElementType(I);
+ uint64_t ElemOffset = BaseOffset + SL->getElementOffset(I);
+ collectLeafFields(ElemTy, DL, ElemOffset, Fields);
+ }
+ } else if (auto *ATy = dyn_cast<ArrayType>(Ty)) {
+ Type *ElemTy = ATy->getElementType();
+ uint64_t ElemSize = DL.getTypeAllocSize(ElemTy);
+ for (uint64_t I = 0; I < ATy->getNumElements(); ++I) {
+ collectLeafFields(ElemTy, DL, BaseOffset + I * ElemSize, Fields);
+ }
+ } else {
+ // Leaf type (scalar, vector, pointer, etc.)
+ Fields.push_back({Ty, BaseOffset, nullptr});
+ }
+}
+
+/// Check if split arguments can be preloaded into SGPRs.
+/// This calculates the new arg layout size after splitting and checks if it
+/// fits in available user SGPRs.
+static bool canPreloadSplitArgs(
+ Function &F, const GCNSubtarget &ST,
+ const DenseMap<Argument *, SmallVector<FieldInfo, 8>> &ArgToFieldsMap) {
+ GCNUserSGPRUsageInfo UserSGPRInfo(F, ST);
+ unsigned NumFreeUserSGPRs = UserSGPRInfo.getNumFreeUserSGPRs();
+ uint64_t AvailableBytes = NumFreeUserSGPRs * 4;
+
+ const DataLayout &DL = F.getParent()->getDataLayout();
+ uint64_t NewArgOffset = 0;
+
+ // Calculate the new arg layout size after splitting
+ for (Argument &Arg : F.args()) {
+ auto It = ArgToFieldsMap.find(&Arg);
+ if (It != ArgToFieldsMap.end()) {
+ // This arg will be split - add sizes of replacement scalar args
+ for (const FieldInfo &FI : It->second) {
+ Align ABITypeAlign = DL.getABITypeAlign(FI.Ty);
+ uint64_t AllocSize = DL.getTypeAllocSize(FI.Ty);
+ NewArgOffset = alignTo(NewArgOffset, ABITypeAlign) + AllocSize;
+ }
+ } else {
+ // This arg is not split - keep original size
+ Type *ArgTy = Arg.getType();
+ if (Arg.hasByRefAttr())
+ ArgTy = Arg.getParamByRefType();
+ Align ABITypeAlign = DL.getABITypeAlign(ArgTy);
+ uint64_t AllocSize = DL.getTypeAllocSize(ArgTy);
+ NewArgOffset = alignTo(NewArgOffset, ABITypeAlign) + AllocSize;
+ }
+ }
+
+ return NewArgOffset <= AvailableBytes;
+}
+
+/// Try to split byref struct kernel arguments into scalar arguments.
+/// Returns the new function with split arguments, or nullptr if no split
+/// was performed. If a new function is returned, the original function F
+/// has been erased and should not be used.
+///
+/// When EnableKernargLayoutChange is false (default), ALL struct fields are
+/// preserved in their original order (recursively flattening nested structs),
+/// maintaining the kernel argument segment layout. Unused fields become dead
+/// arguments.
+///
+/// When EnableKernargLayoutChange is true, only used fields are kept and
+/// the layout may change.
+static Function *trySplitKernelArguments(Function &F, const GCNSubtarget &ST) {
+ if (F.isDeclaration() || F.getCallingConv() != CallingConv::AMDGPU_KERNEL ||
+ F.arg_empty())
+ return nullptr;
+
+ if (!ST.hasKernargPreload())
+ return nullptr;
+
+ const DataLayout &DL = F.getParent()->getDataLayout();
+
+ SmallVector<std::tuple<unsigned, unsigned, uint64_t>, 8> NewArgMappings;
+ DenseMap<Argument *, SmallVector<LoadInst *, 8>> ArgToLoadsMap;
+ DenseMap<Argument *, SmallVector<GetElementPtrInst *, 8>> ArgToGEPsMap;
+ // Maps struct arg to field info (type, offset, associated load if any)
+ DenseMap<Argument *, SmallVector<FieldInfo, 8>> ArgToFieldsMap;
+ SmallVector<Argument *, 8> StructArgs;
+ SmallVector<Type *, 8> NewArgTypes;
+
+ unsigned OriginalArgIndex = 0;
+ unsigned NewArgIndex = 0;
+ auto HandlePassthroughArg = [&](Argument &Arg) {
+ NewArgTypes.push_back(Arg.getType());
+ if (!Arg.hasAttribute(OriginalArgAttr) && NewArgIndex != OriginalArgIndex) {
+ NewArgMappings.emplace_back(NewArgIndex, OriginalArgIndex, 0);
+ }
+ ++NewArgIndex;
+ ++OriginalArgIndex;
+ };
+
+ for (Argument &Arg : F.args()) {
+ PointerType *PT = dyn_cast<PointerType>(Arg.getType());
+ if (!PT || !Arg.hasByRefAttr()) {
+ HandlePassthroughArg(Arg);
+ continue;
+ }
+
+ StructType *STy = dyn_cast<StructType>(Arg.getParamByRefType());
+ if (!STy) {
+ HandlePassthroughArg(Arg);
+ continue;
+ }
+
+ // Collect loads from this struct argument
+ SmallVector<LoadInst *, 8> Loads;
+ SmallVector<GetElementPtrInst *, 8> GEPs;
+
+ // Check if all users are valid for splitting (GEPs + loads)
+ bool HasValidUsers =
+ Arg.use_empty() || areArgUsersValidForSplit(Arg, Loads, GEPs);
+ if (!HasValidUsers) {
+ HandlePassthroughArg(Arg);
+ continue;
+ }
+
+ // Helper to get load offset. Returns std::nullopt if offset can't be
+ // computed (e.g., variable-index GEP).
+ auto GetLoadOffset = [&](LoadInst *LI) -> std::optional<uint64_t> {
+ Value *Ptr = LI->getPointerOperand();
+ // Direct load from argument (offset 0)
+ if (Ptr == &Arg)
+ return 0;
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
+ APInt OffsetAPInt(DL.getPointerSizeInBits(), 0);
+ if (GEP->accumulateConstantOffset(DL, OffsetAPInt))
+ return OffsetAPInt.getZExtValue();
+ }
+ return std::nullopt;
+ };
+
+ unsigned RootIdx = OriginalArgIndex;
+ uint64_t BaseOffset = 0;
+
+ if (Arg.hasAttribute(OriginalArgAttr)) {
+ Attribute Attr = F.getAttributeAtIndex(OriginalArgIndex, OriginalArgAttr);
+ (void)parseOriginalArgAttribute(Attr.getValueAsString(), RootIdx,
+ BaseOffset);
+ }
+
+ // Build a map from offset to load for matching. Skip splitting if any
+ // load has a variable-index GEP (can't compute constant offset).
+ DenseMap<uint64_t, LoadInst *> OffsetToLoad;
+ bool HasVariableIndexLoad = false;
+ for (LoadInst *LI : Loads) {
+ auto Off = GetLoadOffset(LI);
+ if (!Off) {
+ HasVariableIndexLoad = true;
+ break;
+ }
+ OffsetToLoad[*Off] = LI;
+ }
+
+ if (HasVariableIndexLoad) {
+ LLVM_DEBUG(dbgs() << "Skipping split for " << F.getName()
+ << ": load with variable-index GEP\n");
+ HandlePassthroughArg(Arg);
+ continue;
+ }
+
+ StructArgs.push_back(&Arg);
+ ArgToLoadsMap[&Arg] = Loads;
+ ArgToGEPsMap[&Arg] = GEPs;
+
+ SmallVector<FieldInfo, 8> Fields;
+
+ if (EnableKernargLayoutChange) {
+ // Layout change allowed: only keep used fields, sorted by offset
+ llvm::sort(Loads, [&](LoadInst *A, LoadInst *B) {
+ return *GetLoadOffset(A) < *GetLoadOffset(B);
+ });
+
+ for (LoadInst *LI : Loads) {
+ uint64_t LocalOff = *GetLoadOffset(LI);
----------------
shiltian wrote:
There is no point of calling this again. Cache it in the first run?
https://github.com/llvm/llvm-project/pull/133786
More information about the llvm-commits
mailing list