[llvm] [AMDGPU] Split struct kernel arguments (PR #133786)

Yaxun Liu via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 23 20:55:36 PDT 2025


https://github.com/yxsamliu updated https://github.com/llvm/llvm-project/pull/133786

>From 4fe4ed1ff0e446ead58fce1f15ffff3b91195196 Mon Sep 17 00:00:00 2001
From: "Yaxun (Sam) Liu" <yaxun.liu at amd.com>
Date: Mon, 31 Mar 2025 20:01:32 -0400
Subject: [PATCH] [AMDGPU] Split struct kernel arguments

AMDGPU backend has a pass which does transformations to allow
firmware to preload kernel arguments into sgpr's to avoid
loading them from kernel arg segment. This pass can improve
kernel latency but it cannot preload struct-type kernel
arguments.

This patch adds a pass to AMDGPU backend to split and flatten
struct-type kernel arguments so that later passes can
preload them into sgpr's.

Basically, the pass collects load or GEP/load instructions
with struct-type kenel args as operands and makes them
new arguments as the kernel. If all uses of a struct-type
kernel arg can be replaced, it will do the replacements
and create a new kernel with the new signature, and
translate all instructions of the old kernel to use
the new arguments in the new kernel. It adds a function
attribute to encode the mapping from the new kernel
argument index to the old kernel argument index and
offset. The streamer will generate kernel argument
metadata based on that and runtime will process
the kernel arguments based on the metadata.

The pass is disabled by default and can be enabled
by LLVM option `-amdgpu-enable-split-kernel-args`.
---
 llvm/lib/Target/AMDGPU/AMDGPU.h               |   9 +
 .../AMDGPU/AMDGPUHSAMetadataStreamer.cpp      |  33 +-
 .../Target/AMDGPU/AMDGPUHSAMetadataStreamer.h |   4 +-
 llvm/lib/Target/AMDGPU/AMDGPUPassRegistry.def |   1 +
 .../AMDGPU/AMDGPUSplitKernelArguments.cpp     | 365 ++++++++++++++++++
 .../lib/Target/AMDGPU/AMDGPUTargetMachine.cpp |   6 +
 llvm/lib/Target/AMDGPU/CMakeLists.txt         |   1 +
 .../AMDGPU/amdgpu-split-kernel-args.ll        | 252 ++++++++++++
 llvm/test/CodeGen/AMDGPU/llc-pipeline.ll      |   4 +
 9 files changed, 671 insertions(+), 4 deletions(-)
 create mode 100644 llvm/lib/Target/AMDGPU/AMDGPUSplitKernelArguments.cpp
 create mode 100644 llvm/test/CodeGen/AMDGPU/amdgpu-split-kernel-args.ll

diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.h b/llvm/lib/Target/AMDGPU/AMDGPU.h
index 71dd99c0d7a53..3f2e9d70b0f73 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.h
@@ -122,6 +122,15 @@ struct AMDGPUPromoteKernelArgumentsPass
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
 };
 
+ModulePass *createAMDGPUSplitKernelArgumentsPass();
+void initializeAMDGPUSplitKernelArgumentsPass(PassRegistry &);
+extern char &AMDGPUSplitKernelArgumentsID;
+
+struct AMDGPUSplitKernelArgumentsPass
+    : PassInfoMixin<AMDGPUSplitKernelArgumentsPass> {
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+};
+
 ModulePass *createAMDGPULowerKernelAttributesPass();
 void initializeAMDGPULowerKernelAttributesPass(PassRegistry &);
 extern char &AMDGPULowerKernelAttributesID;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp
index 2991778a1bbc7..45998b2d05044 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.cpp
@@ -357,17 +357,38 @@ void MetadataStreamerMsgPackV4::emitKernelArg(const Argument &Arg,
   Align ArgAlign;
   std::tie(ArgTy, ArgAlign) = getArgumentTypeAlign(Arg, DL);
 
+  // Assuming the argument is not split from struct-type argument by default,
+  // unless we find it in function attribute amdgpu-argument-mapping.
+  unsigned OriginalArgIndex = ~0U;
+  uint64_t OriginalArgOffset = 0;
+  Attribute Attr =
+      Func->getAttributes().getParamAttr(ArgNo, "amdgpu-original-arg");
+  if (Attr.isValid()) {
+    StringRef MappingStr = Attr.getValueAsString();
+    SmallVector<StringRef, 2> Elements;
+    MappingStr.split(Elements, ':');
+    if (Elements.size() == 2) {
+      if (Elements[0].getAsInteger(10, OriginalArgIndex))
+        report_fatal_error(
+            "Invalid original argument index in amdgpu-original-arg attribute");
+      if (Elements[1].getAsInteger(10, OriginalArgOffset))
+        report_fatal_error("Invalid original argument offset in "
+                           "amdgpu-original-arg attribute");
+    }
+  }
+
   emitKernelArg(DL, ArgTy, ArgAlign,
                 getValueKind(ArgTy, TypeQual, BaseTypeName), Offset, Args,
-                PointeeAlign, Name, TypeName, BaseTypeName, ActAccQual,
-                AccQual, TypeQual);
+                PointeeAlign, Name, TypeName, BaseTypeName, ActAccQual, AccQual,
+                TypeQual, OriginalArgIndex, OriginalArgOffset);
 }
 
 void MetadataStreamerMsgPackV4::emitKernelArg(
     const DataLayout &DL, Type *Ty, Align Alignment, StringRef ValueKind,
     unsigned &Offset, msgpack::ArrayDocNode Args, MaybeAlign PointeeAlign,
     StringRef Name, StringRef TypeName, StringRef BaseTypeName,
-    StringRef ActAccQual, StringRef AccQual, StringRef TypeQual) {
+    StringRef ActAccQual, StringRef AccQual, StringRef TypeQual,
+    unsigned OriginalArgIndex, uint64_t OriginalArgOffset) {
   auto Arg = Args.getDocument()->getMapNode();
 
   if (!Name.empty())
@@ -409,6 +430,12 @@ void MetadataStreamerMsgPackV4::emitKernelArg(
       Arg[".is_pipe"] = Arg.getDocument()->getNode(true);
   }
 
+  // Add original argument index and offset to the metadata
+  if (OriginalArgIndex != ~0U) {
+    Arg[".original_arg_index"] = Arg.getDocument()->getNode(OriginalArgIndex);
+    Arg[".original_arg_offset"] = Arg.getDocument()->getNode(OriginalArgOffset);
+  }
+
   Args.push_back(Arg);
 }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.h b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.h
