[llvm] [DAGCombiner] Move handling of atomic loads from SystemZ to DAGCombiner (NFC). (PR #86484)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 25 03:39:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Jonas Paulsson (JonPsson1)

<details>
<summary>Changes</summary>

The folding of sign/zero extensions into an atomic load by specifying an extension type is not target specific, and therefore belongs in the DAGCombiner rather than in the SystemZ backend.

- Handle atomic loads similarly to regular loads by adding AtomicLoadExtActions with set/get methods.
- Move SystemZ extendAtomicLoad() to DagCombiner.cpp.



---
Full diff: https://github.com/llvm/llvm-project/pull/86484.diff


4 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+50) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+41) 
- (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+6) 
- (modified) llvm/lib/Target/SystemZ/SystemZISelLowering.cpp (+9-33) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 59fad88f91b1d1..a4dc097446186a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1454,6 +1454,28 @@ class TargetLoweringBase {
            getLoadExtAction(ExtType, ValVT, MemVT) == Custom;
   }
 
+  /// Same as getLoadExtAction, but for atomic loads.
+  LegalizeAction getAtomicLoadExtAction(unsigned ExtType, EVT ValVT,
+                                        EVT MemVT) const {
+    if (ValVT.isExtended() || MemVT.isExtended()) return Expand;
+    unsigned ValI = (unsigned)ValVT.getSimpleVT().SimpleTy;
+    unsigned MemI = (unsigned)MemVT.getSimpleVT().SimpleTy;
+    assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValI < MVT::VALUETYPE_SIZE &&
+           MemI < MVT::VALUETYPE_SIZE && "Table isn't big enough!");
+    unsigned Shift = 4 * ExtType;
+    LegalizeAction Action =
+        (LegalizeAction)((AtomicLoadExtActions[ValI][MemI] >> Shift) & 0xf);
+    assert((Action == Legal || Action == Expand) &&
+           "Unsupported atomic load extension action.");
+    return Action;
+  }
+
+  /// Return true if the specified atomic load with extension is legal on
+  /// this target.
+  bool isAtomicLoadExtLegal(unsigned ExtType, EVT ValVT, EVT MemVT) const {
+    return getAtomicLoadExtAction(ExtType, ValVT, MemVT) == Legal;
+  }
+
   /// Return how this store with truncation should be treated: either it is
   /// legal, needs to be promoted to a larger size, needs to be expanded to some
   /// other code sequence, or the target has a custom expander for it.
@@ -2536,6 +2558,30 @@ class TargetLoweringBase {
       setLoadExtAction(ExtTypes, ValVT, MemVT, Action);
   }
 
