[llvm] [DAG] visitEXTRACT_SUBVECTOR - change fold helper methods to take operands instead of EXTRACT_SUBVECTOR node. NFC. (PR #138279)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri May 2 07:33:26 PDT 2025


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/138279

>From a6c602df7a5a9a2eb1f13c3ec69c4c9d15e44069 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 2 May 2025 14:44:00 +0100
Subject: [PATCH 1/2] [DAG] visitEXTRACT_SUBVECTOR - change fold helper methods
 to take raw operands instead of EXTRACT_SUBVECTOR node. NFC.

Call with the individual subvector type, source vector and index operands instead of the original EXTRACT_SUBVECTOR node.

Some of these folds still assumed that EXTRACT_SUBVECTOR/INSERT_SUBVECTOR nodes could have variable indices, despite us moving to all constant indices some time ago - all of that code has now been simplified.

I've moved the narrowExtractedVectorBinOp higher up, but it won't affect fold order - it didn't rely on the peekThroughBitcasts call, and worked on BinOps, not BUILD_VECTOR/INSERT_SUBVECTOR nodes.

Prep work to make it easier for more of these folds to work through BITCASTS.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 101 +++++++-----------
 1 file changed, 41 insertions(+), 60 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..a9bf94e02b55b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -25100,26 +25100,26 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
 
 // Helper that peeks through INSERT_SUBVECTOR/CONCAT_VECTORS to find
 // if the subvector can be sourced for free.
-static SDValue getSubVectorSrc(SDValue V, SDValue Index, EVT SubVT) {
+static SDValue getSubVectorSrc(SDValue V, unsigned Index, EVT SubVT) {
   if (V.getOpcode() == ISD::INSERT_SUBVECTOR &&
-      V.getOperand(1).getValueType() == SubVT && V.getOperand(2) == Index) {
+      V.getOperand(1).getValueType() == SubVT &&
+      V.getConstantOperandAPInt(2) == Index) {
     return V.getOperand(1);
   }
-  auto *IndexC = dyn_cast<ConstantSDNode>(Index);
-  if (IndexC && V.getOpcode() == ISD::CONCAT_VECTORS &&
+  if (V.getOpcode() == ISD::CONCAT_VECTORS &&
       V.getOperand(0).getValueType() == SubVT &&
-      (IndexC->getZExtValue() % SubVT.getVectorMinNumElements()) == 0) {
-    uint64_t SubIdx = IndexC->getZExtValue() / SubVT.getVectorMinNumElements();
+      (Index % SubVT.getVectorMinNumElements()) == 0) {
+    uint64_t SubIdx = Index / SubVT.getVectorMinNumElements();
     return V.getOperand(SubIdx);
   }
   return SDValue();
 }
 
-static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
+static SDValue narrowInsertExtractVectorBinOp(EVT SubVT, SDValue BinOp,
+                                              unsigned Index, const SDLoc &DL,
                                               SelectionDAG &DAG,
                                               bool LegalOperations) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  SDValue BinOp = Extract->getOperand(0);
   unsigned BinOpcode = BinOp.getOpcode();
   if (!TLI.isBinOp(BinOpcode) || BinOp->getNumValues() != 1)
     return SDValue();
@@ -25128,9 +25128,6 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
   SDValue Bop0 = BinOp.getOperand(0), Bop1 = BinOp.getOperand(1);
   if (VecVT != Bop0.getValueType() || VecVT != Bop1.getValueType())
     return SDValue();
-
-  SDValue Index = Extract->getOperand(1);
-  EVT SubVT = Extract->getValueType(0);
   if (!TLI.isOperationLegalOrCustom(BinOpcode, SubVT, LegalOperations))
     return SDValue();
 
@@ -25146,29 +25143,26 @@ static SDValue narrowInsertExtractVectorBinOp(SDNode *Extract,
   // We are inserting both operands of the wide binop only to extract back
   // to the narrow vector size. Eliminate all of the insert/extract:
   // ext (binop (ins ?, X, Index), (ins ?, Y, Index)), Index --> binop X, Y
-  return DAG.getNode(BinOpcode, SDLoc(Extract), SubVT, Sub0, Sub1,
-                     BinOp->getFlags());
+  return DAG.getNode(BinOpcode, DL, SubVT, Sub0, Sub1, BinOp->getFlags());
 }
 
 /// If we are extracting a subvector produced by a wide binary operator try
 /// to use a narrow binary operator and/or avoid concatenation and extraction.
-static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
+static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index,
+                                          const SDLoc &DL, SelectionDAG &DAG,
                                           bool LegalOperations) {
   // TODO: Refactor with the caller (visitEXTRACT_SUBVECTOR), so we can share
   // some of these bailouts with other transforms.
 
-  if (SDValue V = narrowInsertExtractVectorBinOp(Extract, DAG, LegalOperations))
+  if (SDValue V = narrowInsertExtractVectorBinOp(VT, Src, Index, DL, DAG,
+                                                 LegalOperations))
     return V;
 
-  // The extract index must be a constant, so we can map it to a concat operand.
-  auto *ExtractIndexC = dyn_cast<ConstantSDNode>(Extract->getOperand(1));
-  if (!ExtractIndexC)
-    return SDValue();
 
   // We are looking for an optionally bitcasted wide vector binary operator
   // feeding an extract subvector.
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  SDValue BinOp = peekThroughBitcasts(Extract->getOperand(0));
+  SDValue BinOp = peekThroughBitcasts(Src);
   unsigned BOpcode = BinOp.getOpcode();
   if (!TLI.isBinOp(BOpcode) || BinOp->getNumValues() != 1)
     return SDValue();
@@ -25190,9 +25184,7 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
   if (!WideBVT.isFixedLengthVector())
     return SDValue();
 
-  EVT VT = Extract->getValueType(0);
-  unsigned ExtractIndex = ExtractIndexC->getZExtValue();
-  assert(ExtractIndex % VT.getVectorNumElements() == 0 &&
+  assert((Index % VT.getVectorNumElements()) == 0 &&
          "Extract index is not a multiple of the vector length.");
 
   // Bail out if this is not a proper multiple width extraction.
@@ -25219,12 +25211,11 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
   // for concat ops. The narrow binop alone makes this transform profitable.
   // We can't just reuse the original extract index operand because we may have
   // bitcasted.
-  unsigned ConcatOpNum = ExtractIndex / VT.getVectorNumElements();
+  unsigned ConcatOpNum = Index / VT.getVectorNumElements();
   unsigned ExtBOIdx = ConcatOpNum * NarrowBVT.getVectorNumElements();
   if (TLI.isExtractSubvectorCheap(NarrowBVT, WideBVT, ExtBOIdx) &&
-      BinOp.hasOneUse() && Extract->getOperand(0)->hasOneUse()) {
+      BinOp.hasOneUse() && Src->hasOneUse()) {
     // extract (binop B0, B1), N --> binop (extract B0, N), (extract B1, N)
-    SDLoc DL(Extract);
     SDValue NewExtIndex = DAG.getVectorIdxConstant(ExtBOIdx, DL);
     SDValue X = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
                             BinOp.getOperand(0), NewExtIndex);
@@ -25264,7 +25255,6 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
     // extract (binop (concat X1, X2), (concat Y1, Y2)), N --> binop XN, YN
     // extract (binop (concat X1, X2), Y), N --> binop XN, (extract Y, IndexC)
     // extract (binop X, (concat Y1, Y2)), N --> binop (extract X, IndexC), YN
-    SDLoc DL(Extract);
     SDValue IndexC = DAG.getVectorIdxConstant(ExtBOIdx, DL);
     SDValue X = SubVecL ? DAG.getBitcast(NarrowBVT, SubVecL)
                         : DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NarrowBVT,
@@ -25284,24 +25274,24 @@ static SDValue narrowExtractedVectorBinOp(SDNode *Extract, SelectionDAG &DAG,
 /// If we are extracting a subvector from a wide vector load, convert to a
 /// narrow load to eliminate the extraction:
 /// (extract_subvector (load wide vector)) --> (load narrow vector)
-static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL,
-                                         SelectionDAG &DAG) {
+static SDValue narrowExtractedVectorLoad(EVT VT, SDValue Src, unsigned Index,
+                                         const SDLoc &DL, SelectionDAG &DAG) {
   // TODO: Add support for big-endian. The offset calculation must be adjusted.
   if (DAG.getDataLayout().isBigEndian())
     return SDValue();
 
-  auto *Ld = dyn_cast<LoadSDNode>(Extract->getOperand(0));
+  auto *Ld = dyn_cast<LoadSDNode>(Src);
   if (!Ld || !ISD::isNormalLoad(Ld) || !Ld->isSimple())
     return SDValue();
 
-  // Allow targets to opt-out.
-  EVT VT = Extract->getValueType(0);
-
   // We can only create byte sized loads.
   if (!VT.isByteSized())
     return SDValue();
 
-  unsigned Index = Extract->getConstantOperandVal(1);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (!TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, VT))
+    return SDValue();
+
   unsigned NumElts = VT.getVectorMinNumElements();
   // A fixed length vector being extracted from a scalable vector
   // may not be any *smaller* than the scalable one.
@@ -25319,7 +25309,6 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL,
   if (Offset.isFixed())
     ByteOffset = Offset.getFixedValue();
 
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   if (!TLI.shouldReduceLoadWidth(Ld, Ld->getExtensionType(), VT, ByteOffset))
     return SDValue();
 
@@ -25350,23 +25339,18 @@ static SDValue narrowExtractedVectorLoad(SDNode *Extract, const SDLoc &DL,
 /// iff it is legal and profitable to do so. Notably, the trimmed mask
 /// (containing only the elements that are extracted)
 /// must reference at most two subvectors.
-static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
+static SDValue foldExtractSubvectorFromShuffleVector(EVT NarrowVT, SDValue Src,
+                                                     unsigned Index,
+                                                     const SDLoc &DL,
                                                      SelectionDAG &DAG,
-                                                     const TargetLowering &TLI,
                                                      bool LegalOperations) {
-  assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
-         "Must only be called on EXTRACT_SUBVECTOR's");
-
-  SDValue N0 = N->getOperand(0);
-
   // Only deal with non-scalable vectors.
-  EVT NarrowVT = N->getValueType(0);
-  EVT WideVT = N0.getValueType();
+  EVT WideVT = Src.getValueType();
   if (!NarrowVT.isFixedLengthVector() || !WideVT.isFixedLengthVector())
     return SDValue();
 
   // The operand must be a shufflevector.
-  auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(N0);
+  auto *WideShuffleVector = dyn_cast<ShuffleVectorSDNode>(Src);
   if (!WideShuffleVector)
     return SDValue();
 
@@ -25375,13 +25359,13 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
     return SDValue();
 
   // And the narrow shufflevector that we'll form must be legal.
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   if (LegalOperations &&
       !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, NarrowVT))
     return SDValue();
 
-  uint64_t FirstExtractedEltIdx = N->getConstantOperandVal(1);
   int NumEltsExtracted = NarrowVT.getVectorNumElements();
-  assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 &&
+  assert((Index % NumEltsExtracted) == 0 &&
          "Extract index is not a multiple of the output vector length.");
 
   int WideNumElts = WideVT.getVectorNumElements();
@@ -25392,8 +25376,7 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
       DemandedSubvectors;
 
   // Try to decode the wide mask into narrow mask from at most two subvectors.
-  for (int M : WideShuffleVector->getMask().slice(FirstExtractedEltIdx,
-                                                  NumEltsExtracted)) {
+  for (int M : WideShuffleVector->getMask().slice(Index, NumEltsExtracted)) {
     assert((M >= -1) && (M < (2 * WideNumElts)) &&
            "Out-of-bounds shuffle mask?");
 
@@ -25476,8 +25459,6 @@ static SDValue foldExtractSubvectorFromShuffleVector(SDNode *N,
       !TLI.isShuffleMaskLegal(NewMask, NarrowVT))
     return SDValue();
 
-  SDLoc DL(N);
-
   SmallVector<SDValue, 2> NewOps;
   for (const std::pair<SDValue /*Op*/, int /*SubvectorIndex*/>
            &DemandedSubvector : DemandedSubvectors) {
@@ -25507,9 +25488,8 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
   if (V.isUndef())
     return DAG.getUNDEF(NVT);
 
-  if (TLI.isOperationLegalOrCustomOrPromote(ISD::LOAD, NVT))
-    if (SDValue NarrowLoad = narrowExtractedVectorLoad(N, DL, DAG))
-      return NarrowLoad;
+  if (SDValue NarrowLoad = narrowExtractedVectorLoad(NVT, V, ExtIdx, DL, DAG))
+    return NarrowLoad;
 
   // Combine an extract of an extract into a single extract_subvector.
   // ext (ext X, C), 0 --> ext X, C
@@ -25631,9 +25611,13 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
     }
   }
 
-  if (SDValue V =
-          foldExtractSubvectorFromShuffleVector(N, DAG, TLI, LegalOperations))
-    return V;
+  if (SDValue Shuffle = foldExtractSubvectorFromShuffleVector(
+          NVT, V, ExtIdx, DL, DAG, LegalOperations))
+    return Shuffle;
+
+  if (SDValue NarrowBOp =
+          narrowExtractedVectorBinOp(NVT, V, ExtIdx, DL, DAG, LegalOperations))
+    return NarrowBOp;
 
   V = peekThroughBitcasts(V);
 
@@ -25694,9 +25678,6 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode *N) {
     }
   }
 
-  if (SDValue NarrowBOp = narrowExtractedVectorBinOp(N, DAG, LegalOperations))
-    return NarrowBOp;
-
   // If only EXTRACT_SUBVECTOR nodes use the source vector we can
   // simplify it based on the (valid) extractions.
   if (!V.getValueType().isScalableVector() &&

>From b169bd55a2dd6be6007f50e4677b7ed8cc3b4f7b Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 2 May 2025 15:33:11 +0100
Subject: [PATCH 2/2] Fix clang-format warning

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a9bf94e02b55b..5c686f2f3907f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -25158,7 +25158,6 @@ static SDValue narrowExtractedVectorBinOp(EVT VT, SDValue Src, unsigned Index,
                                                  LegalOperations))
     return V;
 
-
   // We are looking for an optionally bitcasted wide vector binary operator
   // feeding an extract subvector.
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();



More information about the llvm-commits mailing list