index 22dfcb4a4ec1d..edf31ad0d6c43 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUHSAMetadataStreamer.h
@@ -116,7 +116,9 @@ class LLVM_EXTERNAL_VISIBILITY MetadataStreamerMsgPackV4
                      MaybeAlign PointeeAlign = std::nullopt,
                      StringRef Name = "", StringRef TypeName = "",
                      StringRef BaseTypeName = "", StringRef ActAccQual = "",
-                     StringRef AccQual = "", StringRef TypeQual = "");
+                     StringRef AccQual = "", StringRef TypeQual = "",
+                     unsigned OriginalArgIndex = ~0U,
+                     uint64_t OriginalArgOffset = 0);
 
   void emitHiddenKernelArgs(const MachineFunction &MF, unsigned &Offset,
                             msgpack::ArrayDocNode Args) override;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPassRegistry.def b/llvm/lib/Target/AMDGPU/AMDGPUPassRegistry.def
index 13453963eec6d..a10b9a882b702 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUPassRegistry.def
+++ b/llvm/lib/Target/AMDGPU/AMDGPUPassRegistry.def
@@ -30,6 +30,7 @@ MODULE_PASS("amdgpu-printf-runtime-binding", AMDGPUPrintfRuntimeBindingPass())
 MODULE_PASS("amdgpu-remove-incompatible-functions", AMDGPURemoveIncompatibleFunctionsPass(*this))
 MODULE_PASS("amdgpu-sw-lower-lds", AMDGPUSwLowerLDSPass(*this))
 MODULE_PASS("amdgpu-unify-metadata", AMDGPUUnifyMetadataPass())
