[llvm] a310cba - [X86] Add getAVX512Node helper. NFC.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 13 06:00:10 PST 2021


Author: Simon Pilgrim
Date: 2021-11-13T13:59:42Z
New Revision: a310cbae02248023b11eefc8d1663661ac1f7721

URL: https://github.com/llvm/llvm-project/commit/a310cbae02248023b11eefc8d1663661ac1f7721
DIFF: https://github.com/llvm/llvm-project/commit/a310cbae02248023b11eefc8d1663661ac1f7721.diff

LOG: [X86] Add getAVX512Node helper. NFC.

For AVX512 targets without VLX, we have to widen 128/256-bit vectors to 512-bits to use some specific AVX512 instructions (or some other instructions with predicates etc.).

I've pulled out the widening code from LowerFunnelShift into the helper function, so we can convert some other widening patterns in the future.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 71dfd090142e7..0d152d65022c2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -6393,6 +6393,37 @@ SDValue SplitOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget,
   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs);
 }
 
+// Helper function that extends a non-512-bit vector op to 512-bits on non-VLX
+// targets.
+static SDValue getAVX512Node(unsigned Opcode, const SDLoc &DL, MVT VT,
+                             ArrayRef<SDValue> Ops, SelectionDAG &DAG,
+                             const X86Subtarget &Subtarget) {
+  assert(Subtarget.hasAVX512() && "AVX512 target expected");
+
+  // If we have VLX or the type is already 512-bits, then create the node
+  // directly.
+  if (Subtarget.hasVLX() || VT.is512BitVector())
+    return DAG.getNode(Opcode, DL, VT, Ops);
+
+  // Widen the vector ops.
+  MVT SVT = VT.getScalarType();
+  MVT WideVT = MVT::getVectorVT(SVT, 512 / SVT.getSizeInBits());
+  SmallVector<SDValue> WideOps(Ops.begin(), Ops.end());
+  for (SDValue &Op : WideOps) {
+    MVT OpVT = Op.getSimpleValueType();
+    // Just pass through scalar operands.
+    if (!OpVT.isVector())
+      continue;
+    assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
+           "Vector size mismatch");
+    Op = widenSubVector(Op, false, Subtarget, DAG, DL, 512);
+  }
+
+  // Perform the 512-bit op then extract the bottom subvector.
+  SDValue Res = DAG.getNode(Opcode, DL, WideVT, WideOps);
+  return extractSubVector(Res, 0, DAG, DL, VT.getSizeInBits());
+}
+
 /// Insert i1-subvector to i1-vector.
 static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG,
                                 const X86Subtarget &Subtarget) {
@@ -29593,29 +29624,15 @@ static SDValue LowerFunnelShift(SDValue Op, const X86Subtarget &Subtarget,
     if (IsFSHR)
       std::swap(Op0, Op1);
 
-    // With AVX512, but not VLX we need to widen to get a 512-bit result type.
-    if (!Subtarget.hasVLX() && !VT.is512BitVector()) {
-      Op0 = widenSubVector(Op0, false, Subtarget, DAG, DL, 512);
-      Op1 = widenSubVector(Op1, false, Subtarget, DAG, DL, 512);
-    }
-
-    SDValue Funnel;
     APInt APIntShiftAmt;
-    MVT ResultVT = Op0.getSimpleValueType();
     if (X86::isConstantSplat(Amt, APIntShiftAmt)) {
       uint64_t ShiftAmt = APIntShiftAmt.urem(VT.getScalarSizeInBits());
-      Funnel =
-          DAG.getNode(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, ResultVT, Op0,
-                      Op1, DAG.getTargetConstant(ShiftAmt, DL, MVT::i8));
-    } else {
-      if (!Subtarget.hasVLX() && !VT.is512BitVector())
-        Amt = widenSubVector(Amt, false, Subtarget, DAG, DL, 512);
-      Funnel = DAG.getNode(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL,
-                           ResultVT, Op0, Op1, Amt);
-    }
-    if (!Subtarget.hasVLX() && !VT.is512BitVector())
-      Funnel = extractSubVector(Funnel, 0, DAG, DL, VT.getSizeInBits());
-    return Funnel;
+      SDValue Imm = DAG.getTargetConstant(ShiftAmt, DL, MVT::i8);
+      return getAVX512Node(IsFSHR ? X86ISD::VSHRD : X86ISD::VSHLD, DL, VT,
+                           {Op0, Op1, Imm}, DAG, Subtarget);
+    }
+    return getAVX512Node(IsFSHR ? X86ISD::VSHRDV : X86ISD::VSHLDV, DL, VT,
+                         {Op0, Op1, Amt}, DAG, Subtarget);
   }
   assert(
       (VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) &&


        


More information about the llvm-commits mailing list