[Mlir-commits] [mlir] 0bf120a - [mlir] [sroa] Add support for MemRef.

Tobias Gysi llvmlistbot at llvm.org
Wed May 24 00:39:05 PDT 2023


Author: Théo Degioanni
Date: 2023-05-24T07:33:28Z
New Revision: 0bf120a82040f7ffaba0f0ab72a983f1cd9343ab

URL: https://github.com/llvm/llvm-project/commit/0bf120a82040f7ffaba0f0ab72a983f1cd9343ab
DIFF: https://github.com/llvm/llvm-project/commit/0bf120a82040f7ffaba0f0ab72a983f1cd9343ab.diff

LOG: [mlir] [sroa] Add support for MemRef.

This patch implements SROA interfaces for MemRef, up to a given fixed
size.

Reviewed By: gysit, Dinistro

Differential Revision: https://reviews.llvm.org/D151102

Added: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h
    mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
    mlir/test/Dialect/MemRef/sroa.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/MemRef/IR/CMakeLists.txt

Removed: 
    mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h
new file mode 100644
index 0000000000000..6b56311ed8840
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h
@@ -0,0 +1,20 @@
+//===- MemRefMemorySlot.h - Implementation of Memory Slot Interfaces ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_IR_MEMREFMEMORYSLOT_H
+#define MLIR_DIALECT_MEMREF_IR_MEMREFMEMORYSLOT_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerMemorySlotExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_IR_MEMREFMEMORYSLOT_H

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 8500c4c26ab25..d4c14e24f627e 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -311,7 +311,8 @@ def MemRef_ReallocOp : MemRef_Op<"realloc"> {
 
 def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource,[
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
-    DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+    DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
+    DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]> {
   let summary = "stack memory allocation operation";
   let description = [{
     The `alloca` operation allocates memory on the stack, to be automatically
@@ -1162,7 +1163,8 @@ def LoadOp : MemRef_Op<"load",
                      "memref", "result",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
-      DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
+      DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "load operation";
   let description = [{
     The `load` op reads an element from a memref specified by an index list. The
@@ -1752,7 +1754,8 @@ def MemRef_StoreOp : MemRef_Op<"store",
                      "memref", "value",
                      "::llvm::cast<MemRefType>($_self).getElementType()">,
       MemRefsNormalizable,
-      DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
+      DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
   let summary = "store operation";
   let description = [{
     Store a value to a memref location given by indices. The value stored should

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index e0f154824d8eb..b00de3f0a2002 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -48,6 +48,7 @@
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
 #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
@@ -148,6 +149,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   memref::registerBufferizableOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
+  memref::registerMemorySlotExternalModels(registry);
   scf::registerBufferizableOpInterfaceExternalModels(registry);
   scf::registerValueBoundsOpInterfaceExternalModels(registry);
   shape::registerBufferizableOpInterfaceExternalModels(registry);

diff  --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 1b01f783a2224..fd2fed28badd6 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -1,6 +1,6 @@
 add_mlir_dialect_library(MLIRMemRefDialect
   MemRefDialect.cpp
-  MemRefMem2Reg.cpp
+  MemRefMemorySlot.cpp
   MemRefOps.cpp
   ValueBoundsOpInterfaceImpl.cpp
 
@@ -21,6 +21,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRDialectUtils
   MLIRInferTypeOpInterface
   MLIRIR
+  MLIRMemorySlotInterfaces
   MLIRShapedOpInterfaces
   MLIRSideEffectInterfaces
   MLIRValueBoundsOpInterface

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
deleted file mode 100644
index 12d9ebd5a02ad..0000000000000
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
+++ /dev/null
@@ -1,119 +0,0 @@
-//===- MemRefMem2Reg.cpp - Mem2Reg Interfaces -------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements Mem2Reg-related interfaces for MemRef dialect
-// operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "llvm/ADT/TypeSwitch.h"
-
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-//  AllocaOp interfaces
-//===----------------------------------------------------------------------===//
-
-static bool isSupportedElementType(Type type) {
-  return llvm::isa<MemRefType>(type) ||
-         OpBuilder(type.getContext()).getZeroAttr(type);
-}
-
-SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
-  MemRefType type = getType();
-  if (!isSupportedElementType(type.getElementType()))
-    return {};
-  if (!type.hasStaticShape())
-    return {};
-  // Make sure the memref contains only a single element.
-  if (any_of(type.getShape(), [](uint64_t dim) { return dim != 1; }))
-    return {};
-
-  return {MemorySlot{getResult(), type.getElementType()}};
-}
-
-Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
-                                        RewriterBase &rewriter) {
-  assert(isSupportedElementType(slot.elemType));
-  // TODO: support more types.
-  return TypeSwitch<Type, Value>(slot.elemType)
-      .Case([&](MemRefType t) {
-        return rewriter.create<memref::AllocaOp>(getLoc(), t);
-      })
-      .Default([&](Type t) {
-        return rewriter.create<arith::ConstantOp>(getLoc(), t,
-                                                  rewriter.getZeroAttr(t));
-      });
-}
-
-void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
-                                               Value defaultValue,
-                                               RewriterBase &rewriter) {
-  if (defaultValue.use_empty())
-    rewriter.eraseOp(defaultValue.getDefiningOp());
-  rewriter.eraseOp(*this);
-}
-
-void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
-                                           BlockArgument argument,
-                                           RewriterBase &rewriter) {}
-
-//===----------------------------------------------------------------------===//
-//  LoadOp/StoreOp interfaces
-//===----------------------------------------------------------------------===//
-
-bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
-  return getMemRef() == slot.ptr;
-}
-
-Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
-
-bool memref::LoadOp::canUsesBeRemoved(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
-  if (blockingUses.size() != 1)
-    return false;
-  Value blockingUse = (*blockingUses.begin())->get();
-  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
-         getResult().getType() == slot.elemType;
-}
-
-DeletionKind memref::LoadOp::removeBlockingUses(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
-  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
-  // pointer.
-  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
-  return DeletionKind::Delete;
-}
-
-bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
-
-Value memref::StoreOp::getStored(const MemorySlot &slot) {
-  if (getMemRef() != slot.ptr)
-    return {};
-  return getValue();
-}
-
-bool memref::StoreOp::canUsesBeRemoved(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    SmallVectorImpl<OpOperand *> &newBlockingUses) {
-  if (blockingUses.size() != 1)
-    return false;
-  Value blockingUse = (*blockingUses.begin())->get();
-  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
-         getValue() != slot.ptr && getValue().getType() == slot.elemType;
-}
-
-DeletionKind memref::StoreOp::removeBlockingUses(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
-  return DeletionKind::Delete;
-}

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
new file mode 100644
index 0000000000000..34fedec8d7e11
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -0,0 +1,331 @@
+//===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements Mem2Reg-related interfaces for MemRef dialect
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+//  Utilities
+//===----------------------------------------------------------------------===//
+
+/// Walks over the indices of the elements of a tensor of a given `shape` by
+/// updating `index` in place to the next index. This returns failure if the
+/// provided index was the last index.
+static LogicalResult nextIndex(ArrayRef<int64_t> shape,
+                               MutableArrayRef<int64_t> index) {
+  for (size_t i = 0; i < shape.size(); ++i) {
+    index[i]++;
+    if (index[i] < shape[i])
+      return success();
+    index[i] = 0;
+  }
+  return failure();
+}
+
+/// Calls `walker` for each index within a tensor of a given `shape`, providing
+/// the index as an array attribute of the coordinates.
+template <typename CallableT>
+static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
+                              CallableT &&walker) {
+  Type indexType = IndexType::get(ctx);
+  SmallVector<int64_t> shapeIter(shape.size(), 0);
+  do {
+    SmallVector<Attribute> indexAsAttr;
+    for (int64_t dim : shapeIter)
+      indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
+    walker(ArrayAttr::get(ctx, indexAsAttr));
+  } while (succeeded(nextIndex(shape, shapeIter)));
+}
+
+//===----------------------------------------------------------------------===//
+//  Interfaces for AllocaOp
+//===----------------------------------------------------------------------===//
+
+static bool isSupportedElementType(Type type) {
+  return type.isa<MemRefType>() ||
+         OpBuilder(type.getContext()).getZeroAttr(type);
+}
+
+SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
+  MemRefType type = getType();
+  if (!isSupportedElementType(type.getElementType()))
+    return {};
+  if (!type.hasStaticShape())
+    return {};
+  // Make sure the memref contains only a single element.
+  if (type.getNumElements() != 1)
+    return {};
+
+  return {MemorySlot{getResult(), type.getElementType()}};
+}
+
+Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
+                                        RewriterBase &rewriter) {
+  assert(isSupportedElementType(slot.elemType));
+  // TODO: support more types.
+  return TypeSwitch<Type, Value>(slot.elemType)
+      .Case([&](MemRefType t) {
+        return rewriter.create<memref::AllocaOp>(getLoc(), t);
+      })
+      .Default([&](Type t) {
+        return rewriter.create<arith::ConstantOp>(getLoc(), t,
+                                                  rewriter.getZeroAttr(t));
+      });
+}
+
+void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+                                               Value defaultValue,
+                                               RewriterBase &rewriter) {
+  if (defaultValue.use_empty())
+    rewriter.eraseOp(defaultValue.getDefiningOp());
+  rewriter.eraseOp(*this);
+}
+
+void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
+                                           BlockArgument argument,
+                                           RewriterBase &rewriter) {}
+
+SmallVector<DestructurableMemorySlot>
+memref::AllocaOp::getDestructurableSlots() {
+  MemRefType memrefType = getType();
+  auto destructurable = memrefType.dyn_cast<DestructurableTypeInterface>();
+  if (!destructurable)
+    return {};
+
+  Optional<DenseMap<Attribute, Type>> destructuredType =
+      destructurable.getSubelementIndexMap();
+  if (!destructuredType)
+    return {};
+
+  DenseMap<Attribute, Type> indexMap;
+  for (auto const &[index, type] : *destructuredType)
+    indexMap.insert({index, MemRefType::get({}, type)});
+
+  return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}};
+}
+
+DenseMap<Attribute, MemorySlot>
+memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
+                              const SmallPtrSetImpl<Attribute> &usedIndices,
+                              RewriterBase &rewriter) {
+  rewriter.setInsertionPointAfter(*this);
+
+  DenseMap<Attribute, MemorySlot> slotMap;
+
+  auto memrefType = getType().cast<DestructurableTypeInterface>();
+  for (Attribute usedIndex : usedIndices) {
+    Type elemType = memrefType.getTypeAtIndex(usedIndex);
+    MemRefType elemPtr = MemRefType::get({}, elemType);
+    auto subAlloca = rewriter.create<memref::AllocaOp>(getLoc(), elemPtr);
+    slotMap.try_emplace<MemorySlot>(usedIndex,
+                                    {subAlloca.getResult(), elemType});
+  }
+
+  return slotMap;
+}
+
+void memref::AllocaOp::handleDestructuringComplete(
+    const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
+  assert(slot.ptr == getResult());
+  rewriter.eraseOp(*this);
+}
+
+//===----------------------------------------------------------------------===//
+//  Interfaces for LoadOp/StoreOp
+//===----------------------------------------------------------------------===//
+
+bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
+  return getMemRef() == slot.ptr;
+}
+
+Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
+
+bool memref::LoadOp::canUsesBeRemoved(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+  if (blockingUses.size() != 1)
+    return false;
+  Value blockingUse = (*blockingUses.begin())->get();
+  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
+         getResult().getType() == slot.elemType;
+}
+
+DeletionKind memref::LoadOp::removeBlockingUses(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    RewriterBase &rewriter, Value reachingDefinition) {
+  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
+  // pointer.
+  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
+  return DeletionKind::Delete;
+}
+
+/// Returns the index of a memref in attribute form, given its indices.
+static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
+                                                    ValueRange indices) {
+  SmallVector<Attribute> index;
+  for (Value coord : indices) {
+    IntegerAttr coordAttr;
+    if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
+      return {};
+    index.push_back(coordAttr);
+  }
+  return ArrayAttr::get(ctx, index);
+}
+
+bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
+                               SmallPtrSetImpl<Attribute> &usedIndices,
+                               SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+  if (slot.ptr != getMemRef())
+    return false;
+  Attribute index =
+      getAttributeIndexFromIndexOperands(getContext(), getIndices());
+  if (!index)
+    return false;
+  usedIndices.insert(index);
+  return true;
+}
+
+DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
+                                    DenseMap<Attribute, MemorySlot> &subslots,
+                                    RewriterBase &rewriter) {
+  Attribute index =
+      getAttributeIndexFromIndexOperands(getContext(), getIndices());
+  const MemorySlot &memorySlot = subslots.at(index);
+  rewriter.updateRootInPlace(*this, [&]() {
+    setMemRef(memorySlot.ptr);
+    getIndicesMutable().clear();
+  });
+  return DeletionKind::Keep;
+}
+
+bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+Value memref::StoreOp::getStored(const MemorySlot &slot) {
+  if (getMemRef() != slot.ptr)
+    return {};
+  return getValue();
+}
+
+bool memref::StoreOp::canUsesBeRemoved(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+  if (blockingUses.size() != 1)
+    return false;
+  Value blockingUse = (*blockingUses.begin())->get();
+  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
+         getValue() != slot.ptr && getValue().getType() == slot.elemType;
+}
+
+DeletionKind memref::StoreOp::removeBlockingUses(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    RewriterBase &rewriter, Value reachingDefinition) {
+  return DeletionKind::Delete;
+}
+
+bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
+                                SmallPtrSetImpl<Attribute> &usedIndices,
+                                SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+  if (slot.ptr != getMemRef() || getValue() == slot.ptr)
+    return false;
+  Attribute index =
+      getAttributeIndexFromIndexOperands(getContext(), getIndices());
+  if (!index || !slot.elementPtrs.contains(index))
+    return false;
+  usedIndices.insert(index);
+  return true;
+}
+
+DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
+                                     DenseMap<Attribute, MemorySlot> &subslots,
+                                     RewriterBase &rewriter) {
+  Attribute index =
+      getAttributeIndexFromIndexOperands(getContext(), getIndices());
+  const MemorySlot &memorySlot = subslots.at(index);
+  rewriter.updateRootInPlace(*this, [&]() {
+    setMemRef(memorySlot.ptr);
+    getIndicesMutable().clear();
+  });
+  return DeletionKind::Keep;
+}
+
+//===----------------------------------------------------------------------===//
+//  Interfaces for destructurable types
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct MemRefDestructurableTypeExternalModel
+    : public DestructurableTypeInterface::ExternalModel<
+          MemRefDestructurableTypeExternalModel, MemRefType> {
+  std::optional<DenseMap<Attribute, Type>>
+  getSubelementIndexMap(Type type) const {
+    auto memrefType = type.cast<MemRefType>();
+    constexpr int64_t maxMemrefSizeForDestructuring = 16;
+    if (!memrefType.hasStaticShape() ||
+        memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
+        memrefType.getNumElements() == 1)
+      return {};
+
+    DenseMap<Attribute, Type> destructured;
+    walkIndicesAsAttr(
+        memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
+          destructured.insert({index, memrefType.getElementType()});
+        });
+
+    return destructured;
+  }
+
+  Type getTypeAtIndex(Type type, Attribute index) const {
+    auto memrefType = type.cast<MemRefType>();
+    auto coordArrAttr = index.dyn_cast<ArrayAttr>();
+    if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
+      return {};
+
+    Type indexType = IndexType::get(memrefType.getContext());
+    for (const auto &[coordAttr, dimSize] :
+         llvm::zip(coordArrAttr, memrefType.getShape())) {
+      auto coord = coordAttr.dyn_cast<IntegerAttr>();
+      if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
+          coord.getInt() >= dimSize)
+        return {};
+    }
+
+    return memrefType.getElementType();
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+//  Register external models
+//===----------------------------------------------------------------------===//
+
+void mlir::memref::registerMemorySlotExternalModels(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
+    MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
+  });
+}