+MODULE_PASS("amdgpu-split-kernel-arguments", AMDGPUSplitKernelArgumentsPass())
 #undef MODULE_PASS
 
 #ifndef MODULE_PASS_WITH_PARAMS
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSplitKernelArguments.cpp b/llvm/lib/Target/AMDGPU/AMDGPUSplitKernelArguments.cpp
new file mode 100644
index 0000000000000..b5b7f4f709c1f
--- /dev/null
+++ b/llvm/lib/Target/AMDGPU/AMDGPUSplitKernelArguments.cpp
@@ -0,0 +1,365 @@
+//===--- AMDGPUSplitKernelArguments.cpp - Split kernel arguments ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file This pass flats struct-type kernel arguments. It eliminates unused
+// fields and only keeps used fields. The objective is to facilitate preloading
+// of kernel arguments by later passes.
+//
+//===----------------------------------------------------------------------===//
+#include "AMDGPU.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+#define DEBUG_TYPE "amdgpu-split-kernel-arguments"
+
+using namespace llvm;
+
+namespace {
+static cl::opt<bool> EnableSplitKernelArgs(
+    "amdgpu-enable-split-kernel-args",
+    cl::desc("Enable splitting of AMDGPU kernel arguments"), cl::init(false));
+
+bool parseOriginalArgAttribute(StringRef S, unsigned &RootIdx,
+                               uint64_t &BaseOff) {
+  auto Parts = S.split(':');
+  if (Parts.second.empty())
+    return false;
+  if (Parts.first.getAsInteger(10, RootIdx))
+    return false;
+  if (Parts.second.getAsInteger(10, BaseOff))
+    return false;
+  return true;
+}
+
+class AMDGPUSplitKernelArguments : public ModulePass {
+public:
+  static char ID;
+
+  AMDGPUSplitKernelArguments() : ModulePass(ID) {}
+
+  bool runOnModule(Module &M) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+  }
+
+private:
+  static inline constexpr StringRef OriginalArgAttr = "amdgpu-original-arg";
+
+  bool processFunction(Function &F);
+
+  /// Traverses all users of an argument to check if it's suitable for
+  /// splitting. A suitable argument is only used by a chain of
+  /// GEPs that terminate in LoadInsts.
+  /// @return True if the argument is suitable, false otherwise.
+  /// @param Loads [out] The list of terminating loads found.
+  /// @param GEPs [out] The list of GEPs in the use-chains.
+  bool areArgUsersValidForSplit(Argument &Arg,
+                                SmallVectorImpl<LoadInst *> &Loads,
+                                SmallVectorImpl<GetElementPtrInst *> &GEPs);
+};
+
+} // end anonymous namespace
+
+bool AMDGPUSplitKernelArguments::areArgUsersValidForSplit(
+    Argument &Arg, SmallVectorImpl<LoadInst *> &Loads,
+    SmallVectorImpl<GetElementPtrInst *> &GEPs) {
+
+  SmallVector<User *, 16> Worklist(Arg.user_begin(), Arg.user_end());
+  SetVector<User *> Visited;
+
+  while (!Worklist.empty()) {
+    User *U = Worklist.pop_back_val();
+    if (!Visited.insert(U))
+      continue;
+
+    if (auto *LI = dyn_cast<LoadInst>(U)) {
+      Loads.push_back(LI);
+    } else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
+      GEPs.push_back(GEP);
+      for (User *GEPUser : GEP->users()) {
+        Worklist.push_back(GEPUser);
+      }
+    } else
+      return false;
+  }
+
+  const DataLayout &DL = Arg.getParent()->getParent()->getDataLayout();
+  for (const LoadInst *LI : Loads) {
+    APInt Offset(DL.getPointerSizeInBits(), 0);
+
+    const Value *Base =
+        LI->getPointerOperand()->stripAndAccumulateConstantOffsets(
+            DL, Offset, /*AllowNonInbounds=*/false);
+
+    // Non-constant index in GEP
+    if (Base != &Arg)
+      return false;
+  }
+
+  return true;
+}
+bool AMDGPUSplitKernelArguments::processFunction(Function &F) {
+  const DataLayout &DL = F.getParent()->getDataLayout();
+
+  SmallVector<std::tuple<unsigned, unsigned, uint64_t>, 8> NewArgMappings;
+  DenseMap<Argument *, SmallVector<LoadInst *, 8>> ArgToLoadsMap;
+  DenseMap<Argument *, SmallVector<GetElementPtrInst *, 8>> ArgToGEPsMap;
+  SmallVector<Argument *, 8> StructArgs;
+  SmallVector<Type *, 8> NewArgTypes;
+
+  // Collect struct arguments and new argument types
+  unsigned OriginalArgIndex = 0;
+  unsigned NewArgIndex = 0;
+  auto HandlePassthroughArg = [&](Argument &Arg) {
+    NewArgTypes.push_back(Arg.getType());
+
+    if (!Arg.hasAttribute(OriginalArgAttr) && NewArgIndex != OriginalArgIndex) {
+      NewArgMappings.emplace_back(NewArgIndex, OriginalArgIndex, 0);
+    }
+
+    ++NewArgIndex;
+    ++OriginalArgIndex;
+  };
+
+  for (Argument &Arg : F.args()) {
+    if (Arg.use_empty()) {
+      HandlePassthroughArg(Arg);
+      continue;
+    }
+
+    PointerType *PT = dyn_cast<PointerType>(Arg.getType());
+    if (!PT || !Arg.hasByRefAttr() ||
+        !isa<StructType>(Arg.getParamByRefType())) {
+      HandlePassthroughArg(Arg);
+      continue;
+    }
+
+    SmallVector<LoadInst *, 8> Loads;
+    SmallVector<GetElementPtrInst *, 8> GEPs;
+    if (areArgUsersValidForSplit(Arg, Loads, GEPs)) {
+      // The 'Loads' vector is currently in an unstable order. Sort it by
+      // byte offset to ensure a deterministic argument order for the new
+      // function.
+      auto GetOffset = [&](LoadInst *LI) -> uint64_t {
+        uint64_t Offset = 0;
+        if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand())) {
+          APInt OffsetAPInt(DL.getPointerSizeInBits(), 0);
+          if (GEP->accumulateConstantOffset(DL, OffsetAPInt))
+            Offset = OffsetAPInt.getZExtValue();
+        }
+        return Offset;
+      };
+      llvm::stable_sort(Loads, [&](LoadInst *A, LoadInst *B) {
+        return GetOffset(A) < GetOffset(B);
+      });
+
+      unsigned RootIdx = OriginalArgIndex;
+      uint64_t BaseOffset = 0;
+
+      // If the argument has been shifted or split before due to previous
+      // kernel argument splitting passes, get the real original index
+      // and offset before any passes.
+      if (Arg.hasAttribute(OriginalArgAttr)) {
+        Attribute Attr =
+            F.getAttributeAtIndex(OriginalArgIndex, OriginalArgAttr);
+        (void)parseOriginalArgAttribute(Attr.getValueAsString(), RootIdx,
+                                        BaseOffset);
+      }
+
+      StructArgs.push_back(&Arg);
+      ArgToLoadsMap[&Arg] = Loads;
+      ArgToGEPsMap[&Arg] = GEPs;
+      for (LoadInst *LI : Loads) {
+        NewArgTypes.push_back(LI->getType());
+
+        uint64_t LocalOff = GetOffset(LI);
+        uint64_t FinalOff = BaseOffset + LocalOff;
+
+        // Map each new argument to the real original argument index and offset
+        // before any passes
+        NewArgMappings.emplace_back(NewArgIndex, RootIdx, FinalOff);
+        ++NewArgIndex;
+      }
+      ++OriginalArgIndex;
+      continue;
+    }
+    // Argument is not suitable for splitting, treat as passthrough.
+    HandlePassthroughArg(Arg);
+  }
+
+  if (StructArgs.empty())
+    return false;
+
+  // Collect function and return attributes
+  AttributeList OldAttrs = F.getAttributes();
+  AttributeSet FnAttrs = OldAttrs.getFnAttrs();
+  AttributeSet RetAttrs = OldAttrs.getRetAttrs();
+
+  // Create new function type
+  FunctionType *NewFT =
+      FunctionType::get(F.getReturnType(), NewArgTypes, F.isVarArg());
+  Function *NewF =
+      Function::Create(NewFT, F.getLinkage(), F.getAddressSpace(), F.getName());
+  F.getParent()->getFunctionList().insert(F.getIterator(), NewF);
+  NewF->takeName(&F);
+  NewF->setVisibility(F.getVisibility());
+  if (F.hasComdat())
+    NewF->setComdat(F.getComdat());
+  NewF->setDSOLocal(F.isDSOLocal());
+  NewF->setUnnamedAddr(F.getUnnamedAddr());
+  NewF->setCallingConv(F.getCallingConv());
+
+  // Build new parameter attributes
+  SmallVector<AttributeSet, 8> NewArgAttrSets;
+  NewArgIndex = 0;
+  for (Argument &Arg : F.args()) {
+    if (ArgToLoadsMap.count(&Arg)) {
+      for ([[maybe_unused]] LoadInst *LI : ArgToLoadsMap[&Arg]) {
+        NewArgAttrSets.push_back(AttributeSet());
+        ++NewArgIndex;
+      }
+    } else {
+      AttributeSet ArgAttrs = OldAttrs.getParamAttrs(Arg.getArgNo());
+      NewArgAttrSets.push_back(ArgAttrs);
+      ++NewArgIndex;
+    }
+  }
+
+  // Build the new AttributeList
+  AttributeList NewAttrList =
+      AttributeList::get(F.getContext(), FnAttrs, RetAttrs, NewArgAttrSets);
+  NewF->setAttributes(NewAttrList);
+
+  // Add the mapping to the old arguments as function argument
+  // attribute in the format "OriginalArgIndex:Offset"
+  // Note: NewArgMappings contains mapping to the real original index and
+  // offset before any split kernel argument passes.
+  for (const auto &Info : NewArgMappings) {
+    unsigned NewArgIdx, RootArgIdx;
+    uint64_t Offset;
+    std::tie(NewArgIdx, RootArgIdx, Offset) = Info;
+    NewF->addParamAttr(
+        NewArgIdx,
+        Attribute::get(NewF->getContext(), OriginalArgAttr,
+                       (Twine(RootArgIdx) + ":" + Twine(Offset)).str()));
+  }
+
+  LLVM_DEBUG(dbgs() << "New empty function:\n" << *NewF << '\n');
+
+  NewF->splice(NewF->begin(), &F);
+
+  // Map old arguments and loads to new arguments
+  DenseMap<Value *, Value *> VMap;
+  auto NewArgIt = NewF->arg_begin();
+  for (Argument &Arg : F.args()) {
+    if (ArgToLoadsMap.contains(&Arg)) {
+      for (LoadInst *LI : ArgToLoadsMap[&Arg]) {
+        NewArgIt->takeName(LI);
+        Value *NewArg = &*NewArgIt++;
+        if (isa<PointerType>(NewArg->getType()) &&
+            isa<PointerType>(LI->getType())) {
+          IRBuilder<> Builder(LI);
+          Value *CastedArg = Builder.CreatePointerBitCastOrAddrSpaceCast(
+              NewArg, LI->getType());
+          VMap[LI] = CastedArg;
+        } else {
+          VMap[LI] = NewArg;
+        }
+      }
+      PoisonValue *PoisonArg = PoisonValue::get(Arg.getType());
+      Arg.replaceAllUsesWith(PoisonArg);
+    } else {
+      NewArgIt->takeName(&Arg);
+      Value *NewArg = &*NewArgIt;
+      if (isa<PointerType>(NewArg->getType()) &&
+          isa<PointerType>(Arg.getType())) {
+        IRBuilder<> Builder(&*NewF->begin()->begin());
+        Value *CastedArg =
+            Builder.CreatePointerBitCastOrAddrSpaceCast(NewArg, Arg.getType());
+        Arg.replaceAllUsesWith(CastedArg);
+      } else {
+        Arg.replaceAllUsesWith(NewArg);
+      }
+      ++NewArgIt;
+    }
+  }
+
+  // Replace LoadInsts with new arguments
+  for (auto &Entry : ArgToLoadsMap) {
+    for (LoadInst *LI : Entry.second) {
+      Value *NewArg = VMap[LI];
+      LI->replaceAllUsesWith(NewArg);
+      LI->eraseFromParent();
+    }
+  }
+
+  // Erase GEPs
+  for (auto &Entry : ArgToGEPsMap) {
+    for (GetElementPtrInst *GEP : Entry.second) {
+      GEP->replaceAllUsesWith(PoisonValue::get(GEP->getType()));
+      GEP->eraseFromParent();
+    }
+  }
+
+  LLVM_DEBUG(dbgs() << "New function after transformation:\n" << *NewF << '\n');
+
+  F.replaceAllUsesWith(NewF);
+  F.eraseFromParent();
+
+  return true;
+}
+
+bool AMDGPUSplitKernelArguments::runOnModule(Module &M) {
+  if (!EnableSplitKernelArgs)
+    return false;
+  bool Changed = false;
+  SmallVector<Function *, 16> FunctionsToProcess;
+
+  for (Function &F : M) {
+    if (F.isDeclaration() || F.getCallingConv() != CallingConv::AMDGPU_KERNEL ||
+        F.arg_empty())
+      continue;
+    FunctionsToProcess.push_back(&F);
+  }
+
+  for (Function *F : FunctionsToProcess)
+    Changed |= processFunction(*F);
+
+  return Changed;
+}
+
+INITIALIZE_PASS_BEGIN(AMDGPUSplitKernelArguments, DEBUG_TYPE,
+                      "AMDGPU Split Kernel Arguments", false, false)
+INITIALIZE_PASS_END(AMDGPUSplitKernelArguments, DEBUG_TYPE,
+                    "AMDGPU Split Kernel Arguments", false, false)
+
+char AMDGPUSplitKernelArguments::ID = 0;
+
+ModulePass *llvm::createAMDGPUSplitKernelArgumentsPass() {
+  return new AMDGPUSplitKernelArguments();
+}
+
+PreservedAnalyses
+AMDGPUSplitKernelArgumentsPass::run(Module &M, ModuleAnalysisManager &AM) {
+  AMDGPUSplitKernelArguments Splitter;
+  bool Changed = Splitter.runOnModule(M);
+
+  if (!Changed)
+    return PreservedAnalyses::all();
+
+  return PreservedAnalyses::none();
+}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index d2e4825cf3c81..e0e4da5b194ff 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -522,6 +522,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAMDGPUTarget() {
   initializeAMDGPUAtomicOptimizerPass(*PR);
   initializeAMDGPULowerKernelArgumentsPass(*PR);
   initializeAMDGPUPromoteKernelArgumentsPass(*PR);
+  initializeAMDGPUSplitKernelArgumentsPass(*PR);
   initializeAMDGPULowerKernelAttributesPass(*PR);
   initializeAMDGPUExportKernelRuntimeHandlesLegacyPass(*PR);
   initializeAMDGPUPostLegalizerCombinerPass(*PR);
@@ -901,6 +902,7 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
     if (Level != OptimizationLevel::O0) {
       if (!isLTOPreLink(Phase)) {
         AMDGPUAttributorOptions Opts;
+        MPM.addPass(AMDGPUSplitKernelArgumentsPass());
         MPM.addPass(AMDGPUAttributorPass(*this, Opts, Phase));
       }
     }
@@ -933,6 +935,7 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
             PM.addPass(InternalizePass(mustPreserveGV));
             PM.addPass(GlobalDCEPass());
           }
