[llvm] r284953 - [DAG] enhance computeKnownBits to handle SRL/SRA with vector splat constant

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 23 16:13:31 PDT 2016


Author: spatel
Date: Sun Oct 23 18:13:31 2016
New Revision: 284953

URL: http://llvm.org/viewvc/llvm-project?rev=284953&view=rev
Log:
[DAG] enhance computeKnownBits to handle SRL/SRA with vector splat constant

Modified:
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/trunk/test/CodeGen/X86/combine-sra.ll
    llvm/trunk/test/CodeGen/X86/combine-srl.ll

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp?rev=284953&r1=284952&r2=284953&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Sun Oct 23 18:13:31 2016
@@ -1996,6 +1996,18 @@ bool SelectionDAG::MaskedValueIsZero(SDV
   return (KnownZero & Mask) == Mask;
 }
 
+/// If a SHL/SRA/SRL node has a constant or splat constant shift amount that
+/// is less than the element bit-width of the shift node, return it.
+static const APInt *getValidShiftAmountConstant(SDValue V) {
+  if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1))) {
+    // Shifting more than the bitwidth is not valid.
+    const APInt &ShAmt = SA->getAPIntValue();
+    if (ShAmt.ult(V.getScalarValueSizeInBits()))
+      return &ShAmt;
+  }
+  return nullptr;
+}
+
 /// Determine which bits of Op are known to be either zero or one and return
 /// them in the KnownZero/KnownOne bitsets.
 void SelectionDAG::computeKnownBits(SDValue Op, APInt &KnownZero,
@@ -2144,57 +2156,34 @@ void SelectionDAG::computeKnownBits(SDVa
       KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - 1);
     break;
   case ISD::SHL:
-    if (ConstantSDNode *SA = isConstOrConstSplat(Op.getOperand(1))) {
-      // If the shift count is an invalid immediate, don't do anything.
-      APInt ShAmt = SA->getAPIntValue();
-      if (ShAmt.uge(BitWidth))
-        break;
-
+    if (const APInt *ShAmt = getValidShiftAmountConstant(Op)) {
       computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth + 1);
-      KnownZero = KnownZero << ShAmt;
-      KnownOne = KnownOne << ShAmt;
-      // low bits known zero.
-      KnownZero |= APInt::getLowBitsSet(BitWidth, ShAmt.getZExtValue());
+      KnownZero = KnownZero << *ShAmt;
+      KnownOne = KnownOne << *ShAmt;
+      // Low bits are known zero.
+      KnownZero |= APInt::getLowBitsSet(BitWidth, ShAmt->getZExtValue());
     }
     break;
   case ISD::SRL:
-    // FIXME: Reuse isConstOrConstSplat + APInt from above.
-    if (ConstantSDNode *SA = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
-      unsigned ShAmt = SA->getZExtValue();
-
-      // If the shift count is an invalid immediate, don't do anything.
-      if (ShAmt >= BitWidth)
-        break;
-
-      computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
-      KnownZero = KnownZero.lshr(ShAmt);
-      KnownOne  = KnownOne.lshr(ShAmt);
-
-      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt);
-      KnownZero |= HighBits;  // High bits known zero.
+    if (const APInt *ShAmt = getValidShiftAmountConstant(Op)) {
+      computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth + 1);
+      KnownZero = KnownZero.lshr(*ShAmt);
+      KnownOne  = KnownOne.lshr(*ShAmt);
+      // High bits are known zero.
+      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt->getZExtValue());
+      KnownZero |= HighBits;
     }
     break;
   case ISD::SRA:
-    // FIXME: Reuse isConstOrConstSplat + APInt from above.
-    if (ConstantSDNode *SA = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
-      unsigned ShAmt = SA->getZExtValue();
-
-      // If the shift count is an invalid immediate, don't do anything.
-      if (ShAmt >= BitWidth)
-        break;
-
-      // If any of the demanded bits are produced by the sign extension, we also
-      // demand the input sign bit.
-      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt);
-
-      computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
-      KnownZero = KnownZero.lshr(ShAmt);
-      KnownOne  = KnownOne.lshr(ShAmt);
-
-      // Handle the sign bits.
+    if (const APInt *ShAmt = getValidShiftAmountConstant(Op)) {
+      computeKnownBits(Op.getOperand(0), KnownZero, KnownOne, Depth + 1);
+      KnownZero = KnownZero.lshr(*ShAmt);
+      KnownOne  = KnownOne.lshr(*ShAmt);
+      // If we know the value of the sign bit, then we know it is copied across
+      // the high bits by the shift amount.
+      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt->getZExtValue());
       APInt SignBit = APInt::getSignBit(BitWidth);
