[llvm] r343383 - [X86][SSE] LowerScalarImmediateShift - use getTargetConstantBitsFromNode to get immediate data

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 29 09:40:35 PDT 2018


Author: rksimon
Date: Sat Sep 29 09:40:35 2018
New Revision: 343383

URL: http://llvm.org/viewvc/llvm-project?rev=343383&view=rev
Log:
[X86][SSE] LowerScalarImmediateShift - use getTargetConstantBitsFromNode to get immediate data

Don't just attempt to find a splat build vector.

First step towards getting rid of all the 32-bit special case code.

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=343383&r1=343382&r2=343383&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Sat Sep 29 09:40:35 2018
@@ -23422,6 +23422,7 @@ static SDValue LowerScalarImmediateShift
   SDLoc dl(Op);
   SDValue R = Op.getOperand(0);
   SDValue Amt = Op.getOperand(1);
+  unsigned EltSizeInBits = VT.getScalarSizeInBits();
   unsigned X86Opc = getTargetVShiftUniformOpcode(Op.getOpcode(), false);
 
   auto ArithmeticShiftRight64 = [&](uint64_t ShiftAmt) {
@@ -23465,74 +23466,83 @@ static SDValue LowerScalarImmediateShift
   };
 
   // Optimize shl/srl/sra with constant shift amount.
-  if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
-    if (auto *ShiftConst = BVAmt->getConstantSplatNode()) {
-      uint64_t ShiftAmt = ShiftConst->getZExtValue();
-
-      if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
-        return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
-
-      // i64 SRA needs to be performed as partial shifts.
-      if (((!Subtarget.hasXOP() && VT == MVT::v2i64) ||
-           (Subtarget.hasInt256() && VT == MVT::v4i64)) &&
-          Op.getOpcode() == ISD::SRA)
-        return ArithmeticShiftRight64(ShiftAmt);
-
-      if (VT == MVT::v16i8 ||
-          (Subtarget.hasInt256() && VT == MVT::v32i8) ||
-          VT == MVT::v64i8) {
-        unsigned NumElts = VT.getVectorNumElements();
-        MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
-
-        // Simple i8 add case
-        if (Op.getOpcode() == ISD::SHL && ShiftAmt == 1)
-          return DAG.getNode(ISD::ADD, dl, VT, R, R);
-
-        // ashr(R, 7)  === cmp_slt(R, 0)
-        if (Op.getOpcode() == ISD::SRA && ShiftAmt == 7) {
-          SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl);
-          if (VT.is512BitVector()) {
-            assert(VT == MVT::v64i8 && "Unexpected element type!");
-            SDValue CMP = DAG.getSetCC(dl, MVT::v64i1, Zeros, R,
-                                       ISD::SETGT);
-            return DAG.getNode(ISD::SIGN_EXTEND, dl, VT, CMP);
-          }
-          return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R);
-        }
+  APInt UndefElts;
+  SmallVector<APInt, 8> EltBits;
+  if (getTargetConstantBitsFromNode(Amt, EltSizeInBits, UndefElts, EltBits,
+                                    true, false)) {
+    int SplatIndex = -1;
+    for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
+      if (UndefElts[i])
+        continue;
+      if (0 <= SplatIndex && EltBits[i] != EltBits[SplatIndex])
+        return SDValue();
+      SplatIndex = i;
+    }
+    if (SplatIndex < 0)
+      return SDValue();
 
