[llvm] r250118 - [SelectionDAG] Add common vector constant folding helper function

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 12 16:00:11 PDT 2015


Author: rksimon
Date: Mon Oct 12 18:00:11 2015
New Revision: 250118

URL: http://llvm.org/viewvc/llvm-project?rev=250118&view=rev
Log:
[SelectionDAG] Add common vector constant folding helper function

We have a number of functions that implement constant folding of vectors (unary and binary ops) in near identical manners (and the differences don't appear to be critical).

This patch introduces a common implementation (SelectionDAG::FoldConstantVectorArithmetic) and calls this in both the unary and binary op cases.

After this initial patch I intend to begin enabling vector constant folding for a wider number of opcodes in SelectionDAG::getNode().

Differential Revision: http://reviews.llvm.org/D13665

Modified:
    llvm/trunk/include/llvm/CodeGen/SelectionDAG.h
    llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Modified: llvm/trunk/include/llvm/CodeGen/SelectionDAG.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SelectionDAG.h?rev=250118&r1=250117&r2=250118&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/SelectionDAG.h (original)
+++ llvm/trunk/include/llvm/CodeGen/SelectionDAG.h Mon Oct 12 18:00:11 2015
@@ -1166,6 +1166,10 @@ public:
                                  const ConstantSDNode *Cst1,
                                  const ConstantSDNode *Cst2);
 
+  SDValue FoldConstantVectorArithmetic(unsigned Opcode, SDLoc DL,
+                                       EVT VT, ArrayRef<SDValue> Ops,
+                                       const SDNodeFlags *Flags = nullptr);
+
   /// Constant fold a setcc to true or false.
   SDValue FoldSetCC(EVT VT, SDValue N1,
                     SDValue N2, ISD::CondCode Cond, SDLoc dl);

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=250118&r1=250117&r2=250118&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Mon Oct 12 18:00:11 2015
@@ -13454,70 +13454,12 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNo
 
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
+  SDValue Ops[] = {LHS, RHS};
 
