[llvm] [LoopUtils] Cache VFs in addDiffRuntimeChecks (NFC) (PR #130157)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 6 10:24:31 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-vectorizers
@llvm/pr-subscribers-llvm-transforms
Author: Ramkumar Ramachandra (artagnon)
<details>
<summary>Changes</summary>
Caching the runtime-VF, which is actually a vscale expression is safe, when we previously thought that it was unsafe, partly due to a bad FIXME in one of the tests. Strip the FIXME, and demonstrate that GeneratedRTChecks::create does the right thing, by moving the logic for caching runtime-VF to LoopUtils. As a result, we improve the code in GeneratedRTChecks::create, to avoid a non-intuitive footgun.
---
Full diff: https://github.com/llvm/llvm-project/pull/130157.diff
3 Files Affected:
- (modified) llvm/lib/Transforms/Utils/LoopUtils.cpp (+8-2)
- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+11-15)
- (modified) llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll (-1)
``````````diff
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index ec1692a484ce0..074fae5bb0564 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -2048,12 +2048,18 @@ Value *llvm::addDiffRuntimeChecks(
// Map to keep track of created compares, The key is the pair of operands for
// the compare, to allow detecting and re-using redundant compares.
DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
+ // Map to detect redundant values returned by GetVF.
+ DenseMap<Type *, Value *> SeenVFs;
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
Type *Ty = SinkStart->getType();
+ Value *VF = SeenVFs.lookup(Ty);
+ if (!VF) {
+ VF = GetVF(ChkBuilder, Ty->getScalarSizeInBits());
+ SeenVFs.insert({Ty, VF});
+ }
// Compute VF * IC * AccessSize.
auto *VFTimesUFTimesSize =
- ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
- ConstantInt::get(Ty, IC * AccessSize));
+ ChkBuilder.CreateMul(VF, ConstantInt::get(Ty, IC * AccessSize));
Value *Diff =
Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index cb860a472d8f7..5fe6551c3f8e2 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1924,21 +1924,17 @@ class GeneratedRTChecks {
"vector.memcheck");
auto DiffChecks = RtPtrChecking.getDiffChecks();
- if (DiffChecks) {
- Value *RuntimeVF = nullptr;
- MemRuntimeCheckCond = addDiffRuntimeChecks(
- MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
- [VF, &RuntimeVF](IRBuilderBase &B, unsigned Bits) {
- if (!RuntimeVF)
- RuntimeVF = getRuntimeVF(B, B.getIntNTy(Bits), VF);
- return RuntimeVF;
- },
- IC);
- } else {
- MemRuntimeCheckCond = addRuntimeChecks(
- MemCheckBlock->getTerminator(), L, RtPtrChecking.getChecks(),
- MemCheckExp, VectorizerParams::HoistRuntimeChecks);
- }
+ MemRuntimeCheckCond =
+ DiffChecks
+ ? addDiffRuntimeChecks(
+ MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
+ [VF](IRBuilderBase &B, unsigned Bits) {
+ return getRuntimeVF(B, B.getIntNTy(Bits), VF);
+ },
+ IC)
+ : addRuntimeChecks(MemCheckBlock->getTerminator(), L,
+ RtPtrChecking.getChecks(), MemCheckExp,
+ VectorizerParams::HoistRuntimeChecks);
assert(MemRuntimeCheckCond &&
"no RT checks generated although RtPtrChecking "
"claimed checks are required");
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
index feb27caf305a2..40116f161dd6b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-runtime-check-size-based-threshold.ll
@@ -5,7 +5,6 @@ target triple = "aarch64-unknown-linux-gnu"
; Test case where the minimum profitable trip count due to runtime checks
; exceeds VF.getKnownMinValue() * UF.
-; FIXME: The code currently incorrectly is missing a umax(VF * UF, 28).
define void @min_trip_count_due_to_runtime_checks_1(ptr %dst.1, ptr %dst.2, ptr %src.1, ptr %src.2, i64 %n) {
; CHECK-LABEL: @min_trip_count_due_to_runtime_checks_1(
; CHECK-NEXT: entry:
``````````
</details>
https://github.com/llvm/llvm-project/pull/130157
More information about the llvm-commits
mailing list