[llvm] [AMDGPU] Introduce address sanitizer instrumentation for LDS lowered by amdgpu-sw-lower-lds pass (PR #89208)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu May 23 04:16:36 PDT 2024


================
@@ -1239,6 +1236,290 @@ void AddressSanitizerPass::printPipeline(
   OS << '>';
 }
 
+static uint64_t getRedzoneSizeForScale(int MappingScale) {
+  // Redzone used for stack and globals is at least 32 bytes.
+  // For scales 6 and 7, the redzone has to be 64 and 128 bytes respectively.
+  return std::max(32U, 1U << MappingScale);
+}
+
+static uint64_t getMinRedzoneSizeForGlobal(int Scale) {
+  return getRedzoneSizeForScale(Scale);
+}
+
+static uint64_t getRedzoneSizeForGlobal(int Scale, uint64_t SizeInBytes) {
+  constexpr uint64_t kMaxRZ = 1 << 18;
+  const uint64_t MinRZ = getMinRedzoneSizeForGlobal(Scale);
+
+  uint64_t RZ = 0;
+  if (SizeInBytes <= MinRZ / 2) {
+    // Reduce redzone size for small size objects, e.g. int, char[1]. MinRZ is
+    // at least 32 bytes, optimize when SizeInBytes is less than or equal to
+    // half of MinRZ.
+    RZ = MinRZ - SizeInBytes;
+  } else {
+    // Calculate RZ, where MinRZ <= RZ <= MaxRZ, and RZ ~ 1/4 * SizeInBytes.
+    RZ = std::clamp((SizeInBytes / MinRZ / 4) * MinRZ, MinRZ, kMaxRZ);
+
+    // Round up to multiple of MinRZ.
+    if (SizeInBytes % MinRZ)
+      RZ += MinRZ - (SizeInBytes % MinRZ);
+  }
+
+  assert((RZ + SizeInBytes) % MinRZ == 0);
+
+  return RZ;
+}
+
+static GlobalVariable *getKernelSwLDSGlobal(Module &M, Function &F) {
+  SmallString<64> KernelLDSName("llvm.amdgcn.sw.lds.");
+  KernelLDSName += F.getName();
+  return M.getNamedGlobal(KernelLDSName);
+}
+
+static GlobalVariable *getKernelSwLDSMetadataGlobal(Module &M, Function &F) {
+  SmallString<64> KernelLDSName("llvm.amdgcn.sw.lds.");
+  KernelLDSName += F.getName();
+  KernelLDSName += ".md";
+  return M.getNamedGlobal(KernelLDSName);
+}
+
+static GlobalVariable *getKernelSwDynLDSGlobal(Module &M, Function &F) {
+  SmallString<64> KernelLDSName("llvm.amdgcn.");
+  KernelLDSName += F.getName();
+  KernelLDSName += ".dynlds";
+  return M.getNamedGlobal(KernelLDSName);
+}
+
+static GlobalVariable *getKernelSwLDSBaseGlobal(Module &M) {
+  SmallString<64> KernelLDSName("llvm.amdgcn.sw.lds.base.table");
+  return M.getNamedGlobal(KernelLDSName);
+}
+
+static void updateLDSSizeFnAttr(Function *Func, uint32_t Offset,
+                                bool UsesDynLDS) {
+  if (Offset != 0) {
+    std::string Buffer;
+    raw_string_ostream SS{Buffer};
+    SS << format("%u", Offset);
+    if (UsesDynLDS)
+      SS << format(",%u", Offset);
+    Func->addFnAttr("amdgpu-lds-size", Buffer);
+  }
+}
+
+static void recordLDSAbsoluteAddress(Module &M, GlobalVariable *GV,
+                                     uint32_t Address) {
+  LLVMContext &Ctx = M.getContext();
+  auto *IntTy = M.getDataLayout().getIntPtrType(Ctx, 3);
+  auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntTy, Address));
+  auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntTy, Address + 1));
+  GV->setMetadata(LLVMContext::MD_absolute_symbol,
+                  MDNode::get(Ctx, {MinC, MaxC}));
+}
+
+static void UpdateSwLDSMetadataWithRedzoneInfo(Function &F, int Scale) {
+  Module *M = F.getParent();
+  GlobalVariable *SwLDSMetadataGlobal = getKernelSwLDSMetadataGlobal(*M, F);
+  GlobalVariable *SwLDSGlobal = getKernelSwLDSGlobal(*M, F);
+  if (!SwLDSMetadataGlobal || !SwLDSGlobal)
+    return;
+
+  LLVMContext &Ctx = M->getContext();
+  Type *Int32Ty = Type::getInt32Ty(Ctx);
+
+  Constant *MdInit = SwLDSMetadataGlobal->getInitializer();
+  Align MdAlign = Align(SwLDSMetadataGlobal->getAlign().valueOrOne());
+  Align LDSAlign = Align(SwLDSGlobal->getAlign().valueOrOne());
+
+  StructType *MDStructType =
+      cast<StructType>(SwLDSMetadataGlobal->getValueType());
+  assert(MDStructType);
+  unsigned NumStructs = MDStructType->getNumElements();
+
+  std::vector<Type *> Items;
+  std::vector<Constant *> Initializers;
+  uint32_t MallocSize = 0;
+  //{GV.start, Align(GV.size + Redzone.size), Redzone.start, Redzone.size}
+  StructType *LDSItemTy = StructType::create(
+      Ctx, {Int32Ty, Int32Ty, Int32Ty, Int32Ty, Int32Ty}, "");
+  for (unsigned i = 0; i < NumStructs; i++) {
+    Items.push_back(LDSItemTy);
+    ConstantStruct *member =
+        dyn_cast<ConstantStruct>(MdInit->getAggregateElement(i));
+    Constant *NewInitItem;
+    if (member) {
+      ConstantInt *GlobalSize =
+          cast<ConstantInt>(member->getAggregateElement(1U));
+      unsigned GlobalSizeValue = GlobalSize->getZExtValue();
+      Constant *NewItemStartOffset = ConstantInt::get(Int32Ty, MallocSize);
+      if (GlobalSizeValue) {
+        Constant *NewItemGlobalSizeConst =
+            ConstantInt::get(Int32Ty, GlobalSizeValue);
+        const uint64_t RightRedzoneSize =
+            getRedzoneSizeForGlobal(Scale, GlobalSizeValue);
+        MallocSize += GlobalSizeValue;
+        Constant *NewItemRedzoneStartOffset =
+            ConstantInt::get(Int32Ty, MallocSize);
+        MallocSize += RightRedzoneSize;
+        Constant *NewItemRedzoneSize =
+            ConstantInt::get(Int32Ty, RightRedzoneSize);
+
+        unsigned NewItemAlignGlobalPlusRedzoneSize =
+            alignTo(GlobalSizeValue + RightRedzoneSize, LDSAlign);
+        Constant *NewItemAlignGlobalPlusRedzoneSizeConst =
+            ConstantInt::get(Int32Ty, NewItemAlignGlobalPlusRedzoneSize);
+        NewInitItem = ConstantStruct::get(
+            LDSItemTy, {NewItemStartOffset, NewItemGlobalSizeConst,
+                        NewItemAlignGlobalPlusRedzoneSizeConst,
+                        NewItemRedzoneStartOffset, NewItemRedzoneSize});
+        MallocSize = alignTo(MallocSize, LDSAlign);
+      } else {
+        Constant *CurrMallocSize = ConstantInt::get(Int32Ty, MallocSize);
+        Constant *zero = ConstantInt::get(Int32Ty, 0);
+        NewInitItem = ConstantStruct::get(
+            LDSItemTy, {CurrMallocSize, zero, zero, zero, zero});
+      }
+    } else {
+      Constant *CurrMallocSize = ConstantInt::get(Int32Ty, MallocSize);
+      Constant *zero = ConstantInt::get(Int32Ty, 0);
+      NewInitItem = ConstantStruct::get(
+          LDSItemTy, {CurrMallocSize, zero, zero, zero, zero});
+    }
+    Initializers.push_back(NewInitItem);
+  }
+  GlobalVariable *SwDynLDS = getKernelSwDynLDSGlobal(*M, F);
+  bool usesDynLDS = SwDynLDS ? true : false;
+  updateLDSSizeFnAttr(&F, MallocSize, usesDynLDS);
+  if (usesDynLDS)
+    recordLDSAbsoluteAddress(*M, SwDynLDS, MallocSize);
+
+  StructType *MetadataStructType = StructType::create(Ctx, Items, "");
+
+  GlobalVariable *NewSwLDSMetadataGlobal = new GlobalVariable(
+      *M, MetadataStructType, false, GlobalValue::InternalLinkage,
+      PoisonValue::get(MetadataStructType), "", nullptr,
+      GlobalValue::NotThreadLocal, 1, false);
+  Constant *Data = ConstantStruct::get(MetadataStructType, Initializers);
+  NewSwLDSMetadataGlobal->setInitializer(Data);
+  NewSwLDSMetadataGlobal->setAlignment(MdAlign);
+  GlobalValue::SanitizerMetadata MD;
+  MD.NoAddress = true;
+  NewSwLDSMetadataGlobal->setSanitizerMetadata(MD);
+
+  for (Use &U : make_early_inc_range(SwLDSMetadataGlobal->uses())) {
----------------
arsenm wrote:

Why do you need to reconstruct every possible user? This handling is nowhere near complete, and can't work for an arbitrary call

https://github.com/llvm/llvm-project/pull/89208


More information about the llvm-commits mailing list