[llvm] [AMDGPU] Enable amdgpu-sw-lower-lds pass to lower LDS accesses to use device global memory (PR #87265)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 2 08:04:06 PDT 2024
================
@@ -0,0 +1,865 @@
+//===-- AMDGPUSwLowerLDS.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers the local data store, LDS, uses in kernel and non-kernel
+// functions in module with dynamically allocated device global memory.
+//
+// Replacement of Kernel LDS accesses:
+// For a kernel, LDS access can be static or dynamic which are direct
+// (accessed within kernel) and indirect (accessed through non-kernels).
+// A device global memory equal to size of all these LDS globals will be
+// allocated. At the prologue of the kernel, a single work-item from the
+// work-group, does a "malloc" and stores the pointer of the allocation in
+// new LDS global that will be created for the kernel. This will be called
+// "malloc LDS global" in this pass.
+// Each LDS access corresponds to an offset in the allocated memory.
+// All static LDS accesses will be allocated first and then dynamic LDS
+// will occupy the device global memoery.
+// To store the offsets corresponding to all LDS accesses, another global
+// variable is created which will be called "metadata global" in this pass.
+// - Malloc LDS Global:
+// It is LDS global of ptr type with name
+// "llvm.amdgcn.sw.lds.<kernel-name>".
+// - Metadata Global:
+// It is of struct type, with n members. n equals the number of LDS
+// globals accessed by the kernel(direct and indirect). Each member of
+// struct is another struct of type {i32, i32}. First member corresponds
+// to offset, second member corresponds to size of LDS global being
+// replaced. It will have name "llvm.amdgcn.sw.lds.<kernel-name>.md".
+// This global will have an intializer with static LDS related offsets
+// and sizes initialized. But for dynamic LDS related entries, offsets
+// will be intialized to previous static LDS allocation end offset. Sizes
+// for them will be zero initially. These dynamic LDS offset and size
+// values will be updated with in the kernel, since kernel can read the
+// dynamic LDS size allocation done at runtime with query to
+// "hidden_dynamic_lds_size" hidden kernel argument.
+//
+// LDS accesses within the kernel will be replaced by "gep" ptr to
+// corresponding offset into allocated device global memory for the kernel.
+// At the epilogue of kernel, allocated memory would be made free by the same
+// single work-item.
+//
+// Replacement of non-kernel LDS accesses:
+// Multiple kernels can access the same non-kernel function.
+// All the kernels accessing LDS through non-kernels are sorted and
+// assigned a kernel-id. All the LDS globals accessed by non-kernels
+// are sorted. This information is used to build two tables:
+// - Base table:
+// Base table will have single row, with elements of the row
+// placed as per kernel ID. Each element in the row corresponds
+// to addresss of "malloc LDS global" variable created for
+// that kernel.
+// - Offset table:
+// Offset table will have multiple rows and columns.
+// Rows are assumed to be from 0 to (n-1). n is total number
+// of kernels accessing the LDS through non-kernels.
+// Each row will have m elements. m is the total number of
+// unique LDS globals accessed by all non-kernels.
+// Each element in the row correspond to the address of
+// the replacement of LDS global done by that particular kernel.
+// A LDS variable in non-kernel will be replaced based on the information
+// from base and offset tables. Based on kernel-id query, address of "malloc
+// LDS global" for that corresponding kernel is obtained from base table.
+// The Offset into the base "malloc LDS global" is obtained from
+// corresponding element in offset table. With this information, replacement
+// value is obtained.
+//===----------------------------------------------------------------------===//
+
+#include "AMDGPU.h"
+#include "Utils/AMDGPUMemoryUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ReplaceConstant.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
+
+#include <algorithm>
+
+#define DEBUG_TYPE "amdgpu-sw-lower-lds"
+
+using namespace llvm;
+using namespace AMDGPU;
+
+namespace {
+
+using DomTreeCallback = function_ref<DominatorTree *(Function &F)>;
+
+struct LDSAccessTypeInfo {
+ SetVector<GlobalVariable *> StaticLDSGlobals;
+ SetVector<GlobalVariable *> DynamicLDSGlobals;
+};
+
+// Struct to hold all the Metadata required for a kernel
+// to replace a LDS global uses with corresponding offset
+// in to device global memory.
+struct KernelLDSParameters {
+ GlobalVariable *MallocLDSGlobal{nullptr};
+ GlobalVariable *MallocMetadataGlobal{nullptr};
+ LDSAccessTypeInfo DirectAccess;
+ LDSAccessTypeInfo IndirectAccess;
+ DenseMap<GlobalVariable *, SmallVector<uint32_t, 3>>
+ LDSToReplacementIndicesMap;
+ int32_t KernelId{-1};
+ uint32_t MallocSize{0};
+};
+
+// Struct to store infor for creation of offset table
+// for all the non-kernel LDS accesses.
+struct NonKernelLDSParameters {
+ GlobalVariable *LDSBaseTable{nullptr};
+ GlobalVariable *LDSOffsetTable{nullptr};
+ SetVector<Function *> OrderedKernels;
+ SetVector<GlobalVariable *> OrdereLDSGlobals;
+};
+
+class AMDGPUSwLowerLDS {
+public:
+ AMDGPUSwLowerLDS(Module &mod, DomTreeCallback Callback)
+ : M(mod), IRB(M.getContext()), DTCallback(Callback) {}
+ bool Run();
+ void GetUsesOfLDSByNonKernels(CallGraph const &CG,
+ FunctionVariableMap &functions);
+ SetVector<Function *>
+ GetOrderedIndirectLDSAccessingKernels(SetVector<Function *> &&Kernels);
+ SetVector<GlobalVariable *>
+ GetOrderedNonKernelAllLDSGlobals(SetVector<GlobalVariable *> &&Variables);
+ void PopulateMallocLDSGlobal(Function *Func);
+ void PopulateMallocMetadataGlobal(Function *Func);
+ void PopulateLDSToReplacementIndicesMap(Function *Func);
+ void ReplaceKernelLDSAccesses(Function *Func);
+ void LowerKernelLDSAccesses(Function *Func, DomTreeUpdater &DTU);
+ void BuildNonKernelLDSOffsetTable(
+ std::shared_ptr<NonKernelLDSParameters> &NKLDSParams);
+ void BuildNonKernelLDSBaseTable(
+ std::shared_ptr<NonKernelLDSParameters> &NKLDSParams);
+ Constant *
+ GetAddressesOfVariablesInKernel(Function *Func,
+ SetVector<GlobalVariable *> &Variables);
+ void LowerNonKernelLDSAccesses(
+ Function *Func, SetVector<GlobalVariable *> &LDSGlobals,
+ std::shared_ptr<NonKernelLDSParameters> &NKLDSParams);
+
+private:
+ Module &M;
+ IRBuilder<> IRB;
+ DomTreeCallback DTCallback;
+ DenseMap<Function *, std::shared_ptr<KernelLDSParameters>>
+ KernelToLDSParametersMap;
+};
+
+template <typename T> SetVector<T> SortByName(std::vector<T> &&V) {
+ // Sort the vector of globals or Functions based on their name.
+ // Returns a SetVector of globals/Functions.
+ llvm::sort(V.begin(), V.end(), [](const auto *L, const auto *R) {
+ return L->getName() < R->getName();
+ });
+ return {std::move(SetVector<T>(V.begin(), V.end()))};
+}
+
+SetVector<GlobalVariable *> AMDGPUSwLowerLDS::GetOrderedNonKernelAllLDSGlobals(
+ SetVector<GlobalVariable *> &&Variables) {
+ // Sort all the non-kernel LDS accesses based on theor name.
+ SetVector<GlobalVariable *> Ordered = SortByName(
+ std::vector<GlobalVariable *>(Variables.begin(), Variables.end()));
+ return std::move(Ordered);
+}
+
+SetVector<Function *> AMDGPUSwLowerLDS::GetOrderedIndirectLDSAccessingKernels(
+ SetVector<Function *> &&Kernels) {
+ // Sort the non-kernels accessing LDS based on theor name.
+ // Also assign a kernel ID metadata based on the sorted order.
+ LLVMContext &Ctx = M.getContext();
+ if (Kernels.size() > UINT32_MAX) {
+ // 32 bit keeps it in one SGPR. > 2**32 kernels won't fit on the GPU
+ report_fatal_error("Unimplemented SW LDS lowering for > 2**32 kernels");
+ }
+ SetVector<Function *> OrderedKernels =
+ SortByName(std::vector<Function *>(Kernels.begin(), Kernels.end()));
+ for (size_t i = 0; i < Kernels.size(); i++) {
+ Metadata *AttrMDArgs[1] = {
+ ConstantAsMetadata::get(IRB.getInt32(i)),
+ };
+ Function *Func = OrderedKernels[i];
+ Func->setMetadata("llvm.amdgcn.lds.kernel.id",
+ MDNode::get(Ctx, AttrMDArgs));
+ auto &LDSParams = KernelToLDSParametersMap[Func];
+ assert(LDSParams);
+ LDSParams->KernelId = i;
+ }
+ return std::move(OrderedKernels);
+}
+
+void AMDGPUSwLowerLDS::GetUsesOfLDSByNonKernels(
+ CallGraph const &CG, FunctionVariableMap &functions) {
+ // Get uses from the current function, excluding uses by called functions
+ // Two output variables to avoid walking the globals list twice
+ for (auto &GV : M.globals()) {
+ if (!AMDGPU::isLDSVariableToLower(GV)) {
+ continue;
+ }
+
+ if (GV.isAbsoluteSymbolRef()) {
+ report_fatal_error(
+ "LDS variables with absolute addresses are unimplemented.");
+ }
+
+ for (User *V : GV.users()) {
+ User *FUU = V;
+ bool isCast = isa<BitCastOperator, AddrSpaceCastOperator>(FUU);
+ if (isCast && FUU->hasOneUse() && !FUU->user_begin()->user_empty())
+ FUU = *FUU->user_begin();
+ if (auto *I = dyn_cast<Instruction>(FUU)) {
+ Function *F = I->getFunction();
+ if (!isKernelLDS(F)) {
+ functions[F].insert(&GV);
+ }
+ }
+ }
+ }
+}
+
+void AMDGPUSwLowerLDS::PopulateMallocLDSGlobal(Function *Func) {
+ // Create new LDS global required for each kernel to store
+ // device global memory pointer.
+ auto &LDSParams = KernelToLDSParametersMap[Func];
+ assert(LDSParams);
+ // create new global pointer variable
+ LDSParams->MallocLDSGlobal = new GlobalVariable(
+ M, IRB.getPtrTy(), false, GlobalValue::InternalLinkage,
+ PoisonValue::get(IRB.getPtrTy()),
+ Twine("llvm.amdgcn.sw.lds." + Func->getName()), nullptr,
+ GlobalValue::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS, false);
+ return;
+}
+
+void AMDGPUSwLowerLDS::PopulateMallocMetadataGlobal(Function *Func) {
+ // Create new metadata global for every kernel and initialize the
+ // start offsets and sizes corresponding to each LDS accesses.
+ auto &LDSParams = KernelToLDSParametersMap[Func];
+ assert(LDSParams);
+ auto &Ctx = M.getContext();
+ auto &DL = M.getDataLayout();
+ std::vector<Type *> Items;
+ Type *Int32Ty = IRB.getInt32Ty();
+ std::vector<Constant *> Initializers;
+ Align MaxAlignment(1);
+ auto UpdateMaxAlignment = [&MaxAlignment, &DL](GlobalVariable *GV) {
+ uint32_t GVAlignValue = GV->getAlignment();
+ Align GVAlign =
+ GVAlignValue ? Align(GVAlignValue) : AMDGPU::getAlign(DL, GV);
+ MaxAlignment = std::max(MaxAlignment, GVAlign);
+ };
+
+ for (GlobalVariable *GV : LDSParams->DirectAccess.StaticLDSGlobals)
+ UpdateMaxAlignment(GV);
+
+ for (GlobalVariable *GV : LDSParams->DirectAccess.DynamicLDSGlobals)
+ UpdateMaxAlignment(GV);
+
+ for (GlobalVariable *GV : LDSParams->IndirectAccess.StaticLDSGlobals)
+ UpdateMaxAlignment(GV);
+
+ for (GlobalVariable *GV : LDSParams->IndirectAccess.DynamicLDSGlobals)
+ UpdateMaxAlignment(GV);
+
+ uint32_t MaxAlignValue = MaxAlignment.value();
+
+ //{StartOffset, SizeInBytes}
+ StructType *LDSItemTy = StructType::create(
+ Ctx, {Int32Ty, Int32Ty},
+ "llvm.amdgcn.sw.lds." + Func->getName().str() + ".md.item");
+
+ auto InitializerLamda = [&](SetVector<GlobalVariable *> &LDSGlobals) {
+ for (auto &GV : LDSGlobals) {
+ Type *Ty = GV->getValueType();
+ const uint64_t SizeInBytes = DL.getTypeAllocSize(Ty);
+ Items.push_back(LDSItemTy);
+ Constant *ItemStartOffset =
+ ConstantInt::get(Int32Ty, LDSParams->MallocSize);
+ Constant *SizeInBytesConst = ConstantInt::get(Int32Ty, SizeInBytes);
+ uint64_t AlignedSize =
+ ((SizeInBytes + MaxAlignValue - 1) / MaxAlignValue) * MaxAlignValue;
+ LDSParams->MallocSize += AlignedSize;
+ Constant *InitItem =
+ ConstantStruct::get(LDSItemTy, {ItemStartOffset, SizeInBytesConst});
+ Initializers.push_back(InitItem);
+ }
+ };
+
+ InitializerLamda(LDSParams->DirectAccess.StaticLDSGlobals);
+ InitializerLamda(LDSParams->IndirectAccess.StaticLDSGlobals);
+ InitializerLamda(LDSParams->DirectAccess.DynamicLDSGlobals);
+ InitializerLamda(LDSParams->IndirectAccess.DynamicLDSGlobals);
+
+ StructType *MetadataStructType = StructType::create(
+ Ctx, Items, ("llvm.amdgcn.sw.lds." + Func->getName().str() + ".md.type"));
+ LDSParams->MallocMetadataGlobal = new GlobalVariable(
+ M, MetadataStructType, false, GlobalValue::InternalLinkage,
+ PoisonValue::get(MetadataStructType),
+ ("llvm.amdgcn.sw.lds." + Func->getName().str() + ".md"), nullptr,
+ GlobalValue::NotThreadLocal, AMDGPUAS::GLOBAL_ADDRESS, false);
+ Constant *data = ConstantStruct::get(MetadataStructType, Initializers);
+ LDSParams->MallocMetadataGlobal->setInitializer(data);
+ LDSParams->MallocMetadataGlobal->setAlignment(MaxAlignment);
+ GlobalValue::SanitizerMetadata MD;
+ MD.NoAddress = true;
+ LDSParams->MallocMetadataGlobal->setSanitizerMetadata(MD);
+ return;
+}
+
+void AMDGPUSwLowerLDS::PopulateLDSToReplacementIndicesMap(Function *Func) {
+ // Fill the corresponding LDS replacement indices for each LDS access
+ // related to this kernel.
+ auto &LDSParams = KernelToLDSParametersMap[Func];
+ assert(LDSParams);
+ auto PopulateIndices = [&](SetVector<GlobalVariable *> &LDSGlobals,
+ uint32_t &Idx) {
+ for (auto &GV : LDSGlobals) {
+ LDSParams->LDSToReplacementIndicesMap[GV] = {0, Idx, 0};
+ ++Idx;
+ }
+ };
+ uint32_t Idx = 0;
+ PopulateIndices(LDSParams->DirectAccess.StaticLDSGlobals, Idx);
+ PopulateIndices(LDSParams->IndirectAccess.StaticLDSGlobals, Idx);
+ PopulateIndices(LDSParams->DirectAccess.DynamicLDSGlobals, Idx);
+ PopulateIndices(LDSParams->IndirectAccess.DynamicLDSGlobals, Idx);
+ return;
+}
+
+static void ReplacesUsesOfGlobalInFunction(Function *Func, GlobalVariable *GV,
+ Value *Replacement) {
+ // Replace all uses of LDS global in this Function with a Replacement.
+ auto ReplaceUsesLambda = [Func](const Use &U) -> bool {
+ auto *FUU = U.getUser();
+ bool isCast = isa<BitCastOperator, AddrSpaceCastOperator>(FUU);
+ if (isCast && FUU->hasOneUse() && !FUU->user_begin()->user_empty())
+ FUU = *FUU->user_begin();
----------------
arsenm wrote:
Same as above, can't rely on these arbitrary address restrictions. You should also be able to find it nested in any kind of ConstantExpr including ptrtoint
https://github.com/llvm/llvm-project/pull/87265
More information about the llvm-commits
mailing list