[llvm] [DAG] Add funnel-shift matchers to SDPatternMatch (Fixes #185880) (PR #186593)
Vedant Neve via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 1 00:00:31 PDT 2026
https://github.com/0bVdnt updated https://github.com/llvm/llvm-project/pull/186593
>From 25b332cc1f375f26ff17e85255ed0c114489f5ea Mon Sep 17 00:00:00 2001
From: Vedant Neve <vedantneve13 at gmail.com>
Date: Sat, 14 Mar 2026 13:06:27 +0000
Subject: [PATCH 1/2] [DAG] Add funnel-shift matchers to SDPatternMatch
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 100 ++++++++++++++++++
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 80 ++++++++++++++
2 files changed, 180 insertions(+)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 62762d0c50d6a..9e861d4fff1e6 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -962,6 +962,106 @@ inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R);
}
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P>
+m_FShL(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+ return m_TernaryOp(ISD::FSHL, Op0, Op1, Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P>
+m_FShR(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+ return m_TernaryOp(ISD::FSHR, Op0, Op1, Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P, bool Left>
+struct FunnelShiftLike_match {
+ T0_P Op0;
+ T1_P Op1;
+ T2_P Op2;
+
+ FunnelShiftLike_match(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2)
+ : Op0(Op0), Op1(Op1), Op2(Op2) {}
+
+ static bool hasComplementaryConstantShifts(SDValue ShlAmt, SDValue SrlAmt,
+ unsigned BitWidth) {
+ return ISD::matchBinaryPredicate(
+ ShlAmt, SrlAmt, [BitWidth](ConstantSDNode *ShlC, ConstantSDNode *SrlC) {
+ if (!ShlC || !SrlC)
+ return false;
+
+ const APInt &ShlV = ShlC->getAPIntValue();
+ const APInt &SrlV = SrlC->getAPIntValue();
+ unsigned SumWidth = ShlV.getBitWidth();
+ if (SrlV.getBitWidth() > SumWidth)
+ SumWidth = SrlV.getBitWidth();
+ ++SumWidth;
+
+ return ShlV.zext(SumWidth) + SrlV.zext(SumWidth) ==
+ APInt(SumWidth, BitWidth);
+ });
+ }
+
+ template <typename MatchContext>
+ bool matchOperands(const MatchContext &Ctx, SDValue X, SDValue Y, SDValue Z) {
+ return Op0.match(Ctx, X) && Op1.match(Ctx, Y) && Op2.match(Ctx, Z);
+ }
+
+ template <typename MatchContext>
+ bool matchShiftOr(const MatchContext &Ctx, SDValue ShlOp, SDValue SrlOp,
+ unsigned BitWidth) {
+ SDValue X, Y, ShlAmt, SrlAmt;
+ if (!sd_context_match(ShlOp, Ctx, m_Shl(m_Value(X), m_Value(ShlAmt))) ||
+ !sd_context_match(SrlOp, Ctx, m_Srl(m_Value(Y), m_Value(SrlAmt))) ||
+ !hasComplementaryConstantShifts(ShlAmt, SrlAmt, BitWidth))
+ return false;
+
+ return matchOperands(Ctx, X, Y, Left ? ShlAmt : SrlAmt);
+ }
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ if (sd_context_match(N, Ctx, m_Opc(Left ? ISD::FSHL : ISD::FSHR))) {
+ EffectiveOperands<false> EO(N, Ctx);
+ assert(EO.Size == 3);
+ return matchOperands(Ctx, N->getOperand(EO.FirstIndex),
+ N->getOperand(EO.FirstIndex + 1),
+ N->getOperand(EO.FirstIndex + 2));
+ }
+
+ if (sd_context_match(N, Ctx, m_Opc(Left ? ISD::ROTL : ISD::ROTR))) {
+ EffectiveOperands<false> EO(N, Ctx);
+ assert(EO.Size == 2);
+ SDValue X = N->getOperand(EO.FirstIndex);
+ return matchOperands(Ctx, X, X, N->getOperand(EO.FirstIndex + 1));
+ }
+
+ if (sd_context_match(N, Ctx, m_Opc(ISD::OR))) {
+ EffectiveOperands<false> EO(N, Ctx);
+ assert(EO.Size == 2);
+ SDValue LHS = N->getOperand(EO.FirstIndex);
+ SDValue RHS = N->getOperand(EO.FirstIndex + 1);
+ unsigned BitWidth = N.getValueType().getScalarSizeInBits();
+ return matchShiftOr(Ctx, LHS, RHS, BitWidth) ||
+ matchShiftOr(Ctx, RHS, LHS, BitWidth);
+ }
+
+ return false;
+ }
+};
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline FunnelShiftLike_match<T0_P, T1_P, T2_P, true>
+m_FShLLike(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+ return FunnelShiftLike_match<T0_P, T1_P, T2_P, true>(Op0, Op1, Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline FunnelShiftLike_match<T0_P, T1_P, T2_P, false>
+m_FShRLike(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+ return FunnelShiftLike_match<T0_P, T1_P, T2_P, false>(Op0, Op1, Op2);
+}
+
template <typename LHS, typename RHS>
inline BinaryOpc_match<LHS, RHS, true> m_Clmul(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS, true>(ISD::CLMUL, L, R);
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index e3f4ce20f4234..a15d24f8335c9 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -632,6 +632,86 @@ TEST_F(SelectionDAGPatternMatchTest, matchGenericTernaryOp) {
sd_match(FAdd, m_c_TernaryOp(ISD::FMA, m_Value(), m_Value(), m_Value())));
}
+TEST_F(SelectionDAGPatternMatchTest, matchFunnelShift) {
+ SDLoc DL;
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+ SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+ Register::index2VirtReg(1), Int32VT);
+ SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+ Register::index2VirtReg(2), Int32VT);
+ SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+ Register::index2VirtReg(3), Int32VT);
+ SDValue C7 = DAG->getConstant(7, DL, Int32VT);
+ SDValue C24 = DAG->getConstant(24, DL, Int32VT);
+ SDValue C25 = DAG->getConstant(25, DL, Int32VT);
+
+ SDValue FShL = DAG->getNode(ISD::FSHL, DL, Int32VT, Op0, Op1, Op2);
+ SDValue FShR = DAG->getNode(ISD::FSHR, DL, Int32VT, Op0, Op1, Op2);
+ SDValue Rotl = DAG->getNode(ISD::ROTL, DL, Int32VT, Op0, Op2);
+ SDValue Rotr = DAG->getNode(ISD::ROTR, DL, Int32VT, Op0, Op2);
+
+ SDValue Shl7 = DAG->getNode(ISD::SHL, DL, Int32VT, Op0, C7);
+ SDValue Srl25 = DAG->getNode(ISD::SRL, DL, Int32VT, Op1, C25);
+ SDValue Srl24 = DAG->getNode(ISD::SRL, DL, Int32VT, Op1, C24);
+ SDValue OrFSh = DAG->getNode(ISD::OR, DL, Int32VT, Shl7, Srl25);
+ SDValue OrFShCommuted = DAG->getNode(ISD::OR, DL, Int32VT, Srl25, Shl7);
+ SDValue BadOrFSh = DAG->getNode(ISD::OR, DL, Int32VT, Shl7, Srl24);
+
+ using namespace SDPatternMatch;
+ EXPECT_TRUE(sd_match(
+ FShL, m_FShL(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+ EXPECT_TRUE(sd_match(
+ FShR, m_FShR(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+ EXPECT_FALSE(sd_match(FShL, m_FShR(m_Value(), m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(FShR, m_FShL(m_Value(), m_Value(), m_Value())));
+
+ EXPECT_TRUE(sd_match(
+ FShL, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+ EXPECT_TRUE(sd_match(
+ FShR, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+ EXPECT_FALSE(
+ sd_match(FShL, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Value())));
+ EXPECT_FALSE(
+ sd_match(FShR, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Value())));
+
+ EXPECT_TRUE(sd_match(
+ Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2))));
+ EXPECT_TRUE(sd_match(
+ Rotr, m_FShRLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2))));
+ EXPECT_FALSE(sd_match(
+ Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+ EXPECT_FALSE(sd_match(
+ Rotr, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+ EXPECT_FALSE(sd_match(Rotl, m_FShRLike(m_Value(), m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(Rotr, m_FShLLike(m_Value(), m_Value(), m_Value())));
+
+ SDValue A, B, C;
+ EXPECT_TRUE(sd_match(Rotl, m_FShLLike(m_Value(A), m_Value(B), m_Value(C))));
+ EXPECT_EQ(A, Op0);
+ EXPECT_EQ(B, Op0);
+ EXPECT_EQ(C, Op2);
+
+ A = B = C = SDValue();
+ EXPECT_TRUE(sd_match(Rotr, m_FShRLike(m_Value(A), m_Value(B), m_Value(C))));
+ EXPECT_EQ(A, Op0);
+ EXPECT_EQ(B, Op0);
+ EXPECT_EQ(C, Op2);
+
+ EXPECT_TRUE(sd_match(
+ OrFSh, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(7))));
+ EXPECT_TRUE(sd_match(
+ OrFSh, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(25))));
+ EXPECT_TRUE(
+ sd_match(OrFShCommuted,
+ m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(7))));
+ EXPECT_TRUE(
+ sd_match(OrFShCommuted, m_FShRLike(m_Specific(Op0), m_Specific(Op1),
+ m_SpecificInt(25))));
+ EXPECT_FALSE(sd_match(BadOrFSh, m_FShLLike(m_Value(), m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(BadOrFSh, m_FShRLike(m_Value(), m_Value(), m_Value())));
+}
+
TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);
>From 078ca6d1cc3b0c73434996515c3359a5eae5607b Mon Sep 17 00:00:00 2001
From: Vedant Neve <vedantneve13 at gmail.com>
Date: Wed, 1 Apr 2026 06:50:28 +0000
Subject: [PATCH 2/2] Address review feedback
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 87 +++++++++----------
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 19 +++-
2 files changed, 54 insertions(+), 52 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 9e861d4fff1e6..f57205a0071b4 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -17,6 +17,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/bit.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetLowering.h"
@@ -983,23 +984,16 @@ struct FunnelShiftLike_match {
FunnelShiftLike_match(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2)
: Op0(Op0), Op1(Op1), Op2(Op2) {}
- static bool hasComplementaryConstantShifts(SDValue ShlAmt, SDValue SrlAmt,
+ static bool hasComplementaryConstantShifts(const APInt &ShlV,
+ const APInt &SrlV,
unsigned BitWidth) {
- return ISD::matchBinaryPredicate(
- ShlAmt, SrlAmt, [BitWidth](ConstantSDNode *ShlC, ConstantSDNode *SrlC) {
- if (!ShlC || !SrlC)
- return false;
-
- const APInt &ShlV = ShlC->getAPIntValue();
- const APInt &SrlV = SrlC->getAPIntValue();
- unsigned SumWidth = ShlV.getBitWidth();
- if (SrlV.getBitWidth() > SumWidth)
- SumWidth = SrlV.getBitWidth();
- ++SumWidth;
-
- return ShlV.zext(SumWidth) + SrlV.zext(SumWidth) ==
- APInt(SumWidth, BitWidth);
- });
+ unsigned SumWidth = std::max(ShlV.getBitWidth(), SrlV.getBitWidth()) + 1;
+ unsigned BitWidthBits = llvm::bit_width(BitWidth);
+ if (BitWidthBits > SumWidth)
+ return false;
+
+ return ShlV.zext(SumWidth) + SrlV.zext(SumWidth) ==
+ APInt(SumWidth, BitWidth);
}
template <typename MatchContext>
@@ -1008,43 +1002,22 @@ struct FunnelShiftLike_match {
}
template <typename MatchContext>
- bool matchShiftOr(const MatchContext &Ctx, SDValue ShlOp, SDValue SrlOp,
- unsigned BitWidth) {
- SDValue X, Y, ShlAmt, SrlAmt;
- if (!sd_context_match(ShlOp, Ctx, m_Shl(m_Value(X), m_Value(ShlAmt))) ||
- !sd_context_match(SrlOp, Ctx, m_Srl(m_Value(Y), m_Value(SrlAmt))) ||
- !hasComplementaryConstantShifts(ShlAmt, SrlAmt, BitWidth))
- return false;
-
- return matchOperands(Ctx, X, Y, Left ? ShlAmt : SrlAmt);
- }
+ bool matchShiftOr(const MatchContext &Ctx, SDValue N, unsigned BitWidth);
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
- if (sd_context_match(N, Ctx, m_Opc(Left ? ISD::FSHL : ISD::FSHR))) {
- EffectiveOperands<false> EO(N, Ctx);
- assert(EO.Size == 3);
- return matchOperands(Ctx, N->getOperand(EO.FirstIndex),
- N->getOperand(EO.FirstIndex + 1),
- N->getOperand(EO.FirstIndex + 2));
- }
+ if (sd_context_match(N, Ctx,
+ Left ? m_FShL(Op0, Op1, Op2) : m_FShR(Op0, Op1, Op2)))
+ return true;
- if (sd_context_match(N, Ctx, m_Opc(Left ? ISD::ROTL : ISD::ROTR))) {
- EffectiveOperands<false> EO(N, Ctx);
- assert(EO.Size == 2);
- SDValue X = N->getOperand(EO.FirstIndex);
- return matchOperands(Ctx, X, X, N->getOperand(EO.FirstIndex + 1));
- }
+ SDValue X, Z;
+ if (sd_context_match(N, Ctx,
+ Left ? m_Rotl(m_Value(X), m_Value(Z))
+ : m_Rotr(m_Value(X), m_Value(Z))))
+ return matchOperands(Ctx, X, X, Z);
- if (sd_context_match(N, Ctx, m_Opc(ISD::OR))) {
- EffectiveOperands<false> EO(N, Ctx);
- assert(EO.Size == 2);
- SDValue LHS = N->getOperand(EO.FirstIndex);
- SDValue RHS = N->getOperand(EO.FirstIndex + 1);
- unsigned BitWidth = N.getValueType().getScalarSizeInBits();
- return matchShiftOr(Ctx, LHS, RHS, BitWidth) ||
- matchShiftOr(Ctx, RHS, LHS, BitWidth);
- }
+ if (N->getOpcode() == ISD::OR)
+ return matchShiftOr(Ctx, N, N.getValueType().getScalarSizeInBits());
return false;
}
@@ -1321,6 +1294,24 @@ inline Constant64_match<int64_t> m_ConstInt(int64_t &V) {
return Constant64_match<int64_t>(V);
}
+template <typename T0_P, typename T1_P, typename T2_P, bool Left>
+template <typename MatchContext>
+bool FunnelShiftLike_match<T0_P, T1_P, T2_P, Left>::matchShiftOr(
+ const MatchContext &Ctx, SDValue N, unsigned BitWidth) {
+ SDValue X, Y, ShlAmt, SrlAmt;
+ APInt ShlConst, SrlConst;
+ if (!sd_context_match(
+ N, Ctx,
+ m_Or(
+ m_Shl(m_Value(X), m_AllOf(m_Value(ShlAmt), m_ConstInt(ShlConst))),
+ m_Srl(m_Value(Y),
+ m_AllOf(m_Value(SrlAmt), m_ConstInt(SrlConst))))) ||
+ !hasComplementaryConstantShifts(ShlConst, SrlConst, BitWidth))
+ return false;
+
+ return matchOperands(Ctx, X, Y, Left ? ShlAmt : SrlAmt);
+}
+
struct SpecificInt_match {
APInt IntVal;
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index a15d24f8335c9..d77908a099693 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -670,10 +670,8 @@ TEST_F(SelectionDAGPatternMatchTest, matchFunnelShift) {
FShL, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
EXPECT_TRUE(sd_match(
FShR, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
- EXPECT_FALSE(
- sd_match(FShL, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Value())));
- EXPECT_FALSE(
- sd_match(FShR, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Value())));
+ EXPECT_FALSE(sd_match(FShL, m_FShRLike(m_Value(), m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(FShR, m_FShLLike(m_Value(), m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2))));
@@ -710,6 +708,19 @@ TEST_F(SelectionDAGPatternMatchTest, matchFunnelShift) {
m_SpecificInt(25))));
EXPECT_FALSE(sd_match(BadOrFSh, m_FShLLike(m_Value(), m_Value(), m_Value())));
EXPECT_FALSE(sd_match(BadOrFSh, m_FShRLike(m_Value(), m_Value(), m_Value())));
+
+ auto Int1024VT = EVT::getIntegerVT(Context, 1024);
+ auto Int8VT = EVT::getIntegerVT(Context, 8);
+ SDValue WideOp0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+ Register::index2VirtReg(4), Int1024VT);
+ SDValue WideOp1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+ Register::index2VirtReg(5), Int1024VT);
+ SDValue C0I8 = DAG->getConstant(0, DL, Int8VT);
+ SDValue WideShl = DAG->getNode(ISD::SHL, DL, Int1024VT, WideOp0, C0I8);
+ SDValue WideSrl = DAG->getNode(ISD::SRL, DL, Int1024VT, WideOp1, C0I8);
+ SDValue WideOr = DAG->getNode(ISD::OR, DL, Int1024VT, WideShl, WideSrl);
+ EXPECT_FALSE(sd_match(WideOr, m_FShLLike(m_Value(), m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(WideOr, m_FShRLike(m_Value(), m_Value(), m_Value())));
}
TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
More information about the llvm-commits
mailing list