[Mlir-commits] [mlir] [MLIR][SROA][Mem2Reg] Add data layout to interface methods (PR #85644)

Christian Ulmann llvmlistbot at llvm.org
Mon Mar 18 07:09:37 PDT 2024


https://github.com/Dinistro created https://github.com/llvm/llvm-project/pull/85644

This commit expends the Mem2Reg and SROA interface methods with passed in handles to a `DataLayout` structure. This is done to avoid superfluous retreiving of data layouts during each conversion of intrinsics.

This change, additionally, enables subsequent changes to make the LLVM dialect implementation of these interfaces type agnostic.

>From 76415d98f72c6cbc6b2cf221caad16ee52150cd0 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 18 Mar 2024 13:59:06 +0000
Subject: [PATCH] [MLIR][SROA][Mem2Reg] Add data layout to interface methods

This commit expends the Mem2Reg and SROA interface methods with passed
in handles to a `DataLayout` structure. This is done to avoid
superfluous retreiving of data layouts during each conversion of
intrinsics.

This change, additionally, enables subsequent changes to make the LLVM
dialect implementation of these interfaces type agnostic.
---
 .../mlir/Interfaces/MemorySlotInterfaces.td   |  42 +++--
 mlir/include/mlir/Transforms/Mem2Reg.h        |   2 +-
 mlir/include/mlir/Transforms/SROA.h           |   3 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 153 +++++++++++-------
 .../Dialect/MemRef/IR/MemRefMemorySlot.cpp    |  18 ++-
 mlir/lib/Transforms/Mem2Reg.cpp               |  24 ++-
 mlir/lib/Transforms/SROA.cpp                  |  30 ++--
 7 files changed, 163 insertions(+), 109 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 9ffa709cc5bfd4..adf324bf559791 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -83,11 +83,7 @@ def PromotableAllocationOpInterface
 def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
   let description = [{
     Describes an operation that can load from memory slots and/or store
-    to memory slots. Loads and stores must be of whole values of the same
-    type as the slot itself.
-
-    For a memory operation on a slot to be valid, it must operate on the slot
-    pointer *only as a pointer to an element of the type of the slot*.
+    to memory slots.
 
     If the same operation does both loads and stores on the same slot, the
     load must semantically happen first.
@@ -142,7 +138,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       }], "bool", "canUsesBeRemoved",
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
-           "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses)
+           "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses,
+           "const ::mlir::DataLayout &":$datalayout)
     >,
     InterfaceMethod<[{
         Transforms IR to ensure that the current operation does not use the
@@ -197,7 +194,8 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         No IR mutation is allowed in this method.
       }], "bool", "canUsesBeRemoved",
       (ins "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
-           "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses)
+           "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses,
+           "const ::mlir::DataLayout &":$datalayout)
     >,
     InterfaceMethod<[{
         Transforms IR to ensure that the current operation does not use the
@@ -285,29 +283,28 @@ def DestructurableAllocationOpInterface
 def SafeMemorySlotAccessOpInterface
   : OpInterface<"SafeMemorySlotAccessOpInterface"> {
   let description = [{
-    Describes operations using memory slots in a type-safe manner.
+    Describes operations using memory slots in a safe manner.
   }];
   let cppNamespace = "::mlir";
 
   let methods = [
     InterfaceMethod<[{
         Returns whether all accesses in this operation to the provided slot are
-        done in a type-safe manner. To be type-safe, the access must only load
-        the value in this type as the type of the slot, and without assuming any
-        context around the slot. For example, a type-safe load must not load
-        outside the bounds of the slot.
+        done in a safe manner. To be safe, the access most only access the slot
+        inside the bounds that its type implies.
 
-        If the type-safety of the accesses depends on the type-safety of the
-        accesses to further memory slots, the result of this method will be
-        conditioned to the type-safety of the accesses to the slots added by
-        this method to `mustBeSafelyUsed`.
+        If the safety of the accesses depends on the safety of the accesses to
+        further memory slots, the result of this method will be conditioned to
+        the safety of the accesses to the slots added by this method to
+        `mustBeSafelyUsed`.
 
         No IR mutation is allowed in this method.
       }],
       "::mlir::LogicalResult",
       "ensureOnlySafeAccesses",
       (ins "const ::mlir::MemorySlot &":$slot,
-           "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed)
+           "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed,
+           "const ::mlir::DataLayout &":$dataLayout)
     >
   ];
 }
@@ -323,13 +320,12 @@ def DestructurableAccessorOpInterface
     InterfaceMethod<[{
         For a given destructurable memory slot, returns whether this operation can
         rewire its uses of the slot to use the slots generated after
-        destructuring. This may involve creating new operations, and usually
-        amounts to checking if the pointer types match.
+        destructuring. This may involve creating new operations.
 
         This method must also register the indices it will access within the
         `usedIndices` set. If the accessor generates new slots mapping to
         subelements, they must be registered in `mustBeSafelyUsed` to ensure
-        they are used in a locally type-safe manner.
+        they are used in a safe manner.
 
         No IR mutation is allowed in this method.
       }],
@@ -337,7 +333,8 @@ def DestructurableAccessorOpInterface
       "canRewire",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
-           "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed)
+           "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed,
+           "const ::mlir::DataLayout &":$dataLayout)
     >,
     InterfaceMethod<[{
         Rewires the use of a slot to the generated subslots, without deleting
@@ -351,7 +348,8 @@ def DestructurableAccessorOpInterface
       "rewire",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot> &":$subslots,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::RewriterBase &":$rewriter,
+           "const ::mlir::DataLayout &":$dataLayout)
     >
   ];
 }
diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index d145f7ed437582..ed10644e26a519 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -27,7 +27,7 @@ struct Mem2RegStatistics {
 /// at least one memory slot was promoted.
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
-                        RewriterBase &rewriter,
+                        RewriterBase &rewriter, const DataLayout &dataLayout,
                         Mem2RegStatistics statistics = {});
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index 1af1fe930723f1..d09a7989edeab2 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -31,7 +31,8 @@ struct SROAStatistics {
 /// failure if no slot was destructured.
 LogicalResult tryToDestructureMemorySlots(
     ArrayRef<DestructurableAllocationOpInterface> allocators,
-    RewriterBase &rewriter, SROAStatistics statistics = {});
+    RewriterBase &rewriter, const DataLayout &dataLayout,
+    SROAStatistics statistics = {});
 
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index bf24194d03ddb2..a0eabc2250f7e5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -132,7 +132,8 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
 
 bool LLVM::LoadOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   if (blockingUses.size() != 1)
     return false;
   Value blockingUse = (*blockingUses.begin())->get();
@@ -155,7 +156,8 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
 
 bool LLVM::StoreOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   if (blockingUses.size() != 1)
     return false;
   Value blockingUse = (*blockingUses.begin())->get();
@@ -181,12 +183,14 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
 }
 
 LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
-    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
   return success(getAddr() != slot.ptr || getType() == slot.elemType);
 }
 
 LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
-    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
   return success(getAddr() != slot.ptr ||
                  getValue().getType() == slot.elemType);
 }
