[llvm] [LoopUtils] Cache VFs in addDiffRuntimeChecks (NFC) (PR #130157)
Ramkumar Ramachandra via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 7 07:58:12 PST 2025
https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/130157
>From 30ebfa4b83352cba07a6ad54a4a71f4f597f4a3f Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 6 Mar 2025 14:17:27 +0000
Subject: [PATCH] [LoopUtils] Cache VFs in addDiffRuntimeChecks (NFC)
Caching the runtime-VF, which is actually a vscale expression is safe,
when we previously thought that it was unsafe, partly due to a bad FIXME
in one of the tests. Strip the FIXME, and demonstrate that
GeneratedRTChecks::create does the right thing, by moving the logic for
caching runtime-VF to LoopUtils. As a result, we improve the code in
GeneratedRTChecks::create, to avoid a non-intuitive footgun.
---
llvm/lib/Transforms/Utils/LoopUtils.cpp | 10 +++++--
.../Transforms/Vectorize/LoopVectorize.cpp | 26 ++++++++-----------
.../sve-runtime-check-size-based-threshold.ll | 1 -
3 files changed, 19 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 84c08556f8a25..63ad32f6fb561 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -2048,12 +2048,18 @@ Value *llvm::addDiffRuntimeChecks(
// Map to keep track of created compares, The key is the pair of operands for
// the compare, to allow detecting and re-using redundant compares.
DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
+ // Map to detect redundant values returned by GetVF.
+ DenseMap<Type *, Value *> SeenVFs;
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
Type *Ty = SinkStart->getType();
+ Value *VF = SeenVFs.lookup(Ty);
+ if (!VF) {
+ VF = GetVF(ChkBuilder, Ty->getScalarSizeInBits());
+ SeenVFs.insert({Ty, VF});
+ }
// Compute VF * IC * AccessSize.
auto *VFTimesICTimesSize =
- ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
- ConstantInt::get(Ty, IC * AccessSize));
+ ChkBuilder.CreateMul(VF, ConstantInt::get(Ty, IC * AccessSize));
Value *Diff =
Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 74ddf906ff9fd..354c63236882a 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1923,21 +1923,17 @@ class GeneratedRTChecks {
"vector.memcheck");
auto DiffChecks = RtPtrChecking.getDiffChecks();
- if (DiffChecks) {
- Value *RuntimeVF = nullptr;
- MemRuntimeCheckCond = addDiffRuntimeChecks(
- MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
- [VF, &RuntimeVF](IRBuilderBase &B, unsigned Bits) {
- if (!RuntimeVF)
- RuntimeVF = getRuntimeVF(B, B.getIntNTy(Bits), VF);
- return RuntimeVF;
- },
- IC);
- } else {
- MemRuntimeCheckCond = addRuntimeChecks(
- MemCheckBlock->getTerminator(), L, RtPtrChecking.getChecks(),
- MemCheckExp, VectorizerParams::HoistRuntimeChecks);
- }
+ MemRuntimeCheckCond =
+ DiffChecks
+ ? addDiffRuntimeChecks(
+ MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
+ [VF](IRBuilderBase &B, unsigned Bits) {
+ return getRuntimeVF(B, B.getIntNTy(Bits), VF);
+ },
+ IC)
+ : addRuntimeChecks(MemCheckBlock->getTerminator(), L,
+ RtPtrChecking.getChecks(), MemCheckExp,
+ VectorizerParams::HoistRuntimeChecks);
assert(MemRuntimeCheckCond &&
"no RT checks generated although RtPtrChecking "
"claimed checks are required");
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
index feb27caf305a2..40116f161dd6b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
@@ -5,7 +5,6 @@ target triple = "aarch64-unknown-linux-gnu"
; Test case where the minimum profitable trip count due to runtime checks
; exceeds VF.getKnownMinValue() * UF.
-; FIXME: The code currently incorrectly is missing a umax(VF * UF, 28).
define void @min_trip_count_due_to_runtime_checks_1(ptr %dst.1, ptr %dst.2, ptr %src.1, ptr %src.2, i64 %n) {
; CHECK-LABEL: @min_trip_count_due_to_runtime_checks_1(
; CHECK-NEXT: entry:
More information about the llvm-commits
mailing list