[llvm] r212602 - [x86] Re-apply a variant of the x86 side of r212324 now that the rest

Chandler Carruth chandlerc at gmail.com
Wed Jul 9 03:06:58 PDT 2014


Author: chandlerc
Date: Wed Jul  9 05:06:58 2014
New Revision: 212602

URL: http://llvm.org/viewvc/llvm-project?rev=212602&view=rev
Log:
[x86] Re-apply a variant of the x86 side of r212324 now that the rest
has settled without incident, removing the x86-specific and overly
strict 'isVectorSplat' routine in favor of generic and more powerful
splat detection.

The primary motivation and result of this is that the x86 backend can
now see through splats which contain undef elements. This is essential
if we are using a widening form of legalization and I've updated a test
case to also run in that mode as before this change the generated code
for the test case was completely scalarized.

This version of the patch much more carefully handles the undef lanes.
- We aren't overly conservative about them in the shift lowering
  (where we will never use the splat itself).
- One place where the splat would have been re-used by the existing code
  now explicitly constructs a new constant splat that will be safe.
- The broadcast lowering is much more reasonable with undefs by doing
  a correct check of whether the splat is the only user of a loaded
  value, checking that the splat actually crosses multiple lanes before
  using a broadcast, and handling broadcasts of non-constant splats.

As a consequence of the last bullet, the weird usage of vpshufd instead
of vbroadcast is gone, and we actually can lower an AVX splat with
vbroadcastss where before we emitted a really strange pattern of
a vector load and a manual splat across the vector.

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/avx-splat.ll
    llvm/trunk/test/CodeGen/X86/widen_cast-4.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=212602&r1=212601&r2=212602&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Jul  9 05:06:58 2014
@@ -4858,19 +4858,6 @@ static bool ShouldXformToMOVLP(SDNode *V
   return true;
 }
 
-/// isSplatVector - Returns true if N is a BUILD_VECTOR node whose elements are
-/// all the same.
-static bool isSplatVector(SDNode *N) {
-  if (N->getOpcode() != ISD::BUILD_VECTOR)
-    return false;
-
-  SDValue SplatValue = N->getOperand(0);
-  for (unsigned i = 1, e = N->getNumOperands(); i != e; ++i)
-    if (N->getOperand(i) != SplatValue)
-      return false;
-  return true;
-}
-
 /// isZeroShuffle - Returns true if N is a VECTOR_SHUFFLE that can be resolved
 /// to an zero vector.
 /// FIXME: move to dag combiner / method on ShuffleVectorSDNode
@@ -5779,18 +5766,22 @@ static SDValue LowerVectorBroadcast(SDVa
       return SDValue();
 
     case ISD::BUILD_VECTOR: {
-      // The BUILD_VECTOR node must be a splat.
-      if (!isSplatVector(Op.getNode()))
+      auto *BVOp = cast<BuildVectorSDNode>(Op.getNode());
+      BitVector UndefElements;
+      SDValue Splat = BVOp->getSplatValue(&UndefElements);
+
+      // We need a splat of a single value to use broadcast, and it doesn't
+      // make any sense if the value is only in one element of the vector.
+      if (!Splat || (VT.getVectorNumElements() - UndefElements.count()) <= 1)
         return SDValue();
 
-      Ld = Op.getOperand(0);
+      Ld = Splat;
       ConstSplatVal = (Ld.getOpcode() == ISD::Constant ||
-                     Ld.getOpcode() == ISD::ConstantFP);
+                       Ld.getOpcode() == ISD::ConstantFP);
 
-      // The suspected load node has several users. Make sure that all
-      // of its users are from the BUILD_VECTOR node.
-      // Constants may have multiple users.
-      if (!ConstSplatVal && !Ld->hasNUsesOfValue(VT.getVectorNumElements(), 0))
+      // Make sure that all of the users of a non-constant load are from the
+      // BUILD_VECTOR node.
+      if (!ConstSplatVal && !BVOp->isOnlyUserOf(Ld.getNode()))
         return SDValue();
       break;
     }