@@ -206,7 +210,8 @@ static bool forwardToUsers(Operation *op,
 
 bool LLVM::BitcastOp::canUsesBeRemoved(
     const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   return forwardToUsers(*this, newBlockingUses);
 }
 
@@ -217,7 +222,8 @@ DeletionKind LLVM::BitcastOp::removeBlockingUses(
 
 bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
     const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   return forwardToUsers(*this, newBlockingUses);
 }
 
@@ -228,7 +234,8 @@ DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
 
 bool LLVM::LifetimeStartOp::canUsesBeRemoved(
     const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   return true;
 }
 
@@ -239,7 +246,8 @@ DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
 
 bool LLVM::LifetimeEndOp::canUsesBeRemoved(
     const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   return true;
 }
 
@@ -250,7 +258,8 @@ DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
 
 bool LLVM::InvariantStartOp::canUsesBeRemoved(
     const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   return true;
 }
 
@@ -261,7 +270,8 @@ DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
 
 bool LLVM::InvariantEndOp::canUsesBeRemoved(
     const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   return true;
 }
 
@@ -272,7 +282,8 @@ DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
 
 bool LLVM::DbgDeclareOp::canUsesBeRemoved(
     const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   return true;
 }
 
@@ -283,7 +294,8 @@ DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
 
 bool LLVM::DbgValueOp::canUsesBeRemoved(
     const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   // There is only one operand that we can remove the use of.
   if (blockingUses.size() != 1)
     return false;
@@ -318,7 +330,8 @@ static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
 
 bool LLVM::GEPOp::canUsesBeRemoved(
     const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   // GEP can be removed as long as it is a no-op and its users can be removed.
   if (!hasAllZeroIndices(*this))
     return false;
@@ -337,7 +350,8 @@ static bool isFirstIndexZero(LLVM::GEPOp gep) {
 }
 
 LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
-    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
   if (getBase() != slot.ptr)
     return success();
   if (slot.elemType != getElemType())
@@ -353,9 +367,9 @@ LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
 
 bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
                             SmallPtrSetImpl<Attribute> &usedIndices,
-                            SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
-  auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(getBase().getType());
-  if (!basePtrType)
+                            SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                            const DataLayout &dataLayout) {
+  if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
     return false;
 
   if (getBase() != slot.ptr || slot.elemType != getElemType())
@@ -378,7 +392,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
                                  DenseMap<Attribute, MemorySlot> &subslots,
-                                 RewriterBase &rewriter) {
+                                 RewriterBase &rewriter,
+                                 const DataLayout &dataLayout) {
   IntegerAttr firstLevelIndex =
       llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
   const MemorySlot &newSlot = subslots.at(firstLevelIndex);
@@ -447,7 +462,7 @@ std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
 /// of the bounds of the given slot, on a best-effort basis.
 template <class MemIntr>
 static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
-                                           DataLayout &dataLayout) {
+                                           const DataLayout &dataLayout) {
   if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
       op.getDst() != slot.ptr)
     return false;
@@ -512,7 +527,8 @@ Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
 
 bool LLVM::MemsetOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   // TODO: Support non-integer types.
   bool canConvertType =
       TypeSwitch<Type, bool>(slot.elemType)
@@ -526,8 +542,7 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
   if (getIsVolatile())
     return false;
 
-  DataLayout layout = DataLayout::closest(*this);
-  return getStaticMemIntrLen(*this) == layout.getTypeSize(slot.elemType);
+  return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
 }
 
 DeletionKind LLVM::MemsetOp::removeBlockingUses(
@@ -537,14 +552,15 @@ DeletionKind LLVM::MemsetOp::removeBlockingUses(
 }
 
 LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
-    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
-  DataLayout dataLayout = DataLayout::closest(*this);
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
   return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
 }
 
 bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
                                SmallPtrSetImpl<Attribute> &usedIndices,
-                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                               const DataLayout &dataLayout) {
   if (&slot.elemType.getDialect() != getOperation()->getDialect())
     return false;
 
@@ -558,13 +574,13 @@ bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
   if (!areAllIndicesI32(slot))
     return false;
 
-  DataLayout dataLayout = DataLayout::closest(*this);
   return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
 }
 
 DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
                                     DenseMap<Attribute, MemorySlot> &subslots,
-                                    RewriterBase &rewriter) {
+                                    RewriterBase &rewriter,
+                                    const DataLayout &dataLayout) {
   std::optional<DenseMap<Attribute, Type>> types =
       slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
 
@@ -579,7 +595,6 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
     packed = structType.isPacked();
 
   Type i32 = IntegerType::get(getContext(), 32);
-  DataLayout dataLayout = DataLayout::closest(*this);
   uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
   uint64_t covered = 0;
   for (size_t i = 0; i < types->size(); i++) {
@@ -642,7 +657,8 @@ template <class MemcpyLike>
 static bool
 memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
                        const SmallPtrSetImpl<OpOperand *> &blockingUses,
-                       SmallVectorImpl<OpOperand *> &newBlockingUses) {
+                       SmallVectorImpl<OpOperand *> &newBlockingUses,
+                       const DataLayout &dataLayout) {
   // If source and destination are the same, memcpy behavior is undefined and
   // memmove is a no-op. Because there is no memory change happening here,
   // simplifying such operations is left to canonicalization.
@@ -652,8 +668,7 @@ memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
   if (op.getIsVolatile())
     return false;
 
-  DataLayout layout = DataLayout::closest(op);
-  return getStaticMemIntrLen(op) == layout.getTypeSize(slot.elemType);
+  return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
 }
 
 template <class MemcpyLike>
@@ -681,7 +696,8 @@ memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
 template <class MemcpyLike>
 static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
                             SmallPtrSetImpl<Attribute> &usedIndices,
-                            SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+                            SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                            const DataLayout &dataLayout) {
   if (op.getIsVolatile())
     return false;
 
@@ -693,7 +709,6 @@ static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
     return false;
 
   // Only full copies are supported.
-  DataLayout dataLayout = DataLayout::closest(op);
   if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(slot.elemType))
     return false;
 
