[llvm] [AMDGPU] Split struct kernel arguments (PR #133786)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 31 17:31:12 PDT 2025
================
@@ -0,0 +1,373 @@
+//===--- AMDGPUSplitKernelArguments.cpp - Split kernel arguments ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file This pass flats struct-type kernel arguments. It eliminates unused
+// fields and only keeps used fields. The objective is to facilitate preloading
+// of kernel arguments by later passes.
+//
+//===----------------------------------------------------------------------===//
+#include "AMDGPU.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+#define DEBUG_TYPE "amdgpu-split-kernel-arguments"
+
+using namespace llvm;
+
+namespace {
+static llvm::cl::opt<bool> EnableSplitKernelArgs(
+ "amdgpu-enable-split-kernel-args",
+ llvm::cl::desc("Enable splitting of AMDGPU kernel arguments"),
+ llvm::cl::init(false));
+
+class AMDGPUSplitKernelArguments : public ModulePass {
+public:
+ static char ID;
+
+ AMDGPUSplitKernelArguments() : ModulePass(ID) {}
+
+ bool runOnModule(Module &M) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ }
+
+private:
+ bool processFunction(Function &F);
+};
+
+} // end anonymous namespace
+
+bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
+ const DataLayout &DL = F.getParent()->getDataLayout();
+ LLVM_DEBUG(dbgs() << "Entering AMDGPUSplitKernelArguments::processFunction "
+ << F.getName() << '\n');
+ if (F.isDeclaration()) {
+ LLVM_DEBUG(dbgs() << "Function is a declaration, skipping\n");
+ return false;
+ }
+
+ CallingConv::ID CC = F.getCallingConv();
+ if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) {
+ LLVM_DEBUG(dbgs() << "non-kernel or arg_empty\n");
+ return false;
+ }
+
+ SmallVector<std::tuple<unsigned, unsigned, uint64_t>, 8> NewArgMappings;
+ DenseMap<Argument *, SmallVector<LoadInst *, 8>> ArgToLoadsMap;
+ DenseMap<Argument *, SmallVector<GetElementPtrInst *, 8>> ArgToGEPsMap;
+ SmallVector<Argument *, 8> StructArgs;
+ SmallVector<Type *, 8> NewArgTypes;
+
+ auto convertAddressSpace = [](Type *Ty) -> Type * {
+ if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+ if (PtrTy->getAddressSpace() == AMDGPUAS::FLAT_ADDRESS) {
+ return PointerType::get(PtrTy->getContext(), AMDGPUAS::GLOBAL_ADDRESS);
+ }
+ }
+ return Ty;
+ };
+
+ // Collect struct arguments and new argument types
+ unsigned OriginalArgIndex = 0;
+ unsigned NewArgIndex = 0;
+ for (Argument &Arg : F.args()) {
+ LLVM_DEBUG(dbgs() << "Processing argument: " << Arg << "\n");
+ if (Arg.use_empty()) {
+ NewArgTypes.push_back(convertAddressSpace(Arg.getType()));
+ NewArgMappings.push_back(
+ std::make_tuple(NewArgIndex, OriginalArgIndex, 0));
+ ++NewArgIndex;
+ ++OriginalArgIndex;
+ LLVM_DEBUG(dbgs() << "use empty\n");
+ continue;
+ }
+
+ PointerType *PT = dyn_cast<PointerType>(Arg.getType());
+ if (!PT) {
+ NewArgTypes.push_back(Arg.getType());
+ LLVM_DEBUG(dbgs() << "not a pointer\n");
+ // Include mapping if indices have changed
+ if (NewArgIndex != OriginalArgIndex)
+ NewArgMappings.push_back(
+ std::make_tuple(NewArgIndex, OriginalArgIndex, 0));
+ ++NewArgIndex;
+ ++OriginalArgIndex;
+ continue;
+ }
+
+ const bool IsByRef = Arg.hasByRefAttr();
+ if (!IsByRef) {
+ NewArgTypes.push_back(Arg.getType());
+ LLVM_DEBUG(dbgs() << "not byref\n");
+ // Include mapping if indices have changed
+ if (NewArgIndex != OriginalArgIndex)
+ NewArgMappings.push_back(
+ std::make_tuple(NewArgIndex, OriginalArgIndex, 0));
+ ++NewArgIndex;
+ ++OriginalArgIndex;
+ continue;
+ }
+
+ Type *ArgTy = Arg.getParamByRefType();
+ StructType *ST = dyn_cast<StructType>(ArgTy);
+ if (!ST) {
+ NewArgTypes.push_back(Arg.getType());
+ LLVM_DEBUG(dbgs() << "not a struct\n");
+ // Include mapping if indices have changed
+ if (NewArgIndex != OriginalArgIndex)
+ NewArgMappings.push_back(
+ std::make_tuple(NewArgIndex, OriginalArgIndex, 0));
+ ++NewArgIndex;
+ ++OriginalArgIndex;
+ continue;
+ }
+
+ bool AllLoadsOrGEPs = true;
+ SmallVector<LoadInst *, 8> Loads;
+ SmallVector<GetElementPtrInst *, 8> GEPs;
+ for (User *U : Arg.users()) {
+ LLVM_DEBUG(dbgs() << " User: " << *U << "\n");
+ if (auto *LI = dyn_cast<LoadInst>(U)) {
+ Loads.push_back(LI);
+ } else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
+ GEPs.push_back(GEP);
+ for (User *GEPUser : GEP->users()) {
+ LLVM_DEBUG(dbgs() << " GEP User: " << *GEPUser << "\n");
+ if (auto *GEPLoad = dyn_cast<LoadInst>(GEPUser)) {
+ Loads.push_back(GEPLoad);
+ } else {
+ AllLoadsOrGEPs = false;
+ break;
+ }
+ }
+ } else {
+ AllLoadsOrGEPs = false;
+ break;
+ }
+ if (!AllLoadsOrGEPs)
+ break;
+ }
+ LLVM_DEBUG(dbgs() << " AllLoadsOrGEPs: "
+ << (AllLoadsOrGEPs ? "true" : "false") << "\n");
+
+ if (AllLoadsOrGEPs) {
+ StructArgs.push_back(&Arg);
+ ArgToLoadsMap[&Arg] = Loads;
+ ArgToGEPsMap[&Arg] = GEPs;
+ for (LoadInst *LI : Loads) {
+ Type *NewType = convertAddressSpace(LI->getType());
+ NewArgTypes.push_back(NewType);
+
+ // Compute offset
+ uint64_t Offset = 0;
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand())) {
+ APInt OffsetAPInt(DL.getPointerSizeInBits(), 0);
+ if (GEP->accumulateConstantOffset(DL, OffsetAPInt))
+ Offset = OffsetAPInt.getZExtValue();
+ }
+
+ // Map each new argument to the original argument index and offset
+ NewArgMappings.push_back(
+ std::make_tuple(NewArgIndex, OriginalArgIndex, Offset));
+ ++NewArgIndex;
+ }
+ } else {
+ NewArgTypes.push_back(convertAddressSpace(Arg.getType()));
+ // Include mapping if indices have changed
+ if (NewArgIndex != OriginalArgIndex)
+ NewArgMappings.push_back(
+ std::make_tuple(NewArgIndex, OriginalArgIndex, 0));
+ ++NewArgIndex;
+ }
+ ++OriginalArgIndex;
+ }
+
+ if (StructArgs.empty())
+ return false;
+
+ // Collect function and return attributes
+ AttributeList OldAttrs = F.getAttributes();
+ AttributeSet FnAttrs = OldAttrs.getFnAttrs();
+ AttributeSet RetAttrs = OldAttrs.getRetAttrs();
+
+ // Create new function type
+ FunctionType *NewFT =
+ FunctionType::get(F.getReturnType(), NewArgTypes, F.isVarArg());
+ Function *NewF =
+ Function::Create(NewFT, F.getLinkage(), F.getAddressSpace(), F.getName());
+ F.getParent()->getFunctionList().insert(F.getIterator(), NewF);
+ NewF->takeName(&F);
+ NewF->setVisibility(F.getVisibility());
+ if (F.hasComdat())
+ NewF->setComdat(F.getComdat());
+ NewF->setDSOLocal(F.isDSOLocal());
+ NewF->setUnnamedAddr(F.getUnnamedAddr());
+ NewF->setCallingConv(F.getCallingConv());
+
+ // Build new parameter attributes
+ SmallVector<AttributeSet, 8> NewArgAttrSets;
+ NewArgIndex = 0;
+ for (Argument &Arg : F.args()) {
+ if (ArgToLoadsMap.count(&Arg)) {
+ for (LoadInst *LI : ArgToLoadsMap[&Arg]) {
+ (void)LI;
+ NewArgAttrSets.push_back(AttributeSet());
+ ++NewArgIndex;
+ }
+ } else {
+ AttributeSet ArgAttrs = OldAttrs.getParamAttrs(Arg.getArgNo());
+ NewArgAttrSets.push_back(ArgAttrs);
+ ++NewArgIndex;
+ }
+ }
+
+ // Build the new AttributeList
+ AttributeList NewAttrList =
+ AttributeList::get(F.getContext(), FnAttrs, RetAttrs, NewArgAttrSets);
+ NewF->setAttributes(NewAttrList);
+
+ // Add the mapping of the new arguments to the old arguments as a function
+ // attribute in the format "NewArgIndex:OriginalArgIndex:Offset,..."
+ std::string MappingStr;
+ for (const auto &Info : NewArgMappings) {
----------------
arsenm wrote:
This format is going to be really hard to follow. If you really want to mark up an argument, it should be directly attached to the argument
https://github.com/llvm/llvm-project/pull/133786
More information about the llvm-commits
mailing list