[llvm] [SelectionDAG] Introducing the SelectionDAG pattern matching framework (PR #78654)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 16:48:19 PST 2024


https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/78654

>From f22ad774f1bb49a338dcc6a9015d55c2e9069e56 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Thu, 18 Jan 2024 15:56:51 -0800
Subject: [PATCH 1/7] [SelecitonDAG] Introducing SelectionDAG pattern matching
 framework

TBA...
---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 612 ++++++++++++++++++
 llvm/unittests/CodeGen/CMakeLists.txt         |   1 +
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 249 +++++++
 3 files changed, 862 insertions(+)
 create mode 100644 llvm/include/llvm/CodeGen/SDPatternMatch.h
 create mode 100644 llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
new file mode 100644
index 00000000000000..03393abb11a7e1
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -0,0 +1,612 @@
+//==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// Contains matchers for matching SelectionDAG nodes and values.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_SDPATTERNMATCH_H
+#define LLVM_CODEGEN_SDPATTERNMATCH_H
+
+#include "llvm/ADT/APInt.h"
+#include "llvm/CodeGen/SelectionDAG.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/CodeGen/TargetLowering.h"
+
+namespace llvm {
+namespace SDPatternMatch {
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
+  return P.match(DAG, SDValue(N, 0));
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
+  return P.match(DAG, N);
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
+  return P.match(nullptr, SDValue(N, 0));
+}
+
+template <typename Pattern>
+[[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
+  return P.match(nullptr, N);
+}
+
+// === Utilities ===
+struct Value_match {
+  SDValue MatchVal;
+
+  Value_match() = default;
+
+  explicit Value_match(SDValue Match) : MatchVal(Match) {}
+
+  bool match(const SelectionDAG *, SDValue N) {
+    return (MatchVal && (MatchVal == N)) || N.getNode();
+  }
+};
+
+/// Match any valid SDValue.
+inline Value_match m_Value() { return Value_match(); }
+
+inline Value_match m_Specific(SDValue N) { return Value_match(N); }
+
+struct Opcode_match {
+  unsigned Opcode;
+
+  explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
+
+  bool match(const SelectionDAG *, SDValue N) {
+    return N && N->getOpcode() == Opcode;
+  }
+};
+
+inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); }
+
+template <unsigned NumUses, typename Pattern> struct NUses_match {
+  Pattern P;
+
+  explicit NUses_match(const Pattern &P) : P(P) {}
+
+  bool match(const SelectionDAG *DAG, SDValue N) {
+    return N && N->hasNUsesOfValue(NumUses, N.getResNo()) && P.match(DAG, N);
+  }
+};
+
+template <typename Pattern>
+inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) {
+  return NUses_match<1, Pattern>(P);
+}
+template <unsigned N, typename Pattern>
+inline NUses_match<N, Pattern> m_NUses(const Pattern &P) {
+  return NUses_match<N, Pattern>(P);
+}
+
+inline NUses_match<1, Value_match> m_OneUse() {
+  return NUses_match<1, Value_match>(m_Value());
+}
+template <unsigned N> inline NUses_match<N, Value_match> m_NUses() {
+  return NUses_match<N, Value_match>(m_Value());
+}
+
+struct Value_bind {
+  SDValue &BindVal;
+
+  explicit Value_bind(SDValue &N) : BindVal(N) {}
+
+  bool match(const SelectionDAG *, SDValue N) {
+    if (N) {
+      BindVal = N;
+      return true;
+    }
+    return false;
+  }
+};
+
+inline Value_bind m_Value(SDValue &N) { return Value_bind(N); }
+
+template <typename Pattern> struct TLI_pred_match {
+  Pattern P;
+  std::function<bool(const TargetLowering &, SDValue)> PredFunc;
+
+  TLI_pred_match(decltype(PredFunc) &&Pred, const Pattern &P)
+      : P(P), PredFunc(std::move(Pred)) {}
+
+  bool match(const SelectionDAG *DAG, SDValue N) {
+    return DAG && N && PredFunc(DAG->getTargetLoweringInfo(), N) &&
+           P.match(DAG, N);
+  }
+};
+
+/// Match legal SDNodes based on the information provided by TargetLowering.
+template <typename Pattern>
+inline TLI_pred_match<Pattern> m_LegalOp(const Pattern &P) {
+  return TLI_pred_match<Pattern>(
+      [](const TargetLowering &TLI, SDValue N) {
+        return TLI.isOperationLegal(N->getOpcode(), N.getValueType());
+      },
+      P);
+}
+
+// === Value type ===
+struct ValueType_bind {
+  EVT &BindVT;
+
+  explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {}
+
+  bool match(const SelectionDAG *, SDValue N) {
+    if (!N)
+      return false;
+    BindVT = N.getValueType();
+    return true;
+  }
+};
+
+/// Retreive the ValueType of the current SDValue.
+inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); }
+
+template <typename Pattern> struct ValueType_match {
+  std::function<bool(EVT)> PredFunc;
+  Pattern P;
+
+  ValueType_match(decltype(PredFunc) &&Pred, const Pattern &P)
+      : PredFunc(std::move(Pred)), P(P) {}
+
+  bool match(const SelectionDAG *DAG, SDValue N) {
+    return N && PredFunc(N.getValueType()) && P.match(DAG, N);
+  }
+};
+
+/// Match a specific ValueType.
+template <typename Pattern>
+inline ValueType_match<Pattern> m_SpecificVT(EVT RefVT, const Pattern &P) {
+  return ValueType_match<Pattern>([=](EVT VT) { return VT == RefVT; }, P);
+}
+inline ValueType_match<Value_match> m_SpecificVT(EVT RefVT) {
+  return ValueType_match<Value_match>([=](EVT VT) { return VT == RefVT; },
+                                      m_Value());
+}
+
+inline ValueType_match<Value_match> m_Glue() { return m_SpecificVT(MVT::Glue); }
+inline ValueType_match<Value_match> m_OtherVT() {
+  return m_SpecificVT(MVT::Other);
+}
+
+/// Match any integer ValueTypes.
+template <typename Pattern>
+inline ValueType_match<Pattern> m_IntegerVT(const Pattern &P) {
+  return ValueType_match<Pattern>([](EVT VT) { return VT.isInteger(); }, P);
+}
+inline ValueType_match<Value_match> m_IntegerVT() {
+  return ValueType_match<Value_match>([](EVT VT) { return VT.isInteger(); },
+                                      m_Value());
+}
+
+/// Match any floating point ValueTypes.
+template <typename Pattern>
+inline ValueType_match<Pattern> m_FloatingPointVT(const Pattern &P) {
+  return ValueType_match<Pattern>([](EVT VT) { return VT.isFloatingPoint(); },
+                                  P);
+}
+inline ValueType_match<Value_match> m_FloatingPointVT() {
+  return ValueType_match<Value_match>(
+      [](EVT VT) { return VT.isFloatingPoint(); }, m_Value());
+}
+
+/// Match any vector ValueTypes.
+template <typename Pattern>
+inline ValueType_match<Pattern> m_VectorVT(const Pattern &P) {
+  return ValueType_match<Pattern>([](EVT VT) { return VT.isVector(); }, P);
+}
+inline ValueType_match<Value_match> m_VectorVT() {
+  return ValueType_match<Value_match>([](EVT VT) { return VT.isVector(); },
+                                      m_Value());
+}
+
+/// Match fixed-length vector ValueTypes.
+template <typename Pattern>
+inline ValueType_match<Pattern> m_FixedVectorVT(const Pattern &P) {
+  return ValueType_match<Pattern>(
+      [](EVT VT) { return VT.isFixedLengthVector(); }, P);
+}
+inline ValueType_match<Value_match> m_FixedVectorVT() {
+  return ValueType_match<Value_match>(
+      [](EVT VT) { return VT.isFixedLengthVector(); }, m_Value());
+}
+
+/// Match scalable vector ValueTypes.
+template <typename Pattern>
+inline ValueType_match<Pattern> m_ScalableVectorVT(const Pattern &P) {
+  return ValueType_match<Pattern>([](EVT VT) { return VT.isScalableVector(); },
+                                  P);
+}
+inline ValueType_match<Value_match> m_ScalableVectorVT() {
+  return ValueType_match<Value_match>(
+      [](EVT VT) { return VT.isScalableVector(); }, m_Value());
+}
+
+/// Match legal ValueTypes based on the information provided by TargetLowering.
+template <typename Pattern>
+inline TLI_pred_match<Pattern> m_LegalType(const Pattern &P) {
+  return TLI_pred_match<Pattern>(
+      [](const TargetLowering &TLI, SDValue N) {
+        return TLI.isTypeLegal(N.getValueType());
+      },
+      P);
+}
+
+// === Patterns combinators ===
+template <typename... Preds> struct And {
+  bool match(const SelectionDAG *, SDValue N) { return true; }
+};
+
+template <typename Pred, typename... Preds>
+struct And<Pred, Preds...> : And<Preds...> {
+  Pred P;
+  And(Pred &&p, Preds &&...preds)
+      : And<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {
+  }
+
+  bool match(const SelectionDAG *DAG, SDValue N) {
+    return P.match(DAG, N) && And<Preds...>::match(DAG, N);
+  }
+};
+
+template <typename... Preds> struct Or {
+  bool match(const SelectionDAG *, SDValue N) { return false; }
+};
+
+template <typename Pred, typename... Preds>
+struct Or<Pred, Preds...> : Or<Preds...> {
+  Pred P;
+  Or(Pred &&p, Preds &&...preds)
+      : Or<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {}
+
+  bool match(const SelectionDAG *DAG, SDValue N) {
+    return P.match(DAG, N) || Or<Preds...>::match(DAG, N);
+  }
+};
+
+template <typename... Preds> And<Preds...> m_all_of(Preds &&...preds) {
+  return And<Preds...>(std::forward<Preds>(preds)...);
+}
+
+template <typename... Preds> Or<Preds...> m_any_of(Preds &&...preds) {
+  return Or<Preds...>(std::forward<Preds>(preds)...);
+}
+
+// === Generic node matching ===
+template <typename... OpndPreds> struct Node_match {
+  unsigned Opcode;
+  unsigned OpIdx;
+
+  Node_match(unsigned Opc, unsigned OpIdx) : Opcode(Opc), OpIdx(OpIdx) {}
+
+  bool match(const SelectionDAG *, SDValue N) {
+    if (!N)
+      return false;
+
+    if (OpIdx == 0) {
+      // Check opcode
+      if (N->getOpcode() != Opcode)
+        return false;
+    }
+
+    // Returns false if there are more operands than predicates;
+    return N->getNumOperands() == OpIdx;
+  }
+};
+
+template <typename OpndPred, typename... OpndPreds>
+struct Node_match<OpndPred, OpndPreds...> : Node_match<OpndPreds...> {
+  unsigned Opcode;
+  unsigned OpIdx;
+  OpndPred P;
+
+  Node_match(unsigned Opc, unsigned OpIdx, OpndPred &&p, OpndPreds &&...preds)
+      : Node_match<OpndPreds...>(Opc, OpIdx + 1,
+                                 std::forward<OpndPreds>(preds)...),
+        Opcode(Opc), OpIdx(OpIdx), P(std::forward<OpndPred>(p)) {}
+
+  bool match(const SelectionDAG *DAG, SDValue N) {
+    if (!N)
+      return false;
+
+    if (OpIdx == 0) {
+      // Check opcode
+      if (N->getOpcode() != Opcode)
+        return false;
+    }
+
+    if (OpIdx < N->getNumOperands())
+      return P.match(DAG, N->getOperand(OpIdx)) &&
+             Node_match<OpndPreds...>::match(DAG, N);
+
+    // This is the case where there are more predicates than operands.
+    return false;
+  }
+};
+
+template <typename... OpndPreds>
+Node_match<OpndPreds...> m_Node(unsigned Opcode, OpndPreds &&...preds) {
+  return Node_match<OpndPreds...>(Opcode, 0, std::forward<OpndPreds>(preds)...);
+}
+
+/// Provide number of operands that are not chain or glue, as well as the first
+/// index of such operand.
+struct EffectiveOperands {
+  unsigned Size = 0;
+  unsigned FirstIndex = 0;
+
+  explicit EffectiveOperands(SDValue N) {
+    const unsigned TotalNumOps = N->getNumOperands();
+    FirstIndex = TotalNumOps;
+    for (unsigned i = 0; i < TotalNumOps; ++i) {
+      // Count the number of non-chain and non-glue nodes (we ignore chain
+      // and glue by default) and retreive the operand index offset.
+      EVT VT = N->getOperand(i).getValueType();
+      if (VT != MVT::Glue && VT != MVT::Other) {
+        ++Size;
+        if (FirstIndex == TotalNumOps)
+          FirstIndex = i;
+      }
+    }
+  }
+};
+
+// === Binary operations ===
+template <typename LHS_P, typename RHS_P, bool Commutable = false>
+struct BinaryOpc_match {
+  unsigned Opcode;
+  LHS_P LHS;
+  RHS_P RHS;
+
+  BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R)
+      : Opcode(Opc), LHS(L), RHS(R) {}
+
+  bool match(const SelectionDAG *DAG, SDValue N) {
+    if (!N)
+      return false;
+
+    if (N->getOpcode() == Opcode) {
+      EffectiveOperands EO(N);
+      if (EO.Size == 2)
+        return (LHS.match(DAG, N->getOperand(EO.FirstIndex)) &&
+                RHS.match(DAG, N->getOperand(EO.FirstIndex + 1))) ||
+               (Commutable &&
+                LHS.match(DAG, N->getOperand(EO.FirstIndex + 1)) &&
+                RHS.match(DAG, N->getOperand(EO.FirstIndex)));
+    }
+
+    return false;
+  }
+};
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_BinOp(unsigned Opc, const LHS &L,
+                                                const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(Opc, L, R);
+}
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true> m_c_BinOp(unsigned Opc, const LHS &L,
+                                                 const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, true>(Opc, L, R);
+}
+
+// Common binary operations
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_Sub(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::SUB, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_UDiv(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::UDIV, L, R);
+}
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_SDiv(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::SDIV, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_URem(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::UREM, L, R);
+}
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_SRem(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::SREM, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_Shl(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::SHL, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_Sra(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::SRA, L, R);
+}
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_Srl(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::SRL, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_FSub(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::FSUB, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_FDiv(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::FDIV, L, R);
+}
+
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false> m_FRem(const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false>(ISD::FREM, L, R);
+}
+
+// === Unary operations ===
+template <typename Opnd_P> struct UnaryOpc_match {
+  unsigned Opcode;
+  Opnd_P Opnd;
+
+  UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {}
+
+  bool match(const SelectionDAG *DAG, SDValue N) {
+    if (!N)
+      return false;
+
+    if (N->getOpcode() == Opcode) {
+      EffectiveOperands EO(N);
+      if (EO.Size == 1)
+        return Opnd.match(DAG, N->getOperand(EO.FirstIndex));
+    }
+
+    return false;
+  }
+};
+
+template <typename Opnd>
+inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) {
+  return UnaryOpc_match<Opnd>(Opc, Op);
+}
+
+template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
+  return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
+}
+
+template <typename Opnd> inline UnaryOpc_match<Opnd> m_SExt(const Opnd &Op) {
+  return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
+}
+
+template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
+  return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op);
+}
+
+template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
+  return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op);
+}
+
+// === Constants ===
+struct ConstantInt_match {
+  APInt *BindVal;
+
+  explicit ConstantInt_match(APInt *V) : BindVal(V) {}
+
+  bool match(const SelectionDAG *, SDValue N) {
+    // The logics here are similar to that in
+    // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also
+    // treats GlobalAddressSDNode as a constant, which is difficult to turn into
+    // APInt.
+    if (auto *C = dyn_cast_or_null<ConstantSDNode>(N.getNode())) {
+      if (BindVal)
+        *BindVal = C->getAPIntValue();
+      return true;
+    }
+
+    APInt Discard;
+    return ISD::isConstantSplatVector(N.getNode(),
+                                      BindVal ? *BindVal : Discard);
+  }
+};
+/// Match any interger constants or splat of an integer constant.
+inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); }
+/// Match any interger constants or splat of an integer constant; return the
+/// specific constant or constant splat value.
+inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); }
+
+struct SpecificInt_match {
+  APInt IntVal;
+
+  explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {}
+
+  bool match(const SelectionDAG *DAG, SDValue N) {
+    APInt ConstInt;
+    if (sd_match(N, DAG, m_ConstInt(ConstInt)))
+      return APInt::isSameValue(IntVal, ConstInt);
+    return false;
+  }
+};
+
+/// Match a specific integer constant or constant splat value.
+inline SpecificInt_match m_SpecificInt(APInt V) {
+  return SpecificInt_match(std::move(V));
+}
+inline SpecificInt_match m_SpecificInt(uint64_t V) {
+  return SpecificInt_match(APInt(64, V));
+}
+
+inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
+inline SpecificInt_match m_AllOnes() { return m_SpecificInt(~0U); }
+
+/// Match true boolean value based on the information provided by
+/// TargetLowering.
+inline TLI_pred_match<Value_match> m_True() {
+  return TLI_pred_match<Value_match>(
+      [](const TargetLowering &TLI, SDValue N) {
+        APInt ConstVal;
+        if (sd_match(N, m_ConstInt(ConstVal)))
+          switch (TLI.getBooleanContents(N.getValueType())) {
+          case TargetLowering::ZeroOrOneBooleanContent:
+            return ConstVal.isOne();
+          case TargetLowering::ZeroOrNegativeOneBooleanContent:
+            return ConstVal.isAllOnes();
+          case TargetLowering::UndefinedBooleanContent:
+            return (ConstVal & 0x01) == 1;
+          }
+
+        return false;
+      },
+      m_Value());
+}
+/// Match false boolean value based on the information provided by
+/// TargetLowering.
+inline TLI_pred_match<Value_match> m_False() {
+  return TLI_pred_match<Value_match>(
+      [](const TargetLowering &TLI, SDValue N) {
+        APInt ConstVal;
+        if (sd_match(N, m_ConstInt(ConstVal)))
+          switch (TLI.getBooleanContents(N.getValueType())) {
+          case TargetLowering::ZeroOrOneBooleanContent:
+          case TargetLowering::ZeroOrNegativeOneBooleanContent:
+            return ConstVal.isZero();
+          case TargetLowering::UndefinedBooleanContent:
+            return (ConstVal & 0x01) == 0;
+          }
+
+        return false;
+      },
+      m_Value());
+}
+} // namespace SDPatternMatch
+} // namespace llvm
+#endif
diff --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt
index c78cbfcc281939..fb7f8bcf6adfed 100644
--- a/llvm/unittests/CodeGen/CMakeLists.txt
+++ b/llvm/unittests/CodeGen/CMakeLists.txt
@@ -41,6 +41,7 @@ add_llvm_unittest(CodeGenTests
   ScalableVectorMVTsTest.cpp
   SchedBoundary.cpp
   SelectionDAGAddressAnalysisTest.cpp
+  SelectionDAGPatternMatchTest.cpp
   TypeTraitsTest.cpp
   TargetOptionsTest.cpp
   TestAsmPrinter.cpp
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
new file mode 100644
index 00000000000000..be7c66a026b29d
--- /dev/null
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -0,0 +1,249 @@
+//===---- llvm/unittest/CodeGen/SelectionDAGPatternMatchTest.cpp  ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+class SelectionDAGPatternMatchTest : public testing::Test {
+protected:
+  static void SetUpTestCase() {
+    InitializeAllTargets();
+    InitializeAllTargetMCs();
+  }
+
+  void SetUp() override {
+    StringRef Assembly = "@g = global i32 0\n"
+                         "@g_alias = alias i32, i32* @g\n"
+                         "define i32 @f() {\n"
+                         "  %1 = load i32, i32* @g\n"
+                         "  ret i32 %1\n"
+                         "}";
+
+    Triple TargetTriple("riscv64--");
+    std::string Error;
+    const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
+    // FIXME: These tests do not depend on RISCV specifically, but we have to
+    // initialize a target. A skeleton Target for unittests would allow us to
+    // always run these tests.
+    if (!T)
+      GTEST_SKIP();
+
+    TargetOptions Options;
+    TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>(
+        T->createTargetMachine("riscv64", "", "+m,+f,+d,+v", Options,
+                               std::nullopt, std::nullopt,
+                               CodeGenOptLevel::Aggressive)));
+    if (!TM)
+      GTEST_SKIP();
+
+    SMDiagnostic SMError;
+    M = parseAssemblyString(Assembly, SMError, Context);
+    if (!M)
+      report_fatal_error(SMError.getMessage());
+    M->setDataLayout(TM->createDataLayout());
+
+    F = M->getFunction("f");
+    if (!F)
+      report_fatal_error("F?");
+    G = M->getGlobalVariable("g");
+    if (!G)
+      report_fatal_error("G?");
+    AliasedG = M->getNamedAlias("g_alias");
+    if (!AliasedG)
+      report_fatal_error("AliasedG?");
+
+    MachineModuleInfo MMI(TM.get());
+
+    MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
+                                           0, MMI);
+
+    DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::None);
+    if (!DAG)
+      report_fatal_error("DAG?");
+    OptimizationRemarkEmitter ORE(F);
+    DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr);
+  }
+
+  TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) {
+    return DAG->getTargetLoweringInfo().getTypeAction(Context, VT);
+  }
+
+  EVT getTypeToTransformTo(EVT VT) {
+    return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT);
+  }
+
+  LLVMContext Context;
+  std::unique_ptr<LLVMTargetMachine> TM;
+  std::unique_ptr<Module> M;
+  Function *F;
+  GlobalVariable *G;
+  GlobalAlias *AliasedG;
+  std::unique_ptr<MachineFunction> MF;
+  std::unique_ptr<SelectionDAG> DAG;
+};
+
+TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto Float32VT = EVT::getFloatingPointVT(32);
+  auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Float32VT);
+  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(Op0, m_SpecificVT(Int32VT)));
+  EVT BindVT;
+  EXPECT_TRUE(sd_match(Op1, m_VT(BindVT)));
+  EXPECT_EQ(BindVT, Float32VT);
+  EXPECT_TRUE(sd_match(Op0, m_IntegerVT()));
+  EXPECT_TRUE(sd_match(Op1, m_FloatingPointVT()));
+  EXPECT_TRUE(sd_match(Op2, m_VectorVT()));
+  EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto Float32VT = EVT::getFloatingPointVT(32);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
+
+  SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
+  SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
+  SDValue Mul = DAG->getNode(ISD::MUL, DL, Int32VT, Add, Sub);
+
+  SDValue SFAdd = DAG->getNode(ISD::STRICT_FADD, DL, {Float32VT, MVT::Other},
+                               {DAG->getEntryNode(), Op2, Op2});
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(Sub, m_BinOp(ISD::SUB, m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
+  EXPECT_TRUE(sd_match(SFAdd, m_BinOp(ISD::STRICT_FADD, m_SpecificVT(Float32VT),
+                                      m_SpecificVT(Float32VT))));
+  EXPECT_FALSE(sd_match(
+      SFAdd, m_BinOp(ISD::STRICT_FADD, m_OtherVT(), m_SpecificVT(Float32VT))));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto Int64VT = EVT::getIntegerVT(Context, 64);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
+
+  SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
+  SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0);
+  SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
+  EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
+  EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+
+  SDValue Arg0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+
+  SDValue Const3 = DAG->getConstant(3, DL, Int32VT);
+  SDValue Const87 = DAG->getConstant(87, DL, Int32VT);
+  SDValue Splat = DAG->getSplat(VInt32VT, DL, Arg0);
+  SDValue ConstSplat = DAG->getSplat(VInt32VT, DL, Const3);
+  SDValue Zero = DAG->getConstant(0, DL, Int32VT);
+  SDValue One = DAG->getConstant(1, DL, Int32VT);
+  SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, Int32VT);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(Const87, m_ConstInt()));
+  EXPECT_FALSE(sd_match(Arg0, m_ConstInt()));
+  APInt ConstVal;
+  EXPECT_TRUE(sd_match(ConstSplat, m_ConstInt(ConstVal)));
+  EXPECT_EQ(ConstVal, 3);
+  EXPECT_FALSE(sd_match(Splat, m_ConstInt()));
+
+  EXPECT_TRUE(sd_match(Const87, m_SpecificInt(87)));
+  EXPECT_TRUE(sd_match(Const3, m_SpecificInt(ConstVal)));
+  EXPECT_TRUE(sd_match(AllOnes, m_AllOnes()));
+
+  EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False()));
+  EXPECT_TRUE(sd_match(One, DAG.get(), m_True()));
+  EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True()));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+
+  SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
+  SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(
+      Sub, m_any_of(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
+  EXPECT_TRUE(sd_match(Add, m_all_of(m_Opc(ISD::ADD), m_OneUse())));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchNode) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+
+  SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(Add, m_Node(ISD::SUB, m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_Value())));
+  EXPECT_FALSE(
+      sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
+  SDLoc DL;
+  auto Int16VT = EVT::getIntegerVT(Context, 16);
+  auto Int64VT = EVT::getIntegerVT(Context, 64);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int16VT);
+
+  SDValue Add = DAG->getNode(ISD::ADD, DL, Int64VT, Op0, Op0);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(Op0, DAG.get(), m_LegalType(m_Value())));
+  EXPECT_FALSE(sd_match(Op1, DAG.get(), m_LegalType(m_Value())));
+  EXPECT_TRUE(sd_match(Add, DAG.get(),
+                       m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
+}

>From c884481e3537923b8176a903f14bd4c018b6c3f3 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Fri, 19 Jan 2024 13:51:32 -0800
Subject: [PATCH 2/7] In NUses, check the subsequent pattern before checking
 the # of users

As SDNode::hasNUsesOfValue is pretty expensive
---
 llvm/include/llvm/CodeGen/SDPatternMatch.h | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 03393abb11a7e1..b1b8c63c8a6d41 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -77,7 +77,10 @@ template <unsigned NumUses, typename Pattern> struct NUses_match {
   explicit NUses_match(const Pattern &P) : P(P) {}
 
   bool match(const SelectionDAG *DAG, SDValue N) {
-    return N && N->hasNUsesOfValue(NumUses, N.getResNo()) && P.match(DAG, N);
+    // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
+    // multiple results, hence we check the subsequent pattern here before
+    // checking the number of value users.
+    return N && P.match(DAG, N) && N->hasNUsesOfValue(NumUses, N.getResNo());
   }
 };
 

>From c94e7a4e01f4b2da6d1521d68ea67ee75884ecbe Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 12 Feb 2024 15:40:54 -0800
Subject: [PATCH 3/7] Introducing MatchContext into SDPatternMatch

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 175 ++++++++++++++----
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  |  39 ++++
 2 files changed, 175 insertions(+), 39 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index b1b8c63c8a6d41..16b077deff1cbb 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -14,6 +14,7 @@
 #define LLVM_CODEGEN_SDPATTERNMATCH_H
 
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetLowering.h"
@@ -21,24 +22,81 @@
 namespace llvm {
 namespace SDPatternMatch {
 
+/// MatchContext can repurpose existing patterns to behave differently under
+/// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes
+/// in normal circumstances, but matches VP_ADD nodes under a custom
+/// VPMatchContext. This design is meant to facilitate code / pattern reusing.
+class BasicMatchContext {
+  const SelectionDAG *DAG;
+  const TargetLowering *TLI;
+
+public:
+  explicit BasicMatchContext(const SelectionDAG *DAG)
+      : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {}
+
+  explicit BasicMatchContext(const TargetLowering *TLI)
+      : DAG(nullptr), TLI(TLI) {}
+
+  // A valid MatchContext has to implement the following functions.
+
+  const SelectionDAG *getDAG() const { return DAG; }
+
+  const TargetLowering *getTLI() const {
+    if (TLI)
+      return TLI;
+    return DAG ? &DAG->getTargetLoweringInfo() : nullptr;
+  }
+
+  // Optional trait function(s)
+
+  /// Return true if N effectively has opcode Opcode.
+  // bool match(SDValue N, unsigned Opcode)
+};
+
+template <typename MatchContext>
+using ctx_has_get_dag = decltype(std::declval<const MatchContext &>().getDAG());
+
+template <typename MatchContext>
+using ctx_has_get_tli = decltype(std::declval<const MatchContext &>().getTLI());
+
+template <typename MatchContext>
+using ctx_has_match = decltype(std::declval<const MatchContext &>().match(
+    std::declval<SDValue>(), std::declval<unsigned>()));
+
+template <typename Pattern, typename MatchContext>
+[[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx,
+                                    Pattern &&P) {
+  static_assert(is_detected<ctx_has_get_dag, MatchContext>::value,
+                "Match context has to implement getDAG().");
+  static_assert(is_detected<ctx_has_get_tli, MatchContext>::value,
+                "Match context has to implement getTLI().");
+  return P.match(Ctx, N);
+}
+
+template <typename Pattern, typename MatchContext>
+[[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx,
+                                    Pattern &&P) {
+  return sd_context_match(SDValue(N, 0), Ctx, P);
+}
+
 template <typename Pattern>
 [[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) {
-  return P.match(DAG, SDValue(N, 0));
+  return sd_context_match(N, BasicMatchContext(DAG), P);
 }
 
 template <typename Pattern>
 [[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) {
-  return P.match(DAG, N);
+  return sd_context_match(N, BasicMatchContext(DAG), P);
 }
 
 template <typename Pattern>
 [[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) {
-  return P.match(nullptr, SDValue(N, 0));
+  return sd_match(N, nullptr, P);
 }
 
 template <typename Pattern>
 [[nodiscard]] bool sd_match(SDValue N, Pattern &&P) {
-  return P.match(nullptr, N);
+  return sd_match(N, nullptr, P);
 }
 
 // === Utilities ===
@@ -49,7 +107,7 @@ struct Value_match {
 
   explicit Value_match(SDValue Match) : MatchVal(Match) {}
 
-  bool match(const SelectionDAG *, SDValue N) {
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
     return (MatchVal && (MatchVal == N)) || N.getNode();
   }
 };
@@ -64,7 +122,16 @@ struct Opcode_match {
 
   explicit Opcode_match(unsigned Opc) : Opcode(Opc) {}
 
-  bool match(const SelectionDAG *, SDValue N) {
+  template <typename MatchContext>
+  std::enable_if_t<is_detected<ctx_has_match, MatchContext>::value, bool>
+  match(const MatchContext &Ctx, SDValue N) {
+    return N && Ctx.match(N, Opcode);
+  }
+
+  // Default implementation.
+  template <typename MatchContext>
+  std::enable_if_t<!is_detected<ctx_has_match, MatchContext>::value, bool>
+  match(const MatchContext &, SDValue N) {
     return N && N->getOpcode() == Opcode;
   }
 };
@@ -76,11 +143,12 @@ template <unsigned NumUses, typename Pattern> struct NUses_match {
 
   explicit NUses_match(const Pattern &P) : P(P) {}
 
-  bool match(const SelectionDAG *DAG, SDValue N) {
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
     // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
     // multiple results, hence we check the subsequent pattern here before
     // checking the number of value users.
-    return N && P.match(DAG, N) && N->hasNUsesOfValue(NumUses, N.getResNo());
+    return N && P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo());
   }
 };
 
@@ -105,7 +173,7 @@ struct Value_bind {
 
   explicit Value_bind(SDValue &N) : BindVal(N) {}
 
-  bool match(const SelectionDAG *, SDValue N) {
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
     if (N) {
       BindVal = N;
       return true;
@@ -123,9 +191,9 @@ template <typename Pattern> struct TLI_pred_match {
   TLI_pred_match(decltype(PredFunc) &&Pred, const Pattern &P)
       : P(P), PredFunc(std::move(Pred)) {}
 
-  bool match(const SelectionDAG *DAG, SDValue N) {
-    return DAG && N && PredFunc(DAG->getTargetLoweringInfo(), N) &&
-           P.match(DAG, N);
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    return Ctx.getTLI() && N && PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
   }
 };
 
@@ -139,13 +207,30 @@ inline TLI_pred_match<Pattern> m_LegalOp(const Pattern &P) {
       P);
 }
 
+/// Switch to a different MatchContext for subsequent patterns.
+template <typename NewMatchContext, typename Pattern> struct SwitchContext {
+  const NewMatchContext &Ctx;
+  Pattern P;
+
+  template <typename OrigMatchContext>
+  bool match(const OrigMatchContext &, SDValue N) {
+    return P.match(Ctx, N);
+  }
+};
+
+template <typename MatchContext, typename Pattern>
+inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
+                                                      Pattern &&P) {
+  return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
+}
+
 // === Value type ===
 struct ValueType_bind {
   EVT &BindVT;
 
   explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {}
 
-  bool match(const SelectionDAG *, SDValue N) {
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
     if (!N)
       return false;
     BindVT = N.getValueType();
@@ -163,8 +248,9 @@ template <typename Pattern> struct ValueType_match {
   ValueType_match(decltype(PredFunc) &&Pred, const Pattern &P)
       : PredFunc(std::move(Pred)), P(P) {}
 
-  bool match(const SelectionDAG *DAG, SDValue N) {
-    return N && PredFunc(N.getValueType()) && P.match(DAG, N);
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    return N && PredFunc(N.getValueType()) && P.match(Ctx, N);
   }
 };
 
@@ -248,7 +334,9 @@ inline TLI_pred_match<Pattern> m_LegalType(const Pattern &P) {
 
 // === Patterns combinators ===
 template <typename... Preds> struct And {
-  bool match(const SelectionDAG *, SDValue N) { return true; }
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    return true;
+  }
 };
 
 template <typename Pred, typename... Preds>
@@ -258,13 +346,16 @@ struct And<Pred, Preds...> : And<Preds...> {
       : And<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {
   }
 
-  bool match(const SelectionDAG *DAG, SDValue N) {
-    return P.match(DAG, N) && And<Preds...>::match(DAG, N);
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    return P.match(Ctx, N) && And<Preds...>::match(Ctx, N);
   }
 };
 
 template <typename... Preds> struct Or {
-  bool match(const SelectionDAG *, SDValue N) { return false; }
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    return false;
+  }
 };
 
 template <typename Pred, typename... Preds>
@@ -273,8 +364,9 @@ struct Or<Pred, Preds...> : Or<Preds...> {
   Or(Pred &&p, Preds &&...preds)
       : Or<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {}
 
-  bool match(const SelectionDAG *DAG, SDValue N) {
-    return P.match(DAG, N) || Or<Preds...>::match(DAG, N);
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N);
   }
 };
 
@@ -293,13 +385,14 @@ template <typename... OpndPreds> struct Node_match {
 
   Node_match(unsigned Opc, unsigned OpIdx) : Opcode(Opc), OpIdx(OpIdx) {}
 
-  bool match(const SelectionDAG *, SDValue N) {
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
     if (!N)
       return false;
 
     if (OpIdx == 0) {
       // Check opcode
-      if (N->getOpcode() != Opcode)
+      if (!sd_context_match(N, Ctx, m_Opc(Opcode)))
         return false;
     }
 
@@ -319,19 +412,20 @@ struct Node_match<OpndPred, OpndPreds...> : Node_match<OpndPreds...> {
                                  std::forward<OpndPreds>(preds)...),
         Opcode(Opc), OpIdx(OpIdx), P(std::forward<OpndPred>(p)) {}
 
-  bool match(const SelectionDAG *DAG, SDValue N) {
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
     if (!N)
       return false;
 
     if (OpIdx == 0) {
       // Check opcode
-      if (N->getOpcode() != Opcode)
+      if (!sd_context_match(N, Ctx, m_Opc(Opcode)))
         return false;
     }
 
     if (OpIdx < N->getNumOperands())
-      return P.match(DAG, N->getOperand(OpIdx)) &&
-             Node_match<OpndPreds...>::match(DAG, N);
+      return P.match(Ctx, N->getOperand(OpIdx)) &&
+             Node_match<OpndPreds...>::match(Ctx, N);
 
     // This is the case where there are more predicates than operands.
     return false;
@@ -375,18 +469,19 @@ struct BinaryOpc_match {
   BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R)
       : Opcode(Opc), LHS(L), RHS(R) {}
 
-  bool match(const SelectionDAG *DAG, SDValue N) {
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
     if (!N)
       return false;
 
-    if (N->getOpcode() == Opcode) {
+    if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
       EffectiveOperands EO(N);
       if (EO.Size == 2)
-        return (LHS.match(DAG, N->getOperand(EO.FirstIndex)) &&
-                RHS.match(DAG, N->getOperand(EO.FirstIndex + 1))) ||
+        return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+                RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
                (Commutable &&
-                LHS.match(DAG, N->getOperand(EO.FirstIndex + 1)) &&
-                RHS.match(DAG, N->getOperand(EO.FirstIndex)));
+                LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+                RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
     }
 
     return false;
@@ -484,14 +579,15 @@ template <typename Opnd_P> struct UnaryOpc_match {
 
   UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {}
 
-  bool match(const SelectionDAG *DAG, SDValue N) {
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
     if (!N)
       return false;
 
-    if (N->getOpcode() == Opcode) {
+    if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
       EffectiveOperands EO(N);
       if (EO.Size == 1)
-        return Opnd.match(DAG, N->getOperand(EO.FirstIndex));
+        return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
     }
 
     return false;
@@ -525,7 +621,7 @@ struct ConstantInt_match {
 
   explicit ConstantInt_match(APInt *V) : BindVal(V) {}
 
-  bool match(const SelectionDAG *, SDValue N) {
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
     // The logics here are similar to that in
     // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also
     // treats GlobalAddressSDNode as a constant, which is difficult to turn into
@@ -552,9 +648,10 @@ struct SpecificInt_match {
 
   explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {}
 
-  bool match(const SelectionDAG *DAG, SDValue N) {
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
     APInt ConstInt;
-    if (sd_match(N, DAG, m_ConstInt(ConstInt)))
+    if (sd_context_match(N, Ctx, m_ConstInt(ConstInt)))
       return APInt::isSameValue(IntVal, ConstInt);
     return false;
   }
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index be7c66a026b29d..369918c2b6e611 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -231,6 +231,45 @@ TEST_F(SelectionDAGPatternMatchTest, matchNode) {
   EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
 }
 
+namespace {
+struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
+  using SDPatternMatch::BasicMatchContext::BasicMatchContext;
+
+  bool match(SDValue OpVal, unsigned Opc) const {
+    if (!OpVal->isVPOpcode())
+      return OpVal->getOpcode() == Opc;
+
+    auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
+    return BaseOpc.has_value() && *BaseOpc == Opc;
+  }
+};
+} // anonymous namespace
+TEST_F(SelectionDAGPatternMatchTest, matchContext) {
+  SDLoc DL;
+  auto BoolVT = EVT::getIntegerVT(Context, 1);
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+  auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
+
+  SDValue Scalar0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
+  SDValue Mask0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, MaskVT);
+
+  SDValue VPAdd = DAG->getNode(ISD::VP_ADD, DL, VInt32VT,
+                               {Vector0, Vector0, Mask0, Scalar0});
+  SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
+                                     {Scalar0, VPAdd, Mask0, Scalar0});
+
+  using namespace SDPatternMatch;
+  VPMatchContext VPCtx(DAG.get());
+  EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
+  // VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
+  // sd_match before switching to VPMatchContext when checking VPAdd.
+  EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
+                                           m_Context(VPCtx, m_Opc(ISD::ADD)),
+                                           m_Value(), m_Value())));
+}
+
 TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
   SDLoc DL;
   auto Int16VT = EVT::getIntegerVT(Context, 16);

>From 5b9e51d989d86371e8e1b8594d63e36051ddc684 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 14 Feb 2024 16:48:17 -0800
Subject: [PATCH 4/7] Address reviewer comments

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 179 +++++++++---------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  |  13 +-
 2 files changed, 95 insertions(+), 97 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 16b077deff1cbb..db2cd204b26763 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -41,11 +41,7 @@ class BasicMatchContext {
 
   const SelectionDAG *getDAG() const { return DAG; }
 
-  const TargetLowering *getTLI() const {
-    if (TLI)
-      return TLI;
-    return DAG ? &DAG->getTargetLoweringInfo() : nullptr;
-  }
+  const TargetLowering *getTLI() const { return TLI; }
 
   // Optional trait function(s)
 
@@ -108,7 +104,9 @@ struct Value_match {
   explicit Value_match(SDValue Match) : MatchVal(Match) {}
 
   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
-    return (MatchVal && (MatchVal == N)) || N.getNode();
+    if (MatchVal)
+      return MatchVal == N;
+    return N.getNode();
   }
 };
 
@@ -125,14 +123,14 @@ struct Opcode_match {
   template <typename MatchContext>
   std::enable_if_t<is_detected<ctx_has_match, MatchContext>::value, bool>
   match(const MatchContext &Ctx, SDValue N) {
-    return N && Ctx.match(N, Opcode);
+    return Ctx.match(N, Opcode);
   }
 
   // Default implementation.
   template <typename MatchContext>
   std::enable_if_t<!is_detected<ctx_has_match, MatchContext>::value, bool>
   match(const MatchContext &, SDValue N) {
-    return N && N->getOpcode() == Opcode;
+    return N->getOpcode() == Opcode;
   }
 };
 
@@ -148,7 +146,7 @@ template <unsigned NumUses, typename Pattern> struct NUses_match {
     // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces
     // multiple results, hence we check the subsequent pattern here before
     // checking the number of value users.
-    return N && P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo());
+    return P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo());
   }
 };
 
@@ -174,11 +172,8 @@ struct Value_bind {
   explicit Value_bind(SDValue &N) : BindVal(N) {}
 
   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
-    if (N) {
-      BindVal = N;
-      return true;
-    }
-    return false;
+    BindVal = N;
+    return true;
   }
 };
 
@@ -193,7 +188,8 @@ template <typename Pattern> struct TLI_pred_match {
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    return Ctx.getTLI() && N && PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
+    assert(Ctx.getTLI() && "TargetLowering is required for this pattern.");
+    return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N);
   }
 };
 
@@ -231,8 +227,6 @@ struct ValueType_bind {
   explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {}
 
   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
-    if (!N)
-      return false;
     BindVT = N.getValueType();
     return true;
   }
@@ -241,85 +235,77 @@ struct ValueType_bind {
 /// Retreive the ValueType of the current SDValue.
 inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); }
 
-template <typename Pattern> struct ValueType_match {
-  std::function<bool(EVT)> PredFunc;
+template <typename Pattern, typename PredFuncT> struct ValueType_match {
+  PredFuncT PredFunc;
   Pattern P;
 
-  ValueType_match(decltype(PredFunc) &&Pred, const Pattern &P)
-      : PredFunc(std::move(Pred)), P(P) {}
+  ValueType_match(const PredFuncT &Pred, const Pattern &P)
+      : PredFunc(Pred), P(P) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    return N && PredFunc(N.getValueType()) && P.match(Ctx, N);
+    return PredFunc(N.getValueType()) && P.match(Ctx, N);
   }
 };
 
+// Explicit deduction guide.
+template <typename PredFuncT, typename Pattern>
+ValueType_match(const PredFuncT &Pred, const Pattern &P)
+    -> ValueType_match<Pattern, PredFuncT>;
+
 /// Match a specific ValueType.
 template <typename Pattern>
-inline ValueType_match<Pattern> m_SpecificVT(EVT RefVT, const Pattern &P) {
-  return ValueType_match<Pattern>([=](EVT VT) { return VT == RefVT; }, P);
+inline auto m_SpecificVT(EVT RefVT, const Pattern &P) {
+  return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P};
 }
-inline ValueType_match<Value_match> m_SpecificVT(EVT RefVT) {
-  return ValueType_match<Value_match>([=](EVT VT) { return VT == RefVT; },
-                                      m_Value());
+inline auto m_SpecificVT(EVT RefVT) {
+  return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()};
 }
 
-inline ValueType_match<Value_match> m_Glue() { return m_SpecificVT(MVT::Glue); }
-inline ValueType_match<Value_match> m_OtherVT() {
-  return m_SpecificVT(MVT::Other);
-}
+inline auto m_Glue() { return m_SpecificVT(MVT::Glue); }
+inline auto m_OtherVT() { return m_SpecificVT(MVT::Other); }
 
 /// Match any integer ValueTypes.
-template <typename Pattern>
-inline ValueType_match<Pattern> m_IntegerVT(const Pattern &P) {
-  return ValueType_match<Pattern>([](EVT VT) { return VT.isInteger(); }, P);
+template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) {
+  return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P};
 }
-inline ValueType_match<Value_match> m_IntegerVT() {
-  return ValueType_match<Value_match>([](EVT VT) { return VT.isInteger(); },
-                                      m_Value());
+inline auto m_IntegerVT() {
+  return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()};
 }
 
 /// Match any floating point ValueTypes.
-template <typename Pattern>
-inline ValueType_match<Pattern> m_FloatingPointVT(const Pattern &P) {
-  return ValueType_match<Pattern>([](EVT VT) { return VT.isFloatingPoint(); },
-                                  P);
+template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) {
+  return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P};
 }
-inline ValueType_match<Value_match> m_FloatingPointVT() {
-  return ValueType_match<Value_match>(
-      [](EVT VT) { return VT.isFloatingPoint(); }, m_Value());
+inline auto m_FloatingPointVT() {
+  return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); },
+                         m_Value()};
 }
 
 /// Match any vector ValueTypes.
-template <typename Pattern>
-inline ValueType_match<Pattern> m_VectorVT(const Pattern &P) {
-  return ValueType_match<Pattern>([](EVT VT) { return VT.isVector(); }, P);
+template <typename Pattern> inline auto m_VectorVT(const Pattern &P) {
+  return ValueType_match{[](EVT VT) { return VT.isVector(); }, P};
 }
-inline ValueType_match<Value_match> m_VectorVT() {
-  return ValueType_match<Value_match>([](EVT VT) { return VT.isVector(); },
-                                      m_Value());
+inline auto m_VectorVT() {
+  return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()};
 }
 
 /// Match fixed-length vector ValueTypes.
-template <typename Pattern>
-inline ValueType_match<Pattern> m_FixedVectorVT(const Pattern &P) {
-  return ValueType_match<Pattern>(
-      [](EVT VT) { return VT.isFixedLengthVector(); }, P);
+template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) {
+  return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P};
 }
-inline ValueType_match<Value_match> m_FixedVectorVT() {
-  return ValueType_match<Value_match>(
-      [](EVT VT) { return VT.isFixedLengthVector(); }, m_Value());
+inline auto m_FixedVectorVT() {
+  return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); },
+                         m_Value()};
 }
 
 /// Match scalable vector ValueTypes.