+          PM.addPass(AMDGPUSplitKernelArgumentsPass());
           if (EnableAMDGPUAttributor) {
             AMDGPUAttributorOptions Opt;
             if (HasClosedWorldAssumption)
@@ -1291,6 +1294,9 @@ void AMDGPUPassConfig::addIRPasses() {
     addPass(createAMDGPULowerModuleLDSLegacyPass(&TM));
   }
 
+  if (TM.getOptLevel() > CodeGenOptLevel::None)
+    addPass(createAMDGPUSplitKernelArgumentsPass());
+
   // Run atomic optimizer before Atomic Expand
   if ((TM.getTargetTriple().isAMDGCN()) &&
       (TM.getOptLevel() >= CodeGenOptLevel::Less) &&
diff --git a/llvm/lib/Target/AMDGPU/CMakeLists.txt b/llvm/lib/Target/AMDGPU/CMakeLists.txt
index 928a5001e0c98..741b62edc8ec5 100644
--- a/llvm/lib/Target/AMDGPU/CMakeLists.txt
+++ b/llvm/lib/Target/AMDGPU/CMakeLists.txt
@@ -92,6 +92,7 @@ add_llvm_target(AMDGPUCodeGen
   AMDGPUPrintfRuntimeBinding.cpp
   AMDGPUPromoteAlloca.cpp
   AMDGPUPromoteKernelArguments.cpp
+  AMDGPUSplitKernelArguments.cpp
   AMDGPURegBankCombiner.cpp
   AMDGPURegBankLegalize.cpp
   AMDGPURegBankLegalizeHelper.cpp
diff --git a/llvm/test/CodeGen/AMDGPU/amdgpu-split-kernel-args.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-split-kernel-args.ll
new file mode 100644
index 0000000000000..a7eace7145c2a
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-split-kernel-args.ll
@@ -0,0 +1,252 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-attributes --check-globals all --version 5
+; RUN: opt -S -mtriple=amdgcn-amd-amdhsa -passes=amdgpu-split-kernel-arguments -amdgpu-enable-split-kernel-args < %s | FileCheck %s
+; RUN: opt -S -mtriple=amdgcn-amd-amdhsa -passes=amdgpu-split-kernel-arguments -amdgpu-enable-split-kernel-args < %s > %t.ll
+; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx1100 < %t.ll | FileCheck --check-prefix=GCN %s
+;
+; The LLVM IR is from the following HIP program:
+;
+; struct A {
+; int i;
+; char c;
+; long l;
+; int *p;
+; };
+
+; struct B {
+; char c;
+; A a1;
+; int i;
+; A a2;
+; };
+;
+; __global__ void test(int *out, int i, A a, char c, B b) {
+;  *out = i + a.l + c + a.l + b.a1.c;
+;  b.a2.p[2] = a.l + b.a2.c;
+;}
+;
+%struct.A = type { i32, i8, i64, ptr }
+%struct.B = type { i8, %struct.A, i32, %struct.A }
+
+; The "amdgpu-original-arg" function parameter attribute encodes how is the
+; argument split from the original kernel argument.
+;
+; Format: "amdgpu-original-arg"="OrigIndex:OrigOffset"
+; - OrigIndex: Index of the original kernel argument before splitting
+; - OrigOffset: Byte offset within the original argument
+
+;--- Main test case for successful split ---
+
+define amdgpu_kernel void @_Z4testPii1Ac1B(
+; CHECK-LABEL: define amdgpu_kernel void @_Z4testPii1Ac1B(
+; CHECK-SAME: ptr addrspace(1) noundef writeonly captures(none) initializes((0, 4)) [[OUT:%.*]], i32 noundef [[I:%.*]], i64 "amdgpu-original-arg"="2:8" [[A_L:%.*]], i8 noundef [[C:%.*]], i8 "amdgpu-original-arg"="4:12" [[B_A1_C:%.*]], i8 "amdgpu-original-arg"="4:44" [[B_A2_C:%.*]], ptr "amdgpu-original-arg"="4:56" [[B_A2_P:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[B_A2_P]] to ptr addrspace(1)
+; CHECK-NEXT:    [[CONV:%.*]] = zext i32 [[I]] to i64
+; CHECK-NEXT:    [[CONV3:%.*]] = sext i8 [[C]] to i64
+; CHECK-NEXT:    [[CONV8:%.*]] = sext i8 [[B_A1_C]] to i64
+; CHECK-NEXT:    [[FACTOR:%.*]] = shl i64 [[A_L]], 1
+; CHECK-NEXT:    [[ADD4:%.*]] = add nsw i64 [[CONV]], [[CONV3]]
+; CHECK-NEXT:    [[ADD6:%.*]] = add i64 [[ADD4]], [[FACTOR]]
+; CHECK-NEXT:    [[ADD9:%.*]] = add i64 [[ADD6]], [[CONV8]]
+; CHECK-NEXT:    [[CONV10:%.*]] = trunc i64 [[ADD9]] to i32
+; CHECK-NEXT:    store i32 [[CONV10]], ptr addrspace(1) [[OUT]], align 4
+; CHECK-NEXT:    [[B_A2_C_SEXT:%.*]] = sext i8 [[B_A2_C]] to i64
+; CHECK-NEXT:    [[ADD14:%.*]] = add nsw i64 [[A_L]], [[B_A2_C_SEXT]]
+; CHECK-NEXT:    [[CONV15:%.*]] = trunc i64 [[ADD14]] to i32
+; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr addrspace(1) [[TMP1]], i64 2
+; CHECK-NEXT:    store i32 [[CONV15]], ptr addrspace(1) [[ARRAYIDX]], align 4
+; CHECK-NEXT:    ret void
+;
+  ptr addrspace(1) noundef writeonly captures(none) initializes((0, 4)) %out,
+  i32 noundef %i,
+  ptr addrspace(4) noundef readonly byref(%struct.A) align 8 captures(none) %a,
+  i8 noundef %c,
+  ptr addrspace(4) noundef readonly byref(%struct.B) align 8 captures(none) %b
+) {
+entry:
+  ; Load a.l from struct A
+  %a.l.addr = getelementptr inbounds nuw i8, ptr addrspace(4) %a, i64 8
+  %a.l = load i64, ptr addrspace(4) %a.l.addr, align 8
+
+  ; Load b.a1.c from struct B
+  %b.a1.c.addr = getelementptr inbounds nuw i8, ptr addrspace(4) %b, i64 12
+  %b.a1.c = load i8, ptr addrspace(4) %b.a1.c.addr, align 4
+
+  ; Load b.a2.c from struct B
+  %b.a2.c.addr = getelementptr inbounds nuw i8, ptr addrspace(4) %b, i64 44
+  %b.a2.c = load i8, ptr addrspace(4) %b.a2.c.addr, align 4
+
+  ; Load b.a2.p from struct B
+  %b.a2.p.addr = getelementptr inbounds nuw i8, ptr addrspace(4) %b, i64 56
+  %b.a2.p = load ptr, ptr addrspace(4) %b.a2.p.addr, align 8
+
+  ; Cast b.a2.p to global address space
+  %b.a2.p.global = addrspacecast ptr %b.a2.p to ptr addrspace(1)
+
+  ; Compute i + a.l + c + a.l + b.a1.c
+  %i.zext = zext i32 %i to i64
+  %c.sext = sext i8 %c to i64
+  %b.a1.c.sext = sext i8 %b.a1.c to i64
+  %a.l.x2 = shl i64 %a.l, 1
+
+  %tmp_sum1 = add nsw i64 %i.zext, %c.sext
+  %tmp_sum2 = add i64 %tmp_sum1, %a.l.x2
+  %tmp_sum3 = add i64 %tmp_sum2, %b.a1.c.sext
+  %result = trunc i64 %tmp_sum3 to i32
+  store i32 %result, ptr addrspace(1) %out, align 4
+
+  ; Compute a.l + b.a2.c and store to b.a2.p[2]
+  %b.a2.c.sext = sext i8 %b.a2.c to i64
+  %sum_store = add nsw i64 %a.l, %b.a2.c.sext
+  %store_val = trunc i64 %sum_store to i32
+
+  %b.a2.p.elem2 = getelementptr inbounds i32, ptr addrspace(1) %b.a2.p.global, i64 2
+  store i32 %store_val, ptr addrspace(1) %b.a2.p.elem2, align 4
+
+  ret void
+}
+
+; --- Re-split test: arg #0 and #1 were split previously, arg #2 is byref struct ---
+; The second run of the pass must flatten arg #2 into three new parameters.
+;
+define amdgpu_kernel void @test_resplit(
+; CHECK-LABEL: define amdgpu_kernel void @test_resplit(
+; CHECK-SAME: i32 "amdgpu-original-arg"="1:0" [[A_I:%.*]], i64 "amdgpu-original-arg"="1:8" [[A_L:%.*]], i8 "amdgpu-original-arg"="2:12" [[B_A1_C:%.*]], i8 "amdgpu-original-arg"="2:44" [[B_A2_C:%.*]], ptr addrspace(1) noundef "amdgpu-original-arg"="3:0" [[DST:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[A_I_ZEXT:%.*]] = zext i32 [[A_I]] to i64
+; CHECK-NEXT:    [[B_A1_C_SEXT:%.*]] = sext i8 [[B_A1_C]] to i64
+; CHECK-NEXT:    [[B_A2_C_SEXT:%.*]] = sext i8 [[B_A2_C]] to i64
+; CHECK-NEXT:    [[SUM1:%.*]] = add i64 [[A_I_ZEXT]], [[A_L]]
+; CHECK-NEXT:    [[SUM2:%.*]] = add i64 [[SUM1]], [[B_A1_C_SEXT]]
+; CHECK-NEXT:    [[SUM3:%.*]] = add i64 [[SUM2]], [[B_A2_C_SEXT]]
+; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i64 [[SUM3]] to i32
+; CHECK-NEXT:    store i32 [[TRUNC]], ptr addrspace(1) [[DST]], align 4
+; CHECK-NEXT:    ret void
+;
+  i32 "amdgpu-original-arg"="1:0"  %a.i,          ; piece of original arg #1
+  i64 "amdgpu-original-arg"="1:8"  %a.l,          ; piece of original arg #1
+  ptr addrspace(4) noundef readonly
+  byref(%struct.B) align 8 %b,               ; original arg #2 (to split)
+  ptr addrspace(1) noundef %dst) {               ; ordinary output pointer
+entry:
+  ; load b.a1.c  (offset 12)
+  %b.a1.c.addr = getelementptr inbounds i8, ptr addrspace(4) %b, i64 12
+  %b.a1.c      = load i8, ptr addrspace(4) %b.a1.c.addr, align 4
+
+  ; load b.a2.c  (offset 44)
+  %b.a2.c.addr = getelementptr inbounds i8, ptr addrspace(4) %b, i64 44
+  %b.a2.c      = load i8, ptr addrspace(4) %b.a2.c.addr, align 4
+
+  ; sum up and store to the separate dst pointer
+  %a.i.zext      = zext i32 %a.i to i64
+  %b.a1.c.sext   = sext i8  %b.a1.c to i64
+  %b.a2.c.sext   = sext i8  %b.a2.c to i64
+  %sum1          = add i64 %a.i.zext, %a.l
+  %sum2          = add i64 %sum1, %b.a1.c.sext
+  %sum3          = add i64 %sum2, %b.a2.c.sext
+  %sum.trunc     = trunc i64 %sum3 to i32
+  store i32 %sum.trunc, ptr addrspace(1) %dst, align 4
+  ret void
+}
+
+; --- Additional test cases for passthrough logic ---
+
+; Test case for a struct argument that is never used.
+; Should not be split.
+define amdgpu_kernel void @test_unused_arg(ptr addrspace(4) byref(%struct.A) %unused_arg) {
+; CHECK-LABEL: define amdgpu_kernel void @test_unused_arg(
+; CHECK-SAME: ptr addrspace(4) byref([[STRUCT_A:%.*]]) [[UNUSED_ARG:%.*]]) {
+; CHECK-NEXT:    ret void
+;
+  ret void
+}
+
+; Test case for a pointer argument that does not have the 'byref' attribute.
+; Should not be split.
+define amdgpu_kernel void @test_no_byref_arg(ptr %ptr) {
+; CHECK-LABEL: define amdgpu_kernel void @test_no_byref_arg(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = load i32, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret void
+;
+  %val = load i32, ptr %ptr, align 4
+  ret void
+}
+
+; Test case for a 'byref' argument that points to a non-struct type.
+; Should not be split.
+define amdgpu_kernel void @test_byref_non_struct_arg(ptr byref(i32) %ptr) {
+; CHECK-LABEL: define amdgpu_kernel void @test_byref_non_struct_arg(
+; CHECK-SAME: ptr byref(i32) [[PTR:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = load i32, ptr [[PTR]], align 4
+; CHECK-NEXT:    ret void
+;
+  %val = load i32, ptr %ptr, align 4
+  ret void
+}
+
+; Test case for an argument that is used by an unsupported instruction (a store).
+; Should not be split.
+define amdgpu_kernel void @test_unsupported_user(ptr byref(%struct.A) %a) {
+; CHECK-LABEL: define amdgpu_kernel void @test_unsupported_user(
+; CHECK-SAME: ptr byref([[STRUCT_A:%.*]]) [[A:%.*]]) {
+; CHECK-NEXT:    store ptr null, ptr [[A]], align 8
+; CHECK-NEXT:    ret void
+;
+  store ptr null, ptr %a, align 8
+  ret void
+}
+
+; Test case for a load from a GEP with a variable, non-constant offset.
+; Should not be split.
+define amdgpu_kernel void @test_variable_offset(ptr byref(%struct.A) %a, i32 %idx) {
+  ; GEP into the 'p' field (a ptr) with a variable index.
+; CHECK-LABEL: define amdgpu_kernel void @test_variable_offset(
+; CHECK-SAME: ptr byref([[STRUCT_A:%.*]]) [[A:%.*]], i32 [[IDX:%.*]]) {
+; CHECK-NEXT:    [[P_FIELD_PTR:%.*]] = getelementptr inbounds [[STRUCT_A]], ptr [[A]], i32 [[IDX]]
+; CHECK-NEXT:    [[P_FIELD_VAL:%.*]] = load ptr, ptr [[P_FIELD_PTR]], align 8
+; CHECK-NEXT:    ret void
+;
+  %p_field_ptr = getelementptr inbounds %struct.A, ptr %a, i32 %idx
+  %p_field_val = load ptr, ptr %p_field_ptr, align 8
+  ret void
+}
+
+; GCN:   - .address_space:  global
+; GCN-NEXT:     .name:           out
+; GCN-NEXT:     .offset:         0
+; GCN-NEXT:     .size:           8
+; GCN-NEXT:     .value_kind:     global_buffer
+; GCN-NEXT:   - .name:           i
+; GCN-NEXT:     .offset:         8
+; GCN-NEXT:     .size:           4
+; GCN-NEXT:     .value_kind:     by_value
+; GCN-NEXT:   - .name:           a.l
+; GCN-NEXT:     .offset:         16
+; GCN-NEXT:     .original_arg_index: 2
+; GCN-NEXT:     .original_arg_offset: 8
+; GCN-NEXT:     .size:           8
+; GCN-NEXT:     .value_kind:     by_value
+; GCN-NEXT:   - .name:           c
+; GCN-NEXT:     .offset:         24
+; GCN-NEXT:     .size:           1
+; GCN-NEXT:     .value_kind:     by_value
+; GCN-NEXT:   - .name:           b.a1.c
+; GCN-NEXT:     .offset:         25
+; GCN-NEXT:     .original_arg_index: 4
+; GCN-NEXT:     .original_arg_offset: 12
+; GCN-NEXT:     .size:           1
+; GCN-NEXT:     .value_kind:     by_value
+; GCN-NEXT:   - .name:           b.a2.c
+; GCN-NEXT:     .offset:         26
+; GCN-NEXT:     .original_arg_index: 4
+; GCN-NEXT:     .original_arg_offset: 44
+; GCN-NEXT:     .size:           1
+; GCN-NEXT:     .value_kind:     by_value
+; GCN-NEXT:   - .address_space:  generic
+; GCN-NEXT:     .name:           b.a2.p
+; GCN-NEXT:     .offset:         32
+; GCN-NEXT:     .original_arg_index: 4
+; GCN-NEXT:     .original_arg_offset: 56
+; GCN-NEXT:     .size:           8
+; GCN-NEXT:     .value_kind:     global_buffer
diff --git a/llvm/test/CodeGen/AMDGPU/llc-pipeline.ll b/llvm/test/CodeGen/AMDGPU/llc-pipeline.ll
index dd2ff2e013cc8..356ef710463ff 100644
--- a/llvm/test/CodeGen/AMDGPU/llc-pipeline.ll
+++ b/llvm/test/CodeGen/AMDGPU/llc-pipeline.ll
@@ -187,6 +187,7 @@
 ; GCN-O1-NEXT:    Externalize enqueued block runtime handles
 ; GCN-O1-NEXT:    AMDGPU Software lowering of LDS
 ; GCN-O1-NEXT:    Lower uses of LDS variables from non-kernel functions
+; GCN-O1-NEXT:    AMDGPU Split Kernel Arguments
 ; GCN-O1-NEXT:    FunctionPass Manager
 ; GCN-O1-NEXT:      Dominator Tree Construction
 ; GCN-O1-NEXT:      Cycle Info Analysis
@@ -469,6 +470,7 @@
 ; GCN-O1-OPTS-NEXT:    Externalize enqueued block runtime handles
 ; GCN-O1-OPTS-NEXT:    AMDGPU Software lowering of LDS
 ; GCN-O1-OPTS-NEXT:    Lower uses of LDS variables from non-kernel functions
+; GCN-O1-OPTS-NEXT:    AMDGPU Split Kernel Arguments
 ; GCN-O1-OPTS-NEXT:    FunctionPass Manager
 ; GCN-O1-OPTS-NEXT:      Dominator Tree Construction
 ; GCN-O1-OPTS-NEXT:      Cycle Info Analysis
@@ -781,6 +783,7 @@
 ; GCN-O2-NEXT:    Externalize enqueued block runtime handles
 ; GCN-O2-NEXT:    AMDGPU Software lowering of LDS
 ; GCN-O2-NEXT:    Lower uses of LDS variables from non-kernel functions
+; GCN-O2-NEXT:    AMDGPU Split Kernel Arguments
 ; GCN-O2-NEXT:    FunctionPass Manager
 ; GCN-O2-NEXT:      Dominator Tree Construction
 ; GCN-O2-NEXT:      Cycle Info Analysis
@@ -1097,6 +1100,7 @@
 ; GCN-O3-NEXT:    Externalize enqueued block runtime handles
 ; GCN-O3-NEXT:    AMDGPU Software lowering of LDS
 ; GCN-O3-NEXT:    Lower uses of LDS variables from non-kernel functions
+; GCN-O3-NEXT:    AMDGPU Split Kernel Arguments
 ; GCN-O3-NEXT:    FunctionPass Manager
 ; GCN-O3-NEXT:      Dominator Tree Construction
 ; GCN-O3-NEXT:      Cycle Info Analysis



More information about the llvm-commits mailing list