[llvm] a200b0f - [DAG] Introduce getSplat utility for common dispatch pattern [nfc]
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 3 12:50:18 PDT 2022
Author: Philip Reames
Date: 2022-10-03T12:49:39-07:00
New Revision: a200b0fc256a890b3f72014d20fce9e49d75763b
URL: https://github.com/llvm/llvm-project/commit/a200b0fc256a890b3f72014d20fce9e49d75763b
DIFF: https://github.com/llvm/llvm-project/commit/a200b0fc256a890b3f72014d20fce9e49d75763b.diff
LOG: [DAG] Introduce getSplat utility for common dispatch pattern [nfc]
We have a very common pattern of dispatching between BUILD_VECTOR and SPLAT_VECTOR creation repeated in many cases in code. Common the pattern into a utility function.
Added:
Modified:
llvm/include/llvm/CodeGen/SelectionDAG.h
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 969199b6f381c..b6f71ab575b95 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -862,6 +862,16 @@ class SelectionDAG {
return getNode(ISD::SPLAT_VECTOR, DL, VT, Op);
}
+ /// Returns a node representing a splat of one value into all lanes
+ /// of the provided vector type. This is a utility which returns
+ /// either a BUILD_VECTOR or SPLAT_VECTOR depending on the
+ /// scalability of the desired vector type.
+ SDValue getSplat(EVT VT, const SDLoc &DL, SDValue Op) {
+ assert(VT.isVector() && "Can't splat to non-vector type");
+ return VT.isScalableVector() ?
+ getSplatVector(VT, DL, Op) : getSplatBuildVector(VT, DL, Op);
+ }
+
/// Returns a vector of type ResVT whose elements contain the linear sequence
/// <0, Step, Step * 2, Step * 3, ...>
SDValue getStepVector(const SDLoc &DL, EVT ResVT, APInt StepVal);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a6c9c46201568..9281b4eb92b05 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -3469,11 +3469,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
if (VT.isVector()) {
SDValue N1S = DAG.getSplatValue(N1, true);
if (N1S && N1S.getOpcode() == ISD::SUB &&
- isNullConstant(N1S.getOperand(0))) {
- if (VT.isScalableVector())
- return DAG.getSplatVector(VT, DL, N1S.getOperand(1));
- return DAG.getSplatBuildVector(VT, DL, N1S.getOperand(1));
- }
+ isNullConstant(N1S.getOperand(0)))
+ return DAG.getSplat(VT, DL, N1S.getOperand(1));
}
}
@@ -19778,11 +19775,8 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
if (!IndexC) {
// If this is variable insert to undef vector, it might be better to splat:
// inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
- if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
- if (VT.isScalableVector())
- return DAG.getSplatVector(VT, DL, InVal);
- return DAG.getSplatBuildVector(VT, DL, InVal);
- }
+ if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
+ return DAG.getSplat(VT, DL, InVal);
return SDValue();
}
@@ -23817,9 +23811,7 @@ static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
}
// bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
- if (VT.isScalableVector())
- return DAG.getSplatVector(VT, DL, ScalarBO);
- return DAG.getSplatBuildVector(VT, DL, ScalarBO);
+ return DAG.getSplat(VT, DL, ScalarBO);
}
/// Visit a binary vector operation, like ADD.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 0132bf4a8affb..6f0cde6fbddb3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -963,10 +963,7 @@ SDValue VectorLegalizer::ExpandSELECT(SDNode *Node) {
DAG.getConstant(0, DL, BitTy));
// Broadcast the mask so that the entire vector is all one or all zero.
- if (VT.isFixedLengthVector())
- Mask = DAG.getSplatBuildVector(MaskTy, DL, Mask);
- else
- Mask = DAG.getSplatVector(MaskTy, DL, Mask);
+ Mask = DAG.getSplat(MaskTy, DL, Mask);
// Bitcast the operands to be the same type as the mask.
// This is needed when we select between FP types because
@@ -1309,8 +1306,7 @@ SDValue VectorLegalizer::ExpandVP_MERGE(SDNode *Node) {
return DAG.UnrollVectorOp(Node);
SDValue StepVec = DAG.getStepVector(DL, EVLVecVT);
- SDValue SplatEVL = IsFixedLen ? DAG.getSplatBuildVector(EVLVecVT, DL, EVL)
- : DAG.getSplatVector(EVLVecVT, DL, EVL);
+ SDValue SplatEVL = DAG.getSplat(EVLVecVT, DL, EVL);
SDValue EVLMask =
DAG.getSetCC(DL, MaskVT, StepVec, SplatEVL, ISD::CondCode::SETULT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3c2a1166bb63c..80703066d5f02 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1607,11 +1607,8 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
}
SDValue Result(N, 0);
- if (VT.isScalableVector())
- Result = getSplatVector(VT, DL, Result);
- else if (VT.isVector())
- Result = getSplatBuildVector(VT, DL, Result);
-
+ if (VT.isVector())
+ Result = getSplat(VT, DL, Result);
return Result;
}
@@ -1663,10 +1660,8 @@ SDValue SelectionDAG::getConstantFP(const ConstantFP &V, const SDLoc &DL,
}
SDValue Result(N, 0);
- if (VT.isScalableVector())
- Result = getSplatVector(VT, DL, Result);
- else if (VT.isVector())
- Result = getSplatBuildVector(VT, DL, Result);
+ if (VT.isVector())
+ Result = getSplat(VT, DL, Result);
NewSDValueDbgMsg(Result, "Creating fp constant: ", this);
return Result;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 60e63035a8987..3308134044f27 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -1695,9 +1695,7 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
else
Op = DAG.getConstant(0, getCurSDLoc(), EltVT);
- if (isa<ScalableVectorType>(VecTy))
- return NodeMap[V] = DAG.getSplatVector(VT, getCurSDLoc(), Op);
- return NodeMap[V] = DAG.getSplatBuildVector(VT, getCurSDLoc(), Op);
+ return NodeMap[V] = DAG.getSplat(VT, getCurSDLoc(), Op);
}
llvm_unreachable("Unknown vector constant");
@@ -3904,10 +3902,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
if (IsVectorGEP && !N.getValueType().isVector()) {
LLVMContext &Context = *DAG.getContext();
EVT VT = EVT::getVectorVT(Context, N.getValueType(), VectorElementCount);
- if (VectorElementCount.isScalable())
- N = DAG.getSplatVector(VT, dl, N);
- else
- N = DAG.getSplatBuildVector(VT, dl, N);
+ N = DAG.getSplat(VT, dl, N);
}
for (gep_type_iterator GTI = gep_type_begin(&I), E = gep_type_end(&I);
@@ -3979,10 +3974,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
if (!IdxN.getValueType().isVector() && IsVectorGEP) {
EVT VT = EVT::getVectorVT(*Context, IdxN.getValueType(),
VectorElementCount);
- if (VectorElementCount.isScalable())
- IdxN = DAG.getSplatVector(VT, dl, IdxN);
- else
- IdxN = DAG.getSplatBuildVector(VT, dl, IdxN);
+ IdxN = DAG.getSplat(VT, dl, IdxN);
}
// If the index is smaller or larger than intptr_t, truncate or extend
@@ -7247,14 +7239,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
SDValue TripCount = getValue(I.getOperand(1));
auto VecTy = CCVT.changeVectorElementType(ElementVT);
- SDValue VectorIndex, VectorTripCount;
- if (VecTy.isScalableVector()) {
- VectorIndex = DAG.getSplatVector(VecTy, sdl, Index);
- VectorTripCount = DAG.getSplatVector(VecTy, sdl, TripCount);
- } else {
- VectorIndex = DAG.getSplatBuildVector(VecTy, sdl, Index);
- VectorTripCount = DAG.getSplatBuildVector(VecTy, sdl, TripCount);
- }
+ SDValue VectorIndex = DAG.getSplat(VecTy, sdl, Index);
+ SDValue VectorTripCount = DAG.getSplat(VecTy, sdl, TripCount);
SDValue VectorStep = DAG.getStepVector(sdl, VecTy);
SDValue VectorInduction = DAG.getNode(
ISD::UADDSAT, sdl, VecTy, VectorIndex, VectorStep);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index af8fdb77c9dd1..3b36521463d02 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4192,9 +4192,7 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
// Lower vector SELECTs to VSELECTs by splatting the condition.
if (VT.isVector()) {
MVT SplatCondVT = VT.changeVectorElementType(MVT::i1);
- SDValue CondSplat = VT.isScalableVector()
- ? DAG.getSplatVector(SplatCondVT, DL, CondV)
- : DAG.getSplatBuildVector(SplatCondVT, DL, CondV);
+ SDValue CondSplat = DAG.getSplat(SplatCondVT, DL, CondV);
return DAG.getNode(ISD::VSELECT, DL, VT, CondSplat, TrueV, FalseV);
}
More information about the llvm-commits
mailing list