[llvm] [LoopUtils] Cache VFs in addDiffRuntimeChecks (NFC) (PR #130157)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 6 10:24:31 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>

Caching the runtime-VF, which is actually a vscale expression is safe, when we previously thought that it was unsafe, partly due to a bad FIXME in one of the tests. Strip the FIXME, and demonstrate that GeneratedRTChecks::create does the right thing, by moving the logic for caching runtime-VF to LoopUtils. As a result, we improve the code in GeneratedRTChecks::create, to avoid a non-intuitive footgun.

---
Full diff: https://github.com/llvm/llvm-project/pull/130157.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/Utils/LoopUtils.cpp (+8-2) 
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+11-15) 
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll (-1) 


``````````diff
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index ec1692a484ce0..074fae5bb0564 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -2048,12 +2048,18 @@ Value *llvm::addDiffRuntimeChecks(
   // Map to keep track of created compares, The key is the pair of operands for
   // the compare, to allow detecting and re-using redundant compares.
   DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
+  // Map to detect redundant values returned by GetVF.
+  DenseMap<Type *, Value *> SeenVFs;
   for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
     Type *Ty = SinkStart->getType();
+    Value *VF = SeenVFs.lookup(Ty);
+    if (!VF) {
+      VF = GetVF(ChkBuilder, Ty->getScalarSizeInBits());
+      SeenVFs.insert({Ty, VF});
+    }
     // Compute VF * IC * AccessSize.
     auto *VFTimesUFTimesSize =
-        ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
-                             ConstantInt::get(Ty, IC * AccessSize));
+        ChkBuilder.CreateMul(VF, ConstantInt::get(Ty, IC * AccessSize));
     Value *Diff =
         Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc);
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index cb860a472d8f7..5fe6551c3f8e2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1924,21 +1924,17 @@ class GeneratedRTChecks {
                                  "vector.memcheck");
 
       auto DiffChecks = RtPtrChecking.getDiffChecks();
-      if (DiffChecks) {
-        Value *RuntimeVF = nullptr;
-        MemRuntimeCheckCond = addDiffRuntimeChecks(
-            MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
-            [VF, &RuntimeVF](IRBuilderBase &B, unsigned Bits) {
-              if (!RuntimeVF)
-                RuntimeVF = getRuntimeVF(B, B.getIntNTy(Bits), VF);
-              return RuntimeVF;
-            },
-            IC);
-      } else {
-        MemRuntimeCheckCond = addRuntimeChecks(
-            MemCheckBlock->getTerminator(), L, RtPtrChecking.getChecks(),
-            MemCheckExp, VectorizerParams::HoistRuntimeChecks);
-      }
+      MemRuntimeCheckCond =
+          DiffChecks
+              ? addDiffRuntimeChecks(
+                    MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
+                    [VF](IRBuilderBase &B, unsigned Bits) {
+                      return getRuntimeVF(B, B.getIntNTy(Bits), VF);
+                    },
+                    IC)
+              : addRuntimeChecks(MemCheckBlock->getTerminator(), L,
+                                 RtPtrChecking.getChecks(), MemCheckExp,
+                                 VectorizerParams::HoistRuntimeChecks);
       assert(MemRuntimeCheckCond &&
              "no RT checks generated although RtPtrChecking "
              "claimed checks are required");
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
index feb27caf305a2..40116f161dd6b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
@@ -5,7 +5,6 @@ target triple = "aarch64-unknown-linux-gnu"
 
 ; Test case where the minimum profitable trip count due to runtime checks
 ; exceeds VF.getKnownMinValue() * UF.
-; FIXME: The code currently incorrectly is missing a umax(VF * UF, 28).
 define void @min_trip_count_due_to_runtime_checks_1(ptr %dst.1, ptr %dst.2, ptr %src.1, ptr %src.2, i64 %n) {
 ; CHECK-LABEL: @min_trip_count_due_to_runtime_checks_1(
 ; CHECK-NEXT:  entry:

``````````

</details>


https://github.com/llvm/llvm-project/pull/130157


More information about the llvm-commits mailing list