[llvm] r354797 - [SelectionDAG] Add demanded elts variants to isConstOrConstSplat helpers. NFCI.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 25 08:31:59 PST 2019


Author: rksimon
Date: Mon Feb 25 08:31:58 2019
New Revision: 354797

URL: http://llvm.org/viewvc/llvm-project?rev=354797&view=rev
Log:
[SelectionDAG] Add demanded elts variants to isConstOrConstSplat helpers. NFCI.

These helpers extend the existing isConstOrConstSplat helper checks to support DemandedElts masks as well.

We already had a local version of this in SelectionDAG that computeKnownBits/ComputeNumSignBits made use of, but this adds the functionality directly to the BuildVectorSDNode node and extends isConstOrConstSplat etc. to use that.

This will allow us to reuse the functionality in SimplifyDemandedVectorElts/SimplifyDemandedBits.

Differential Revision: https://reviews.llvm.org/D58503

Modified:
    llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Modified: llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h?rev=354797&r1=354796&r2=354797&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h (original)
+++ llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h Mon Feb 25 08:31:58 2019
@@ -1629,9 +1629,19 @@ bool isBitwiseNot(SDValue V);
 /// Returns the SDNode if it is a constant splat BuildVector or constant int.
 ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false);
 
+/// Returns the SDNode if it is a demanded constant splat BuildVector or
+/// constant int.
+ConstantSDNode *isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
+                                    bool AllowUndefs = false);
+
 /// Returns the SDNode if it is a constant splat BuildVector or constant float.
 ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, bool AllowUndefs = false);
 
+/// Returns the SDNode if it is a demanded constant splat BuildVector or
+/// constant float.
+ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, const APInt &DemandedElts,
+                                        bool AllowUndefs = false);
+
 /// Return true if the value is a constant 0 integer or a splatted vector of
 /// a constant 0 integer (with no undefs by default).
 /// Build vector implicit truncation is not an issue for null values.
@@ -1868,12 +1878,31 @@ public:
                        unsigned MinSplatBits = 0,
                        bool isBigEndian = false) const;
 
+  /// Returns the demanded splatted value or a null value if this is not a
+  /// splat.
+  ///
+  /// The DemandedElts mask indicates the elements that must be in the splat.
+  /// If passed a non-null UndefElements bitvector, it will resize it to match
+  /// the vector width and set the bits where elements are undef.
+  SDValue getSplatValue(const APInt &DemandedElts,
+                        BitVector *UndefElements = nullptr) const;
+
   /// Returns the splatted value or a null value if this is not a splat.
   ///
   /// If passed a non-null UndefElements bitvector, it will resize it to match
   /// the vector width and set the bits where elements are undef.
   SDValue getSplatValue(BitVector *UndefElements = nullptr) const;
 
+  /// Returns the demanded splatted constant or null if this is not a constant
+  /// splat.
+  ///
+  /// The DemandedElts mask indicates the elements that must be in the splat.
+  /// If passed a non-null UndefElements bitvector, it will resize it to match
+  /// the vector width and set the bits where elements are undef.
+  ConstantSDNode *
+  getConstantSplatNode(const APInt &DemandedElts,
+                       BitVector *UndefElements = nullptr) const;
+
   /// Returns the splatted constant or null if this is not a constant
   /// splat.
   ///
@@ -1882,6 +1911,16 @@ public:
   ConstantSDNode *
   getConstantSplatNode(BitVector *UndefElements = nullptr) const;
 
+  /// Returns the demanded splatted constant FP or null if this is not a
+  /// constant FP splat.
+  ///
+  /// The DemandedElts mask indicates the elements that must be in the splat.
+  /// If passed a non-null UndefElements bitvector, it will resize it to match
+  /// the vector width and set the bits where elements are undef.
+  ConstantFPSDNode *
+  getConstantFPSplatNode(const APInt &DemandedElts,
+                         BitVector *UndefElements = nullptr) const;
+
   /// Returns the splatted constant FP or null if this is not a constant
   /// FP splat.
   ///

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp?rev=354797&r1=354796&r2=354797&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Mon Feb 25 08:31:58 2019
@@ -2253,30 +2253,6 @@ bool SelectionDAG::isSplatValue(SDValue
          (AllowUndefs || !UndefElts);
 }
 
-/// Helper function that checks to see if a node is a constant or a
-/// build vector of splat constants at least within the demanded elts.
-static ConstantSDNode *isConstOrDemandedConstSplat(SDValue N,
-                                                   const APInt &DemandedElts) {
-  if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
-    return CN;
-  if (N.getOpcode() != ISD::BUILD_VECTOR)
-    return nullptr;
-  EVT VT = N.getValueType();
-  ConstantSDNode *Cst = nullptr;
-  unsigned NumElts = VT.getVectorNumElements();
-  assert(DemandedElts.getBitWidth() == NumElts && "Unexpected vector size");
-  for (unsigned i = 0; i != NumElts; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(i));
-    if (!C || (Cst && Cst->getAPIntValue() != C->getAPIntValue()) ||
-        C->getValueType(0) != VT.getScalarType())
-      return nullptr;
-    Cst = C;
-  }
-  return Cst;
-}
-
 /// If a SHL/SRA/SRL node has a constant or splat constant shift amount that
 /// is less than the element bit-width of the shift node, return it.
 static const APInt *getValidShiftAmountConstant(SDValue V) {
@@ -2717,8 +2693,7 @@ KnownBits SelectionDAG::computeKnownBits
     break;
   case ISD::FSHL:
   case ISD::FSHR:
-    if (ConstantSDNode *C =
-            isConstOrDemandedConstSplat(Op.getOperand(2), DemandedElts)) {
+    if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(2), DemandedElts)) {
       unsigned Amt = C->getAPIntValue().urem(BitWidth);
 
       // For fshl, 0-shift returns the 1st arg.
@@ -3155,10 +3130,10 @@ KnownBits SelectionDAG::computeKnownBits
     // the minimum of the clamp min/max range.
     bool IsMax = (Opcode == ISD::SMAX);
     ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
-    if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)))
+    if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
       if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
