[llvm] DAG: Move scalarizeExtractedVectorLoad to TargetLowering (PR #122670)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 29 08:33:38 PST 2025


https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/122670

>From d763962e444cfbb091bd9fbdd193381559fddedd Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Mon, 13 Jan 2025 11:09:20 +0700
Subject: [PATCH 1/2] DAG: Move scalarizeExtractedVectorLoad to TargetLowering

SimplifyDemandedVectorElts should be able to use this on loads
---
 llvm/include/llvm/CodeGen/TargetLowering.h    | 12 +++
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  9 ++-
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 74 +++++++++++++++++++
 3 files changed, 93 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9fcd2ac9514e56..04ee24c0916e5f 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5622,6 +5622,18 @@ class TargetLowering : public TargetLoweringBase {
   // joining their results. SDValue() is returned when expansion did not happen.
   SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;
 
+  /// Replace an extraction of a load with a narrowed load.
+  ///
+  /// \param ResultVT type of the result extraction.
+  /// \param InVecVT type of the input vector to with bitcasts resolved.
+  /// \param EltNo index of the vector element to load.
+  /// \param OriginalLoad vector load that to be replaced.
+  /// \returns \p ResultVT Load on success SDValue() on failure.
+  SDValue scalarizeExtractedVectorLoad(EVT ResultVT, const SDLoc &DL,
+                                       EVT InVecVT, SDValue EltNo,
+                                       LoadSDNode *OriginalLoad,
+                                       SelectionDAG &DAG) const;
+
 private:
   SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
                            const SDLoc &DL, DAGCombinerInfo &DCI) const;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a0c703d2df8a2f..50675fcd98d3d5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -23268,8 +23268,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
       ISD::isNormalLoad(VecOp.getNode()) &&
       !Index->hasPredecessor(VecOp.getNode())) {
     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
-    if (VecLoad && VecLoad->isSimple())
-      return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
+    if (VecLoad && VecLoad->isSimple()) {
+      if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
+              ExtVT, SDLoc(N), VecVT, Index, VecLoad, DAG)) {
+        ++OpsNarrowed;
+        return Scalarized;
+      }
+    }
   }
 
   // Perform only after legalization to ensure build_vector / vector_shuffle
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 49ec47f4e8a702..5140b40166b99e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12114,3 +12114,77 @@ SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
   SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
 }
+
+SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
+                                                     const SDLoc &DL,
+                                                     EVT InVecVT, SDValue EltNo,
+                                                     LoadSDNode *OriginalLoad,
+                                                     SelectionDAG &DAG) const {
+  assert(OriginalLoad->isSimple());
+
+  EVT VecEltVT = InVecVT.getVectorElementType();
+
+  // If the vector element type is not a multiple of a byte then we are unable
+  // to correctly compute an address to load only the extracted element as a
+  // scalar.
+  if (!VecEltVT.isByteSized())
+    return SDValue();
+
+  ISD::LoadExtType ExtTy =
+      ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
+  if (!isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
+      !shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
+    return SDValue();
+
+  Align Alignment = OriginalLoad->getAlign();
+  MachinePointerInfo MPI;
+  if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
+    int Elt = ConstEltNo->getZExtValue();
+    unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
+    MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
+    Alignment = commonAlignment(Alignment, PtrOff);
+  } else {
+    // Discard the pointer info except the address space because the memory
+    // operand can't represent this new access since the offset is variable.
+    MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
+    Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
+  }
+
+  unsigned IsFast = 0;
+  if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
+                          OriginalLoad->getAddressSpace(), Alignment,
+                          OriginalLoad->getMemOperand()->getFlags(), &IsFast) ||
+      !IsFast)
+    return SDValue();
+
+  SDValue NewPtr =
+      getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo);
+
+  // We are replacing a vector load with a scalar load. The new load must have
+  // identical memory op ordering to the original.
+  SDValue Load;
+  if (ResultVT.bitsGT(VecEltVT)) {
+    // If the result type of vextract is wider than the load, then issue an
+    // extending load instead.
+    ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT)
+                                   ? ISD::ZEXTLOAD
+                                   : ISD::EXTLOAD;
+    Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
+                          NewPtr, MPI, VecEltVT, Alignment,
+                          OriginalLoad->getMemOperand()->getFlags(),
+                          OriginalLoad->getAAInfo());
+    DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
+  } else {
+    // The result type is narrower or the same width as the vector element
+    Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
+                       Alignment, OriginalLoad->getMemOperand()->getFlags(),
+                       OriginalLoad->getAAInfo());
+    DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
+    if (ResultVT.bitsLT(VecEltVT))
+      Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
+    else
+      Load = DAG.getBitcast(ResultVT, Load);
+  }
+
+  return Load;
+}

