[llvm] [AMDGPU] Introduce address sanitizer instrumentation for LDS lowered by amdgpu-sw-lower-lds pass (PR #89208)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Thu May 23 04:16:36 PDT 2024
================
@@ -1239,6 +1236,290 @@ void AddressSanitizerPass::printPipeline(
OS << '>';
}
+static uint64_t getRedzoneSizeForScale(int MappingScale) {
+ // Redzone used for stack and globals is at least 32 bytes.
+ // For scales 6 and 7, the redzone has to be 64 and 128 bytes respectively.
+ return std::max(32U, 1U << MappingScale);
+}
+
+static uint64_t getMinRedzoneSizeForGlobal(int Scale) {
+ return getRedzoneSizeForScale(Scale);
+}
+
+static uint64_t getRedzoneSizeForGlobal(int Scale, uint64_t SizeInBytes) {
+ constexpr uint64_t kMaxRZ = 1 << 18;
+ const uint64_t MinRZ = getMinRedzoneSizeForGlobal(Scale);
+
+ uint64_t RZ = 0;
+ if (SizeInBytes <= MinRZ / 2) {
+ // Reduce redzone size for small size objects, e.g. int, char[1]. MinRZ is
+ // at least 32 bytes, optimize when SizeInBytes is less than or equal to
+ // half of MinRZ.
+ RZ = MinRZ - SizeInBytes;
+ } else {
+ // Calculate RZ, where MinRZ <= RZ <= MaxRZ, and RZ ~ 1/4 * SizeInBytes.
+ RZ = std::clamp((SizeInBytes / MinRZ / 4) * MinRZ, MinRZ, kMaxRZ);
+
+ // Round up to multiple of MinRZ.
+ if (SizeInBytes % MinRZ)
+ RZ += MinRZ - (SizeInBytes % MinRZ);
+ }
+
+ assert((RZ + SizeInBytes) % MinRZ == 0);
+
+ return RZ;
+}
+
+static GlobalVariable *getKernelSwLDSGlobal(Module &M, Function &F) {
+ SmallString<64> KernelLDSName("llvm.amdgcn.sw.lds.");
+ KernelLDSName += F.getName();
+ return M.getNamedGlobal(KernelLDSName);
+}
+
+static GlobalVariable *getKernelSwLDSMetadataGlobal(Module &M, Function &F) {
+ SmallString<64> KernelLDSName("llvm.amdgcn.sw.lds.");
+ KernelLDSName += F.getName();
+ KernelLDSName += ".md";
+ return M.getNamedGlobal(KernelLDSName);
+}
+
+static GlobalVariable *getKernelSwDynLDSGlobal(Module &M, Function &F) {
+ SmallString<64> KernelLDSName("llvm.amdgcn.");
+ KernelLDSName += F.getName();
+ KernelLDSName += ".dynlds";
+ return M.getNamedGlobal(KernelLDSName);
+}
+
+static GlobalVariable *getKernelSwLDSBaseGlobal(Module &M) {
+ SmallString<64> KernelLDSName("llvm.amdgcn.sw.lds.base.table");
+ return M.getNamedGlobal(KernelLDSName);
+}
+
+static void updateLDSSizeFnAttr(Function *Func, uint32_t Offset,
+ bool UsesDynLDS) {
+ if (Offset != 0) {
+ std::string Buffer;
+ raw_string_ostream SS{Buffer};
+ SS << format("%u", Offset);
+ if (UsesDynLDS)
+ SS << format(",%u", Offset);
+ Func->addFnAttr("amdgpu-lds-size", Buffer);
+ }
+}
+
+static void recordLDSAbsoluteAddress(Module &M, GlobalVariable *GV,
+ uint32_t Address) {
+ LLVMContext &Ctx = M.getContext();
+ auto *IntTy = M.getDataLayout().getIntPtrType(Ctx, 3);
+ auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntTy, Address));
+ auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntTy, Address + 1));
+ GV->setMetadata(LLVMContext::MD_absolute_symbol,
+ MDNode::get(Ctx, {MinC, MaxC}));
+}
+
+static void UpdateSwLDSMetadataWithRedzoneInfo(Function &F, int Scale) {
+ Module *M = F.getParent();
+ GlobalVariable *SwLDSMetadataGlobal = getKernelSwLDSMetadataGlobal(*M, F);
+ GlobalVariable *SwLDSGlobal = getKernelSwLDSGlobal(*M, F);
+ if (!SwLDSMetadataGlobal || !SwLDSGlobal)
+ return;
+
+ LLVMContext &Ctx = M->getContext();
+ Type *Int32Ty = Type::getInt32Ty(Ctx);
+
+ Constant *MdInit = SwLDSMetadataGlobal->getInitializer();
+ Align MdAlign = Align(SwLDSMetadataGlobal->getAlign().valueOrOne());
+ Align LDSAlign = Align(SwLDSGlobal->getAlign().valueOrOne());
+
+ StructType *MDStructType =
+ cast<StructType>(SwLDSMetadataGlobal->getValueType());
+ assert(MDStructType);
+ unsigned NumStructs = MDStructType->getNumElements();
+
+ std::vector<Type *> Items;
+ std::vector<Constant *> Initializers;
+ uint32_t MallocSize = 0;
+ //{GV.start, Align(GV.size + Redzone.size), Redzone.start, Redzone.size}
+ StructType *LDSItemTy = StructType::create(
+ Ctx, {Int32Ty, Int32Ty, Int32Ty, Int32Ty, Int32Ty}, "");
+ for (unsigned i = 0; i < NumStructs; i++) {
+ Items.push_back(LDSItemTy);
+ ConstantStruct *member =
+ dyn_cast<ConstantStruct>(MdInit->getAggregateElement(i));
+ Constant *NewInitItem;
+ if (member) {
+ ConstantInt *GlobalSize =
+ cast<ConstantInt>(member->getAggregateElement(1U));
+ unsigned GlobalSizeValue = GlobalSize->getZExtValue();
+ Constant *NewItemStartOffset = ConstantInt::get(Int32Ty, MallocSize);
+ if (GlobalSizeValue) {
+ Constant *NewItemGlobalSizeConst =
+ ConstantInt::get(Int32Ty, GlobalSizeValue);
+ const uint64_t RightRedzoneSize =
+ getRedzoneSizeForGlobal(Scale, GlobalSizeValue);
+ MallocSize += GlobalSizeValue;
+ Constant *NewItemRedzoneStartOffset =
+ ConstantInt::get(Int32Ty, MallocSize);
+ MallocSize += RightRedzoneSize;
+ Constant *NewItemRedzoneSize =
+ ConstantInt::get(Int32Ty, RightRedzoneSize);
+
+ unsigned NewItemAlignGlobalPlusRedzoneSize =
+ alignTo(GlobalSizeValue + RightRedzoneSize, LDSAlign);
+ Constant *NewItemAlignGlobalPlusRedzoneSizeConst =
+ ConstantInt::get(Int32Ty, NewItemAlignGlobalPlusRedzoneSize);
+ NewInitItem = ConstantStruct::get(
+ LDSItemTy, {NewItemStartOffset, NewItemGlobalSizeConst,
+ NewItemAlignGlobalPlusRedzoneSizeConst,
+ NewItemRedzoneStartOffset, NewItemRedzoneSize});
+ MallocSize = alignTo(MallocSize, LDSAlign);
+ } else {
+ Constant *CurrMallocSize = ConstantInt::get(Int32Ty, MallocSize);
+ Constant *zero = ConstantInt::get(Int32Ty, 0);
+ NewInitItem = ConstantStruct::get(
+ LDSItemTy, {CurrMallocSize, zero, zero, zero, zero});
+ }
+ } else {
+ Constant *CurrMallocSize = ConstantInt::get(Int32Ty, MallocSize);
+ Constant *zero = ConstantInt::get(Int32Ty, 0);
+ NewInitItem = ConstantStruct::get(
+ LDSItemTy, {CurrMallocSize, zero, zero, zero, zero});
+ }
+ Initializers.push_back(NewInitItem);
+ }
+ GlobalVariable *SwDynLDS = getKernelSwDynLDSGlobal(*M, F);
+ bool usesDynLDS = SwDynLDS ? true : false;
+ updateLDSSizeFnAttr(&F, MallocSize, usesDynLDS);
+ if (usesDynLDS)
+ recordLDSAbsoluteAddress(*M, SwDynLDS, MallocSize);
+
+ StructType *MetadataStructType = StructType::create(Ctx, Items, "");
+
+ GlobalVariable *NewSwLDSMetadataGlobal = new GlobalVariable(
+ *M, MetadataStructType, false, GlobalValue::InternalLinkage,
+ PoisonValue::get(MetadataStructType), "", nullptr,
+ GlobalValue::NotThreadLocal, 1, false);
+ Constant *Data = ConstantStruct::get(MetadataStructType, Initializers);
+ NewSwLDSMetadataGlobal->setInitializer(Data);
+ NewSwLDSMetadataGlobal->setAlignment(MdAlign);
+ GlobalValue::SanitizerMetadata MD;
+ MD.NoAddress = true;
+ NewSwLDSMetadataGlobal->setSanitizerMetadata(MD);
+
+ for (Use &U : make_early_inc_range(SwLDSMetadataGlobal->uses())) {
+ if (GEPOperator *GEP = dyn_cast<GEPOperator>(U.getUser())) {
+ SmallVector<Constant *> Indices;
+ for (Use &Idx : GEP->indices()) {
+ Indices.push_back(cast<Constant>(Idx));
+ }
+ Constant *NewGEP = ConstantExpr::getGetElementPtr(
+ MetadataStructType, NewSwLDSMetadataGlobal, Indices, true);
+ GEP->replaceAllUsesWith(NewGEP);
+ } else if (LoadInst *Load = dyn_cast<LoadInst>(U.getUser())) {
+ Constant *zero = ConstantInt::get(Int32Ty, 0);
+ SmallVector<Constant *> Indices{zero, zero, zero};
+ Constant *NewGEP = ConstantExpr::getGetElementPtr(
+ MetadataStructType, NewSwLDSMetadataGlobal, Indices, true);
+ IRBuilder<> IRB(Load);
+ LoadInst *NewLoad = IRB.CreateLoad(Load->getType(), NewGEP);
+ Load->replaceAllUsesWith(NewLoad);
+ Load->eraseFromParent();
+ } else if (StoreInst *Store = dyn_cast<StoreInst>(U.getUser())) {
+ Constant *zero = ConstantInt::get(Int32Ty, 0);
+ SmallVector<Constant *> Indices{zero, zero, zero};
+ Constant *NewGEP = ConstantExpr::getGetElementPtr(
+ MetadataStructType, NewSwLDSMetadataGlobal, Indices, true);
+ IRBuilder<> IRB(Store);
+ StoreInst *NewStore = IRB.CreateStore(Store->getValueOperand(), NewGEP);
+ Store->replaceAllUsesWith(NewStore);
+ Store->eraseFromParent();
+ } else
+ report_fatal_error("AMDGPU Sw LDS Metadata User instruction not handled");
+ }
+ SwLDSMetadataGlobal->replaceAllUsesWith(NewSwLDSMetadataGlobal);
+ NewSwLDSMetadataGlobal->takeName(SwLDSMetadataGlobal);
+ SwLDSMetadataGlobal->eraseFromParent();
+ return;
+}
+
+static void poisonRedzonesForSwLDS(Function &F) {
+ Module *M = F.getParent();
+ GlobalVariable *SwLDSGlobal = getKernelSwLDSGlobal(*M, F);
+ GlobalVariable *SwLDSMetadataGlobal = getKernelSwLDSMetadataGlobal(*M, F);
+
+ if (!SwLDSGlobal || !SwLDSMetadataGlobal)
+ return;
+
+ LLVMContext &Ctx = M->getContext();
+ Type *Int64Ty = Type::getInt64Ty(Ctx);
+ Type *VoidTy = Type::getVoidTy(Ctx);
+ FunctionCallee AsanPoisonRegion = M->getOrInsertFunction(
+ StringRef("__asan_poison_region"),
+ FunctionType::get(VoidTy, {Int64Ty, Int64Ty}, false));
+ Constant *MdInit = SwLDSMetadataGlobal->getInitializer();
+
+ for (User *U : SwLDSGlobal->users()) {
+ StoreInst *SI = dyn_cast<StoreInst>(U);
+ if (!SI)
+ continue;
+
+ Type *PtrTy =
+ cast<PointerType>(SI->getValueOperand()->getType()->getScalarType());
+ unsigned int AddrSpace = PtrTy->getPointerAddressSpace();
+ if (AddrSpace != 1)
+ report_fatal_error("AMDGPU illegal store to SW LDS");
+
+ StructType *MDStructType =
+ cast<StructType>(SwLDSMetadataGlobal->getValueType());
+ assert(MDStructType);
----------------
arsenm wrote:
This is redundant with cast
https://github.com/llvm/llvm-project/pull/89208
More information about the llvm-commits
mailing list