[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 19 13:03:13 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Alan Li (lialan)
<details>
<summary>Changes</summary>
This patch is to make the `flatten-memref` pass to be more complete by adding supports of memref operands of more ops.
Some patterns are from: https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
---
Patch is 34.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159841.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h (+99)
- (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+18-54)
- (modified) mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp (+296-5)
- (modified) mlir/test/Dialect/MemRef/flatten_memref.mlir (+128)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 8b76930aed35a..562b8c11225e8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -14,10 +14,15 @@
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLFunctionalExtras.h"
namespace mlir {
+class Location;
class OpBuilder;
class RewritePatternSet;
class RewriterBase;
@@ -33,7 +38,9 @@ class NarrowTypeEmulationConverter;
namespace memref {
class AllocOp;
class AllocaOp;
+class CollapseShapeOp;
class DeallocOp;
+class ExpandShapeOp;
//===----------------------------------------------------------------------===//
// Patterns
@@ -213,6 +220,98 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
memref::AllocaOp allocToAlloca(
RewriterBase &rewriter, memref::AllocOp alloc,
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
+
+/// Compute the expanded sizes of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// sizes#i =
+/// baseSizes#groupId / product(expandShapeSizes#j,
+/// for j in group excluding reassIdx#i)
+/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the expanded strides of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// strides#i =
+/// origStrides#reassDim * product(expandShapeSizes#j, for j in
+/// reassIdx#i+1..reassIdx#i+group.size-1)
+///
+/// Where reassIdx#i is the reassociation index for at index i in \p groupId
+/// and expandShapeSizes#j is either:
+/// - The constant size at dimension j, derived directly from the result type of
+/// the expand_shape op, or
+/// - An affine expression: baseSizes#reassDim / product of all constant sizes
+/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
+/// element.)
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides,
+ unsigned groupId);
+
+/// Produce an OpFoldResult object with \p builder at \p loc representing
+/// `prod(valueOrConstant#i, for i in {indices})`,
+/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
+/// values[i] otherwise.
+///
+/// \pre for all index in indices: index < values.size()
+/// \pre for all index in indices: index < maybeConstants.size()
+OpFoldResult
+getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> maybeConstants,
+ ArrayRef<OpFoldResult> values,
+ llvm::function_ref<bool(int64_t)> isDynamic);
+
+/// Compute the collapsed size of the given \p collapseShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// TODO: Move this utility function directly within CollapseShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the collapsed stride of the given \p collpaseShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// Conceptually this helper function returns the stride of the inner most
+/// dimension of that group in the original shape.
+///
+/// \post result.size() == 1, in other words, each group collapse to one
+/// dimension.
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides, unsigned groupId);
+
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..6b69d0e366903 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -24,6 +24,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
@@ -35,6 +36,7 @@ namespace memref {
using namespace mlir;
using namespace mlir::affine;
+using namespace mlir::memref;
namespace {
@@ -250,23 +252,12 @@ struct ExtractStridedMetadataOpSubviewFolder
}
};
-/// Compute the expanded sizes of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origSizes hold the sizes of the source shape as values.
-/// This is used to compute the new sizes in cases of dynamic shapes.
-///
-/// sizes#i =
-/// baseSizes#groupId / product(expandShapeSizes#j,
-/// for j in group excluding reassIdx#i)
-/// Where reassIdx#i is the reassociation index at index i in \p groupId.
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+} // namespace
+
+namespace mlir {
+namespace memref {
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
expandShape.getReassociationIndices()[groupId];
@@ -305,31 +296,7 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
return expandedSizes;
}
-/// Compute the expanded strides of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origStrides and \p origSizes hold respectively the strides and sizes
-/// of the source shape as values.
-/// This is used to compute the strides in cases of dynamic shapes and/or
-/// dynamic stride for this reassociation group.
-///
-/// strides#i =
-/// origStrides#reassDim * product(expandShapeSizes#j, for j in
-/// reassIdx#i+1..reassIdx#i+group.size-1)
-///
-/// Where reassIdx#i is the reassociation index for at index i in \p groupId
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-/// the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-/// element.)
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
@@ -405,14 +372,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
return expandedStrides;
}
-/// Produce an OpFoldResult object with \p builder at \p loc representing
-/// `prod(valueOrConstant#i, for i in {indices})`,
-/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
-/// values[i] otherwise.
-///
-/// \pre for all index in indices: index < values.size()
-/// \pre for all index in indices: index < maybeConstants.size()
-static OpFoldResult
+OpFoldResult
getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
ArrayRef<int64_t> maybeConstants,
ArrayRef<OpFoldResult> values,
@@ -450,8 +410,8 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
SmallVector<OpFoldResult> collapsedSize;
@@ -491,8 +451,8 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
-static SmallVector<OpFoldResult>
-getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
@@ -546,6 +506,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {lastValidStride};
}
+} // namespace memref
+} // namespace mlir
+
+namespace {
/// From `reshape_like(memref, subSizes, subStrides))` compute
///
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 1208fddf37e0b..43a67f1fab2be 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -21,11 +21,13 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
@@ -46,6 +48,7 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
return cast<Value>(in);
}
+
/// Returns a collapsed memref and the linearized index to access the element
/// at the specified indices.
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -89,17 +92,21 @@ static bool needFlattening(Value val) {
return type.getRank() > 1;
}
-static bool checkLayout(Value val) {
- auto type = cast<MemRefType>(val.getType());
+static bool checkLayout(MemRefType type) {
return type.getLayout().isIdentity() ||
isa<StridedLayoutAttr>(type.getLayout());
}
+static bool checkLayout(Value val) {
+ return checkLayout(cast<MemRefType>(val.getType()));
+}
+
namespace {
static Value getTargetMemref(Operation *op) {
return llvm::TypeSwitch<Operation *, Value>(op)
.template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
- memref::AllocOp>([](auto op) { return op.getMemref(); })
+ memref::AllocOp, memref::DeallocOp>(
+ [](auto op) { return op.getMemref(); })
.template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
vector::MaskedStoreOp, vector::TransferReadOp,
vector::TransferWriteOp>(
@@ -189,6 +196,10 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
rewriter.replaceOp(op, newTransferWrite);
})
+ .template Case<memref::DeallocOp>([&](auto op) {
+ auto newDealloc = memref::DeallocOp::create(rewriter, loc, flatMemref);
+ rewriter.replaceOp(op, newDealloc);
+ })
.Default([&](auto op) {
op->emitOpError("unimplemented: do not know how to replace op.");
});
@@ -197,7 +208,8 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
template <typename T>
static ValueRange getIndices(T op) {
if constexpr (std::is_same_v<T, memref::AllocaOp> ||
- std::is_same_v<T, memref::AllocOp>) {
+ std::is_same_v<T, memref::AllocOp> ||
+ std::is_same_v<T, memref::DeallocOp>) {
return ValueRange{};
} else {
return op.getIndices();
@@ -250,6 +262,243 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
}
};
+/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
+struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ static Attribute flattenAttribute(Attribute value, ShapedType newType) {
+ if (!value)
+ return value;
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
+ return splatAttr.reshape(newType);
+ } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+ return denseAttr.reshape(newType);
+ } else if (auto denseResourceAttr =
+ llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+ return DenseResourceElementsAttr::get(newType,
+ denseResourceAttr.getRawHandle());
+ }
+ return {};
+ }
+
+ LogicalResult
+ matchAndRewrite(memref::GlobalOp globalOp,
+ PatternRewriter &rewriter) const override {
+ auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
+ if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
+ return failure();
+
+ auto tensorType = RankedTensorType::get({oldType.getNumElements()},
+ oldType.getElementType());
+ auto memRefType =
+ MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
+ AffineMap(), oldType.getMemorySpace());
+ auto newInitialValue =
+ flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
+ rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+ globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
+ memRefType, newInitialValue, globalOp.getConstant(),
+ /*alignment=*/IntegerAttr());
+ return success();
+ }
+};
+
+struct FlattenCollapseShape final
+ : public OpRewritePattern<memref::CollapseShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ memref::ExtractStridedMetadataOp metadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+ SmallVector<OpFoldResult> collapsedSizes;
+ SmallVector<OpFoldResult> collapsedStrides;
+ unsigned numGroups = op.getReassociationIndices().size();
+ collapsedSizes.reserve(numGroups);
+ collapsedStrides.reserve(numGroups);
+ for (unsigned i = 0; i < numGroups; ++i) {
+ SmallVector<OpFoldResult> groupSizes =
+ memref::getCollapsedSize(op, rewriter, origSizes, i);
+ SmallVector<OpFoldResult> groupStrides =
+ memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+ collapsedSizes.append(groupSizes.begin(), groupSizes.end());
+ collapsedStrides.append(groupStrides.begin(), groupStrides.end());
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, op.getType(), op.getSrc(), offset, collapsedSizes,
+ collapsedStrides);
+ return success();
+ }
+};
+
+struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ memref::ExtractStridedMetadataOp metadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+ SmallVector<OpFoldResult> expandedSizes;
+ SmallVector<OpFoldResult> expandedStrides;
+ unsigned numGroups = op.getReassociationIndices().size();
+ expandedSizes.reserve(op.getResultType().getRank());
+ expandedStrides.reserve(op.getResultType().getRank());
+
+ for (unsigned i = 0; i < numGroups; ++i) {
+ SmallVector<OpFoldResult> groupSizes =
+ memref::getExpandedSizes(op, rewriter, origSizes, i);
+ SmallVector<OpFoldResult> groupStrides =
+ memref::getExpandedStrides(op, rewriter, origSizes, origStrides, i);
+ expandedSizes.append(groupSizes.begin(), groupSizes.end());
+ expandedStrides.append(groupStrides.begin(), groupStrides.end());
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, op.getType(), op.getSrc(), offset, expandedSizes, expandedStrides);
+ return success();
+ }
+};
+
+
+// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
+struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::SubViewOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
+ if (!sourceType || sourceType.getRank() <= 1)
+ return failure();
+ if (!checkLayout(sourceType))
+ return failure();
+
+ MemRefType resultType = op.getType();
+ if (resultType.getRank() <= 1 || !checkLayout(res...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/159841
More information about the Mlir-commits
mailing list