[llvm-branch-commits] [NFC][BoundsChecking] Switch HandlerBuilder to CRTP (PR #120506)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Dec 18 16:35:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Vitaly Buka (vitalybuka)

<details>
<summary>Changes</summary>

So far the only implementation is TrapHandlerBuilder.


---
Full diff: https://github.com/llvm/llvm-project/pull/120506.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp (+44-23) 


``````````diff
diff --git a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp
index a985ab07e12eeb..56a9dd40c500f1 100644
--- a/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp
+++ b/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp
@@ -105,11 +105,11 @@ static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal,
   return Or;
 }
 
-class HandlerBuilder {
-  BasicBlock *TrapBB = nullptr;
+template <class T> class HandlerBuilder {
+  BasicBlock *ReuseTrapBB = nullptr;
 
 public:
-  BasicBlock *build(BuilderTy &IRB) {
+  BasicBlock *build(BuilderTy &IRB, BasicBlock *Cont) {
     Function *Fn = IRB.GetInsertBlock()->getParent();
     auto DebugLoc = IRB.getCurrentDebugLocation();
     IRBuilder<>::InsertPointGuard Guard(IRB);
@@ -117,31 +117,45 @@ class HandlerBuilder {
     // Create a trapping basic block on demand using a callback. Depending on
     // flags, this will either create a single block for the entire function or
     // will create a fresh block every time it is called.
-    if (TrapBB && SingleTrapBB && !DebugTrapBB)
-      return TrapBB;
+    if (ReuseTrapBB)
+      return ReuseTrapBB;
 
-    TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn);
+    BasicBlock *TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn);
     IRB.SetInsertPoint(TrapBB);
 
-    Intrinsic::ID IntrID = DebugTrapBB ? Intrinsic::ubsantrap : Intrinsic::trap;
+    CallInst *TrapCall = T::CallHandler(IRB);
 
-    CallInst *TrapCall;
-    if (DebugTrapBB) {
-      TrapCall = IRB.CreateIntrinsic(
-          IntrID, {}, ConstantInt::get(IRB.getInt8Ty(), Fn->size()));
+    TrapCall->setDoesNotThrow();
+    TrapCall->setDebugLoc(DebugLoc);
+    if (T::MayReturn) {
+      IRB.CreateBr(Cont);
     } else {
-      TrapCall = IRB.CreateIntrinsic(IntrID, {}, {});
+      TrapCall->setDoesNotReturn();
+      IRB.CreateUnreachable();
     }
 
-    TrapCall->setDoesNotReturn();
-    TrapCall->setDoesNotThrow();
-    TrapCall->setDebugLoc(DebugLoc);
-    IRB.CreateUnreachable();
+    if (T::CanReuseTrapBB && SingleTrapBB && !DebugTrapBB)
+      ReuseTrapBB = TrapBB;
 
     return TrapBB;
   }
 };
 
+class TrapHandlerBuilder : public HandlerBuilder<TrapHandlerBuilder> {
+public:
+  static constexpr bool MayReturn = false;
+  static constexpr bool CanReuseTrapBB = true;
+
+  static CallInst *CallHandler(BuilderTy &IRB) {
+    if (!DebugTrapBB)
+      return IRB.CreateIntrinsic(Intrinsic::trap, {}, {});
+    return IRB.CreateIntrinsic(
+        Intrinsic::ubsantrap, {},
+        ConstantInt::get(IRB.getInt8Ty(),
+                         IRB.GetInsertBlock()->getParent()->size()));
+  }
+};
+
 /// Adds run-time bounds checks to memory accessing instructions.
 ///
 /// \p Or is the condition that should guard the trap.
@@ -164,7 +178,7 @@ static void insertBoundsCheck(Value *Or, BuilderTy &IRB, HandlerBuilderTy &HB) {
   BasicBlock *Cont = OldBB->splitBasicBlock(SplitI);
   OldBB->getTerminator()->eraseFromParent();
 
-  BasicBlock *TrapBB = HB.build(IRB);
+  BasicBlock *TrapBB = HB.build(IRB, Cont);
 
   if (C) {
     // If we have a constant zero, unconditionally branch.
@@ -180,8 +194,8 @@ static void insertBoundsCheck(Value *Or, BuilderTy &IRB, HandlerBuilderTy &HB) {
 
 template <typename HandlerBuilderTy>
 bool insertBoundsChecks(
-    const ArrayRef<std::pair<Instruction *, Value *>> &TrapInfo,
-    HandlerBuilderTy &HB) {
+    const ArrayRef<std::pair<Instruction *, Value *>> &TrapInfo) {
+  HandlerBuilderTy HB;
   for (const auto &Entry : TrapInfo) {
     Instruction *Inst = Entry.first;
     const DataLayout &DL = Inst->getParent()->getDataLayout();
@@ -193,7 +207,8 @@ bool insertBoundsChecks(
 }
 
 static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI,
-                              ScalarEvolution &SE) {
+                              ScalarEvolution &SE,
+                              BoundsCheckingPass::ReportingMode Mode) {
   if (F.hasFnAttribute(Attribute::NoSanitizeBounds))
     return false;
 
@@ -231,15 +246,21 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI,
       TrapInfo.push_back(std::make_pair(&I, Or));
   }
 
-  HandlerBuilder HB;
-  return insertBoundsChecks(TrapInfo, HB);
+  switch (Mode) {
+  case BoundsCheckingPass::ReportingMode::Trap:
+  case BoundsCheckingPass::ReportingMode::MinRuntimeAbort:
+  case BoundsCheckingPass::ReportingMode::MinRuntime:
+  case BoundsCheckingPass::ReportingMode::FullRuntime:
+  case BoundsCheckingPass::ReportingMode::FullRuntimeAbort:
+    return insertBoundsChecks<TrapHandlerBuilder>(TrapInfo);
+  }
 }
 
 PreservedAnalyses BoundsCheckingPass::run(Function &F, FunctionAnalysisManager &AM) {
   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
   auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
 
-  if (!addBoundsChecking(F, TLI, SE))
+  if (!addBoundsChecking(F, TLI, SE, Mode))
     return PreservedAnalyses::all();
 
   return PreservedAnalyses::none();

``````````

</details>


https://github.com/llvm/llvm-project/pull/120506


More information about the llvm-branch-commits mailing list