[llvm-branch-commits] [llvm] [SelectionDAG][X86] Split <2 x T> vector types for atomic load (PR #120640)
    via llvm-branch-commits 
    llvm-branch-commits at lists.llvm.org
       
    Mon Jan  6 11:25:21 PST 2025
    
    
  
https://github.com/jofrn updated https://github.com/llvm/llvm-project/pull/120640
>From 3be4fa05ccaa6b9a2485445211723ef5a4b47964 Mon Sep 17 00:00:00 2001
From: jofrn <jofernau at amd.com>
Date: Thu, 19 Dec 2024 16:25:55 -0500
Subject: [PATCH] [SelectionDAG][X86] Split <2 x T> vector types for atomic
 load
Vector types of 2 elements that aren't widened are split
so that they can be vectorized within SelectionDAG. This change
utilizes the load vectorization infrastructure in order to
regroup the split elements. This enables SelectionDAG to
translate vectors with type bfloat,half.
commit-id:3a045357
---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  4 +-
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |  1 +
 .../SelectionDAG/LegalizeVectorTypes.cpp      | 35 +++++++++++++
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 24 ++++++++-
 .../SelectionDAGAddressAnalysis.cpp           | 30 ++++++-----
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  6 ++-
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 51 ++++++++++++-------
 llvm/test/CodeGen/X86/atomic-load-store.ll    | 18 +++++++
 8 files changed, 133 insertions(+), 36 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index ff7caec41855fd..54f2cd3fb12105 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1835,7 +1835,7 @@ class SelectionDAG {
   /// chain to the token factor. This ensures that the new memory node will have
   /// the same relative memory dependency position as the old load. Returns the
   /// new merged load chain.
-  SDValue makeEquivalentMemoryOrdering(LoadSDNode *OldLoad, SDValue NewMemOp);
+  SDValue makeEquivalentMemoryOrdering(MemSDNode *OldLoad, SDValue NewMemOp);
 
   /// Topological-sort the AllNodes list and a
   /// assign a unique node id for each node in the DAG based on their
@@ -2261,6 +2261,8 @@ class SelectionDAG {
   /// location that the 'Base' load is loading from.
   bool areNonVolatileConsecutiveLoads(LoadSDNode *LD, LoadSDNode *Base,
                                       unsigned Bytes, int Dist) const;
+  bool areNonVolatileConsecutiveLoads(AtomicSDNode *LD, AtomicSDNode *Base,
+                                      unsigned Bytes, int Dist) const;
 
   /// Infer alignment of a load / store address. Return std::nullopt if it
   /// cannot be inferred.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 3b3dddc44e3682..e0cd7319ac034b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -946,6 +946,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVecRes_FPOp_MultiType(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_IS_FPCLASS(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_INSERT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_VP_LOAD(VPLoadSDNode *LD, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *SLD, SDValue &Lo,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index bc0a3a4589b941..5401e910fa78e6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1148,6 +1148,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     SplitVecRes_STEP_VECTOR(N, Lo, Hi);
     break;
   case ISD::SIGN_EXTEND_INREG: SplitVecRes_InregOp(N, Lo, Hi); break;
+  case ISD::ATOMIC_LOAD:
+    SplitVecRes_ATOMIC_LOAD(cast<AtomicSDNode>(N), Lo, Hi);
+    break;
   case ISD::LOAD:
     SplitVecRes_LOAD(cast<LoadSDNode>(N), Lo, Hi);
     break;
@@ -1391,6 +1394,38 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
     SetSplitVector(SDValue(N, ResNo), Lo, Hi);
 }
 
+void DAGTypeLegalizer::SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD, SDValue &Lo,
+                                               SDValue &Hi) {
+  EVT LoVT, HiVT;
+  SDLoc dl(LD);
+  std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(LD->getValueType(0));
+
+  SDValue Ch = LD->getChain();
+  SDValue Ptr = LD->getBasePtr();
+  EVT MemoryVT = LD->getMemoryVT();
+
+  EVT LoMemVT, HiMemVT;
+  std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
+
+  Lo = DAG.getAtomic(ISD::ATOMIC_LOAD, dl, LoMemVT, LoMemVT, Ch, Ptr,
+                     LD->getMemOperand());
+
+  MachinePointerInfo MPI;
+  IncrementPointer(LD, LoMemVT, MPI, Ptr);
+
+  Hi = DAG.getAtomic(ISD::ATOMIC_LOAD, dl, HiMemVT, HiMemVT, Ch, Ptr,
+                     LD->getMemOperand());
+
+  // Build a factor node to remember that this load is independent of the
+  // other one.
+  Ch = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Lo.getValue(1),
+                   Hi.getValue(1));
+
+  // Legalize the chain result - switch anything that used the old chain to
+  // use the new one.
+  ReplaceValueWith(SDValue(LD, 1), Ch);
+}
+
 void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
                                         MachinePointerInfo &MPI, SDValue &Ptr,
                                         uint64_t *ScaledOffset) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 0dfd0302ae5438..572cce96e496dc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12161,7 +12161,7 @@ SDValue SelectionDAG::makeEquivalentMemoryOrdering(SDValue OldChain,
   return TokenFactor;
 }
 
