[llvm] r348953 - [SelectionDAG] Add a generic isSplatValue function

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 12 10:32:29 PST 2018


Author: rksimon
Date: Wed Dec 12 10:32:29 2018
New Revision: 348953

URL: http://llvm.org/viewvc/llvm-project?rev=348953&view=rev
Log:
[SelectionDAG] Add a generic isSplatValue function

This patch introduces a generic function to determine whether a given vector type is known to be a splat value for the specified demanded elements, recursing up the DAG looking for BUILD_VECTOR or VECTOR_SHUFFLE splat patterns.

It also keeps track of the elements that are known to be UNDEF - it returns true if all the demanded elements are UNDEF (as this may be useful under some circumstances), so this needs to be handled by the caller.

A wrapper variant is also provided that doesn't take the DemandedElts or UndefElts arguments for cases where we just want to know if the SDValue is a splat or not (with/without UNDEFS).

I had hoped to completely remove the X86 local version of this function, but I'm seeing some regressions in shift/rotate codegen that will take a little longer to fix and I hope to get this in sooner so I can continue work on PR38243 which needs more capable splat detection.

Differential Revision: https://reviews.llvm.org/D55426

Modified:
    llvm/trunk/include/llvm/CodeGen/SelectionDAG.h
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/trunk/lib/Target/Mips/MipsSEISelLowering.cpp
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp

Modified: llvm/trunk/include/llvm/CodeGen/SelectionDAG.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SelectionDAG.h?rev=348953&r1=348952&r2=348953&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/SelectionDAG.h (original)
+++ llvm/trunk/include/llvm/CodeGen/SelectionDAG.h Wed Dec 12 10:32:29 2018
@@ -1515,6 +1515,18 @@ public:
   /// allow an 'add' to be transformed into an 'or'.
   bool haveNoCommonBitsSet(SDValue A, SDValue B) const;
 
+  /// Test whether \p V has a splatted value for all the demanded elements.
+  ///
+  /// On success \p UndefElts will indicate the elements that have UNDEF
+  /// values instead of the splat value, this is only guaranteed to be correct
+  /// for \p DemandedElts.
+  ///
+  /// NOTE: The function will return true for a demanded splat of UNDEF values.
+  bool isSplatValue(SDValue V, const APInt &DemandedElts, APInt &UndefElts);
+
+  /// Test whether \p V has a splatted value.
+  bool isSplatValue(SDValue V, bool AllowUndefs = false);
+
   /// Match a binop + shuffle pyramid that represents a horizontal reduction
   /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
   /// Extract. The reduction must use one of the opcodes listed in /p

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp?rev=348953&r1=348952&r2=348953&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Wed Dec 12 10:32:29 2018
@@ -2121,6 +2121,102 @@ bool SelectionDAG::MaskedValueIsZero(SDV
   return Mask.isSubsetOf(computeKnownBits(Op, Depth).Zero);
 }
 