-      SignBit = SignBit.lshr(ShAmt);  // Adjust to where it is now in the mask.
-
+      SignBit = SignBit.lshr(*ShAmt);  // Adjust to where it is now in the mask.
       if (KnownZero.intersects(SignBit)) {
         KnownZero |= HighBits;  // New bits are known zero.
       } else if (KnownOne.intersects(SignBit)) {

Modified: llvm/trunk/test/CodeGen/X86/combine-sra.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/combine-sra.ll?rev=284953&r1=284952&r2=284953&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/combine-sra.ll (original)
+++ llvm/trunk/test/CodeGen/X86/combine-sra.ll Sun Oct 23 18:13:31 2016
@@ -300,15 +300,12 @@ define <4 x i32> @combine_vec_ashr_posit
 define <4 x i32> @combine_vec_ashr_positive_splat(<4 x i32> %x, <4 x i32> %y) {
 ; SSE-LABEL: combine_vec_ashr_positive_splat:
 ; SSE:       # BB#0:
-; SSE-NEXT:    pand {{.*}}(%rip), %xmm0
-; SSE-NEXT:    psrld $10, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_vec_ashr_positive_splat:
 ; AVX:       # BB#0:
-; AVX-NEXT:    vpbroadcastd {{.*}}(%rip), %xmm1
-; AVX-NEXT:    vpand %xmm1, %xmm0, %xmm0
-; AVX-NEXT:    vpsrld $10, %xmm0, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = and <4 x i32> %x, <i32 1023, i32 1023, i32 1023, i32 1023>
   %2 = ashr <4 x i32> %1, <i32 10, i32 10, i32 10, i32 10>

Modified: llvm/trunk/test/CodeGen/X86/combine-srl.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/combine-srl.ll?rev=284953&r1=284952&r2=284953&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/combine-srl.ll (original)
+++ llvm/trunk/test/CodeGen/X86/combine-srl.ll Sun Oct 23 18:13:31 2016
@@ -79,15 +79,12 @@ define <4 x i32> @combine_vec_lshr_by_ze
 define <4 x i32> @combine_vec_lshr_known_zero0(<4 x i32> %x) {
 ; SSE-LABEL: combine_vec_lshr_known_zero0:
 ; SSE:       # BB#0:
-; SSE-NEXT:    pand {{.*}}(%rip), %xmm0
-; SSE-NEXT:    psrld $4, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_vec_lshr_known_zero0:
 ; AVX:       # BB#0:
-; AVX-NEXT:    vpbroadcastd {{.*}}(%rip), %xmm1
-; AVX-NEXT:    vpand %xmm1, %xmm0, %xmm0
-; AVX-NEXT:    vpsrld $4, %xmm0, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = and <4 x i32> %x, <i32 15, i32 15, i32 15, i32 15>
   %2 = lshr <4 x i32> %1, <i32 4, i32 4, i32 4, i32 4>
@@ -292,21 +289,12 @@ define <4 x i32> @combine_vec_lshr_trunc
 define <4 x i32> @combine_vec_lshr_trunc_lshr_zero0(<4 x i64> %x) {
 ; SSE-LABEL: combine_vec_lshr_trunc_lshr_zero0:
 ; SSE:       # BB#0:
-; SSE-NEXT:    psrlq $48, %xmm0
-; SSE-NEXT:    psrlq $48, %xmm1
-; SSE-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[0,1,0,2]
-; SSE-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; SSE-NEXT:    pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7]
-; SSE-NEXT:    psrld $24, %xmm0
+; SSE-NEXT:    xorps %xmm0, %xmm0
 ; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_vec_lshr_trunc_lshr_zero0:
 ; AVX:       # BB#0:
-; AVX-NEXT:    vpsrlq $48, %ymm0, %ymm0
-; AVX-NEXT:    vpshufd {{.*#+}} ymm0 = ymm0[0,2,2,3,4,6,6,7]
-; AVX-NEXT:    vpermq {{.*#+}} ymm0 = ymm0[0,2,2,3]
-; AVX-NEXT:    vpsrld $24, %xmm0, %xmm0
-; AVX-NEXT:    vzeroupper
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = lshr <4 x i64> %x, <i64 48, i64 48, i64 48, i64 48>
   %2 = trunc <4 x i64> %1 to <4 x i32>




More information about the llvm-commits mailing list