[Mlir-commits] [mlir] [MLIR][SROA] Reuse allocators to avoid rewalking the IR (PR #91971)

Christian Ulmann llvmlistbot at llvm.org
Mon May 13 23:18:04 PDT 2024


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/91971

>From c60b96b7e377569cbec608cde9ee9733062ea9aa Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 13 May 2024 14:19:28 +0000
Subject: [PATCH 1/3] [MLIR][SROA] Reuse allocators to avoid rewalking the IR

This commit extends the SROA interfaces to ensure the impementations can
communicate newly created allocators to the algorithm. This ensures that
the SROA implementation does no longer require re-walking the IR to find
new allocators.
---
 .../mlir/Interfaces/MemorySlotInterfaces.td   | 10 ++-
 mlir/include/mlir/Transforms/SROA.h           |  6 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 13 +--
 .../Dialect/MemRef/IR/MemRefMemorySlot.cpp    | 13 +--
 mlir/lib/Transforms/SROA.cpp                  | 86 ++++++++++++-------
 mlir/test/Transforms/sroa.mlir                | 31 +++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 79 ++++++++++++++---
 mlir/test/lib/Dialect/Test/TestOps.td         |  5 +-
 8 files changed, 185 insertions(+), 58 deletions(-)
 create mode 100644 mlir/test/Transforms/sroa.mlir

diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index e2409cbec5fde..29960f457373c 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -298,14 +298,20 @@ def DestructurableAllocationOpInterface
       "destructure",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
-           "::mlir::OpBuilder &":$builder)
+           "::mlir::OpBuilder &":$builder,
+           "::mlir::SmallVectorImpl<::mlir::DestructurableAllocationOpInterface> &":
+             $newAllocators)
     >,
     InterfaceMethod<[{
         Hook triggered once the destructuring of a slot is complete, meaning the
         original slot is no longer being refered to and could be deleted.
         This will only be called for slots declared by this operation.
+
+        Must return a new destructurable allocation op if this operation
+        produced multiple destructurable slots, nullopt otherwise.
       }],
-      "void", "handleDestructuringComplete",
+      "::std::optional<::mlir::DestructurableAllocationOpInterface>",
+        "handleDestructuringComplete",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "::mlir::OpBuilder &":$builder)
     >,
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index fa84fb1eae73a..d48f809d30766 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -27,8 +27,10 @@ struct SROAStatistics {
   llvm::Statistic *maxSubelementAmount = nullptr;
 };
 
-/// Attempts to destructure the slots of destructurable allocators. Returns
-/// failure if no slot was destructured.
+/// Attempts to destructure the slots of destructurable allocators. Iteratively
+/// retries the destructuring of all slots as destructuring one slot might
+/// enable subsequent destructuring. Returns failure if no slot was
+/// destructured.
 LogicalResult tryToDestructureMemorySlots(
     ArrayRef<DestructurableAllocationOpInterface> allocators,
     OpBuilder &builder, const DataLayout &dataLayout,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 4fdf847a559ce..3f1e5b1773bf7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -77,10 +77,10 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
                                    *destructuredType}};
 }
 
-DenseMap<Attribute, MemorySlot>
-LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
-                            const SmallPtrSetImpl<Attribute> &usedIndices,
-                            OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
+    const DestructurableMemorySlot &slot,
+    const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
   assert(slot.ptr == getResult());
   builder.setInsertionPointAfter(*this);
 
@@ -92,16 +92,19 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
     auto subAlloca = builder.create<LLVM::AllocaOp>(
         getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
         getArraySize());
+    newAllocators.push_back(subAlloca);
     slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
   }
 
   return slotMap;
 }
 
-void LLVM::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+LLVM::AllocaOp::handleDestructuringComplete(
     const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
   this->erase();
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index e30598e6878f4..631dee2d40538 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -126,10 +126,10 @@ memref::AllocaOp::getDestructurableSlots() {
       DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
 }
 
-DenseMap<Attribute, MemorySlot>
-memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
-                              const SmallPtrSetImpl<Attribute> &usedIndices,
-                              OpBuilder &builder) {
+DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure(
+    const DestructurableMemorySlot &slot,
+    const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
   builder.setInsertionPointAfter(*this);
 
   DenseMap<Attribute, MemorySlot> slotMap;
@@ -139,6 +139,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
     Type elemType = memrefType.getTypeAtIndex(usedIndex);
     MemRefType elemPtr = MemRefType::get({}, elemType);
     auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr);
+    newAllocators.push_back(subAlloca);
     slotMap.try_emplace<MemorySlot>(usedIndex,
                                     {subAlloca.getResult(), elemType});
   }
