[llvm] [AMDGPU] Split struct kernel arguments (PR #133786)

Yaxun Liu via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 22 06:49:18 PST 2026


================
@@ -0,0 +1,365 @@
+//===--- AMDGPUSplitKernelArguments.cpp - Split kernel arguments ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file This pass flats struct-type kernel arguments. It eliminates unused
+// fields and only keeps used fields. The objective is to facilitate preloading
+// of kernel arguments by later passes.
+//
+//===----------------------------------------------------------------------===//
+#include "AMDGPU.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+#define DEBUG_TYPE "amdgpu-split-kernel-arguments"
+
+using namespace llvm;
+
+namespace {
+static cl::opt<bool> EnableSplitKernelArgs(
+    "amdgpu-enable-split-kernel-args",
+    cl::desc("Enable splitting of AMDGPU kernel arguments"), cl::init(false));
+
+bool parseOriginalArgAttribute(StringRef S, unsigned &RootIdx,
+                               uint64_t &BaseOff) {
+  auto Parts = S.split(':');
+  if (Parts.second.empty())
+    return false;
+  if (Parts.first.getAsInteger(10, RootIdx))
+    return false;
+  if (Parts.second.getAsInteger(10, BaseOff))
+    return false;
+  return true;
+}
+
+class AMDGPUSplitKernelArguments : public ModulePass {
+public:
+  static char ID;
+
+  AMDGPUSplitKernelArguments() : ModulePass(ID) {}
+
+  bool runOnModule(Module &M) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+  }
+
+private:
+  static inline constexpr StringRef OriginalArgAttr = "amdgpu-original-arg";
+
+  bool processFunction(Function &F);
+
+  /// Traverses all users of an argument to check if it's suitable for
+  /// splitting. A suitable argument is only used by a chain of
+  /// GEPs that terminate in LoadInsts.
+  /// @return True if the argument is suitable, false otherwise.
+  /// @param Loads [out] The list of terminating loads found.
+  /// @param GEPs [out] The list of GEPs in the use-chains.
+  bool areArgUsersValidForSplit(Argument &Arg,
+                                SmallVectorImpl<LoadInst *> &Loads,
+                                SmallVectorImpl<GetElementPtrInst *> &GEPs);
+};
+
+} // end anonymous namespace
+
+bool AMDGPUSplitKernelArguments::areArgUsersValidForSplit(
+    Argument &Arg, SmallVectorImpl<LoadInst *> &Loads,
+    SmallVectorImpl<GetElementPtrInst *> &GEPs) {
+
+  SmallVector<User *, 16> Worklist(Arg.user_begin(), Arg.user_end());
+  SetVector<User *> Visited;
+
+  while (!Worklist.empty()) {
+    User *U = Worklist.pop_back_val();
+    if (!Visited.insert(U))
+      continue;
+
+    if (auto *LI = dyn_cast<LoadInst>(U)) {
+      Loads.push_back(LI);
+    } else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      GEPs.push_back(GEP);
+      for (User *GEPUser : GEP->users()) {
+        Worklist.push_back(GEPUser);
+      }
+    } else
+      return false;
+  }
+
+  const DataLayout &DL = Arg.getParent()->getParent()->getDataLayout();
+  for (const LoadInst *LI : Loads) {
+    APInt Offset(DL.getPointerSizeInBits(), 0);
+
+    const Value *Base =
+        LI->getPointerOperand()->stripAndAccumulateConstantOffsets(
+            DL, Offset, /*AllowNonInbounds=*/false);
+
+    // Non-constant index in GEP
+    if (Base != &Arg)
+      return false;
+  }
+
+  return true;
+}
+bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
+  const DataLayout &DL = F.getParent()->getDataLayout();
+
+  SmallVector<std::tuple<unsigned, unsigned, uint64_t>, 8> NewArgMappings;
+  DenseMap<Argument *, SmallVector<LoadInst *, 8>> ArgToLoadsMap;
+  DenseMap<Argument *, SmallVector<GetElementPtrInst *, 8>> ArgToGEPsMap;
+  SmallVector<Argument *, 8> StructArgs;
+  SmallVector<Type *, 8> NewArgTypes;
+
+  // Collect struct arguments and new argument types
+  unsigned OriginalArgIndex = 0;
+  unsigned NewArgIndex = 0;
+  auto HandlePassthroughArg = [&](Argument &Arg) {
+    NewArgTypes.push_back(Arg.getType());
+
+    if (!Arg.hasAttribute(OriginalArgAttr) && NewArgIndex != OriginalArgIndex) {
+      NewArgMappings.emplace_back(NewArgIndex, OriginalArgIndex, 0);
+    }
+
+    ++NewArgIndex;
+    ++OriginalArgIndex;
+  };
+
+  for (Argument &Arg : F.args()) {
+    if (Arg.use_empty()) {
+      HandlePassthroughArg(Arg);
+      continue;
+    }
+
+    PointerType *PT = dyn_cast<PointerType>(Arg.getType());
+    if (!PT || !Arg.hasByRefAttr() ||
+        !isa<StructType>(Arg.getParamByRefType())) {
+      HandlePassthroughArg(Arg);
+      continue;
+    }
+
+    SmallVector<LoadInst *, 8> Loads;
+    SmallVector<GetElementPtrInst *, 8> GEPs;
+    if (areArgUsersValidForSplit(Arg, Loads, GEPs)) {
+      // The 'Loads' vector is currently in an unstable order. Sort it by
+      // byte offset to ensure a deterministic argument order for the new
+      // function.
+      auto GetOffset = [&](LoadInst *LI) -> uint64_t {
+        uint64_t Offset = 0;
+        if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand())) {
+          APInt OffsetAPInt(DL.getPointerSizeInBits(), 0);
+          if (GEP->accumulateConstantOffset(DL, OffsetAPInt))
+            Offset = OffsetAPInt.getZExtValue();
+        }
+        return Offset;
+      };
+      llvm::stable_sort(Loads, [&](LoadInst *A, LoadInst *B) {
----------------
yxsamliu wrote:

Done. Changed to `llvm::sort`.

https://github.com/llvm/llvm-project/pull/133786


More information about the llvm-commits mailing list