[llvm-branch-commits] [llvm] [SelectionDAG] Fold extracts spanning concat operands (PR #200936)

Simon Pilgrim via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jun 2 02:18:54 PDT 2026


================
@@ -27545,6 +27545,69 @@ static SDValue foldExtractSubvectorFromShuffleVector(EVT NarrowVT, SDValue Src,
   return DAG.getVectorShuffle(NarrowVT, DL, NewOps[0], NewOps[1], NewMask);
 }
 
+static SDValue foldExtractSubvectorFromConcatVectors(EVT NVT, SDValue V,
+                                                     uint64_t ExtIdx,
+                                                     const SDLoc &DL,
+                                                     SelectionDAG &DAG,
+                                                     bool LegalOperations) {
+  if (V.getOpcode() != ISD::CONCAT_VECTORS)
+    return SDValue();
+
+  unsigned ExtNumElts = NVT.getVectorMinNumElements();
+  EVT ConcatSrcVT = V.getOperand(0).getValueType();
+  assert(ConcatSrcVT.getVectorElementType() == NVT.getVectorElementType() &&
+         "Concat and extract subvector do not change element type");
+
+  unsigned ConcatSrcNumElts = ConcatSrcVT.getVectorMinNumElements();
+  unsigned ConcatOpIdx = ExtIdx / ConcatSrcNumElts;
+  if (ConcatOpIdx >= V.getNumOperands())
+    return SDValue();
+
+  // If the concatenated source types match this extract, it's a direct
+  // simplification:
+  // extract_subvec (concat V1, V2, ...), i --> Vi
+  if (NVT.getVectorElementCount() == ConcatSrcVT.getVectorElementCount())
+    return V.getOperand(ConcatOpIdx);
+
+  if (!NVT.isFixedLengthVector() || !ConcatSrcVT.isFixedLengthVector())
+    return SDValue();
+
+  // If the concatenated source vectors are a multiple length of this extract,
+  // then extract a fraction of one of those source vectors directly from a
+  // concat operand. Example:
+  //   v2i8 extract_subvec (v16i8 concat (v8i8 X), (v8i8 Y)), 14 -->
+  //   v2i8 extract_subvec v8i8 Y, 6
+  if (ConcatSrcNumElts % ExtNumElts == 0) {
+    uint64_t NewExtIdx = ExtIdx - ConcatOpIdx * ConcatSrcNumElts;
+    if (NewExtIdx + ExtNumElts > ConcatSrcNumElts)
+      return SDValue();
+    assert(NewExtIdx % ExtNumElts == 0 &&
+           "Extract index is not a multiple of the input vector length.");
+    return DAG.getExtractSubvector(DL, NVT, V.getOperand(ConcatOpIdx),
+                                   NewExtIdx);
+  }
+
+  // If the extract covers multiple whole concat operands, rebuild that smaller
+  // concat directly.
+  if (ExtNumElts % ConcatSrcNumElts == 0 && ExtIdx % ConcatSrcNumElts == 0) {
+    if (LegalOperations &&
+        !DAG.getTargetLoweringInfo().isOperationLegalOrCustom(
+            ISD::CONCAT_VECTORS, NVT))
+      return SDValue();
+
+    unsigned NumConcatOps = ExtNumElts / ConcatSrcNumElts;
+    if (ConcatOpIdx + NumConcatOps > V.getNumOperands())
+      return SDValue();
+
+    SmallVector<SDValue, 8> Ops;
----------------
RKSimon wrote:

return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Op->ops().slice(I, NumConcatOps));

https://github.com/llvm/llvm-project/pull/200936


More information about the llvm-branch-commits mailing list