[llvm] [X86] LowerShift - pull out repeated getVectorNumElements calls. NFC. (PR #120241)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 17 07:03:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/120241.diff


1 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+5-8) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c50db59464724c..4ed0a886c9d298 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -30000,6 +30000,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
   SDLoc dl(Op);
   SDValue R = Op.getOperand(0);
   SDValue Amt = Op.getOperand(1);
+  unsigned NumElts = VT.getVectorNumElements();
   unsigned EltSizeInBits = VT.getScalarSizeInBits();
   bool ConstantAmt = ISD::isBuildVectorOfConstantSDNodes(Amt.getNode());
 
@@ -30069,7 +30070,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
   if (ConstantAmt && (VT == MVT::v8i16 || VT == MVT::v4i32 ||
                       (VT == MVT::v16i16 && Subtarget.hasInt256()))) {
     SDValue Amt1, Amt2;
-    unsigned NumElts = VT.getVectorNumElements();
     SmallVector<int, 8> ShuffleMask;
     for (unsigned i = 0; i != NumElts; ++i) {
       SDValue A = Amt->getOperand(i);
@@ -30116,7 +30116,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
        VT == MVT::v8i16 || VT == MVT::v16i16 || VT == MVT::v32i16) &&
       !Subtarget.hasXOP()) {
     MVT NarrowScalarVT = VT.getScalarType();
-    int NumElts = VT.getVectorNumElements();
     // We can do this extra fast if each pair of narrow elements is shifted by
     // the same amount by doing this SWAR style: use a shift to move the valid
     // bits to the right position, mask out any bits which crossed from one
@@ -30377,7 +30376,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
   if ((VT == MVT::v16i8 && Subtarget.hasSSSE3()) ||
       (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
       (VT == MVT::v64i8 && Subtarget.hasBWI())) {
-    unsigned NumElts = VT.getVectorNumElements();
     unsigned NumLanes = VT.getSizeInBits() / 128u;
     unsigned NumEltsPerLane = NumElts / NumLanes;
     SmallVector<APInt, 16> LUT;
@@ -30417,7 +30415,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
     assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) &&
            "Unexpected vector type");
     MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32;
-    MVT ExtVT = MVT::getVectorVT(EvtSVT, VT.getVectorNumElements());
+    MVT ExtVT = MVT::getVectorVT(EvtSVT, NumElts);
     unsigned ExtOpc = Opc == ISD::SRA ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
     R = DAG.getNode(ExtOpc, dl, ExtVT, R);
     Amt = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Amt);
@@ -30431,7 +30429,6 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
       (VT == MVT::v16i8 || (VT == MVT::v32i8 && Subtarget.hasInt256()) ||
        (VT == MVT::v64i8 && Subtarget.hasBWI())) &&
       !Subtarget.hasXOP()) {
-    int NumElts = VT.getVectorNumElements();
     MVT VT16 = MVT::getVectorVT(MVT::i16, NumElts / 2);
     SDValue Cst8 = DAG.getTargetConstant(8, dl, MVT::i8);
 
@@ -30477,14 +30474,14 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
   if (VT == MVT::v16i8 ||
       (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP()) ||
       (VT == MVT::v64i8 && Subtarget.hasBWI())) {
-    MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2);
+    MVT ExtVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
 
     auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) {
       if (VT.is512BitVector()) {
         // On AVX512BW targets we make use of the fact that VSELECT lowers
         // to a masked blend which selects bytes based just on the sign bit
         // extracted to a mask.
-        MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
+        MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
         V0 = DAG.getBitcast(VT, V0);
         V1 = DAG.getBitcast(VT, V1);
         Sel = DAG.getBitcast(VT, Sel);
@@ -30611,7 +30608,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
       // On SSE41 targets we can use PBLENDVB which selects bytes based just on
       // the sign bit.
       if (UseSSE41) {
-        MVT ExtVT = MVT::getVectorVT(MVT::i8, VT.getVectorNumElements() * 2);
+        MVT ExtVT = MVT::getVectorVT(MVT::i8, NumElts * 2);
         V0 = DAG.getBitcast(ExtVT, V0);
         V1 = DAG.getBitcast(ExtVT, V1);
         Sel = DAG.getBitcast(ExtVT, Sel);

``````````

</details>


https://github.com/llvm/llvm-project/pull/120241


More information about the llvm-commits mailing list