@@ -733,15 +748,13 @@ void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
 /// Rewires a memcpy-like operation. Only copies to or from the full slot are
 /// supported.
 template <class MemcpyLike>
-static DeletionKind memcpyRewire(MemcpyLike op,
-                                 const DestructurableMemorySlot &slot,
-                                 DenseMap<Attribute, MemorySlot> &subslots,
-                                 RewriterBase &rewriter) {
+static DeletionKind
+memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
+             DenseMap<Attribute, MemorySlot> &subslots, RewriterBase &rewriter,
+             const DataLayout &dataLayout) {
   if (subslots.empty())
     return DeletionKind::Delete;
 
-  DataLayout layout = DataLayout::closest(op);
-
   assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
   bool isDst = slot.ptr == op.getDst();
 
@@ -772,7 +785,7 @@ static DeletionKind memcpyRewire(MemcpyLike op,
         isDst ? op.getSrc() : op.getDst(), gepIndices);
 
     // Then create a new memcpy out of this source pointer.
-    createMemcpyLikeToReplace(rewriter, layout, op,
+    createMemcpyLikeToReplace(rewriter, dataLayout, op,
                               isDst ? subslot.ptr : subslotPtrInOther,
                               isDst ? subslotPtrInOther : subslot.ptr,
                               subslot.elemType, op.getIsVolatile());
@@ -798,8 +811,10 @@ Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
 
 bool LLVM::MemcpyOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
-  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses);
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
+  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+                                dataLayout);
 }
 
 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
@@ -810,20 +825,24 @@ DeletionKind LLVM::MemcpyOp::removeBlockingUses(
 }
 
 LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
-    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
   return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
 }
 
 bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
                                SmallPtrSetImpl<Attribute> &usedIndices,
-                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
-  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed);
+                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                               const DataLayout &dataLayout) {
+  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+                         dataLayout);
 }
 
 DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
                                     DenseMap<Attribute, MemorySlot> &subslots,
-                                    RewriterBase &rewriter) {
-  return memcpyRewire(*this, slot, subslots, rewriter);
+                                    RewriterBase &rewriter,
+                                    const DataLayout &dataLayout) {
+  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
 }
 
 bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
@@ -841,8 +860,10 @@ Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
 
 bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
-  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses);
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
+  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+                                dataLayout);
 }
 
 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
@@ -853,22 +874,26 @@ DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
 }
 
 LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
-    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
   return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
 }
 
 bool LLVM::MemcpyInlineOp::canRewire(
     const DestructurableMemorySlot &slot,
     SmallPtrSetImpl<Attribute> &usedIndices,
-    SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
-  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed);
+    SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
+  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+                         dataLayout);
 }
 
 DeletionKind
 LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
                              DenseMap<Attribute, MemorySlot> &subslots,
-                             RewriterBase &rewriter) {
-  return memcpyRewire(*this, slot, subslots, rewriter);
+                             RewriterBase &rewriter,
+                             const DataLayout &dataLayout) {
+  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
 }
 
 bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
@@ -886,8 +911,10 @@ Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
 
 bool LLVM::MemmoveOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
-  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses);
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
+  return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
+                                dataLayout);
 }
 
 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
@@ -898,20 +925,24 @@ DeletionKind LLVM::MemmoveOp::removeBlockingUses(
 }
 
 LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
-    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+    const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+    const DataLayout &dataLayout) {
   return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
 }
 
 bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
                                 SmallPtrSetImpl<Attribute> &usedIndices,
-                                SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
-  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed);
+                                SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                                const DataLayout &dataLayout) {
+  return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
+                         dataLayout);
 }
 
 DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
                                      DenseMap<Attribute, MemorySlot> &subslots,
