[llvm] [DAG] Add wrappers for insert_vector_elt and extract_vector_elt (PR #139141)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 8 13:07:03 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

As with the recently added subvector variants, provide the unsigned index operand to simplify a bunch of code.

---
Full diff: https://github.com/llvm/llvm-project/pull/139141.diff


3 Files Affected:

- (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+16) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+9-15) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+15-27) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1b990e29158fd..251ae14c79068 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -924,6 +924,22 @@ class SelectionDAG {
   /// Example: shuffle A, B, <0,5,2,7> -> shuffle B, A, <4,1,6,3>
   SDValue getCommutedVectorShuffle(const ShuffleVectorSDNode &SV);
 
+  /// Extract element at \p Idx from \o Vec.  See EXTRACT_VECTOR_ELT
+  /// description for result type handling.
+  SDValue getExtractVectorElt(const SDLoc &DL, EVT VT, SDValue Vec,
+                              unsigned Idx) {
+    return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Vec,
+                   getVectorIdxConstant(Idx, DL));
+  }
+
+  /// Insert \p Elt into \p Vec at offset \p Idx.  See INSERT_VECTOR_ELT
+  /// description for element type handling.
+  SDValue getInsertVectorElt(const SDLoc &DL, SDValue Vec, SDValue Elt,
+                             unsigned Idx) {
+    return getNode(ISD::INSERT_VECTOR_ELT, DL, Vec.getValueType(), Vec, Elt,
+                   getVectorIdxConstant(Idx, DL));
+  }
+
   /// Insert \p SubVec at the \p Idx element of \p Vec.
   SDValue getInsertSubvector(const SDLoc &DL, SDValue Vec, SDValue SubVec,
                              unsigned Idx) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index effe08cdd44f8..bbf1b0fd590ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -3244,8 +3244,7 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
       if (LegalSVT.bitsLT(SVT))
         return SDValue();
     }
-    return getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), LegalSVT, SrcVector,
-                   getVectorIdxConstant(SplatIdx, SDLoc(V)));
+    return getExtractVectorElt(SDLoc(V), LegalSVT, SrcVector, SplatIdx);
   }
   return SDValue();
 }
@@ -7557,11 +7556,10 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     // elements.
     if (N2C && N1.getOpcode() == ISD::CONCAT_VECTORS &&
         N1.getOperand(0).getValueType().isFixedLengthVector()) {
-      unsigned Factor =
-        N1.getOperand(0).getValueType().getVectorNumElements();
-      return getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
-                     N1.getOperand(N2C->getZExtValue() / Factor),
-                     getVectorIdxConstant(N2C->getZExtValue() % Factor, DL));
+      unsigned Factor = N1.getOperand(0).getValueType().getVectorNumElements();
+      return getExtractVectorElt(DL, VT,
+                                 N1.getOperand(N2C->getZExtValue() / Factor),
+                                 N2C->getZExtValue() % Factor);
     }
 
     // EXTRACT_VECTOR_ELT of BUILD_VECTOR or SPLAT_VECTOR is often formed while
@@ -8624,8 +8622,7 @@ static SDValue getMemsetStores(SelectionDAG &DAG, const SDLoc &dl,
         // Target which can combine store(extractelement VectorTy, Idx) can get
         // the smaller value for free.
         SDValue TailValue = DAG.getNode(ISD::BITCAST, dl, SVT, MemSetValue);
-        Value = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, TailValue,
-                            DAG.getVectorIdxConstant(Index, dl));
+        Value = DAG.getExtractVectorElt(dl, VT, TailValue, Index);
       } else
         Value = getMemsetValue(Src, VT, DAG, dl);
     }
@@ -12775,8 +12772,7 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
 
         // A vector operand; extract a single element.
         EVT OperandEltVT = OperandVT.getVectorElementType();
-        Operands[j] = getNode(ISD::EXTRACT_VECTOR_ELT, dl, OperandEltVT,
-                              Operand, getVectorIdxConstant(i, dl));
+        Operands[j] = getExtractVectorElt(dl, OperandEltVT, Operand, i);
       }
 
       SDValue EltOp = getNode(N->getOpcode(), dl, {EltVT, EltVT1}, Operands);
@@ -12810,8 +12806,7 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
       if (OperandVT.isVector()) {
         // A vector operand; extract a single element.
         EVT OperandEltVT = OperandVT.getVectorElementType();
-        Operands[j] = getNode(ISD::EXTRACT_VECTOR_ELT, dl, OperandEltVT,
-                              Operand, getVectorIdxConstant(i, dl));
+        Operands[j] = getExtractVectorElt(dl, OperandEltVT, Operand, i);
       } else {
         // A scalar operand; just use it as is.
         Operands[j] = Operand;
@@ -13090,8 +13085,7 @@ void SelectionDAG::ExtractVectorElements(SDValue Op,
     EltVT = VT.getVectorElementType();
   SDLoc SL(Op);
   for (unsigned i = Start, e = Start + Count; i != e; ++i) {
-    Args.push_back(getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, Op,
-                           getVectorIdxConstant(i, SL)));
+    Args.push_back(getExtractVectorElt(SL, EltVT, Op, i));
   }
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d65e921dfc660..68f01d1d675b7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -3805,8 +3805,7 @@ static SDValue lowerBuildVectorViaDominantValues(SDValue Op, SelectionDAG &DAG,
       if (V.isUndef() || !Processed.insert(V).second)
         continue;
       if (ValueCounts[V] == 1) {
-        Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V,
-                          DAG.getVectorIdxConstant(OpIdx.index(), DL));
+        Vec = DAG.getInsertVectorElt(DL, Vec, V, OpIdx.index());
       } else {
         // Blend in all instances of this value using a VSELECT, using a
         // mask where each bit signals whether that element is the one
@@ -3963,10 +3962,9 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
     if (ViaIntVT == MVT::i32)
       SplatValue = SignExtend64<32>(SplatValue);
 
-    SDValue Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ViaVecVT,
-                              DAG.getUNDEF(ViaVecVT),
-                              DAG.getSignedConstant(SplatValue, DL, XLenVT),
-                              DAG.getVectorIdxConstant(0, DL));
+    SDValue Vec = DAG.getInsertVectorElt(
+        DL, DAG.getUNDEF(ViaVecVT),
+        DAG.getSignedConstant(SplatValue, DL, XLenVT), 0);
     if (ViaVecLen != 1)
       Vec = DAG.getExtractSubvector(DL, MVT::getVectorVT(ViaIntVT, 1), Vec, 0);
     return DAG.getBitcast(VT, Vec);
@@ -7180,9 +7178,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
         EVT BVT = EVT::getVectorVT(*DAG.getContext(), Op0VT, 1);
         if (!isTypeLegal(BVT))
           return SDValue();
-        return DAG.getBitcast(VT, DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, BVT,
-                                              DAG.getUNDEF(BVT), Op0,
-                                              DAG.getVectorIdxConstant(0, DL)));
+        return DAG.getBitcast(
+            VT, DAG.getInsertVectorElt(DL, DAG.getUNDEF(BVT), Op0, 0));
       }
       return SDValue();
     }
