[llvm] LAA: make stride versioning code more robust (PR #97075)
Ramkumar Ramachandra via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 28 08:56:54 PDT 2024
https://github.com/artagnon created https://github.com/llvm/llvm-project/pull/97075
The current stride versioning code in collectStridedAccess is quite fragile, and has implicit effects. Make it more robust by making it clear that Stride - 1 == BTC is a special case, and operate on ConstantRanges directly. Query the exact backedge-taken count instead of the symbolic maximum of it. This patch has the side effect of making it possible to directly return the SCEVUnknown under a cast in getStrideFromPointer, eliminating a second cast-stripping in collectStridedAccess. It also has the side-effect of a positive test update in symbolic-stride.
>From e81a7ebbe0a9e1b127935d67dc713971484e0688 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 27 Jun 2024 17:12:33 +0100
Subject: [PATCH] LAA: make stride versioning code more robust
The current stride versioning code in collectStridedAccess is quite
fragile, and has implicit effects. Make it more robust by making it
clear that Stride - 1 == BTC is a special case, and operate on
ConstantRanges directly. Query the exact backedge-taken count instead of
the symbolic maximum of it. This patch has the side effect of making it
possible to directly return the SCEVUnknown under a cast in
getStrideFromPointer, eliminating a second cast-stripping in
collectStridedAccess. It also has the side-effect of a positive test
update in symbolic-stride.
---
llvm/include/llvm/IR/ConstantRange.h | 3 +
llvm/lib/Analysis/LoopAccessAnalysis.cpp | 62 ++++++++-----------
llvm/lib/IR/ConstantRange.cpp | 5 ++
.../LoopAccessAnalysis/symbolic-stride.ll | 15 ++---
4 files changed, 38 insertions(+), 47 deletions(-)
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index 7b94b9c6c6d11..86d0a6b35d748 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -277,6 +277,9 @@ class [[nodiscard]] ConstantRange {
/// Return true if all values in this range are non-negative.
bool isAllNonNegative() const;
+ /// Return true if all values in this range are positive.
+ bool isAllPositive() const;
+
/// Return the largest unsigned value contained in the ConstantRange.
APInt getUnsignedMax() const;
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 38bf6d8160aa9..4932ed61ec1ba 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -35,6 +35,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugLoc.h"
@@ -61,6 +62,8 @@
#include <cassert>
#include <cstdint>
#include <iterator>
+#include <optional>
+#include <sys/types.h>
#include <utility>
#include <variant>
#include <vector>
@@ -2914,7 +2917,7 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(V))
if (isa<SCEVUnknown>(C->getOperand()))
- return V;
+ return C->getOperand();
return nullptr;
}
@@ -2930,7 +2933,8 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// computation of an interesting IV - but we chose not to as we
// don't have a cost model here, and broadening the scope exposes
// far too many unprofitable cases.
- const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
+ ScalarEvolution *SE = PSE->getSE();
+ const SCEV *StrideExpr = getStrideFromPointer(Ptr, SE, TheLoop);
if (!StrideExpr)
return;
@@ -2943,10 +2947,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
return;
}
- // Avoid adding the "Stride == 1" predicate when we know that
- // Stride >= Trip-Count. Such a predicate will effectively optimize a single
- // or zero iteration loop, as Trip-Count <= Stride == 1.
- //
// TODO: We are currently not making a very informed decision on when it is
// beneficial to apply stride versioning. It might make more sense that the
// users of this analysis (such as the vectorizer) will trigger it, based on
@@ -2956,40 +2956,30 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// of various possible stride specializations, considering the alternatives
// of using gather/scatters (if available).
- const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount();
-
- // Match the types so we can compare the stride and the MaxBTC.
- // The Stride can be positive/negative, so we sign extend Stride;
- // The backedgeTakenCount is non-negative, so we zero extend MaxBTC.
- const DataLayout &DL = TheLoop->getHeader()->getDataLayout();
- uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType());
- uint64_t BETypeSizeBits = DL.getTypeSizeInBits(MaxBTC->getType());
- const SCEV *CastedStride = StrideExpr;
- const SCEV *CastedBECount = MaxBTC;
- ScalarEvolution *SE = PSE->getSE();
- if (BETypeSizeBits >= StrideTypeSizeBits)
- CastedStride = SE->getNoopOrSignExtend(StrideExpr, MaxBTC->getType());
- else
- CastedBECount = SE->getZeroExtendExpr(MaxBTC, StrideExpr->getType());
- const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
- // Since TripCount == BackEdgeTakenCount + 1, checking:
- // "Stride >= TripCount" is equivalent to checking:
- // Stride - MaxBTC> 0
- if (SE->isKnownPositive(StrideMinusBETaken)) {
- LLVM_DEBUG(
- dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
- "Stride==1 predicate will imply that the loop executes "
- "at most once.\n");
+ // Get two signed ranges and compare them, after adjusting for bitwidth. BTC
+ // range could extend into -1.
+ const SCEV *BTC = PSE->getBackedgeTakenCount();
+ ConstantRange BTCRange = SE->getSignedRange(BTC);
+ ConstantRange StrideRange =
+ SE->getSignedRange(StrideExpr).sextOrTrunc(BTCRange.getBitWidth());
+
+ // Stride is zero-extended to compare with BTC.
+ const SCEV *CastedStride =
+ SE->getTruncateOrZeroExtend(StrideExpr, BTC->getType());
+ const SCEV *StrideMinusOne =
+ SE->getMinusSCEV(CastedStride, SE->getOne(CastedStride->getType()));
+
+ // Stride - 1 exactly equal to BTC is a special case for which the loop should
+ // not be versioned. Otherwise, the loop should not be versioned if the range
+ // difference is all positive.
+ if (StrideMinusOne == BTC ||
+ StrideRange.difference(BTCRange).isAllPositive()) {
+ LLVM_DEBUG(dbgs() << "LAA: Not versioning with Stride==1 predicate.\n");
return;
}
LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");
- // Strip back off the integer cast, and check that our result is a
- // SCEVUnknown as we expect.
- const SCEV *StrideBase = StrideExpr;
- if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
- StrideBase = C->getOperand();
- SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
+ SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideExpr);
}
LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 19041704a40be..b942894d34467 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -440,6 +440,11 @@ bool ConstantRange::isAllNonNegative() const {
return !isSignWrappedSet() && Lower.isNonNegative();
}
+bool ConstantRange::isAllPositive() const {
+ // Empty and full set are automatically treated correctly.
+ return !isSignWrappedSet() && Lower.isStrictlyPositive();
+}
+
APInt ConstantRange::getUnsignedMax() const {
if (isFullSet() || isUpperWrapped())
return APInt::getMaxValue(getBitWidth());
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
index 7c1b11e22aef2..e9aeac7ac2bc5 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
@@ -170,23 +170,16 @@ define void @single_stride_castexpr_multiuse(i32 %offset, ptr %src, ptr %dst, i1
; CHECK-NEXT: %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3
; CHECK-NEXT: Grouped accesses:
; CHECK-NEXT: Group [[GRP3]]:
-; CHECK-NEXT: (Low: ((4 * %iv.1) + %dst) High: (804 + (4 * %iv.1) + (-4 * (zext i32 %offset to i64))<nsw> + %dst))
-; CHECK-NEXT: Member: {((4 * %iv.1) + %dst),+,4}<%inner.loop>
+; CHECK-NEXT: (Low: (((4 * %iv.1) + %dst) umin ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst)) High: (4 + (((4 * %iv.1) + %dst) umax ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst))))
+; CHECK-NEXT: Member: {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
; CHECK-NEXT: Group [[GRP4]]:
-; CHECK-NEXT: (Low: (4 + %src) High: (808 + (-4 * (zext i32 %offset to i64))<nsw> + %src))
-; CHECK-NEXT: Member: {(4 + %src),+,4}<%inner.loop>
+; CHECK-NEXT: (Low: ((4 * (zext i32 %offset to i64))<nuw><nsw> + %src) High: (804 + %src))
+; CHECK-NEXT: Member: {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
; CHECK-EMPTY:
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
-; CHECK-NEXT: Equal predicate: %offset == 1
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
-; CHECK-NEXT: [PSE] %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3:
-; CHECK-NEXT: {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
-; CHECK-NEXT: --> {(4 + %src),+,4}<%inner.loop>
-; CHECK-NEXT: [PSE] %gep.dst = getelementptr i32, ptr %dst, i64 %iv.2:
-; CHECK-NEXT: {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
-; CHECK-NEXT: --> {((4 * %iv.1) + %dst),+,4}<%inner.loop>
; CHECK-NEXT: outer.header:
; CHECK-NEXT: Report: loop is not the innermost loop
; CHECK-NEXT: Dependences:
More information about the llvm-commits
mailing list