-template <typename Pattern>
-inline ValueType_match<Pattern> m_ScalableVectorVT(const Pattern &P) {
-  return ValueType_match<Pattern>([](EVT VT) { return VT.isScalableVector(); },
-                                  P);
+template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) {
+  return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P};
 }
-inline ValueType_match<Value_match> m_ScalableVectorVT() {
-  return ValueType_match<Value_match>(
-      [](EVT VT) { return VT.isScalableVector(); }, m_Value());
+inline auto m_ScalableVectorVT() {
+  return ValueType_match{[](EVT VT) { return VT.isScalableVector(); },
+                         m_Value()};
 }
 
 /// Match legal ValueTypes based on the information provided by TargetLowering.
@@ -370,11 +356,11 @@ struct Or<Pred, Preds...> : Or<Preds...> {
   }
 };
 
-template <typename... Preds> And<Preds...> m_all_of(Preds &&...preds) {
+template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) {
   return And<Preds...>(std::forward<Preds>(preds)...);
 }
 
-template <typename... Preds> Or<Preds...> m_any_of(Preds &&...preds) {
+template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) {
   return Or<Preds...>(std::forward<Preds>(preds)...);
 }
 
@@ -387,9 +373,6 @@ template <typename... OpndPreds> struct Node_match {
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    if (!N)
-      return false;
-
     if (OpIdx == 0) {
       // Check opcode
       if (!sd_context_match(N, Ctx, m_Opc(Opcode)))
@@ -414,9 +397,6 @@ struct Node_match<OpndPred, OpndPreds...> : Node_match<OpndPreds...> {
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    if (!N)
-      return false;
-
     if (OpIdx == 0) {
       // Check opcode
       if (!sd_context_match(N, Ctx, m_Opc(Opcode)))
@@ -439,7 +419,7 @@ Node_match<OpndPreds...> m_Node(unsigned Opcode, OpndPreds &&...preds) {
 
 /// Provide number of operands that are not chain or glue, as well as the first
 /// index of such operand.
-struct EffectiveOperands {
+template <bool ExcludeChain> struct EffectiveOperands {
   unsigned Size = 0;
   unsigned FirstIndex = 0;
 
@@ -459,8 +439,16 @@ struct EffectiveOperands {
   }
 };
 
+template <> struct EffectiveOperands<false> {
+  unsigned Size = 0;
+  unsigned FirstIndex = 0;
+
+  explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
+};
+
 // === Binary operations ===
-template <typename LHS_P, typename RHS_P, bool Commutable = false>
+template <typename LHS_P, typename RHS_P, bool Commutable = false,
+          bool ExcludeChain = false>
 struct BinaryOpc_match {
   unsigned Opcode;
   LHS_P LHS;
@@ -471,17 +459,13 @@ struct BinaryOpc_match {
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    if (!N)
-      return false;
-
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
-      EffectiveOperands EO(N);
-      if (EO.Size == 2)
-        return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
-                RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
-               (Commutable &&
-                LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
-                RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
+      EffectiveOperands<ExcludeChain> EO(N);
+      assert(EO.Size == 2);
+      return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+              RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+             (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+              RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
     }
 
     return false;
@@ -499,6 +483,17 @@ inline BinaryOpc_match<LHS, RHS, true> m_c_BinOp(unsigned Opc, const LHS &L,
   return BinaryOpc_match<LHS, RHS, true>(Opc, L, R);
 }
 
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, false, true>
+m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R);
+}
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true, true>
+m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) {
+  return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R);
+}
+
 // Common binary operations
 template <typename LHS, typename RHS>
 inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) {
@@ -573,7 +568,7 @@ inline BinaryOpc_match<LHS, RHS, false> m_FRem(const LHS &L, const RHS &R) {
 }
 
 // === Unary operations ===
-template <typename Opnd_P> struct UnaryOpc_match {
+template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
   unsigned Opcode;
   Opnd_P Opnd;
 
@@ -581,13 +576,10 @@ template <typename Opnd_P> struct UnaryOpc_match {
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    if (!N)
-      return false;
-
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
-      EffectiveOperands EO(N);
-      if (EO.Size == 1)
-        return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
+      EffectiveOperands<ExcludeChain> EO(N);
+      assert(EO.Size == 1);
+      return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
     }
 
     return false;
@@ -598,6 +590,11 @@ template <typename Opnd>
 inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) {
   return UnaryOpc_match<Opnd>(Opc, Op);
 }
+template <typename Opnd>
+inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc,
+                                                   const Opnd &Op) {
+  return UnaryOpc_match<Opnd, true>(Opc, Op);
+}
 
 template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
   return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 369918c2b6e611..77755ca1bd071d 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -141,10 +141,11 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(
       Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
-  EXPECT_TRUE(sd_match(SFAdd, m_BinOp(ISD::STRICT_FADD, m_SpecificVT(Float32VT),
-                                      m_SpecificVT(Float32VT))));
-  EXPECT_FALSE(sd_match(
-      SFAdd, m_BinOp(ISD::STRICT_FADD, m_OtherVT(), m_SpecificVT(Float32VT))));
+  EXPECT_TRUE(
+      sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_SpecificVT(Float32VT),
+                                     m_SpecificVT(Float32VT))));
+  EXPECT_FALSE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_OtherVT(),
+                                              m_SpecificVT(Float32VT))));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
@@ -209,8 +210,8 @@ TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
 
   using namespace SDPatternMatch;
   EXPECT_TRUE(sd_match(
-      Sub, m_any_of(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
-  EXPECT_TRUE(sd_match(Add, m_all_of(m_Opc(ISD::ADD), m_OneUse())));
+      Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
+  EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse())));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, matchNode) {

>From e47df78582cfd6ad7af6df602c4ac82a65b1b5ca Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 14 Feb 2024 17:00:21 -0800
Subject: [PATCH 5/7] fixup! Address reviewer comments

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index db2cd204b26763..e4bb226202572a 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -113,7 +113,10 @@ struct Value_match {
 /// Match any valid SDValue.
 inline Value_match m_Value() { return Value_match(); }
 
-inline Value_match m_Specific(SDValue N) { return Value_match(N); }
+inline Value_match m_Specific(SDValue N) {
+  assert(N);
+  return Value_match(N);
+}
 
 struct Opcode_match {
   unsigned Opcode;

>From 364f112f4a2efc00486dc6b9dbdb6a8565b4a41f Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 14 Feb 2024 17:15:37 -0800
Subject: [PATCH 6/7] Add m_Deferred

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 19 +++++++++++++++++++
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  |  3 +++
 2 files changed, 22 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index e4bb226202572a..993dc3a09ed399 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -118,6 +118,25 @@ inline Value_match m_Specific(SDValue N) {
   return Value_match(N);
 }
 
+struct DeferredValue_match {
+  SDValue &MatchVal;
+
+  explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    return N == MatchVal;
+  }
+};
+
+/// Similar to m_Specific, but the specific value to match is determined by
+/// another sub-pattern in the same sd_match() expression. For instance,
+/// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since
+/// `X` is not initialized at the time it got copied into `m_Specific`. Instead,
+/// we should use `m_Add(m_Value(X), m_Deferred(X))`.
+inline DeferredValue_match m_Deferred(SDValue &V) {
+  return DeferredValue_match(V);
+}
+
 struct Opcode_match {
   unsigned Opcode;
 
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 77755ca1bd071d..17fc3ce8af2677 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -144,6 +144,9 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   EXPECT_TRUE(
       sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_SpecificVT(Float32VT),
                                      m_SpecificVT(Float32VT))));
+  SDValue BindVal;
+  EXPECT_TRUE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_Value(BindVal),
+                                             m_Deferred(BindVal))));
   EXPECT_FALSE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_OtherVT(),
                                               m_SpecificVT(Float32VT))));
 }

