[llvm] [CodeGen] Add getTgtMemIntrinsic overload for multiple memory operands (NFC) (PR #175843)

Nicolai Hähnle via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 13 13:47:50 PST 2026


https://github.com/nhaehnle updated https://github.com/llvm/llvm-project/pull/175843

>From 29893007295c8928ff3b2eb0d297c9a072341301 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nicolai=20H=C3=A4hnle?= <nicolai.haehnle at amd.com>
Date: Mon, 12 Jan 2026 09:42:04 -0800
Subject: [PATCH] [CodeGen] Add getTgtMemIntrinsic overload for multiple memory
 operands (NFC)

There are target intrinsics that logically require two MMOs, such as
llvm.amdgcn.global.load.lds, which is a copy from global memory to LDS,
so there's both a load and a store to different addresses.

Add an overload of getTgtMemIntrinsic that produces intrinsic info in a
vector, and implement it in terms of the existing (now protected)
overload.

GlobalISel and SelectionDAG paths are updated to support multiple MMOs.
The main part of this change is supporting multiple MMOs in
MemIntrinsicNodes.

Converting the backends to using the new overload is a fairly mechanical step
that is done in a separate change in the hope that that allows reducing merging
pains during review and for downstreams. A later change will then enable
using multiple MMOs in AMDGPU.

commit-id:b4a924aa
---
 .../llvm/CodeGen/GlobalISel/IRTranslator.h    |   2 +-
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  20 ++-
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 115 +++++++++++++-----
 llvm/include/llvm/CodeGen/TargetLowering.h    |  23 +++-
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  47 +++----
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  91 ++++++++++----
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  67 +++++-----
 .../SelectionDAG/SelectionDAGDumper.cpp       |   8 +-
 .../CodeGen/SelectionDAG/SelectionDAGISel.cpp |   3 +-
 9 files changed, 252 insertions(+), 124 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
index 5f5a6f5c72abf..de7bfda9f81da 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h
@@ -299,7 +299,7 @@ class IRTranslator : public MachineFunctionPass {
 
   bool translateIntrinsic(
       const CallBase &CB, Intrinsic::ID ID, MachineIRBuilder &MIRBuilder,
-      const TargetLowering::IntrinsicInfo *TgtMemIntrinsicInfo = nullptr);
+      ArrayRef<TargetLowering::IntrinsicInfo> TgtMemIntrinsicInfos = {});
 
   /// When an invoke or a cleanupret unwinds to the next EH pad, there are
   /// many places it could ultimately go. In the IR, we have a single unwind
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 604319095e74f..649386014c66f 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -435,10 +435,18 @@ class SelectionDAG {
 
   template <typename SDNodeTy>
   static uint16_t getSyntheticNodeSubclassData(unsigned Opc, unsigned Order,
-                                                SDVTList VTs, EVT MemoryVT,
-                                                MachineMemOperand *MMO) {
+                                               SDVTList VTs, EVT MemoryVT,
+                                               MachineMemOperand *MMO) {
     return SDNodeTy(Opc, Order, DebugLoc(), VTs, MemoryVT, MMO)
-         .getRawSubclassData();
+        .getRawSubclassData();
+  }
+
+  template <typename SDNodeTy>
+  static uint16_t getSyntheticNodeSubclassData(
+      unsigned Opc, unsigned Order, SDVTList VTs, EVT MemoryVT,
+      PointerUnion<MachineMemOperand *, MachineMemOperand **> MemRefs) {
+    return SDNodeTy(Opc, Order, DebugLoc(), VTs, MemoryVT, MemRefs)
+        .getRawSubclassData();
   }
 
   void createOperands(SDNode *Node, ArrayRef<SDValue> Vals);
@@ -1456,6 +1464,12 @@ class SelectionDAG {
                                        SDVTList VTList, ArrayRef<SDValue> Ops,
                                        EVT MemVT, MachineMemOperand *MMO);
 
+  /// getMemIntrinsicNode - Creates a MemIntrinsicNode with multiple MMOs.
+  LLVM_ABI SDValue getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl,
+                                       SDVTList VTList, ArrayRef<SDValue> Ops,
+                                       EVT MemVT,
+                                       ArrayRef<MachineMemOperand *> MMOs);
+
   /// Creates a LifetimeSDNode that starts (`IsStart==true`) or ends
   /// (`IsStart==false`) the lifetime of the `FrameIndex`.
   LLVM_ABI SDValue getLifetimeNode(bool IsStart, const SDLoc &dl, SDValue Chain,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index aa72e81b2ab54..ecbcc0a87a4fc 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1411,19 +1411,26 @@ class MemSDNode : public SDNode {
   EVT MemoryVT;
 
 protected:
-  /// Memory reference information.
-  MachineMemOperand *MMO;
+  /// Memory reference information. Must always have at least one MMO.
+  /// - MachineMemOperand*: exactly 1 MMO (common case)
+  /// - MachineMemOperand**: pointer to array, size at offset -1
+  PointerUnion<MachineMemOperand *, MachineMemOperand **> MemRefs;
 
 public:
-  LLVM_ABI MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl,
-                     SDVTList VTs, EVT memvt, MachineMemOperand *MMO);
+  /// Constructor that supports single or multiple MMOs. For single MMO, pass
+  /// the MMO pointer directly. For multiple MMOs, pre-allocate storage with
+  /// count at offset -1 and pass pointer to array.
+  LLVM_ABI
+  MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs,
+            EVT memvt,
+            PointerUnion<MachineMemOperand *, MachineMemOperand **> memrefs);
 
-  bool readMem() const { return MMO->isLoad(); }
-  bool writeMem() const { return MMO->isStore(); }
+  bool readMem() const { return getMemOperand()->isLoad(); }
+  bool writeMem() const { return getMemOperand()->isStore(); }
 
   /// Returns alignment and volatility of the memory access
-  Align getBaseAlign() const { return MMO->getBaseAlign(); }
-  Align getAlign() const { return MMO->getAlign(); }
+  Align getBaseAlign() const { return getMemOperand()->getBaseAlign(); }
+  Align getAlign() const { return getMemOperand()->getAlign(); }
 
   /// Return the SubclassData value, without HasDebugValue. This contains an
   /// encoding of the volatile flag, as well as bits used by subclasses. This
@@ -1450,36 +1457,40 @@ class MemSDNode : public SDNode {
   bool isInvariant() const { return MemSDNodeBits.IsInvariant; }
 
   // Returns the offset from the location of the access.
-  int64_t getSrcValueOffset() const { return MMO->getOffset(); }
+  int64_t getSrcValueOffset() const { return getMemOperand()->getOffset(); }
 
   /// Returns the AA info that describes the dereference.
-  AAMDNodes getAAInfo() const { return MMO->getAAInfo(); }
+  AAMDNodes getAAInfo() const { return getMemOperand()->getAAInfo(); }
 
   /// Returns the Ranges that describes the dereference.
-  const MDNode *getRanges() const { return MMO->getRanges(); }
+  const MDNode *getRanges() const { return getMemOperand()->getRanges(); }
 
   /// Returns the synchronization scope ID for this memory operation.
-  SyncScope::ID getSyncScopeID() const { return MMO->getSyncScopeID(); }
+  SyncScope::ID getSyncScopeID() const {
+    return getMemOperand()->getSyncScopeID();
+  }
 
   /// Return the atomic ordering requirements for this memory operation. For
   /// cmpxchg atomic operations, return the atomic ordering requirements when
   /// store occurs.
   AtomicOrdering getSuccessOrdering() const {
-    return MMO->getSuccessOrdering();
+    return getMemOperand()->getSuccessOrdering();
   }
 
   /// Return a single atomic ordering that is at least as strong as both the
   /// success and failure orderings for an atomic operation.  (For operations
   /// other than cmpxchg, this is equivalent to getSuccessOrdering().)
-  AtomicOrdering getMergedOrdering() const { return MMO->getMergedOrdering(); }
+  AtomicOrdering getMergedOrdering() const {
+    return getMemOperand()->getMergedOrdering();
+  }
 
   /// Return true if the memory operation ordering is Unordered or higher.
-  bool isAtomic() const { return MMO->isAtomic(); }
+  bool isAtomic() const { return getMemOperand()->isAtomic(); }
 
   /// Returns true if the memory operation doesn't imply any ordering
   /// constraints on surrounding memory operations beyond the normal memory
   /// aliasing rules.
-  bool isUnordered() const { return MMO->isUnordered(); }
+  bool isUnordered() const { return getMemOperand()->isUnordered(); }
 
   /// Returns true if the memory operation is neither atomic or volatile.
   bool isSimple() const { return !isAtomic() && !isVolatile(); }
@@ -1487,12 +1498,34 @@ class MemSDNode : public SDNode {
   /// Return the type of the in-memory value.
   EVT getMemoryVT() const { return MemoryVT; }
 
-  /// Return a MachineMemOperand object describing the memory
+  /// Return the unique MachineMemOperand object describing the memory
   /// reference performed by operation.
-  MachineMemOperand *getMemOperand() const { return MMO; }
+  /// Asserts if multiple MMOs are present - use memoperands() instead.
+  MachineMemOperand *getMemOperand() const {
+    assert(!isa<MachineMemOperand **>(MemRefs) &&
+           "Use memoperands() for nodes with multiple memory operands");
+    return cast<MachineMemOperand *>(MemRefs);
+  }
+
+  /// Return the number of memory operands.
+  size_t getNumMemOperands() const {
+    if (isa<MachineMemOperand *>(MemRefs))
+      return 1;
+    MachineMemOperand **Array = cast<MachineMemOperand **>(MemRefs);
+    return reinterpret_cast<size_t *>(Array)[-1];
+  }
+
+  /// Return the memory operands for this node.
+  ArrayRef<MachineMemOperand *> memoperands() const {
+    if (isa<MachineMemOperand *>(MemRefs))
+      return ArrayRef(MemRefs.getAddrOfPtr1(), 1);
+    MachineMemOperand **Array = cast<MachineMemOperand **>(MemRefs);
+    size_t Count = reinterpret_cast<size_t *>(Array)[-1];
+    return ArrayRef(Array, Count);
+  }
 
   const MachinePointerInfo &getPointerInfo() const {
-    return MMO->getPointerInfo();
+    return getMemOperand()->getPointerInfo();
   }
 
   /// Return the address space for the associated pointer
@@ -1501,19 +1534,35 @@ class MemSDNode : public SDNode {
   }
 
   /// Update this MemSDNode's MachineMemOperand information
-  /// to reflect the alignment of NewMMO, if it has a greater alignment.
+  /// to reflect the alignment of NewMMOs, if they have greater alignment.
   /// This must only be used when the new alignment applies to all users of
-  /// this MachineMemOperand.
-  void refineAlignment(const MachineMemOperand *NewMMO) {
-    MMO->refineAlignment(NewMMO);
+  /// these MachineMemOperands. The NewMMOs array must parallel memoperands().
+  void refineAlignment(ArrayRef<MachineMemOperand *> NewMMOs) {
+    ArrayRef<MachineMemOperand *> MMOs = memoperands();
+    assert(NewMMOs.size() == MMOs.size() && "MMO count mismatch");
+    for (auto [MMO, NewMMO] : zip(MMOs, NewMMOs))
+      MMO->refineAlignment(NewMMO);
+  }
+
+  void refineAlignment(MachineMemOperand *NewMMO) {
+    refineAlignment(ArrayRef(NewMMO));
   }
 
-  void refineRanges(const MachineMemOperand *NewMMO) {
-    // If this node has range metadata that is different than NewMMO, clear the
-    // range metadata.
+  /// Refine range metadata for all MMOs. The NewMMOs array must parallel
+  /// memoperands(). For each pair, if ranges differ, the stored range is
+  /// cleared.
+  void refineRanges(ArrayRef<MachineMemOperand *> NewMMOs) {
+    ArrayRef<MachineMemOperand *> MMOs = memoperands();
+    assert(NewMMOs.size() == MMOs.size() && "MMO count mismatch");
     // FIXME: Union the ranges instead?
-    if (getRanges() && getRanges() != NewMMO->getRanges())
-      MMO->clearRanges();
+    for (auto [MMO, NewMMO] : zip(MMOs, NewMMOs)) {
+      if (MMO->getRanges() && MMO->getRanges() != NewMMO->getRanges())
+        MMO->clearRanges();
+    }
+  }
+
+  void refineRanges(MachineMemOperand *NewMMO) {
+    refineRanges(ArrayRef(NewMMO));
   }
 
   const SDValue &getChain() const { return getOperand(0); }
@@ -1626,7 +1675,7 @@ class AtomicSDNode : public MemSDNode {
   /// when store does not occur.
   AtomicOrdering getFailureOrdering() const {
     assert(isCompareAndSwap() && "Must be cmpxchg operation");
-    return MMO->getFailureOrdering();
+    return getMemOperand()->getFailureOrdering();
   }
 
   // Methods to support isa and dyn_cast
@@ -1666,9 +1715,11 @@ class AtomicSDNode : public MemSDNode {
 /// opcode (see `SelectionDAGTargetInfo::isTargetMemoryOpcode`).
 class MemIntrinsicSDNode : public MemSDNode {
 public:
-  MemIntrinsicSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl,
-                     SDVTList VTs, EVT MemoryVT, MachineMemOperand *MMO)
-      : MemSDNode(Opc, Order, dl, VTs, MemoryVT, MMO) {
+  MemIntrinsicSDNode(
+      unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs,
+      EVT MemoryVT,
+      PointerUnion<MachineMemOperand *, MachineMemOperand **> MemRefs)
+      : MemSDNode(Opc, Order, dl, VTs, MemoryVT, MemRefs) {
     SDNodeBits.IsMemIntrinsic = true;
   }
 
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 1df4ba582d1b9..5a2a32e72719f 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1259,15 +1259,32 @@ class LLVM_ABI TargetLoweringBase {
   };
 
   /// Given an intrinsic, checks if on the target the intrinsic will need to map
-  /// to a MemIntrinsicNode (touches memory). If this is the case, it returns
-  /// true and store the intrinsic information into the IntrinsicInfo that was
-  /// passed to the function.
+  /// to a MemIntrinsicNode (touches memory). If this is the case, it stores
+  /// the intrinsic information into the IntrinsicInfo vector passed to the
+  /// function. The vector may contain multiple entries for intrinsics that
+  /// access multiple memory locations.
+  virtual void getTgtMemIntrinsic(SmallVectorImpl<IntrinsicInfo> &Infos,
+                                  const CallBase &I, MachineFunction &MF,
+                                  unsigned Intrinsic) const {
+    // The default implementation forwards to the legacy single-info overload
+    // for compatibility.
+    IntrinsicInfo Info;
+    if (getTgtMemIntrinsic(Info, I, MF, Intrinsic))
+      Infos.push_back(Info);
+  }
+
+protected:
+  /// This is a legacy single-info overload. New code should override the
+  /// SmallVectorImpl overload instead to support multiple memory operands.
+  ///
+  /// TODO: Remove this once the refactoring is complete.
   virtual bool getTgtMemIntrinsic(IntrinsicInfo &, const CallBase &,
                                   MachineFunction &,
                                   unsigned /*Intrinsic*/) const {
     return false;
   }
 
+public:
   /// Returns true if the target can instruction select the specified FP
   /// immediate natively. If false, the legalizer will materialize the FP
   /// immediate as a load from a constant pool.
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 12552bce3caaa..981be7492ff5c 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2829,20 +2829,16 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   if (translateKnownIntrinsic(CI, ID, MIRBuilder))
     return true;
 
-  TargetLowering::IntrinsicInfo Info;
-  bool IsTgtMemIntrinsic = TLI->getTgtMemIntrinsic(Info, CI, *MF, ID);
+  SmallVector<TargetLowering::IntrinsicInfo> Infos;
+  TLI->getTgtMemIntrinsic(Infos, CI, *MF, ID);
 
-  return translateIntrinsic(CI, ID, MIRBuilder,
-                            IsTgtMemIntrinsic ? &Info : nullptr);
+  return translateIntrinsic(CI, ID, MIRBuilder, Infos);
 }
 
 /// Translate a call or callbr to an intrinsic.
-/// Depending on whether TLI->getTgtMemIntrinsic() is true, TgtMemIntrinsicInfo
-/// is a pointer to the correspondingly populated IntrinsicInfo object.
-/// Otherwise, this pointer is null.
 bool IRTranslator::translateIntrinsic(
     const CallBase &CB, Intrinsic::ID ID, MachineIRBuilder &MIRBuilder,
-    const TargetLowering::IntrinsicInfo *TgtMemIntrinsicInfo) {
+    ArrayRef<TargetLowering::IntrinsicInfo> TgtMemIntrinsicInfos) {
   ArrayRef<Register> ResultRegs;
   if (!CB.getType()->isVoidTy())
     ResultRegs = getOrCreateVRegs(CB);
@@ -2884,30 +2880,25 @@ bool IRTranslator::translateIntrinsic(
     }
   }
 
-  // Add a MachineMemOperand if it is a target mem intrinsic.
-  if (TgtMemIntrinsicInfo) {
-    const Function *F = CB.getCalledFunction();
+  // Add MachineMemOperands for each memory access described by the target.
+  for (const auto &Info : TgtMemIntrinsicInfos) {
+    Align Alignment = Info.align.value_or(
+        DL->getABITypeAlign(Info.memVT.getTypeForEVT(CB.getContext())));
+    LLT MemTy = Info.memVT.isSimple()
+                    ? getLLTForMVT(Info.memVT.getSimpleVT())
+                    : LLT::scalar(Info.memVT.getStoreSizeInBits());
 
-    Align Alignment = TgtMemIntrinsicInfo->align.value_or(DL->getABITypeAlign(
-        TgtMemIntrinsicInfo->memVT.getTypeForEVT(F->getContext())));
-    LLT MemTy =
-        TgtMemIntrinsicInfo->memVT.isSimple()
-            ? getLLTForMVT(TgtMemIntrinsicInfo->memVT.getSimpleVT())
-            : LLT::scalar(TgtMemIntrinsicInfo->memVT.getStoreSizeInBits());
-
-    // TODO: We currently just fallback to address space 0 if getTgtMemIntrinsic
-    //       didn't yield anything useful.
+    // TODO: We currently just fallback to address space 0 if
+    // getTgtMemIntrinsic didn't yield anything useful.
     MachinePointerInfo MPI;
-    if (TgtMemIntrinsicInfo->ptrVal) {
-      MPI = MachinePointerInfo(TgtMemIntrinsicInfo->ptrVal,
-                               TgtMemIntrinsicInfo->offset);
-    } else if (TgtMemIntrinsicInfo->fallbackAddressSpace) {
-      MPI = MachinePointerInfo(*TgtMemIntrinsicInfo->fallbackAddressSpace);
+    if (Info.ptrVal) {
+      MPI = MachinePointerInfo(Info.ptrVal, Info.offset);
+    } else if (Info.fallbackAddressSpace) {
+      MPI = MachinePointerInfo(*Info.fallbackAddressSpace);
     }
     MIB.addMemOperand(MF->getMachineMemOperand(
-        MPI, TgtMemIntrinsicInfo->flags, MemTy, Alignment, CB.getAAMetadata(),
-        /*Ranges=*/nullptr, TgtMemIntrinsicInfo->ssid,
-        TgtMemIntrinsicInfo->order, TgtMemIntrinsicInfo->failureOrder));
+        MPI, Info.flags, MemTy, Alignment, CB.getAAMetadata(),
+        /*Ranges=*/nullptr, Info.ssid, Info.order, Info.failureOrder));
   }
 
   if (CB.isConvergent()) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index c8a4dc6e67908..6c3497e30f005 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -974,9 +974,11 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
   // to check.
   if (auto *MN = dyn_cast<MemIntrinsicSDNode>(N)) {
     ID.AddInteger(MN->getRawSubclassData());
-    ID.AddInteger(MN->getPointerInfo().getAddrSpace());
-    ID.AddInteger(MN->getMemOperand()->getFlags());
     ID.AddInteger(MN->getMemoryVT().getRawBits());
+    for (const MachineMemOperand *MMO : MN->memoperands()) {
+      ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+      ID.AddInteger(MMO->getFlags());
+    }
   }
 }
 
@@ -1276,7 +1278,7 @@ SelectionDAG::AddModifiedNodeToCSEMaps(SDNode *N) {
       // recursive merging of other unrelated nodes down the line.
       Existing->intersectFlagsWith(N->getFlags());
       if (auto *MemNode = dyn_cast<MemSDNode>(Existing))
-        MemNode->refineRanges(cast<MemSDNode>(N)->getMemOperand());
+        MemNode->refineRanges(cast<MemSDNode>(N)->memoperands());
       ReplaceAllUsesWith(N, Existing);
 
       // N is now dead. Inform the listeners and delete it.
@@ -9677,6 +9679,14 @@ SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl,
                                           SDVTList VTList,
                                           ArrayRef<SDValue> Ops, EVT MemVT,
                                           MachineMemOperand *MMO) {
+  return getMemIntrinsicNode(Opcode, dl, VTList, Ops, MemVT, ArrayRef(MMO));
+}
+
+SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl,
+                                          SDVTList VTList,
+                                          ArrayRef<SDValue> Ops, EVT MemVT,
+                                          ArrayRef<MachineMemOperand *> MMOs) {
+  assert(!MMOs.empty() && "Must have at least one MMO");
   assert(
       (Opcode == ISD::INTRINSIC_VOID || Opcode == ISD::INTRINSIC_W_CHAIN ||
        Opcode == ISD::PREFETCH ||
@@ -9684,30 +9694,47 @@ SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl,
         Opcode >= ISD::BUILTIN_OP_END && TSI->isTargetMemoryOpcode(Opcode))) &&
       "Opcode is not a memory-accessing opcode!");
 
+  PointerUnion<MachineMemOperand *, MachineMemOperand **> MemRefs;
+  if (MMOs.size() == 1) {
+    MemRefs = MMOs[0];
+  } else {
+    // Allocate: [size_t count][MMO*][MMO*]...
+    size_t AllocSize =
+        sizeof(size_t) + MMOs.size() * sizeof(MachineMemOperand *);
+    void *Buffer = Allocator.Allocate(AllocSize, alignof(size_t));
+    size_t *CountPtr = static_cast<size_t *>(Buffer);
+    *CountPtr = MMOs.size();
+    MachineMemOperand **Array =
+        reinterpret_cast<MachineMemOperand **>(CountPtr + 1);
+    llvm::copy(MMOs, Array);
+    MemRefs = Array;
+  }
+
   // Memoize the node unless it returns a glue result.
   MemIntrinsicSDNode *N;
   if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) {
     FoldingSetNodeID ID;
     AddNodeIDNode(ID, Opcode, VTList, Ops);
     ID.AddInteger(getSyntheticNodeSubclassData<MemIntrinsicSDNode>(
-        Opcode, dl.getIROrder(), VTList, MemVT, MMO));
-    ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
-    ID.AddInteger(MMO->getFlags());
+        Opcode, dl.getIROrder(), VTList, MemVT, MemRefs));
     ID.AddInteger(MemVT.getRawBits());
+    for (const MachineMemOperand *MMO : MMOs) {
+      ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
+      ID.AddInteger(MMO->getFlags());
+    }
     void *IP = nullptr;
     if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
-      cast<MemIntrinsicSDNode>(E)->refineAlignment(MMO);
+      cast<MemIntrinsicSDNode>(E)->refineAlignment(MMOs);
       return SDValue(E, 0);
     }
 
     N = newSDNode<MemIntrinsicSDNode>(Opcode, dl.getIROrder(), dl.getDebugLoc(),
-                                      VTList, MemVT, MMO);
+                                      VTList, MemVT, MemRefs);
     createOperands(N, Ops);
-
-  CSEMap.InsertNode(N, IP);
+    CSEMap.InsertNode(N, IP);
   } else {
     N = newSDNode<MemIntrinsicSDNode>(Opcode, dl.getIROrder(), dl.getDebugLoc(),
-                                      VTList, MemVT, MMO);
+                                      VTList, MemVT, MemRefs);
     createOperands(N, Ops);
   }
   InsertNode(N);