@@ -7194,8 +7191,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
       if (!isTypeLegal(BVT))
         return SDValue();
       SDValue BVec = DAG.getBitcast(BVT, Op0);
-      return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
-                         DAG.getVectorIdxConstant(0, DL));
+      return DAG.getExtractVectorElt(DL, VT, BVec, 0);
     }
     return SDValue();
   }
@@ -9916,8 +9912,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
 
   if (!EltVT.isInteger()) {
     // Floating-point extracts are handled in TableGen.
-    return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec,
-                       DAG.getVectorIdxConstant(0, DL));
+    return DAG.getExtractVectorElt(DL, EltVT, Vec, 0);
   }
 
   SDValue Elt0 = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Vec);
@@ -10321,8 +10316,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Res);
   }
   case Intrinsic::riscv_vfmv_f_s:
-    return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(),
-                       Op.getOperand(1), DAG.getVectorIdxConstant(0, DL));
+    return DAG.getExtractVectorElt(DL, Op.getValueType(), Op.getOperand(1), 0);
   case Intrinsic::riscv_vmv_v_x:
     return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2),
                             Op.getOperand(3), Op.getSimpleValueType(), DL, DAG,
@@ -10856,8 +10850,7 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT,
   SDValue Policy = DAG.getTargetConstant(RISCVVType::TAIL_AGNOSTIC, DL, XLenVT);
   SDValue Ops[] = {PassThru, Vec, InitialValue, Mask, VL, Policy};
   SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, Ops);
-  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
-                     DAG.getVectorIdxConstant(0, DL));
+  return DAG.getExtractVectorElt(DL, ResVT, Reduction, 0);
 }
 
 SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
@@ -10902,8 +10895,7 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
   case ISD::UMIN:
   case ISD::SMAX:
   case ISD::SMIN:
-    StartV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec,
-                         DAG.getVectorIdxConstant(0, DL));
+    StartV = DAG.getExtractVectorElt(DL, VecEltVT, Vec, 0);
   }
   return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), StartV, Vec,
                            Mask, VL, DL, DAG, Subtarget);
@@ -10934,9 +10926,7 @@ getRVVFPReductionOpAndOperands(SDValue Op, SelectionDAG &DAG, EVT EltVT,
   case ISD::VECREDUCE_FMAXIMUM:
   case ISD::VECREDUCE_FMIN:
   case ISD::VECREDUCE_FMAX: {
-    SDValue Front =
-        DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0),
-                    DAG.getVectorIdxConstant(0, DL));
+    SDValue Front = DAG.getExtractVectorElt(DL, EltVT, Op.getOperand(0), 0);
     unsigned RVVOpc =
         (Opcode == ISD::VECREDUCE_FMIN || Opcode == ISD::VECREDUCE_FMINIMUM)
             ? RISCVISD::VECREDUCE_FMIN_VL
@@ -14055,8 +14045,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
       EVT BVT = EVT::getVectorVT(*DAG.getContext(), VT, 1);
       if (isTypeLegal(BVT)) {
         SDValue BVec = DAG.getBitcast(BVT, Op0);
-        Results.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, BVec,
-                                      DAG.getVectorIdxConstant(0, DL)));
+        Results.push_back(DAG.getExtractVectorElt(DL, VT, BVec, 0));
       }
     }
     break;
@@ -18204,12 +18193,11 @@ static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
   if (ConcatVT.getVectorElementType() != InVal.getValueType())
     return SDValue();
   unsigned ConcatNumElts = ConcatVT.getVectorNumElements();
-  SDValue NewIdx = DAG.getVectorIdxConstant(Elt % ConcatNumElts, DL);
+  unsigned NewIdx = Elt % ConcatNumElts;
 
   unsigned ConcatOpIdx = Elt / ConcatNumElts;
   SDValue ConcatOp = InVec.getOperand(ConcatOpIdx);
-  ConcatOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ConcatVT,
-                         ConcatOp, InVal, NewIdx);
+  ConcatOp = DAG.getInsertVectorElt(DL, ConcatOp, InVal, NewIdx);
 
   SmallVector<SDValue> ConcatOps(InVec->ops());
   ConcatOps[ConcatOpIdx] = ConcatOp;

``````````

</details>


https://github.com/llvm/llvm-project/pull/139141


More information about the llvm-commits mailing list