[llvm] 61d3ad9 - [SCEVPatternMatch] Introduce m_scev_AffineAddRec (#140377)
via llvm-commits
llvm-commits at lists.llvm.org
Mon May 19 04:02:10 PDT 2025
Author: Ramkumar Ramachandra
Date: 2025-05-19T12:02:07+01:00
New Revision: 61d3ad963c9a8764a20c01c26374000d9ba5975d
URL: https://github.com/llvm/llvm-project/commit/61d3ad963c9a8764a20c01c26374000d9ba5975d
DIFF: https://github.com/llvm/llvm-project/commit/61d3ad963c9a8764a20c01c26374000d9ba5975d.diff
LOG: [SCEVPatternMatch] Introduce m_scev_AffineAddRec (#140377)
Introduce m_scev_AffineAddRec to match affine AddRecs, a class_match for
SCEVConstant, and demonstrate their utility in LSR and SCEV. While at
it, rename m_Specific to m_scev_Specific for clarity.
Added:
Modified:
llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 5e53a4f502153..cfb1b4c6ea6b4 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -61,6 +61,9 @@ template <typename Class> struct class_match {
};
inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
+inline class_match<const SCEVConstant> m_SCEVConstant() {
+ return class_match<const SCEVConstant>();
+}
template <typename Class> struct bind_ty {
Class *&VR;
@@ -95,7 +98,7 @@ struct specificscev_ty {
};
/// Match if we have a specific specified SCEV.
-inline specificscev_ty m_Specific(const SCEV *S) { return S; }
+inline specificscev_ty m_scev_Specific(const SCEV *S) { return S; }
struct is_specific_cst {
uint64_t CV;
@@ -192,6 +195,12 @@ inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
}
+
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVAddRecExpr, Op0_t, Op1_t>
+m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1) {
+ return m_scev_Binary<SCEVAddRecExpr>(Op0, Op1);
+}
} // namespace SCEVPatternMatch
} // namespace llvm
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 69714a112310e..342c8e39e2d3c 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12480,26 +12480,21 @@ static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
if (!ICmpInst::isRelational(Pred))
return false;
- const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
- if (!LAR)
- return false;
- const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
- if (!RAR)
+ const SCEV *LStart, *RStart, *Step;
+ if (!match(LHS, m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step))) ||
+ !match(RHS, m_scev_AffineAddRec(m_SCEV(RStart), m_scev_Specific(Step))))
return false;
+ const SCEVAddRecExpr *LAR = cast<SCEVAddRecExpr>(LHS);
+ const SCEVAddRecExpr *RAR = cast<SCEVAddRecExpr>(RHS);
if (LAR->getLoop() != RAR->getLoop())
return false;
- if (!LAR->isAffine() || !RAR->isAffine())
- return false;
-
- if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
- return false;
SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
SCEV::FlagNSW : SCEV::FlagNUW;
if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
return false;
- return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
+ return SE.isKnownPredicate(Pred, LStart, RStart);
}
/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
@@ -12716,7 +12711,7 @@ static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS,
case ICmpInst::ICMP_SLE: {
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
- match(RHS, m_scev_ZExt(m_Specific(Op)));
+ match(RHS, m_scev_ZExt(m_scev_Specific(Op)));
}
case ICmpInst::ICMP_UGE:
std::swap(LHS, RHS);
@@ -12724,7 +12719,7 @@ static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS,
case ICmpInst::ICMP_ULE: {
// If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
- match(RHS, m_scev_SExt(m_Specific(Op)));
+ match(RHS, m_scev_SExt(m_scev_Specific(Op)));
}
default:
return false;
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 33d6c77f61cfd..bdab14ed34c54 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -77,11 +77,11 @@
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/ScalarEvolutionNormalization.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/BinaryFormat/Dwarf.h"
-#include "llvm/Config/llvm-config.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
@@ -128,6 +128,7 @@
#include <utility>
using namespace llvm;
+using namespace SCEVPatternMatch;
#define DEBUG_TYPE "loop-reduce"
@@ -556,16 +557,17 @@ static void DoInitialMatch(const SCEV *S, Loop *L,
}
// Look at addrec operands.
- if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S))
- if (!AR->getStart()->isZero() && AR->isAffine()) {
- DoInitialMatch(AR->getStart(), L, Good, Bad, SE);
- DoInitialMatch(SE.getAddRecExpr(SE.getConstant(AR->getType(), 0),
- AR->getStepRecurrence(SE),
- // FIXME: AR->getNoWrapFlags()
- AR->getLoop(), SCEV::FlagAnyWrap),
- L, Good, Bad, SE);
- return;
- }
+ const SCEV *Start, *Step;
+ if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step))) &&
+ !Start->isZero()) {
+ DoInitialMatch(Start, L, Good, Bad, SE);
+ DoInitialMatch(SE.getAddRecExpr(SE.getConstant(S->getType(), 0), Step,
+ // FIXME: AR->getNoWrapFlags()
+ cast<SCEVAddRecExpr>(S)->getLoop(),
+ SCEV::FlagAnyWrap),
+ L, Good, Bad, SE);
+ return;
+ }
// Handle a multiplication by -1 (negation) if it didn't fold.
if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S))
@@ -1411,22 +1413,16 @@ void Cost::RateRegister(const Formula &F, const SCEV *Reg,
unsigned LoopCost = 1;
if (TTI->isIndexedLoadLegal(TTI->MIM_PostInc, AR->getType()) ||
TTI->isIndexedStoreLegal(TTI->MIM_PostInc, AR->getType())) {
-
- // If the step size matches the base offset, we could use pre-indexed
- // addressing.
- if (AMK == TTI::AMK_PreIndexed && F.BaseOffset.isFixed()) {
- if (auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)))
- if (Step->getAPInt() == F.BaseOffset.getFixedValue())
- LoopCost = 0;
- } else if (AMK == TTI::AMK_PostIndexed) {
- const SCEV *LoopStep = AR->getStepRecurrence(*SE);
- if (isa<SCEVConstant>(LoopStep)) {
- const SCEV *LoopStart = AR->getStart();
- if (!isa<SCEVConstant>(LoopStart) &&
- SE->isLoopInvariant(LoopStart, L))
- LoopCost = 0;
- }
- }
+ const SCEV *Start;
+ const SCEVConstant *Step;
+ if (match(AR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEVConstant(Step))))
+ // If the step size matches the base offset, we could use pre-indexed
+ // addressing.
+ if ((AMK == TTI::AMK_PreIndexed && F.BaseOffset.isFixed() &&
+ Step->getAPInt() == F.BaseOffset.getFixedValue()) ||
+ (AMK == TTI::AMK_PostIndexed && !isa<SCEVConstant>(Start) &&
+ SE->isLoopInvariant(Start, L)))
+ LoopCost = 0;
}
C.AddRecCost += LoopCost;
@@ -2519,13 +2515,11 @@ ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) {
// Check the relevant induction variable for conformance to
// the pattern.
const SCEV *IV = SE.getSCEV(Cond->getOperand(0));
- const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(IV);
- if (!AR || !AR->isAffine() ||
- AR->getStart() != One ||
- AR->getStepRecurrence(SE) != One)
+ if (!match(IV,
+ m_scev_AffineAddRec(m_scev_SpecificInt(1), m_scev_SpecificInt(1))))
return Cond;
- assert(AR->getLoop() == L &&
+ assert(cast<SCEVAddRecExpr>(IV)->getLoop() == L &&
"Loop condition operand is an addrec in a
diff erent loop!");
// Check the right operand of the select, and remember it, as it will
@@ -3320,7 +3314,7 @@ void LSRInstance::CollectChains() {
void LSRInstance::FinalizeChain(IVChain &Chain) {
assert(!Chain.Incs.empty() && "empty IV chains are not allowed");
LLVM_DEBUG(dbgs() << "Final Chain: " << *Chain.Incs[0].UserInst << "\n");
-
+
for (const IVInc &Inc : Chain) {
LLVM_DEBUG(dbgs() << " Inc: " << *Inc.UserInst << "\n");
auto UseI = find(Inc.UserInst->operands(), Inc.IVOperand);
@@ -3823,26 +3817,27 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder);
}
return nullptr;
- } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
+ }
+ const SCEV *Start, *Step;
+ if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) {
// Split a non-zero base out of an addrec.
- if (AR->getStart()->isZero() || !AR->isAffine())
+ if (Start->isZero())
return S;
- const SCEV *Remainder = CollectSubexprs(AR->getStart(),
- C, Ops, L, SE, Depth+1);
+ const SCEV *Remainder = CollectSubexprs(Start, C, Ops, L, SE, Depth + 1);
// Split the non-zero AddRec unless it is part of a nested recurrence that
// does not pertain to this loop.
- if (Remainder && (AR->getLoop() == L || !isa<SCEVAddRecExpr>(Remainder))) {
+ if (Remainder && (cast<SCEVAddRecExpr>(S)->getLoop() == L ||
+ !isa<SCEVAddRecExpr>(Remainder))) {
Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder);
Remainder = nullptr;
}
- if (Remainder != AR->getStart()) {
+ if (Remainder != Start) {
if (!Remainder)
- Remainder = SE.getConstant(AR->getType(), 0);
- return SE.getAddRecExpr(Remainder,
- AR->getStepRecurrence(SE),
- AR->getLoop(),
- //FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
+ Remainder = SE.getConstant(S->getType(), 0);
+ return SE.getAddRecExpr(Remainder, Step,
+ cast<SCEVAddRecExpr>(S)->getLoop(),
+ // FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
SCEV::FlagAnyWrap);
}
} else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
@@ -3870,17 +3865,13 @@ static bool mayUsePostIncMode(const TargetTransformInfo &TTI,
if (LU.Kind != LSRUse::Address ||
!LU.AccessTy.getType()->isIntOrIntVectorTy())
return false;
- const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S);
- if (!AR)
- return false;
- const SCEV *LoopStep = AR->getStepRecurrence(SE);
- if (!isa<SCEVConstant>(LoopStep))
+ const SCEV *Start;
+ if (!match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEVConstant())))
return false;
// Check if a post-indexed load/store can be used.
- if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, AR->getType()) ||
- TTI.isIndexedStoreLegal(TTI.MIM_PostInc, AR->getType())) {
- const SCEV *LoopStart = AR->getStart();
- if (!isa<SCEVConstant>(LoopStart) && SE.isLoopInvariant(LoopStart, L))
+ if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, S->getType()) ||
+ TTI.isIndexedStoreLegal(TTI.MIM_PostInc, S->getType())) {
+ if (!isa<SCEVConstant>(Start) && SE.isLoopInvariant(Start, L))
return true;
}
return false;
@@ -4139,18 +4130,15 @@ void LSRInstance::GenerateConstantOffsetsImpl(
// base pointer for each iteration of the loop, resulting in no extra add/sub
// instructions for pointer updating.
if (AMK == TTI::AMK_PreIndexed && LU.Kind == LSRUse::Address) {
- if (auto *GAR = dyn_cast<SCEVAddRecExpr>(G)) {
- if (auto *StepRec =
- dyn_cast<SCEVConstant>(GAR->getStepRecurrence(SE))) {
- const APInt &StepInt = StepRec->getAPInt();
- int64_t Step = StepInt.isNegative() ?
- StepInt.getSExtValue() : StepInt.getZExtValue();
-
- for (Immediate Offset : Worklist) {
- if (Offset.isFixed()) {
- Offset = Immediate::getFixed(Offset.getFixedValue() - Step);
- GenerateOffset(G, Offset);
- }
+ const APInt *StepInt;
+ if (match(G, m_scev_AffineAddRec(m_SCEV(), m_scev_APInt(StepInt)))) {
+ int64_t Step = StepInt->isNegative() ? StepInt->getSExtValue()
+ : StepInt->getZExtValue();
+
+ for (Immediate Offset : Worklist) {
+ if (Offset.isFixed()) {
+ Offset = Immediate::getFixed(Offset.getFixedValue() - Step);
+ GenerateOffset(G, Offset);
}
}
}
@@ -6621,7 +6609,7 @@ struct SCEVDbgValueBuilder {
if (Op.getOp() != dwarf::DW_OP_LLVM_arg) {
Op.appendToVector(DestExpr);
continue;
- }
+ }
DestExpr.push_back(dwarf::DW_OP_LLVM_arg);
// `DW_OP_LLVM_arg n` represents the nth LocationOp in this SCEV,
More information about the llvm-commits
mailing list