>From 5c27bc863a7ea7414aaa145d02fdfc32f58c11b9 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 21 Feb 2024 16:47:04 -0800
Subject: [PATCH 7/7] Address reviewer comments

The major one was simplifying Node_match into Operands_match.
---
 llvm/include/llvm/CodeGen/SDPatternMatch.h | 97 +++++++++-------------
 1 file changed, 41 insertions(+), 56 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 993dc3a09ed399..c047b045e798c2 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -201,12 +201,12 @@ struct Value_bind {
 
 inline Value_bind m_Value(SDValue &N) { return Value_bind(N); }
 
-template <typename Pattern> struct TLI_pred_match {
+template <typename Pattern, typename PredFuncT> struct TLI_pred_match {
   Pattern P;
-  std::function<bool(const TargetLowering &, SDValue)> PredFunc;
+  PredFuncT PredFunc;
 
-  TLI_pred_match(decltype(PredFunc) &&Pred, const Pattern &P)
-      : P(P), PredFunc(std::move(Pred)) {}
+  TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
+      : P(P), PredFunc(Pred) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
@@ -215,14 +215,18 @@ template <typename Pattern> struct TLI_pred_match {
   }
 };
 
+// Explicit deduction guide.
+template <typename PredFuncT, typename Pattern>
+TLI_pred_match(const PredFuncT &Pred, const Pattern &P)
+    -> TLI_pred_match<Pattern, PredFuncT>;
+
 /// Match legal SDNodes based on the information provided by TargetLowering.
-template <typename Pattern>
-inline TLI_pred_match<Pattern> m_LegalOp(const Pattern &P) {
-  return TLI_pred_match<Pattern>(
-      [](const TargetLowering &TLI, SDValue N) {
-        return TLI.isOperationLegal(N->getOpcode(), N.getValueType());
-      },
-      P);
+template <typename Pattern> inline auto m_LegalOp(const Pattern &P) {
+  return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
+                          return TLI.isOperationLegal(N->getOpcode(),
+                                                      N.getValueType());
+                        },
+                        P};
 }
 
 /// Switch to a different MatchContext for subsequent patterns.
@@ -331,13 +335,11 @@ inline auto m_ScalableVectorVT() {
 }
 
 /// Match legal ValueTypes based on the information provided by TargetLowering.
-template <typename Pattern>
-inline TLI_pred_match<Pattern> m_LegalType(const Pattern &P) {
-  return TLI_pred_match<Pattern>(
-      [](const TargetLowering &TLI, SDValue N) {
-        return TLI.isTypeLegal(N.getValueType());
-      },
-      P);
+template <typename Pattern> inline auto m_LegalType(const Pattern &P) {
+  return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) {
+                          return TLI.isTypeLegal(N.getValueType());
+                        },
+                        P};
 }
 
 // === Patterns combinators ===