diff  --git a/mlir/test/Dialect/MemRef/sroa.mlir b/mlir/test/Dialect/MemRef/sroa.mlir
new file mode 100644
index 0000000000000..d78053d8ea777
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/sroa.mlir
@@ -0,0 +1,154 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(sroa))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @basic(%arg0: i32, %arg1: i32) -> i32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT: = memref.alloca()
+  // CHECK-COUNT-2: = memref.alloca() : memref<i32>
+  // CHECK-NOT: = memref.alloca()
+  %alloca = memref.alloca() : memref<2xi32>
+  // CHECK: memref.store %[[ARG0]], %[[ALLOCA0:.*]][]
+  memref.store %arg0, %alloca[%c0] : memref<2xi32>
+  // CHECK: memref.store %[[ARG1]], %[[ALLOCA1:.*]][]
+  memref.store %arg1, %alloca[%c1] : memref<2xi32>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA0]][]
+  %res = memref.load %alloca[%c0] : memref<2xi32>
+  // CHECK: return %[[RES]] : i32
+  return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_high_dimensions
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+func.func @basic_high_dimensions(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT: = memref.alloca()
+  // CHECK-COUNT-3: = memref.alloca() : memref<i32>
+  // CHECK-NOT: = memref.alloca()
+  %alloca = memref.alloca() : memref<2x2xi32>
+  // CHECK: memref.store %[[ARG0]], %[[ALLOCA0:.*]][]
+  memref.store %arg0, %alloca[%c0, %c0] : memref<2x2xi32>
+  // CHECK: memref.store %[[ARG1]], %[[ALLOCA1:.*]][]
+  memref.store %arg1, %alloca[%c0, %c1] : memref<2x2xi32>
+  // CHECK: memref.store %[[ARG2]], %[[ALLOCA2:.*]][]
+  memref.store %arg2, %alloca[%c1, %c0] : memref<2x2xi32>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA1]][]
+  %res = memref.load %alloca[%c0, %c1] : memref<2x2xi32>
+  // CHECK: return %[[RES]] : i32
+  return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @resolve_alias
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @resolve_alias(%arg0: i32, %arg1: i32) -> i32 {
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT: = memref.alloca()
+  // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i32>
+  // CHECK-NOT: = memref.alloca()
+  %alloca = memref.alloca() : memref<2xi32>
+  // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][]
+  memref.store %arg0, %alloca[%c0] : memref<2xi32>
+  // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][]
+  memref.store %arg1, %alloca[%c0] : memref<2xi32>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][]
+  %res = memref.load %alloca[%c0] : memref<2xi32>
+  // CHECK: return %[[RES]] : i32
+  return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_direct_use
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @no_direct_use(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT: = memref.alloca()
+  // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
+  // CHECK-NOT: = memref.alloca()
+  %alloca = memref.alloca() : memref<2xi32>
+  // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
+  memref.store %arg0, %alloca[%c0] : memref<2xi32>
+  // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][%[[C1]]]
+  memref.store %arg1, %alloca[%c1] : memref<2xi32>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]]]
+  %res = memref.load %alloca[%c0] : memref<2xi32>
+  call @use(%alloca) : (memref<2xi32>) -> ()
+  // CHECK: return %[[RES]] : i32
+  return %res : i32
+}
+
+func.func @use(%foo: memref<2xi32>) { return }
+
+// -----
+
+// CHECK-LABEL: func.func @no_dynamic_indexing
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[INDEX:.*]]: index)
+func.func @no_dynamic_indexing(%arg0: i32, %arg1: i32, %index: index) -> i32 {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NOT: = memref.alloca()
+  // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
+  // CHECK-NOT: = memref.alloca()
+  %alloca = memref.alloca() : memref<2xi32>
+  // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
+  memref.store %arg0, %alloca[%c0] : memref<2xi32>
+  // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][%[[INDEX]]]
+  memref.store %arg1, %alloca[%index] : memref<2xi32>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]]]
+  %res = memref.load %alloca[%c0] : memref<2xi32>
+  // CHECK: return %[[RES]] : i32
+  return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_dynamic_shape
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @no_dynamic_shape(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK-NOT: = memref.alloca()
+  // CHECK: %[[ALLOCA:.*]] = memref.alloca(%[[C1]]) : memref<?x2xi32>
+  // CHECK-NOT: = memref.alloca()
+  %alloca = memref.alloca(%c1) : memref<?x2xi32>
+  // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]], %[[C0]]]
+  memref.store %arg0, %alloca[%c0, %c0] : memref<?x2xi32>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]], %[[C0]]]
+  %res = memref.load %alloca[%c0, %c0] : memref<?x2xi32>
+  // CHECK: return %[[RES]] : i32
+  return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_out_of_bounds
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[C100:.*]] = arith.constant 100 : index
+  %c100 = arith.constant 100 : index
+  // CHECK-NOT: = memref.alloca()
+  // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
+  // CHECK-NOT: = memref.alloca()
+  %alloca = memref.alloca() : memref<2xi32>
+  // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
+  memref.store %arg0, %alloca[%c0] : memref<2xi32>
+  // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][%[[C100]]]
+  memref.store %arg1, %alloca[%c100] : memref<2xi32>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]]]
+  %res = memref.load %alloca[%c0] : memref<2xi32>
+  // CHECK: return %[[RES]] : i32
+  return %res : i32
+}


        


More information about the Mlir-commits mailing list