[llvm] [DAG] Add funnel-shift matchers to SDPatternMatch (Fixes #185880) (PR #186593)

Vedant Neve via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 1 00:00:31 PDT 2026


https://github.com/0bVdnt updated https://github.com/llvm/llvm-project/pull/186593

>From 25b332cc1f375f26ff17e85255ed0c114489f5ea Mon Sep 17 00:00:00 2001
From: Vedant Neve <vedantneve13 at gmail.com>
Date: Sat, 14 Mar 2026 13:06:27 +0000
Subject: [PATCH 1/2] [DAG] Add funnel-shift matchers to SDPatternMatch

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 100 ++++++++++++++++++
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  |  80 ++++++++++++++
 2 files changed, 180 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 62762d0c50d6a..9e861d4fff1e6 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -962,6 +962,106 @@ inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) {
   return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R);
 }
 
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P>
+m_FShL(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return m_TernaryOp(ISD::FSHL, Op0, Op1, Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P>
+m_FShR(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return m_TernaryOp(ISD::FSHR, Op0, Op1, Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P, bool Left>
+struct FunnelShiftLike_match {
+  T0_P Op0;
+  T1_P Op1;
+  T2_P Op2;
+
+  FunnelShiftLike_match(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2)
+      : Op0(Op0), Op1(Op1), Op2(Op2) {}
+
+  static bool hasComplementaryConstantShifts(SDValue ShlAmt, SDValue SrlAmt,
+                                             unsigned BitWidth) {
+    return ISD::matchBinaryPredicate(
+        ShlAmt, SrlAmt, [BitWidth](ConstantSDNode *ShlC, ConstantSDNode *SrlC) {
+          if (!ShlC || !SrlC)
+            return false;
+
+          const APInt &ShlV = ShlC->getAPIntValue();
+          const APInt &SrlV = SrlC->getAPIntValue();
+          unsigned SumWidth = ShlV.getBitWidth();
+          if (SrlV.getBitWidth() > SumWidth)
+            SumWidth = SrlV.getBitWidth();
+          ++SumWidth;
+
+          return ShlV.zext(SumWidth) + SrlV.zext(SumWidth) ==
+                 APInt(SumWidth, BitWidth);
+        });
+  }
+
+  template <typename MatchContext>
+  bool matchOperands(const MatchContext &Ctx, SDValue X, SDValue Y, SDValue Z) {
+    return Op0.match(Ctx, X) && Op1.match(Ctx, Y) && Op2.match(Ctx, Z);
+  }
+
+  template <typename MatchContext>
+  bool matchShiftOr(const MatchContext &Ctx, SDValue ShlOp, SDValue SrlOp,
+                    unsigned BitWidth) {
+    SDValue X, Y, ShlAmt, SrlAmt;
+    if (!sd_context_match(ShlOp, Ctx, m_Shl(m_Value(X), m_Value(ShlAmt))) ||
+        !sd_context_match(SrlOp, Ctx, m_Srl(m_Value(Y), m_Value(SrlAmt))) ||
+        !hasComplementaryConstantShifts(ShlAmt, SrlAmt, BitWidth))
+      return false;
+
+    return matchOperands(Ctx, X, Y, Left ? ShlAmt : SrlAmt);
+  }
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    if (sd_context_match(N, Ctx, m_Opc(Left ? ISD::FSHL : ISD::FSHR))) {
+      EffectiveOperands<false> EO(N, Ctx);
+      assert(EO.Size == 3);
+      return matchOperands(Ctx, N->getOperand(EO.FirstIndex),
+                           N->getOperand(EO.FirstIndex + 1),
+                           N->getOperand(EO.FirstIndex + 2));
+    }
+
+    if (sd_context_match(N, Ctx, m_Opc(Left ? ISD::ROTL : ISD::ROTR))) {
+      EffectiveOperands<false> EO(N, Ctx);
+      assert(EO.Size == 2);
+      SDValue X = N->getOperand(EO.FirstIndex);
+      return matchOperands(Ctx, X, X, N->getOperand(EO.FirstIndex + 1));
+    }
+
+    if (sd_context_match(N, Ctx, m_Opc(ISD::OR))) {
+      EffectiveOperands<false> EO(N, Ctx);
+      assert(EO.Size == 2);
+      SDValue LHS = N->getOperand(EO.FirstIndex);
+      SDValue RHS = N->getOperand(EO.FirstIndex + 1);
+      unsigned BitWidth = N.getValueType().getScalarSizeInBits();
+      return matchShiftOr(Ctx, LHS, RHS, BitWidth) ||
+             matchShiftOr(Ctx, RHS, LHS, BitWidth);
+    }
+
+    return false;
+  }
+};
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline FunnelShiftLike_match<T0_P, T1_P, T2_P, true>
+m_FShLLike(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return FunnelShiftLike_match<T0_P, T1_P, T2_P, true>(Op0, Op1, Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline FunnelShiftLike_match<T0_P, T1_P, T2_P, false>
+m_FShRLike(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return FunnelShiftLike_match<T0_P, T1_P, T2_P, false>(Op0, Op1, Op2);
+}
+
 template <typename LHS, typename RHS>
 inline BinaryOpc_match<LHS, RHS, true> m_Clmul(const LHS &L, const RHS &R) {
   return BinaryOpc_match<LHS, RHS, true>(ISD::CLMUL, L, R);
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index e3f4ce20f4234..a15d24f8335c9 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -632,6 +632,86 @@ TEST_F(SelectionDAGPatternMatchTest, matchGenericTernaryOp) {
       sd_match(FAdd, m_c_TernaryOp(ISD::FMA, m_Value(), m_Value(), m_Value())));
 }
 
+TEST_F(SelectionDAGPatternMatchTest, matchFunnelShift) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+                                    Register::index2VirtReg(1), Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+                                    Register::index2VirtReg(2), Int32VT);
+  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+                                    Register::index2VirtReg(3), Int32VT);
+  SDValue C7 = DAG->getConstant(7, DL, Int32VT);
+  SDValue C24 = DAG->getConstant(24, DL, Int32VT);
+  SDValue C25 = DAG->getConstant(25, DL, Int32VT);
+
+  SDValue FShL = DAG->getNode(ISD::FSHL, DL, Int32VT, Op0, Op1, Op2);
+  SDValue FShR = DAG->getNode(ISD::FSHR, DL, Int32VT, Op0, Op1, Op2);
+  SDValue Rotl = DAG->getNode(ISD::ROTL, DL, Int32VT, Op0, Op2);
+  SDValue Rotr = DAG->getNode(ISD::ROTR, DL, Int32VT, Op0, Op2);
+
+  SDValue Shl7 = DAG->getNode(ISD::SHL, DL, Int32VT, Op0, C7);
+  SDValue Srl25 = DAG->getNode(ISD::SRL, DL, Int32VT, Op1, C25);
+  SDValue Srl24 = DAG->getNode(ISD::SRL, DL, Int32VT, Op1, C24);
+  SDValue OrFSh = DAG->getNode(ISD::OR, DL, Int32VT, Shl7, Srl25);
+  SDValue OrFShCommuted = DAG->getNode(ISD::OR, DL, Int32VT, Srl25, Shl7);
+  SDValue BadOrFSh = DAG->getNode(ISD::OR, DL, Int32VT, Shl7, Srl24);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(
+      FShL, m_FShL(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+  EXPECT_TRUE(sd_match(
+      FShR, m_FShR(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+  EXPECT_FALSE(sd_match(FShL, m_FShR(m_Value(), m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(FShR, m_FShL(m_Value(), m_Value(), m_Value())));
+
+  EXPECT_TRUE(sd_match(
+      FShL, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+  EXPECT_TRUE(sd_match(
+      FShR, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+  EXPECT_FALSE(
+      sd_match(FShL, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Value())));
+  EXPECT_FALSE(
+      sd_match(FShR, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Value())));
+
+  EXPECT_TRUE(sd_match(
+      Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2))));
+  EXPECT_TRUE(sd_match(
+      Rotr, m_FShRLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2))));
+  EXPECT_FALSE(sd_match(
+      Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+  EXPECT_FALSE(sd_match(
+      Rotr, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+  EXPECT_FALSE(sd_match(Rotl, m_FShRLike(m_Value(), m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(Rotr, m_FShLLike(m_Value(), m_Value(), m_Value())));
+
+  SDValue A, B, C;
+  EXPECT_TRUE(sd_match(Rotl, m_FShLLike(m_Value(A), m_Value(B), m_Value(C))));
+  EXPECT_EQ(A, Op0);
+  EXPECT_EQ(B, Op0);
+  EXPECT_EQ(C, Op2);
+
+  A = B = C = SDValue();
+  EXPECT_TRUE(sd_match(Rotr, m_FShRLike(m_Value(A), m_Value(B), m_Value(C))));
+  EXPECT_EQ(A, Op0);
+  EXPECT_EQ(B, Op0);
+  EXPECT_EQ(C, Op2);
+
+  EXPECT_TRUE(sd_match(
+      OrFSh, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(7))));
+  EXPECT_TRUE(sd_match(
+      OrFSh, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(25))));
+  EXPECT_TRUE(
+      sd_match(OrFShCommuted,
+               m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_SpecificInt(7))));
+  EXPECT_TRUE(
+      sd_match(OrFShCommuted, m_FShRLike(m_Specific(Op0), m_Specific(Op1),
+                                         m_SpecificInt(25))));
+  EXPECT_FALSE(sd_match(BadOrFSh, m_FShLLike(m_Value(), m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(BadOrFSh, m_FShRLike(m_Value(), m_Value(), m_Value())));
+}
+
 TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
   SDLoc DL;
   auto Int32VT = EVT::getIntegerVT(Context, 32);

>From 078ca6d1cc3b0c73434996515c3359a5eae5607b Mon Sep 17 00:00:00 2001
From: Vedant Neve <vedantneve13 at gmail.com>
Date: Wed, 1 Apr 2026 06:50:28 +0000
Subject: [PATCH 2/2] Address review feedback

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 87 +++++++++----------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 19 +++-
 2 files changed, 54 insertions(+), 52 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 9e861d4fff1e6..f57205a0071b4 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -17,6 +17,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/bit.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetLowering.h"
@@ -983,23 +984,16 @@ struct FunnelShiftLike_match {
   FunnelShiftLike_match(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2)
       : Op0(Op0), Op1(Op1), Op2(Op2) {}
 
-  static bool hasComplementaryConstantShifts(SDValue ShlAmt, SDValue SrlAmt,
+  static bool hasComplementaryConstantShifts(const APInt &ShlV,
+                                             const APInt &SrlV,
                                              unsigned BitWidth) {
-    return ISD::matchBinaryPredicate(
-        ShlAmt, SrlAmt, [BitWidth](ConstantSDNode *ShlC, ConstantSDNode *SrlC) {
-          if (!ShlC || !SrlC)
-            return false;
-
-          const APInt &ShlV = ShlC->getAPIntValue();
-          const APInt &SrlV = SrlC->getAPIntValue();
-          unsigned SumWidth = ShlV.getBitWidth();
-          if (SrlV.getBitWidth() > SumWidth)
-            SumWidth = SrlV.getBitWidth();
-          ++SumWidth;
-
-          return ShlV.zext(SumWidth) + SrlV.zext(SumWidth) ==
-                 APInt(SumWidth, BitWidth);
-        });
+    unsigned SumWidth = std::max(ShlV.getBitWidth(), SrlV.getBitWidth()) + 1;
+    unsigned BitWidthBits = llvm::bit_width(BitWidth);
+    if (BitWidthBits > SumWidth)
+      return false;
+
+    return ShlV.zext(SumWidth) + SrlV.zext(SumWidth) ==
+           APInt(SumWidth, BitWidth);
   }
 
   template <typename MatchContext>
@@ -1008,43 +1002,22 @@ struct FunnelShiftLike_match {
   }
 
   template <typename MatchContext>
-  bool matchShiftOr(const MatchContext &Ctx, SDValue ShlOp, SDValue SrlOp,
-                    unsigned BitWidth) {
-    SDValue X, Y, ShlAmt, SrlAmt;
-    if (!sd_context_match(ShlOp, Ctx, m_Shl(m_Value(X), m_Value(ShlAmt))) ||
-        !sd_context_match(SrlOp, Ctx, m_Srl(m_Value(Y), m_Value(SrlAmt))) ||
-        !hasComplementaryConstantShifts(ShlAmt, SrlAmt, BitWidth))
-      return false;
-
-    return matchOperands(Ctx, X, Y, Left ? ShlAmt : SrlAmt);
-  }
+  bool matchShiftOr(const MatchContext &Ctx, SDValue N, unsigned BitWidth);
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    if (sd_context_match(N, Ctx, m_Opc(Left ? ISD::FSHL : ISD::FSHR))) {
-      EffectiveOperands<false> EO(N, Ctx);
-      assert(EO.Size == 3);
-      return matchOperands(Ctx, N->getOperand(EO.FirstIndex),
-                           N->getOperand(EO.FirstIndex + 1),
-                           N->getOperand(EO.FirstIndex + 2));
-    }
+    if (sd_context_match(N, Ctx,
+                         Left ? m_FShL(Op0, Op1, Op2) : m_FShR(Op0, Op1, Op2)))
+      return true;
 
-    if (sd_context_match(N, Ctx, m_Opc(Left ? ISD::ROTL : ISD::ROTR))) {
-      EffectiveOperands<false> EO(N, Ctx);
-      assert(EO.Size == 2);
-      SDValue X = N->getOperand(EO.FirstIndex);
-      return matchOperands(Ctx, X, X, N->getOperand(EO.FirstIndex + 1));
-    }
+    SDValue X, Z;
+    if (sd_context_match(N, Ctx,
+                         Left ? m_Rotl(m_Value(X), m_Value(Z))
+                              : m_Rotr(m_Value(X), m_Value(Z))))
+      return matchOperands(Ctx, X, X, Z);
 
-    if (sd_context_match(N, Ctx, m_Opc(ISD::OR))) {
-      EffectiveOperands<false> EO(N, Ctx);
-      assert(EO.Size == 2);
-      SDValue LHS = N->getOperand(EO.FirstIndex);
-      SDValue RHS = N->getOperand(EO.FirstIndex + 1);
-      unsigned BitWidth = N.getValueType().getScalarSizeInBits();
-      return matchShiftOr(Ctx, LHS, RHS, BitWidth) ||
-             matchShiftOr(Ctx, RHS, LHS, BitWidth);
-    }
+    if (N->getOpcode() == ISD::OR)
+      return matchShiftOr(Ctx, N, N.getValueType().getScalarSizeInBits());
 
     return false;
   }
@@ -1321,6 +1294,24 @@ inline Constant64_match<int64_t> m_ConstInt(int64_t &V) {
   return Constant64_match<int64_t>(V);
 }
 
+template <typename T0_P, typename T1_P, typename T2_P, bool Left>
+template <typename MatchContext>
+bool FunnelShiftLike_match<T0_P, T1_P, T2_P, Left>::matchShiftOr(
+    const MatchContext &Ctx, SDValue N, unsigned BitWidth) {
+  SDValue X, Y, ShlAmt, SrlAmt;
+  APInt ShlConst, SrlConst;
+  if (!sd_context_match(
+          N, Ctx,
+          m_Or(
+              m_Shl(m_Value(X), m_AllOf(m_Value(ShlAmt), m_ConstInt(ShlConst))),
+              m_Srl(m_Value(Y),
+                    m_AllOf(m_Value(SrlAmt), m_ConstInt(SrlConst))))) ||
+      !hasComplementaryConstantShifts(ShlConst, SrlConst, BitWidth))
+    return false;
+
+  return matchOperands(Ctx, X, Y, Left ? ShlAmt : SrlAmt);
+}
+
 struct SpecificInt_match {
   APInt IntVal;
 
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index a15d24f8335c9..d77908a099693 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -670,10 +670,8 @@ TEST_F(SelectionDAGPatternMatchTest, matchFunnelShift) {
       FShL, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
   EXPECT_TRUE(sd_match(
       FShR, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
-  EXPECT_FALSE(
-      sd_match(FShL, m_FShRLike(m_Specific(Op0), m_Specific(Op1), m_Value())));
-  EXPECT_FALSE(
-      sd_match(FShR, m_FShLLike(m_Specific(Op0), m_Specific(Op1), m_Value())));
+  EXPECT_FALSE(sd_match(FShL, m_FShRLike(m_Value(), m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(FShR, m_FShLLike(m_Value(), m_Value(), m_Value())));
 
   EXPECT_TRUE(sd_match(
       Rotl, m_FShLLike(m_Specific(Op0), m_Specific(Op0), m_Specific(Op2))));
@@ -710,6 +708,19 @@ TEST_F(SelectionDAGPatternMatchTest, matchFunnelShift) {
                                          m_SpecificInt(25))));
   EXPECT_FALSE(sd_match(BadOrFSh, m_FShLLike(m_Value(), m_Value(), m_Value())));
   EXPECT_FALSE(sd_match(BadOrFSh, m_FShRLike(m_Value(), m_Value(), m_Value())));
+
+  auto Int1024VT = EVT::getIntegerVT(Context, 1024);
+  auto Int8VT = EVT::getIntegerVT(Context, 8);
+  SDValue WideOp0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+                                        Register::index2VirtReg(4), Int1024VT);
+  SDValue WideOp1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL,
+                                        Register::index2VirtReg(5), Int1024VT);
+  SDValue C0I8 = DAG->getConstant(0, DL, Int8VT);
+  SDValue WideShl = DAG->getNode(ISD::SHL, DL, Int1024VT, WideOp0, C0I8);
+  SDValue WideSrl = DAG->getNode(ISD::SRL, DL, Int1024VT, WideOp1, C0I8);
+  SDValue WideOr = DAG->getNode(ISD::OR, DL, Int1024VT, WideShl, WideSrl);
+  EXPECT_FALSE(sd_match(WideOr, m_FShLLike(m_Value(), m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(WideOr, m_FShRLike(m_Value(), m_Value(), m_Value())));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {



More information about the llvm-commits mailing list