[llvm] [AMDGPU] Split struct kernel arguments (PR #133786)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 31 17:31:12 PDT 2025


================
@@ -0,0 +1,373 @@
+//===--- AMDGPUSplitKernelArguments.cpp - Split kernel arguments ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file This pass flats struct-type kernel arguments. It eliminates unused
+// fields and only keeps used fields. The objective is to facilitate preloading
+// of kernel arguments by later passes.
+//
+//===----------------------------------------------------------------------===//
+#include "AMDGPU.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+#define DEBUG_TYPE "amdgpu-split-kernel-arguments"
+
+using namespace llvm;
+
+namespace {
+static llvm::cl::opt<bool> EnableSplitKernelArgs(
+    "amdgpu-enable-split-kernel-args",
+    llvm::cl::desc("Enable splitting of AMDGPU kernel arguments"),
+    llvm::cl::init(false));
+
+class AMDGPUSplitKernelArguments : public ModulePass {
+public:
+  static char ID;
+
+  AMDGPUSplitKernelArguments() : ModulePass(ID) {}
+
+  bool runOnModule(Module &M) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+  }
+
+private:
+  bool processFunction(Function &F);
+};
+
+} // end anonymous namespace
+
+bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
+  const DataLayout &DL = F.getParent()->getDataLayout();
+  LLVM_DEBUG(dbgs() << "Entering AMDGPUSplitKernelArguments::processFunction "
+                    << F.getName() << '\n');
+  if (F.isDeclaration()) {
+    LLVM_DEBUG(dbgs() << "Function is a declaration, skipping\n");
+    return false;
+  }
+
+  CallingConv::ID CC = F.getCallingConv();
+  if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty()) {
+    LLVM_DEBUG(dbgs() << "non-kernel or arg_empty\n");
+    return false;
+  }
+
+  SmallVector<std::tuple<unsigned, unsigned, uint64_t>, 8> NewArgMappings;
+  DenseMap<Argument *, SmallVector<LoadInst *, 8>> ArgToLoadsMap;
+  DenseMap<Argument *, SmallVector<GetElementPtrInst *, 8>> ArgToGEPsMap;
+  SmallVector<Argument *, 8> StructArgs;
+  SmallVector<Type *, 8> NewArgTypes;
+
+  auto convertAddressSpace = [](Type *Ty) -> Type * {
+    if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+      if (PtrTy->getAddressSpace() == AMDGPUAS::FLAT_ADDRESS) {
+        return PointerType::get(PtrTy->getContext(), AMDGPUAS::GLOBAL_ADDRESS);
+      }
+    }
+    return Ty;
+  };
+
+  // Collect struct arguments and new argument types
+  unsigned OriginalArgIndex = 0;
+  unsigned NewArgIndex = 0;
+  for (Argument &Arg : F.args()) {
+    LLVM_DEBUG(dbgs() << "Processing argument: " << Arg << "\n");
+    if (Arg.use_empty()) {
+      NewArgTypes.push_back(convertAddressSpace(Arg.getType()));
+      NewArgMappings.push_back(
+          std::make_tuple(NewArgIndex, OriginalArgIndex, 0));
+      ++NewArgIndex;
+      ++OriginalArgIndex;
+      LLVM_DEBUG(dbgs() << "use empty\n");
+      continue;
+    }
+
+    PointerType *PT = dyn_cast<PointerType>(Arg.getType());
+    if (!PT) {
+      NewArgTypes.push_back(Arg.getType());
+      LLVM_DEBUG(dbgs() << "not a pointer\n");
+      // Include mapping if indices have changed
+      if (NewArgIndex != OriginalArgIndex)
+        NewArgMappings.push_back(
+            std::make_tuple(NewArgIndex, OriginalArgIndex, 0));
+      ++NewArgIndex;
+      ++OriginalArgIndex;
+      continue;
+    }
+
+    const bool IsByRef = Arg.hasByRefAttr();
+    if (!IsByRef) {
+      NewArgTypes.push_back(Arg.getType());
+      LLVM_DEBUG(dbgs() << "not byref\n");
----------------
arsenm wrote:

```suggestion
```

https://github.com/llvm/llvm-project/pull/133786


More information about the llvm-commits mailing list