[Mlir-commits] [mlir] [MLIR][SROA] Reuse allocators to avoid rewalking the IR (PR #91971)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 13 07:25:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Christian Ulmann (Dinistro)

<details>
<summary>Changes</summary>

This commit extends the SROA interfaces to ensure the interface instantiations can communicate newly created allocators to the algorithm. This ensures that the SROA implementation does no longer require re-walking the IR to find new allocators.

---
Full diff: https://github.com/llvm/llvm-project/pull/91971.diff


8 Files Affected:

- (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+8-2) 
- (modified) mlir/include/mlir/Transforms/SROA.h (+4-2) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+8-5) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+8-5) 
- (modified) mlir/lib/Transforms/SROA.cpp (+56-30) 
- (added) mlir/test/Transforms/sroa.mlir (+31) 
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+67-12) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+3-2) 


``````````diff
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index e2409cbec5fde..29960f457373c 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -298,14 +298,20 @@ def DestructurableAllocationOpInterface
       "destructure",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
-           "::mlir::OpBuilder &":$builder)
+           "::mlir::OpBuilder &":$builder,
+           "::mlir::SmallVectorImpl<::mlir::DestructurableAllocationOpInterface> &":
+             $newAllocators)
     >,
     InterfaceMethod<[{
         Hook triggered once the destructuring of a slot is complete, meaning the
         original slot is no longer being refered to and could be deleted.
         This will only be called for slots declared by this operation.
+
+        Must return a new destructurable allocation op if this operation
+        produced multiple destructurable slots, nullopt otherwise.
       }],
-      "void", "handleDestructuringComplete",
+      "::std::optional<::mlir::DestructurableAllocationOpInterface>",
+        "handleDestructuringComplete",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "::mlir::OpBuilder &":$builder)
     >,
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index fa84fb1eae73a..d48f809d30766 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -27,8 +27,10 @@ struct SROAStatistics {
   llvm::Statistic *maxSubelementAmount = nullptr;
 };
 
-/// Attempts to destructure the slots of destructurable allocators. Returns
-/// failure if no slot was destructured.
+/// Attempts to destructure the slots of destructurable allocators. Iteratively
+/// retries the destructuring of all slots as destructuring one slot might
+/// enable subsequent destructuring. Returns failure if no slot was
+/// destructured.
 LogicalResult tryToDestructureMemorySlots(
     ArrayRef<DestructurableAllocationOpInterface> allocators,
     OpBuilder &builder, const DataLayout &dataLayout,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 4fdf847a559ce..3f1e5b1773bf7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -77,10 +77,10 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
                                    *destructuredType}};
 }
 
-DenseMap<Attribute, MemorySlot>
-LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
-                            const SmallPtrSetImpl<Attribute> &usedIndices,
-                            OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
+    const DestructurableMemorySlot &slot,
+    const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
   assert(slot.ptr == getResult());
   builder.setInsertionPointAfter(*this);
 
@@ -92,16 +92,19 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
     auto subAlloca = builder.create<LLVM::AllocaOp>(
         getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
         getArraySize());
+    newAllocators.push_back(subAlloca);
     slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
   }
 
   return slotMap;
 }
 
-void LLVM::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+LLVM::AllocaOp::handleDestructuringComplete(
     const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
   this->erase();
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index e30598e6878f4..631dee2d40538 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -126,10 +126,10 @@ memref::AllocaOp::getDestructurableSlots() {
       DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
 }
 
-DenseMap<Attribute, MemorySlot>
-memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
-                              const SmallPtrSetImpl<Attribute> &usedIndices,
-                              OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
+    const DestructurableMemorySlot &slot,
+    const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
   builder.setInsertionPointAfter(*this);
 
   DenseMap<Attribute, MemorySlot> slotMap;
@@ -139,6 +139,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
     Type elemType = memrefType.getTypeAtIndex(usedIndex);
     MemRefType elemPtr = MemRefType::get({}, elemType);
     auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
+    newAllocators.push_back(subAlloca);
     slotMap.try_emplace<MemorySlot>(usedIndex,
                                     {subAlloca.getResult(), elemType});
   }
@@ -146,10 +147,12 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
   return slotMap;
 }
 
