[llvm] [AMDGPU] Split struct kernel arguments (PR #133786)
Yaxun Liu via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 22 06:49:18 PST 2026
================
@@ -0,0 +1,365 @@
+//===--- AMDGPUSplitKernelArguments.cpp - Split kernel arguments ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file This pass flats struct-type kernel arguments. It eliminates unused
+// fields and only keeps used fields. The objective is to facilitate preloading
+// of kernel arguments by later passes.
+//
+//===----------------------------------------------------------------------===//
+#include "AMDGPU.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+#define DEBUG_TYPE "amdgpu-split-kernel-arguments"
+
+using namespace llvm;
+
+namespace {
+static cl::opt<bool> EnableSplitKernelArgs(
+ "amdgpu-enable-split-kernel-args",
+ cl::desc("Enable splitting of AMDGPU kernel arguments"), cl::init(false));
+
+bool parseOriginalArgAttribute(StringRef S, unsigned &RootIdx,
+ uint64_t &BaseOff) {
+ auto Parts = S.split(':');
+ if (Parts.second.empty())
+ return false;
+ if (Parts.first.getAsInteger(10, RootIdx))
+ return false;
+ if (Parts.second.getAsInteger(10, BaseOff))
+ return false;
+ return true;
+}
+
+class AMDGPUSplitKernelArguments : public ModulePass {
+public:
+ static char ID;
+
+ AMDGPUSplitKernelArguments() : ModulePass(ID) {}
+
+ bool runOnModule(Module &M) override;
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ }
+
+private:
+ static inline constexpr StringRef OriginalArgAttr = "amdgpu-original-arg";
+
+ bool processFunction(Function &F);
+
+ /// Traverses all users of an argument to check if it's suitable for
+ /// splitting. A suitable argument is only used by a chain of
+ /// GEPs that terminate in LoadInsts.
+ /// @return True if the argument is suitable, false otherwise.
+ /// @param Loads [out] The list of terminating loads found.
+ /// @param GEPs [out] The list of GEPs in the use-chains.
+ bool areArgUsersValidForSplit(Argument &Arg,
+ SmallVectorImpl<LoadInst *> &Loads,
+ SmallVectorImpl<GetElementPtrInst *> &GEPs);
+};
+
+} // end anonymous namespace
+
+bool AMDGPUSplitKernelArguments::areArgUsersValidForSplit(
+ Argument &Arg, SmallVectorImpl<LoadInst *> &Loads,
+ SmallVectorImpl<GetElementPtrInst *> &GEPs) {
+
+ SmallVector<User *, 16> Worklist(Arg.user_begin(), Arg.user_end());
+ SetVector<User *> Visited;
+
+ while (!Worklist.empty()) {
+ User *U = Worklist.pop_back_val();
+ if (!Visited.insert(U))
+ continue;
+
+ if (auto *LI = dyn_cast<LoadInst>(U)) {
+ Loads.push_back(LI);
+ } else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
+ GEPs.push_back(GEP);
+ for (User *GEPUser : GEP->users()) {
+ Worklist.push_back(GEPUser);
+ }
+ } else
+ return false;
+ }
+
+ const DataLayout &DL = Arg.getParent()->getParent()->getDataLayout();
+ for (const LoadInst *LI : Loads) {
+ APInt Offset(DL.getPointerSizeInBits(), 0);
+
+ const Value *Base =
+ LI->getPointerOperand()->stripAndAccumulateConstantOffsets(
+ DL, Offset, /*AllowNonInbounds=*/false);
+
+ // Non-constant index in GEP
+ if (Base != &Arg)
+ return false;
+ }
+
+ return true;
+}
+bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
+ const DataLayout &DL = F.getParent()->getDataLayout();
+
+ SmallVector<std::tuple<unsigned, unsigned, uint64_t>, 8> NewArgMappings;
+ DenseMap<Argument *, SmallVector<LoadInst *, 8>> ArgToLoadsMap;
+ DenseMap<Argument *, SmallVector<GetElementPtrInst *, 8>> ArgToGEPsMap;
+ SmallVector<Argument *, 8> StructArgs;
+ SmallVector<Type *, 8> NewArgTypes;
+
+ // Collect struct arguments and new argument types
+ unsigned OriginalArgIndex = 0;
+ unsigned NewArgIndex = 0;
+ auto HandlePassthroughArg = [&](Argument &Arg) {
+ NewArgTypes.push_back(Arg.getType());
+
+ if (!Arg.hasAttribute(OriginalArgAttr) && NewArgIndex != OriginalArgIndex) {
+ NewArgMappings.emplace_back(NewArgIndex, OriginalArgIndex, 0);
+ }
+
+ ++NewArgIndex;
+ ++OriginalArgIndex;
+ };
+
+ for (Argument &Arg : F.args()) {
+ if (Arg.use_empty()) {
+ HandlePassthroughArg(Arg);
+ continue;
+ }
+
+ PointerType *PT = dyn_cast<PointerType>(Arg.getType());
+ if (!PT || !Arg.hasByRefAttr() ||
+ !isa<StructType>(Arg.getParamByRefType())) {
+ HandlePassthroughArg(Arg);
+ continue;
+ }
+
+ SmallVector<LoadInst *, 8> Loads;
+ SmallVector<GetElementPtrInst *, 8> GEPs;
+ if (areArgUsersValidForSplit(Arg, Loads, GEPs)) {
+ // The 'Loads' vector is currently in an unstable order. Sort it by
+ // byte offset to ensure a deterministic argument order for the new
+ // function.
+ auto GetOffset = [&](LoadInst *LI) -> uint64_t {
+ uint64_t Offset = 0;
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand())) {
+ APInt OffsetAPInt(DL.getPointerSizeInBits(), 0);
+ if (GEP->accumulateConstantOffset(DL, OffsetAPInt))
+ Offset = OffsetAPInt.getZExtValue();
+ }
+ return Offset;
+ };
+ llvm::stable_sort(Loads, [&](LoadInst *A, LoadInst *B) {
----------------
yxsamliu wrote:
Done. Changed to `llvm::sort`.
https://github.com/llvm/llvm-project/pull/133786
More information about the llvm-commits
mailing list