[llvm] r363792 - [DAGCombiner] Support (shl (ext (shl x, c1)), c2) -> 0 non-uniform folds.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 19 05:25:29 PDT 2019


Author: rksimon
Date: Wed Jun 19 05:25:29 2019
New Revision: 363792

URL: http://llvm.org/viewvc/llvm-project?rev=363792&view=rev
Log:
[DAGCombiner] Support (shl (ext (shl x, c1)), c2) -> 0 non-uniform folds.

Use matchBinaryPredicate instead of isConstOrConstSplat to let us handle non-uniform shift cases. 

This requires us to tweak matchBinaryPredicate to allow it to (optionally) handle constants with different type widths.

Modified:
    llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/trunk/test/CodeGen/X86/combine-shl.ll

Modified: llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h?rev=363792&r1=363791&r2=363792&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h (original)
+++ llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h Wed Jun 19 05:25:29 2019
@@ -2617,10 +2617,11 @@ namespace ISD {
   /// Attempt to match a binary predicate against a pair of scalar/splat
   /// constants or every element of a pair of constant BUILD_VECTORs.
   /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
+  /// If AllowTypeMismatch is true then RetType + ArgTypes don't need to match.
   bool matchBinaryPredicate(
       SDValue LHS, SDValue RHS,
       std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match,
-      bool AllowUndefs = false);
+      bool AllowUndefs = false, bool AllowTypeMismatch = false);
 } // end namespace ISD
 
 } // end namespace llvm

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=363792&r1=363791&r2=363792&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Wed Jun 19 05:25:29 2019
@@ -7210,19 +7210,35 @@ SDValue DAGCombiner::visitSHL(SDNode *N)
   // that are shifted out by the inner shift in the first form.  This means
   // the outer shift size must be >= the number of bits added by the ext.
   // As a corollary, we don't care what kind of ext it is.
-  if (N1C && (N0.getOpcode() == ISD::ZERO_EXTEND ||
-              N0.getOpcode() == ISD::ANY_EXTEND ||
-              N0.getOpcode() == ISD::SIGN_EXTEND) &&
+  if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
+       N0.getOpcode() == ISD::ANY_EXTEND ||
+       N0.getOpcode() == ISD::SIGN_EXTEND) &&
       N0.getOperand(0).getOpcode() == ISD::SHL) {
     SDValue N0Op0 = N0.getOperand(0);
-    if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) {
+    SDValue InnerShiftAmt = N0Op0.getOperand(1);
+    EVT InnerVT = N0Op0.getValueType();
+    uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
+
+    auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
+                                                         ConstantSDNode *RHS) {
+      APInt c1 = LHS->getAPIntValue();
+      APInt c2 = RHS->getAPIntValue();
+      zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+      return c2.uge(OpSizeInBits - InnerBitwidth) &&
+             (c1 + c2).uge(OpSizeInBits);
+    };
+    if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
+                                  /*AllowUndefs*/ false,
+                                  /*AllowTypeMismatch*/ true))
+      return DAG.getConstant(0, SDLoc(N), VT);
+
+    ConstantSDNode *N0Op0C1 = isConstOrConstSplat(InnerShiftAmt);
+    if (N1C && N0Op0C1) {
       APInt c1 = N0Op0C1->getAPIntValue();
       APInt c2 = N1C->getAPIntValue();
       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
 
-      EVT InnerShiftVT = N0Op0.getValueType();
-      uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
-      if (c2.uge(OpSizeInBits - InnerShiftSize)) {
+      if (c2.uge(OpSizeInBits - InnerBitwidth)) {
         SDLoc DL(N0);
         APInt Sum = c1 + c2;
         if (Sum.uge(OpSizeInBits))

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp?rev=363792&r1=363791&r2=363792&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Wed Jun 19 05:25:29 2019
@@ -294,8 +294,8 @@ bool ISD::matchUnaryPredicate(SDValue Op
 bool ISD::matchBinaryPredicate(
     SDValue LHS, SDValue RHS,
     std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match,
-    bool AllowUndefs) {
-  if (LHS.getValueType() != RHS.getValueType())
+    bool AllowUndefs, bool AllowTypeMismatch) {
+  if (!AllowTypeMismatch && LHS.getValueType() != RHS.getValueType())
     return false;
 
   // TODO: Add support for scalar UNDEF cases?
@@ -318,8 +318,8 @@ bool ISD::matchBinaryPredicate(
     auto *RHSCst = dyn_cast<ConstantSDNode>(RHSOp);
     if ((!LHSCst && !LHSUndef) || (!RHSCst && !RHSUndef))
       return false;
-    if (LHSOp.getValueType() != SVT ||
-        LHSOp.getValueType() != RHSOp.getValueType())
+    if (!AllowTypeMismatch && (LHSOp.getValueType() != SVT ||
+                               LHSOp.getValueType() != RHSOp.getValueType()))
       return false;
     if (!Match(LHSCst, RHSCst))
       return false;

Modified: llvm/trunk/test/CodeGen/X86/combine-shl.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/combine-shl.ll?rev=363792&r1=363791&r2=363792&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/combine-shl.ll (original)
+++ llvm/trunk/test/CodeGen/X86/combine-shl.ll Wed Jun 19 05:25:29 2019
@@ -264,33 +264,16 @@ define <8 x i32> @combine_vec_shl_ext_sh
   ret <8 x i32> %3
 }
 
-; TODO - this should fold to ZERO.
 define <8 x i32> @combine_vec_shl_ext_shl1(<8 x i16> %x) {
-; SSE2-LABEL: combine_vec_shl_ext_shl1:
-; SSE2:       # %bb.0:
-; SSE2-NEXT:    pmullw {{.*}}(%rip), %xmm0
-; SSE2-NEXT:    punpcklwd {{.*#+}} xmm0 = xmm0[0,0,1,1,2,2,3,3]
-; SSE2-NEXT:    pslld $30, %xmm0
-; SSE2-NEXT:    xorpd %xmm1, %xmm1
-; SSE2-NEXT:    movsd {{.*#+}} xmm0 = xmm1[0],xmm0[1]
-; SSE2-NEXT:    movsd {{.*#+}} xmm1 = xmm1[0,1]
-; SSE2-NEXT:    retq
-;
-; SSE41-LABEL: combine_vec_shl_ext_shl1:
-; SSE41:       # %bb.0:
-; SSE41-NEXT:    pmullw {{.*}}(%rip), %xmm0
-; SSE41-NEXT:    pmovsxwd %xmm0, %xmm0
-; SSE41-NEXT:    pslld $30, %xmm0
-; SSE41-NEXT:    pxor %xmm1, %xmm1
-; SSE41-NEXT:    pblendw {{.*#+}} xmm0 = xmm1[0,1,2,3],xmm0[4,5,6,7]
-; SSE41-NEXT:    pxor %xmm1, %xmm1
-; SSE41-NEXT:    retq
+; SSE-LABEL: combine_vec_shl_ext_shl1:
+; SSE:       # %bb.0:
+; SSE-NEXT:    xorps %xmm0, %xmm0
+; SSE-NEXT:    xorps %xmm1, %xmm1
+; SSE-NEXT:    retq
 ;
 ; AVX-LABEL: combine_vec_shl_ext_shl1:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vpmullw {{.*}}(%rip), %xmm0, %xmm0
-; AVX-NEXT:    vpmovsxwd %xmm0, %ymm0
-; AVX-NEXT:    vpsllvd {{.*}}(%rip), %ymm0, %ymm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = shl <8 x i16> %x, <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>
   %2 = sext <8 x i16> %1 to <8 x i32>




More information about the llvm-commits mailing list