-void memref::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+memref::AllocaOp::handleDestructuringComplete(
     const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
   this->erase();
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 4e28fa687ffd4..67cbade07bc94 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -132,16 +132,17 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
 /// Performs the destructuring of a destructible slot given associated
 /// destructuring information. The provided slot will be destructured in
 /// subslots as specified by its allocator.
-static void destructureSlot(DestructurableMemorySlot &slot,
-                            DestructurableAllocationOpInterface allocator,
-                            OpBuilder &builder, const DataLayout &dataLayout,
-                            MemorySlotDestructuringInfo &info,
-                            const SROAStatistics &statistics) {
+static void destructureSlot(
+    DestructurableMemorySlot &slot,
+    DestructurableAllocationOpInterface allocator, OpBuilder &builder,
+    const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators,
+    const SROAStatistics &statistics) {
   OpBuilder::InsertionGuard guard(builder);
 
   builder.setInsertionPointToStart(slot.ptr.getParentBlock());
   DenseMap<Attribute, MemorySlot> subslots =
-      allocator.destructure(slot, info.usedIndices, builder);
+      allocator.destructure(slot, info.usedIndices, builder, newAllocators);
 
   if (statistics.slotsWithMemoryBenefit &&
       slot.elementPtrs.size() != info.usedIndices.size())
@@ -185,7 +186,11 @@ static void destructureSlot(DestructurableMemorySlot &slot,
   if (statistics.destructuredAmount)
     (*statistics.destructuredAmount)++;
 
-  allocator.handleDestructuringComplete(slot, builder);
+  std::optional<DestructurableAllocationOpInterface> newAllocator =
+      allocator.handleDestructuringComplete(slot, builder);
+  // Add newly created allocators to the worklist for further processing.
+  if (newAllocator)
+    newAllocators.push_back(*newAllocator);
 }
 
 LogicalResult mlir::tryToDestructureMemorySlots(
@@ -194,16 +199,44 @@ LogicalResult mlir::tryToDestructureMemorySlots(
     SROAStatistics statistics) {
   bool destructuredAny = false;
 
-  for (DestructurableAllocationOpInterface allocator : allocators) {
-    for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
-      std::optional<MemorySlotDestructuringInfo> info =
-          computeDestructuringInfo(slot, dataLayout);
-      if (!info)
-        continue;
+  SmallVector<DestructurableAllocationOpInterface> workList(allocators.begin(),
+                                                            allocators.end());
+  SmallVector<DestructurableAllocationOpInterface> newWorkList;
+  newWorkList.reserve(allocators.size());
+  // Destructuring a slot can allow for further destructuring of other
+  // slots, destructuring is tried until no destructuring succeeds.
+  while (true) {
+    bool changesInThisRound = false;
+
+    for (DestructurableAllocationOpInterface allocator : workList) {
+      bool destructuredAnySlot = false;
+      for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
+        std::optional<MemorySlotDestructuringInfo> info =
+            computeDestructuringInfo(slot, dataLayout);
+        if (!info)
+          continue;
 
-      destructureSlot(slot, allocator, builder, dataLayout, *info, statistics);
-      destructuredAny = true;
+        destructureSlot(slot, allocator, builder, dataLayout, *info,
+                        newWorkList, statistics);
+        destructuredAnySlot = true;
+
+        // A break is required, since destructuring a slot may invalidate the
+        // remaning slots of an allocator.
+        break;
+      }
+      if (!destructuredAnySlot)
+        newWorkList.push_back(allocator);
+      changesInThisRound |= destructuredAnySlot;
     }
+
+    if (!changesInThisRound)
+      break;
+    destructuredAny |= changesInThisRound;
+
+    // Swap the vector's backing memory and clear the entries in newWorkList
+    // afterwards. This ensures that additional heap allocations can be avoided.
+    workList.swap(newWorkList);
+    newWorkList.clear();
   }
 
   return success(destructuredAny);
@@ -230,23 +263,16 @@ struct SROA : public impl::SROABase<SROA> {
 
       OpBuilder builder(&region.front(), region.front().begin());
 
-      // Destructuring a slot can allow for further destructuring of other
-      // slots, destructuring is tried until no destructuring succeeds.
-      while (true) {
-        SmallVector<DestructurableAllocationOpInterface> allocators;
-        // Build a list of allocators to attempt to destructure the slots of.
-        // TODO: Update list on the fly to avoid repeated visiting of the same
-        // allocators.
-        region.walk([&](DestructurableAllocationOpInterface allocator) {
-          allocators.emplace_back(allocator);
-        });
-
-        if (failed(tryToDestructureMemorySlots(allocators, builder, dataLayout,
-                                               statistics)))
-          break;
+      SmallVector<DestructurableAllocationOpInterface> allocators;
+      // Build a list of allocators to attempt to destructure the slots of.
+      region.walk([&](DestructurableAllocationOpInterface allocator) {
+        allocators.emplace_back(allocator);
+      });
 
+      // Attempt to destructure as many slots as possible.
+      if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
+                                                statistics)))
         changed = true;
-      }
     }
     if (!changed)
       markAllAnalysesPreserved();
diff --git a/mlir/test/Transforms/sroa.mlir b/mlir/test/Transforms/sroa.mlir
new file mode 100644
index 0000000000000..c9e80a6cf8dd1
--- /dev/null
+++ b/mlir/test/Transforms/sroa.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(sroa))' --split-input-file | FileCheck %s
+
+// Verifies that allocators with mutliple slots are handled properly.
+
+// CHECK-LABEL: func.func @multi_slot_alloca
+func.func @multi_slot_alloca() -> (i32, i32) {
+  %0 = arith.constant 0 : index
+  %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+  // CHECK-COUNT-2: test.multi_slot_alloca : () -> memref<i32>
+  %3 = memref.load %1[%0] {first}: memref<2xi32>
+  %4 = memref.load %2[%0] {second} : memref<4xi32>
+  return %3, %4 : i32, i32
+}
+
+// -----
+
+// Verifies that a multi slot allocator can be partially destructured.
+
+func.func private @consumer(memref<2xi32>)
+
+// CHECK-LABEL: func.func @multi_slot_alloca_only_second
+func.func @multi_slot_alloca_only_second() -> (i32, i32) {
+  %0 = arith.constant 0 : index
+  // CHECK: test.multi_slot_alloca : () -> memref<2xi32>
+  // CHECK: test.multi_slot_alloca : () -> memref<i32>
+  %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+  func.call @consumer(%1) : (memref<2xi32>) -> ()
+  %3 = memref.load %1[%0] : memref<2xi32>
+  %4 = memref.load %2[%0] : memref<4xi32>
+  return %3, %4 : i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index d22d48b139a04..c8949e8d1af72 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1199,22 +1199,20 @@ void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
   // Not relevant for testing.
 }
 
-std::optional<PromotableAllocationOpInterface>
-TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
-                                             Value defaultValue,
-                                             OpBuilder &builder) {
-  if (defaultValue && defaultValue.use_empty())
-    defaultValue.getDefiningOp()->erase();
+/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
+static std::optional<TestMultiSlotAlloca>
+createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
+                                TestMultiSlotAlloca oldOp) {
 
-  if (getNumResults() == 1) {
-    erase();
+  if (oldOp.getNumResults() == 1) {
+    oldOp.erase();
     return std::nullopt;
   }
 
   SmallVector<Type> newTypes;
   SmallVector<Value> remainingValues;
 
-  for (Value oldResult : getResults()) {
+  for (Value oldResult : oldOp.getResults()) {
     if (oldResult == slot.ptr)
       continue;
     remainingValues.push_back(oldResult);
@@ -1222,12 +1220,69 @@ TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
   }
 
   OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPoint(*this);
-  auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
+  builder.setInsertionPoint(oldOp);
+  auto replacement =
+      builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
   for (auto [oldResult, newResult] :
        llvm::zip_equal(remainingValues, replacement.getResults()))
     oldResult.replaceAllUsesWith(newResult);
 
-  erase();
+  oldOp.erase();
   return replacement;
 }
+
+std::optional<PromotableAllocationOpInterface>
+TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
+                                             Value defaultValue,
+                                             OpBuilder &builder) {
+  if (defaultValue && defaultValue.use_empty())
+    defaultValue.getDefiningOp()->erase();
+  return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
+
+SmallVector<DestructurableMemorySlot>
+TestMultiSlotAlloca::getDestructurableSlots() {
+  SmallVector<DestructurableMemorySlot> slots;
+  for (Value result : getResults()) {
+    auto memrefType = cast<MemRefType>(result.getType());
+    auto destructurable =
+        llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
+    if (!destructurable)
+      continue;
+
+    std::optional<DenseMap<Attribute, Type>> destructuredType =
+        destructurable.getSubelementIndexMap();
+    if (!destructuredType)
+      continue;
+    slots.emplace_back(
+        DestructurableMemorySlot{{result, memrefType}, *destructuredType});
+  }
+  return slots;
+}
+
+DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
+    const DestructurableMemorySlot &slot,
+    const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointAfter(*this);
+
+  DenseMap<Attribute, MemorySlot> slotMap;
+
+  for (Attribute usedIndex : usedIndices) {
+    Type elemType = slot.elementPtrs.lookup(usedIndex);
+    MemRefType elemPtr = MemRefType::get({}, elemType);
+    auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
+    newAllocators.push_back(subAlloca);
+    slotMap.try_emplace<MemorySlot>(usedIndex,
+                                    {subAlloca.getResult(0), elemType});
+  }
+
+  return slotMap;
+}
+
+std::optional<DestructurableAllocationOpInterface>
+TestMultiSlotAlloca::handleDestructuringComplete(
+    const DestructurableMemorySlot &slot, OpBuilder &builder) {
+  return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e16ea2407314e..7fc3d22d18958 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3169,11 +3169,12 @@ def TestOpOptionallyImplementingInterface
 }
 
 //===----------------------------------------------------------------------===//
-// Test Mem2Reg
+// Test Mem2Reg & SROA
 //===----------------------------------------------------------------------===//
 
 def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
-    [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+    [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
+     DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]> {
   let results = (outs Variadic<MemRefOf<[I32]>>:$results);
   let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/91971


More information about the Mlir-commits mailing list