[llvm] Add SD matchers and unit test coverage for ISD::VECTOR_SHUFFLE (PR #119592)

Aidan Goldfarb via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 16 17:30:54 PST 2024


https://github.com/AidanGoldfarb updated https://github.com/llvm/llvm-project/pull/119592

>From 4649aa0cc320d12ba47f90604270cd0df66e1621 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Wed, 11 Dec 2024 12:13:10 -0500
Subject: [PATCH 1/8] Added SD matchers and unit test coverage for
 ISD::VECTOR_SHUFFLE

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 50 +++++++++++++++++++
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 19 +++++++
 2 files changed, 69 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96667952a16efc..95686d310d92b2 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -226,6 +226,15 @@ inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
   return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
 }
 
+struct m_Mask {
+  ArrayRef<int> &MaskRef;
+  m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
+  bool match(ArrayRef<int> Mask) {
+    MaskRef = Mask;
+    return true;
+  }
+};
+
 // === Value type ===
 struct ValueType_bind {
   EVT &BindVT;
@@ -540,6 +549,24 @@ struct BinaryOpc_match {
   }
 };
 
+/// Matches shuffle.
+template <typename T0, typename T1, typename T2> struct SDShuffle_match {
+  T0 Op1;
+  T1 Op2;
+  T2 Mask;
+
+  SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
+      : Op1(Op1), Op2(Op2), Mask(Mask) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
+      return Op1.match(Ctx, I->getOperand(0)) &&
+             Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask());
+    }
+    return false;
+  }
+};
 template <typename LHS_P, typename RHS_P, typename Pred_t,
           bool Commutable = false, bool ExcludeChain = false>
 struct MaxMin_match {
@@ -790,6 +817,29 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
   return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
 }
 
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
+  return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
+}
+
+// template <typename LHS, typename RHS, typename Mask_t>
+// inline TernaryOpc_match<LHS, RHS, Mask_t>
+// m_Shuffle(const LHS &v1, const RHS &v2, const Mask_t &mask) {
+//   return TernaryOpc_match<LHS, RHS, Mask_t>(ISD::VECTOR_SHUFFLE, v1, v2,
+//   mask);
+// }
+// template <typename LHS, typename RHS, typename Mask_t>
+// inline bool
+// m_Shuffle(const LHS &v1, const RHS &v2,const Mask_t &mask) {
+//   return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2) && true;
+// }
+
+template <typename V1_t, typename V2_t, typename Mask_t>
+inline SDShuffle_match<V1_t, V2_t, Mask_t>
+m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
+  return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
+}
+
 // === Unary operations ===
 template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
   unsigned Opcode;
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index b9bcb2479fdcc4..2987f18e336323 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -119,6 +119,25 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
   EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
 }
 
+TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+  SmallVector<int, 4> MaskData = {2, 0, 3, 1};
+  ArrayRef<int> Mask(MaskData);
+
+  SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+  SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
+  SDValue VecShuffleWithMask_0 =
+      DAG->getVectorShuffle(VInt32VT, DL, V0, V1, MaskData);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(
+      sd_match(VecShuffleWithMask_0, m_Shuffle(m_Value(V0), m_Value(V1))));
+  EXPECT_TRUE(sd_match(VecShuffleWithMask_0,
+                       m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))));
+}
+
 TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
   SDLoc DL;
   auto Int32VT = EVT::getIntegerVT(Context, 32);

>From 1b3dd6f5c5b59a2e1652804693057e21eb24d18f Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Wed, 11 Dec 2024 12:21:51 -0500
Subject: [PATCH 2/8] Cleaned up code

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h | 12 ------------
 1 file changed, 12 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 95686d310d92b2..752296634aaaaa 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -822,18 +822,6 @@ inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
   return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
 }
 