@@ -9375,8 +9366,13 @@ X86TargetLowering::LowerVECTOR_SHUFFLE(S
   bool Commuted = false;
   // FIXME: This should also accept a bitcast of a splat?  Be careful, not
   // 1,1,1,1 -> v8i16 though.
-  V1IsSplat = isSplatVector(V1.getNode());
-  V2IsSplat = isSplatVector(V2.getNode());
+  BitVector UndefElements;
+  if (auto *BVOp = dyn_cast<BuildVectorSDNode>(V1.getNode()))
+    if (BVOp->getConstantSplatNode(&UndefElements) && UndefElements.none())
+      V1IsSplat = true;
+  if (auto *BVOp = dyn_cast<BuildVectorSDNode>(V2.getNode()))
+    if (BVOp->getConstantSplatNode(&UndefElements) && UndefElements.none())
+      V2IsSplat = true;
 
   // Canonicalize the splat or undef, if present, to be on the RHS.
   if (!V2IsUndef && V1IsSplat && !V2IsSplat) {
@@ -15171,10 +15167,9 @@ static SDValue LowerScalarImmediateShift
   SDValue Amt = Op.getOperand(1);
 
   // Optimize shl/srl/sra with constant shift amount.
-  if (isSplatVector(Amt.getNode())) {
-    SDValue SclrAmt = Amt->getOperand(0);
-    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(SclrAmt)) {
-      uint64_t ShiftAmt = C->getZExtValue();
+  if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
+    if (auto *ShiftConst = BVAmt->getConstantSplatNode()) {
+      uint64_t ShiftAmt = ShiftConst->getZExtValue();
 
       if (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 ||
           (Subtarget->hasInt256() &&
@@ -19423,28 +19418,34 @@ static SDValue PerformSELECTCombine(SDNo
           Other->getOpcode() == ISD::SUB && DAG.isEqualTo(OpRHS, CondRHS))
         return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS);
 
-      // If the RHS is a constant we have to reverse the const canonicalization.
-      // x > C-1 ? x+-C : 0 --> subus x, C
-      if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD &&
-          isSplatVector(CondRHS.getNode()) && isSplatVector(OpRHS.getNode())) {
-        APInt A = cast<ConstantSDNode>(OpRHS.getOperand(0))->getAPIntValue();
-        if (CondRHS.getConstantOperandVal(0) == -A-1)
-          return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS,
-                             DAG.getConstant(-A, VT));
-      }
-
-      // Another special case: If C was a sign bit, the sub has been
-      // canonicalized into a xor.
-      // FIXME: Would it be better to use computeKnownBits to determine whether
-      //        it's safe to decanonicalize the xor?
-      // x s< 0 ? x^C : 0 --> subus x, C
-      if (CC == ISD::SETLT && Other->getOpcode() == ISD::XOR &&
-          ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
-          isSplatVector(OpRHS.getNode())) {
-        APInt A = cast<ConstantSDNode>(OpRHS.getOperand(0))->getAPIntValue();
-        if (A.isSignBit())
-          return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS);
-      }
+      if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS))
+        if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) {
+          if (auto *CondRHSBV = dyn_cast<BuildVectorSDNode>(CondRHS))
+            if (auto *CondRHSConst = CondRHSBV->getConstantSplatNode())
+              // If the RHS is a constant we have to reverse the const
+              // canonicalization.
+              // x > C-1 ? x+-C : 0 --> subus x, C
+              if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD &&
+                  CondRHSConst->getAPIntValue() ==
+                      (-OpRHSConst->getAPIntValue() - 1))
+                return DAG.getNode(
+                    X86ISD::SUBUS, DL, VT, OpLHS,
+                    DAG.getConstant(-OpRHSConst->getAPIntValue(), VT));
+
+          // Another special case: If C was a sign bit, the sub has been
+          // canonicalized into a xor.
+          // FIXME: Would it be better to use computeKnownBits to determine
+          //        whether it's safe to decanonicalize the xor?
+          // x s< 0 ? x^C : 0 --> subus x, C
+          if (CC == ISD::SETLT && Other->getOpcode() == ISD::XOR &&
+              ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
+              OpRHSConst->getAPIntValue().isSignBit())
+            // Note that we have to rebuild the RHS constant here to ensure we
+            // don't rely on particular values of undef lanes.
+            return DAG.getNode(
+                X86ISD::SUBUS, DL, VT, OpLHS,
+                DAG.getConstant(OpRHSConst->getAPIntValue(), VT));
+        }
     }
   }
 
