[llvm] r308598 - [DAGCombiner] Match non-uniform constant vectors using predicates.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 20 03:13:40 PDT 2017


Author: rksimon
Date: Thu Jul 20 03:13:40 2017
New Revision: 308598

URL: http://llvm.org/viewvc/llvm-project?rev=308598&view=rev
Log:
[DAGCombiner] Match non-uniform constant vectors using predicates.

Most combines currently recognise scalar and splat-vector constants, but not non-uniform vector constants.

This patch introduces a matching mechanism that uses predicates to check against BUILD_VECTOR of ConstantSDNode, as well as scalar ConstantSDNode cases.

I've changed a couple of predicates to demonstrate - the combine-shl changes add currently unsupported cases, while the MatchRotate replaces an existing mechanism.

Differential Revision: https://reviews.llvm.org/D35492

Modified:
    llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/trunk/test/CodeGen/X86/combine-shl.ll

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=308598&r1=308597&r2=308598&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Thu Jul 20 03:13:40 2017
@@ -873,6 +873,56 @@ static bool isAnyConstantBuildVector(con
          ISD::isBuildVectorOfConstantFPSDNodes(N);
 }
 
+// Attempt to match a unary predicate against a scalar/splat constant or
+// every element of a constant BUILD_VECTOR.
+static bool matchUnaryPredicate(SDValue Op,
+                                std::function<bool(ConstantSDNode *)> Match) {
+  if (auto *Cst = dyn_cast<ConstantSDNode>(Op))
+    return Match(Cst);
+
+  if (ISD::BUILD_VECTOR != Op.getOpcode())
+    return false;
+
+  EVT SVT = Op.getValueType().getScalarType();
+  for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
+    auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i));
+    if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst))
+      return false;
+  }
+  return true;
+}
+
+// Attempt to match a binary predicate against a pair of scalar/splat constants
+// or every element of a pair of constant BUILD_VECTORs.
+static bool matchBinaryPredicate(
+    SDValue LHS, SDValue RHS,
+    std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match) {
+  if (LHS.getValueType() != RHS.getValueType())
+    return false;
+
+  if (auto *LHSCst = dyn_cast<ConstantSDNode>(LHS))
+    if (auto *RHSCst = dyn_cast<ConstantSDNode>(RHS))
+      return Match(LHSCst, RHSCst);
+
+  if (ISD::BUILD_VECTOR != LHS.getOpcode() ||
+      ISD::BUILD_VECTOR != RHS.getOpcode())
+    return false;
+
+  EVT SVT = LHS.getValueType().getScalarType();
+  for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) {
+    auto *LHSCst = dyn_cast<ConstantSDNode>(LHS.getOperand(i));
+    auto *RHSCst = dyn_cast<ConstantSDNode>(RHS.getOperand(i));
+    if (!LHSCst || !RHSCst)
+      return false;
+    if (LHSCst->getValueType(0) != SVT ||
+        LHSCst->getValueType(0) != RHSCst->getValueType(0))
+      return false;
+    if (!Match(LHSCst, RHSCst))
+      return false;
+  }
+  return true;
+}
+
 SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
                                     SDValue N1) {
   EVT VT = N0.getValueType();
@@ -4585,20 +4635,6 @@ SDNode *DAGCombiner::MatchRotatePosNeg(S
   return nullptr;
 }
 
-// if Left + Right == Sum (constant or constant splat vector)
-static bool sumMatchConstant(SDValue Left, SDValue Right, unsigned Sum,
-                             SelectionDAG &DAG, const SDLoc &DL) {
-  EVT ShiftVT = Left.getValueType();
-  if (ShiftVT != Right.getValueType()) return false;
-
-  SDValue ShiftSum = DAG.FoldConstantArithmetic(ISD::ADD, DL, ShiftVT,
-                         Left.getNode(), Right.getNode());
-  if (!ShiftSum) return false;
-
-  ConstantSDNode *CSum = isConstOrConstSplat(ShiftSum);
-  return CSum && CSum->getZExtValue() == Sum;
-}
-
 // MatchRotate - Handle an 'or' of two operands.  If this is one of the many
 // idioms for rotate, and if the target supports rotation instructions, generate
 // a rot[lr].
@@ -4644,7 +4680,11 @@ SDNode *DAGCombiner::MatchRotate(SDValue
 
   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
   // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
-  if (sumMatchConstant(LHSShiftAmt, RHSShiftAmt, EltSizeInBits, DAG, DL)) {
+  auto MatchRotateSum = [EltSizeInBits](ConstantSDNode *LHS,
+                                        ConstantSDNode *RHS) {
+    return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
+  };
+  if (matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
     SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
                               LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt);
 
@@ -5365,7 +5405,11 @@ SDValue DAGCombiner::visitSHL(SDNode *N)
   if (isNullConstantOrNullSplatConstant(N0))
     return N0;
   // fold (shl x, c >= size(x)) -> undef
-  if (N1C && N1C->getAPIntValue().uge(OpSizeInBits))
+  // NOTE: ALL vector elements must be too big to avoid partial UNDEFs.
+  auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) {
+    return Val->getAPIntValue().uge(OpSizeInBits);
+  };
+  if (matchUnaryPredicate(N1, MatchShiftTooBig))
     return DAG.getUNDEF(VT);
   // fold (shl x, 0) -> x
   if (N1C && N1C->isNullValue())
@@ -5392,20 +5436,29 @@ SDValue DAGCombiner::visitSHL(SDNode *N)
     return SDValue(N, 0);
 
   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
-  if (N1C && N0.getOpcode() == ISD::SHL) {
-    if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
-      SDLoc DL(N);
-      APInt c1 = N0C1->getAPIntValue();
-      APInt c2 = N1C->getAPIntValue();
+  if (N0.getOpcode() == ISD::SHL) {
+    auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS,
+                                          ConstantSDNode *RHS) {
+      APInt c1 = LHS->getAPIntValue();
+      APInt c2 = RHS->getAPIntValue();
       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
-
-      APInt Sum = c1 + c2;
-      if (Sum.uge(OpSizeInBits))
-        return DAG.getConstant(0, DL, VT);
-
-      return DAG.getNode(
-          ISD::SHL, DL, VT, N0.getOperand(0),
-          DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType()));
+      return (c1 + c2).uge(OpSizeInBits);
+    };
+    if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
+      return DAG.getConstant(0, SDLoc(N), VT);
+
+    auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
+                                       ConstantSDNode *RHS) {
+      APInt c1 = LHS->getAPIntValue();
+      APInt c2 = RHS->getAPIntValue();
+      zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+      return (c1 + c2).ult(OpSizeInBits);
+    };
+    if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
+      SDLoc DL(N);
+      EVT ShiftVT = N1.getValueType();
+      SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
+      return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), Sum);
     }
   }
 

