[llvm] r325815 - [SelectionDAG] Move matchUnaryPredicate/matchBinaryPredicate into SelectionDAGNodes.h

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 10:45:13 PST 2018


Author: rksimon
Date: Thu Feb 22 10:45:13 2018
New Revision: 325815

URL: http://llvm.org/viewvc/llvm-project?rev=325815&view=rev
Log:
[SelectionDAG] Move matchUnaryPredicate/matchBinaryPredicate into SelectionDAGNodes.h

This allows us to improve vector constant matching in more DAG code (backends, TargetLowering etc.).

Differential Revision: https://reviews.llvm.org/D43466

Modified:
    llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Modified: llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h?rev=325815&r1=325814&r2=325815&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h (original)
+++ llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h Thu Feb 22 10:45:13 2018
@@ -2332,6 +2332,17 @@ namespace ISD {
       cast<StoreSDNode>(N)->getAddressingMode() == ISD::UNINDEXED;
   }
 
+  /// Attempt to match a unary predicate against a scalar/splat constant or
+  /// every element of a constant BUILD_VECTOR.
+  bool matchUnaryPredicate(SDValue Op,
+                           std::function<bool(ConstantSDNode *)> Match);
+
+  /// Attempt to match a binary predicate against a pair of scalar/splat
+  /// constants or every element of a pair of constant BUILD_VECTORs.
+  bool matchBinaryPredicate(
+      SDValue LHS, SDValue RHS,
+      std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match);
+
 } // end namespace ISD
 
 } // end namespace llvm

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=325815&r1=325814&r2=325815&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Thu Feb 22 10:45:13 2018
@@ -920,56 +920,6 @@ static bool isAnyConstantBuildVector(con
          ISD::isBuildVectorOfConstantFPSDNodes(N);
 }
 
-// Attempt to match a unary predicate against a scalar/splat constant or
-// every element of a constant BUILD_VECTOR.
-static bool matchUnaryPredicate(SDValue Op,
-                                std::function<bool(ConstantSDNode *)> Match) {
-  if (auto *Cst = dyn_cast<ConstantSDNode>(Op))
-    return Match(Cst);
-
-  if (ISD::BUILD_VECTOR != Op.getOpcode())
-    return false;
-
-  EVT SVT = Op.getValueType().getScalarType();
-  for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
-    auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i));
-    if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst))
-      return false;
-  }
-  return true;
-}
-
-// Attempt to match a binary predicate against a pair of scalar/splat constants
-// or every element of a pair of constant BUILD_VECTORs.
-static bool matchBinaryPredicate(
-    SDValue LHS, SDValue RHS,
-    std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match) {
-  if (LHS.getValueType() != RHS.getValueType())
-    return false;
-
-  if (auto *LHSCst = dyn_cast<ConstantSDNode>(LHS))
-    if (auto *RHSCst = dyn_cast<ConstantSDNode>(RHS))
-      return Match(LHSCst, RHSCst);
-
-  if (ISD::BUILD_VECTOR != LHS.getOpcode() ||
-      ISD::BUILD_VECTOR != RHS.getOpcode())
-    return false;
-
-  EVT SVT = LHS.getValueType().getScalarType();
-  for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) {
-    auto *LHSCst = dyn_cast<ConstantSDNode>(LHS.getOperand(i));
-    auto *RHSCst = dyn_cast<ConstantSDNode>(RHS.getOperand(i));
-    if (!LHSCst || !RHSCst)
-      return false;
-    if (LHSCst->getValueType(0) != SVT ||
-        LHSCst->getValueType(0) != RHSCst->getValueType(0))
-      return false;
-    if (!Match(LHSCst, RHSCst))
-      return false;
-  }
-  return true;
-}
-
 SDValue DAGCombiner::ReassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
                                     SDValue N1) {
   EVT VT = N0.getValueType();
@@ -4067,7 +4017,7 @@ SDValue DAGCombiner::visitAND(SDNode *N)
     return RHS->getAPIntValue().isSubsetOf(LHS->getAPIntValue());
   };
   if (N0.getOpcode() == ISD::OR &&
-      matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
+      ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchSubset))
     return N1;
   // fold (and (any_ext V), c) -> (zero_ext V) if 'and' only clears top bits.
   if (N1C && N0.getOpcode() == ISD::ANY_EXTEND) {
@@ -4756,7 +4706,7 @@ SDValue DAGCombiner::visitOR(SDNode *N)
     return LHS->getAPIntValue().intersects(RHS->getAPIntValue());
   };
   if (N0.getOpcode() == ISD::AND && N0.getNode()->hasOneUse() &&
-      matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect)) {
+      ISD::matchBinaryPredicate(N0.getOperand(1), N1, MatchIntersect)) {
     if (SDValue COR = DAG.FoldConstantArithmetic(
             ISD::OR, SDLoc(N1), VT, N1.getNode(), N0.getOperand(1).getNode())) {
       SDValue IOR = DAG.getNode(ISD::OR, SDLoc(N0), VT, N0.getOperand(0), N1);
@@ -4991,7 +4941,7 @@ SDNode *DAGCombiner::MatchRotate(SDValue
                                         ConstantSDNode *RHS) {
     return (LHS->getAPIntValue() + RHS->getAPIntValue()) == EltSizeInBits;
   };
-  if (matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
+  if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
     SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
                               LHSShiftArg, HasROTL ? LHSShiftAmt : RHSShiftAmt);
 
@@ -5704,7 +5654,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N)
   auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) {
     return Val->getAPIntValue().uge(OpSizeInBits);
   };
