[llvm] r349319 - [X86] Pull out constant splat rotation detection.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 16 11:46:04 PST 2018


Author: rksimon
Date: Sun Dec 16 11:46:04 2018
New Revision: 349319

URL: http://llvm.org/viewvc/llvm-project?rev=349319&view=rev
Log:
[X86] Pull out constant splat rotation detection.

We had 3 different approaches - consistently use getTargetConstantBitsFromNode and allow undef elts.

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=349319&r1=349318&r2=349319&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Sun Dec 16 11:46:04 2018
@@ -24727,21 +24727,31 @@ static SDValue LowerRotate(SDValue Op, c
   SDValue Amt = Op.getOperand(1);
   unsigned Opcode = Op.getOpcode();
   unsigned EltSizeInBits = VT.getScalarSizeInBits();
+  int NumElts = VT.getVectorNumElements();
+
+  // Check for constant splat rotation amount.
+  APInt UndefElts;
+  SmallVector<APInt, 32> EltBits;
+  int CstSplatIndex = -1;
+  if (getTargetConstantBitsFromNode(Amt, EltSizeInBits, UndefElts, EltBits))
+    for (int i = 0; i != NumElts; ++i)
+      if (!UndefElts[i]) {
+        if (CstSplatIndex < 0 || EltBits[i] == EltBits[CstSplatIndex]) {
+          CstSplatIndex = i;
+          continue;
+        }
+        CstSplatIndex = -1;
+        break;
+      }
 
   // AVX512 implicitly uses modulo rotation amounts.
   if (Subtarget.hasAVX512() && 32 <= EltSizeInBits) {
     // Attempt to rotate by immediate.
-    APInt UndefElts;
-    SmallVector<APInt, 16> EltBits;
-    if (getTargetConstantBitsFromNode(Amt, EltSizeInBits, UndefElts, EltBits)) {
-      if (!UndefElts && llvm::all_of(EltBits, [EltBits](APInt &V) {
-            return EltBits[0] == V;
-          })) {
-        unsigned Op = (Opcode == ISD::ROTL ? X86ISD::VROTLI : X86ISD::VROTRI);
-        uint64_t RotateAmt = EltBits[0].urem(EltSizeInBits);
-        return DAG.getNode(Op, DL, VT, R,
-                           DAG.getConstant(RotateAmt, DL, MVT::i8));
-      }
+    if (0 <= CstSplatIndex) {
+      unsigned Op = (Opcode == ISD::ROTL ? X86ISD::VROTLI : X86ISD::VROTRI);
+      uint64_t RotateAmt = EltBits[CstSplatIndex].urem(EltSizeInBits);
+      return DAG.getNode(Op, DL, VT, R,
+                         DAG.getConstant(RotateAmt, DL, MVT::i8));
     }
 
     // Else, fall-back on VPROLV/VPRORV.
@@ -24759,12 +24769,10 @@ static SDValue LowerRotate(SDValue Op, c
     assert(VT.is128BitVector() && "Only rotate 128-bit vectors!");
 
     // Attempt to rotate by immediate.
-    if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
-      if (auto *RotateConst = BVAmt->getConstantSplatNode()) {
-        uint64_t RotateAmt = RotateConst->getAPIntValue().urem(EltSizeInBits);
-        return DAG.getNode(X86ISD::VROTLI, DL, VT, R,
-                           DAG.getConstant(RotateAmt, DL, MVT::i8));
-      }
+    if (0 <= CstSplatIndex) {
+      uint64_t RotateAmt = EltBits[CstSplatIndex].urem(EltSizeInBits);
+      return DAG.getNode(X86ISD::VROTLI, DL, VT, R,
+                         DAG.getConstant(RotateAmt, DL, MVT::i8));
     }
 
     // Use general rotate by variable (per-element).
@@ -24781,15 +24789,14 @@ static SDValue LowerRotate(SDValue Op, c
          "Only vXi32/vXi16/vXi8 vector rotates supported");
 
   // Rotate by an uniform constant - expand back to shifts.
-  if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt))
-    if (BVAmt->getConstantSplatNode())
-      return SDValue();
+  if (0 <= CstSplatIndex)
+    return SDValue();
 
   // v16i8/v32i8: Split rotation into rot4/rot2/rot1 stages and select by
   // the amount bit.
   if (EltSizeInBits == 8) {
     // We don't need ModuloAmt here as we just peek at individual bits.
-    MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2);
+    MVT ExtVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
 
     auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) {
       if (Subtarget.hasSSE41()) {




More information about the llvm-commits mailing list