+/// isSplatValue - Return true if the vector V has the same value
+/// across all DemandedElts.
+bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts,
+                                APInt &UndefElts) {
+  if (!DemandedElts)
+    return false; // No demanded elts, better to assume we don't know anything.
+
+  EVT VT = V.getValueType();
+  assert(VT.isVector() && "Vector type expected");
+
+  unsigned NumElts = VT.getVectorNumElements();
+  assert(NumElts == DemandedElts.getBitWidth() && "Vector size mismatch");
+  UndefElts = APInt::getNullValue(NumElts);
+
+  switch (V.getOpcode()) {
+  case ISD::BUILD_VECTOR: {
+    SDValue Scl;
+    for (unsigned i = 0; i != NumElts; ++i) {
+      SDValue Op = V.getOperand(i);
+      if (Op.isUndef()) {
+        UndefElts.setBit(i);
+        continue;
+      }
+      if (!DemandedElts[i])
+        continue;
+      if (Scl && Scl != Op)
+        return false;
+      Scl = Op;
+    }
+    return true;
+  }
+  case ISD::VECTOR_SHUFFLE: {
+    // Check if this is a shuffle node doing a splat.
+    // TODO: Do we need to handle shuffle(splat, undef, mask)?
+    int SplatIndex = -1;
+    ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(V)->getMask();
+    for (int i = 0; i != (int)NumElts; ++i) {
+      int M = Mask[i];
+      if (M < 0) {
+        UndefElts.setBit(i);
+        continue;
+      }
+      if (!DemandedElts[i])
+        continue;
+      if (0 <= SplatIndex && SplatIndex != M)
+        return false;
+      SplatIndex = M;
+    }
+    return true;
+  }
+  case ISD::EXTRACT_SUBVECTOR: {
+    SDValue Src = V.getOperand(0);
+    ConstantSDNode *SubIdx = dyn_cast<ConstantSDNode>(V.getOperand(1));
+    unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
+    if (SubIdx && SubIdx->getAPIntValue().ule(NumSrcElts - NumElts)) {
+      // Offset the demanded elts by the subvector index.
+      uint64_t Idx = SubIdx->getZExtValue();
+      APInt UndefSrcElts;
+      APInt DemandedSrc = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx);
+      if (isSplatValue(Src, DemandedSrc, UndefSrcElts)) {
+        UndefElts = UndefSrcElts.extractBits(NumElts, Idx);
+        return true;
+      }
+    }
+    break;
+  }
+  case ISD::ADD:
+  case ISD::SUB:
+  case ISD::AND: {
+    APInt UndefLHS, UndefRHS;
+    SDValue LHS = V.getOperand(0);
+    SDValue RHS = V.getOperand(1);
+    if (isSplatValue(LHS, DemandedElts, UndefLHS) &&
+        isSplatValue(RHS, DemandedElts, UndefRHS)) {
+      UndefElts = UndefLHS | UndefRHS;
+      return true;
+    }
+    break;
+  }
+  }
+
+  return false;
+}
+
+/// Helper wrapper to main isSplatValue function.
+bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) {
+  EVT VT = V.getValueType();
+  assert(VT.isVector() && "Vector type expected");
+  unsigned NumElts = VT.getVectorNumElements();
+
+  APInt UndefElts;
+  APInt DemandedElts = APInt::getAllOnesValue(NumElts);
+  return isSplatValue(V, DemandedElts, UndefElts) &&
+         (AllowUndefs || !UndefElts);
+}
+
 /// Helper function that checks to see if a node is a constant or a
 /// build vector of splat constants at least within the demanded elts.
 static ConstantSDNode *isConstOrDemandedConstSplat(SDValue N,

Modified: llvm/trunk/lib/Target/Mips/MipsSEISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/Mips/MipsSEISelLowering.cpp?rev=348953&r1=348952&r2=348953&view=diff
==============================================================================
--- llvm/trunk/lib/Target/Mips/MipsSEISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/Mips/MipsSEISelLowering.cpp Wed Dec 12 10:32:29 2018
@@ -2360,24 +2360,6 @@ SDValue MipsSETargetLowering::lowerINTRI
   }
 }
 
-/// Check if the given BuildVectorSDNode is a splat.
-/// This method currently relies on DAG nodes being reused when equivalent,
-/// so it's possible for this to return false even when isConstantSplat returns
-/// true.
-static bool isSplatVector(const BuildVectorSDNode *N) {
-  unsigned int nOps = N->getNumOperands();
-  assert(nOps > 1 && "isSplatVector has 0 or 1 sized build vector");
-
-  SDValue Operand0 = N->getOperand(0);
-
-  for (unsigned int i = 1; i < nOps; ++i) {
-    if (N->getOperand(i) != Operand0)
-      return false;
-  }
-
-  return true;
-}
-
 // Lower ISD::EXTRACT_VECTOR_ELT into MipsISD::VEXTRACT_SEXT_ELT.
 //
 // The non-value bits resulting from ISD::EXTRACT_VECTOR_ELT are undefined. We
@@ -2488,7 +2470,7 @@ SDValue MipsSETargetLowering::lowerBUILD
       Result = DAG.getNode(ISD::BITCAST, SDLoc(Node), ResTy, Result);
 
     return Result;