-// template <typename LHS, typename RHS, typename Mask_t>
-// inline TernaryOpc_match<LHS, RHS, Mask_t>
-// m_Shuffle(const LHS &v1, const RHS &v2, const Mask_t &mask) {
-//   return TernaryOpc_match<LHS, RHS, Mask_t>(ISD::VECTOR_SHUFFLE, v1, v2,
-//   mask);
-// }
-// template <typename LHS, typename RHS, typename Mask_t>
-// inline bool
-// m_Shuffle(const LHS &v1, const RHS &v2,const Mask_t &mask) {
-//   return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2) && true;
-// }
-
 template <typename V1_t, typename V2_t, typename Mask_t>
 inline SDShuffle_match<V1_t, V2_t, Mask_t>
 m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {

>From 65ab88c3f44090afb85f7734040cdd820e57b8a2 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Thu, 12 Dec 2024 14:12:07 -0500
Subject: [PATCH 3/8] changed SDShuffle_match to store Mask directly. Replaced
 the call to Mask.match() with a std::equal(mask,i->getMask(). changed
 matchVecShuffle test to not initialize Mask. Also added a test to compare
 contents of MaskData to Mask.

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h       | 16 ++++++++--------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp     |  5 +++--
 2 files changed, 11 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 752296634aaaaa..df70eed88a36f8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -550,19 +550,19 @@ struct BinaryOpc_match {
 };
 
 /// Matches shuffle.
-template <typename T0, typename T1, typename T2> struct SDShuffle_match {
+template <typename T0, typename T1> struct SDShuffle_match {
   T0 Op1;
   T1 Op2;
-  T2 Mask;
+  ArrayRef<int> Mask;
 
-  SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
+  SDShuffle_match(const T0 &Op1, const T1 &Op2, const ArrayRef<int> &Mask)
       : Op1(Op1), Op2(Op2), Mask(Mask) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
       return Op1.match(Ctx, I->getOperand(0)) &&
-             Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask());
+             Op2.match(Ctx, I->getOperand(1)) && std::equal(Mask.begin(), Mask.end(), I->getMask().begin());
     }
     return false;
   }
@@ -822,10 +822,10 @@ inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
   return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
 }
 
-template <typename V1_t, typename V2_t, typename Mask_t>
-inline SDShuffle_match<V1_t, V2_t, Mask_t>
-m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
-  return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
+template <typename V1_t, typename V2_t>
+inline SDShuffle_match<V1_t, V2_t>
+m_Shuffle(const V1_t &v1, const V2_t &v2, const ArrayRef<int> mask) {
+  return SDShuffle_match<V1_t, V2_t>(v1, v2, mask);
 }
 
 // === Unary operations ===
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 2987f18e336323..c72f454e02382a 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -124,7 +124,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
   auto Int32VT = EVT::getIntegerVT(Context, 32);
   auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
   SmallVector<int, 4> MaskData = {2, 0, 3, 1};
-  ArrayRef<int> Mask(MaskData);
+  ArrayRef<int> Mask;
 
   SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
   SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
@@ -135,7 +135,8 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
   EXPECT_TRUE(
       sd_match(VecShuffleWithMask_0, m_Shuffle(m_Value(V0), m_Value(V1))));
   EXPECT_TRUE(sd_match(VecShuffleWithMask_0,
-                       m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))));
+                       m_Shuffle(m_Value(V0), m_Value(V1), Mask)));
+  EXPECT_TRUE(std::equal(Mask.begin(), Mask.end(), MaskData.begin()));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {

>From c577f7b1e0a634beaa1b18594f8afeb35d6196a8 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Thu, 12 Dec 2024 14:19:17 -0500
Subject: [PATCH 4/8] formatting fix

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index df70eed88a36f8..6b4562e53c3f9a 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -562,7 +562,8 @@ template <typename T0, typename T1> struct SDShuffle_match {
   bool match(const MatchContext &Ctx, SDValue N) {
     if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
       return Op1.match(Ctx, I->getOperand(0)) &&
-             Op2.match(Ctx, I->getOperand(1)) && std::equal(Mask.begin(), Mask.end(), I->getMask().begin());
+             Op2.match(Ctx, I->getOperand(1)) &&
+             std::equal(Mask.begin(), Mask.end(), I->getMask().begin());
     }
     return false;
   }
@@ -823,8 +824,8 @@ inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
 }
 
 template <typename V1_t, typename V2_t>
