[llvm] Add SD matchers and unit test coverage for ISD::VECTOR_SHUFFLE (PR #119592)
Aidan Goldfarb via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 12 11:12:30 PST 2024
https://github.com/AidanGoldfarb updated https://github.com/llvm/llvm-project/pull/119592
>From 4649aa0cc320d12ba47f90604270cd0df66e1621 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Wed, 11 Dec 2024 12:13:10 -0500
Subject: [PATCH 1/3] Added SD matchers and unit test coverage for
ISD::VECTOR_SHUFFLE
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 50 +++++++++++++++++++
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 19 +++++++
2 files changed, 69 insertions(+)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96667952a16efc..95686d310d92b2 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -226,6 +226,15 @@ inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
}
+struct m_Mask {
+ ArrayRef<int> &MaskRef;
+ m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
+ bool match(ArrayRef<int> Mask) {
+ MaskRef = Mask;
+ return true;
+ }
+};
+
// === Value type ===
struct ValueType_bind {
EVT &BindVT;
@@ -540,6 +549,24 @@ struct BinaryOpc_match {
}
};
+/// Matches shuffle.
+template <typename T0, typename T1, typename T2> struct SDShuffle_match {
+ T0 Op1;
+ T1 Op2;
+ T2 Mask;
+
+ SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
+ : Op1(Op1), Op2(Op2), Mask(Mask) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
+ return Op1.match(Ctx, I->getOperand(0)) &&
+ Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask());
+ }
+ return false;
+ }
+};
template <typename LHS_P, typename RHS_P, typename Pred_t,
bool Commutable = false, bool ExcludeChain = false>
struct MaxMin_match {
@@ -790,6 +817,29 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
}
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
+ return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
+}
+
+// template <typename LHS, typename RHS, typename Mask_t>
+// inline TernaryOpc_match<LHS, RHS, Mask_t>
+// m_Shuffle(const LHS &v1, const RHS &v2, const Mask_t &mask) {
+// return TernaryOpc_match<LHS, RHS, Mask_t>(ISD::VECTOR_SHUFFLE, v1, v2,
+// mask);
+// }
+// template <typename LHS, typename RHS, typename Mask_t>
+// inline bool
+// m_Shuffle(const LHS &v1, const RHS &v2,const Mask_t &mask) {
+// return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2) && true;
+// }
+
+template <typename V1_t, typename V2_t, typename Mask_t>
+inline SDShuffle_match<V1_t, V2_t, Mask_t>
+m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
+ return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
+}
+
// === Unary operations ===
template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
unsigned Opcode;
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index b9bcb2479fdcc4..2987f18e336323 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -119,6 +119,25 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
}
+TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
+ SDLoc DL;
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+ auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+ SmallVector<int, 4> MaskData = {2, 0, 3, 1};
+ ArrayRef<int> Mask(MaskData);
+
+ SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+ SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
+ SDValue VecShuffleWithMask_0 =
+ DAG->getVectorShuffle(VInt32VT, DL, V0, V1, MaskData);
+
+ using namespace SDPatternMatch;
+ EXPECT_TRUE(
+ sd_match(VecShuffleWithMask_0, m_Shuffle(m_Value(V0), m_Value(V1))));
+ EXPECT_TRUE(sd_match(VecShuffleWithMask_0,
+ m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))));
+}
+
TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);
>From 1b3dd6f5c5b59a2e1652804693057e21eb24d18f Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Wed, 11 Dec 2024 12:21:51 -0500
Subject: [PATCH 2/3] Cleaned up code
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 12 ------------
1 file changed, 12 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 95686d310d92b2..752296634aaaaa 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -822,18 +822,6 @@ inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
}
-// template <typename LHS, typename RHS, typename Mask_t>
-// inline TernaryOpc_match<LHS, RHS, Mask_t>
-// m_Shuffle(const LHS &v1, const RHS &v2, const Mask_t &mask) {
-// return TernaryOpc_match<LHS, RHS, Mask_t>(ISD::VECTOR_SHUFFLE, v1, v2,
-// mask);
-// }
-// template <typename LHS, typename RHS, typename Mask_t>
-// inline bool
-// m_Shuffle(const LHS &v1, const RHS &v2,const Mask_t &mask) {
-// return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2) && true;
-// }
-
template <typename V1_t, typename V2_t, typename Mask_t>
inline SDShuffle_match<V1_t, V2_t, Mask_t>
m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
>From 65ab88c3f44090afb85f7734040cdd820e57b8a2 Mon Sep 17 00:00:00 2001
From: Aidan <aidan.goldfarb at mail.mcgill.ca>
Date: Thu, 12 Dec 2024 14:12:07 -0500
Subject: [PATCH 3/3] changed SDShuffle_match to store Mask directly. Replaced
the call to Mask.match() with a std::equal(mask,i->getMask(). changed
matchVecShuffle test to not initialize Mask. Also added a test to compare
contents of MaskData to Mask.
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 16 ++++++++--------
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 5 +++--
2 files changed, 11 insertions(+), 10 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 752296634aaaaa..df70eed88a36f8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -550,19 +550,19 @@ struct BinaryOpc_match {
};
/// Matches shuffle.
-template <typename T0, typename T1, typename T2> struct SDShuffle_match {
+template <typename T0, typename T1> struct SDShuffle_match {
T0 Op1;
T1 Op2;
- T2 Mask;
+ ArrayRef<int> Mask;
- SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
+ SDShuffle_match(const T0 &Op1, const T1 &Op2, const ArrayRef<int> &Mask)
: Op1(Op1), Op2(Op2), Mask(Mask) {}
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
return Op1.match(Ctx, I->getOperand(0)) &&
- Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask());
+ Op2.match(Ctx, I->getOperand(1)) && std::equal(Mask.begin(), Mask.end(), I->getMask().begin());
}
return false;
}
@@ -822,10 +822,10 @@ inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
}
-template <typename V1_t, typename V2_t, typename Mask_t>
-inline SDShuffle_match<V1_t, V2_t, Mask_t>
-m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
- return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
+template <typename V1_t, typename V2_t>
+inline SDShuffle_match<V1_t, V2_t>
+m_Shuffle(const V1_t &v1, const V2_t &v2, const ArrayRef<int> mask) {
+ return SDShuffle_match<V1_t, V2_t>(v1, v2, mask);
}
// === Unary operations ===
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 2987f18e336323..c72f454e02382a 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -124,7 +124,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
auto Int32VT = EVT::getIntegerVT(Context, 32);
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
SmallVector<int, 4> MaskData = {2, 0, 3, 1};
- ArrayRef<int> Mask(MaskData);
+ ArrayRef<int> Mask;
SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
@@ -135,7 +135,8 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
EXPECT_TRUE(
sd_match(VecShuffleWithMask_0, m_Shuffle(m_Value(V0), m_Value(V1))));
EXPECT_TRUE(sd_match(VecShuffleWithMask_0,
- m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))));
+ m_Shuffle(m_Value(V0), m_Value(V1), Mask)));
+ EXPECT_TRUE(std::equal(Mask.begin(), Mask.end(), MaskData.begin()));
}
TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
More information about the llvm-commits
mailing list