-        CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1),
-                                              DemandedElts);
+        CstHigh =
+            isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
     if (CstLow && CstHigh) {
       if (!IsMax)
         std::swap(CstLow, CstHigh);
@@ -3439,7 +3414,7 @@ unsigned SelectionDAG::ComputeNumSignBit
     Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
     // SRA X, C   -> adds C sign bits.
     if (ConstantSDNode *C =
-            isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) {
+            isConstOrConstSplat(Op.getOperand(1), DemandedElts)) {
       APInt ShiftVal = C->getAPIntValue();
       ShiftVal += Tmp;
       Tmp = ShiftVal.uge(VTBits) ? VTBits : ShiftVal.getZExtValue();
@@ -3447,7 +3422,7 @@ unsigned SelectionDAG::ComputeNumSignBit
     return Tmp;
   case ISD::SHL:
     if (ConstantSDNode *C =
-            isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) {
+            isConstOrConstSplat(Op.getOperand(1), DemandedElts)) {
       // shl destroys sign bits.
       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
       if (C->getAPIntValue().uge(VTBits) ||      // Bad shift.
@@ -3487,10 +3462,10 @@ unsigned SelectionDAG::ComputeNumSignBit
     // the minimum of the clamp min/max range.
     bool IsMax = (Opcode == ISD::SMAX);
     ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
-    if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)))
+    if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
       if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
-        CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1),
-                                              DemandedElts);
+        CstHigh =
+            isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
     if (CstLow && CstHigh) {
       if (!IsMax)
         std::swap(CstLow, CstHigh);
@@ -8593,6 +8568,24 @@ ConstantSDNode *llvm::isConstOrConstSpla
   return nullptr;
 }
 
+ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
+                                          bool AllowUndefs) {
+  if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
+    return CN;
+
+  if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
+    BitVector UndefElements;
+    ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements);
+
+    // BuildVectors can truncate their operands. Ignore that case here.
+    if (CN && (UndefElements.none() || AllowUndefs) &&
+        CN->getValueType(0) == N.getValueType().getScalarType())
+      return CN;
+  }
+
+  return nullptr;
+}
+
 ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) {
   if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
     return CN;
@@ -8607,6 +8600,23 @@ ConstantFPSDNode *llvm::isConstOrConstSp
   return nullptr;
 }
 
+ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N,
+                                              const APInt &DemandedElts,
+                                              bool AllowUndefs) {
+  if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
+    return CN;
+
+  if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
+    BitVector UndefElements;
+    ConstantFPSDNode *CN =
+        BV->getConstantFPSplatNode(DemandedElts, &UndefElements);
+    if (CN && (UndefElements.none() || AllowUndefs))
+      return CN;
+  }
+
+  return nullptr;
+}
+
 bool llvm::isNullOrNullSplat(SDValue N, bool AllowUndefs) {
   // TODO: may want to use peekThroughBitcast() here.
   ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
@@ -9193,13 +9203,20 @@ bool BuildVectorSDNode::isConstantSplat(
   return true;
 }
 
-SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
+SDValue BuildVectorSDNode::getSplatValue(const APInt &DemandedElts,
+                                         BitVector *UndefElements) const {
   if (UndefElements) {
     UndefElements->clear();
     UndefElements->resize(getNumOperands());
   }
+  assert(getNumOperands() == DemandedElts.getBitWidth() &&
+         "Unexpected vector size");
+  if (!DemandedElts)
+    return SDValue();
   SDValue Splatted;
   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
+    if (!DemandedElts[i])
+      continue;
     SDValue Op = getOperand(i);
     if (Op.isUndef()) {
       if (UndefElements)
@@ -9212,20 +9229,40 @@ SDValue BuildVectorSDNode::getSplatValue
   }
 
   if (!Splatted) {
-    assert(getOperand(0).isUndef() &&
+    unsigned FirstDemandedIdx = DemandedElts.countTrailingZeros();
+    assert(getOperand(FirstDemandedIdx).isUndef() &&
            "Can only have a splat without a constant for all undefs.");
-    return getOperand(0);
+    return getOperand(FirstDemandedIdx);
   }
 
   return Splatted;
 }
 
+SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
+  APInt DemandedElts = APInt::getAllOnesValue(getNumOperands());
+  return getSplatValue(DemandedElts, UndefElements);
+}
+
+ConstantSDNode *
+BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts,
+                                        BitVector *UndefElements) const {
+  return dyn_cast_or_null<ConstantSDNode>(
+      getSplatValue(DemandedElts, UndefElements));
+}
+
 ConstantSDNode *
 BuildVectorSDNode::getConstantSplatNode(BitVector *UndefElements) const {
   return dyn_cast_or_null<ConstantSDNode>(getSplatValue(UndefElements));
 }
 
 ConstantFPSDNode *
+BuildVectorSDNode::getConstantFPSplatNode(const APInt &DemandedElts,
+                                          BitVector *UndefElements) const {
+  return dyn_cast_or_null<ConstantFPSDNode>(
+      getSplatValue(DemandedElts, UndefElements));
+}
+
+ConstantFPSDNode *
 BuildVectorSDNode::getConstantFPSplatNode(BitVector *UndefElements) const {
   return dyn_cast_or_null<ConstantFPSDNode>(getSplatValue(UndefElements));
 }




More information about the llvm-commits mailing list