-inline SDShuffle_match<V1_t, V2_t>
-m_Shuffle(const V1_t &v1, const V2_t &v2, const ArrayRef<int> mask) {
+inline SDShuffle_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2,
+                                             const ArrayRef<int> mask) {
   return SDShuffle_match<V1_t, V2_t>(v1, v2, mask);
 }
 

>From b365d5bbf9a89393e0bd3cca147d873ef19996c2 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Sat, 14 Dec 2024 13:24:13 -0500
Subject: [PATCH 5/8] proposed changes. Split functionality of m_shuffle. Fst
 varient captures mask as &arrayref. Snd matches specific contents as
 arrayref. Removed m_mask, updated tests

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 55 ++++++++++++-------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 12 ++--
 2 files changed, 40 insertions(+), 27 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 6b4562e53c3f9a..230dd087b35f31 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -226,15 +226,6 @@ inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
   return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
 }
 
-struct m_Mask {
-  ArrayRef<int> &MaskRef;
-  m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
-  bool match(ArrayRef<int> Mask) {
-    MaskRef = Mask;
-    return true;
-  }
-};
-
 // === Value type ===
 struct ValueType_bind {
   EVT &BindVT;
@@ -549,21 +540,41 @@ struct BinaryOpc_match {
   }
 };
 
-/// Matches shuffle.
+/// Matching while capturing mask
 template <typename T0, typename T1> struct SDShuffle_match {
   T0 Op1;
   T1 Op2;
-  ArrayRef<int> Mask;
 
-  SDShuffle_match(const T0 &Op1, const T1 &Op2, const ArrayRef<int> &Mask)
-      : Op1(Op1), Op2(Op2), Mask(Mask) {}
+  const ArrayRef<int> *MaskRef;
+
+  // capturing mask
+  SDShuffle_match(const T0 &Op1, const T1 &Op2, const ArrayRef<int> &MaskRef)
+      : Op1(Op1), Op2(Op2), MaskRef(&MaskRef) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
+      return Op1.match(Ctx, I->getOperand(0)) &&
+             Op2.match(Ctx, I->getOperand(1));
+    }
+    return false;
+  }
+};
+
+/// Matching against a specific match
+template <typename T0, typename T1> struct SDShuffle_maskMatch {
+  T0 Op1;
+  T1 Op2;
+  ArrayRef<int> SpecificMask;
+
+  SDShuffle_maskMatch(const T0 &Op1, const T1 &Op2, const ArrayRef<int> Mask)
+      : Op1(Op1), Op2(Op2), SpecificMask(Mask) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
       return Op1.match(Ctx, I->getOperand(0)) &&
-             Op2.match(Ctx, I->getOperand(1)) &&
-             std::equal(Mask.begin(), Mask.end(), I->getMask().begin());
+             Op2.match(Ctx, I->getOperand(1)) && I->getMask() == SpecificMask;
     }
     return false;
   }
@@ -818,15 +829,17 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
   return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
 }
 
-template <typename LHS, typename RHS>
-inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
-  return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
+template <typename V1_t, typename V2_t>
+inline SDShuffle_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2,
+                                             const ArrayRef<int> &maskRef) {
+  return SDShuffle_match<V1_t, V2_t>(v1, v2, maskRef);
 }
 
 template <typename V1_t, typename V2_t>