-  // If the LHS and RHS are BUILD_VECTOR nodes, see if we can constant fold
-  // this operation.
-  if (LHS.getOpcode() == ISD::BUILD_VECTOR &&
-      RHS.getOpcode() == ISD::BUILD_VECTOR) {
-    // Check if both vectors are constants. If not bail out.
-    if (!(cast<BuildVectorSDNode>(LHS)->isConstant() &&
-          cast<BuildVectorSDNode>(RHS)->isConstant()))
-      return SDValue();
-
-    SmallVector<SDValue, 8> Ops;
-    for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) {
-      SDValue LHSOp = LHS.getOperand(i);
-      SDValue RHSOp = RHS.getOperand(i);
-
-      // Can't fold divide by zero.
-      if (N->getOpcode() == ISD::SDIV || N->getOpcode() == ISD::UDIV ||
-          N->getOpcode() == ISD::FDIV) {
-        if (isNullConstant(RHSOp) || (RHSOp.getOpcode() == ISD::ConstantFP &&
-             cast<ConstantFPSDNode>(RHSOp.getNode())->isZero()))
-          break;
-      }
-
-      EVT VT = LHSOp.getValueType();
-      EVT RVT = RHSOp.getValueType();
-      EVT ST = VT;
-
-      if (RVT.getSizeInBits() < VT.getSizeInBits())
-        ST = RVT;
-
-      // Integer BUILD_VECTOR operands may have types larger than the element
-      // size (e.g., when the element type is not legal).  Prior to type
-      // legalization, the types may not match between the two BUILD_VECTORS.
-      // Truncate the operands to make them match.
-      if (VT.getSizeInBits() != LHS.getValueType().getScalarSizeInBits()) {
-        EVT ScalarT = LHS.getValueType().getScalarType();
-        LHSOp = DAG.getNode(ISD::TRUNCATE, SDLoc(N), ScalarT, LHSOp);
-        VT = LHSOp.getValueType();
-      }
-      if (RVT.getSizeInBits() != RHS.getValueType().getScalarSizeInBits()) {
-        EVT ScalarT = RHS.getValueType().getScalarType();
-        RHSOp = DAG.getNode(ISD::TRUNCATE, SDLoc(N), ScalarT, RHSOp);
-        RVT = RHSOp.getValueType();
-      }
-
-      SDValue FoldOp = DAG.getNode(N->getOpcode(), SDLoc(LHS), VT,
-                                   LHSOp, RHSOp, N->getFlags());
-
-      // We need the resulting constant to be legal if we are in a phase after
-      // legalization, so zero extend to the smallest operand type if required.
-      if (ST != VT && Level != BeforeLegalizeTypes)
-        FoldOp = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LHS), ST, FoldOp);
-
-      if (FoldOp.getOpcode() != ISD::UNDEF &&
-          FoldOp.getOpcode() != ISD::Constant &&
-          FoldOp.getOpcode() != ISD::ConstantFP)
-        break;
-      Ops.push_back(FoldOp);
-      AddToWorklist(FoldOp.getNode());
-    }
-
-    if (Ops.size() == LHS.getNumOperands())
-      return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), LHS.getValueType(), Ops);
-  }
+  // See if we can constant fold the vector operation.
+  if (SDValue Fold = DAG.FoldConstantVectorArithmetic(
+          N->getOpcode(), SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags()))
+    return Fold;
 
   // Try to convert a constant mask AND into a shuffle clear mask.
   if (SDValue Shuffle = XformToShuffleWithZero(N))

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp?rev=250118&r1=250117&r2=250118&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Mon Oct 12 18:00:11 2015
@@ -3010,44 +3010,9 @@ SDValue SelectionDAG::getNode(unsigned O
       case ISD::CTTZ:
       case ISD::CTTZ_ZERO_UNDEF:
       case ISD::CTPOP: {
-        EVT SVT = VT.getScalarType();
-        EVT InVT = BV->getValueType(0);
-        EVT InSVT = InVT.getScalarType();
-
-        // Find legal integer scalar type for constant promotion and
-        // ensure that its scalar size is at least as large as source.
-        EVT LegalSVT = SVT;
-        if (SVT.isInteger()) {
-          LegalSVT = TLI->getTypeToTransformTo(*getContext(), SVT);
-          if (LegalSVT.bitsLT(SVT)) break;
-        }
-
-        // Let the above scalar folding handle the folding of each element.
-        SmallVector<SDValue, 8> Ops;
-        for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
-          SDValue OpN = BV->getOperand(i);
-          EVT OpVT = OpN.getValueType();
-
-          // Build vector (integer) scalar operands may need implicit
-          // truncation - do this before constant folding.
-          if (OpVT.isInteger() && OpVT.bitsGT(InSVT))
-            OpN = getNode(ISD::TRUNCATE, DL, InSVT, OpN);
-
-          OpN = getNode(Opcode, DL, SVT, OpN);
-
-          // Legalize the (integer) scalar constant if necessary.
-          if (LegalSVT != SVT)
-            OpN = getNode(ISD::ANY_EXTEND, DL, LegalSVT, OpN);
-
-          if (OpN.getOpcode() != ISD::UNDEF &&
-              OpN.getOpcode() != ISD::Constant &&
-              OpN.getOpcode() != ISD::ConstantFP)
-            break;
-          Ops.push_back(OpN);
-        }
-        if (Ops.size() == VT.getVectorNumElements())
-          return getNode(ISD::BUILD_VECTOR, DL, VT, Ops);
-        break;
+        SDValue Ops = { Operand };
+        if (SDValue Fold = FoldConstantVectorArithmetic(Opcode, DL, VT, Ops))
+          return Fold;
       }
       }
     }
@@ -3348,6 +3313,93 @@ SDValue SelectionDAG::FoldConstantArithm
   return getNode(ISD::BUILD_VECTOR, SDLoc(), VT, Outputs);
 }
 