>From 65e9c1b6e762019e6ea2a19b919ad798ebf19b80 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Wed, 29 Jan 2025 22:59:41 +0700
Subject: [PATCH 2/2] Remove DAGCombiner copy

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 88 +------------------
 1 file changed, 1 insertion(+), 87 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 50675fcd98d3d5..740b19baa7dd9b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -385,17 +385,6 @@ namespace {
     bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
     bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);
 
-    /// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
-    ///   load.
-    ///
-    /// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
-    /// \param InVecVT type of the input vector to EVE with bitcasts resolved.
-    /// \param EltNo index of the vector element to load.
-    /// \param OriginalLoad load that EVE came from to be replaced.
-    /// \returns EVE on success SDValue() on failure.
-    SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
-                                         SDValue EltNo,
-                                         LoadSDNode *OriginalLoad);
     void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
     SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
     SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
@@ -22720,81 +22709,6 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
   return SDValue();
 }
 
-SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
-                                                  SDValue EltNo,
-                                                  LoadSDNode *OriginalLoad) {
-  assert(OriginalLoad->isSimple());
-
-  EVT ResultVT = EVE->getValueType(0);
-  EVT VecEltVT = InVecVT.getVectorElementType();
-
-  // If the vector element type is not a multiple of a byte then we are unable
-  // to correctly compute an address to load only the extracted element as a
-  // scalar.
-  if (!VecEltVT.isByteSized())
-    return SDValue();
-
-  ISD::LoadExtType ExtTy =
-      ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
-  if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
-      !TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
-    return SDValue();
-
-  Align Alignment = OriginalLoad->getAlign();
-  MachinePointerInfo MPI;
-  SDLoc DL(EVE);
-  if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
-    int Elt = ConstEltNo->getZExtValue();
-    unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
-    MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
-    Alignment = commonAlignment(Alignment, PtrOff);
-  } else {
-    // Discard the pointer info except the address space because the memory
-    // operand can't represent this new access since the offset is variable.
-    MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
-    Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
-  }
-
-  unsigned IsFast = 0;
-  if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
-                              OriginalLoad->getAddressSpace(), Alignment,
-                              OriginalLoad->getMemOperand()->getFlags(),
-                              &IsFast) ||
-      !IsFast)
-    return SDValue();
-
-  SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
-                                               InVecVT, EltNo);
-
-  // We are replacing a vector load with a scalar load. The new load must have
-  // identical memory op ordering to the original.
-  SDValue Load;
-  if (ResultVT.bitsGT(VecEltVT)) {
-    // If the result type of vextract is wider than the load, then issue an
-    // extending load instead.
-    ISD::LoadExtType ExtType =
-        TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
-                                                              : ISD::EXTLOAD;
-    Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
-                          NewPtr, MPI, VecEltVT, Alignment,
-                          OriginalLoad->getMemOperand()->getFlags(),
-                          OriginalLoad->getAAInfo());
-    DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
-  } else {
-    // The result type is narrower or the same width as the vector element
-    Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
-                       Alignment, OriginalLoad->getMemOperand()->getFlags(),
-                       OriginalLoad->getAAInfo());
-    DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
-    if (ResultVT.bitsLT(VecEltVT))
-      Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
-    else
-      Load = DAG.getBitcast(ResultVT, Load);
-  }
-  ++OpsNarrowed;
-  return Load;
-}
-
 /// Transform a vector binary operation into a scalar binary operation by moving
 /// the math/logic after an extract element of a vector.
 static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
@@ -23362,7 +23276,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
   if (Elt == -1)
     return DAG.getUNDEF(LVT);
 
-  return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
+  return TLI.scalarizeExtractedVectorLoad(LVT, DL, VecVT, Index, LN0, DAG);
 }
 
 // Simplify (build_vec (ext )) to (bitcast (build_vec ))



More information about the llvm-commits mailing list