-inline SDShuffle_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2,
-                                             const ArrayRef<int> mask) {
-  return SDShuffle_match<V1_t, V2_t>(v1, v2, mask);
+inline SDShuffle_maskMatch<V1_t, V2_t>
+m_ShuffleSpecificMask(const V1_t &v1, const V2_t &v2,
+                      const ArrayRef<int> mask) {
+  return SDShuffle_maskMatch<V1_t, V2_t>(v1, v2, mask);
 }
 
 // === Unary operations ===
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index c72f454e02382a..e8bfdc89bbd41d 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -124,19 +124,19 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
   auto Int32VT = EVT::getIntegerVT(Context, 32);
   auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
   SmallVector<int, 4> MaskData = {2, 0, 3, 1};
-  ArrayRef<int> Mask;
+  ArrayRef<int> CapturedMask;
 
   SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
   SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
-  SDValue VecShuffleWithMask_0 =
+  SDValue VecShuffleWithMask =
       DAG->getVectorShuffle(VInt32VT, DL, V0, V1, MaskData);
 
   using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(VecShuffleWithMask,
+                       m_Shuffle(m_Value(V0), m_Value(V1), CapturedMask)));
   EXPECT_TRUE(
-      sd_match(VecShuffleWithMask_0, m_Shuffle(m_Value(V0), m_Value(V1))));
-  EXPECT_TRUE(sd_match(VecShuffleWithMask_0,
-                       m_Shuffle(m_Value(V0), m_Value(V1), Mask)));
-  EXPECT_TRUE(std::equal(Mask.begin(), Mask.end(), MaskData.begin()));
+      sd_match(VecShuffleWithMask,
+               m_ShuffleSpecificMask(m_Value(V0), m_Value(V1), MaskData)));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {

>From c77ac06ce5f361841bec5a2857a3de285e074e23 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Sun, 15 Dec 2024 11:47:01 -0500
Subject: [PATCH 6/8] Updated element-wise arrayref cmps to not use overridden
 == but std::equal(). Readded cmp for capture test. Removed const for
 ArrayRef. Correctly capture arrayref in m_Shuffle varient. Changed m_Value(V)
 to m_Value, as vec contents not needed.

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 20 ++++++++++---------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  |  9 +++++----
 2 files changed, 16 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 230dd087b35f31..74b4e67609547b 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -545,15 +545,16 @@ template <typename T0, typename T1> struct SDShuffle_match {
   T0 Op1;
   T1 Op2;
 
-  const ArrayRef<int> *MaskRef;
+  ArrayRef<int> &CapturedMask;
 
   // capturing mask
-  SDShuffle_match(const T0 &Op1, const T1 &Op2, const ArrayRef<int> &MaskRef)
-      : Op1(Op1), Op2(Op2), MaskRef(&MaskRef) {}
+  SDShuffle_match(const T0 &Op1, const T1 &Op2, ArrayRef<int> &MaskRef)
+      : Op1(Op1), Op2(Op2), CapturedMask(MaskRef) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
+      CapturedMask = I->getMask();
       return Op1.match(Ctx, I->getOperand(0)) &&
              Op2.match(Ctx, I->getOperand(1));
     }
@@ -567,14 +568,16 @@ template <typename T0, typename T1> struct SDShuffle_maskMatch {
   T1 Op2;
   ArrayRef<int> SpecificMask;
 
-  SDShuffle_maskMatch(const T0 &Op1, const T1 &Op2, const ArrayRef<int> Mask)
+  SDShuffle_maskMatch(const T0 &Op1, const T1 &Op2, ArrayRef<int> Mask)
       : Op1(Op1), Op2(Op2), SpecificMask(Mask) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
       return Op1.match(Ctx, I->getOperand(0)) &&
-             Op2.match(Ctx, I->getOperand(1)) && I->getMask() == SpecificMask;
+             Op2.match(Ctx, I->getOperand(1)) &&
+             std::equal(SpecificMask.begin(), SpecificMask.end(),
+                        I->getMask().begin(), I->getMask().end());
     }
     return false;
   }
@@ -831,14 +834,13 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
 
 template <typename V1_t, typename V2_t>
 inline SDShuffle_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2,
-                                             const ArrayRef<int> &maskRef) {
-  return SDShuffle_match<V1_t, V2_t>(v1, v2, maskRef);
+                                             ArrayRef<int> &mask) {
+  return SDShuffle_match<V1_t, V2_t>(v1, v2, mask);
 }
 
 template <typename V1_t, typename V2_t>
 inline SDShuffle_maskMatch<V1_t, V2_t>
