[llvm] c2580af - [X86] Convert shift+clamp -> avx2 shift folds to use SDPatternMatch::m_SetCC. NFC.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 15 03:42:30 PDT 2024


Author: Simon Pilgrim
Date: 2024-07-15T11:42:12+01:00
New Revision: c2580afed7e55f13762d56400dc346f222ea5884

URL: https://github.com/llvm/llvm-project/commit/c2580afed7e55f13762d56400dc346f222ea5884
DIFF: https://github.com/llvm/llvm-project/commit/c2580afed7e55f13762d56400dc346f222ea5884.diff

LOG: [X86] Convert shift+clamp -> avx2 shift folds to use SDPatternMatch::m_SetCC. NFC.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a731541ca7778..91a5526a82bbe 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -46193,15 +46193,13 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
   if (N->getOpcode() == ISD::VSELECT &&
       (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) &&
       supportedVectorVarShift(VT, Subtarget, LHS.getOpcode())) {
-    APInt SV;
+    using namespace llvm::SDPatternMatch;
     // fold select(icmp_ult(amt,BW),shl(x,amt),0) -> avx2 psllv(x,amt)
     // fold select(icmp_ult(amt,BW),srl(x,amt),0) -> avx2 psrlv(x,amt)
-    if (Cond.getOpcode() == ISD::SETCC &&
-        Cond.getOperand(0) == LHS.getOperand(1) &&
-        cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
-        ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
-        ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
-        SV == VT.getScalarSizeInBits()) {
+    if (ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
+        sd_match(Cond, m_SetCC(m_Specific(LHS.getOperand(1)),
+                               m_SpecificInt(VT.getScalarSizeInBits()),
+                               m_SpecificCondCode(ISD::SETULT)))) {
       return DAG.getNode(LHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV
                                                      : X86ISD::VSHLV,
                          DL, VT, LHS.getOperand(0), LHS.getOperand(1));
@@ -48020,10 +48018,12 @@ static SDValue combineShiftToPMULH(SDNode *N, SelectionDAG &DAG,
 
 static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
                                 const X86Subtarget &Subtarget) {
+  using namespace llvm::SDPatternMatch;
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
   EVT VT = N0.getValueType();
+  unsigned EltSizeInBits = VT.getScalarSizeInBits();
   SDLoc DL(N);
 
   // Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts
@@ -48033,21 +48033,16 @@ static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
     SDValue Cond = N0.getOperand(0);
     SDValue N00 = N0.getOperand(1);
     SDValue N01 = N0.getOperand(2);
-    APInt SV;
     // fold shl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psllv(x,amt)
-    if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
-        cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
-        ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
-        ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
-        SV == VT.getScalarSizeInBits()) {
+    if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
+        sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
+                               m_SpecificCondCode(ISD::SETULT)))) {
       return DAG.getNode(X86ISD::VSHLV, DL, VT, N00, N1);
     }
     // fold shl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psllv(x,amt)
-    if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
-        cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETUGE &&
-        ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
-        ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
-        SV == VT.getScalarSizeInBits()) {
+    if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
+        sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
+                               m_SpecificCondCode(ISD::SETUGE)))) {
       return DAG.getNode(X86ISD::VSHLV, DL, VT, N01, N1);
     }
   }
@@ -48160,9 +48155,11 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
 static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
                                         TargetLowering::DAGCombinerInfo &DCI,
                                         const X86Subtarget &Subtarget) {
+  using namespace llvm::SDPatternMatch;
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   EVT VT = N0.getValueType();
+  unsigned EltSizeInBits = VT.getScalarSizeInBits();
   SDLoc DL(N);
 
   if (SDValue V = combineShiftToPMULH(N, DAG, DL, Subtarget))
@@ -48175,21 +48172,16 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
     SDValue Cond = N0.getOperand(0);
     SDValue N00 = N0.getOperand(1);
     SDValue N01 = N0.getOperand(2);
-    APInt SV;
     // fold srl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psrlv(x,amt)
-    if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
-        cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
-        ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
-        ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
-        SV == VT.getScalarSizeInBits()) {
+    if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
+        sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
+                               m_SpecificCondCode(ISD::SETULT)))) {
       return DAG.getNode(X86ISD::VSRLV, DL, VT, N00, N1);
     }
     // fold srl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psrlv(x,amt)
-    if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
-        cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETUGE &&
-        ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
-        ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
-        SV == VT.getScalarSizeInBits()) {
+    if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
+        sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
+                               m_SpecificCondCode(ISD::SETUGE)))) {
       return DAG.getNode(X86ISD::VSRLV, DL, VT, N01, N1);
     }
   }


        


More information about the llvm-commits mailing list