-                                     RewriterBase &rewriter) {
-  return memcpyRewire(*this, slot, subslots, rewriter);
+                                     RewriterBase &rewriter,
+                                     const DataLayout &dataLayout) {
+  return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 561b8619032cce..7be4056fb2fc80 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -170,7 +170,8 @@ Value memref::LoadOp::getStored(const MemorySlot &slot,
 
 bool memref::LoadOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   if (blockingUses.size() != 1)
     return false;
   Value blockingUse = (*blockingUses.begin())->get();
@@ -210,7 +211,8 @@ static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
 
 bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
                                SmallPtrSetImpl<Attribute> &usedIndices,
-                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                               const DataLayout &dataLayout) {
   if (slot.ptr != getMemRef())
     return false;
   Attribute index = getAttributeIndexFromIndexOperands(
@@ -223,7 +225,8 @@ bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
                                     DenseMap<Attribute, MemorySlot> &subslots,
-                                    RewriterBase &rewriter) {
+                                    RewriterBase &rewriter,
+                                    const DataLayout &dataLayout) {
   Attribute index = getAttributeIndexFromIndexOperands(
       getContext(), getIndices(), getMemRefType());
   const MemorySlot &memorySlot = subslots.at(index);
@@ -247,7 +250,8 @@ Value memref::StoreOp::getStored(const MemorySlot &slot,
 
 bool memref::StoreOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+    SmallVectorImpl<OpOperand *> &newBlockingUses,
+    const DataLayout &dataLayout) {
   if (blockingUses.size() != 1)
     return false;
   Value blockingUse = (*blockingUses.begin())->get();
@@ -263,7 +267,8 @@ DeletionKind memref::StoreOp::removeBlockingUses(
 
 bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
                                 SmallPtrSetImpl<Attribute> &usedIndices,
-                                SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+                                SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+                                const DataLayout &dataLayout) {
   if (slot.ptr != getMemRef() || getValue() == slot.ptr)
     return false;
   Attribute index = getAttributeIndexFromIndexOperands(
@@ -276,7 +281,8 @@ bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
 
 DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
                                      DenseMap<Attribute, MemorySlot> &subslots,
-                                     RewriterBase &rewriter) {
+                                     RewriterBase &rewriter,
+                                     const DataLayout &dataLayout) {
   Attribute index = getAttributeIndexFromIndexOperands(
       getContext(), getIndices(), getMemRefType());
   const MemorySlot &memorySlot = subslots.at(index);
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 84ac69b4514b4f..80e3b790163297 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/Mem2Reg.h"
+#include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dominance.h"
@@ -117,8 +118,9 @@ struct MemorySlotPromotionInfo {
 /// promotion. This does not mutate IR.
 class MemorySlotPromotionAnalyzer {
 public:
-  MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
-      : slot(slot), dominance(dominance) {}
+  MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance,
+                              const DataLayout &dataLayout)
+      : slot(slot), dominance(dominance), dataLayout(dataLayout) {}
 
   /// Computes the information for slot promotion if promotion is possible,
   /// returns nothing otherwise.
@@ -153,6 +155,7 @@ class MemorySlotPromotionAnalyzer {
 
   MemorySlot slot;
   DominanceInfo &dominance;
+  const DataLayout &dataLayout;
 };
 
 /// The MemorySlotPromoter handles the state of promoting a memory slot. It
@@ -267,10 +270,12 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
     // If the operation decides it cannot deal with removing the blocking uses,
     // promotion must fail.
     if (auto promotable = dyn_cast<PromotableOpInterface>(user)) {
-      if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses))
+      if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
+                                       dataLayout))
         return failure();
     } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
-      if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses))
+      if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
+                                       dataLayout))
         return failure();
     } else {
       // An operation that has blocking uses must be promoted. If it is not
@@ -610,7 +615,8 @@ void MemorySlotPromoter::promoteSlot() {
 
 LogicalResult mlir::tryToPromoteMemorySlots(
     ArrayRef<PromotableAllocationOpInterface> allocators,
-    RewriterBase &rewriter, Mem2RegStatistics statistics) {
+    RewriterBase &rewriter, const DataLayout &dataLayout,
+    Mem2RegStatistics statistics) {
   bool promotedAny = false;
 
   for (PromotableAllocationOpInterface allocator : allocators) {
@@ -619,7 +625,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
         continue;
 
       DominanceInfo dominance;
-      MemorySlotPromotionAnalyzer analyzer(slot, dominance);
+      MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
       std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
       if (info) {
         MemorySlotPromoter(slot, allocator, rewriter, dominance,
@@ -661,8 +667,12 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
           allocators.emplace_back(allocator);
         });
 
+        auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
+        const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
+
         // Attempt promoting until no promotion succeeds.
-        if (failed(tryToPromoteMemorySlots(allocators, rewriter, statistics)))
+        if (failed(tryToPromoteMemorySlots(allocators, rewriter, dataLayout,
+                                           statistics)))
           break;
 
         changed = true;
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 6111489bdebefd..f24cbb7b1725cc 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/SROA.h"
+#include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/Passes.h"
@@ -42,7 +43,8 @@ struct MemorySlotDestructuringInfo {
 /// nothing if the slot cannot be destructured or if there is no useful work to
 /// be done.
 static std::optional<MemorySlotDestructuringInfo>
-computeDestructuringInfo(DestructurableMemorySlot &slot) {
+computeDestructuringInfo(DestructurableMemorySlot &slot,
+                         const DataLayout &dataLayout) {
   assert(isa<DestructurableTypeInterface>(slot.elemType));
 
   if (slot.ptr.use_empty())
@@ -62,7 +64,8 @@ computeDestructuringInfo(DestructurableMemorySlot &slot) {
   for (OpOperand &use : slot.ptr.getUses()) {
     if (auto accessor =
             dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
-      if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist)) {
+      if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist,
+                             dataLayout)) {
         info.accessors.push_back(accessor);
         continue;
       }
@@ -82,8 +85,8 @@ computeDestructuringInfo(DestructurableMemorySlot &slot) {
       Operation *subslotUser = subslotUse.getOwner();
 
       if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
-        if (succeeded(memOp.ensureOnlySafeAccesses(mustBeUsedSafely,
-                                                   usedSafelyWorklist)))
+        if (succeeded(memOp.ensureOnlySafeAccesses(
+                mustBeUsedSafely, usedSafelyWorklist, dataLayout)))
           continue;
 
       // If it cannot be shown that the operation uses the slot safely, maybe it
@@ -110,7 +113,7 @@ computeDestructuringInfo(DestructurableMemorySlot &slot) {
     SmallVector<OpOperand *> newBlockingUses;
     // If the operation decides it cannot deal with removing the blocking uses,
     // destructuring must fail.
-    if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses))
+    if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
       return {};
 
     // Then, register any new blocking uses for coming operations.
@@ -132,6 +135,7 @@ computeDestructuringInfo(DestructurableMemorySlot &slot) {
 static void destructureSlot(DestructurableMemorySlot &slot,
                             DestructurableAllocationOpInterface allocator,
                             RewriterBase &rewriter,
+                            const DataLayout &dataLayout,
                             MemorySlotDestructuringInfo &info,
                             const SROAStatistics &statistics) {
   RewriterBase::InsertionGuard guard(rewriter);
@@ -158,7 +162,8 @@ static void destructureSlot(DestructurableMemorySlot &slot,
   for (Operation *toRewire : llvm::reverse(usersToRewire)) {
     rewriter.setInsertionPointAfter(toRewire);
     if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
-      if (accessor.rewire(slot, subslots, rewriter) == DeletionKind::Delete)
+      if (accessor.rewire(slot, subslots, rewriter, dataLayout) ==
+          DeletionKind::Delete)
         toErase.push_back(accessor);
       continue;
     }
@@ -186,17 +191,18 @@ static void destructureSlot(DestructurableMemorySlot &slot,
 
 LogicalResult mlir::tryToDestructureMemorySlots(
     ArrayRef<DestructurableAllocationOpInterface> allocators,
-    RewriterBase &rewriter, SROAStatistics statistics) {
+    RewriterBase &rewriter, const DataLayout &dataLayout,
+    SROAStatistics statistics) {
   bool destructuredAny = false;
 
   for (DestructurableAllocationOpInterface allocator : allocators) {
     for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
       std::optional<MemorySlotDestructuringInfo> info =
-          computeDestructuringInfo(slot);
+          computeDestructuringInfo(slot, dataLayout);
       if (!info)
         continue;
 
-      destructureSlot(slot, allocator, rewriter, *info, statistics);
+      destructureSlot(slot, allocator, rewriter, dataLayout, *info, statistics);
       destructuredAny = true;
     }
   }
@@ -215,6 +221,8 @@ struct SROA : public impl::SROABase<SROA> {
     SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
                               &maxSubelementAmount};
 
+    auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
+    const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
     bool changed = false;
 
     for (Region &region : scopeOp->getRegions()) {
@@ -235,8 +243,8 @@ struct SROA : public impl::SROABase<SROA> {
           allocators.emplace_back(allocator);
         });
 
-        if (failed(
-                tryToDestructureMemorySlots(allocators, rewriter, statistics)))
+        if (failed(tryToDestructureMemorySlots(allocators, rewriter, dataLayout,
+                                               statistics)))
           break;
 
         changed = true;



More information about the Mlir-commits mailing list