-m_ShuffleSpecificMask(const V1_t &v1, const V2_t &v2,
-                      const ArrayRef<int> mask) {
+m_ShuffleSpecificMask(const V1_t &v1, const V2_t &v2, ArrayRef<int> mask) {
   return SDShuffle_maskMatch<V1_t, V2_t>(v1, v2, mask);
 }
 
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index e8bfdc89bbd41d..e071b0599b3c34 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -133,10 +133,11 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
 
   using namespace SDPatternMatch;
   EXPECT_TRUE(sd_match(VecShuffleWithMask,
-                       m_Shuffle(m_Value(V0), m_Value(V1), CapturedMask)));
-  EXPECT_TRUE(
-      sd_match(VecShuffleWithMask,
-               m_ShuffleSpecificMask(m_Value(V0), m_Value(V1), MaskData)));
+                       m_Shuffle(m_Value(), m_Value(), CapturedMask)));
+  EXPECT_TRUE(sd_match(VecShuffleWithMask,
+                       m_ShuffleSpecificMask(m_Value(), m_Value(), MaskData)));
+  EXPECT_TRUE(std::equal(MaskData.begin(), MaskData.end(), CapturedMask.begin(),
+                         CapturedMask.end()));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {

>From d818fd534bd45a4c1a1dae316785641343422745 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Sun, 15 Dec 2024 12:22:42 -0500
Subject: [PATCH 7/8] added some expect_false tests for mask capture and
 matching

---
 llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index e071b0599b3c34..134a1d8216b75a 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -124,6 +124,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
   auto Int32VT = EVT::getIntegerVT(Context, 32);
   auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
   SmallVector<int, 4> MaskData = {2, 0, 3, 1};
+  SmallVector<int, 4> otherMaskData = {1, 2, 3, 4};
   ArrayRef<int> CapturedMask;
 
   SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
@@ -138,6 +139,11 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
                        m_ShuffleSpecificMask(m_Value(), m_Value(), MaskData)));
   EXPECT_TRUE(std::equal(MaskData.begin(), MaskData.end(), CapturedMask.begin(),
                          CapturedMask.end()));
+  EXPECT_FALSE(
+      sd_match(VecShuffleWithMask,
+               m_ShuffleSpecificMask(m_Value(), m_Value(), otherMaskData)));
+  EXPECT_FALSE(std::equal(otherMaskData.begin(), otherMaskData.end(),
+                          CapturedMask.begin(), CapturedMask.end()));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {

>From e9987c63308f82bd657c96406a48af95fb283513 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Mon, 16 Dec 2024 20:30:39 -0500
Subject: [PATCH 8/8] only capture mask if ops match. reordered tests

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h              | 8 +++++---
 llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp | 5 +++--
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 74b4e67609547b..cb67e74cfd01d0 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -554,9 +554,11 @@ template <typename T0, typename T1> struct SDShuffle_match {
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
-      CapturedMask = I->getMask();
-      return Op1.match(Ctx, I->getOperand(0)) &&
-             Op2.match(Ctx, I->getOperand(1));
+      if (Op1.match(Ctx, I->getOperand(0)) &&
+          Op2.match(Ctx, I->getOperand(1))) {
+        CapturedMask = I->getMask();
+        return true;
+      }
     }
     return false;
   }
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 134a1d8216b75a..698a795d2736fb 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -137,11 +137,12 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
                        m_Shuffle(m_Value(), m_Value(), CapturedMask)));
   EXPECT_TRUE(sd_match(VecShuffleWithMask,
                        m_ShuffleSpecificMask(m_Value(), m_Value(), MaskData)));
-  EXPECT_TRUE(std::equal(MaskData.begin(), MaskData.end(), CapturedMask.begin(),
-                         CapturedMask.end()));
   EXPECT_FALSE(
       sd_match(VecShuffleWithMask,
                m_ShuffleSpecificMask(m_Value(), m_Value(), otherMaskData)));
+
+  EXPECT_TRUE(std::equal(MaskData.begin(), MaskData.end(), CapturedMask.begin(),
+                         CapturedMask.end()));
   EXPECT_FALSE(std::equal(otherMaskData.begin(), otherMaskData.end(),
                           CapturedMask.begin(), CapturedMask.end()));
 }



More information about the llvm-commits mailing list