[llvm] f60d3ec - [DAG] Add BuildVectorSDNode::getConstantRawBits helper

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 8 04:07:46 PST 2021


Author: Simon Pilgrim
Date: 2021-11-08T12:07:38Z
New Revision: f60d3ec0c7fd324bccf2275c8f28c390b2b5f069

URL: https://github.com/llvm/llvm-project/commit/f60d3ec0c7fd324bccf2275c8f28c390b2b5f069
DIFF: https://github.com/llvm/llvm-project/commit/f60d3ec0c7fd324bccf2275c8f28c390b2b5f069.diff

LOG: [DAG] Add BuildVectorSDNode::getConstantRawBits helper

We have several places where we need to extract the raw bits data from a BUILD_VECTOR node, so consolidate this to a single helper function that handles Undefs and Integer/FP constants, including implicit truncation.

This should make it easier to extend D113202 to handle more constant folding of bitcasted constant data.

Differential Revision: https://reviews.llvm.org/D113351

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index cc00af90ec67..c2c5dbc26478 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -2049,6 +2049,14 @@ class BuildVectorSDNode : public SDNode {
   int32_t getConstantFPSplatPow2ToLog2Int(BitVector *UndefElements,
                                           uint32_t BitWidth) const;
 
+  /// Extract the raw bit data from a build vector of Undef, Constant or
+  /// ConstantFP node elements. Each raw bit element will be \p
+  /// DstEltSizeInBits wide, undef elements are treated as zero, and entirely
+  /// undefined elements are flagged in \p UndefElements.
+  bool getConstantRawBits(bool IsLittleEndian, unsigned DstEltSizeInBits,
+                          SmallVectorImpl<APInt> &RawBitElements,
+                          BitVector &UndefElements) const;
+
   bool isConstant() const;
 
   static bool classof(const SDNode *N) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d01e3b3fd663..9f40a0247c2a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13039,68 +13039,30 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
     return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
   }
 
-  SDLoc DL(BV);
-
   // Okay, we know the src/dst types are both integers of 
diff ering types.
-  // Handling growing first.
   assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
-  if (SrcBitSize < DstBitSize) {
-    unsigned NumInputsPerOutput = DstBitSize/SrcBitSize;
 
-    SmallVector<SDValue, 8> Ops;
-    for (unsigned i = 0, e = BV->getNumOperands(); i != e;
-         i += NumInputsPerOutput) {
-      bool isLE = DAG.getDataLayout().isLittleEndian();
-      APInt NewBits = APInt(DstBitSize, 0);
-      bool EltIsUndef = true;
-      for (unsigned j = 0; j != NumInputsPerOutput; ++j) {
-        // Shift the previously computed bits over.
-        NewBits <<= SrcBitSize;
-        SDValue Op = BV->getOperand(i+ (isLE ? (NumInputsPerOutput-j-1) : j));
-        if (Op.isUndef()) continue;
-        EltIsUndef = false;
-
-        NewBits |= cast<ConstantSDNode>(Op)->getAPIntValue().
-                   zextOrTrunc(SrcBitSize).zext(DstBitSize);
-      }
+  // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
+  // BuildVectorSDNode?
+  auto *BVN = cast<BuildVectorSDNode>(BV);
 
-      if (EltIsUndef)
-        Ops.push_back(DAG.getUNDEF(DstEltVT));
-      else
-        Ops.push_back(DAG.getConstant(NewBits, DL, DstEltVT));
-    }
-
-    EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
-    return DAG.getBuildVector(VT, DL, Ops);
-  }
+  // Extract the constant raw bit data.
+  BitVector UndefElements;
+  SmallVector<APInt> RawBits;
+  bool IsLE = DAG.getDataLayout().isLittleEndian();
+  if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
+    return SDValue();
 
-  // Finally, this must be the case where we are shrinking elements: each input
-  // turns into multiple outputs.
-  unsigned NumOutputsPerInput = SrcBitSize/DstBitSize;
-  EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
-                            NumOutputsPerInput*BV->getNumOperands());
+  SDLoc DL(BV);
   SmallVector<SDValue, 8> Ops;
-
-  for (const SDValue &Op : BV->op_values()) {
-    if (Op.isUndef()) {
-      Ops.append(NumOutputsPerInput, DAG.getUNDEF(DstEltVT));
-      continue;
-    }
-
-    APInt OpVal = cast<ConstantSDNode>(Op)->
-                  getAPIntValue().zextOrTrunc(SrcBitSize);
-
-    for (unsigned j = 0; j != NumOutputsPerInput; ++j) {
-      APInt ThisVal = OpVal.trunc(DstBitSize);
-      Ops.push_back(DAG.getConstant(ThisVal, DL, DstEltVT));
-      OpVal.lshrInPlace(DstBitSize);
-    }
-
-    // For big endian targets, swap the order of the pieces of each element.
-    if (DAG.getDataLayout().isBigEndian())
-      std::reverse(Ops.end()-NumOutputsPerInput, Ops.end());
+  for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
+    if (UndefElements[I])
+      Ops.push_back(DAG.getUNDEF(DstEltVT));
+    else
+      Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT));
   }
 
+  EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
   return DAG.getBuildVector(VT, DL, Ops);
 }
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 96673897d0e6..6739f53eec23 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -10916,6 +10916,73 @@ BuildVectorSDNode::getConstantFPSplatPow2ToLog2Int(BitVector *UndefElements,
   return -1;
 }
 