@@ -387,47 +389,29 @@ template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) {
 }
 
 // === Generic node matching ===
-template <typename... OpndPreds> struct Node_match {
-  unsigned Opcode;
-  unsigned OpIdx;
-
-  Node_match(unsigned Opc, unsigned OpIdx) : Opcode(Opc), OpIdx(OpIdx) {}
-
+template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    if (OpIdx == 0) {
-      // Check opcode
-      if (!sd_context_match(N, Ctx, m_Opc(Opcode)))
-        return false;
-    }
-
     // Returns false if there are more operands than predicates;
     return N->getNumOperands() == OpIdx;
   }
 };
 
-template <typename OpndPred, typename... OpndPreds>
-struct Node_match<OpndPred, OpndPreds...> : Node_match<OpndPreds...> {
-  unsigned Opcode;
-  unsigned OpIdx;
+template <unsigned OpIdx, typename OpndPred, typename... OpndPreds>
+struct Operands_match<OpIdx, OpndPred, OpndPreds...>
+    : Operands_match<OpIdx + 1, OpndPreds...> {
   OpndPred P;
 
-  Node_match(unsigned Opc, unsigned OpIdx, OpndPred &&p, OpndPreds &&...preds)
-      : Node_match<OpndPreds...>(Opc, OpIdx + 1,
-                                 std::forward<OpndPreds>(preds)...),
-        Opcode(Opc), OpIdx(OpIdx), P(std::forward<OpndPred>(p)) {}
+  Operands_match(OpndPred &&p, OpndPreds &&...preds)
+      : Operands_match<OpIdx + 1, OpndPreds...>(
+            std::forward<OpndPreds>(preds)...),
+        P(std::forward<OpndPred>(p)) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    if (OpIdx == 0) {
-      // Check opcode
-      if (!sd_context_match(N, Ctx, m_Opc(Opcode)))
-        return false;
-    }
-
     if (OpIdx < N->getNumOperands())
       return P.match(Ctx, N->getOperand(OpIdx)) &&
-             Node_match<OpndPreds...>::match(Ctx, N);
+             Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N);
 
     // This is the case where there are more predicates than operands.
     return false;
@@ -435,8 +419,9 @@ struct Node_match<OpndPred, OpndPreds...> : Node_match<OpndPreds...> {
 };
 
 template <typename... OpndPreds>
-Node_match<OpndPreds...> m_Node(unsigned Opcode, OpndPreds &&...preds) {
-  return Node_match<OpndPreds...>(Opcode, 0, std::forward<OpndPreds>(preds)...);
+auto m_Node(unsigned Opcode, OpndPreds &&...preds) {
+  return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>(
+                                    std::forward<OpndPreds>(preds)...));
 }
 
 /// Provide number of operands that are not chain or glue, as well as the first
@@ -448,14 +433,14 @@ template <bool ExcludeChain> struct EffectiveOperands {
   explicit EffectiveOperands(SDValue N) {
     const unsigned TotalNumOps = N->getNumOperands();
     FirstIndex = TotalNumOps;
-    for (unsigned i = 0; i < TotalNumOps; ++i) {
+    for (unsigned I = 0; I < TotalNumOps; ++I) {
       // Count the number of non-chain and non-glue nodes (we ignore chain
       // and glue by default) and retreive the operand index offset.
-      EVT VT = N->getOperand(i).getValueType();
+      EVT VT = N->getOperand(I).getValueType();
       if (VT != MVT::Glue && VT != MVT::Other) {
         ++Size;
         if (FirstIndex == TotalNumOps)
-          FirstIndex = i;
+          FirstIndex = I;
       }
     }
   }
@@ -689,8 +674,8 @@ inline SpecificInt_match m_AllOnes() { return m_SpecificInt(~0U); }
 
 /// Match true boolean value based on the information provided by
 /// TargetLowering.
-inline TLI_pred_match<Value_match> m_True() {
-  return TLI_pred_match<Value_match>(
+inline auto m_True() {
+  return TLI_pred_match{
       [](const TargetLowering &TLI, SDValue N) {
         APInt ConstVal;
         if (sd_match(N, m_ConstInt(ConstVal)))
@@ -705,12 +690,12 @@ inline TLI_pred_match<Value_match> m_True() {
 
         return false;
       },
-      m_Value());
+      m_Value()};
 }
 /// Match false boolean value based on the information provided by
 /// TargetLowering.
-inline TLI_pred_match<Value_match> m_False() {
-  return TLI_pred_match<Value_match>(
+inline auto m_False() {
+  return TLI_pred_match{
       [](const TargetLowering &TLI, SDValue N) {
         APInt ConstVal;
         if (sd_match(N, m_ConstInt(ConstVal)))
@@ -724,7 +709,7 @@ inline TLI_pred_match<Value_match> m_False() {
 
         return false;
       },
-      m_Value());
+      m_Value()};
 }
 } // namespace SDPatternMatch
 } // namespace llvm



More information about the llvm-commits mailing list