[llvm] [DAG] Add wrappers for insert_vector_elt and extract_vector_elt (PR #139141)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 8 13:07:03 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Philip Reames (preames)
<details>
<summary>Changes</summary>
As with the recently added subvector variants, provide the unsigned index operand to simplify a bunch of code.
---
Full diff: https://github.com/llvm/llvm-project/pull/139141.diff
3 Files Affected:
- (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+16)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+9-15)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+15-27)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1b990e29158fd..251ae14c79068 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -924,6 +924,22 @@ class SelectionDAG {
/// Example: shuffle A, B, <0,5,2,7> -> shuffle B, A, <4,1,6,3>
SDValue getCommutedVectorShuffle(const ShuffleVectorSDNode &SV);
+ /// Extract element at \p Idx from \o Vec. See EXTRACT_VECTOR_ELT
+ /// description for result type handling.
+ SDValue getExtractVectorElt(const SDLoc &DL, EVT VT, SDValue Vec,
+ unsigned Idx) {
+ return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Vec,
+ getVectorIdxConstant(Idx, DL));
+ }
+
+ /// Insert \p Elt into \p Vec at offset \p Idx. See INSERT_VECTOR_ELT
+ /// description for element type handling.
+ SDValue getInsertVectorElt(const SDLoc &DL, SDValue Vec, SDValue Elt,
+ unsigned Idx) {
+ return getNode(ISD::INSERT_VECTOR_ELT, DL, Vec.getValueType(), Vec, Elt,
+ getVectorIdxConstant(Idx, DL));
+ }
+
/// Insert \p SubVec at the \p Idx element of \p Vec.
SDValue getInsertSubvector(const SDLoc &DL, SDValue Vec, SDValue SubVec,
unsigned Idx) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index effe08cdd44f8..bbf1b0fd590ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3244,8 +3244,7 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
if (LegalSVT.bitsLT(SVT))
return SDValue();
}
- return getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), LegalSVT, SrcVector,
- getVectorIdxConstant(SplatIdx, SDLoc(V)));
+ return getExtractVectorElt(SDLoc(V), LegalSVT, SrcVector, SplatIdx);
}
return SDValue();
}
@@ -7557,11 +7556,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
// elements.
if (N2C && N1.getOpcode() == ISD::CONCAT_VECTORS &&
N1.getOperand(0).getValueType().isFixedLengthVector()) {
- unsigned Factor =
- N1.getOperand(0).getValueType().getVectorNumElements();
- return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
- N1.getOperand(N2C->getZExtValue() / Factor),
- getVectorIdxConstant(N2C->getZExtValue() % Factor, DL));
+ unsigned Factor = N1.getOperand(0).getValueType().getVectorNumElements();
+ return getExtractVectorElt(DL, VT,
+ N1.getOperand(N2C->getZExtValue() / Factor),
+ N2C->getZExtValue() % Factor);
}
// EXTRACT_VECTOR_ELT of BUILD_VECTOR or SPLAT_VECTOR is often formed while
@@ -8624,8 +8622,7 @@ static SDValue getMemsetStores(SelectionDAG &DAG, const SDLoc &dl,
// Target which can combine store(extractelement VectorTy, Idx) can get
// the smaller value for free.
SDValue TailValue = DAG.getNode(ISD::BITCAST, dl, SVT, MemSetValue);
- Value = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, TailValue,
- DAG.getVectorIdxConstant(Index, dl));
+ Value = DAG.getExtractVectorElt(dl, VT, TailValue, Index);
} else
Value = getMemsetValue(Src, VT, DAG, dl);
}
@@ -12775,8 +12772,7 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
// A vector operand; extract a single element.
EVT OperandEltVT = OperandVT.getVectorElementType();
- Operands[j] = getNode(ISD::EXTRACT_VECTOR_ELT, dl, OperandEltVT,
- Operand, getVectorIdxConstant(i, dl));
+ Operands[j] = getExtractVectorElt(dl, OperandEltVT, Operand, i);
}
SDValue EltOp = getNode(N->getOpcode(), dl, {EltVT, EltVT1}, Operands);
@@ -12810,8 +12806,7 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
if (OperandVT.isVector()) {
// A vector operand; extract a single element.
EVT OperandEltVT = OperandVT.getVectorElementType();
- Operands[j] = getNode(ISD::EXTRACT_VECTOR_ELT, dl, OperandEltVT,
- Operand, getVectorIdxConstant(i, dl));
+ Operands[j] = getExtractVectorElt(dl, OperandEltVT, Operand, i);
} else {
// A scalar operand; just use it as is.
Operands[j] = Operand;
@@ -13090,8 +13085,7 @@ void SelectionDAG::ExtractVectorElements(SDValue Op,
EltVT = VT.getVectorElementType();
SDLoc SL(Op);
for (unsigned i = Start, e = Start + Count; i != e; ++i) {
- Args.push_back(getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, Op,
- getVectorIdxConstant(i, SL)));
+ Args.push_back(getExtractVectorElt(SL, EltVT, Op, i));
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d65e921dfc660..68f01d1d675b7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -3805,8 +3805,7 @@ static SDValue lowerBuildVectorViaDominantValues(SDValue Op, SelectionDAG &DAG,
if (V.isUndef() || !Processed.insert(V).second)
continue;
if (ValueCounts[V] == 1) {
- Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V,
- DAG.getVectorIdxConstant(OpIdx.index(), DL));
+ Vec = DAG.getInsertVectorElt(DL, Vec, V, OpIdx.index());
} else {
// Blend in all instances of this value using a VSELECT, using a
// mask where each bit signals whether that element is the one
@@ -3963,10 +3962,9 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
if (ViaIntVT == MVT::i32)
SplatValue = SignExtend64<32>(SplatValue);
- SDValue Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ViaVecVT,
- DAG.getUNDEF(ViaVecVT),
- DAG.getSignedConstant(SplatValue, DL, XLenVT),
- DAG.getVectorIdxConstant(0, DL));
+ SDValue Vec = DAG.getInsertVectorElt(
+ DL, DAG.getUNDEF(ViaVecVT),
+ DAG.getSignedConstant(SplatValue, DL, XLenVT), 0);
if (ViaVecLen != 1)
Vec = DAG.getExtractSubvector(DL, MVT::getVectorVT(ViaIntVT, 1), Vec, 0);
return DAG.getBitcast(VT, Vec);
@@ -7180,9 +7178,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
EVT BVT = EVT::getVectorVT(*DAG.getContext(), Op0VT, 1);
if (!isTypeLegal(BVT))
return SDValue();
- return DAG.getBitcast(VT, DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, BVT,
- DAG.getUNDEF(BVT), Op0,
- DAG.getVectorIdxConstant(0, DL)));
+ return DAG.getBitcast(
+ VT, DAG.getInsertVectorElt(DL, DAG.getUNDEF(BVT), Op0, 0));
}
return SDValue();
}
@@ -7194,8 +7191,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
if (!isTypeLegal(BVT))
return SDValue();
SDValue BVec = DAG.getBitcast(BVT, Op0);
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
- DAG.getVectorIdxConstant(0, DL));
+ return DAG.getExtractVectorElt(DL, VT, BVec, 0);
}
return SDValue();
}
@@ -9916,8 +9912,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
if (!EltVT.isInteger()) {
// Floating-point extracts are handled in TableGen.
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec,
- DAG.getVectorIdxConstant(0, DL));
+ return DAG.getExtractVectorElt(DL, EltVT, Vec, 0);
}
SDValue Elt0 = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec);
@@ -10321,8 +10316,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
}
case Intrinsic::riscv_vfmv_f_s:
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(),
- Op.getOperand(1), DAG.getVectorIdxConstant(0, DL));
+ return DAG.getExtractVectorElt(DL, Op.getValueType(), Op.getOperand(1), 0);
case Intrinsic::riscv_vmv_v_x:
return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2),
Op.getOperand(3), Op.getSimpleValueType(), DL, DAG,
@@ -10856,8 +10850,7 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT,
SDValue Policy = DAG.getTargetConstant(RISCVVType::TAIL_AGNOSTIC, DL, XLenVT);
SDValue Ops[] = {PassThru, Vec, InitialValue, Mask, VL, Policy};
SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, Ops);
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
- DAG.getVectorIdxConstant(0, DL));
+ return DAG.getExtractVectorElt(DL, ResVT, Reduction, 0);
}
SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
@@ -10902,8 +10895,7 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
case ISD::UMIN:
case ISD::SMAX:
case ISD::SMIN:
- StartV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec,
- DAG.getVectorIdxConstant(0, DL));
+ StartV = DAG.getExtractVectorElt(DL, VecEltVT, Vec, 0);
}
return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), StartV, Vec,
Mask, VL, DL, DAG, Subtarget);
@@ -10934,9 +10926,7 @@ getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT,
case ISD::VECREDUCE_FMAXIMUM:
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAX: {
- SDValue Front =
- DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0),
- DAG.getVectorIdxConstant(0, DL));
+ SDValue Front = DAG.getExtractVectorElt(DL, EltVT, Op.getOperand(0), 0);
unsigned RVVOpc =
(Opcode == ISD::VECREDUCE_FMIN || Opcode == ISD::VECREDUCE_FMINIMUM)
? RISCVISD::VECREDUCE_FMIN_VL
@@ -14055,8 +14045,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
EVT BVT = EVT::getVectorVT(*DAG.getContext(), VT, 1);
if (isTypeLegal(BVT)) {
SDValue BVec = DAG.getBitcast(BVT, Op0);
- Results.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
- DAG.getVectorIdxConstant(0, DL)));
+ Results.push_back(DAG.getExtractVectorElt(DL, VT, BVec, 0));
}
}
break;
@@ -18204,12 +18193,11 @@ static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
if (ConcatVT.getVectorElementType() != InVal.getValueType())
return SDValue();
unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
- SDValue NewIdx = DAG.getVectorIdxConstant(Elt % ConcatNumElts, DL);
+ unsigned NewIdx = Elt % ConcatNumElts;
unsigned ConcatOpIdx = Elt / ConcatNumElts;
SDValue ConcatOp = InVec.getOperand(ConcatOpIdx);
- ConcatOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ConcatVT,
- ConcatOp, InVal, NewIdx);
+ ConcatOp = DAG.getInsertVectorElt(DL, ConcatOp, InVal, NewIdx);
SmallVector<SDValue> ConcatOps(InVec->ops());
ConcatOps[ConcatOpIdx] = ConcatOp;
``````````
</details>
https://github.com/llvm/llvm-project/pull/139141
More information about the llvm-commits
mailing list