[llvm] [SelectionDAG] Introducing the SelectionDAG pattern matching framework (PR #78654)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 13 08:35:25 PST 2024


================
@@ -0,0 +1,712 @@
+//==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// Contains matchers for matching SelectionDAG nodes and values.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_SDPATTERNMATCH_H
+#define LLVM_CODEGEN_SDPATTERNMATCH_H
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/CodeGen/SelectionDAG.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/CodeGen/TargetLowering.h"
+
+namespace llvm {
+namespace SDPatternMatch {
+
+/// MatchContext can repurpose existing patterns to behave differently under
+/// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes
+/// in normal circumstances, but matches VP_ADD nodes under a custom
+/// VPMatchContext. This design is meant to facilitate code / pattern reusing.
+class BasicMatchContext {
+  const SelectionDAG *DAG;
+  const TargetLowering *TLI;
+
+public:
+  explicit BasicMatchContext(const SelectionDAG *DAG)
+      : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {}
+
+  explicit BasicMatchContext(const TargetLowering *TLI)
+      : DAG(nullptr), TLI(TLI) {}
+
+  // A valid MatchContext has to implement the following functions.
+
+  const SelectionDAG *getDAG() const { return DAG; }
+
+  const TargetLowering *getTLI() const {
+    if (TLI)
+      return TLI;
+    return DAG ? &DAG->getTargetLoweringInfo() : nullptr;
+  }
+
+  // Optional trait function(s)
+
+  /// Return true if N effectively has opcode Opcode.
+  // bool match(SDValue N, unsigned Opcode)
+};
+
+template <typename MatchContext>
+using ctx_has_get_dag = decltype(std::declval<const MatchContext &>().getDAG());
+
+template <typename MatchContext>
+using ctx_has_get_tli = decltype(std::declval<const MatchContext &>().getTLI());
+
+template <typename MatchContext>
+using ctx_has_match = decltype(std::declval<const MatchContext &>().match(
+    std::declval<SDValue>(), std::declval<unsigned>()));
+
+template <typename Pattern, typename MatchContext>
+[[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx,
+                                    Pattern &&P) {
+  static_assert(is_detected<ctx_has_get_dag, MatchContext>::value,
+                "Match context has to implement getDAG().");
+  static_assert(is_detected<ctx_has_get_tli, MatchContext>::value,
+                "Match context has to implement getTLI().");
+  return P.match(Ctx, N);
+}
+
+template <typename Pattern, typename MatchContext>
+[[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx,
+                                    Pattern &&P) {
+  return sd_context_match(SDValue(N, 0), Ctx, P);
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
+  return sd_context_match(N, BasicMatchContext(DAG), P);
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
+  return sd_context_match(N, BasicMatchContext(DAG), P);
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
+  return sd_match(N, nullptr, P);
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
+  return sd_match(N, nullptr, P);
+}
+
+// === Utilities ===
+struct Value_match {
+  SDValue MatchVal;
+
+  Value_match() = default;
+
+  explicit Value_match(SDValue Match) : MatchVal(Match) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    return (MatchVal && (MatchVal == N)) || N.getNode();
+  }
+};
+
+/// Match any valid SDValue.
+inline Value_match m_Value() { return Value_match(); }
+
+inline Value_match m_Specific(SDValue N) { return Value_match(N); }
+
+struct Opcode_match {
+  unsigned Opcode;
+
+  explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
+
+  template <typename MatchContext>
+  std::enable_if_t<is_detected<ctx_has_match, MatchContext>::value, bool>
+  match(const MatchContext &Ctx, SDValue N) {
+    return N && Ctx.match(N, Opcode);
+  }
+
+  // Default implementation.
+  template <typename MatchContext>
+  std::enable_if_t<!is_detected<ctx_has_match, MatchContext>::value, bool>
+  match(const MatchContext &, SDValue N) {
+    return N && N->getOpcode() == Opcode;
+  }
+};
+
+inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); }
+
+template <unsigned NumUses, typename Pattern> struct NUses_match {
+  Pattern P;
+
+  explicit NUses_match(const Pattern &P) : P(P) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
+    // multiple results, hence we check the subsequent pattern here before
+    // checking the number of value users.
+    return N && P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo());
+  }
+};
+
+template <typename Pattern>
+inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) {
+  return NUses_match<1, Pattern>(P);
+}
+template <unsigned N, typename Pattern>
+inline NUses_match<N, Pattern> m_NUses(const Pattern &P) {
+  return NUses_match<N, Pattern>(P);
+}
+
+inline NUses_match<1, Value_match> m_OneUse() {
+  return NUses_match<1, Value_match>(m_Value());
+}
+template <unsigned N> inline NUses_match<N, Value_match> m_NUses() {
+  return NUses_match<N, Value_match>(m_Value());
+}
+
+struct Value_bind {
+  SDValue &BindVal;
+
+  explicit Value_bind(SDValue &N) : BindVal(N) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    if (N) {
+      BindVal = N;
+      return true;
+    }
+    return false;
+  }
+};
+
+inline Value_bind m_Value(SDValue &N) { return Value_bind(N); }
+
+template <typename Pattern> struct TLI_pred_match {
+  Pattern P;
+  std::function<bool(const TargetLowering &, SDValue)> PredFunc;
+
+  TLI_pred_match(decltype(PredFunc) &&Pred, const Pattern &P)
+      : P(P), PredFunc(std::move(Pred)) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    return Ctx.getTLI() && N && PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
+  }
+};
+
+/// Match legal SDNodes based on the information provided by TargetLowering.
+template <typename Pattern>
+inline TLI_pred_match<Pattern> m_LegalOp(const Pattern &P) {
+  return TLI_pred_match<Pattern>(
+      [](const TargetLowering &TLI, SDValue N) {
+        return TLI.isOperationLegal(N->getOpcode(), N.getValueType());
+      },
+      P);
+}
+
+/// Switch to a different MatchContext for subsequent patterns.
+template <typename NewMatchContext, typename Pattern> struct SwitchContext {
+  const NewMatchContext &Ctx;
+  Pattern P;
+
+  template <typename OrigMatchContext>
+  bool match(const OrigMatchContext &, SDValue N) {
+    return P.match(Ctx, N);
+  }
+};
+
+template <typename MatchContext, typename Pattern>
+inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
+                                                      Pattern &&P) {
+  return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
+}
+
+// === Value type ===
+struct ValueType_bind {
+  EVT &BindVT;
+
+  explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    if (!N)
+      return false;
+    BindVT = N.getValueType();
+    return true;
+  }
+};
+
+/// Retreive the ValueType of the current SDValue.
+inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); }
+
+template <typename Pattern> struct ValueType_match {
+  std::function<bool(EVT)> PredFunc;
----------------
nikic wrote:

Don't use std::function, it is super expensive. As you can't use function_ref here, I guess you'll have to template over the predicate as well.

https://github.com/llvm/llvm-project/pull/78654


More information about the llvm-commits mailing list