[llvm] [AMDGPU] Introduce address sanitizer instrumentation for LDS lowered by amdgpu-sw-lower-lds pass (PR #89208)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu May 23 04:16:36 PDT 2024


================
@@ -1239,6 +1236,290 @@ void AddressSanitizerPass::printPipeline(
   OS << '>';
 }
 
+static uint64_t getRedzoneSizeForScale(int MappingScale) {
+  // Redzone used for stack and globals is at least 32 bytes.
+  // For scales 6 and 7, the redzone has to be 64 and 128 bytes respectively.
+  return std::max(32U, 1U << MappingScale);
+}
+
+static uint64_t getMinRedzoneSizeForGlobal(int Scale) {
+  return getRedzoneSizeForScale(Scale);
+}
+
+static uint64_t getRedzoneSizeForGlobal(int Scale, uint64_t SizeInBytes) {
+  constexpr uint64_t kMaxRZ = 1 << 18;
+  const uint64_t MinRZ = getMinRedzoneSizeForGlobal(Scale);
+
+  uint64_t RZ = 0;
+  if (SizeInBytes <= MinRZ / 2) {
+    // Reduce redzone size for small size objects, e.g. int, char[1]. MinRZ is
+    // at least 32 bytes, optimize when SizeInBytes is less than or equal to
+    // half of MinRZ.
+    RZ = MinRZ - SizeInBytes;
+  } else {
+    // Calculate RZ, where MinRZ <= RZ <= MaxRZ, and RZ ~ 1/4 * SizeInBytes.
+    RZ = std::clamp((SizeInBytes / MinRZ / 4) * MinRZ, MinRZ, kMaxRZ);
+
+    // Round up to multiple of MinRZ.
+    if (SizeInBytes % MinRZ)
+      RZ += MinRZ - (SizeInBytes % MinRZ);
+  }
+
+  assert((RZ + SizeInBytes) % MinRZ == 0);
+
+  return RZ;
+}
+
+static GlobalVariable *getKernelSwLDSGlobal(Module &M, Function &F) {
+  SmallString<64> KernelLDSName("llvm.amdgcn.sw.lds.");
+  KernelLDSName += F.getName();
+  return M.getNamedGlobal(KernelLDSName);
+}
+
+static GlobalVariable *getKernelSwLDSMetadataGlobal(Module &M, Function &F) {
+  SmallString<64> KernelLDSName("llvm.amdgcn.sw.lds.");
+  KernelLDSName += F.getName();
+  KernelLDSName += ".md";
+  return M.getNamedGlobal(KernelLDSName);
+}
+
+static GlobalVariable *getKernelSwDynLDSGlobal(Module &M, Function &F) {
+  SmallString<64> KernelLDSName("llvm.amdgcn.");
+  KernelLDSName += F.getName();
+  KernelLDSName += ".dynlds";
+  return M.getNamedGlobal(KernelLDSName);
+}
+
+static GlobalVariable *getKernelSwLDSBaseGlobal(Module &M) {
+  SmallString<64> KernelLDSName("llvm.amdgcn.sw.lds.base.table");
+  return M.getNamedGlobal(KernelLDSName);
+}
+
+static void updateLDSSizeFnAttr(Function *Func, uint32_t Offset,
+                                bool UsesDynLDS) {
+  if (Offset != 0) {
+    std::string Buffer;
+    raw_string_ostream SS{Buffer};
+    SS << format("%u", Offset);
+    if (UsesDynLDS)
+      SS << format(",%u", Offset);
+    Func->addFnAttr("amdgpu-lds-size", Buffer);
+  }
+}
+
+static void recordLDSAbsoluteAddress(Module &M, GlobalVariable *GV,
+                                     uint32_t Address) {
+  LLVMContext &Ctx = M.getContext();
+  auto *IntTy = M.getDataLayout().getIntPtrType(Ctx, 3);
+  auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntTy, Address));
+  auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntTy, Address + 1));
+  GV->setMetadata(LLVMContext::MD_absolute_symbol,
+                  MDNode::get(Ctx, {MinC, MaxC}));
+}
+
+static void UpdateSwLDSMetadataWithRedzoneInfo(Function &F, int Scale) {
+  Module *M = F.getParent();
+  GlobalVariable *SwLDSMetadataGlobal = getKernelSwLDSMetadataGlobal(*M, F);
+  GlobalVariable *SwLDSGlobal = getKernelSwLDSGlobal(*M, F);
+  if (!SwLDSMetadataGlobal || !SwLDSGlobal)
+    return;
+
+  LLVMContext &Ctx = M->getContext();
+  Type *Int32Ty = Type::getInt32Ty(Ctx);
+
+  Constant *MdInit = SwLDSMetadataGlobal->getInitializer();
+  Align MdAlign = Align(SwLDSMetadataGlobal->getAlign().valueOrOne());
+  Align LDSAlign = Align(SwLDSGlobal->getAlign().valueOrOne());
+
+  StructType *MDStructType =
+      cast<StructType>(SwLDSMetadataGlobal->getValueType());
+  assert(MDStructType);
+  unsigned NumStructs = MDStructType->getNumElements();
+
+  std::vector<Type *> Items;
+  std::vector<Constant *> Initializers;
+  uint32_t MallocSize = 0;
+  //{GV.start, Align(GV.size + Redzone.size), Redzone.start, Redzone.size}
+  StructType *LDSItemTy = StructType::create(
+      Ctx, {Int32Ty, Int32Ty, Int32Ty, Int32Ty, Int32Ty}, "");
+  for (unsigned i = 0; i < NumStructs; i++) {
+    Items.push_back(LDSItemTy);
+    ConstantStruct *member =
+        dyn_cast<ConstantStruct>(MdInit->getAggregateElement(i));
+    Constant *NewInitItem;
+    if (member) {
+      ConstantInt *GlobalSize =
+          cast<ConstantInt>(member->getAggregateElement(1U));
+      unsigned GlobalSizeValue = GlobalSize->getZExtValue();
+      Constant *NewItemStartOffset = ConstantInt::get(Int32Ty, MallocSize);
+      if (GlobalSizeValue) {
+        Constant *NewItemGlobalSizeConst =
+            ConstantInt::get(Int32Ty, GlobalSizeValue);
+        const uint64_t RightRedzoneSize =
+            getRedzoneSizeForGlobal(Scale, GlobalSizeValue);
+        MallocSize += GlobalSizeValue;
+        Constant *NewItemRedzoneStartOffset =
+            ConstantInt::get(Int32Ty, MallocSize);
+        MallocSize += RightRedzoneSize;
+        Constant *NewItemRedzoneSize =
+            ConstantInt::get(Int32Ty, RightRedzoneSize);
+
+        unsigned NewItemAlignGlobalPlusRedzoneSize =
+            alignTo(GlobalSizeValue + RightRedzoneSize, LDSAlign);
+        Constant *NewItemAlignGlobalPlusRedzoneSizeConst =
+            ConstantInt::get(Int32Ty, NewItemAlignGlobalPlusRedzoneSize);
+        NewInitItem = ConstantStruct::get(
+            LDSItemTy, {NewItemStartOffset, NewItemGlobalSizeConst,
+                        NewItemAlignGlobalPlusRedzoneSizeConst,
+                        NewItemRedzoneStartOffset, NewItemRedzoneSize});
+        MallocSize = alignTo(MallocSize, LDSAlign);
+      } else {
+        Constant *CurrMallocSize = ConstantInt::get(Int32Ty, MallocSize);
+        Constant *zero = ConstantInt::get(Int32Ty, 0);
+        NewInitItem = ConstantStruct::get(
+            LDSItemTy, {CurrMallocSize, zero, zero, zero, zero});
+      }
+    } else {
+      Constant *CurrMallocSize = ConstantInt::get(Int32Ty, MallocSize);
+      Constant *zero = ConstantInt::get(Int32Ty, 0);
+      NewInitItem = ConstantStruct::get(
+          LDSItemTy, {CurrMallocSize, zero, zero, zero, zero});
+    }
+    Initializers.push_back(NewInitItem);
+  }
+  GlobalVariable *SwDynLDS = getKernelSwDynLDSGlobal(*M, F);
+  bool usesDynLDS = SwDynLDS ? true : false;
+  updateLDSSizeFnAttr(&F, MallocSize, usesDynLDS);
+  if (usesDynLDS)
+    recordLDSAbsoluteAddress(*M, SwDynLDS, MallocSize);
+
+  StructType *MetadataStructType = StructType::create(Ctx, Items, "");
+
+  GlobalVariable *NewSwLDSMetadataGlobal = new GlobalVariable(
+      *M, MetadataStructType, false, GlobalValue::InternalLinkage,
+      PoisonValue::get(MetadataStructType), "", nullptr,
+      GlobalValue::NotThreadLocal, 1, false);
+  Constant *Data = ConstantStruct::get(MetadataStructType, Initializers);
+  NewSwLDSMetadataGlobal->setInitializer(Data);
+  NewSwLDSMetadataGlobal->setAlignment(MdAlign);
+  GlobalValue::SanitizerMetadata MD;
+  MD.NoAddress = true;
+  NewSwLDSMetadataGlobal->setSanitizerMetadata(MD);
+
+  for (Use &U : make_early_inc_range(SwLDSMetadataGlobal->uses())) {
+    if (GEPOperator *GEP = dyn_cast<GEPOperator>(U.getUser())) {
+      SmallVector<Constant *> Indices;
+      for (Use &Idx : GEP->indices()) {
+        Indices.push_back(cast<Constant>(Idx));
+      }
+      Constant *NewGEP = ConstantExpr::getGetElementPtr(
+          MetadataStructType, NewSwLDSMetadataGlobal, Indices, true);
+      GEP->replaceAllUsesWith(NewGEP);
+    } else if (LoadInst *Load = dyn_cast<LoadInst>(U.getUser())) {
+      Constant *zero = ConstantInt::get(Int32Ty, 0);
+      SmallVector<Constant *> Indices{zero, zero, zero};
+      Constant *NewGEP = ConstantExpr::getGetElementPtr(
+          MetadataStructType, NewSwLDSMetadataGlobal, Indices, true);
+      IRBuilder<> IRB(Load);
+      LoadInst *NewLoad = IRB.CreateLoad(Load->getType(), NewGEP);
+      Load->replaceAllUsesWith(NewLoad);
+      Load->eraseFromParent();
+    } else if (StoreInst *Store = dyn_cast<StoreInst>(U.getUser())) {
+      Constant *zero = ConstantInt::get(Int32Ty, 0);
+      SmallVector<Constant *> Indices{zero, zero, zero};
+      Constant *NewGEP = ConstantExpr::getGetElementPtr(
+          MetadataStructType, NewSwLDSMetadataGlobal, Indices, true);
+      IRBuilder<> IRB(Store);
+      StoreInst *NewStore = IRB.CreateStore(Store->getValueOperand(), NewGEP);
+      Store->replaceAllUsesWith(NewStore);
+      Store->eraseFromParent();
+    } else
+      report_fatal_error("AMDGPU Sw LDS Metadata User instruction not handled");
+  }
+  SwLDSMetadataGlobal->replaceAllUsesWith(NewSwLDSMetadataGlobal);
+  NewSwLDSMetadataGlobal->takeName(SwLDSMetadataGlobal);
+  SwLDSMetadataGlobal->eraseFromParent();
+  return;
+}
+
+static void poisonRedzonesForSwLDS(Function &F) {
+  Module *M = F.getParent();
+  GlobalVariable *SwLDSGlobal = getKernelSwLDSGlobal(*M, F);
+  GlobalVariable *SwLDSMetadataGlobal = getKernelSwLDSMetadataGlobal(*M, F);
+
+  if (!SwLDSGlobal || !SwLDSMetadataGlobal)
+    return;
+
+  LLVMContext &Ctx = M->getContext();
+  Type *Int64Ty = Type::getInt64Ty(Ctx);
+  Type *VoidTy = Type::getVoidTy(Ctx);
+  FunctionCallee AsanPoisonRegion = M->getOrInsertFunction(
+      StringRef("__asan_poison_region"),
+      FunctionType::get(VoidTy, {Int64Ty, Int64Ty}, false));
+  Constant *MdInit = SwLDSMetadataGlobal->getInitializer();
+
+  for (User *U : SwLDSGlobal->users()) {
+    StoreInst *SI = dyn_cast<StoreInst>(U);
+    if (!SI)
+      continue;
+
+    Type *PtrTy =
+        cast<PointerType>(SI->getValueOperand()->getType()->getScalarType());
+    unsigned int AddrSpace = PtrTy->getPointerAddressSpace();
+    if (AddrSpace != 1)
+      report_fatal_error("AMDGPU illegal store to SW LDS");
+
+    StructType *MDStructType =
+        cast<StructType>(SwLDSMetadataGlobal->getValueType());
+    assert(MDStructType);
----------------
arsenm wrote:

This is redundant with cast

https://github.com/llvm/llvm-project/pull/89208


More information about the llvm-commits mailing list