@@ -146,10 +147,12 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
   return slotMap;
 }
 
-void memref::AllocaOp::handleDestructuringComplete(
+std::optional<DestructurableAllocationOpInterface>
+memref::AllocaOp::handleDestructuringComplete(
     const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
   this->erase();
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 4e28fa687ffd4..67cbade07bc94 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -132,16 +132,17 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
 /// Performs the destructuring of a destructible slot given associated
 /// destructuring information. The provided slot will be destructured in
 /// subslots as specified by its allocator.
-static void destructureSlot(DestructurableMemorySlot &slot,
-                            DestructurableAllocationOpInterface allocator,
-                            OpBuilder &builder, const DataLayout &dataLayout,
-                            MemorySlotDestructuringInfo &info,
-                            const SROAStatistics &statistics) {
+static void destructureSlot(
+    DestructurableMemorySlot &slot,
+    DestructurableAllocationOpInterface allocator, OpBuilder &builder,
+    const DataLayout &dataLayout, MemorySlotDestructuringInfo &info,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators,
+    const SROAStatistics &statistics) {
   OpBuilder::InsertionGuard guard(builder);
 
   builder.setInsertionPointToStart(slot.ptr.getParentBlock());
   DenseMap<Attribute, MemorySlot> subslots =
-      allocator.destructure(slot, info.usedIndices, builder);
+      allocator.destructure(slot, info.usedIndices, builder, newAllocators);
 
   if (statistics.slotsWithMemoryBenefit &&
       slot.elementPtrs.size() != info.usedIndices.size())
@@ -185,7 +186,11 @@ static void destructureSlot(DestructurableMemorySlot &slot,
   if (statistics.destructuredAmount)
     (*statistics.destructuredAmount)++;
 
-  allocator.handleDestructuringComplete(slot, builder);
+  std::optional<DestructurableAllocationOpInterface> newAllocator =
+      allocator.handleDestructuringComplete(slot, builder);
+  // Add newly created allocators to the worklist for further processing.
+  if (newAllocator)
+    newAllocators.push_back(*newAllocator);
 }
 
 LogicalResult mlir::tryToDestructureMemorySlots(
@@ -194,16 +199,44 @@ LogicalResult mlir::tryToDestructureMemorySlots(
     SROAStatistics statistics) {
   bool destructuredAny = false;
 
-  for (DestructurableAllocationOpInterface allocator : allocators) {
-    for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
-      std::optional<MemorySlotDestructuringInfo> info =
-          computeDestructuringInfo(slot, dataLayout);
-      if (!info)
-        continue;
+  SmallVector<DestructurableAllocationOpInterface> workList(allocators.begin(),
+                                                            allocators.end());
+  SmallVector<DestructurableAllocationOpInterface> newWorkList;
+  newWorkList.reserve(allocators.size());
+  // Destructuring a slot can allow for further destructuring of other
+  // slots, destructuring is tried until no destructuring succeeds.
+  while (true) {
+    bool changesInThisRound = false;
+
+    for (DestructurableAllocationOpInterface allocator : workList) {
+      bool destructuredAnySlot = false;
+      for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
+        std::optional<MemorySlotDestructuringInfo> info =
+            computeDestructuringInfo(slot, dataLayout);
+        if (!info)
+          continue;
 
-      destructureSlot(slot, allocator, builder, dataLayout, *info, statistics);
-      destructuredAny = true;
+        destructureSlot(slot, allocator, builder, dataLayout, *info,
+                        newWorkList, statistics);
+        destructuredAnySlot = true;
+
+        // A break is required, since destructuring a slot may invalidate the
+        // remaning slots of an allocator.
+        break;
+      }
+      if (!destructuredAnySlot)
+        newWorkList.push_back(allocator);
+      changesInThisRound |= destructuredAnySlot;
     }
+
+    if (!changesInThisRound)
+      break;
+    destructuredAny |= changesInThisRound;
+
+    // Swap the vector's backing memory and clear the entries in newWorkList
+    // afterwards. This ensures that additional heap allocations can be avoided.
+    workList.swap(newWorkList);
+    newWorkList.clear();
   }
 
   return success(destructuredAny);
@@ -230,23 +263,16 @@ struct SROA : public impl::SROABase<SROA> {
 
       OpBuilder builder(&region.front(), region.front().begin());
 
-      // Destructuring a slot can allow for further destructuring of other
-      // slots, destructuring is tried until no destructuring succeeds.
-      while (true) {
-        SmallVector<DestructurableAllocationOpInterface> allocators;
-        // Build a list of allocators to attempt to destructure the slots of.
-        // TODO: Update list on the fly to avoid repeated visiting of the same
-        // allocators.
-        region.walk([&](DestructurableAllocationOpInterface allocator) {
-          allocators.emplace_back(allocator);
-        });
-
-        if (failed(tryToDestructureMemorySlots(allocators, builder, dataLayout,
-                                               statistics)))
-          break;
+      SmallVector<DestructurableAllocationOpInterface> allocators;
+      // Build a list of allocators to attempt to destructure the slots of.
+      region.walk([&](DestructurableAllocationOpInterface allocator) {
+        allocators.emplace_back(allocator);
+      });
 
+      // Attempt to destructure as many slots as possible.
+      if (succeeded(tryToDestructureMemorySlots(allocators, builder, dataLayout,
+                                                statistics)))
         changed = true;
-      }
     }
     if (!changed)
       markAllAnalysesPreserved();
diff --git a/mlir/test/Transforms/sroa.mlir b/mlir/test/Transforms/sroa.mlir
new file mode 100644
index 0000000000000..c9e80a6cf8dd1
--- /dev/null
+++ b/mlir/test/Transforms/sroa.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(sroa))' --split-input-file | FileCheck %s
+
+// Verifies that allocators with mutliple slots are handled properly.
+
+// CHECK-LABEL: func.func @multi_slot_alloca
+func.func @multi_slot_alloca() -> (i32, i32) {
+  %0 = arith.constant 0 : index
+  %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+  // CHECK-COUNT-2: test.multi_slot_alloca : () -> memref<i32>
+  %3 = memref.load %1[%0] {first}: memref<2xi32>
+  %4 = memref.load %2[%0] {second} : memref<4xi32>
+  return %3, %4 : i32, i32
+}
+
+// -----
+
+// Verifies that a multi slot allocator can be partially destructured.
+
+func.func private @consumer(memref<2xi32>)
+
+// CHECK-LABEL: func.func @multi_slot_alloca_only_second
+func.func @multi_slot_alloca_only_second() -> (i32, i32) {
+  %0 = arith.constant 0 : index
+  // CHECK: test.multi_slot_alloca : () -> memref<2xi32>
+  // CHECK: test.multi_slot_alloca : () -> memref<i32>
+  %1, %2 = test.multi_slot_alloca : () -> (memref<2xi32>, memref<4xi32>)
+  func.call @consumer(%1) : (memref<2xi32>) -> ()
+  %3 = memref.load %1[%0] : memref<2xi32>
+  %4 = memref.load %2[%0] : memref<4xi32>
+  return %3, %4 : i32, i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index d22d48b139a04..c8949e8d1af72 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1199,22 +1199,20 @@ void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
   // Not relevant for testing.
 }
 
-std::optional<PromotableAllocationOpInterface>
-TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
-                                             Value defaultValue,
-                                             OpBuilder &builder) {
-  if (defaultValue && defaultValue.use_empty())
-    defaultValue.getDefiningOp()->erase();
+/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
+static std::optional<TestMultiSlotAlloca>
+createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
+                                TestMultiSlotAlloca oldOp) {
 
-  if (getNumResults() == 1) {
-    erase();
+  if (oldOp.getNumResults() == 1) {
+    oldOp.erase();
     return std::nullopt;
   }
 
   SmallVector<Type> newTypes;
   SmallVector<Value> remainingValues;
 
-  for (Value oldResult : getResults()) {
+  for (Value oldResult : oldOp.getResults()) {
     if (oldResult == slot.ptr)
       continue;
     remainingValues.push_back(oldResult);
@@ -1222,12 +1220,69 @@ TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
   }
 
   OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPoint(*this);
-  auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
+  builder.setInsertionPoint(oldOp);
+  auto replacement =
+      builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
   for (auto [oldResult, newResult] :
        llvm::zip_equal(remainingValues, replacement.getResults()))
     oldResult.replaceAllUsesWith(newResult);
 
-  erase();
+  oldOp.erase();
   return replacement;
 }
+
+std::optional<PromotableAllocationOpInterface>
+TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
+                                             Value defaultValue,
+                                             OpBuilder &builder) {
+  if (defaultValue && defaultValue.use_empty())
+    defaultValue.getDefiningOp()->erase();
+  return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
+
+SmallVector<DestructurableMemorySlot>
+TestMultiSlotAlloca::getDestructurableSlots() {
+  SmallVector<DestructurableMemorySlot> slots;
+  for (Value result : getResults()) {
+    auto memrefType = cast<MemRefType>(result.getType());
+    auto destructurable =
+        llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
+    if (!destructurable)
+      continue;
+
+    std::optional<DenseMap<Attribute, Type>> destructuredType =
+        destructurable.getSubelementIndexMap();
+    if (!destructuredType)
+      continue;
+    slots.emplace_back(
+        DestructurableMemorySlot{{result, memrefType}, *destructuredType});
+  }
+  return slots;
+}
+
+DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
+    const DestructurableMemorySlot &slot,
+    const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
+    SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointAfter(*this);
+
+  DenseMap<Attribute, MemorySlot> slotMap;
+
+  for (Attribute usedIndex : usedIndices) {
+    Type elemType = slot.elementPtrs.lookup(usedIndex);
+    MemRefType elemPtr = MemRefType::get({}, elemType);
+    auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
+    newAllocators.push_back(subAlloca);
+    slotMap.try_emplace<MemorySlot>(usedIndex,
+                                    {subAlloca.getResult(0), elemType});
+  }
+
+  return slotMap;
+}
+
+std::optional<DestructurableAllocationOpInterface>
+TestMultiSlotAlloca::handleDestructuringComplete(
+    const DestructurableMemorySlot &slot, OpBuilder &builder) {
+  return createNewMultiAllocaWithoutSlot(slot, builder, *this);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e16ea2407314e..7fc3d22d18958 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3169,11 +3169,12 @@ def TestOpOptionallyImplementingInterface
 }
 
 //===----------------------------------------------------------------------===//
-// Test Mem2Reg
+// Test Mem2Reg & SROA
 //===----------------------------------------------------------------------===//
 
 def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
-    [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+    [DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
+     DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]> {
   let results = (outs Variadic<MemRefOf<[I32]>>:$results);
   let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
 }

>From d7b0fd7c24aa7c400eeea9068c06ba009e8b1447 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Mon, 13 May 2024 20:13:24 +0000
Subject: [PATCH 2/3] address review comments

---
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index c8949e8d1af72..0b676db18af41 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1245,8 +1245,7 @@ TestMultiSlotAlloca::getDestructurableSlots() {
   SmallVector<DestructurableMemorySlot> slots;
   for (Value result : getResults()) {
     auto memrefType = cast<MemRefType>(result.getType());
-    auto destructurable =
-        llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
+    auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType);
     if (!destructurable)
       continue;
 

>From 38efb8e5ce6825dd0dc2be8e2b63c9a53498c4fd Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Tue, 14 May 2024 06:17:43 +0000
Subject: [PATCH 3/3] address Theo's comments

---
 mlir/include/mlir/Interfaces/MemorySlotInterfaces.td | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 29960f457373c..6f023f0c5263f 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -307,11 +307,11 @@ def DestructurableAllocationOpInterface
         original slot is no longer being refered to and could be deleted.
         This will only be called for slots declared by this operation.
 
-        Must return a new destructurable allocation op if this operation
-        produced multiple destructurable slots, nullopt otherwise.
+        Must return a new destructurable allocation op if this hook creates
+        a new destructurable op, nullopt otherwise. 
       }],
       "::std::optional<::mlir::DestructurableAllocationOpInterface>",
-        "handleDestructuringComplete",
+      "handleDestructuringComplete",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "::mlir::OpBuilder &":$builder)
     >,



More information about the Mlir-commits mailing list