[llvm] 217e0f3 - [SCEV] Add initial pattern matching for SCEV constants. (NFC) (#119389)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 13 02:36:34 PST 2024
Author: Florian Hahn
Date: 2024-12-13T10:36:30Z
New Revision: 217e0f39710dec3348c996ecf98a76fd08b69853
URL: https://github.com/llvm/llvm-project/commit/217e0f39710dec3348c996ecf98a76fd08b69853
DIFF: https://github.com/llvm/llvm-project/commit/217e0f39710dec3348c996ecf98a76fd08b69853.diff
LOG: [SCEV] Add initial pattern matching for SCEV constants. (NFC) (#119389)
Add initial pattern matching for SCEV constants. Follow-up patches will
add additional matchers for various SCEV expressions.
This patch only converts a few instances to use the new matchers to make
sure everything builds as expected for now.
PR: https://github.com/llvm/llvm-project/pull/119389
Added:
llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
Modified:
llvm/lib/Analysis/ScalarEvolution.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
new file mode 100644
index 00000000000000..21d2ef3c867d7d
--- /dev/null
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -0,0 +1,58 @@
+//===----------------------------------------------------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a simple and efficient mechanism for performing general
+// tree-based pattern matches on SCEVs, based on LLVM's IR pattern matchers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+
+namespace llvm {
+namespace SCEVPatternMatch {
+
+template <typename Val, typename Pattern>
+bool match(const SCEV *S, const Pattern &P) {
+ return P.match(S);
+}
+
+template <typename Predicate> struct cst_pred_ty : public Predicate {
+ bool match(const SCEV *S) {
+ assert((isa<SCEVCouldNotCompute>(S) || !S->getType()->isVectorTy()) &&
+ "no vector types expected from SCEVs");
+ auto *C = dyn_cast<SCEVConstant>(S);
+ return C && this->isValue(C->getAPInt());
+ }
+};
+
+struct is_zero {
+ bool isValue(const APInt &C) { return C.isZero(); }
+};
+/// Match an integer 0.
+inline cst_pred_ty<is_zero> m_scev_Zero() { return cst_pred_ty<is_zero>(); }
+
+struct is_one {
+ bool isValue(const APInt &C) { return C.isOne(); }
+};
+/// Match an integer 1.
+inline cst_pred_ty<is_one> m_scev_One() { return cst_pred_ty<is_one>(); }
+
+struct is_all_ones {
+ bool isValue(const APInt &C) { return C.isAllOnes(); }
+};
+/// Match an integer with all bits set.
+inline cst_pred_ty<is_all_ones> m_scev_AllOnes() {
+ return cst_pred_ty<is_all_ones>();
+}
+
+} // namespace SCEVPatternMatch
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index cad10486cbf3fa..e18133971f5bf0 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -79,6 +79,7 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Config/llvm-config.h"
@@ -133,6 +134,7 @@
using namespace llvm;
using namespace PatternMatch;
+using namespace SCEVPatternMatch;
#define DEBUG_TYPE "scalar-evolution"
@@ -443,23 +445,11 @@ ArrayRef<const SCEV *> SCEV::operands() const {
llvm_unreachable("Unknown SCEV kind!");
}
-bool SCEV::isZero() const {
- if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
- return SC->getValue()->isZero();
- return false;
-}
+bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
-bool SCEV::isOne() const {
- if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
- return SC->getValue()->isOne();
- return false;
-}
+bool SCEV::isOne() const { return match(this, m_scev_One()); }
-bool SCEV::isAllOnesValue() const {
- if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
- return SC->getValue()->isMinusOne();
- return false;
-}
+bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
bool SCEV::isNonConstantNegative() const {
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
@@ -3423,9 +3413,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
return S;
// 0 udiv Y == 0
- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
- if (LHSC->getValue()->isZero())
- return LHS;
+ if (match(LHS, m_scev_Zero()))
+ return LHS;
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
if (RHSC->getValue()->isOne())
@@ -10593,7 +10582,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
// Get the initial value for the loop.
const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
- const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
if (!isLoopInvariant(Step, L))
return getCouldNotCompute();
@@ -10615,8 +10603,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
// Handle unitary steps, which cannot wraparound.
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
// N = Distance (as unsigned)
- if (StepC &&
- (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
+
+ if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
@@ -10668,6 +10656,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
}
// Solve the general equation.
+ const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
if (!StepC || StepC->getValue()->isZero())
return getCouldNotCompute();
const SCEV *E = SolveLinEquationWithOverflow(
@@ -15510,9 +15499,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
// If we have LHS == 0, check if LHS is computing a property of some unknown
// SCEV %v which we can rewrite %v to express explicitly.
- const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
- if (Predicate == CmpInst::ICMP_EQ && RHSC &&
- RHSC->getValue()->isNullValue()) {
+ if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
// If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
// explicitly express that.
const SCEV *URemLHS = nullptr;
@@ -15693,8 +15680,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
To = RHS;
break;
case CmpInst::ICMP_NE:
- if (isa<SCEVConstant>(RHS) &&
- cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
+ if (match(RHS, m_scev_Zero())) {
const SCEV *OneAlignedUp =
DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
More information about the llvm-commits
mailing list