[llvm] [SCEVPatternMatch] Extend with more matchers (PR #138836)
Ramkumar Ramachandra via llvm-commits
llvm-commits at lists.llvm.org
Thu May 8 13:32:21 PDT 2025
https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/138836
>From eecf6d1f52d027ceeec71ed7d5d986716c1b83d5 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Wed, 7 May 2025 10:16:25 +0100
Subject: [PATCH 1/4] [SCEVPatternMatch] Extend with more matchers
---
.../Analysis/ScalarEvolutionPatternMatch.h | 48 +++++++++++++++++++
llvm/lib/Analysis/LoopAccessAnalysis.cpp | 39 +++++----------
llvm/lib/Analysis/ScalarEvolution.cpp | 21 ++++----
3 files changed, 69 insertions(+), 39 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 674147ca175ef..f5fa9c9b72c77 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -58,6 +58,8 @@ template <typename Class> struct class_match {
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
};
+inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
+
template <typename Class> struct bind_ty {
Class *&VR;
@@ -93,6 +95,41 @@ struct specificscev_ty {
/// Match if we have a specific specified SCEV.
inline specificscev_ty m_Specific(const SCEV *S) { return S; }
+template <typename Class> struct cst_match {
+ Class CV;
+
+ cst_match(Class Op0) : CV(Op0) {}
+
+ bool match(const SCEV *S) const {
+ assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
+ "no vector types expected from SCEVs");
+ auto *C = dyn_cast<SCEVConstant>(S);
+ return C && C->getAPInt() == CV;
+ }
+};
+
+/// Match an SCEV constant with a plain unsigned integer.
+inline cst_match<uint64_t> m_SCEVConstant(uint64_t V) { return V; }
+
+struct bind_cst_ty {
+ const APInt *&CR;
+
+ bind_cst_ty(const APInt *&Op0) : CR(Op0) {}
+
+ bool match(const SCEV *S) const {
+ assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
+ "no vector types expected from SCEVs");
+ auto *C = dyn_cast<SCEVConstant>(S);
+ if (!C)
+ return false;
+ CR = &C->getAPInt();
+ return true;
+ }
+};
+
+/// Match an SCEV constant and bind it to an APInt.
+inline bind_cst_ty m_SCEVConstant(const APInt *&C) { return C; }
+
/// Match a unary SCEV.
template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
Op0_t Op0;
@@ -149,6 +186,17 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
}
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t>
+m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
+ return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
+m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
+ return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
+}
} // namespace SCEVPatternMatch
} // namespace llvm
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index f222a9905c3bb..dc3ea7833c129 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -30,6 +30,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -65,6 +66,7 @@
#include <vector>
using namespace llvm;
+using namespace llvm::SCEVPatternMatch;
#define DEBUG_TYPE "loop-accesses"
@@ -811,8 +813,8 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());
// Calculate the pointer stride and check if it is constant.
- const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
- if (!C) {
+ const APInt *APStepVal;
+ if (!match(Step, m_SCEVConstant(APStepVal))) {
LLVM_DEBUG({
dbgs() << "LAA: Bad stride - Not a constant strided ";
if (Ptr)
@@ -825,13 +827,12 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
const auto &DL = Lp->getHeader()->getDataLayout();
TypeSize AllocSize = DL.getTypeAllocSize(AccessTy);
int64_t Size = AllocSize.getFixedValue();
- const APInt &APStepVal = C->getAPInt();
// Huge step value - give up.
- if (APStepVal.getBitWidth() > 64)
+ if (APStepVal->getBitWidth() > 64)
return std::nullopt;
- int64_t StepVal = APStepVal.getSExtValue();
+ int64_t StepVal = APStepVal->getSExtValue();
// Strided access.
int64_t Stride = StepVal / Size;
@@ -2061,11 +2062,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
DL, SE, *(PSE.getSymbolicMaxBackedgeTakenCount()), *Dist, MaxStride))
return Dependence::NoDep;
- const SCEVConstant *ConstDist = dyn_cast<SCEVConstant>(Dist);
-
// Attempt to prove strided accesses independent.
- if (ConstDist) {
- uint64_t Distance = ConstDist->getAPInt().abs().getZExtValue();
+ const APInt *ConstDist = nullptr;
+ if (match(Dist, m_SCEVConstant(ConstDist))) {
+ uint64_t Distance = ConstDist->abs().getZExtValue();
// If the distance between accesses and their strides are known constants,
// check whether the accesses interlace each other.
@@ -2111,9 +2111,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
return Dependence::Unknown;
}
- if (!HasSameSize ||
- couldPreventStoreLoadForward(
- ConstDist->getAPInt().abs().getZExtValue(), TypeByteSize)) {
+ if (!HasSameSize || couldPreventStoreLoadForward(
+ ConstDist->abs().getZExtValue(), TypeByteSize)) {
LLVM_DEBUG(
dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
return Dependence::ForwardButPreventsForwarding;
@@ -2864,20 +2863,8 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
// Strip off the size of access multiplication if we are still analyzing the
// pointer.
- if (OrigPtr == Ptr) {
- if (auto *M = dyn_cast<SCEVMulExpr>(V)) {
- auto *StepConst = dyn_cast<SCEVConstant>(M->getOperand(0));
- if (!StepConst)
- return nullptr;
-
- auto StepVal = StepConst->getAPInt().trySExtValue();
- // Bail out on a non-unit pointer access size.
- if (!StepVal || StepVal != 1)
- return nullptr;
-
- V = M->getOperand(1);
- }
- }
+ if (OrigPtr == Ptr)
+ match(V, m_scev_Mul(m_SCEVConstant(1), m_SCEV(V)));
// Note that the restriction after this loop invariant check are only
// profitability restrictions.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index ac69ad598a65a..f0c35d233c42a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7149,16 +7149,11 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
"Should be!");
- // Peel off a constant offset:
- if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
- // In the future we could consider being smarter here and handle
- // {Start+Step,+,Step} too.
- if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
- return;
-
- Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
- S = SA->getOperand(1);
- }
+ // Peel off a constant offset. In the future we could consider being
+ // smarter here and handle {Start+Step,+,Step} too.
+ const APInt *Off;
+ if (match(S, m_scev_Add(m_SCEVConstant(Off), m_SCEV(S))))
+ Offset = *Off;
// Peel off a cast operation
if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
@@ -7337,11 +7332,11 @@ bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
return !SCEVExprContains(Op, [this](const SCEV *S) {
- auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
+ const SCEV *Op1;
+ bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
// The UDiv may be UB if the divisor is poison or zero. Unless the divisor
// is a non-zero constant, we have to assume the UDiv may be UB.
- return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
- !isGuaranteedNotToBePoison(UDiv->getOperand(1)));
+ return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
});
}
>From 5d3656055602f042bca5bc0a44038e611e3f0478 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 8 May 2025 16:42:57 +0100
Subject: [PATCH 2/4] [SCEVPatternMatch] Rename fns
---
llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h | 4 ++--
llvm/lib/Analysis/LoopAccessAnalysis.cpp | 6 +++---
llvm/lib/Analysis/ScalarEvolution.cpp | 2 +-
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index f5fa9c9b72c77..536d74f296931 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -109,7 +109,7 @@ template <typename Class> struct cst_match {
};
/// Match an SCEV constant with a plain unsigned integer.
-inline cst_match<uint64_t> m_SCEVConstant(uint64_t V) { return V; }
+inline cst_match<uint64_t> m_scev_SpecificInt(uint64_t V) { return V; }
struct bind_cst_ty {
const APInt *&CR;
@@ -128,7 +128,7 @@ struct bind_cst_ty {
};
/// Match an SCEV constant and bind it to an APInt.
-inline bind_cst_ty m_SCEVConstant(const APInt *&C) { return C; }
+inline bind_cst_ty m_scev_APInt(const APInt *&C) { return C; }
/// Match a unary SCEV.
template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index dc3ea7833c129..d98d0d936d788 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -814,7 +814,7 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
// Calculate the pointer stride and check if it is constant.
const APInt *APStepVal;
- if (!match(Step, m_SCEVConstant(APStepVal))) {
+ if (!match(Step, m_scev_APInt(APStepVal))) {
LLVM_DEBUG({
dbgs() << "LAA: Bad stride - Not a constant strided ";
if (Ptr)
@@ -2064,7 +2064,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
// Attempt to prove strided accesses independent.
const APInt *ConstDist = nullptr;
- if (match(Dist, m_SCEVConstant(ConstDist))) {
+ if (match(Dist, m_scev_APInt(ConstDist))) {
uint64_t Distance = ConstDist->abs().getZExtValue();
// If the distance between accesses and their strides are known constants,
@@ -2864,7 +2864,7 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
// Strip off the size of access multiplication if we are still analyzing the
// pointer.
if (OrigPtr == Ptr)
- match(V, m_scev_Mul(m_SCEVConstant(1), m_SCEV(V)));
+ match(V, m_scev_Mul(m_scev_SpecificInt(1), m_SCEV(V)));
// Note that the restriction after this loop invariant check are only
// profitability restrictions.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index f0c35d233c42a..3f9614254ae7a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7152,7 +7152,7 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
// Peel off a constant offset. In the future we could consider being
// smarter here and handle {Start+Step,+,Step} too.
const APInt *Off;
- if (match(S, m_scev_Add(m_SCEVConstant(Off), m_SCEV(S))))
+ if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
Offset = *Off;
// Peel off a cast operation
>From 780e98ebab8a430e8ceeca568848b208e3572c07 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 8 May 2025 20:55:14 +0100
Subject: [PATCH 3/4] [LAA] Revert some code
---
llvm/lib/Analysis/LoopAccessAnalysis.cpp | 16 ++++++++++++++--
1 file changed, 14 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index d98d0d936d788..535b71cd5897e 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -2863,8 +2863,20 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
// Strip off the size of access multiplication if we are still analyzing the
// pointer.
- if (OrigPtr == Ptr)
- match(V, m_scev_Mul(m_scev_SpecificInt(1), m_SCEV(V)));
+ if (OrigPtr == Ptr) {
+ if (auto *M = dyn_cast<SCEVMulExpr>(V)) {
+ auto *StepConst = dyn_cast<SCEVConstant>(M->getOperand(0));
+ if (!StepConst)
+ return nullptr;
+
+ auto StepVal = StepConst->getAPInt().trySExtValue();
+ // Bail out on a non-unit pointer access size.
+ if (!StepVal || StepVal != 1)
+ return nullptr;
+
+ V = M->getOperand(1);
+ }
+ }
// Note that the restriction after this loop invariant check are only
// profitability restrictions.
>From 1a5b7b72579c4025b456a2385ce42f3cc4c447b5 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Thu, 8 May 2025 21:30:35 +0100
Subject: [PATCH 4/4] [SCEVPatternMatch] Introduce is_specific_cst
---
.../Analysis/ScalarEvolutionPatternMatch.h | 19 +++++++------------
1 file changed, 7 insertions(+), 12 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 536d74f296931..5e53a4f502153 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -23,6 +23,8 @@ template <typename Pattern> bool match(const SCEV *S, const Pattern &P) {
}
template <typename Predicate> struct cst_pred_ty : public Predicate {
+ cst_pred_ty() = default;
+ cst_pred_ty(uint64_t V) : Predicate(V) {}
bool match(const SCEV *S) const {
assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
"no vector types expected from SCEVs");
@@ -95,21 +97,14 @@ struct specificscev_ty {
/// Match if we have a specific specified SCEV.
inline specificscev_ty m_Specific(const SCEV *S) { return S; }
-template <typename Class> struct cst_match {
- Class CV;
-
- cst_match(Class Op0) : CV(Op0) {}
-
- bool match(const SCEV *S) const {
- assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
- "no vector types expected from SCEVs");
- auto *C = dyn_cast<SCEVConstant>(S);
- return C && C->getAPInt() == CV;
- }
+struct is_specific_cst {
+ uint64_t CV;
+ is_specific_cst(uint64_t C) : CV(C) {}
+ bool isValue(const APInt &C) const { return C == CV; }
};
/// Match an SCEV constant with a plain unsigned integer.
-inline cst_match<uint64_t> m_scev_SpecificInt(uint64_t V) { return V; }
+inline cst_pred_ty<is_specific_cst> m_scev_SpecificInt(uint64_t V) { return V; }
struct bind_cst_ty {
const APInt *&CR;
More information about the llvm-commits
mailing list