@@ -13131,21 +13158,33 @@ HandleSDNode::~HandleSDNode() {
   DropOperands();
 }
 
-MemSDNode::MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl,
-                     SDVTList VTs, EVT memvt, MachineMemOperand *mmo)
-    : SDNode(Opc, Order, dl, VTs), MemoryVT(memvt), MMO(mmo) {
-  MemSDNodeBits.IsVolatile = MMO->isVolatile();
-  MemSDNodeBits.IsNonTemporal = MMO->isNonTemporal();
-  MemSDNodeBits.IsDereferenceable = MMO->isDereferenceable();
-  MemSDNodeBits.IsInvariant = MMO->isInvariant();
-
-  // We check here that the size of the memory operand fits within the size of
-  // the MMO. This is because the MMO might indicate only a possible address
-  // range instead of specifying the affected memory addresses precisely.
-  assert(
-      (!MMO->getType().isValid() ||
-       TypeSize::isKnownLE(memvt.getStoreSize(), MMO->getSize().getValue())) &&
-      "Size mismatch!");
+MemSDNode::MemSDNode(
+    unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT memvt,
+    PointerUnion<MachineMemOperand *, MachineMemOperand **> memrefs)
+    : SDNode(Opc, Order, dl, VTs), MemoryVT(memvt), MemRefs(memrefs) {
+  bool IsVolatile = false;
+  bool IsNonTemporal = false;
+  bool IsDereferenceable = true;
+  bool IsInvariant = true;
+  for (const MachineMemOperand *MMO : memoperands()) {
+    IsVolatile |= MMO->isVolatile();
+    IsNonTemporal |= MMO->isNonTemporal();
+    IsDereferenceable &= MMO->isDereferenceable();
+    IsInvariant &= MMO->isInvariant();
+  }
+  MemSDNodeBits.IsVolatile = IsVolatile;
+  MemSDNodeBits.IsNonTemporal = IsNonTemporal;
+  MemSDNodeBits.IsDereferenceable = IsDereferenceable;
+  MemSDNodeBits.IsInvariant = IsInvariant;
+
+  // For the single-MMO case, we check here that the size of the memory operand
+  // fits within the size of the MMO. This is because the MMO might indicate
+  // only a possible address range instead of specifying the affected memory
+  // addresses precisely.
+  assert((getNumMemOperands() != 1 || !getMemOperand()->getType().isValid() ||
+          TypeSize::isKnownLE(memvt.getStoreSize(),
+                              getMemOperand()->getSize().getValue())) &&
+         "Size mismatch!");
 }
 
 /// Profile - Gather unique data for the node.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 0739e8e73dfc2..44d6a5c20e0e8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3512,10 +3512,12 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
 /// - they do not need custom argument handling (no
 /// TLI.CollectTargetIntrinsicOperands())
 void SelectionDAGBuilder::visitCallBrIntrinsic(const CallBrInst &I) {
-  TargetLowering::IntrinsicInfo Info;
-  assert(!DAG.getTargetLoweringInfo().getTgtMemIntrinsic(
-             Info, I, DAG.getMachineFunction(), I.getIntrinsicID()) &&
-         "Intrinsic touches memory");
+#ifndef NDEBUG
+  SmallVector<TargetLowering::IntrinsicInfo, 2> Infos;
+  DAG.getTargetLoweringInfo().getTgtMemIntrinsic(
+      Infos, I, DAG.getMachineFunction(), I.getIntrinsicID());
+  assert(Infos.empty() && "Intrinsic touches memory");
+#endif
 
   auto [HasChain, OnlyLoad] = getTargetIntrinsicCallProperties(I);
 
@@ -5483,14 +5485,15 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I,
                                                unsigned Intrinsic) {
   auto [HasChain, OnlyLoad] = getTargetIntrinsicCallProperties(I);
 
-  // Info is set by getTgtMemIntrinsic
-  TargetLowering::IntrinsicInfo Info;
+  // Infos is set by getTgtMemIntrinsic.
+  SmallVector<TargetLowering::IntrinsicInfo> Infos;
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  bool IsTgtMemIntrinsic =
-      TLI.getTgtMemIntrinsic(Info, I, DAG.getMachineFunction(), Intrinsic);
+  TLI.getTgtMemIntrinsic(Infos, I, DAG.getMachineFunction(), Intrinsic);
+  // Use the first (primary) info determines the node opcode.
+  TargetLowering::IntrinsicInfo *Info = !Infos.empty() ? &Infos[0] : nullptr;
 
-  SmallVector<SDValue, 8> Ops = getTargetIntrinsicOperands(
-      I, HasChain, OnlyLoad, IsTgtMemIntrinsic ? &Info : nullptr);
+  SmallVector<SDValue, 8> Ops =
+      getTargetIntrinsicOperands(I, HasChain, OnlyLoad, Info);
   SDVTList VTs = getTargetIntrinsicVTList(I, HasChain);
 
   // Propagate fast-math-flags from IR to node(s).
@@ -5504,26 +5507,32 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I,
 
   // In some cases, custom collection of operands from CallInst I may be needed.
   TLI.CollectTargetIntrinsicOperands(I, Ops, DAG);
-  if (IsTgtMemIntrinsic) {
+  if (!Infos.empty()) {
     // This is target intrinsic that touches memory
-    //
-    // TODO: We currently just fallback to address space 0 if getTgtMemIntrinsic
-    //       didn't yield anything useful.
-    MachinePointerInfo MPI;
-    if (Info.ptrVal)
-      MPI = MachinePointerInfo(Info.ptrVal, Info.offset);
-    else if (Info.fallbackAddressSpace)
-      MPI = MachinePointerInfo(*Info.fallbackAddressSpace);
-    EVT MemVT = Info.memVT;
-    LocationSize Size = LocationSize::precise(Info.size);
-    if (Size.hasValue() && !Size.getValue())
-      Size = LocationSize::precise(MemVT.getStoreSize());
-    Align Alignment = Info.align.value_or(DAG.getEVTAlign(MemVT));
-    MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
-        MPI, Info.flags, Size, Alignment, I.getAAMetadata(), /*Ranges=*/nullptr,
-        Info.ssid, Info.order, Info.failureOrder);
-    Result =
-        DAG.getMemIntrinsicNode(Info.opc, getCurSDLoc(), VTs, Ops, MemVT, MMO);
+    // Create MachineMemOperands for each memory access described by the target.
+    MachineFunction &MF = DAG.getMachineFunction();
+    SmallVector<MachineMemOperand *> MMOs;
+    for (const auto &Info : Infos) {
+      // TODO: We currently just fallback to address space 0 if
+      // getTgtMemIntrinsic didn't yield anything useful.
+      MachinePointerInfo MPI;
+      if (Info.ptrVal)
+        MPI = MachinePointerInfo(Info.ptrVal, Info.offset);
+      else if (Info.fallbackAddressSpace)
+        MPI = MachinePointerInfo(*Info.fallbackAddressSpace);
+      EVT MemVT = Info.memVT;
+      LocationSize Size = LocationSize::precise(Info.size);
+      if (Size.hasValue() && !Size.getValue())
+        Size = LocationSize::precise(MemVT.getStoreSize());
+      Align Alignment = Info.align.value_or(DAG.getEVTAlign(MemVT));
+      MachineMemOperand *MMO = MF.getMachineMemOperand(
+          MPI, Info.flags, Size, Alignment, I.getAAMetadata(),
+          /*Ranges=*/nullptr, Info.ssid, Info.order, Info.failureOrder);
+      MMOs.push_back(MMO);
+    }
+
+    Result = DAG.getMemIntrinsicNode(Info->opc, getCurSDLoc(), VTs, Ops,
+                                     Info->memVT, MMOs);
   } else {
     Result = getTargetNonMemIntrinsicNode(*I.getType(), HasChain, Ops, VTs);
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 965e4f61659db..62573a43e9ab0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -929,7 +929,13 @@ void SDNode::print_details(raw_ostream &OS, const SelectionDAG *G) const {
     OS << ">";
   } else if (const MemSDNode *M = dyn_cast<MemSDNode>(this)) {
     OS << "<";
-    printMemOperand(OS, *M->getMemOperand(), G);
+    bool First = true;
+    for (const MachineMemOperand *MMO : M->memoperands()) {
+      if (!First)
+        OS << ", ";
+      First = false;
+      printMemOperand(OS, *MMO, G);
+    }
     if (auto *A = dyn_cast<AtomicSDNode>(M))
       if (A->getOpcode() == ISD::ATOMIC_LOAD) {
         bool doExt = true;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index e092061fb5e04..f2bb098c55a7a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -3544,7 +3544,8 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
     }
     case OPC_RecordMemRef:
       if (auto *MN = dyn_cast<MemSDNode>(N))
-        MatchedMemRefs.push_back(MN->getMemOperand());
+        MatchedMemRefs.append(MN->memoperands().begin(),
+                              MN->memoperands().end());
       else {
         LLVM_DEBUG(dbgs() << "Expected MemSDNode "; N->dump(CurDAG);
                    dbgs() << '\n');



More information about the llvm-commits mailing list