+SDValue SelectionDAG::FoldConstantVectorArithmetic(unsigned Opcode, SDLoc DL,
+                                                   EVT VT,
+                                                   ArrayRef<SDValue> Ops,
+                                                   const SDNodeFlags *Flags) {
+  // If the opcode is a target-specific ISD node, there's nothing we can
+  // do here and the operand rules may not line up with the below, so
+  // bail early.
+  if (Opcode >= ISD::BUILTIN_OP_END)
+    return SDValue();
+
+  // We can only fold vectors - maybe merge with FoldConstantArithmetic someday?
+  if (!VT.isVector())
+    return SDValue();
+
+  unsigned NumElts = VT.getVectorNumElements();
+
+  auto IsSameVectorSize = [&](const SDValue &Op) {
+    return Op.getValueType().isVector() &&
+           Op.getValueType().getVectorNumElements() == NumElts;
+  };
+
+  auto IsConstantBuildVectorOrUndef = [&](const SDValue &Op) {
+    BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op);
+    return (Op.getOpcode() == ISD::UNDEF) || (BV && BV->isConstant());
+  };
+
+  // All operands must be vector types with the same number of elements as
+  // the result type and must be either UNDEF or a build vector of constant
+  // or UNDEF scalars.
+  if (!std::all_of(Ops.begin(), Ops.end(), IsConstantBuildVectorOrUndef) ||
+      !std::all_of(Ops.begin(), Ops.end(), IsSameVectorSize))
+    return SDValue();
+
+  // Find legal integer scalar type for constant promotion and
+  // ensure that its scalar size is at least as large as source.
+  EVT SVT = VT.getScalarType();
+  EVT LegalSVT = SVT;
+  if (SVT.isInteger()) {
+    LegalSVT = TLI->getTypeToTransformTo(*getContext(), SVT);
+    if (LegalSVT.bitsLT(SVT))
+      return SDValue();
+  }
+
+  // Constant fold each scalar lane separately.
+  SmallVector<SDValue, 4> ScalarResults;
+  for (unsigned i = 0; i != NumElts; i++) {
+    SmallVector<SDValue, 4> ScalarOps;
+    for (SDValue Op : Ops) {
+      EVT InSVT = Op->getValueType(0).getScalarType();
+      BuildVectorSDNode *InBV = dyn_cast<BuildVectorSDNode>(Op);
+      if (!InBV) {
+        // We've checked that this is UNDEF above.
+        ScalarOps.push_back(getUNDEF(LegalSVT));
+        continue;
+      }
+
+      SDValue ScalarOp = InBV->getOperand(i);
+      EVT ScalarVT = ScalarOp.getValueType();
+
+      // Build vector (integer) scalar operands may need implicit
+      // truncation - do this before constant folding.
+      if (ScalarVT.isInteger() && ScalarVT.bitsGT(InSVT))
+        ScalarOp = getNode(ISD::TRUNCATE, DL, InSVT, ScalarOp);
+
+      ScalarOps.push_back(ScalarOp);
+    }
+
+    // Constant fold the scalar operands.
+    SDValue ScalarResult = getNode(Opcode, DL, SVT, ScalarOps, Flags);
+
+    // Legalize the (integer) scalar constant if necessary.
+    if (LegalSVT != SVT)
+      ScalarResult = getNode(ISD::ANY_EXTEND, DL, LegalSVT, ScalarResult);
+
+    // Scalar folding only succeeded if the result is a constant or UNDEF.
+    if (ScalarResult.getOpcode() != ISD::UNDEF &&
+        ScalarResult.getOpcode() != ISD::Constant &&
+        ScalarResult.getOpcode() != ISD::ConstantFP)
+      return SDValue();
+    ScalarResults.push_back(ScalarResult);
+  }
+
+  assert(ScalarResults.size() == NumElts &&
+         "Unexpected number of scalar results for BUILD_VECTOR");
+  return getNode(ISD::BUILD_VECTOR, DL, VT, ScalarResults);
+}
+
 SDValue SelectionDAG::getNode(unsigned Opcode, SDLoc DL, EVT VT, SDValue N1,
                               SDValue N2, const SDNodeFlags *Flags) {
   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);




More information about the llvm-commits mailing list