[llvm] e83eee3 - [DAG] Create SDPatternMatch method `m_SelectLike` to match `ISD::Select` and `ISD::VSelect` (#164069)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 22 02:49:38 PDT 2025


Author: kper
Date: 2025-10-22T09:49:34Z
New Revision: e83eee335c477ab80612b09bf840700d6982c3ef

URL: https://github.com/llvm/llvm-project/commit/e83eee335c477ab80612b09bf840700d6982c3ef
DIFF: https://github.com/llvm/llvm-project/commit/e83eee335c477ab80612b09bf840700d6982c3ef.diff

LOG: [DAG] Create SDPatternMatch method `m_SelectLike` to match `ISD::Select` and `ISD::VSelect` (#164069)

Fixes #150019

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SDPatternMatch.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 201dc68de8b76..0dcf400962393 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -558,6 +558,11 @@ m_VSelect(const T0_P &Cond, const T1_P &T, const T2_P &F) {
   return TernaryOpc_match<T0_P, T1_P, T2_P>(ISD::VSELECT, Cond, T, F);
 }
 
+template <typename T0_P, typename T1_P, typename T2_P>
+inline auto m_SelectLike(const T0_P &Cond, const T1_P &T, const T2_P &F) {
+  return m_AnyOf(m_Select(Cond, T, F), m_VSelect(Cond, T, F));
+}
+
 template <typename T0_P, typename T1_P, typename T2_P>
 inline Result_match<0, TernaryOpc_match<T0_P, T1_P, T2_P>>
 m_Load(const T0_P &Ch, const T1_P &Ptr, const T2_P &Offset) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 310d35d9b1d1e..6aa71254fe6ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2476,16 +2476,17 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
 /// masked vector operation if the target supports it.
 static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
                                               bool ShouldCommuteOperands) {
-  // Match a select as operand 1. The identity constant that we are looking for
-  // is only valid as operand 1 of a non-commutative binop.
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
+
+  // Match a select as operand 1. The identity constant that we are looking for
+  // is only valid as operand 1 of a non-commutative binop.
   if (ShouldCommuteOperands)
     std::swap(N0, N1);
 
-  unsigned SelOpcode = N1.getOpcode();
-  if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
-      !N1.hasOneUse())
+  SDValue Cond, TVal, FVal;
+  if (!sd_match(N1, m_OneUse(m_SelectLike(m_Value(Cond), m_Value(TVal),
+                                          m_Value(FVal)))))
     return SDValue();
 
   // We can't hoist all instructions because of immediate UB (not speculatable).
@@ -2493,11 +2494,9 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
   if (!DAG.isSafeToSpeculativelyExecuteNode(N))
     return SDValue();
 
+  unsigned SelOpcode = N1.getOpcode();
   unsigned Opcode = N->getOpcode();
   EVT VT = N->getValueType(0);
-  SDValue Cond = N1.getOperand(0);
-  SDValue TVal = N1.getOperand(1);
-  SDValue FVal = N1.getOperand(2);
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   // This transform increases uses of N0, so freeze it to be safe.
@@ -13856,12 +13855,11 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
           Opcode == ISD::ANY_EXTEND) &&
          "Expected EXTEND dag node in input!");
 
-  if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
-      !N0.hasOneUse())
+  SDValue Cond, Op1, Op2;
+  if (!sd_match(N0, m_OneUse(m_SelectLike(m_Value(Cond), m_Value(Op1),
+                                          m_Value(Op2)))))
     return SDValue();
 
-  SDValue Op1 = N0->getOperand(1);
-  SDValue Op2 = N0->getOperand(2);
   if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
     return SDValue();
 
@@ -13883,7 +13881,7 @@ static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
 
   SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
   SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
-  return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
+  return DAG.getSelect(DL, VT, Cond, Ext1, Ext2);
 }
 
 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
@@ -29620,13 +29618,14 @@ static SDValue takeInexpensiveLog2(SelectionDAG &DAG, const SDLoc &DL, EVT VT,
   }
 
   // c ? X : Y -> c ? Log2(X) : Log2(Y)
-  if ((Op.getOpcode() == ISD::SELECT || Op.getOpcode() == ISD::VSELECT) &&
-      Op.hasOneUse()) {
-    if (SDValue LogX = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(1),
-                                           Depth + 1, AssumeNonZero))
-      if (SDValue LogY = takeInexpensiveLog2(DAG, DL, VT, Op.getOperand(2),
-                                             Depth + 1, AssumeNonZero))
-        return DAG.getSelect(DL, VT, Op.getOperand(0), LogX, LogY);
+  SDValue Cond, TVal, FVal;
+  if (sd_match(Op, m_OneUse(m_SelectLike(m_Value(Cond), m_Value(TVal),
+                                         m_Value(FVal))))) {
+    if (SDValue LogX =
+            takeInexpensiveLog2(DAG, DL, VT, TVal, Depth + 1, AssumeNonZero))
+      if (SDValue LogY =
+              takeInexpensiveLog2(DAG, DL, VT, FVal, Depth + 1, AssumeNonZero))
+        return DAG.getSelect(DL, VT, Cond, LogX, LogY);
   }
 
   // log2(umin(X, Y)) -> umin(log2(X), log2(Y))

diff  --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 16b997901dc1c..aa56aafa2812c 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -550,6 +550,31 @@ TEST_F(SelectionDAGPatternMatchTest, matchNode) {
   EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
 }
 
+TEST_F(SelectionDAGPatternMatchTest, matchSelectLike) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+
+  SDValue Cond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, Int32VT);
+  SDValue TVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue FVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+
+  SDValue VCond = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 0, VInt32VT);
+  SDValue VTVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+  SDValue VFVal = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
+
+  SDValue Select = DAG->getNode(ISD::SELECT, DL, Int32VT, Cond, TVal, FVal);
+  SDValue VSelect =
+      DAG->getNode(ISD::VSELECT, DL, Int32VT, VCond, VTVal, VFVal);
+
+  using namespace SDPatternMatch;
+  EXPECT_TRUE(sd_match(Select, m_SelectLike(m_Specific(Cond), m_Specific(TVal),
+                                            m_Specific(FVal))));
+  EXPECT_TRUE(
+      sd_match(VSelect, m_SelectLike(m_Specific(VCond), m_Specific(VTVal),
+                                     m_Specific(VFVal))));
+}
+
 namespace {
 struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
   using SDPatternMatch::BasicMatchContext::BasicMatchContext;


        


More information about the llvm-commits mailing list