-SDValue SelectionDAG::makeEquivalentMemoryOrdering(LoadSDNode *OldLoad,
+SDValue SelectionDAG::makeEquivalentMemoryOrdering(MemSDNode *OldLoad,
                                                    SDValue NewMemOp) {
   assert(isa<MemSDNode>(NewMemOp.getNode()) && "Expected a memop node");
   SDValue OldChain = SDValue(OldLoad, 1);
@@ -12873,13 +12873,33 @@ std::pair<SDValue, SDValue> SelectionDAG::UnrollVectorOverflowOp(
                         getBuildVector(NewOvVT, dl, OvScalars));
 }
 
+bool SelectionDAG::areNonVolatileConsecutiveLoads(AtomicSDNode *LD,
+                                                  AtomicSDNode *Base,
+                                                  unsigned Bytes,
+                                                  int Dist) const {
+  if (LD->isVolatile() || Base->isVolatile())
+    return false;
+  if (LD->getChain() != Base->getChain())
+    return false;
+  EVT VT = LD->getMemoryVT();
+  if (VT.getSizeInBits() / 8 != Bytes)
+    return false;
+
+  auto BaseLocDecomp = BaseIndexOffset::match(Base, *this);
+  auto LocDecomp = BaseIndexOffset::match(LD, *this);
+
+  int64_t Offset = 0;
+  if (BaseLocDecomp.equalBaseIndex(LocDecomp, *this, Offset))
+    return (Dist * (int64_t)Bytes == Offset);
+  return false;
+}
+
 bool SelectionDAG::areNonVolatileConsecutiveLoads(LoadSDNode *LD,
                                                   LoadSDNode *Base,
                                                   unsigned Bytes,
                                                   int Dist) const {
   if (LD->isVolatile() || Base->isVolatile())
     return false;
-  // TODO: probably too restrictive for atomics, revisit
   if (!LD->isSimple())
     return false;
   if (LD->isIndexed() || Base->isIndexed())
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGAddressAnalysis.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGAddressAnalysis.cpp
index f2ab88851b780e..a19af64a796229 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGAddressAnalysis.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGAddressAnalysis.cpp
@@ -194,8 +194,8 @@ bool BaseIndexOffset::contains(const SelectionDAG &DAG, int64_t BitSize,
   return false;
 }
 
-/// Parses tree in Ptr for base, index, offset addresses.
-static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
+template <typename T>
+static BaseIndexOffset matchSDNode(const T *N,
                                    const SelectionDAG &DAG) {
   SDValue Ptr = N->getBasePtr();
 
@@ -206,16 +206,18 @@ static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
   bool IsIndexSignExt = false;
 
   // pre-inc/pre-dec ops are components of EA.
-  if (N->getAddressingMode() == ISD::PRE_INC) {
-    if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
-      Offset += C->getSExtValue();
-    else // If unknown, give up now.
-      return BaseIndexOffset(SDValue(), SDValue(), 0, false);
-  } else if (N->getAddressingMode() == ISD::PRE_DEC) {
-    if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
-      Offset -= C->getSExtValue();
-    else // If unknown, give up now.
-      return BaseIndexOffset(SDValue(), SDValue(), 0, false);
+  if constexpr (std::is_same_v<T, LSBaseSDNode>) {
+    if (N->getAddressingMode() == ISD::PRE_INC) {
+      if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
+        Offset += C->getSExtValue();
+      else // If unknown, give up now.
+        return BaseIndexOffset(SDValue(), SDValue(), 0, false);
+    } else if (N->getAddressingMode() == ISD::PRE_DEC) {
+      if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
+        Offset -= C->getSExtValue();
+      else // If unknown, give up now.
+        return BaseIndexOffset(SDValue(), SDValue(), 0, false);
+    }
   }
 
   // Consume constant adds & ors with appropriate masking.
@@ -300,8 +302,10 @@ static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
 
 BaseIndexOffset BaseIndexOffset::match(const SDNode *N,
                                        const SelectionDAG &DAG) {
+  if (const auto *AN = dyn_cast<AtomicSDNode>(N))
+    return matchSDNode(AN, DAG);
   if (const auto *LS0 = dyn_cast<LSBaseSDNode>(N))
-    return matchLSNode(LS0, DAG);
+    return matchSDNode(LS0, DAG);
   if (const auto *LN = dyn_cast<LifetimeSDNode>(N)) {
     if (LN->hasOffset())
       return BaseIndexOffset(LN->getOperand(1), SDValue(), LN->getOffset(),
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index f8d7c3ef7bbe71..06af262814532a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -5218,7 +5218,11 @@ void SelectionDAGBuilder::visitAtomicLoad(const LoadInst &I) {
     L = DAG.getPtrExtOrTrunc(L, dl, VT);
 
   setValue(&I, L);
-  DAG.setRoot(OutChain);
+
+  if (VT.isVector())
+    DAG.setRoot(InChain);
+  else
+    DAG.setRoot(OutChain);
 }
 
 void SelectionDAGBuilder::visitAtomicStore(const StoreInst &I) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a7ee433351ff06..d9fbc48ed2a6de 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -7049,15 +7049,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
   return SDValue();
 }
 
-// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
-static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
-  if (ISD::isNON_EXTLoad(Elt.getNode())) {
-    auto *BaseLd = cast<LoadSDNode>(Elt);
-    if (!BaseLd->isSimple())
-      return false;
-    Ld = BaseLd;
-    ByteOffset = 0;
-    return true;
+template <typename T>
+static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
+  if constexpr (std::is_same_v<T, AtomicSDNode>) {
+    if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
+      Ld = BaseLd;
+      ByteOffset = 0;
+      return true;
+    }
+  } else if constexpr (std::is_same_v<T, LoadSDNode>) {
+    if (ISD::isNON_EXTLoad(Elt.getNode())) {
+      auto *BaseLd = cast<LoadSDNode>(Elt);
+      if (!BaseLd->isSimple())
+        return false;
+      Ld = BaseLd;
+      ByteOffset = 0;
+      return true;
+    }
   }
 
   switch (Elt.getOpcode()) {
@@ -7097,6 +7105,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
 /// a build_vector or insert_subvector whose loaded operands are 'Elts'.
 ///
 /// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
+template <typename T>
 static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
                                         const SDLoc &DL, SelectionDAG &DAG,
                                         const X86Subtarget &Subtarget,
@@ -7111,7 +7120,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
   APInt ZeroMask = APInt::getZero(NumElems);
   APInt UndefMask = APInt::getZero(NumElems);
 
-  SmallVector<LoadSDNode*, 8> Loads(NumElems, nullptr);
+  SmallVector<T*, 8> Loads(NumElems, nullptr);
   SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
 
   // For each element in the initializer, see if we've found a load, zero or an
@@ -7161,7 +7170,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
   EVT EltBaseVT = EltBase.getValueType();
   assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
          "Register/Memory size mismatch");
-  LoadSDNode *LDBase = Loads[FirstLoadedElt];
+  T *LDBase = Loads[FirstLoadedElt];
   assert(LDBase && "Did not find base load for merging consecutive loads");
   unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
   unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7175,8 +7184,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
 
   // Check to see if the element's load is consecutive to the base load
   // or offset from a previous (already checked) load.
-  auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
-    LoadSDNode *Ld = Loads[EltIdx];
+  auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
+    T *Ld = Loads[EltIdx];
     int64_t ByteOffset = ByteOffsets[EltIdx];
     if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
       int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7204,7 +7213,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
     }
   }
 
-  auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
+  auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
     auto MMOFlags = LDBase->getMemOperand()->getFlags();
     assert(LDBase->isSimple() &&
            "Cannot merge volatile or atomic loads.");
@@ -7274,7 +7283,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
       EVT HalfVT =
           EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
       SDValue HalfLD =
-          EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
+          EltsFromConsecutiveLoads<T>(HalfVT, Elts.drop_back(HalfNumElems), DL,
                                    DAG, Subtarget, IsAfterLegalize);
       if (HalfLD)
         return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7351,7 +7360,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
           EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
                            VT.getSizeInBits() / ScalarSize);
       if (TLI.isTypeLegal(BroadcastVT)) {
-        if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
+        if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T>(
                 RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
           SDValue Broadcast = RepeatLoad;
           if (RepeatSize > ScalarSize) {
@@ -7392,7 +7401,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
     return SDValue();
   }
   assert(Elts.size() == VT.getVectorNumElements());
-  return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
+  return EltsFromConsecutiveLoads<LoadSDNode>(VT, Elts, DL, DAG, Subtarget,
                                   IsAfterLegalize);
 }
 
@@ -9247,8 +9256,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
   {
     SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
     if (SDValue LD =
-            EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
+            EltsFromConsecutiveLoads<LoadSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
       return LD;
+    } else if (SDValue LD =
+            EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
+      return LD;
+    }
   }
 
   // If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -57934,7 +57947,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
                                 *FirstLd->getMemOperand(), &Fast) &&
         Fast) {
       if (SDValue Ld =
-              EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
+              EltsFromConsecutiveLoads<LoadSDNode>(VT, Ops, DL, DAG, Subtarget, false))
         return Ld;
     }
   }
diff --git a/llvm/test/CodeGen/X86/atomic-load-store.ll b/llvm/test/CodeGen/X86/atomic-load-store.ll
index 6c2a7e1d68c382..398f665484569a 100644
--- a/llvm/test/CodeGen/X86/atomic-load-store.ll
+++ b/llvm/test/CodeGen/X86/atomic-load-store.ll
@@ -195,6 +195,24 @@ define <2 x float> @atomic_vec2_float_align(ptr %x) {
   ret <2 x float> %ret
 }
 
+define <2 x half> @atomic_vec2_half(ptr %x) {
+; CHECK-LABEL: atomic_vec2_half:
+; CHECK:       ## %bb.0:
+; CHECK-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; CHECK-NEXT:    retq
+  %ret = load atomic <2 x half>, ptr %x acquire, align 4
+  ret <2 x half> %ret
+}
+
+define <2 x bfloat> @atomic_vec2_bfloat(ptr %x) {
+; CHECK-LABEL: atomic_vec2_bfloat:
+; CHECK:       ## %bb.0:
+; CHECK-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; CHECK-NEXT:    retq
+  %ret = load atomic <2 x bfloat>, ptr %x acquire, align 4
+  ret <2 x bfloat> %ret
+}
+
 define <1 x ptr> @atomic_vec1_ptr(ptr %x) nounwind {
 ; CHECK3-LABEL: atomic_vec1_ptr:
 ; CHECK3:       ## %bb.0:
    
    
More information about the llvm-branch-commits
mailing list