[llvm] [SCEV] Add initial matchers for SCEV expressions. (NFC) (PR #119390)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 10 07:12:11 PST 2024
https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/119390
This patch adds initial matchers for SCEV expressions with an arbitrary
number of operands and specializes it for binary add expressions.
Also adds matchers for SCEVConstant and SCEVUnknown.
This patch only converts a few instances to use the new matchers to make
sure everything builds as expected for now.
The goal of the matchers is to hopefully make it slightly easier to write
code matching SCEV patterns.
Depends on https://github.com/llvm/llvm-project/pull/119389
>From a363bee42a7594f983fe742bcd314644ef9f7812 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Dec 2024 14:28:31 +0000
Subject: [PATCH 1/2] [SCEV] Add initial pattern matching for SCEV constants.
(NFC)
Add initial pattern matching for SCEV constants. Follow-up patches will
add additional matchers for various SCEV expressions.
---
.../Analysis/ScalarEvolutionPatternMatch.h | 59 +++++++++++++++++++
llvm/lib/Analysis/ScalarEvolution.cpp | 13 ++--
2 files changed, 66 insertions(+), 6 deletions(-)
create mode 100644 llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
new file mode 100644
index 00000000000000..636b9f8e1544f7
--- /dev/null
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -0,0 +1,59 @@
+//===- ScalarEvolutionPatternMatch.h - Match on SCEVs -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides a simple and efficient mechanism for performing general
+// tree-based pattern matches on SCEVs, based on LLVM's IR pattern matchers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+#define LLVM_ANALYSIS_SCALAREVOLUTIONPATTERNMATCH_H
+
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+
+namespace llvm {
+namespace SCEVPatternMatch {
+
+template <typename Val, typename Pattern>
+bool match(const SCEV *S, const Pattern &P) {
+ return P.match(S);
+}
+
+/// Match a specified integer value. \p BitWidth optionally specifies the
+/// bitwidth the matched constant must have. If it is 0, the matched constant
+/// can have any bitwidth.
+template <unsigned BitWidth = 0> struct specific_intval {
+ APInt Val;
+
+ specific_intval(APInt V) : Val(std::move(V)) {}
+
+ bool match(const SCEV *S) const {
+ const auto *C = dyn_cast<SCEVConstant>(S);
+ if (!C)
+ return false;
+
+ if (BitWidth != 0 && C->getAPInt().getBitWidth() != BitWidth)
+ return false;
+ return APInt::isSameValue(C->getAPInt(), Val);
+ }
+};
+
+inline specific_intval<0> m_scev_Zero() {
+ return specific_intval<0>(APInt(64, 0));
+}
+inline specific_intval<0> m_scev_One() {
+ return specific_intval<0>(APInt(64, 1));
+}
+inline specific_intval<0> m_scev_MinusOne() {
+ return specific_intval<0>(APInt(64, -1));
+}
+
+} // namespace SCEVPatternMatch
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index cad10486cbf3fa..741431ac8aa158 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -79,6 +79,7 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Config/llvm-config.h"
@@ -133,6 +134,7 @@
using namespace llvm;
using namespace PatternMatch;
+using namespace SCEVPatternMatch;
#define DEBUG_TYPE "scalar-evolution"
@@ -3423,9 +3425,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
return S;
// 0 udiv Y == 0
- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
- if (LHSC->getValue()->isZero())
- return LHS;
+ if (match(LHS, m_scev_Zero()))
+ return LHS;
if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
if (RHSC->getValue()->isOne())
@@ -10593,7 +10594,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
// Get the initial value for the loop.
const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
- const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
if (!isLoopInvariant(Step, L))
return getCouldNotCompute();
@@ -10615,8 +10615,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
// Handle unitary steps, which cannot wraparound.
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
// N = Distance (as unsigned)
- if (StepC &&
- (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
+
+ if (match(Step, m_CombineOr(m_scev_One(), m_scev_MinusOne()))) {
APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
@@ -10668,6 +10668,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
}
// Solve the general equation.
+ const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
if (!StepC || StepC->getValue()->isZero())
return getCouldNotCompute();
const SCEV *E = SolveLinEquationWithOverflow(
>From ac5edf57ab3d1aa7d9bf2e72c6a656bcefa18e4a Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 10 Dec 2024 15:01:52 +0000
Subject: [PATCH 2/2] [SCEV] Add initial matchers for SCEV expressions. (NFC)
This patch adds initial matchers for SCEV expressions with an arbitrary
number of operands and specializes it for binary add expressions.
Also adds matchers for SCEVConstant and SCEVUnknown.
This patch only converts a few instances to use the new matchers to make
sure everything builds as expected for now.
Depends on https://github.com/llvm/llvm-project/pull/119389
---
.../Analysis/ScalarEvolutionPatternMatch.h | 81 +++++++++++++++++++
llvm/lib/Analysis/ScalarEvolution.cpp | 12 ++-
2 files changed, 86 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 636b9f8e1544f7..1dff1a672fd225 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -53,6 +53,87 @@ inline specific_intval<0> m_scev_MinusOne() {
return specific_intval<0>(APInt(64, -1));
}
+template <typename Class> struct class_match {
+ template <typename ITy> bool match(ITy *V) const { return isa<Class>(V); }
+};
+
+template <typename Class> struct bind_ty {
+ Class *&VR;
+
+ bind_ty(Class *&V) : VR(V) {}
+
+ template <typename ITy> bool match(ITy *V) const {
+ if (auto *CV = dyn_cast<Class>(V)) {
+ VR = CV;
+ return true;
+ }
+ return false;
+ }
+};
+
+/// Match a SCEV, capturing it if we match.
+inline bind_ty<const SCEV> m_SCEV(const SCEV *&V) { return V; }
+inline bind_ty<const SCEVConstant> m_SCEVConstant(const SCEVConstant *&V) {
+ return V;
+}
+inline bind_ty<const SCEVUnknown> m_SCEVUnknown(const SCEVUnknown *&V) {
+ return V;
+}
+
+namespace detail {
+
+template <typename TupleTy, typename Fn, std::size_t... Is>
+bool CheckTupleElements(const TupleTy &Ops, Fn P, std::index_sequence<Is...>) {
+ return (P(std::get<Is>(Ops), Is) && ...);
+}
+
+/// Helper to check if predicate \p P holds on all tuple elements in \p Ops
+template <typename TupleTy, typename Fn>
+bool all_of_tuple_elements(const TupleTy &Ops, Fn P) {
+ return CheckTupleElements(
+ Ops, P, std::make_index_sequence<std::tuple_size<TupleTy>::value>{});
+}
+
+} // namespace detail
+
+template <typename Ops_t, typename SCEVTy> struct SCEV_match {
+ Ops_t Ops;
+
+ SCEV_match() : Ops() {
+ static_assert(std::tuple_size<Ops_t>::value == 0 &&
+ "constructor can only be used with zero operands");
+ }
+ SCEV_match(Ops_t Ops) : Ops(Ops) {}
+ template <typename A_t, typename B_t> SCEV_match(A_t A, B_t B) : Ops({A, B}) {
+ static_assert(std::tuple_size<Ops_t>::value == 2 &&
+ "constructor can only be used for binary matcher");
+ }
+
+ bool match(const SCEV *S) const {
+ auto *Cast = dyn_cast<SCEVTy>(S);
+ if (!Cast || Cast->getNumOperands() != std::tuple_size<Ops_t>::value)
+ return false;
+ return detail::all_of_tuple_elements(Ops, [Cast](auto Op, unsigned Idx) {
+ return Op.match(Cast->getOperand(Idx));
+ });
+ }
+};
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+using BinarySCEV_match = SCEV_match<std::tuple<Op0_t, Op1_t>, SCEVTy>;
+
+template <typename Op0_t, typename Op1_t, typename SCEVTy>
+inline BinarySCEV_match<Op0_t, Op1_t, SCEVTy> m_scev_Binary(const Op0_t &Op0,
+ const Op1_t &Op1) {
+ return BinarySCEV_match<Op0_t, Op1_t, SCEVTy>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>
+m_scev_Add(const Op0_t &Op0, const Op1_t &Op1) {
+ return BinarySCEV_match<Op0_t, Op1_t, SCEVAddExpr>(Op0, Op1);
+}
+
} // namespace SCEVPatternMatch
} // namespace llvm
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 741431ac8aa158..ed92335606ad90 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15393,14 +15393,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
// (X >=u C1).
auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
&ExprsToRewrite]() {
- auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
- if (!AddExpr || AddExpr->getNumOperands() != 2)
- return false;
-
- auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
- auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
+ const SCEVConstant *C1;
+ const SCEVUnknown *LHSUnknown;
auto *C2 = dyn_cast<SCEVConstant>(RHS);
- if (!C1 || !C2 || !LHSUnknown)
+ if (!match(LHS,
+ m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
+ !C2)
return false;
auto ExactRegion =
More information about the llvm-commits
mailing list