[llvm] fc5254c - [LoopUtils] Simplify code for runtime check generation a bit (NFCI).

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 4 04:12:50 PDT 2024


Author: Florian Hahn
Date: 2024-06-04T12:12:29+01:00
New Revision: fc5254c8ac02d29e7daab4ecce42cb5a82c8b3a2

URL: https://github.com/llvm/llvm-project/commit/fc5254c8ac02d29e7daab4ecce42cb5a82c8b3a2
DIFF: https://github.com/llvm/llvm-project/commit/fc5254c8ac02d29e7daab4ecce42cb5a82c8b3a2.diff

LOG: [LoopUtils] Simplify code for runtime check generation a bit (NFCI).

Store getSE result in variable to re-use and use structured bindings
when looping over bounds.

Added: 
    

Modified: 
    llvm/lib/Transforms/Utils/LoopUtils.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index cc883a7dc2927..de3eb4a4ed5be 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1743,16 +1743,16 @@ static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
     auto *HighAR = cast<SCEVAddRecExpr>(High);
     auto *LowAR = cast<SCEVAddRecExpr>(Low);
     const Loop *OuterLoop = TheLoop->getParentLoop();
-    const SCEV *Recur = LowAR->getStepRecurrence(*Exp.getSE());
-    if (Recur == HighAR->getStepRecurrence(*Exp.getSE()) &&
+    ScalarEvolution &SE = *Exp.getSE();
+    const SCEV *Recur = LowAR->getStepRecurrence(SE);
+    if (Recur == HighAR->getStepRecurrence(SE) &&
         HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) {
       BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch();
-      const SCEV *OuterExitCount =
-          Exp.getSE()->getExitCount(OuterLoop, OuterLoopLatch);
+      const SCEV *OuterExitCount = SE.getExitCount(OuterLoop, OuterLoopLatch);
       if (!isa<SCEVCouldNotCompute>(OuterExitCount) &&
           OuterExitCount->getType()->isIntegerTy()) {
-        const SCEV *NewHigh = cast<SCEVAddRecExpr>(High)->evaluateAtIteration(
-            OuterExitCount, *Exp.getSE());
+        const SCEV *NewHigh =
+            cast<SCEVAddRecExpr>(High)->evaluateAtIteration(OuterExitCount, SE);
         if (!isa<SCEVCouldNotCompute>(NewHigh)) {
           LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include "
                                "outer loop in order to permit hoisting\n");
@@ -1760,7 +1760,7 @@ static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
           Low = cast<SCEVAddRecExpr>(Low)->getStart();
           // If there is a possibility that the stride is negative then we have
           // to generate extra checks to ensure the stride is positive.
-          if (!Exp.getSE()->isKnownNonNegative(Recur)) {
+          if (!SE.isKnownNonNegative(Recur)) {
             Stride = Recur;
             LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is "
                                  "positive: "
@@ -1821,8 +1821,7 @@ Value *llvm::addRuntimeChecks(
   // Our instructions might fold to a constant.
   Value *MemoryRuntimeCheck = nullptr;
 
-  for (const auto &Check : ExpandedChecks) {
-    const PointerBounds &A = Check.first, &B = Check.second;
+  for (const auto &[A, B] : ExpandedChecks) {
     // Check if two pointers (A and B) conflict where conflict is computed as:
     // start(A) <= end(B) && start(B) <= end(A)
 
@@ -1880,14 +1879,14 @@ Value *llvm::addDiffRuntimeChecks(
   // Map to keep track of created compares, The key is the pair of operands for
   // the compare, to allow detecting and re-using redundant compares.
   DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
-  for (const auto &C : Checks) {
-    Type *Ty = C.SinkStart->getType();
+  for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
+    Type *Ty = SinkStart->getType();
     // Compute VF * IC * AccessSize.
     auto *VFTimesUFTimesSize =
         ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
-                             ConstantInt::get(Ty, IC * C.AccessSize));
-    Value *Diff = Expander.expandCodeFor(
-        SE.getMinusSCEV(C.SinkStart, C.SrcStart), Ty, Loc);
+                             ConstantInt::get(Ty, IC * AccessSize));
+    Value *Diff =
+        Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc);
 
     // Check if the same compare has already been created earlier. In that case,
     // there is no need to check it again.
@@ -1898,7 +1897,7 @@ Value *llvm::addDiffRuntimeChecks(
     IsConflict =
         ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "
diff .check");
     SeenCompares.insert({{Diff, VFTimesUFTimesSize}, IsConflict});
-    if (C.NeedsFreeze)
+    if (NeedsFreeze)
       IsConflict =
           ChkBuilder.CreateFreeze(IsConflict, IsConflict->getName() + ".fr");
     if (MemoryRuntimeCheck) {


        


More information about the llvm-commits mailing list