+bool BuildVectorSDNode::getConstantRawBits(
+    bool IsLittleEndian, unsigned DstEltSizeInBits,
+    SmallVectorImpl<APInt> &RawBitElements, BitVector &UndefElements) const {
+  // Early-out if this contains anything but Undef/Constant/ConstantFP.
+  if (!isConstant())
+    return false;
+
+  unsigned NumSrcOps = getNumOperands();
+  unsigned SrcEltSizeInBits = getValueType(0).getScalarSizeInBits();
+  assert(((NumSrcOps * SrcEltSizeInBits) % DstEltSizeInBits) == 0 &&
+         "Invalid bitcast scale");
+
+  unsigned NumDstOps = (NumSrcOps * SrcEltSizeInBits) / DstEltSizeInBits;
+  UndefElements.clear();
+  UndefElements.resize(NumDstOps, false);
+  RawBitElements.assign(NumDstOps, APInt::getNullValue(DstEltSizeInBits));
+
+  // Concatenate src elements constant bits together into dst element.
+  if (SrcEltSizeInBits <= DstEltSizeInBits) {
+    unsigned Scale = DstEltSizeInBits / SrcEltSizeInBits;
+    for (unsigned I = 0; I != NumDstOps; ++I) {
+      UndefElements.set(I);
+      APInt &RawBits = RawBitElements[I];
+      for (unsigned J = 0; J != Scale; ++J) {
+        unsigned Idx = (I * Scale) + (IsLittleEndian ? J : (Scale - J - 1));
+        SDValue Op = getOperand(Idx);
+        if (Op.isUndef())
+          continue;
+        UndefElements.reset(I);
+        auto *CInt = dyn_cast<ConstantSDNode>(Op);
+        auto *CFP = dyn_cast<ConstantFPSDNode>(Op);
+        assert((CInt || CFP) && "Unknown constant");
+        APInt EltBits =
+            CInt ? CInt->getAPIntValue().truncOrSelf(SrcEltSizeInBits)
+                 : CFP->getValueAPF().bitcastToAPInt();
+        assert(EltBits.getBitWidth() == SrcEltSizeInBits &&
+               "Illegal constant bitwidths");
+        RawBits.insertBits(EltBits, J * SrcEltSizeInBits);
+      }
+    }
+    return true;
+  }
+
+  // Split src element constant bits into dst elements.
+  unsigned Scale = SrcEltSizeInBits / DstEltSizeInBits;
+  for (unsigned I = 0; I != NumSrcOps; ++I) {
+    SDValue Op = getOperand(I);
+    if (Op.isUndef()) {
+      UndefElements.set(I * Scale, (I + 1) * Scale);
+      continue;
+    }
+    auto *CInt = dyn_cast<ConstantSDNode>(Op);
+    auto *CFP = dyn_cast<ConstantFPSDNode>(Op);
+    assert((CInt || CFP) && "Unknown constant");
+    APInt EltBits =
+        CInt ? CInt->getAPIntValue() : CFP->getValueAPF().bitcastToAPInt();
+
+    for (unsigned J = 0; J != Scale; ++J) {
+      unsigned Idx = (I * Scale) + (IsLittleEndian ? J : (Scale - J - 1));
+      APInt &RawBits = RawBitElements[Idx];
+      RawBits = EltBits.extractBits(DstEltSizeInBits, J * DstEltSizeInBits);
+    }
+  }
+
+  return true;
+}
+
 bool BuildVectorSDNode::isConstant() const {
   for (const SDValue &Op : op_values()) {
     unsigned Opc = Op.getOpcode();

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ba2a16c531b0..cf7649138e46 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -6879,40 +6879,17 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
   }
 
   // Extract constant bits from build vector.
-  if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) {
+  if (auto *BV = dyn_cast<BuildVectorSDNode>(Op)) {
+    BitVector Undefs;
+    SmallVector<APInt> SrcEltBits;
     unsigned SrcEltSizeInBits = VT.getScalarSizeInBits();
-    unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits;
-
-    APInt UndefSrcElts(NumSrcElts, 0);
-    SmallVector<APInt, 64> SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0));
-    for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
-      const SDValue &Src = Op.getOperand(i);
-      if (Src.isUndef()) {
-        UndefSrcElts.setBit(i);
-        continue;
-      }
-      auto *Cst = cast<ConstantSDNode>(Src);
-      SrcEltBits[i] = Cst->getAPIntValue().zextOrTrunc(SrcEltSizeInBits);
+    if (BV->getConstantRawBits(true, SrcEltSizeInBits, SrcEltBits, Undefs)) {
+      APInt UndefSrcElts = APInt::getNullValue(SrcEltBits.size());
+      for (unsigned I = 0, E = SrcEltBits.size(); I != E; ++I)
+        if (Undefs[I])
+          UndefSrcElts.setBit(I);
+      return CastBitData(UndefSrcElts, SrcEltBits);
     }
-    return CastBitData(UndefSrcElts, SrcEltBits);
-  }
-  if (ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) {
-    unsigned SrcEltSizeInBits = VT.getScalarSizeInBits();
-    unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits;
-
-    APInt UndefSrcElts(NumSrcElts, 0);
-    SmallVector<APInt, 64> SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0));
-    for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
-      const SDValue &Src = Op.getOperand(i);
-      if (Src.isUndef()) {
-        UndefSrcElts.setBit(i);
-        continue;
-      }
-      auto *Cst = cast<ConstantFPSDNode>(Src);
-      APInt RawBits = Cst->getValueAPF().bitcastToAPInt();
-      SrcEltBits[i] = RawBits.zextOrTrunc(SrcEltSizeInBits);
-    }
-    return CastBitData(UndefSrcElts, SrcEltBits);
   }
 
   // Extract constant bits from constant pool vector.


        


More information about the llvm-commits mailing list