-  if (matchUnaryPredicate(N1, MatchShiftTooBig))
+  if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig))
     return DAG.getUNDEF(VT);
   // fold (shl x, 0) -> x
   if (N1C && N1C->isNullValue())
@@ -5739,7 +5689,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N)
       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
       return (c1 + c2).uge(OpSizeInBits);
     };
-    if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
+    if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
       return DAG.getConstant(0, SDLoc(N), VT);
 
     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
@@ -5749,7 +5699,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N)
       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
       return (c1 + c2).ult(OpSizeInBits);
     };
-    if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
+    if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
       SDLoc DL(N);
       EVT ShiftVT = N1.getValueType();
       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
@@ -5925,7 +5875,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N)
   auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) {
     return Val->getAPIntValue().uge(OpSizeInBits);
   };
-  if (matchUnaryPredicate(N1, MatchShiftTooBig))
+  if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig))
     return DAG.getUNDEF(VT);
   // fold (sra x, 0) -> x
   if (N1C && N1C->isNullValue())
@@ -5960,7 +5910,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N)
       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
       return (c1 + c2).uge(OpSizeInBits);
     };
-    if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
+    if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0),
                          DAG.getConstant(OpSizeInBits - 1, DL, ShiftVT));
 
@@ -5971,7 +5921,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N)
       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
       return (c1 + c2).ult(OpSizeInBits);
     };
-    if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
+    if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));
       return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), Sum);
     }
@@ -6089,7 +6039,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N)
   auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) {
     return Val->getAPIntValue().uge(OpSizeInBits);
   };
-  if (matchUnaryPredicate(N1, MatchShiftTooBig))
+  if (ISD::matchUnaryPredicate(N1, MatchShiftTooBig))
     return DAG.getUNDEF(VT);
   // fold (srl x, 0) -> x
   if (N1C && N1C->isNullValue())
@@ -6112,7 +6062,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N)
       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
       return (c1 + c2).uge(OpSizeInBits);
     };
-    if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
+    if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange))
       return DAG.getConstant(0, SDLoc(N), VT);
 
     auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS,
@@ -6122,7 +6072,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N)
       zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
       return (c1 + c2).ult(OpSizeInBits);
     };
-    if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
+    if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) {
       SDLoc DL(N);
       EVT ShiftVT = N1.getValueType();
       SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1));

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp?rev=325815&r1=325814&r2=325815&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Thu Feb 22 10:45:13 2018
@@ -263,6 +263,52 @@ bool ISD::allOperandsUndef(const SDNode
   return true;
 }
 
+bool ISD::matchUnaryPredicate(SDValue Op,
+                              std::function<bool(ConstantSDNode *)> Match) {
+  if (auto *Cst = dyn_cast<ConstantSDNode>(Op))
+    return Match(Cst);
+
+  if (ISD::BUILD_VECTOR != Op.getOpcode())
+    return false;
+
+  EVT SVT = Op.getValueType().getScalarType();
+  for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
+    auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i));
+    if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst))
+      return false;
+  }
+  return true;
+}
+
+bool ISD::matchBinaryPredicate(
+    SDValue LHS, SDValue RHS,
+    std::function<bool(ConstantSDNode *, ConstantSDNode *)> Match) {
+  if (LHS.getValueType() != RHS.getValueType())
+    return false;
+
+  if (auto *LHSCst = dyn_cast<ConstantSDNode>(LHS))
+    if (auto *RHSCst = dyn_cast<ConstantSDNode>(RHS))
+      return Match(LHSCst, RHSCst);
+
+  if (ISD::BUILD_VECTOR != LHS.getOpcode() ||
+      ISD::BUILD_VECTOR != RHS.getOpcode())
+    return false;
+
+  EVT SVT = LHS.getValueType().getScalarType();
+  for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) {
+    auto *LHSCst = dyn_cast<ConstantSDNode>(LHS.getOperand(i));
+    auto *RHSCst = dyn_cast<ConstantSDNode>(RHS.getOperand(i));
+    if (!LHSCst || !RHSCst)
+      return false;
+    if (LHSCst->getValueType(0) != SVT ||
+        LHSCst->getValueType(0) != RHSCst->getValueType(0))
+      return false;
+    if (!Match(LHSCst, RHSCst))
+      return false;
+  }
+  return true;
+}
+
 ISD::NodeType ISD::getExtForLoadExtType(bool IsFP, ISD::LoadExtType ExtType) {
   switch (ExtType) {
   case ISD::EXTLOAD:




More information about the llvm-commits mailing list