+  /// Let target indicate that an extending atomic load of the specified type
+  /// is legal.
+  void setAtomicLoadExtAction(unsigned ExtType, MVT ValVT, MVT MemVT,
+                              LegalizeAction Action) {
+    assert(ExtType < ISD::LAST_LOADEXT_TYPE && ValVT.isValid() &&
+           MemVT.isValid() && "Table isn't big enough!");
+    assert((unsigned)Action < 0x10 && "too many bits for bitfield array");
+    unsigned Shift = 4 * ExtType;
+    AtomicLoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] &=
+        ~((uint16_t)0xF << Shift);
+    AtomicLoadExtActions[ValVT.SimpleTy][MemVT.SimpleTy] |=
+        ((uint16_t)Action << Shift);
+  }
+  void setAtomicLoadExtAction(ArrayRef<unsigned> ExtTypes, MVT ValVT, MVT MemVT,
+                              LegalizeAction Action) {
+    for (auto ExtType : ExtTypes)
+      setAtomicLoadExtAction(ExtType, ValVT, MemVT, Action);
+  }
+  void setAtomicLoadExtAction(ArrayRef<unsigned> ExtTypes, MVT ValVT,
+                              ArrayRef<MVT> MemVTs, LegalizeAction Action) {
+    for (auto MemVT : MemVTs)
+      setAtomicLoadExtAction(ExtTypes, ValVT, MemVT, Action);
+  }
+
   /// Indicate that the specified truncating store does not work with the
   /// specified type and indicate what to do about it.
   void setTruncStoreAction(MVT ValVT, MVT MemVT, LegalizeAction Action) {
@@ -3521,6 +3567,10 @@ class TargetLoweringBase {
   /// for each of the 4 load ext types.
   uint16_t LoadExtActions[MVT::VALUETYPE_SIZE][MVT::VALUETYPE_SIZE];
 
+  /// Similar to LoadExtActions, but for atomic loads. Only Legal or Expand
+  /// (default) values are supported.
+  uint16_t AtomicLoadExtActions[MVT::VALUETYPE_SIZE][MVT::VALUETYPE_SIZE];
+
   /// For each value type pair keep a LegalizeAction that indicates whether a
   /// truncating store of a specific value type and truncating type is legal.
   LegalizeAction TruncStoreActions[MVT::VALUETYPE_SIZE][MVT::VALUETYPE_SIZE];
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e27a8bb8fdacda..7b5dc15eac53b4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13131,6 +13131,37 @@ tryToFoldExtOfMaskedLoad(SelectionDAG &DAG, const TargetLowering &TLI, EVT VT,
   return NewLoad;
 }
 
+// fold ([s|z]ext (atomic_load)) -> ([s|z]ext (truncate ([s|z]ext atomic_load)))
+static SDValue tryToFoldExtOfAtomicLoad(SelectionDAG &DAG,
+                                        const TargetLowering &TLI, EVT VT,
+                                        SDValue N0,
+                                        ISD::LoadExtType ExtLoadType) {
+  auto *ALoad = dyn_cast<AtomicSDNode>(N0);
+  if (!ALoad || ALoad->getOpcode() != ISD::ATOMIC_LOAD)
+    return {};
+  EVT MemoryVT = ALoad->getMemoryVT();
+  if (!TLI.isAtomicLoadExtLegal(ExtLoadType, VT, MemoryVT))
+    return {};
+  // Can't fold into atomic_load if it is already extending differently.
+  ISD::LoadExtType ALoadExtTy = ALoad->getExtensionType();
+  if ((ALoadExtTy == ISD::ZEXTLOAD && ExtLoadType == ISD::SEXTLOAD) ||
+      (ALoadExtTy == ISD::SEXTLOAD && ExtLoadType == ISD::ZEXTLOAD))
+    return {};
+
+  EVT OrigVT = ALoad->getValueType(0);
+  assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
+  auto *NewALoad = dyn_cast<AtomicSDNode>(DAG.getAtomic(
+      ISD::ATOMIC_LOAD, SDLoc(ALoad), MemoryVT, VT, ALoad->getChain(),
+      ALoad->getBasePtr(), ALoad->getMemOperand()));
+  NewALoad->setExtensionType(ExtLoadType);
+  DAG.ReplaceAllUsesOfValueWith(
+      SDValue(ALoad, 0),
+      DAG.getNode(ISD::TRUNCATE, SDLoc(ALoad), OrigVT, SDValue(NewALoad, 0)));
+  // Update the chain uses.
+  DAG.ReplaceAllUsesOfValueWith(SDValue(ALoad, 1), SDValue(NewALoad, 1));
+  return SDValue(NewALoad, 0);
+}
+
 static SDValue foldExtendedSignBitTest(SDNode *N, SelectionDAG &DAG,
                                        bool LegalOperations) {
   assert((N->getOpcode() == ISD::SIGN_EXTEND ||
@@ -13402,6 +13433,11 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
           DAG, *this, TLI, VT, LegalOperations, N, N0, ISD::SEXTLOAD))
     return foldedExt;
 
+  // Try to simplify (sext (atomic_load x)).
+  if (SDValue foldedExt =
+          tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::SEXTLOAD))
+    return foldedExt;
+
   // fold (sext (and/or/xor (load x), cst)) ->
   //      (and/or/xor (sextload x), (sext cst))
   if (ISD::isBitwiseLogicOp(N0.getOpcode()) &&
@@ -13713,6 +13749,11 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
   if (SDValue ExtLoad = CombineExtLoad(N))
     return ExtLoad;
 
+  // Try to simplify (zext (atomic_load x)).
+  if (SDValue foldedExt =
+          tryToFoldExtOfAtomicLoad(DAG, TLI, VT, N0, ISD::ZEXTLOAD))
+    return foldedExt;
+
   // fold (zext (and/or/xor (load x), cst)) ->
   //      (and/or/xor (zextload x), (zext cst))
   // Unless (and (load x) cst) will match as a zextload already and has
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 9990556f89ed8b..6250607f7092ef 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -823,6 +823,12 @@ void TargetLoweringBase::initActions() {
   std::fill(std::begin(TargetDAGCombineArray),
             std::end(TargetDAGCombineArray), 0);
 
+  // Let extending atomic loads be unsupported by default.
+  for (MVT ValVT : MVT::all_valuetypes())
+    for (MVT MemVT : MVT::all_valuetypes())
+      setAtomicLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD}, ValVT, MemVT,
+                             Expand);
+
   // We're somewhat special casing MVT::i2 and MVT::i4. Ideally we want to
   // remove this and targets should individually set these types if not legal.
   for (ISD::NodeType NT : enum_seq(ISD::DELETED_NODE, ISD::BUILTIN_OP_END,
diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
index da4bcd7f0c66ed..6496fe766101fc 100644
--- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
@@ -293,6 +293,15 @@ SystemZTargetLowering::SystemZTargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::ATOMIC_LOAD,     MVT::i128, Custom);
   setOperationAction(ISD::ATOMIC_STORE,    MVT::i128, Custom);
 
+  // Mark sign/zero extending atomic loads as legal, which will make
+  // DAGCombiner fold extensions into atomic loads if possible.
+  setAtomicLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::i64,
+                         {MVT::i8, MVT::i16, MVT::i32}, Legal);
+  setAtomicLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::i32,
+                         {MVT::i8, MVT::i16}, Legal);
+  setAtomicLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::i16,
+                         MVT::i8, Legal);
+
   // We can use the CC result of compare-and-swap to implement
   // the "success" result of ATOMIC_CMP_SWAP_WITH_SUCCESS.
   setOperationAction(ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS, MVT::i32, Custom);