@@ -20152,16 +20153,15 @@ static SDValue PerformSHLCombine(SDNode
   // vector operations in many cases. Also, on sandybridge ADD is faster than
   // shl.
   // (shl V, 1) -> add V,V
-  if (isSplatVector(N1.getNode())) {
-    assert(N0.getValueType().isVector() && "Invalid vector shift type");
-    ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1->getOperand(0));
-    // We shift all of the values by one. In many cases we do not have
-    // hardware support for this operation. This is better expressed as an ADD
-    // of two values.
-    if (N1C && (1 == N1C->getZExtValue())) {
-      return DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N0);
+  if (auto *N1BV = dyn_cast<BuildVectorSDNode>(N1))
+    if (auto *N1SplatC = N1BV->getConstantSplatNode()) {
+      assert(N0.getValueType().isVector() && "Invalid vector shift type");
+      // We shift all of the values by one. In many cases we do not have
+      // hardware support for this operation. This is better expressed as an ADD
+      // of two values.
+      if (N1SplatC->getZExtValue() == 1)
+        return DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N0);
     }
-  }
 
   return SDValue();
 }
@@ -20180,10 +20180,9 @@ static SDValue performShiftToAllZeros(SD
 
   SDValue Amt = N->getOperand(1);
   SDLoc DL(N);
-  if (isSplatVector(Amt.getNode())) {
-    SDValue SclrAmt = Amt->getOperand(0);
-    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(SclrAmt)) {
-      APInt ShiftAmt = C->getAPIntValue();
+  if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Amt))
+    if (auto *AmtSplat = AmtBV->getConstantSplatNode()) {
+      APInt ShiftAmt = AmtSplat->getAPIntValue();
       unsigned MaxAmount = VT.getVectorElementType().getSizeInBits();
 
       // SSE2/AVX2 logical shifts always return a vector of 0s
@@ -20193,7 +20192,6 @@ static SDValue performShiftToAllZeros(SD
       if (ShiftAmt.trunc(8).uge(MaxAmount))
         return getZeroVector(VT, Subtarget, DAG, DL);
     }
-  }
 
   return SDValue();
 }
@@ -20387,9 +20385,10 @@ static SDValue WidenMaskArithmetic(SDNod
 
   // The right side has to be a 'trunc' or a constant vector.
   bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE;
-  bool RHSConst = (isSplatVector(N1.getNode()) &&
-                   isa<ConstantSDNode>(N1->getOperand(0)));
-  if (!RHSTrunc && !RHSConst)
+  ConstantSDNode *RHSConstSplat;
+  if (auto *RHSBV = dyn_cast<BuildVectorSDNode>(N1))
+    RHSConstSplat = RHSBV->getConstantSplatNode();
+  if (!RHSTrunc && !RHSConstSplat)
     return SDValue();
 
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -20399,9 +20398,9 @@ static SDValue WidenMaskArithmetic(SDNod
 
   // Set N0 and N1 to hold the inputs to the new wide operation.
   N0 = N0->getOperand(0);
-  if (RHSConst) {
+  if (RHSConstSplat) {
     N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT.getScalarType(),
-                     N1->getOperand(0));
+                     SDValue(RHSConstSplat, 0));
     SmallVector<SDValue, 8> C(WideVT.getVectorNumElements(), N1);
     N1 = DAG.getNode(ISD::BUILD_VECTOR, DL, WideVT, C);
   } else if (RHSTrunc) {
@@ -20547,12 +20546,9 @@ static SDValue PerformOrCombine(SDNode *
       unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
       unsigned SraAmt = ~0;
       if (Mask.getOpcode() == ISD::SRA) {
-        SDValue Amt = Mask.getOperand(1);
-        if (isSplatVector(Amt.getNode())) {
-          SDValue SclrAmt = Amt->getOperand(0);
-          if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(SclrAmt))
-            SraAmt = C->getZExtValue();
-        }
+        if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Mask.getOperand(1)))
+          if (auto *AmtConst = AmtBV->getConstantSplatNode())
+            SraAmt = AmtConst->getZExtValue();
       } else if (Mask.getOpcode() == X86ISD::VSRAI) {
         SDValue SraC = Mask.getOperand(1);
         SraAmt  = cast<ConstantSDNode>(SraC)->getZExtValue();

Modified: llvm/trunk/test/CodeGen/X86/avx-splat.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/avx-splat.ll?rev=212602&r1=212601&r2=212602&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/avx-splat.ll (original)
+++ llvm/trunk/test/CodeGen/X86/avx-splat.ll Wed Jul  9 05:06:58 2014
@@ -43,13 +43,10 @@ entry:
   ret <4 x double> %vecinit6.i
 }
 