Modified: llvm/trunk/test/CodeGen/X86/combine-shl.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/combine-shl.ll?rev=308598&r1=308597&r2=308598&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/combine-shl.ll (original)
+++ llvm/trunk/test/CodeGen/X86/combine-shl.ll Thu Jul 20 03:13:40 2017
@@ -37,7 +37,6 @@ define <4 x i32> @combine_vec_shl_outofr
 ;
 ; AVX-LABEL: combine_vec_shl_outofrange1:
 ; AVX:       # BB#0:
-; AVX-NEXT:    vpsllvd {{.*}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = shl <4 x i32> %x, <i32 33, i32 34, i32 35, i32 36>
   ret <4 x i32> %1
@@ -153,7 +152,6 @@ define <4 x i32> @combine_vec_shl_shl1(<
 ; AVX-LABEL: combine_vec_shl_shl1:
 ; AVX:       # BB#0:
 ; AVX-NEXT:    vpsllvd {{.*}}(%rip), %xmm0, %xmm0
-; AVX-NEXT:    vpsllvd {{.*}}(%rip), %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = shl <4 x i32> %x, <i32 0, i32 1, i32 2, i32 3>
   %2 = shl <4 x i32> %1, <i32 4, i32 5, i32 6, i32 7>
@@ -184,8 +182,7 @@ define <4 x i32> @combine_vec_shl_shl_ze
 ;
 ; AVX-LABEL: combine_vec_shl_shl_zero1:
 ; AVX:       # BB#0:
-; AVX-NEXT:    vpsllvd {{.*}}(%rip), %xmm0, %xmm0
-; AVX-NEXT:    vpsllvd {{.*}}(%rip), %xmm0, %xmm0
+; AVX-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX-NEXT:    retq
   %1 = shl <4 x i32> %x, <i32 17, i32 18, i32 19, i32 20>
   %2 = shl <4 x i32> %1, <i32 25, i32 26, i32 27, i32 28>




More information about the llvm-commits mailing list