@@ -6614,27 +6623,6 @@ SDValue SystemZTargetLowering::combineTruncateExtract(
   return SDValue();
 }
 
-// Replace ALoad with a new ATOMIC_LOAD with a result that is extended to VT
-// per ETy.
-static SDValue extendAtomicLoad(AtomicSDNode *ALoad, EVT VT, SelectionDAG &DAG,
-                                ISD::LoadExtType ETy) {
-  if (VT.getSizeInBits() > 64)
-    return SDValue();
-  EVT OrigVT = ALoad->getValueType(0);
-  assert(OrigVT.getSizeInBits() < VT.getSizeInBits() && "VT should be wider.");
-  EVT MemoryVT = ALoad->getMemoryVT();
-  auto *NewALoad = dyn_cast<AtomicSDNode>(DAG.getAtomic(
-      ISD::ATOMIC_LOAD, SDLoc(ALoad), MemoryVT, VT, ALoad->getChain(),
-      ALoad->getBasePtr(), ALoad->getMemOperand()));
-  NewALoad->setExtensionType(ETy);
-  DAG.ReplaceAllUsesOfValueWith(
-      SDValue(ALoad, 0),
-      DAG.getNode(ISD::TRUNCATE, SDLoc(ALoad), OrigVT, SDValue(NewALoad, 0)));
-  // Update the chain uses.
-  DAG.ReplaceAllUsesOfValueWith(SDValue(ALoad, 1), SDValue(NewALoad, 1));
-  return SDValue(NewALoad, 0);
-}
-
 SDValue SystemZTargetLowering::combineZERO_EXTEND(
     SDNode *N, DAGCombinerInfo &DCI) const {
   // Convert (zext (select_ccmask C1, C2)) into (select_ccmask C1', C2')
@@ -6681,12 +6669,6 @@ SDValue SystemZTargetLowering::combineZERO_EXTEND(
     }
   }
 
-  // Fold into ATOMIC_LOAD unless it is already sign extending.
-  if (auto *ALoad = dyn_cast<AtomicSDNode>(N0))
-    if (ALoad->getOpcode() == ISD::ATOMIC_LOAD &&
-        ALoad->getExtensionType() != ISD::SEXTLOAD)
-      return extendAtomicLoad(ALoad, VT, DAG, ISD::ZEXTLOAD);
-
   return SDValue();
 }
 
@@ -6739,12 +6721,6 @@ SDValue SystemZTargetLowering::combineSIGN_EXTEND(
     }
   }
 
-  // Fold into ATOMIC_LOAD unless it is already zero extending.
-  if (auto *ALoad = dyn_cast<AtomicSDNode>(N0))
-    if (ALoad->getOpcode() == ISD::ATOMIC_LOAD &&
-        ALoad->getExtensionType() != ISD::ZEXTLOAD)
-      return extendAtomicLoad(ALoad, VT, DAG, ISD::SEXTLOAD);
-
   return SDValue();
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/86484


More information about the llvm-commits mailing list