[llvm] 4589911 - [SCEVPatternMatch] Extend with more matchers (#138836)
via llvm-commits
llvm-commits at lists.llvm.org
Fri May 9 01:20:18 PDT 2025
Author: Ramkumar Ramachandra
Date: 2025-05-09T09:20:14+01:00
New Revision: 458991197d252e28ce4720a0770ef7d183435eeb
URL: https://github.com/llvm/llvm-project/commit/458991197d252e28ce4720a0770ef7d183435eeb
DIFF: https://github.com/llvm/llvm-project/commit/458991197d252e28ce4720a0770ef7d183435eeb.diff
LOG: [SCEVPatternMatch] Extend with more matchers (#138836)
Added:
Modified:
llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
llvm/lib/Analysis/LoopAccessAnalysis.cpp
llvm/lib/Analysis/ScalarEvolution.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 674147ca175ef..5e53a4f502153 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -23,6 +23,8 @@ template <typename Pattern> bool match(const SCEV *S, const Pattern &P) {
}
template <typename Predicate> struct cst_pred_ty : public Predicate {
+ cst_pred_ty() = default;
+ cst_pred_ty(uint64_t V) : Predicate(V) {}
bool match(const SCEV *S) const {
assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
"no vector types expected from SCEVs");
@@ -58,6 +60,8 @@ template <typename Class> struct class_match {
template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
};
+inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
+
template <typename Class> struct bind_ty {
Class *&VR;
@@ -93,6 +97,34 @@ struct specificscev_ty {
/// Match if we have a specific specified SCEV.
inline specificscev_ty m_Specific(const SCEV *S) { return S; }
+struct is_specific_cst {
+ uint64_t CV;
+ is_specific_cst(uint64_t C) : CV(C) {}
+ bool isValue(const APInt &C) const { return C == CV; }
+};
+
+/// Match an SCEV constant with a plain unsigned integer.
+inline cst_pred_ty<is_specific_cst> m_scev_SpecificInt(uint64_t V) { return V; }
+
+struct bind_cst_ty {
+ const APInt *&CR;
+
+ bind_cst_ty(const APInt *&Op0) : CR(Op0) {}
+
+ bool match(const SCEV *S) const {
+ assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
+ "no vector types expected from SCEVs");
+ auto *C = dyn_cast<SCEVConstant>(S);
+ if (!C)
+ return false;
+ CR = &C->getAPInt();
+ return true;
+ }
+};
+
+/// Match an SCEV constant and bind it to an APInt.
+inline bind_cst_ty m_scev_APInt(const APInt *&C) { return C; }
+
/// Match a unary SCEV.
template <typename SCEVTy, typename Op0_t> struct SCEVUnaryExpr_match {
Op0_t Op0;
@@ -149,6 +181,17 @@ m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVAddExpr>(Op0, Op1);
}
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t>
+m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
+ return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
+m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
+ return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
+}
} // namespace SCEVPatternMatch
} // namespace llvm
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index f222a9905c3bb..535b71cd5897e 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -30,6 +30,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -65,6 +66,7 @@
#include <vector>
using namespace llvm;
+using namespace llvm::SCEVPatternMatch;
#define DEBUG_TYPE "loop-accesses"
@@ -811,8 +813,8 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
const SCEV *Step = AR->getStepRecurrence(*PSE.getSE());
// Calculate the pointer stride and check if it is constant.
- const SCEVConstant *C = dyn_cast<SCEVConstant>(Step);
- if (!C) {
+ const APInt *APStepVal;
+ if (!match(Step, m_scev_APInt(APStepVal))) {
LLVM_DEBUG({
dbgs() << "LAA: Bad stride - Not a constant strided ";
if (Ptr)
@@ -825,13 +827,12 @@ getStrideFromAddRec(const SCEVAddRecExpr *AR, const Loop *Lp, Type *AccessTy,
const auto &DL = Lp->getHeader()->getDataLayout();
TypeSize AllocSize = DL.getTypeAllocSize(AccessTy);
int64_t Size = AllocSize.getFixedValue();
- const APInt &APStepVal = C->getAPInt();
// Huge step value - give up.
- if (APStepVal.getBitWidth() > 64)
+ if (APStepVal->getBitWidth() > 64)
return std::nullopt;
- int64_t StepVal = APStepVal.getSExtValue();
+ int64_t StepVal = APStepVal->getSExtValue();
// Strided access.
int64_t Stride = StepVal / Size;
@@ -2061,11 +2062,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
DL, SE, *(PSE.getSymbolicMaxBackedgeTakenCount()), *Dist, MaxStride))
return Dependence::NoDep;
- const SCEVConstant *ConstDist = dyn_cast<SCEVConstant>(Dist);
-
// Attempt to prove strided accesses independent.
- if (ConstDist) {
- uint64_t Distance = ConstDist->getAPInt().abs().getZExtValue();
+ const APInt *ConstDist = nullptr;
+ if (match(Dist, m_scev_APInt(ConstDist))) {
+ uint64_t Distance = ConstDist->abs().getZExtValue();
// If the distance between accesses and their strides are known constants,
// check whether the accesses interlace each other.
@@ -2111,9 +2111,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
return Dependence::Unknown;
}
- if (!HasSameSize ||
- couldPreventStoreLoadForward(
- ConstDist->getAPInt().abs().getZExtValue(), TypeByteSize)) {
+ if (!HasSameSize || couldPreventStoreLoadForward(
+ ConstDist->abs().getZExtValue(), TypeByteSize)) {
LLVM_DEBUG(
dbgs() << "LAA: Forward but may prevent st->ld forwarding\n");
return Dependence::ForwardButPreventsForwarding;
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index ac69ad598a65a..3f9614254ae7a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7149,16 +7149,11 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&
"Should be!");
- // Peel off a constant offset:
- if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
- // In the future we could consider being smarter here and handle
- // {Start+Step,+,Step} too.
- if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
- return;
-
- Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
- S = SA->getOperand(1);
- }
+ // Peel off a constant offset. In the future we could consider being
+ // smarter here and handle {Start+Step,+,Step} too.
+ const APInt *Off;
+ if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
+ Offset = *Off;
// Peel off a cast operation
if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
@@ -7337,11 +7332,11 @@ bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
return !SCEVExprContains(Op, [this](const SCEV *S) {
- auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
+ const SCEV *Op1;
+ bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
// The UDiv may be UB if the divisor is poison or zero. Unless the divisor
// is a non-zero constant, we have to assume the UDiv may be UB.
- return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
- !isGuaranteedNotToBePoison(UDiv->getOperand(1)));
+ return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
});
}
More information about the llvm-commits
mailing list