[Mlir-commits] [mlir] 0bf120a - [mlir] [sroa] Add support for MemRef.
Tobias Gysi
llvmlistbot at llvm.org
Wed May 24 00:39:05 PDT 2023
Author: Théo Degioanni
Date: 2023-05-24T07:33:28Z
New Revision: 0bf120a82040f7ffaba0f0ab72a983f1cd9343ab
URL: https://github.com/llvm/llvm-project/commit/0bf120a82040f7ffaba0f0ab72a983f1cd9343ab
DIFF: https://github.com/llvm/llvm-project/commit/0bf120a82040f7ffaba0f0ab72a983f1cd9343ab.diff
LOG: [mlir] [sroa] Add support for MemRef.
This patch implements SROA interfaces for MemRef, up to a given fixed
size.
Reviewed By: gysit, Dinistro
Differential Revision: https://reviews.llvm.org/D151102
Added:
mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h
mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
mlir/test/Dialect/MemRef/sroa.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
Removed:
mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h
new file mode 100644
index 0000000000000..6b56311ed8840
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h
@@ -0,0 +1,20 @@
+//===- MemRefMemorySlot.h - Implementation of Memory Slot Interfaces ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_IR_MEMREFMEMORYSLOT_H
+#define MLIR_DIALECT_MEMREF_IR_MEMREFMEMORYSLOT_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerMemorySlotExternalModels(DialectRegistry ®istry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_IR_MEMREFMEMORYSLOT_H
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 8500c4c26ab25..d4c14e24f627e 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -311,7 +311,8 @@ def MemRef_ReallocOp : MemRef_Op<"realloc"> {
def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource,[
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
+ DeclareOpInterfaceMethods<PromotableAllocationOpInterface>,
+ DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>]> {
let summary = "stack memory allocation operation";
let description = [{
The `alloca` operation allocates memory on the stack, to be automatically
@@ -1162,7 +1163,8 @@ def LoadOp : MemRef_Op<"load",
"memref", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
- DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
+ DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+ DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "load operation";
let description = [{
The `load` op reads an element from a memref specified by an index list. The
@@ -1752,7 +1754,8 @@ def MemRef_StoreOp : MemRef_Op<"store",
"memref", "value",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
- DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
+ DeclareOpInterfaceMethods<PromotableMemOpInterface>,
+ DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "store operation";
let description = [{
Store a value to a memref location given by indices. The value stored should
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index e0f154824d8eb..b00de3f0a2002 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -48,6 +48,7 @@
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
@@ -148,6 +149,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
memref::registerBufferizableOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
+ memref::registerMemorySlotExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerValueBoundsOpInterfaceExternalModels(registry);
shape::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 1b01f783a2224..fd2fed28badd6 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -1,6 +1,6 @@
add_mlir_dialect_library(MLIRMemRefDialect
MemRefDialect.cpp
- MemRefMem2Reg.cpp
+ MemRefMemorySlot.cpp
MemRefOps.cpp
ValueBoundsOpInterfaceImpl.cpp
@@ -21,6 +21,7 @@ add_mlir_dialect_library(MLIRMemRefDialect
MLIRDialectUtils
MLIRInferTypeOpInterface
MLIRIR
+ MLIRMemorySlotInterfaces
MLIRShapedOpInterfaces
MLIRSideEffectInterfaces
MLIRValueBoundsOpInterface
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
deleted file mode 100644
index 12d9ebd5a02ad..0000000000000
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
+++ /dev/null
@@ -1,119 +0,0 @@
-//===- MemRefMem2Reg.cpp - Mem2Reg Interfaces -------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements Mem2Reg-related interfaces for MemRef dialect
-// operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "llvm/ADT/TypeSwitch.h"
-
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// AllocaOp interfaces
-//===----------------------------------------------------------------------===//
-
-static bool isSupportedElementType(Type type) {
- return llvm::isa<MemRefType>(type) ||
- OpBuilder(type.getContext()).getZeroAttr(type);
-}
-
-SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
- MemRefType type = getType();
- if (!isSupportedElementType(type.getElementType()))
- return {};
- if (!type.hasStaticShape())
- return {};
- // Make sure the memref contains only a single element.
- if (any_of(type.getShape(), [](uint64_t dim) { return dim != 1; }))
- return {};
-
- return {MemorySlot{getResult(), type.getElementType()}};
-}
-
-Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
- RewriterBase &rewriter) {
- assert(isSupportedElementType(slot.elemType));
- // TODO: support more types.
- return TypeSwitch<Type, Value>(slot.elemType)
- .Case([&](MemRefType t) {
- return rewriter.create<memref::AllocaOp>(getLoc(), t);
- })
- .Default([&](Type t) {
- return rewriter.create<arith::ConstantOp>(getLoc(), t,
- rewriter.getZeroAttr(t));
- });
-}
-
-void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
- Value defaultValue,
- RewriterBase &rewriter) {
- if (defaultValue.use_empty())
- rewriter.eraseOp(defaultValue.getDefiningOp());
- rewriter.eraseOp(*this);
-}
-
-void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
- BlockArgument argument,
- RewriterBase &rewriter) {}
-
-//===----------------------------------------------------------------------===//
-// LoadOp/StoreOp interfaces
-//===----------------------------------------------------------------------===//
-
-bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
- return getMemRef() == slot.ptr;
-}
-
-Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
-
-bool memref::LoadOp::canUsesBeRemoved(
- const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
- if (blockingUses.size() != 1)
- return false;
- Value blockingUse = (*blockingUses.begin())->get();
- return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
- getResult().getType() == slot.elemType;
-}
-
-DeletionKind memref::LoadOp::removeBlockingUses(
- const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
- // `canUsesBeRemoved` checked this blocking use must be the loaded slot
- // pointer.
- rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
- return DeletionKind::Delete;
-}
-
-bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
-
-Value memref::StoreOp::getStored(const MemorySlot &slot) {
- if (getMemRef() != slot.ptr)
- return {};
- return getValue();
-}
-
-bool memref::StoreOp::canUsesBeRemoved(
- const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- SmallVectorImpl<OpOperand *> &newBlockingUses) {
- if (blockingUses.size() != 1)
- return false;
- Value blockingUse = (*blockingUses.begin())->get();
- return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
- getValue() != slot.ptr && getValue().getType() == slot.elemType;
-}
-
-DeletionKind memref::StoreOp::removeBlockingUses(
- const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
- return DeletionKind::Delete;
-}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
new file mode 100644
index 0000000000000..34fedec8d7e11
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -0,0 +1,331 @@
+//===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements Mem2Reg-related interfaces for MemRef dialect
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+/// Walks over the indices of the elements of a tensor of a given `shape` by
+/// updating `index` in place to the next index. This returns failure if the
+/// provided index was the last index.
+static LogicalResult nextIndex(ArrayRef<int64_t> shape,
+ MutableArrayRef<int64_t> index) {
+ for (size_t i = 0; i < shape.size(); ++i) {
+ index[i]++;
+ if (index[i] < shape[i])
+ return success();
+ index[i] = 0;
+ }
+ return failure();
+}
+
+/// Calls `walker` for each index within a tensor of a given `shape`, providing
+/// the index as an array attribute of the coordinates.
+template <typename CallableT>
+static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
+ CallableT &&walker) {
+ Type indexType = IndexType::get(ctx);
+ SmallVector<int64_t> shapeIter(shape.size(), 0);
+ do {
+ SmallVector<Attribute> indexAsAttr;
+ for (int64_t dim : shapeIter)
+ indexAsAttr.push_back(IntegerAttr::get(indexType, dim));
+ walker(ArrayAttr::get(ctx, indexAsAttr));
+ } while (succeeded(nextIndex(shape, shapeIter)));
+}
+
+//===----------------------------------------------------------------------===//
+// Interfaces for AllocaOp
+//===----------------------------------------------------------------------===//
+
+static bool isSupportedElementType(Type type) {
+ return type.isa<MemRefType>() ||
+ OpBuilder(type.getContext()).getZeroAttr(type);
+}
+
+SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
+ MemRefType type = getType();
+ if (!isSupportedElementType(type.getElementType()))
+ return {};
+ if (!type.hasStaticShape())
+ return {};
+ // Make sure the memref contains only a single element.
+ if (type.getNumElements() != 1)
+ return {};
+
+ return {MemorySlot{getResult(), type.getElementType()}};
+}
+
+Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
+ RewriterBase &rewriter) {
+ assert(isSupportedElementType(slot.elemType));
+ // TODO: support more types.
+ return TypeSwitch<Type, Value>(slot.elemType)
+ .Case([&](MemRefType t) {
+ return rewriter.create<memref::AllocaOp>(getLoc(), t);
+ })
+ .Default([&](Type t) {
+ return rewriter.create<arith::ConstantOp>(getLoc(), t,
+ rewriter.getZeroAttr(t));
+ });
+}
+
+void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue,
+ RewriterBase &rewriter) {
+ if (defaultValue.use_empty())
+ rewriter.eraseOp(defaultValue.getDefiningOp());
+ rewriter.eraseOp(*this);
+}
+
+void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
+ BlockArgument argument,
+ RewriterBase &rewriter) {}
+
+SmallVector<DestructurableMemorySlot>
+memref::AllocaOp::getDestructurableSlots() {
+ MemRefType memrefType = getType();
+ auto destructurable = memrefType.dyn_cast<DestructurableTypeInterface>();
+ if (!destructurable)
+ return {};
+
+ Optional<DenseMap<Attribute, Type>> destructuredType =
+ destructurable.getSubelementIndexMap();
+ if (!destructuredType)
+ return {};
+
+ DenseMap<Attribute, Type> indexMap;
+ for (auto const &[index, type] : *destructuredType)
+ indexMap.insert({index, MemRefType::get({}, type)});
+
+ return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}};
+}
+
+DenseMap<Attribute, MemorySlot>
+memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
+ const SmallPtrSetImpl<Attribute> &usedIndices,
+ RewriterBase &rewriter) {
+ rewriter.setInsertionPointAfter(*this);
+
+ DenseMap<Attribute, MemorySlot> slotMap;
+
+ auto memrefType = getType().cast<DestructurableTypeInterface>();
+ for (Attribute usedIndex : usedIndices) {
+ Type elemType = memrefType.getTypeAtIndex(usedIndex);
+ MemRefType elemPtr = MemRefType::get({}, elemType);
+ auto subAlloca = rewriter.create<memref::AllocaOp>(getLoc(), elemPtr);
+ slotMap.try_emplace<MemorySlot>(usedIndex,
+ {subAlloca.getResult(), elemType});
+ }
+
+ return slotMap;
+}
+
+void memref::AllocaOp::handleDestructuringComplete(
+ const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
+ assert(slot.ptr == getResult());
+ rewriter.eraseOp(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// Interfaces for LoadOp/StoreOp
+//===----------------------------------------------------------------------===//
+
+bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
+ return getMemRef() == slot.ptr;
+}
+
+Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
+
+bool memref::LoadOp::canUsesBeRemoved(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ if (blockingUses.size() != 1)
+ return false;
+ Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
+ getResult().getType() == slot.elemType;
+}
+
+DeletionKind memref::LoadOp::removeBlockingUses(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ RewriterBase &rewriter, Value reachingDefinition) {
+ // `canUsesBeRemoved` checked this blocking use must be the loaded slot
+ // pointer.
+ rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
+ return DeletionKind::Delete;
+}
+
+/// Returns the index of a memref in attribute form, given its indices.
+static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx,
+ ValueRange indices) {
+ SmallVector<Attribute> index;
+ for (Value coord : indices) {
+ IntegerAttr coordAttr;
+ if (!matchPattern(coord, m_Constant<IntegerAttr>(&coordAttr)))
+ return {};
+ index.push_back(coordAttr);
+ }
+ return ArrayAttr::get(ctx, index);
+}
+
+bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ if (slot.ptr != getMemRef())
+ return false;
+ Attribute index =
+ getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ if (!index)
+ return false;
+ usedIndices.insert(index);
+ return true;
+}
+
+DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter) {
+ Attribute index =
+ getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ const MemorySlot &memorySlot = subslots.at(index);
+ rewriter.updateRootInPlace(*this, [&]() {
+ setMemRef(memorySlot.ptr);
+ getIndicesMutable().clear();
+ });
+ return DeletionKind::Keep;
+}
+
+bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+Value memref::StoreOp::getStored(const MemorySlot &slot) {
+ if (getMemRef() != slot.ptr)
+ return {};
+ return getValue();
+}
+
+bool memref::StoreOp::canUsesBeRemoved(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ if (blockingUses.size() != 1)
+ return false;
+ Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
+ getValue() != slot.ptr && getValue().getType() == slot.elemType;
+}
+
+DeletionKind memref::StoreOp::removeBlockingUses(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ RewriterBase &rewriter, Value reachingDefinition) {
+ return DeletionKind::Delete;
+}
+
+bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
+ if (slot.ptr != getMemRef() || getValue() == slot.ptr)
+ return false;
+ Attribute index =
+ getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ if (!index || !slot.elementPtrs.contains(index))
+ return false;
+ usedIndices.insert(index);
+ return true;
+}
+
+DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter) {
+ Attribute index =
+ getAttributeIndexFromIndexOperands(getContext(), getIndices());
+ const MemorySlot &memorySlot = subslots.at(index);
+ rewriter.updateRootInPlace(*this, [&]() {
+ setMemRef(memorySlot.ptr);
+ getIndicesMutable().clear();
+ });
+ return DeletionKind::Keep;
+}
+
+//===----------------------------------------------------------------------===//
+// Interfaces for destructurable types
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct MemRefDestructurableTypeExternalModel
+ : public DestructurableTypeInterface::ExternalModel<
+ MemRefDestructurableTypeExternalModel, MemRefType> {
+ std::optional<DenseMap<Attribute, Type>>
+ getSubelementIndexMap(Type type) const {
+ auto memrefType = type.cast<MemRefType>();
+ constexpr int64_t maxMemrefSizeForDestructuring = 16;
+ if (!memrefType.hasStaticShape() ||
+ memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
+ memrefType.getNumElements() == 1)
+ return {};
+
+ DenseMap<Attribute, Type> destructured;
+ walkIndicesAsAttr(
+ memrefType.getContext(), memrefType.getShape(), [&](Attribute index) {
+ destructured.insert({index, memrefType.getElementType()});
+ });
+
+ return destructured;
+ }
+
+ Type getTypeAtIndex(Type type, Attribute index) const {
+ auto memrefType = type.cast<MemRefType>();
+ auto coordArrAttr = index.dyn_cast<ArrayAttr>();
+ if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
+ return {};
+
+ Type indexType = IndexType::get(memrefType.getContext());
+ for (const auto &[coordAttr, dimSize] :
+ llvm::zip(coordArrAttr, memrefType.getShape())) {
+ auto coord = coordAttr.dyn_cast<IntegerAttr>();
+ if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
+ coord.getInt() >= dimSize)
+ return {};
+ }
+
+ return memrefType.getElementType();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Register external models
+//===----------------------------------------------------------------------===//
+
+void mlir::memref::registerMemorySlotExternalModels(DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
+ MemRefType::attachInterface<MemRefDestructurableTypeExternalModel>(*ctx);
+ });
+}
diff --git a/mlir/test/Dialect/MemRef/sroa.mlir b/mlir/test/Dialect/MemRef/sroa.mlir
new file mode 100644
index 0000000000000..d78053d8ea777
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/sroa.mlir
@@ -0,0 +1,154 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(sroa))" --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @basic(%arg0: i32, %arg1: i32) -> i32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: = memref.alloca()
+ // CHECK-COUNT-2: = memref.alloca() : memref<i32>
+ // CHECK-NOT: = memref.alloca()
+ %alloca = memref.alloca() : memref<2xi32>
+ // CHECK: memref.store %[[ARG0]], %[[ALLOCA0:.*]][]
+ memref.store %arg0, %alloca[%c0] : memref<2xi32>
+ // CHECK: memref.store %[[ARG1]], %[[ALLOCA1:.*]][]
+ memref.store %arg1, %alloca[%c1] : memref<2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA0]][]
+ %res = memref.load %alloca[%c0] : memref<2xi32>
+ // CHECK: return %[[RES]] : i32
+ return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_high_dimensions
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+func.func @basic_high_dimensions(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: = memref.alloca()
+ // CHECK-COUNT-3: = memref.alloca() : memref<i32>
+ // CHECK-NOT: = memref.alloca()
+ %alloca = memref.alloca() : memref<2x2xi32>
+ // CHECK: memref.store %[[ARG0]], %[[ALLOCA0:.*]][]
+ memref.store %arg0, %alloca[%c0, %c0] : memref<2x2xi32>
+ // CHECK: memref.store %[[ARG1]], %[[ALLOCA1:.*]][]
+ memref.store %arg1, %alloca[%c0, %c1] : memref<2x2xi32>
+ // CHECK: memref.store %[[ARG2]], %[[ALLOCA2:.*]][]
+ memref.store %arg2, %alloca[%c1, %c0] : memref<2x2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA1]][]
+ %res = memref.load %alloca[%c0, %c1] : memref<2x2xi32>
+ // CHECK: return %[[RES]] : i32
+ return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @resolve_alias
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @resolve_alias(%arg0: i32, %arg1: i32) -> i32 {
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: = memref.alloca()
+ // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i32>
+ // CHECK-NOT: = memref.alloca()
+ %alloca = memref.alloca() : memref<2xi32>
+ // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][]
+ memref.store %arg0, %alloca[%c0] : memref<2xi32>
+ // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][]
+ memref.store %arg1, %alloca[%c0] : memref<2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][]
+ %res = memref.load %alloca[%c0] : memref<2xi32>
+ // CHECK: return %[[RES]] : i32
+ return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_direct_use
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @no_direct_use(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: = memref.alloca()
+ // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
+ // CHECK-NOT: = memref.alloca()
+ %alloca = memref.alloca() : memref<2xi32>
+ // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
+ memref.store %arg0, %alloca[%c0] : memref<2xi32>
+ // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][%[[C1]]]
+ memref.store %arg1, %alloca[%c1] : memref<2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]]]
+ %res = memref.load %alloca[%c0] : memref<2xi32>
+ call @use(%alloca) : (memref<2xi32>) -> ()
+ // CHECK: return %[[RES]] : i32
+ return %res : i32
+}
+
+func.func @use(%foo: memref<2xi32>) { return }
+
+// -----
+
+// CHECK-LABEL: func.func @no_dynamic_indexing
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[INDEX:.*]]: index)
+func.func @no_dynamic_indexing(%arg0: i32, %arg1: i32, %index: index) -> i32 {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NOT: = memref.alloca()
+ // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
+ // CHECK-NOT: = memref.alloca()
+ %alloca = memref.alloca() : memref<2xi32>
+ // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
+ memref.store %arg0, %alloca[%c0] : memref<2xi32>
+ // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][%[[INDEX]]]
+ memref.store %arg1, %alloca[%index] : memref<2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]]]
+ %res = memref.load %alloca[%c0] : memref<2xi32>
+ // CHECK: return %[[RES]] : i32
+ return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_dynamic_shape
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @no_dynamic_shape(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK-NOT: = memref.alloca()
+ // CHECK: %[[ALLOCA:.*]] = memref.alloca(%[[C1]]) : memref<?x2xi32>
+ // CHECK-NOT: = memref.alloca()
+ %alloca = memref.alloca(%c1) : memref<?x2xi32>
+ // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]], %[[C0]]]
+ memref.store %arg0, %alloca[%c0, %c0] : memref<?x2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]], %[[C0]]]
+ %res = memref.load %alloca[%c0, %c0] : memref<?x2xi32>
+ // CHECK: return %[[RES]] : i32
+ return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_out_of_bounds
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
+func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[C100:.*]] = arith.constant 100 : index
+ %c100 = arith.constant 100 : index
+ // CHECK-NOT: = memref.alloca()
+ // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32>
+ // CHECK-NOT: = memref.alloca()
+ %alloca = memref.alloca() : memref<2xi32>
+ // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]]
+ memref.store %arg0, %alloca[%c0] : memref<2xi32>
+ // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][%[[C100]]]
+ memref.store %arg1, %alloca[%c100] : memref<2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]]]
+ %res = memref.load %alloca[%c0] : memref<2xi32>
+ // CHECK: return %[[RES]] : i32
+ return %res : i32
+}
More information about the Mlir-commits
mailing list