-        // XOP can shift v16i8 directly instead of as shift v8i16 + mask.
-        if (VT == MVT::v16i8 && Subtarget.hasXOP())
-          return SDValue();
-
-        if (Op.getOpcode() == ISD::SHL) {
-          // Make a large shift.
-          SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ShiftVT,
-                                                   R, ShiftAmt, DAG);
-          SHL = DAG.getBitcast(VT, SHL);
-          // Zero out the rightmost bits.
-          return DAG.getNode(ISD::AND, dl, VT, SHL,
-                             DAG.getConstant(uint8_t(-1U << ShiftAmt), dl, VT));
-        }
-        if (Op.getOpcode() == ISD::SRL) {
-          // Make a large shift.
-          SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ShiftVT,
-                                                   R, ShiftAmt, DAG);
-          SRL = DAG.getBitcast(VT, SRL);
-          // Zero out the leftmost bits.
-          return DAG.getNode(ISD::AND, dl, VT, SRL,
-                             DAG.getConstant(uint8_t(-1U) >> ShiftAmt, dl, VT));
+    uint64_t ShiftAmt = EltBits[SplatIndex].getZExtValue();
+    if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
+      return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
+
+    // i64 SRA needs to be performed as partial shifts.
+    if (((!Subtarget.hasXOP() && VT == MVT::v2i64) ||
+         (Subtarget.hasInt256() && VT == MVT::v4i64)) &&
+        Op.getOpcode() == ISD::SRA)
+      return ArithmeticShiftRight64(ShiftAmt);
+
+    if (VT == MVT::v16i8 || (Subtarget.hasInt256() && VT == MVT::v32i8) ||
+        VT == MVT::v64i8) {
+      unsigned NumElts = VT.getVectorNumElements();
+      MVT ShiftVT = MVT::getVectorVT(MVT::i16, NumElts / 2);
+
+      // Simple i8 add case
+      if (Op.getOpcode() == ISD::SHL && ShiftAmt == 1)
+        return DAG.getNode(ISD::ADD, dl, VT, R, R);
+
+      // ashr(R, 7)  === cmp_slt(R, 0)
+      if (Op.getOpcode() == ISD::SRA && ShiftAmt == 7) {
+        SDValue Zeros = getZeroVector(VT, Subtarget, DAG, dl);
+        if (VT.is512BitVector()) {
+          assert(VT == MVT::v64i8 && "Unexpected element type!");
+          SDValue CMP = DAG.getSetCC(dl, MVT::v64i1, Zeros, R, ISD::SETGT);
+          return DAG.getNode(ISD::SIGN_EXTEND, dl, VT, CMP);
         }
-        if (Op.getOpcode() == ISD::SRA) {
-          // ashr(R, Amt) === sub(xor(lshr(R, Amt), Mask), Mask)
-          SDValue Res = DAG.getNode(ISD::SRL, dl, VT, R, Amt);
-
-          SDValue Mask = DAG.getConstant(128 >> ShiftAmt, dl, VT);
-          Res = DAG.getNode(ISD::XOR, dl, VT, Res, Mask);
-          Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask);
-          return Res;
-        }
-        llvm_unreachable("Unknown shift opcode.");
+        return DAG.getNode(X86ISD::PCMPGT, dl, VT, Zeros, R);
+      }
+
+      // XOP can shift v16i8 directly instead of as shift v8i16 + mask.
+      if (VT == MVT::v16i8 && Subtarget.hasXOP())
+        return SDValue();
+
+      if (Op.getOpcode() == ISD::SHL) {
+        // Make a large shift.
+        SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, ShiftVT, R,
+                                                 ShiftAmt, DAG);
+        SHL = DAG.getBitcast(VT, SHL);
+        // Zero out the rightmost bits.
+        return DAG.getNode(ISD::AND, dl, VT, SHL,
+                           DAG.getConstant(uint8_t(-1U << ShiftAmt), dl, VT));
+      }
+      if (Op.getOpcode() == ISD::SRL) {
+        // Make a large shift.
+        SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, ShiftVT, R,
+                                                 ShiftAmt, DAG);
+        SRL = DAG.getBitcast(VT, SRL);
+        // Zero out the leftmost bits.
+        return DAG.getNode(ISD::AND, dl, VT, SRL,
+                           DAG.getConstant(uint8_t(-1U) >> ShiftAmt, dl, VT));
+      }
+      if (Op.getOpcode() == ISD::SRA) {
+        // ashr(R, Amt) === sub(xor(lshr(R, Amt), Mask), Mask)
+        SDValue Res = DAG.getNode(ISD::SRL, dl, VT, R, Amt);
+
+        SDValue Mask = DAG.getConstant(128 >> ShiftAmt, dl, VT);
+        Res = DAG.getNode(ISD::XOR, dl, VT, Res, Mask);
+        Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask);
+        return Res;
       }
+      llvm_unreachable("Unknown shift opcode.");
     }
   }
 




More information about the llvm-commits mailing list