[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)
Alan Li
llvmlistbot at llvm.org
Fri Sep 19 14:59:08 PDT 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/159841
>From bc742d1f6289ba275a0b61e34654dc198846c6e0 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 16 Sep 2025 20:57:41 -0400
Subject: [PATCH 1/6] Update
---
.../MemRef/Transforms/FlattenMemRefs.cpp | 13 +++++--
mlir/test/Dialect/MemRef/flatten_memref.mlir | 38 +++++++++++++++++++
2 files changed, 48 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 1208fddf37e0b..e2d5d62f64750 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -99,7 +99,8 @@ namespace {
static Value getTargetMemref(Operation *op) {
return llvm::TypeSwitch<Operation *, Value>(op)
.template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
- memref::AllocOp>([](auto op) { return op.getMemref(); })
+ memref::AllocOp, memref::DeallocOp>(
+ [](auto op) { return op.getMemref(); })
.template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
vector::MaskedStoreOp, vector::TransferReadOp,
vector::TransferWriteOp>(
@@ -189,6 +190,10 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
rewriter.replaceOp(op, newTransferWrite);
})
+ .template Case<memref::DeallocOp>([&](auto op) {
+ auto newDealloc = memref::DeallocOp::create(rewriter, loc, flatMemref);
+ rewriter.replaceOp(op, newDealloc);
+ })
.Default([&](auto op) {
op->emitOpError("unimplemented: do not know how to replace op.");
});
@@ -197,7 +202,8 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
template <typename T>
static ValueRange getIndices(T op) {
if constexpr (std::is_same_v<T, memref::AllocaOp> ||
- std::is_same_v<T, memref::AllocOp>) {
+ std::is_same_v<T, memref::AllocOp> ||
+ std::is_same_v<T, memref::DeallocOp>) {
return ValueRange{};
} else {
return op.getIndices();
@@ -286,7 +292,8 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
MemRefRewritePattern<memref::StoreOp>,
MemRefRewritePattern<memref::AllocOp>,
- MemRefRewritePattern<memref::AllocaOp>>(
+ MemRefRewritePattern<memref::AllocaOp>,
+ MemRefRewritePattern<memref::DeallocOp>>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index e45a10ca0d431..88f46b07dad93 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -298,3 +298,41 @@ func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32,
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[1, 4], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
// CHECK: memref.load %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[1], offset: 100>>
+
+// -----
+
+func.func @dealloc_static_memref(%input: memref<4x8xf32>) {
+ memref.dealloc %input : memref<4x8xf32>
+ return
+}
+
+// CHECK-LABEL: func @dealloc_static_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>)
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32], strides: [1] : memref<4x8xf32> to memref<32xf32, strided<[1]>>
+// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1]>>
+
+// -----
+
+func.func @dealloc_dynamic_memref(%input: memref<?x?xf32>) {
+ memref.dealloc %input : memref<?x?xf32>
+ return
+}
+
+// CHECK-LABEL: func @dealloc_dynamic_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>)
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK: %[[SIZE:.*]] = affine.max #{{.*}}()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [%[[SIZE]]], strides: [1] : memref<?x?xf32> to memref<?xf32, strided<[1]>>
+// CHECK: memref.dealloc %[[REINT]] : memref<?xf32, strided<[1]>>
+
+// -----
+
+func.func @dealloc_strided_memref(%input: memref<4x8xf32, strided<[8, 1], offset: 100>>) {
+ memref.dealloc %input : memref<4x8xf32, strided<[8, 1], offset: 100>>
+ return
+}
+
+// CHECK-LABEL: func @dealloc_strided_memref
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 1], offset: 100>>)
+// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
+// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1], offset: 100>>
>From f549e4fbc6215d775bd67d1ec1d679378d111e3a Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 17 Sep 2025 09:30:55 -0400
Subject: [PATCH 2/6] New things.
---
.../MemRef/Transforms/FlattenMemRefs.cpp | 212 +++++++++++++++++-
mlir/test/Dialect/MemRef/flatten_memref.mlir | 48 ++++
2 files changed, 259 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index e2d5d62f64750..e0dbc8d8641bd 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -95,6 +96,99 @@ static bool checkLayout(Value val) {
isa<StridedLayoutAttr>(type.getLayout());
}
+/// Produce an OpFoldResult representing the product of the values or constants
+/// referenced by `indices`. `staticShape` provides the statically known sizes
+/// for the source memref, while `values` contains the mixed (value/attribute)
+/// representation produced by `memref.extract_strided_metadata`.
+static OpFoldResult getProductOfValues(ArrayRef<int64_t> indices,
+ OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> staticShape,
+ ArrayRef<OpFoldResult> values) {
+ AffineExpr product = builder.getAffineConstantExpr(1);
+ SmallVector<OpFoldResult> inputs;
+ unsigned numSymbols = 0;
+ for (int64_t idx : indices) {
+ product = product * builder.getAffineSymbolExpr(numSymbols++);
+ if (ShapedType::isDynamic(staticShape[idx]))
+ inputs.push_back(values[idx]);
+ else
+ inputs.push_back(builder.getIndexAttr(staticShape[idx]));
+ }
+ return affine::makeComposedFoldedAffineApply(builder, loc, product, inputs);
+}
+
+/// Return the collapsed size (as OpFoldResult) for the reassociation group
+/// `groupId` of `collapseShapeOp`.
+static SmallVector<OpFoldResult>
+getCollapsedSize(memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+ SmallVector<OpFoldResult> collapsedSize;
+
+ MemRefType resultType = collapseShapeOp.getResultType();
+ int64_t dimSize = resultType.getDimSize(groupId);
+ if (!ShapedType::isDynamic(dimSize)) {
+ collapsedSize.push_back(builder.getIndexAttr(dimSize));
+ return collapsedSize;
+ }
+
+ auto sourceType = collapseShapeOp.getSrcType();
+ ArrayRef<int64_t> staticShape = sourceType.getShape();
+ ArrayRef<int64_t> reassocGroup =
+ collapseShapeOp.getReassociationIndices()[groupId];
+
+ collapsedSize.push_back(getProductOfValues(reassocGroup, builder,
+ collapseShapeOp.getLoc(),
+ staticShape, origSizes));
+ return collapsedSize;
+}
+
+/// Return the collapsed stride (as OpFoldResult) for the reassociation group
+/// `groupId` of `collapseShapeOp`.
+static SmallVector<OpFoldResult> getCollapsedStride(
+ memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, ArrayRef<OpFoldResult> origStrides,
+ unsigned groupId) {
+ ArrayRef<int64_t> reassocGroup =
+ collapseShapeOp.getReassociationIndices()[groupId];
+ assert(!reassocGroup.empty() &&
+ "reassociation group must contain at least one dimension");
+
+ auto sourceType = collapseShapeOp.getSrcType();
+ auto [strides, offset] = sourceType.getStridesAndOffset();
+ (void)offset;
+ ArrayRef<int64_t> srcShape = sourceType.getShape();
+
+ OpFoldResult lastValidStride = nullptr;
+ for (int64_t dim : reassocGroup) {
+ if (srcShape[dim] == 1)
+ continue;
+ int64_t currentStride = strides[dim];
+ if (ShapedType::isDynamic(currentStride))
+ lastValidStride = origStrides[dim];
+ else
+ lastValidStride = builder.getIndexAttr(currentStride);
+ }
+
+ if (!lastValidStride) {
+ MemRefType collapsedType = collapseShapeOp.getResultType();
+ auto [collapsedStrides, collapsedOffset] =
+ collapsedType.getStridesAndOffset();
+ (void)collapsedOffset;
+ int64_t finalStride = collapsedStrides[groupId];
+ if (ShapedType::isDynamic(finalStride)) {
+ for (int64_t dim : reassocGroup) {
+ assert(srcShape[dim] == 1 && "expected size-one dimensions");
+ if (ShapedType::isDynamic(strides[dim]))
+ return {origStrides[dim]};
+ }
+ llvm_unreachable("expected to find a dynamic stride");
+ }
+ return {builder.getIndexAttr(finalStride)};
+ }
+
+ return {lastValidStride};
+}
+
namespace {
static Value getTargetMemref(Operation *op) {
return llvm::TypeSwitch<Operation *, Value>(op)
@@ -256,6 +350,82 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
}
};
+/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
+struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ static Attribute flattenAttribute(Attribute value, ShapedType newType) {
+ if (!value)
+ return value;
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
+ return splatAttr.reshape(newType);
+ } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+ return denseAttr.reshape(newType);
+ } else if (auto denseResourceAttr =
+ llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+ return DenseResourceElementsAttr::get(newType,
+ denseResourceAttr.getRawHandle());
+ }
+ return {};
+ }
+
+ LogicalResult
+ matchAndRewrite(memref::GlobalOp globalOp,
+ PatternRewriter &rewriter) const override {
+ auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
+ if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
+ return failure();
+
+ auto tensorType = RankedTensorType::get({oldType.getNumElements()},
+ oldType.getElementType());
+ auto memRefType =
+ MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
+ AffineMap(), oldType.getMemorySpace());
+ auto newInitialValue =
+ flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
+ rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+ globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
+ memRefType, newInitialValue, globalOp.getConstant(),
+ /*alignment=*/IntegerAttr());
+ return success();
+ }
+};
+
+struct FlattenCollapseShape final
+ : public OpRewritePattern<memref::CollapseShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ memref::ExtractStridedMetadataOp metadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+ SmallVector<OpFoldResult> collapsedSizes;
+ SmallVector<OpFoldResult> collapsedStrides;
+ unsigned numGroups = op.getReassociationIndices().size();
+ collapsedSizes.reserve(numGroups);
+ collapsedStrides.reserve(numGroups);
+ for (unsigned i = 0; i < numGroups; ++i) {
+ SmallVector<OpFoldResult> groupSizes =
+ getCollapsedSize(op, rewriter, origSizes, i);
+ SmallVector<OpFoldResult> groupStrides =
+ getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+ collapsedSizes.append(groupSizes.begin(), groupSizes.end());
+ collapsedStrides.append(groupStrides.begin(), groupStrides.end());
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, op.getType(), op.getSrc(), offset, collapsedSizes,
+ collapsedStrides);
+ return success();
+ }
+};
+
struct FlattenMemrefsPass
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
using Base::Base;
@@ -288,12 +458,52 @@ void memref::populateFlattenVectorOpsOnMemrefPatterns(
patterns.getContext());
}
+/// Special pattern for GetGlobalOp to avoid infinite loops
+struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::GetGlobalOp op,
+ PatternRewriter &rewriter) const override {
+ // Check if this get_global references a multi-dimensional global
+ auto module = op->template getParentOfType<ModuleOp>();
+ auto globalOp = module.template lookupSymbol<memref::GlobalOp>(op.getName());
+ if (!globalOp) {
+ return failure();
+ }
+
+ auto globalType = globalOp.getType();
+ auto resultType = op.getType();
+
+ // Only apply if the global has been flattened but the get_global hasn't
+ if (globalType.getRank() == 1 && resultType.getRank() > 1) {
+ auto newGetGlobal = memref::GetGlobalOp::create(
+ rewriter, op.getLoc(), globalType, op.getName());
+
+ // Cast the flattened result back to the original shape
+ memref::ExtractStridedMetadataOp stridedMetadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(), op.getResult());
+ auto castResult = memref::ReinterpretCastOp::create(
+ rewriter, op.getLoc(), resultType, newGetGlobal,
+ /*offset=*/rewriter.getIndexAttr(0),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides());
+ rewriter.replaceOp(op, castResult);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
MemRefRewritePattern<memref::StoreOp>,
MemRefRewritePattern<memref::AllocOp>,
MemRefRewritePattern<memref::AllocaOp>,
- MemRefRewritePattern<memref::DeallocOp>>(
+ MemRefRewritePattern<memref::DeallocOp>,
+ FlattenCollapseShape,
+ FlattenGetGlobal,
+ FlattenGlobal>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index 88f46b07dad93..5c61bc87418bd 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -194,6 +194,35 @@ func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: in
// -----
+func.func @collapse_shape_static(%arg0: memref<2x3x4xf32>) -> memref<6x4xf32> {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]]
+ : memref<2x3x4xf32> into memref<6x4xf32>
+ return %0 : memref<6x4xf32>
+}
+// CHECK-LABEL: func @collapse_shape_static
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [6, 4], strides: [4, 1]
+// CHECK: return %[[REINT]]
+
+// -----
+
+func.func @collapse_shape_dynamic(
+ %arg0: memref<2x?x4xf32, strided<[?, ?, ?], offset: ?>>) ->
+ memref<?x4xf32, strided<[?, ?], offset: ?>> {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]]
+ : memref<2x?x4xf32, strided<[?, ?, ?], offset: ?>>
+ into memref<?x4xf32, strided<[?, ?], offset: ?>>
+ return %0 : memref<?x4xf32, strided<[?, ?], offset: ?>>
+}
+// CHECK: #map = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #map1 = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
+// CHECK-LABEL: func @collapse_shape_dynamic
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %arg0
+// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#1]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]], 4], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2]
+// CHECK: return %[[REINT]]
+
+// -----
+
func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
%c0 = arith.constant 0 : i2
%0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>
@@ -336,3 +365,22 @@ func.func @dealloc_strided_memref(%input: memref<4x8xf32, strided<[8, 1], offset
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 1], offset: 100>>)
// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>>
// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1], offset: 100>>
+
+// -----
+
+memref.global "private" constant @constant_3x3x1x1xf32 : memref<3x3x1x1xf32> = dense<[[[[-1.000000e+00]], [[0.000000e+00]], [[1.000000e+00]]], [[[-2.000000e+00]], [[0.000000e+00]], [[2.000000e+00]]], [[[-1.000000e+00]], [[0.000000e+00]], [[1.000000e+00]]]]>
+func.func @load_global_with_offset(%i0: index, %i1: index, %i2: index, %i3: index) -> f32 {
+ %global = memref.get_global @constant_3x3x1x1xf32 : memref<3x3x1x1xf32>
+ %val = memref.load %global[%i0, %i1, %i2, %i3] : memref<3x3x1x1xf32>
+ return %val: f32
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1, s2, s3] -> (s0 * 3 + s1 + s2 + s3)>
+// CHECK: memref.global "private" constant @constant_3x3x1x1xf32 : memref<9xf32> = dense<[-1.000000e+00, 0.000000e+00, 1.000000e+00, -2.000000e+00, 0.000000e+00, 2.000000e+00, -1.000000e+00, 0.000000e+00, 1.000000e+00]>
+//CHECK-LABEL: func.func @load_global_with_offset
+// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index)
+// CHECK: %[[GLOBAL:.+]] = memref.get_global @constant_3x3x1x1xf32 : memref<9xf32>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[$MAP]]()[%[[I0]], %[[I1]], %[[I2]], %[[I3]]]
+// CHECK: %[[REINTERPRET:.+]] = memref.reinterpret_cast %[[GLOBAL]] to offset: [0], sizes: [9], strides: [1] : memref<9xf32> to memref<9xf32, strided<[1]>>
+// CHECK: %[[LOAD:.+]] = memref.load %[[REINTERPRET]][%[[INDEX]]] : memref<9xf32, strided<[1]>>
+// CHECK: return %[[LOAD]]
>From 79d2edfdb5788c13a8f580c551ca712b32f8b0b7 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 17 Sep 2025 09:40:18 -0400
Subject: [PATCH 3/6] ExpandShapeOp
---
.../MemRef/Transforms/FlattenMemRefs.cpp | 144 ++++++++++++++++++
mlir/test/Dialect/MemRef/flatten_memref.mlir | 31 +++-
2 files changed, 174 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index e0dbc8d8641bd..5394df5b46fb2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -96,6 +96,115 @@ static bool checkLayout(Value val) {
isa<StridedLayoutAttr>(type.getLayout());
}
+/// Compute the expanded sizes of the given expand_shape for the reassociation
+/// group `groupId`. Portions adapted from
+/// `lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp` to avoid a direct
+/// dependency from the MemRef dialect on the Affine dialect.
+static SmallVector<OpFoldResult>
+getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+ ArrayRef<int64_t> reassocGroup =
+ expandShape.getReassociationIndices()[groupId];
+ assert(!reassocGroup.empty() &&
+ "Reassociation group should have at least one dimension");
+
+ unsigned groupSize = reassocGroup.size();
+ SmallVector<OpFoldResult> expandedSizes(groupSize);
+
+ uint64_t productOfAllStaticSizes = 1;
+ std::optional<unsigned> dynSizeIdx;
+ MemRefType expandShapeType = expandShape.getResultType();
+
+ for (unsigned i = 0; i < groupSize; ++i) {
+ uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
+ if (ShapedType::isDynamic(dimSize)) {
+ assert(!dynSizeIdx && "there must be at most one dynamic size per group");
+ dynSizeIdx = i;
+ continue;
+ }
+ productOfAllStaticSizes *= dimSize;
+ expandedSizes[i] = builder.getIndexAttr(dimSize);
+ }
+
+ if (dynSizeIdx) {
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ expandedSizes[*dynSizeIdx] = affine::makeComposedFoldedAffineApply(
+ builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
+ origSizes[groupId]);
+ }
+
+ return expandedSizes;
+}
+
+/// Compute the expanded strides of the given expand_shape for the reassociation
+/// group `groupId`.
+static SmallVector<OpFoldResult>
+getExpandedStrides(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
+ ArrayRef<int64_t> reassocGroup =
+ expandShape.getReassociationIndices()[groupId];
+ assert(!reassocGroup.empty() &&
+ "Reassociation group should have at least one dimension");
+
+ unsigned groupSize = reassocGroup.size();
+ MemRefType expandShapeType = expandShape.getResultType();
+
+ std::optional<int64_t> dynSizeIdx;
+ uint64_t currentStride = 1;
+ SmallVector<OpFoldResult> expandedStrides(groupSize);
+ for (int i = groupSize - 1; i >= 0; --i) {
+ expandedStrides[i] = builder.getIndexAttr(currentStride);
+ uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
+ if (ShapedType::isDynamic(dimSize)) {
+ assert(!dynSizeIdx && "there must be at most one dynamic size per group");
+ dynSizeIdx = i;
+ continue;
+ }
+ currentStride *= dimSize;
+ }
+
+ auto sourceType = expandShape.getSrcType();
+ auto [strides, offset] = sourceType.getStridesAndOffset();
+ (void)offset;
+
+ OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
+ ? origStrides[groupId]
+ : builder.getIndexAttr(strides[groupId]);
+
+ int64_t doneStrideIdx = 0;
+ if (dynSizeIdx) {
+ int64_t productOfAllStaticSizes = currentStride;
+ assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
+ "dynamic reassociation must originate from dynamic source dim");
+ OpFoldResult origSize = origSizes[groupId];
+
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ AffineExpr s1 = builder.getAffineSymbolExpr(1);
+ for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
+ auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
+ assert(baseAttr && "expected attribute stride");
+ int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
+ expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
+ builder, expandShape.getLoc(),
+ (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
+ {origSize, origStride});
+ }
+ }
+
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
+ auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
+ assert(baseAttr && "expected attribute stride");
+ int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
+ expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
+ builder, expandShape.getLoc(), s0 * baseExpandedStride,
+ {origStride});
+ }
+
+ return expandedStrides;
+}
+
/// Produce an OpFoldResult representing the product of the values or constants
/// referenced by `indices`. `staticShape` provides the statically known sizes
/// for the source memref, while `values` contains the mixed (value/attribute)
@@ -426,6 +535,40 @@ struct FlattenCollapseShape final
}
};
+struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ memref::ExtractStridedMetadataOp metadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+ SmallVector<OpFoldResult> expandedSizes;
+ SmallVector<OpFoldResult> expandedStrides;
+ unsigned numGroups = op.getReassociationIndices().size();
+ expandedSizes.reserve(op.getResultType().getRank());
+ expandedStrides.reserve(op.getResultType().getRank());
+
+ for (unsigned i = 0; i < numGroups; ++i) {
+ SmallVector<OpFoldResult> groupSizes =
+ getExpandedSizes(op, rewriter, origSizes, i);
+ SmallVector<OpFoldResult> groupStrides =
+ getExpandedStrides(op, rewriter, origSizes, origStrides, i);
+ expandedSizes.append(groupSizes.begin(), groupSizes.end());
+ expandedStrides.append(groupStrides.begin(), groupStrides.end());
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, op.getType(), op.getSrc(), offset, expandedSizes, expandedStrides);
+ return success();
+ }
+};
+
struct FlattenMemrefsPass
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
using Base::Base;
@@ -501,6 +644,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
MemRefRewritePattern<memref::AllocOp>,
MemRefRewritePattern<memref::AllocaOp>,
MemRefRewritePattern<memref::DeallocOp>,
+ FlattenExpandShape,
FlattenCollapseShape,
FlattenGetGlobal,
FlattenGlobal>(
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index 5c61bc87418bd..2a5f141dbe328 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -214,7 +214,6 @@ func.func @collapse_shape_dynamic(
return %0 : memref<?x4xf32, strided<[?, ?], offset: ?>>
}
// CHECK: #map = affine_map<()[s0] -> (s0 * 2)>
-// CHECK: #map1 = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
// CHECK-LABEL: func @collapse_shape_dynamic
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %arg0
// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#1]
@@ -223,6 +222,36 @@ func.func @collapse_shape_dynamic(
// -----
+func.func @expand_shape_static(%arg0: memref<6x4xf32>) -> memref<2x3x4xf32> {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 3, 4]
+ : memref<6x4xf32> into memref<2x3x4xf32>
+ return %0 : memref<2x3x4xf32>
+}
+// CHECK-LABEL: func @expand_shape_static
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [2, 3, 4], strides: [12, 4, 1]
+// CHECK: return %[[REINT]]
+
+// -----
+
+func.func @expand_shape_dynamic(
+ %arg0: memref<?x4xf32, strided<[?, ?], offset: ?>>, %size: index) ->
+ memref<?x3x4xf32, strided<[?, ?, ?], offset: ?>> {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%size, 3, 4]
+ : memref<?x4xf32, strided<[?, ?], offset: ?>>
+ into memref<?x3x4xf32, strided<[?, ?, ?], offset: ?>>
+ return %0 : memref<?x3x4xf32, strided<[?, ?, ?], offset: ?>>
+}
+// CHECK: #map = affine_map<()[s0] -> (s0 floordiv 3)>
+// CHECK: #map1 = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-LABEL: func @expand_shape_dynamic
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %arg0
+// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#0]
+// CHECK: %[[STRIDE:.*]] = affine.apply #map1()[%[[STRIDES]]#0]
+// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]], 3, 4], strides: [%[[STRIDE]], %[[STRIDES]]#0, %[[STRIDES]]#1]
+// CHECK: return %[[REINT]]
+
+// -----
+
func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> {
%c0 = arith.constant 0 : i2
%0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [true]} : memref<4x8xi2>, vector<8xi2>
>From d5cda23901336741cc70577bbe15c1b2f1336af6 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 17 Sep 2025 15:32:52 -0400
Subject: [PATCH 4/6] working
---
.../Dialect/MemRef/Transforms/Transforms.h | 99 +++++++
.../Transforms/ExpandStridedMetadata.cpp | 72 ++---
.../MemRef/Transforms/FlattenMemRefs.cpp | 245 +++---------------
3 files changed, 156 insertions(+), 260 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 8b76930aed35a..562b8c11225e8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -14,10 +14,15 @@
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLFunctionalExtras.h"
namespace mlir {
+class Location;
class OpBuilder;
class RewritePatternSet;
class RewriterBase;
@@ -33,7 +38,9 @@ class NarrowTypeEmulationConverter;
namespace memref {
class AllocOp;
class AllocaOp;
+class CollapseShapeOp;
class DeallocOp;
+class ExpandShapeOp;
//===----------------------------------------------------------------------===//
// Patterns
@@ -213,6 +220,98 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
memref::AllocaOp allocToAlloca(
RewriterBase &rewriter, memref::AllocOp alloc,
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
+
+/// Compute the expanded sizes of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// sizes#i =
+/// baseSizes#groupId / product(expandShapeSizes#j,
+/// for j in group excluding reassIdx#i)
+/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the expanded strides of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// strides#i =
+/// origStrides#reassDim * product(expandShapeSizes#j, for j in
+/// reassIdx#i+1..reassIdx#i+group.size-1)
+///
+/// Where reassIdx#i is the reassociation index for at index i in \p groupId
+/// and expandShapeSizes#j is either:
+/// - The constant size at dimension j, derived directly from the result type of
+/// the expand_shape op, or
+/// - An affine expression: baseSizes#reassDim / product of all constant sizes
+/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
+/// element.)
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides,
+ unsigned groupId);
+
+/// Produce an OpFoldResult object with \p builder at \p loc representing
+/// `prod(valueOrConstant#i, for i in {indices})`,
+/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
+/// values[i] otherwise.
+///
+/// \pre for all index in indices: index < values.size()
+/// \pre for all index in indices: index < maybeConstants.size()
+OpFoldResult
+getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> maybeConstants,
+ ArrayRef<OpFoldResult> values,
+ llvm::function_ref<bool(int64_t)> isDynamic);
+
+/// Compute the collapsed size of the given \p collapseShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// TODO: Move this utility function directly within CollapseShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the collapsed stride of the given \p collpaseShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// Conceptually this helper function returns the stride of the inner most
+/// dimension of that group in the original shape.
+///
+/// \post result.size() == 1, in other words, each group collapse to one
+/// dimension.
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides, unsigned groupId);
+
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..6b69d0e366903 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -24,6 +24,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
@@ -35,6 +36,7 @@ namespace memref {
using namespace mlir;
using namespace mlir::affine;
+using namespace mlir::memref;
namespace {
@@ -250,23 +252,12 @@ struct ExtractStridedMetadataOpSubviewFolder
}
};
-/// Compute the expanded sizes of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origSizes hold the sizes of the source shape as values.
-/// This is used to compute the new sizes in cases of dynamic shapes.
-///
-/// sizes#i =
-/// baseSizes#groupId / product(expandShapeSizes#j,
-/// for j in group excluding reassIdx#i)
-/// Where reassIdx#i is the reassociation index at index i in \p groupId.
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+} // namespace
+
+namespace mlir {
+namespace memref {
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
expandShape.getReassociationIndices()[groupId];
@@ -305,31 +296,7 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
return expandedSizes;
}
-/// Compute the expanded strides of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origStrides and \p origSizes hold respectively the strides and sizes
-/// of the source shape as values.
-/// This is used to compute the strides in cases of dynamic shapes and/or
-/// dynamic stride for this reassociation group.
-///
-/// strides#i =
-/// origStrides#reassDim * product(expandShapeSizes#j, for j in
-/// reassIdx#i+1..reassIdx#i+group.size-1)
-///
-/// Where reassIdx#i is the reassociation index for at index i in \p groupId
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-/// the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-/// element.)
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
@@ -405,14 +372,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
return expandedStrides;
}
-/// Produce an OpFoldResult object with \p builder at \p loc representing
-/// `prod(valueOrConstant#i, for i in {indices})`,
-/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
-/// values[i] otherwise.
-///
-/// \pre for all index in indices: index < values.size()
-/// \pre for all index in indices: index < maybeConstants.size()
-static OpFoldResult
+OpFoldResult
getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
ArrayRef<int64_t> maybeConstants,
ArrayRef<OpFoldResult> values,
@@ -450,8 +410,8 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
SmallVector<OpFoldResult> collapsedSize;
@@ -491,8 +451,8 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
-static SmallVector<OpFoldResult>
-getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
@@ -546,6 +506,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {lastValidStride};
}
+} // namespace memref
+} // namespace mlir
+
+namespace {
/// From `reshape_like(memref, subSizes, subStrides))` compute
///
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 5394df5b46fb2..f17a0582919fc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -96,208 +96,6 @@ static bool checkLayout(Value val) {
isa<StridedLayoutAttr>(type.getLayout());
}
-/// Compute the expanded sizes of the given expand_shape for the reassociation
-/// group `groupId`. Portions adapted from
-/// `lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp` to avoid a direct
-/// dependency from the MemRef dialect on the Affine dialect.
-static SmallVector<OpFoldResult>
-getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
- ArrayRef<int64_t> reassocGroup =
- expandShape.getReassociationIndices()[groupId];
- assert(!reassocGroup.empty() &&
- "Reassociation group should have at least one dimension");
-
- unsigned groupSize = reassocGroup.size();
- SmallVector<OpFoldResult> expandedSizes(groupSize);
-
- uint64_t productOfAllStaticSizes = 1;
- std::optional<unsigned> dynSizeIdx;
- MemRefType expandShapeType = expandShape.getResultType();
-
- for (unsigned i = 0; i < groupSize; ++i) {
- uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
- if (ShapedType::isDynamic(dimSize)) {
- assert(!dynSizeIdx && "there must be at most one dynamic size per group");
- dynSizeIdx = i;
- continue;
- }
- productOfAllStaticSizes *= dimSize;
- expandedSizes[i] = builder.getIndexAttr(dimSize);
- }
-
- if (dynSizeIdx) {
- AffineExpr s0 = builder.getAffineSymbolExpr(0);
- expandedSizes[*dynSizeIdx] = affine::makeComposedFoldedAffineApply(
- builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
- origSizes[groupId]);
- }
-
- return expandedSizes;
-}
-
-/// Compute the expanded strides of the given expand_shape for the reassociation
-/// group `groupId`.
-static SmallVector<OpFoldResult>
-getExpandedStrides(memref::ExpandShapeOp expandShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes,
- ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
- ArrayRef<int64_t> reassocGroup =
- expandShape.getReassociationIndices()[groupId];
- assert(!reassocGroup.empty() &&
- "Reassociation group should have at least one dimension");
-
- unsigned groupSize = reassocGroup.size();
- MemRefType expandShapeType = expandShape.getResultType();
-
- std::optional<int64_t> dynSizeIdx;
- uint64_t currentStride = 1;
- SmallVector<OpFoldResult> expandedStrides(groupSize);
- for (int i = groupSize - 1; i >= 0; --i) {
- expandedStrides[i] = builder.getIndexAttr(currentStride);
- uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
- if (ShapedType::isDynamic(dimSize)) {
- assert(!dynSizeIdx && "there must be at most one dynamic size per group");
- dynSizeIdx = i;
- continue;
- }
- currentStride *= dimSize;
- }
-
- auto sourceType = expandShape.getSrcType();
- auto [strides, offset] = sourceType.getStridesAndOffset();
- (void)offset;
-
- OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
- ? origStrides[groupId]
- : builder.getIndexAttr(strides[groupId]);
-
- int64_t doneStrideIdx = 0;
- if (dynSizeIdx) {
- int64_t productOfAllStaticSizes = currentStride;
- assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
- "dynamic reassociation must originate from dynamic source dim");
- OpFoldResult origSize = origSizes[groupId];
-
- AffineExpr s0 = builder.getAffineSymbolExpr(0);
- AffineExpr s1 = builder.getAffineSymbolExpr(1);
- for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
- auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
- assert(baseAttr && "expected attribute stride");
- int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
- expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
- builder, expandShape.getLoc(),
- (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
- {origSize, origStride});
- }
- }
-
- AffineExpr s0 = builder.getAffineSymbolExpr(0);
- for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
- auto baseAttr = expandedStrides[doneStrideIdx].dyn_cast<Attribute>();
- assert(baseAttr && "expected attribute stride");
- int64_t baseExpandedStride = cast<IntegerAttr>(baseAttr).getInt();
- expandedStrides[doneStrideIdx] = affine::makeComposedFoldedAffineApply(
- builder, expandShape.getLoc(), s0 * baseExpandedStride,
- {origStride});
- }
-
- return expandedStrides;
-}
-
-/// Produce an OpFoldResult representing the product of the values or constants
-/// referenced by `indices`. `staticShape` provides the statically known sizes
-/// for the source memref, while `values` contains the mixed (value/attribute)
-/// representation produced by `memref.extract_strided_metadata`.
-static OpFoldResult getProductOfValues(ArrayRef<int64_t> indices,
- OpBuilder &builder, Location loc,
- ArrayRef<int64_t> staticShape,
- ArrayRef<OpFoldResult> values) {
- AffineExpr product = builder.getAffineConstantExpr(1);
- SmallVector<OpFoldResult> inputs;
- unsigned numSymbols = 0;
- for (int64_t idx : indices) {
- product = product * builder.getAffineSymbolExpr(numSymbols++);
- if (ShapedType::isDynamic(staticShape[idx]))
- inputs.push_back(values[idx]);
- else
- inputs.push_back(builder.getIndexAttr(staticShape[idx]));
- }
- return affine::makeComposedFoldedAffineApply(builder, loc, product, inputs);
-}
-
-/// Return the collapsed size (as OpFoldResult) for the reassociation group
-/// `groupId` of `collapseShapeOp`.
-static SmallVector<OpFoldResult>
-getCollapsedSize(memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
- SmallVector<OpFoldResult> collapsedSize;
-
- MemRefType resultType = collapseShapeOp.getResultType();
- int64_t dimSize = resultType.getDimSize(groupId);
- if (!ShapedType::isDynamic(dimSize)) {
- collapsedSize.push_back(builder.getIndexAttr(dimSize));
- return collapsedSize;
- }
-
- auto sourceType = collapseShapeOp.getSrcType();
- ArrayRef<int64_t> staticShape = sourceType.getShape();
- ArrayRef<int64_t> reassocGroup =
- collapseShapeOp.getReassociationIndices()[groupId];
-
- collapsedSize.push_back(getProductOfValues(reassocGroup, builder,
- collapseShapeOp.getLoc(),
- staticShape, origSizes));
- return collapsedSize;
-}
-
-/// Return the collapsed stride (as OpFoldResult) for the reassociation group
-/// `groupId` of `collapseShapeOp`.
-static SmallVector<OpFoldResult> getCollapsedStride(
- memref::CollapseShapeOp collapseShapeOp, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, ArrayRef<OpFoldResult> origStrides,
- unsigned groupId) {
- ArrayRef<int64_t> reassocGroup =
- collapseShapeOp.getReassociationIndices()[groupId];
- assert(!reassocGroup.empty() &&
- "reassociation group must contain at least one dimension");
-
- auto sourceType = collapseShapeOp.getSrcType();
- auto [strides, offset] = sourceType.getStridesAndOffset();
- (void)offset;
- ArrayRef<int64_t> srcShape = sourceType.getShape();
-
- OpFoldResult lastValidStride = nullptr;
- for (int64_t dim : reassocGroup) {
- if (srcShape[dim] == 1)
- continue;
- int64_t currentStride = strides[dim];
- if (ShapedType::isDynamic(currentStride))
- lastValidStride = origStrides[dim];
- else
- lastValidStride = builder.getIndexAttr(currentStride);
- }
-
- if (!lastValidStride) {
- MemRefType collapsedType = collapseShapeOp.getResultType();
- auto [collapsedStrides, collapsedOffset] =
- collapsedType.getStridesAndOffset();
- (void)collapsedOffset;
- int64_t finalStride = collapsedStrides[groupId];
- if (ShapedType::isDynamic(finalStride)) {
- for (int64_t dim : reassocGroup) {
- assert(srcShape[dim] == 1 && "expected size-one dimensions");
- if (ShapedType::isDynamic(strides[dim]))
- return {origStrides[dim]};
- }
- llvm_unreachable("expected to find a dynamic stride");
- }
- return {builder.getIndexAttr(finalStride)};
- }
-
- return {lastValidStride};
-}
-
namespace {
static Value getTargetMemref(Operation *op) {
return llvm::TypeSwitch<Operation *, Value>(op)
@@ -521,9 +319,9 @@ struct FlattenCollapseShape final
collapsedStrides.reserve(numGroups);
for (unsigned i = 0; i < numGroups; ++i) {
SmallVector<OpFoldResult> groupSizes =
- getCollapsedSize(op, rewriter, origSizes, i);
+ memref::getCollapsedSize(op, rewriter, origSizes, i);
SmallVector<OpFoldResult> groupStrides =
- getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+ memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i);
collapsedSizes.append(groupSizes.begin(), groupSizes.end());
collapsedStrides.append(groupStrides.begin(), groupStrides.end());
}
@@ -556,9 +354,9 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
for (unsigned i = 0; i < numGroups; ++i) {
SmallVector<OpFoldResult> groupSizes =
- getExpandedSizes(op, rewriter, origSizes, i);
+ memref::getExpandedSizes(op, rewriter, origSizes, i);
SmallVector<OpFoldResult> groupStrides =
- getExpandedStrides(op, rewriter, origSizes, origStrides, i);
+ memref::getExpandedStrides(op, rewriter, origSizes, origStrides, i);
expandedSizes.append(groupSizes.begin(), groupSizes.end());
expandedStrides.append(groupStrides.begin(), groupStrides.end());
}
@@ -569,6 +367,40 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
}
};
+
+/*
+// Flattens memref subspan ops with more than 1 dimensions to 1 dimension.
+struct FlattenSubView final : public OpConversionPattern<memref::SubViewOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isRankZeroOrOneMemRef(adaptor.getSource().getType())) {
+ return rewriter.notifyMatchFailure(
+ op, "expected converted memref of rank <= 1");
+ }
+ Type neededResultType =
+ getTypeConverter()->convertType(op.getResult().getType());
+ if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType))
+ return failure();
+ Value size = createTotalElementCountValue(op.getType(), op.getSizes(),
+ op.getLoc(), rewriter);
+ SmallVector<Value> offsets = mlir::getValueOrCreateConstantIndexOp(
+ rewriter, op.getLoc(), op.getMixedOffsets());
+ Value linearOffset =
+ linearizeIndices(op.getSource(), offsets, op.getLoc(), rewriter);
+ Value stride = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 1);
+ Value newSubView = memref::SubViewOp::create(
+ rewriter, op.getLoc(), adaptor.getSource(), ValueRange({linearOffset}),
+ ValueRange({size}), ValueRange({stride}));
+ rewriter.replaceOpWithNewOp<memref::CastOp>(op, neededResultType,
+ newSubView);
+ return success();
+ }
+};
+*/
+
struct FlattenMemrefsPass
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
using Base::Base;
@@ -646,6 +478,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
MemRefRewritePattern<memref::DeallocOp>,
FlattenExpandShape,
FlattenCollapseShape,
+ //FlattenSubView,
FlattenGetGlobal,
FlattenGlobal>(
patterns.getContext());
>From 05cdabb23bf25498c9b4ea5e3bb3b7b1102d037b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 17 Sep 2025 21:56:10 -0400
Subject: [PATCH 5/6] subview
---
.../MemRef/Transforms/FlattenMemRefs.cpp | 181 ++++++++++++++----
mlir/test/Dialect/MemRef/flatten_memref.mlir | 13 ++
2 files changed, 152 insertions(+), 42 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index f17a0582919fc..43a67f1fab2be 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
@@ -47,6 +48,7 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
return cast<Value>(in);
}
+
/// Returns a collapsed memref and the linearized index to access the element
/// at the specified indices.
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -90,12 +92,15 @@ static bool needFlattening(Value val) {
return type.getRank() > 1;
}
-static bool checkLayout(Value val) {
- auto type = cast<MemRefType>(val.getType());
+static bool checkLayout(MemRefType type) {
return type.getLayout().isIdentity() ||
isa<StridedLayoutAttr>(type.getLayout());
}
+static bool checkLayout(Value val) {
+ return checkLayout(cast<MemRefType>(val.getType()));
+}
+
namespace {
static Value getTargetMemref(Operation *op) {
return llvm::TypeSwitch<Operation *, Value>(op)
@@ -368,38 +373,131 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
};
-/*
-// Flattens memref subspan ops with more than 1 dimensions to 1 dimension.
-struct FlattenSubView final : public OpConversionPattern<memref::SubViewOp> {
- using OpConversionPattern::OpConversionPattern;
+// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
+struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
+ using OpRewritePattern::OpRewritePattern;
- LogicalResult
- matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- if (!isRankZeroOrOneMemRef(adaptor.getSource().getType())) {
- return rewriter.notifyMatchFailure(
- op, "expected converted memref of rank <= 1");
- }
- Type neededResultType =
- getTypeConverter()->convertType(op.getResult().getType());
- if (!neededResultType || !isRankZeroOrOneMemRef(neededResultType))
+ LogicalResult matchAndRewrite(memref::SubViewOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
+ if (!sourceType || sourceType.getRank() <= 1)
+ return failure();
+ if (!checkLayout(sourceType))
return failure();
- Value size = createTotalElementCountValue(op.getType(), op.getSizes(),
- op.getLoc(), rewriter);
- SmallVector<Value> offsets = mlir::getValueOrCreateConstantIndexOp(
- rewriter, op.getLoc(), op.getMixedOffsets());
- Value linearOffset =
- linearizeIndices(op.getSource(), offsets, op.getLoc(), rewriter);
- Value stride = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 1);
- Value newSubView = memref::SubViewOp::create(
- rewriter, op.getLoc(), adaptor.getSource(), ValueRange({linearOffset}),
- ValueRange({size}), ValueRange({stride}));
- rewriter.replaceOpWithNewOp<memref::CastOp>(op, neededResultType,
- newSubView);
+
+ MemRefType resultType = op.getType();
+ if (resultType.getRank() <= 1 || !checkLayout(resultType))
+ return failure();
+
+ unsigned elementBitWidth = sourceType.getElementTypeBitWidth();
+ if (!elementBitWidth)
+ return failure();
+
+ Location loc = op.getLoc();
+
+ // Materialize offsets as values so they can participate in linearization.
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
+
+ SmallVector<Value> offsetValues;
+ offsetValues.reserve(mixedOffsets.size());
+ for (OpFoldResult ofr : mixedOffsets)
+ offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
+
+ auto [flatSource, linearOffset] =
+ getFlattenMemrefAndOffset(rewriter, loc, op.getSource(),
+ ValueRange(offsetValues));
+
+ memref::ExtractStridedMetadataOp sourceMetadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
+
+ SmallVector<OpFoldResult> sourceStrides =
+ sourceMetadata.getConstifiedMixedStrides();
+ OpFoldResult sourceOffset = sourceMetadata.getConstifiedMixedOffset();
+
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+ SmallVector<OpFoldResult> resultSizes;
+ SmallVector<OpFoldResult> resultStrides;
+ resultSizes.reserve(resultType.getRank());
+ resultStrides.reserve(resultType.getRank());
+
+ OpFoldResult resultOffset = sourceOffset;
+ for (auto [idx, it] : llvm::enumerate(llvm::zip_equal(
+ mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
+ auto [offsetOfr, strideOfr, sizeOfr, relativeStrideOfr] = it;
+ OpFoldResult contribution = [&]() -> OpFoldResult {
+ if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
+ if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
+ auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
+ auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
+ return rewriter.getIndexAttr(offsetInt * strideInt);
+ }
+ }
+ Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offsetOfr);
+ Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
+ return rewriter.create<arith::MulIOp>(loc, offsetVal, strideVal)
+ .getResult();
+ }();
+ resultOffset = [&]() -> OpFoldResult {
+ if (Attribute offsetAttr = dyn_cast<Attribute>(resultOffset)) {
+ if (Attribute contribAttr = dyn_cast<Attribute>(contribution)) {
+ auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
+ auto contribInt = cast<IntegerAttr>(contribAttr).getInt();
+ return rewriter.getIndexAttr(offsetInt + contribInt);
+ }
+ }
+ Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
+ Value contribVal = getValueFromOpFoldResult(rewriter, loc, contribution);
+ return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
+ .getResult();
+ }();
+
+ if (droppedDims.test(idx))
+ continue;
+
+ resultSizes.push_back(sizeOfr);
+ OpFoldResult combinedStride = [&]() -> OpFoldResult {
+ if (Attribute relStrideAttr = dyn_cast<Attribute>(relativeStrideOfr)) {
+ if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
+ auto relStrideInt = cast<IntegerAttr>(relStrideAttr).getInt();
+ auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
+ return rewriter.getIndexAttr(relStrideInt * strideInt);
+ }
+ }
+ Value relStrideVal =
+ getValueFromOpFoldResult(rewriter, loc, relativeStrideOfr);
+ Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
+ return rewriter.create<arith::MulIOp>(loc, relStrideVal, strideVal)
+ .getResult();
+ }();
+ resultStrides.push_back(combinedStride);
+ }
+
+ memref::LinearizedMemRefInfo linearizedInfo;
+ [[maybe_unused]] OpFoldResult linearizedIndex;
+ std::tie(linearizedInfo, linearizedIndex) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
+ resultSizes, resultStrides);
+
+ Value flattenedSize = getValueFromOpFoldResult(
+ rewriter, loc, linearizedInfo.linearizedSize);
+ Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
+
+ Value flattenedSubview = memref::SubViewOp::create(
+ rewriter, loc, flatSource, ValueRange{linearOffset},
+ ValueRange{flattenedSize}, ValueRange{strideOne});
+
+ Value replacement = memref::ReinterpretCastOp::create(
+ rewriter, loc, resultType, flattenedSubview, resultOffset, resultSizes,
+ resultStrides);
+
+ rewriter.replaceOp(op, replacement);
return success();
}
};
-*/
struct FlattenMemrefsPass
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
@@ -422,18 +520,6 @@ struct FlattenMemrefsPass
} // namespace
-void memref::populateFlattenVectorOpsOnMemrefPatterns(
- RewritePatternSet &patterns) {
- patterns.insert<MemRefRewritePattern<vector::LoadOp>,
- MemRefRewritePattern<vector::StoreOp>,
- MemRefRewritePattern<vector::TransferReadOp>,
- MemRefRewritePattern<vector::TransferWriteOp>,
- MemRefRewritePattern<vector::MaskedLoadOp>,
- MemRefRewritePattern<vector::MaskedStoreOp>>(
- patterns.getContext());
-}
-
-/// Special pattern for GetGlobalOp to avoid infinite loops
struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
using OpRewritePattern::OpRewritePattern;
@@ -470,6 +556,17 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
}
};
+void memref::populateFlattenVectorOpsOnMemrefPatterns(
+ RewritePatternSet &patterns) {
+ patterns.insert<MemRefRewritePattern<vector::LoadOp>,
+ MemRefRewritePattern<vector::StoreOp>,
+ MemRefRewritePattern<vector::TransferReadOp>,
+ MemRefRewritePattern<vector::TransferWriteOp>,
+ MemRefRewritePattern<vector::MaskedLoadOp>,
+ MemRefRewritePattern<vector::MaskedStoreOp>>(
+ patterns.getContext());
+}
+
void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
MemRefRewritePattern<memref::StoreOp>,
@@ -478,7 +575,7 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
MemRefRewritePattern<memref::DeallocOp>,
FlattenExpandShape,
FlattenCollapseShape,
- //FlattenSubView,
+ FlattenSubView,
FlattenGetGlobal,
FlattenGlobal>(
patterns.getContext());
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index 2a5f141dbe328..05290f96f45e9 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -194,6 +194,19 @@ func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: in
// -----
+func.func @flatten_subview_static(%arg0: memref<3x4xf32, strided<[4, 1], offset: 0>>) -> memref<2x2xf32, strided<[4, 1], offset: 1>> {
+ %sub = memref.subview %arg0[0, 1] [2, 2] [1, 1]
+ : memref<3x4xf32, strided<[4, 1], offset: 0>> to memref<2x2xf32, strided<[4, 1], offset: 1>>
+ return %sub : memref<2x2xf32, strided<[4, 1], offset: 1>>
+}
+// CHECK-LABEL: func @flatten_subview_static
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[FLAT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [12], strides: [1]
+// CHECK: %[[SUB:.*]] = memref.subview %[[FLAT]][%[[C1]]] [%[[C8]]] [%[[C1]]]
+// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[SUB]] to offset: [1], sizes: [2, 2], strides: [4, 1]
+// CHECK: return %[[CAST]]
+
func.func @collapse_shape_static(%arg0: memref<2x3x4xf32>) -> memref<6x4xf32> {
%0 = memref.collapse_shape %arg0 [[0, 1], [2]]
: memref<2x3x4xf32> into memref<6x4xf32>
>From 70690626399a48b9801f854ec5d311b4c0e8106b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 19 Sep 2025 17:58:44 -0400
Subject: [PATCH 6/6] Fix format and C++20 issue
---
.../Dialect/MemRef/Transforms/Transforms.h | 34 +++++-----
.../Transforms/ExpandStridedMetadata.cpp | 34 +++++-----
.../MemRef/Transforms/FlattenMemRefs.cpp | 67 ++++++++++---------
3 files changed, 71 insertions(+), 64 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 562b8c11225e8..d40cb5ee8a064 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -18,8 +18,8 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallVector.h"
namespace mlir {
class Location;
@@ -236,9 +236,10 @@ memref::AllocaOp allocToAlloca(
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ unsigned groupId);
/// Compute the expanded strides of the given \p expandShape for the
/// \p groupId-th reassociation group.
@@ -277,11 +278,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
///
/// \pre for all index in indices: index < values.size()
/// \pre for all index in indices: index < maybeConstants.size()
-OpFoldResult
-getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
- ArrayRef<int64_t> maybeConstants,
- ArrayRef<OpFoldResult> values,
- llvm::function_ref<bool(int64_t)> isDynamic);
+OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
+ Location loc, ArrayRef<int64_t> maybeConstants,
+ ArrayRef<OpFoldResult> values,
+ llvm::function_ref<bool(int64_t)> isDynamic);
/// Compute the collapsed size of the given \p collapseShape for the
/// \p groupId-th reassociation group.
@@ -291,9 +291,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ unsigned groupId);
/// Compute the collapsed stride of the given \p collpaseShape for the
/// \p groupId-th reassociation group.
@@ -307,10 +308,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
-SmallVector<OpFoldResult>
-getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes,
- ArrayRef<OpFoldResult> origStrides, unsigned groupId);
+SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides,
+ unsigned groupId);
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 6b69d0e366903..96bceae88af9d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -21,10 +21,10 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
-#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
@@ -256,9 +256,10 @@ struct ExtractStridedMetadataOpSubviewFolder
namespace mlir {
namespace memref {
-SmallVector<OpFoldResult>
-getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
expandShape.getReassociationIndices()[groupId];
assert(!reassocGroup.empty() &&
@@ -372,11 +373,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
return expandedStrides;
}
-OpFoldResult
-getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
- ArrayRef<int64_t> maybeConstants,
- ArrayRef<OpFoldResult> values,
- llvm::function_ref<bool(int64_t)> isDynamic) {
+OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
+ Location loc, ArrayRef<int64_t> maybeConstants,
+ ArrayRef<OpFoldResult> values,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
AffineExpr productOfValues = builder.getAffineConstantExpr(1);
SmallVector<OpFoldResult> inputValues;
unsigned numberOfSymbols = 0;
@@ -410,9 +410,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ unsigned groupId) {
SmallVector<OpFoldResult> collapsedSize;
MemRefType collapseShapeType = collapseShape.getResultType();
@@ -451,10 +452,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
-SmallVector<OpFoldResult>
-getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes,
- ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
+SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides,
+ unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
collapseShape.getReassociationIndices()[groupId];
assert(!reassocGroup.empty() &&
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 43a67f1fab2be..1e6805139f7ff 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -21,9 +21,9 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,7 +48,6 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
return cast<Value>(in);
}
-
/// Returns a collapsed memref and the linearized index to access the element
/// at the specified indices.
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -281,9 +280,8 @@ struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
return {};
}
- LogicalResult
- matchAndRewrite(memref::GlobalOp globalOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(memref::GlobalOp globalOp,
+ PatternRewriter &rewriter) const override {
auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
return failure();
@@ -314,7 +312,8 @@ struct FlattenCollapseShape final
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
- SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ SmallVector<OpFoldResult> origStrides =
+ metadata.getConstifiedMixedStrides();
OpFoldResult offset = metadata.getConstifiedMixedOffset();
SmallVector<OpFoldResult> collapsedSizes;
@@ -338,7 +337,8 @@ struct FlattenCollapseShape final
}
};
-struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
+struct FlattenExpandShape final
+ : public OpRewritePattern<memref::ExpandShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
@@ -348,7 +348,8 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
- SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ SmallVector<OpFoldResult> origStrides =
+ metadata.getConstifiedMixedStrides();
OpFoldResult offset = metadata.getConstifiedMixedOffset();
SmallVector<OpFoldResult> expandedSizes;
@@ -372,7 +373,6 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
}
};
-
// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
@@ -405,9 +405,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
for (OpFoldResult ofr : mixedOffsets)
offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
- auto [flatSource, linearOffset] =
- getFlattenMemrefAndOffset(rewriter, loc, op.getSource(),
- ValueRange(offsetValues));
+ auto [flatSource, linearOffset] = getFlattenMemrefAndOffset(
+ rewriter, loc, op.getSource(), ValueRange(offsetValues));
memref::ExtractStridedMetadataOp sourceMetadata =
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
@@ -424,9 +423,14 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
resultStrides.reserve(resultType.getRank());
OpFoldResult resultOffset = sourceOffset;
- for (auto [idx, it] : llvm::enumerate(llvm::zip_equal(
+ for (auto zipped : llvm::enumerate(llvm::zip_equal(
mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
- auto [offsetOfr, strideOfr, sizeOfr, relativeStrideOfr] = it;
+ auto idx = zipped.index();
+ auto it = zipped.value();
+ auto offsetOfr = std::get<0>(it);
+ auto strideOfr = std::get<1>(it);
+ auto sizeOfr = std::get<2>(it);
+ auto relativeStrideOfr = std::get<3>(it);
OpFoldResult contribution = [&]() -> OpFoldResult {
if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
@@ -449,7 +453,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
}
}
Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
- Value contribVal = getValueFromOpFoldResult(rewriter, loc, contribution);
+ Value contribVal =
+ getValueFromOpFoldResult(rewriter, loc, contribution);
return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
.getResult();
}();
@@ -478,12 +483,12 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
memref::LinearizedMemRefInfo linearizedInfo;
[[maybe_unused]] OpFoldResult linearizedIndex;
std::tie(linearizedInfo, linearizedIndex) =
- memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
- resultSizes, resultStrides);
+ memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+ elementBitWidth, resultOffset,
+ resultSizes, resultStrides);
- Value flattenedSize = getValueFromOpFoldResult(
- rewriter, loc, linearizedInfo.linearizedSize);
+ Value flattenedSize =
+ getValueFromOpFoldResult(rewriter, loc, linearizedInfo.linearizedSize);
Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value flattenedSubview = memref::SubViewOp::create(
@@ -524,10 +529,11 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::GetGlobalOp op,
- PatternRewriter &rewriter) const override {
+ PatternRewriter &rewriter) const override {
// Check if this get_global references a multi-dimensional global
auto module = op->template getParentOfType<ModuleOp>();
- auto globalOp = module.template lookupSymbol<memref::GlobalOp>(op.getName());
+ auto globalOp =
+ module.template lookupSymbol<memref::GlobalOp>(op.getName());
if (!globalOp) {
return failure();
}
@@ -537,12 +543,13 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
// Only apply if the global has been flattened but the get_global hasn't
if (globalType.getRank() == 1 && resultType.getRank() > 1) {
- auto newGetGlobal = memref::GetGlobalOp::create(
- rewriter, op.getLoc(), globalType, op.getName());
+ auto newGetGlobal = memref::GetGlobalOp::create(rewriter, op.getLoc(),
+ globalType, op.getName());
// Cast the flattened result back to the original shape
memref::ExtractStridedMetadataOp stridedMetadata =
- memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(), op.getResult());
+ memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(),
+ op.getResult());
auto castResult = memref::ReinterpretCastOp::create(
rewriter, op.getLoc(), resultType, newGetGlobal,
/*offset=*/rewriter.getIndexAttr(0),
@@ -572,13 +579,9 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
MemRefRewritePattern<memref::StoreOp>,
MemRefRewritePattern<memref::AllocOp>,
MemRefRewritePattern<memref::AllocaOp>,
- MemRefRewritePattern<memref::DeallocOp>,
- FlattenExpandShape,
- FlattenCollapseShape,
- FlattenSubView,
- FlattenGetGlobal,
- FlattenGlobal>(
- patterns.getContext());
+ MemRefRewritePattern<memref::DeallocOp>, FlattenExpandShape,
+ FlattenCollapseShape, FlattenSubView, FlattenGetGlobal,
+ FlattenGlobal>(patterns.getContext());
}
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
More information about the Mlir-commits
mailing list