[llvm] [LoopUtils] Cache VFs in addDiffRuntimeChecks (NFC) (PR #130157)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 7 07:58:12 PST 2025


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/130157

>From 30ebfa4b83352cba07a6ad54a4a71f4f597f4a3f Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 6 Mar 2025 14:17:27 +0000
Subject: [PATCH] [LoopUtils] Cache VFs in addDiffRuntimeChecks (NFC)

Caching the runtime-VF, which is actually a vscale expression is safe,
when we previously thought that it was unsafe, partly due to a bad FIXME
in one of the tests. Strip the FIXME, and demonstrate that
GeneratedRTChecks::create does the right thing, by moving the logic for
caching runtime-VF to LoopUtils. As a result, we improve the code in
GeneratedRTChecks::create, to avoid a non-intuitive footgun.
---
 llvm/lib/Transforms/Utils/LoopUtils.cpp       | 10 +++++--
 .../Transforms/Vectorize/LoopVectorize.cpp    | 26 ++++++++-----------
 .../sve-runtime-check-size-based-threshold.ll |  1 -
 3 files changed, 19 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 84c08556f8a25..63ad32f6fb561 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -2048,12 +2048,18 @@ Value *llvm::addDiffRuntimeChecks(
   // Map to keep track of created compares, The key is the pair of operands for
   // the compare, to allow detecting and re-using redundant compares.
   DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
+  // Map to detect redundant values returned by GetVF.
+  DenseMap<Type *, Value *> SeenVFs;
   for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
     Type *Ty = SinkStart->getType();
+    Value *VF = SeenVFs.lookup(Ty);
+    if (!VF) {
+      VF = GetVF(ChkBuilder, Ty->getScalarSizeInBits());
+      SeenVFs.insert({Ty, VF});
+    }
     // Compute VF * IC * AccessSize.
     auto *VFTimesICTimesSize =
-        ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
-                             ConstantInt::get(Ty, IC * AccessSize));
+        ChkBuilder.CreateMul(VF, ConstantInt::get(Ty, IC * AccessSize));
     Value *Diff =
         Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc);
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 74ddf906ff9fd..354c63236882a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1923,21 +1923,17 @@ class GeneratedRTChecks {
                                  "vector.memcheck");
 
       auto DiffChecks = RtPtrChecking.getDiffChecks();
-      if (DiffChecks) {
-        Value *RuntimeVF = nullptr;
-        MemRuntimeCheckCond = addDiffRuntimeChecks(
-            MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
-            [VF, &RuntimeVF](IRBuilderBase &B, unsigned Bits) {
-              if (!RuntimeVF)
-                RuntimeVF = getRuntimeVF(B, B.getIntNTy(Bits), VF);
-              return RuntimeVF;
-            },
-            IC);
-      } else {
-        MemRuntimeCheckCond = addRuntimeChecks(
-            MemCheckBlock->getTerminator(), L, RtPtrChecking.getChecks(),
-            MemCheckExp, VectorizerParams::HoistRuntimeChecks);
-      }
+      MemRuntimeCheckCond =
+          DiffChecks
+              ? addDiffRuntimeChecks(
+                    MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
+                    [VF](IRBuilderBase &B, unsigned Bits) {
+                      return getRuntimeVF(B, B.getIntNTy(Bits), VF);
+                    },
+                    IC)
+              : addRuntimeChecks(MemCheckBlock->getTerminator(), L,
+                                 RtPtrChecking.getChecks(), MemCheckExp,
+                                 VectorizerParams::HoistRuntimeChecks);
       assert(MemRuntimeCheckCond &&
              "no RT checks generated although RtPtrChecking "
              "claimed checks are required");
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
index feb27caf305a2..40116f161dd6b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
@@ -5,7 +5,6 @@ target triple = "aarch64-unknown-linux-gnu"
 
 ; Test case where the minimum profitable trip count due to runtime checks
 ; exceeds VF.getKnownMinValue() * UF.
-; FIXME: The code currently incorrectly is missing a umax(VF * UF, 28).
 define void @min_trip_count_due_to_runtime_checks_1(ptr %dst.1, ptr %dst.2, ptr %src.1, ptr %src.2, i64 %n) {
 ; CHECK-LABEL: @min_trip_count_due_to_runtime_checks_1(
 ; CHECK-NEXT:  entry:



More information about the llvm-commits mailing list