[llvm] DAG: Move scalarizeExtractedVectorLoad to TargetLowering (PR #122670)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 12 23:09:44 PST 2025


https://github.com/arsenm created https://github.com/llvm/llvm-project/pull/122670

SimplifyDemandedVectorElts should be able to use this on loads

>From 9bedb1442c1b1e6288544f9887668146b4ab0327 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Mon, 13 Jan 2025 11:09:20 +0700
Subject: [PATCH] DAG: Move scalarizeExtractedVectorLoad to TargetLowering

SimplifyDemandedVectorElts should be able to use this on loads
---
 llvm/include/llvm/CodeGen/TargetLowering.h    | 12 +++
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  9 ++-
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 74 +++++++++++++++++++
 3 files changed, 93 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index ce58777655e063..f898dfaa89915f 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5618,6 +5618,18 @@ class TargetLowering : public TargetLoweringBase {
   // joining their results. SDValue() is returned when expansion did not happen.
   SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;
 
+  /// Replace an extraction of a load with a narrowed load.
+  ///
+  /// \param ResultVT type of the result extraction.
+  /// \param InVecVT type of the input vector to with bitcasts resolved.
+  /// \param EltNo index of the vector element to load.
+  /// \param OriginalLoad vector load that to be replaced.
+  /// \returns \p ResultVT Load on success SDValue() on failure.
+  SDValue scalarizeExtractedVectorLoad(EVT ResultVT, const SDLoc &DL,
+                                       EVT InVecVT, SDValue EltNo,
+                                       LoadSDNode *OriginalLoad,
+                                       SelectionDAG &DAG) const;
+
 private:
   SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
                            const SDLoc &DL, DAGCombinerInfo &DCI) const;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index da3c834417d6b2..6497531047476f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -23246,8 +23246,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
       ISD::isNormalLoad(VecOp.getNode()) &&
       !Index->hasPredecessor(VecOp.getNode())) {
     auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
-    if (VecLoad && VecLoad->isSimple())
-      return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
+    if (VecLoad && VecLoad->isSimple()) {
+      if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
+              ExtVT, SDLoc(N), VecVT, Index, VecLoad, DAG)) {
+        ++OpsNarrowed;
+        return Scalarized;
+      }
+    }
   }
 
   // Perform only after legalization to ensure build_vector / vector_shuffle
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 56194e2614af2d..b1fb4947fb9451 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12069,3 +12069,77 @@ SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
   SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
   return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
 }
+
+SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
+                                                     const SDLoc &DL,
+                                                     EVT InVecVT, SDValue EltNo,
+                                                     LoadSDNode *OriginalLoad,
+                                                     SelectionDAG &DAG) const {
+  assert(OriginalLoad->isSimple());
+
+  EVT VecEltVT = InVecVT.getVectorElementType();
+
+  // If the vector element type is not a multiple of a byte then we are unable
+  // to correctly compute an address to load only the extracted element as a
+  // scalar.
+  if (!VecEltVT.isByteSized())
+    return SDValue();
+
+  ISD::LoadExtType ExtTy =
+      ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
+  if (!isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
+      !shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
+    return SDValue();
+
+  Align Alignment = OriginalLoad->getAlign();
+  MachinePointerInfo MPI;
+  if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
+    int Elt = ConstEltNo->getZExtValue();
+    unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
+    MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
+    Alignment = commonAlignment(Alignment, PtrOff);
+  } else {
+    // Discard the pointer info except the address space because the memory
+    // operand can't represent this new access since the offset is variable.
+    MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
+    Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
+  }
+
+  unsigned IsFast = 0;
+  if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
+                          OriginalLoad->getAddressSpace(), Alignment,
+                          OriginalLoad->getMemOperand()->getFlags(), &IsFast) ||
+      !IsFast)
+    return SDValue();
+
+  SDValue NewPtr =
+      getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo);
+
+  // We are replacing a vector load with a scalar load. The new load must have
+  // identical memory op ordering to the original.
+  SDValue Load;
+  if (ResultVT.bitsGT(VecEltVT)) {
+    // If the result type of vextract is wider than the load, then issue an
+    // extending load instead.
+    ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT)
+                                   ? ISD::ZEXTLOAD
+                                   : ISD::EXTLOAD;
+    Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
+                          NewPtr, MPI, VecEltVT, Alignment,
+                          OriginalLoad->getMemOperand()->getFlags(),
+                          OriginalLoad->getAAInfo());
+    DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
+  } else {
+    // The result type is narrower or the same width as the vector element
+    Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
+                       Alignment, OriginalLoad->getMemOperand()->getFlags(),
+                       OriginalLoad->getAAInfo());
+    DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
+    if (ResultVT.bitsLT(VecEltVT))
+      Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
+    else
+      Load = DAG.getBitcast(ResultVT, Load);
+  }
+
+  return Load;
+}



More information about the llvm-commits mailing list