[llvm] Added SD matchers and unit test coverage for ISD::VECTOR_SHUFFLE (PR #119592)

Aidan Goldfarb via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 11 09:16:43 PST 2024


https://github.com/AidanGoldfarb created https://github.com/llvm/llvm-project/pull/119592

This PR addresses #118845. 

Details to come.

>From 4649aa0cc320d12ba47f90604270cd0df66e1621 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Wed, 11 Dec 2024 12:13:10 -0500
Subject: [PATCH] Added SD matchers and unit test coverage for
 ISD::VECTOR_SHUFFLE

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 50 +++++++++++++++++++
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 19 +++++++
 2 files changed, 69 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96667952a16efc..95686d310d92b2 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -226,6 +226,15 @@ inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
   return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
 }
 
+struct m_Mask {
+  ArrayRef<int> &MaskRef;
+  m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
+  bool match(ArrayRef<int> Mask) {
+    MaskRef = Mask;
+    return true;
+  }
+};
+
 // === Value type ===
 struct ValueType_bind {
   EVT &BindVT;
@@ -540,6 +549,24 @@ struct BinaryOpc_match {
   }
 };
 
+/// Matches shuffle.
+template <typename T0, typename T1, typename T2> struct SDShuffle_match {
+  T0 Op1;
+  T1 Op2;
+  T2 Mask;
+
+  SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
+      : Op1(Op1), Op2(Op2), Mask(Mask) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
+      return Op1.match(Ctx, I->getOperand(0)) &&
+             Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask());
+    }
+    return false;
+  }
+};
 template <typename LHS_P, typename RHS_P, typename Pred_t,
           bool Commutable = false, bool ExcludeChain = false>
 struct MaxMin_match {
@@ -790,6 +817,29 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
   return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
 }
 
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
+  return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
+}
+
+// template <typename LHS, typename RHS, typename Mask_t>
+// inline TernaryOpc_match<LHS, RHS, Mask_t>
+// m_Shuffle(const LHS &v1, const RHS &v2, const Mask_t &mask) {
+//   return TernaryOpc_match<LHS, RHS, Mask_t>(ISD::VECTOR_SHUFFLE, v1, v2,
+//   mask);
+// }
+// template <typename LHS, typename RHS, typename Mask_t>
+// inline bool
+// m_Shuffle(const LHS &v1, const RHS &v2,const Mask_t &mask) {
+//   return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2) && true;
+// }
+
+template <typename V1_t, typename V2_t, typename Mask_t>
+inline SDShuffle_match<V1_t, V2_t, Mask_t>
+m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
+  return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
+}
+
 // === Unary operations ===
 template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
   unsigned Opcode;
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index b9bcb2479fdcc4..2987f18e336323 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -119,6 +119,25 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
   EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
 }
 
+TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+  SmallVector<int, 4> MaskData = {2, 0, 3, 1};
+  ArrayRef<int> Mask(MaskData);
+
+  SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+  SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
+  SDValue VecShuffleWithMask_0 =
+      DAG->getVectorShuffle(VInt32VT, DL, V0, V1, MaskData);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(
+      sd_match(VecShuffleWithMask_0, m_Shuffle(m_Value(V0), m_Value(V1))));
+  EXPECT_TRUE(sd_match(VecShuffleWithMask_0,
+                       m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))));
+}
+
 TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
   SDLoc DL;
   auto Int32VT = EVT::getIntegerVT(Context, 32);



More information about the llvm-commits mailing list