[llvm] r250118 - [SelectionDAG] Add common vector constant folding helper function
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 12 16:00:11 PDT 2015
Author: rksimon
Date: Mon Oct 12 18:00:11 2015
New Revision: 250118
URL: http://llvm.org/viewvc/llvm-project?rev=250118&view=rev
Log:
[SelectionDAG] Add common vector constant folding helper function
We have a number of functions that implement constant folding of vectors (unary and binary ops) in near identical manners (and the differences don't appear to be critical).
This patch introduces a common implementation (SelectionDAG::FoldConstantVectorArithmetic) and calls this in both the unary and binary op cases.
After this initial patch I intend to begin enabling vector constant folding for a wider number of opcodes in SelectionDAG::getNode().
Differential Revision: http://reviews.llvm.org/D13665
Modified:
llvm/trunk/include/llvm/CodeGen/SelectionDAG.h
llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Modified: llvm/trunk/include/llvm/CodeGen/SelectionDAG.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/SelectionDAG.h?rev=250118&r1=250117&r2=250118&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/SelectionDAG.h (original)
+++ llvm/trunk/include/llvm/CodeGen/SelectionDAG.h Mon Oct 12 18:00:11 2015
@@ -1166,6 +1166,10 @@ public:
const ConstantSDNode *Cst1,
const ConstantSDNode *Cst2);
+ SDValue FoldConstantVectorArithmetic(unsigned Opcode, SDLoc DL,
+ EVT VT, ArrayRef<SDValue> Ops,
+ const SDNodeFlags *Flags = nullptr);
+
/// Constant fold a setcc to true or false.
SDValue FoldSetCC(EVT VT, SDValue N1,
SDValue N2, ISD::CondCode Cond, SDLoc dl);
Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=250118&r1=250117&r2=250118&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Mon Oct 12 18:00:11 2015
@@ -13454,70 +13454,12 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNo
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
+ SDValue Ops[] = {LHS, RHS};
- // If the LHS and RHS are BUILD_VECTOR nodes, see if we can constant fold
- // this operation.
- if (LHS.getOpcode() == ISD::BUILD_VECTOR &&
- RHS.getOpcode() == ISD::BUILD_VECTOR) {
- // Check if both vectors are constants. If not bail out.
- if (!(cast<BuildVectorSDNode>(LHS)->isConstant() &&
- cast<BuildVectorSDNode>(RHS)->isConstant()))
- return SDValue();
-
- SmallVector<SDValue, 8> Ops;
- for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) {
- SDValue LHSOp = LHS.getOperand(i);
- SDValue RHSOp = RHS.getOperand(i);
-
- // Can't fold divide by zero.
- if (N->getOpcode() == ISD::SDIV || N->getOpcode() == ISD::UDIV ||
- N->getOpcode() == ISD::FDIV) {
- if (isNullConstant(RHSOp) || (RHSOp.getOpcode() == ISD::ConstantFP &&
- cast<ConstantFPSDNode>(RHSOp.getNode())->isZero()))
- break;
- }
-
- EVT VT = LHSOp.getValueType();
- EVT RVT = RHSOp.getValueType();
- EVT ST = VT;
-
- if (RVT.getSizeInBits() < VT.getSizeInBits())
- ST = RVT;
-
- // Integer BUILD_VECTOR operands may have types larger than the element
- // size (e.g., when the element type is not legal). Prior to type
- // legalization, the types may not match between the two BUILD_VECTORS.
- // Truncate the operands to make them match.
- if (VT.getSizeInBits() != LHS.getValueType().getScalarSizeInBits()) {
- EVT ScalarT = LHS.getValueType().getScalarType();
- LHSOp = DAG.getNode(ISD::TRUNCATE, SDLoc(N), ScalarT, LHSOp);
- VT = LHSOp.getValueType();
- }
- if (RVT.getSizeInBits() != RHS.getValueType().getScalarSizeInBits()) {
- EVT ScalarT = RHS.getValueType().getScalarType();
- RHSOp = DAG.getNode(ISD::TRUNCATE, SDLoc(N), ScalarT, RHSOp);
- RVT = RHSOp.getValueType();
- }
-
- SDValue FoldOp = DAG.getNode(N->getOpcode(), SDLoc(LHS), VT,
- LHSOp, RHSOp, N->getFlags());
-
- // We need the resulting constant to be legal if we are in a phase after
- // legalization, so zero extend to the smallest operand type if required.
- if (ST != VT && Level != BeforeLegalizeTypes)
- FoldOp = DAG.getNode(ISD::ANY_EXTEND, SDLoc(LHS), ST, FoldOp);
-
- if (FoldOp.getOpcode() != ISD::UNDEF &&
- FoldOp.getOpcode() != ISD::Constant &&
- FoldOp.getOpcode() != ISD::ConstantFP)
- break;
- Ops.push_back(FoldOp);
- AddToWorklist(FoldOp.getNode());
- }
-
- if (Ops.size() == LHS.getNumOperands())
- return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), LHS.getValueType(), Ops);
- }
+ // See if we can constant fold the vector operation.
+ if (SDValue Fold = DAG.FoldConstantVectorArithmetic(
+ N->getOpcode(), SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags()))
+ return Fold;
// Try to convert a constant mask AND into a shuffle clear mask.
if (SDValue Shuffle = XformToShuffleWithZero(N))
Modified: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp?rev=250118&r1=250117&r2=250118&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp Mon Oct 12 18:00:11 2015
@@ -3010,44 +3010,9 @@ SDValue SelectionDAG::getNode(unsigned O
case ISD::CTTZ:
case ISD::CTTZ_ZERO_UNDEF:
case ISD::CTPOP: {
- EVT SVT = VT.getScalarType();
- EVT InVT = BV->getValueType(0);
- EVT InSVT = InVT.getScalarType();
-
- // Find legal integer scalar type for constant promotion and
- // ensure that its scalar size is at least as large as source.
- EVT LegalSVT = SVT;
- if (SVT.isInteger()) {
- LegalSVT = TLI->getTypeToTransformTo(*getContext(), SVT);
- if (LegalSVT.bitsLT(SVT)) break;
- }
-
- // Let the above scalar folding handle the folding of each element.
- SmallVector<SDValue, 8> Ops;
- for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
- SDValue OpN = BV->getOperand(i);
- EVT OpVT = OpN.getValueType();
-
- // Build vector (integer) scalar operands may need implicit
- // truncation - do this before constant folding.
- if (OpVT.isInteger() && OpVT.bitsGT(InSVT))
- OpN = getNode(ISD::TRUNCATE, DL, InSVT, OpN);
-
- OpN = getNode(Opcode, DL, SVT, OpN);
-
- // Legalize the (integer) scalar constant if necessary.
- if (LegalSVT != SVT)
- OpN = getNode(ISD::ANY_EXTEND, DL, LegalSVT, OpN);
-
- if (OpN.getOpcode() != ISD::UNDEF &&
- OpN.getOpcode() != ISD::Constant &&
- OpN.getOpcode() != ISD::ConstantFP)
- break;
- Ops.push_back(OpN);
- }
- if (Ops.size() == VT.getVectorNumElements())
- return getNode(ISD::BUILD_VECTOR, DL, VT, Ops);
- break;
+ SDValue Ops = { Operand };
+ if (SDValue Fold = FoldConstantVectorArithmetic(Opcode, DL, VT, Ops))
+ return Fold;
}
}
}
@@ -3348,6 +3313,93 @@ SDValue SelectionDAG::FoldConstantArithm
return getNode(ISD::BUILD_VECTOR, SDLoc(), VT, Outputs);
}
+SDValue SelectionDAG::FoldConstantVectorArithmetic(unsigned Opcode, SDLoc DL,
+ EVT VT,
+ ArrayRef<SDValue> Ops,
+ const SDNodeFlags *Flags) {
+ // If the opcode is a target-specific ISD node, there's nothing we can
+ // do here and the operand rules may not line up with the below, so
+ // bail early.
+ if (Opcode >= ISD::BUILTIN_OP_END)
+ return SDValue();
+
+ // We can only fold vectors - maybe merge with FoldConstantArithmetic someday?
+ if (!VT.isVector())
+ return SDValue();
+
+ unsigned NumElts = VT.getVectorNumElements();
+
+ auto IsSameVectorSize = [&](const SDValue &Op) {
+ return Op.getValueType().isVector() &&
+ Op.getValueType().getVectorNumElements() == NumElts;
+ };
+
+ auto IsConstantBuildVectorOrUndef = [&](const SDValue &Op) {
+ BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op);
+ return (Op.getOpcode() == ISD::UNDEF) || (BV && BV->isConstant());
+ };
+
+ // All operands must be vector types with the same number of elements as
+ // the result type and must be either UNDEF or a build vector of constant
+ // or UNDEF scalars.
+ if (!std::all_of(Ops.begin(), Ops.end(), IsConstantBuildVectorOrUndef) ||
+ !std::all_of(Ops.begin(), Ops.end(), IsSameVectorSize))
+ return SDValue();
+
+ // Find legal integer scalar type for constant promotion and
+ // ensure that its scalar size is at least as large as source.
+ EVT SVT = VT.getScalarType();
+ EVT LegalSVT = SVT;
+ if (SVT.isInteger()) {
+ LegalSVT = TLI->getTypeToTransformTo(*getContext(), SVT);
+ if (LegalSVT.bitsLT(SVT))
+ return SDValue();
+ }
+
+ // Constant fold each scalar lane separately.
+ SmallVector<SDValue, 4> ScalarResults;
+ for (unsigned i = 0; i != NumElts; i++) {
+ SmallVector<SDValue, 4> ScalarOps;
+ for (SDValue Op : Ops) {
+ EVT InSVT = Op->getValueType(0).getScalarType();
+ BuildVectorSDNode *InBV = dyn_cast<BuildVectorSDNode>(Op);
+ if (!InBV) {
+ // We've checked that this is UNDEF above.
+ ScalarOps.push_back(getUNDEF(LegalSVT));
+ continue;
+ }
+
+ SDValue ScalarOp = InBV->getOperand(i);
+ EVT ScalarVT = ScalarOp.getValueType();
+
+ // Build vector (integer) scalar operands may need implicit
+ // truncation - do this before constant folding.
+ if (ScalarVT.isInteger() && ScalarVT.bitsGT(InSVT))
+ ScalarOp = getNode(ISD::TRUNCATE, DL, InSVT, ScalarOp);
+
+ ScalarOps.push_back(ScalarOp);
+ }
+
+ // Constant fold the scalar operands.
+ SDValue ScalarResult = getNode(Opcode, DL, SVT, ScalarOps, Flags);
+
+ // Legalize the (integer) scalar constant if necessary.
+ if (LegalSVT != SVT)
+ ScalarResult = getNode(ISD::ANY_EXTEND, DL, LegalSVT, ScalarResult);
+
+ // Scalar folding only succeeded if the result is a constant or UNDEF.
+ if (ScalarResult.getOpcode() != ISD::UNDEF &&
+ ScalarResult.getOpcode() != ISD::Constant &&
+ ScalarResult.getOpcode() != ISD::ConstantFP)
+ return SDValue();
+ ScalarResults.push_back(ScalarResult);
+ }
+
+ assert(ScalarResults.size() == NumElts &&
+ "Unexpected number of scalar results for BUILD_VECTOR");
+ return getNode(ISD::BUILD_VECTOR, DL, VT, ScalarResults);
+}
+
SDValue SelectionDAG::getNode(unsigned Opcode, SDLoc DL, EVT VT, SDValue N1,
SDValue N2, const SDNodeFlags *Flags) {
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
More information about the llvm-commits
mailing list