[llvm] [SelectionDAG] Introducing the SelectionDAG pattern matching framework (PR #78654)
Min-Yih Hsu via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 21 16:48:28 PST 2024
================
@@ -0,0 +1,731 @@
+//==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// Contains matchers for matching SelectionDAG nodes and values.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_SDPATTERNMATCH_H
+#define LLVM_CODEGEN_SDPATTERNMATCH_H
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/CodeGen/SelectionDAG.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/CodeGen/TargetLowering.h"
+
+namespace llvm {
+namespace SDPatternMatch {
+
+/// MatchContext can repurpose existing patterns to behave differently under
+/// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes
+/// in normal circumstances, but matches VP_ADD nodes under a custom
+/// VPMatchContext. This design is meant to facilitate code / pattern reusing.
+class BasicMatchContext {
+ const SelectionDAG *DAG;
+ const TargetLowering *TLI;
+
+public:
+ explicit BasicMatchContext(const SelectionDAG *DAG)
+ : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {}
+
+ explicit BasicMatchContext(const TargetLowering *TLI)
+ : DAG(nullptr), TLI(TLI) {}
+
+ // A valid MatchContext has to implement the following functions.
+
+ const SelectionDAG *getDAG() const { return DAG; }
+
+ const TargetLowering *getTLI() const { return TLI; }
+
+ // Optional trait function(s)
+
+ /// Return true if N effectively has opcode Opcode.
+ // bool match(SDValue N, unsigned Opcode)
+};
+
+template <typename MatchContext>
+using ctx_has_get_dag = decltype(std::declval<const MatchContext &>().getDAG());
+
+template <typename MatchContext>
+using ctx_has_get_tli = decltype(std::declval<const MatchContext &>().getTLI());
+
+template <typename MatchContext>
+using ctx_has_match = decltype(std::declval<const MatchContext &>().match(
+ std::declval<SDValue>(), std::declval<unsigned>()));
+
+template <typename Pattern, typename MatchContext>
+[[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx,
+ Pattern &&P) {
+ static_assert(is_detected<ctx_has_get_dag, MatchContext>::value,
+ "Match context has to implement getDAG().");
+ static_assert(is_detected<ctx_has_get_tli, MatchContext>::value,
+ "Match context has to implement getTLI().");
+ return P.match(Ctx, N);
+}
+
+template <typename Pattern, typename MatchContext>
+[[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx,
+ Pattern &&P) {
+ return sd_context_match(SDValue(N, 0), Ctx, P);
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
+ return sd_context_match(N, BasicMatchContext(DAG), P);
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
+ return sd_context_match(N, BasicMatchContext(DAG), P);
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
+ return sd_match(N, nullptr, P);
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
+ return sd_match(N, nullptr, P);
+}
+
+// === Utilities ===
+struct Value_match {
+ SDValue MatchVal;
+
+ Value_match() = default;
+
+ explicit Value_match(SDValue Match) : MatchVal(Match) {}
+
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ if (MatchVal)
+ return MatchVal == N;
+ return N.getNode();
+ }
+};
+
+/// Match any valid SDValue.
+inline Value_match m_Value() { return Value_match(); }
+
+inline Value_match m_Specific(SDValue N) {
+ assert(N);
+ return Value_match(N);
+}
+
+struct DeferredValue_match {
+ SDValue &MatchVal;
+
+ explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {}
+
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ return N == MatchVal;
+ }
+};
+
+/// Similar to m_Specific, but the specific value to match is determined by
+/// another sub-pattern in the same sd_match() expression. For instance,
+/// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since
+/// `X` is not initialized at the time it got copied into `m_Specific`. Instead,
+/// we should use `m_Add(m_Value(X), m_Deferred(X))`.
+inline DeferredValue_match m_Deferred(SDValue &V) {
+ return DeferredValue_match(V);
+}
+
+struct Opcode_match {
+ unsigned Opcode;
+
+ explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
+
+ template <typename MatchContext>
+ std::enable_if_t<is_detected<ctx_has_match, MatchContext>::value, bool>
+ match(const MatchContext &Ctx, SDValue N) {
+ return Ctx.match(N, Opcode);
+ }
+
+ // Default implementation.
+ template <typename MatchContext>
+ std::enable_if_t<!is_detected<ctx_has_match, MatchContext>::value, bool>
+ match(const MatchContext &, SDValue N) {
+ return N->getOpcode() == Opcode;
+ }
+};
+
+inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); }
+
+template <unsigned NumUses, typename Pattern> struct NUses_match {
+ Pattern P;
+
+ explicit NUses_match(const Pattern &P) : P(P) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
+ // multiple results, hence we check the subsequent pattern here before
+ // checking the number of value users.
+ return P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo());
+ }
+};
+
+template <typename Pattern>
+inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) {
+ return NUses_match<1, Pattern>(P);
+}
+template <unsigned N, typename Pattern>
+inline NUses_match<N, Pattern> m_NUses(const Pattern &P) {
+ return NUses_match<N, Pattern>(P);
+}
+
+inline NUses_match<1, Value_match> m_OneUse() {
+ return NUses_match<1, Value_match>(m_Value());
+}
+template <unsigned N> inline NUses_match<N, Value_match> m_NUses() {
+ return NUses_match<N, Value_match>(m_Value());
+}
+
+struct Value_bind {
+ SDValue &BindVal;
+
+ explicit Value_bind(SDValue &N) : BindVal(N) {}
+
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ BindVal = N;
+ return true;
+ }
+};
+
+inline Value_bind m_Value(SDValue &N) { return Value_bind(N); }
+
+template <typename Pattern> struct TLI_pred_match {
+ Pattern P;
+ std::function<bool(const TargetLowering &, SDValue)> PredFunc;
+
+ TLI_pred_match(decltype(PredFunc) &&Pred, const Pattern &P)
+ : P(P), PredFunc(std::move(Pred)) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ assert(Ctx.getTLI() && "TargetLowering is required for this pattern.");
+ return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
+ }
+};
+
+/// Match legal SDNodes based on the information provided by TargetLowering.
+template <typename Pattern>
+inline TLI_pred_match<Pattern> m_LegalOp(const Pattern &P) {
+ return TLI_pred_match<Pattern>(
+ [](const TargetLowering &TLI, SDValue N) {
+ return TLI.isOperationLegal(N->getOpcode(), N.getValueType());
+ },
+ P);
+}
+
+/// Switch to a different MatchContext for subsequent patterns.
+template <typename NewMatchContext, typename Pattern> struct SwitchContext {
+ const NewMatchContext &Ctx;
+ Pattern P;
+
+ template <typename OrigMatchContext>
+ bool match(const OrigMatchContext &, SDValue N) {
+ return P.match(Ctx, N);
+ }
+};
+
+template <typename MatchContext, typename Pattern>
+inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
+ Pattern &&P) {
+ return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
+}
+
+// === Value type ===
+struct ValueType_bind {
+ EVT &BindVT;
+
+ explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {}
+
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ BindVT = N.getValueType();
+ return true;
+ }
+};
+
+/// Retreive the ValueType of the current SDValue.
+inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); }
+
+template <typename Pattern, typename PredFuncT> struct ValueType_match {
+ PredFuncT PredFunc;
+ Pattern P;
+
+ ValueType_match(const PredFuncT &Pred, const Pattern &P)
+ : PredFunc(Pred), P(P) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ return PredFunc(N.getValueType()) && P.match(Ctx, N);
+ }
+};
+
+// Explicit deduction guide.
+template <typename PredFuncT, typename Pattern>
+ValueType_match(const PredFuncT &Pred, const Pattern &P)
+ -> ValueType_match<Pattern, PredFuncT>;
+
+/// Match a specific ValueType.
+template <typename Pattern>
+inline auto m_SpecificVT(EVT RefVT, const Pattern &P) {
+ return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P};
+}
+inline auto m_SpecificVT(EVT RefVT) {
+ return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()};
+}
+
+inline auto m_Glue() { return m_SpecificVT(MVT::Glue); }
+inline auto m_OtherVT() { return m_SpecificVT(MVT::Other); }
+
+/// Match any integer ValueTypes.
+template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) {
+ return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P};
+}
+inline auto m_IntegerVT() {
+ return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()};
+}
+
+/// Match any floating point ValueTypes.
+template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) {
+ return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P};
+}
+inline auto m_FloatingPointVT() {
+ return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); },
+ m_Value()};
+}
+
+/// Match any vector ValueTypes.
+template <typename Pattern> inline auto m_VectorVT(const Pattern &P) {
+ return ValueType_match{[](EVT VT) { return VT.isVector(); }, P};
+}
+inline auto m_VectorVT() {
+ return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()};
+}
+
+/// Match fixed-length vector ValueTypes.
+template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) {
+ return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P};
+}
+inline auto m_FixedVectorVT() {
+ return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); },
+ m_Value()};
+}
+
+/// Match scalable vector ValueTypes.
+template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) {
+ return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P};
+}
+inline auto m_ScalableVectorVT() {
+ return ValueType_match{[](EVT VT) { return VT.isScalableVector(); },
+ m_Value()};
+}
+
+/// Match legal ValueTypes based on the information provided by TargetLowering.
+template <typename Pattern>
+inline TLI_pred_match<Pattern> m_LegalType(const Pattern &P) {
+ return TLI_pred_match<Pattern>(
+ [](const TargetLowering &TLI, SDValue N) {
+ return TLI.isTypeLegal(N.getValueType());
+ },
+ P);
+}
+
+// === Patterns combinators ===
+template <typename... Preds> struct And {
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ return true;
+ }
+};
+
+template <typename Pred, typename... Preds>
+struct And<Pred, Preds...> : And<Preds...> {
+ Pred P;
+ And(Pred &&p, Preds &&...preds)
+ : And<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {
+ }
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ return P.match(Ctx, N) && And<Preds...>::match(Ctx, N);
+ }
+};
+
+template <typename... Preds> struct Or {
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ return false;
+ }
+};
+
+template <typename Pred, typename... Preds>
+struct Or<Pred, Preds...> : Or<Preds...> {
+ Pred P;
+ Or(Pred &&p, Preds &&...preds)
+ : Or<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N);
+ }
+};
+
+template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) {
+ return And<Preds...>(std::forward<Preds>(preds)...);
+}
+
+template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) {
+ return Or<Preds...>(std::forward<Preds>(preds)...);
+}
+
+// === Generic node matching ===
+template <typename... OpndPreds> struct Node_match {
+ unsigned Opcode;
+ unsigned OpIdx;
+
+ Node_match(unsigned Opc, unsigned OpIdx) : Opcode(Opc), OpIdx(OpIdx) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ if (OpIdx == 0) {
+ // Check opcode
+ if (!sd_context_match(N, Ctx, m_Opc(Opcode)))
+ return false;
+ }
+
+ // Returns false if there are more operands than predicates;
+ return N->getNumOperands() == OpIdx;
+ }
+};
+
+template <typename OpndPred, typename... OpndPreds>
+struct Node_match<OpndPred, OpndPreds...> : Node_match<OpndPreds...> {
+ unsigned Opcode;
+ unsigned OpIdx;
+ OpndPred P;
+
+ Node_match(unsigned Opc, unsigned OpIdx, OpndPred &&p, OpndPreds &&...preds)
+ : Node_match<OpndPreds...>(Opc, OpIdx + 1,
+ std::forward<OpndPreds>(preds)...),
+ Opcode(Opc), OpIdx(OpIdx), P(std::forward<OpndPred>(p)) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ if (OpIdx == 0) {
+ // Check opcode
+ if (!sd_context_match(N, Ctx, m_Opc(Opcode)))
+ return false;
+ }
+
+ if (OpIdx < N->getNumOperands())
+ return P.match(Ctx, N->getOperand(OpIdx)) &&
+ Node_match<OpndPreds...>::match(Ctx, N);
+
+ // This is the case where there are more predicates than operands.
+ return false;
+ }
+};
+
+template <typename... OpndPreds>
+Node_match<OpndPreds...> m_Node(unsigned Opcode, OpndPreds &&...preds) {
+ return Node_match<OpndPreds...>(Opcode, 0, std::forward<OpndPreds>(preds)...);
+}
+
+/// Provide number of operands that are not chain or glue, as well as the first
+/// index of such operand.
+template <bool ExcludeChain> struct EffectiveOperands {
+ unsigned Size = 0;
+ unsigned FirstIndex = 0;
+
+ explicit EffectiveOperands(SDValue N) {
+ const unsigned TotalNumOps = N->getNumOperands();
+ FirstIndex = TotalNumOps;
+ for (unsigned i = 0; i < TotalNumOps; ++i) {
----------------
mshockwave wrote:
Done
https://github.com/llvm/llvm-project/pull/78654
More information about the llvm-commits
mailing list