-  } else if (isSplatVector(Node))
+  } else if (DAG.isSplatValue(Op, /* AllowUndefs */ false))
     return Op;
   else if (!isConstantOrUndefBUILD_VECTOR(Node)) {
     // Use INSERT_VECTOR_ELT operations rather than expand to stores.

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=348953&r1=348952&r2=348953&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Dec 12 10:32:29 2018
@@ -24072,26 +24072,30 @@ static SDValue LowerScalarImmediateShift
 }
 
 // If V is a splat value, return the source vector and splat index;
-// TODO - can we make this generic and move to SelectionDAG?
-static SDValue IsSplatVector(SDValue V, int &SplatIdx) {
+static SDValue IsSplatVector(SDValue V, int &SplatIdx, SelectionDAG &DAG) {
   V = peekThroughEXTRACT_SUBVECTORs(V);
 
+  EVT VT = V.getValueType();
   unsigned Opcode = V.getOpcode();
   switch (Opcode) {
-  case ISD::BUILD_VECTOR: {
-    BitVector UndefElts;
-    SDValue SplatAmt = cast<BuildVectorSDNode>(V)->getSplatValue(&UndefElts);
-    if (SplatAmt && !SplatAmt.isUndef()) {
-      for (int i = 0, e = UndefElts.size(); i != e; ++i)
-        if (!UndefElts[i]) {
-          SplatIdx = i;
-          return V;
-        }
+  default: {
+    APInt UndefElts;
+    APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements());
+    if (DAG.isSplatValue(V, DemandedElts, UndefElts)) {
+      // Handle case where all demanded elements are UNDEF.
+      if (DemandedElts.isSubsetOf(UndefElts)) {
+        SplatIdx = 0;
+        return DAG.getUNDEF(VT);
+      }
+      SplatIdx = (UndefElts & DemandedElts).countTrailingOnes();
+      return V;
     }
     break;
   }
   case ISD::VECTOR_SHUFFLE: {
     // Check if this is a shuffle node doing a splat.
+    // TODO - remove this and rely purely on SelectionDAG::isSplatValue,
+    // getTargetVShiftNode currently struggles without the splat source.
     auto *SVN = cast<ShuffleVectorSDNode>(V);
     if (!SVN->isSplat())
       break;
@@ -24100,23 +24104,6 @@ static SDValue IsSplatVector(SDValue V,
     SplatIdx = Idx % NumElts;
     return V.getOperand(Idx / NumElts);
   }
-  case ISD::SUB: {
-    SDValue LHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(0));
-    SDValue RHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(1));
-
-    // Ensure that the corresponding splat BV element is not UNDEF.
-    BitVector UndefElts;
-    auto *BV0 = dyn_cast<BuildVectorSDNode>(LHS);
-    auto *SVN1 = dyn_cast<ShuffleVectorSDNode>(RHS);
-    if (BV0 && SVN1 && BV0->getSplatValue(&UndefElts) && SVN1->isSplat()) {
-      int Idx = SVN1->getSplatIndex();
-      if (!UndefElts[Idx]) {
-        SplatIdx = Idx;
-        return V;
-      }
-    }
-    break;
-  }
   }
 
   return SDValue();
@@ -24125,7 +24112,7 @@ static SDValue IsSplatVector(SDValue V,
 static SDValue GetSplatValue(SDValue V, const SDLoc &dl,
                              SelectionDAG &DAG) {
   int SplatIdx;
-  if (SDValue SrcVector = IsSplatVector(V, SplatIdx))
+  if (SDValue SrcVector = IsSplatVector(V, SplatIdx, DAG))
     return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
                        SrcVector.getValueType().getScalarType(), SrcVector,
                        DAG.getIntPtrConstant(SplatIdx, dl));
@@ -24850,8 +24837,7 @@ static SDValue LowerRotate(SDValue Op, c
   // Rotate by splat - expand back to shifts.
   // TODO - legalizers should be able to handle this.
   if (EltSizeInBits >= 16 || Subtarget.hasBWI()) {
-    int SplatIdx;
-    if (IsSplatVector(Amt, SplatIdx)) {
+    if (DAG.isSplatValue(Amt)) {
       SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT);
       AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt);
       SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt);




More information about the llvm-commits mailing list