-; Test this simple opt:
+; Test this turns into a broadcast:
 ;   shuffle (scalar_to_vector (load (ptr + 4))), undef, <0, 0, 0, 0>
-; To:
-;   shuffle (vload ptr)), undef, <1, 1, 1, 1>
-; CHECK: vmovdqa
-; CHECK-NEXT: vpshufd $-1
-; CHECK-NEXT: vinsertf128  $1
+;   
+; CHECK: vbroadcastss
 define <8 x float> @funcE() nounwind {
 allocas:
   %udx495 = alloca [18 x [18 x float]], align 32

Modified: llvm/trunk/test/CodeGen/X86/widen_cast-4.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/widen_cast-4.ll?rev=212602&r1=212601&r2=212602&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/widen_cast-4.ll (original)
+++ llvm/trunk/test/CodeGen/X86/widen_cast-4.ll Wed Jul  9 05:06:58 2014
@@ -1,8 +1,9 @@
 ; RUN: llc < %s -march=x86 -mattr=+sse4.2 | FileCheck %s
-; CHECK: psraw
-; CHECK: psraw
+; RUN: llc < %s -march=x86 -mattr=+sse4.2 -x86-experimental-vector-widening-legalization | FileCheck %s --check-prefix=CHECK-WIDE
 
 define void @update(i64* %dst_i, i64* %src_i, i32 %n) nounwind {
+; CHECK-LABEL: update:
+; CHECK-WIDE-LABEL: update:
 entry:
 	%dst_i.addr = alloca i64*		; <i64**> [#uses=2]
 	%src_i.addr = alloca i64*		; <i64**> [#uses=2]
@@ -44,6 +45,26 @@ forbody:		; preds = %forcond
 	%shr = ashr <8 x i8> %add, < i8 2, i8 2, i8 2, i8 2, i8 2, i8 2, i8 2, i8 2 >		; <<8 x i8>> [#uses=1]
 	store <8 x i8> %shr, <8 x i8>* %arrayidx10
 	br label %forinc
+; CHECK: %forbody
+; CHECK:      pmovzxbw
+; CHECK-NEXT: paddw
+; CHECK-NEXT: psllw $8
+; CHECK-NEXT: psraw $8
+; CHECK-NEXT: psraw $2
+; CHECK-NEXT: pshufb
+; CHECK-NEXT: movlpd
+;
+; FIXME: We shouldn't require both a movd and an insert.
+; CHECK-WIDE: %forbody
+; CHECK-WIDE:      movd
+; CHECK-WIDE-NEXT: pinsrd
+; CHECK-WIDE-NEXT: paddb
+; CHECK-WIDE-NEXT: psrlw $2
+; CHECK-WIDE-NEXT: pand
+; CHECK-WIDE-NEXT: pxor
+; CHECK-WIDE-NEXT: psubb
+; CHECK-WIDE-NEXT: pextrd
+; CHECK-WIDE-NEXT: movd
 
 forinc:		; preds = %forbody
 	%tmp15 = load i32* %i		; <i32> [#uses=1]





More information about the llvm-commits mailing list