[llvm] 43e357f - [X86] Update sra(x,umin(amt,bw-1)) -> psrav(x,amt) fold to use SDPatternMatch. NFC.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 12 03:36:15 PDT 2024


Author: Simon Pilgrim
Date: 2024-07-12T11:35:53+01:00
New Revision: 43e357fa60383b48f1debccbb6ea63a8a8583722

URL: https://github.com/llvm/llvm-project/commit/43e357fa60383b48f1debccbb6ea63a8a8583722
DIFF: https://github.com/llvm/llvm-project/commit/43e357fa60383b48f1debccbb6ea63a8a8583722.diff

LOG: [X86] Update sra(x,umin(amt,bw-1)) -> psrav(x,amt) fold to use SDPatternMatch. NFC.

First tentative attempt to use SDPatternMatch for x86 combine matching - main problem so far is namespace clashing when trying to expose llvm::SDPatternMatch to the entire file.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 9f52d670e4b37..9b3aeb2fd1803 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -38,6 +38,7 @@
 #include "llvm/CodeGen/MachineLoopInfo.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/CodeGen/WinEHFuncInfo.h"
 #include "llvm/IR/CallingConv.h"
@@ -48084,22 +48085,22 @@ static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
 
 static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
                                            const X86Subtarget &Subtarget) {
+  using namespace llvm::SDPatternMatch;
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   EVT VT = N0.getValueType();
   unsigned Size = VT.getSizeInBits();
+  SDLoc DL(N);
 
   if (SDValue V = combineShiftToPMULH(N, DAG, Subtarget))
     return V;
 
-  APInt ShiftAmt;
-  if (supportedVectorVarShift(VT, Subtarget, ISD::SRA) &&
-      N1.getOpcode() == ISD::UMIN &&
-      ISD::isConstantSplatVector(N1.getOperand(1).getNode(), ShiftAmt) &&
-      ShiftAmt == VT.getScalarSizeInBits() - 1) {
-    SDValue ShrAmtVal = N1.getOperand(0);
-    SDLoc DL(N);
-    return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
+  // fold sra(x,umin(amt,bw-1)) -> avx2 psrav(x,amt)
+  if (supportedVectorVarShift(VT, Subtarget, ISD::SRA)) {
+    SDValue ShrAmtVal;
+    if (sd_match(N1, m_UMin(m_Value(ShrAmtVal),
+                            m_SpecificInt(VT.getScalarSizeInBits() - 1))))
+      return DAG.getNode(X86ISD::VSRAV, DL, VT, N0, ShrAmtVal);
   }
 
   // fold (SRA (SHL X, ShlConst), SraConst)
@@ -48137,7 +48138,6 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
     // Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
     if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
       continue;
-    SDLoc DL(N);
     SDValue NN =
         DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
     if (SraConst.eq(ShlConst))


        


More information about the llvm-commits mailing list