[Mlir-commits] [mlir] [MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination violation (PR #188980)
Mehdi Amini
llvmlistbot at llvm.org
Thu Apr 2 06:29:45 PDT 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/188980
>From 4692a60aeaf6fc8b84960b8469b1ec05db27fb7c Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 15:57:26 -0700
Subject: [PATCH] [MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination
violation
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The generic MemRefRewritePattern handles AllocOp/AllocaOp by calling
getFlattenMemrefAndOffset with the op's own result as the source memref.
This inserts ExtractStridedMetadataOp and ReinterpretCastOp that consume
op.result before the alloc op itself in the block. After
replaceOpWithNewOp, op.result is RAUW'd to the new ReinterpretCastOp
result, leaving those earlier ops with forward references — a domination
violation caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
Replace the AllocOp/AllocaOp cases in MemRefRewritePattern with a
dedicated AllocLikeFlattenPattern that never touches op.result until the
final replaceOpWithNewOp:
- sizes come from op.getMixedSizes() (operands, not the result)
- strides and offset come from getStridesAndOffset on the MemRefType
- the flat allocation size is computed via getLinearizedMemRefOffsetAndSize,
plus the static base offset so the buffer covers [0, offset+extent)
- castAllocResult takes pre-computed offset/sizes/strides instead of
inserting an ExtractStridedMetadataOp on the original op
- non-zero static base offsets are now correctly preserved in the
reinterpret_cast (the old code hardcoded offset=0, which was a
verifier error for layouts with offset \!= 0)
- dynamic offsets or strides bail out via notifyMatchFailure
Also remove the now-dead AllocOp/AllocaOp branches from replaceOp() and
the constexpr specialisation in getIndices().
Add tests for: dynamic single-dim alloc/alloca, explicit row-major strides,
non-zero static offset, padded strides, and multi-dynamic (rejected) cases.
Fixes a failure reproducible with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
Assisted-by: Claude Code
---
.../MemRef/Transforms/FlattenMemRefs.cpp | 130 ++++++++++++++----
mlir/test/Dialect/MemRef/flatten_memref.mlir | 75 ++++++++++
2 files changed, 175 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 32244728ff333..1bc921a55ffe6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -108,15 +108,11 @@ static Value getTargetMemref(Operation *op) {
}
template <typename T>
-static void castAllocResult(T oper, T newOper, Location loc,
- PatternRewriter &rewriter) {
- memref::ExtractStridedMetadataOp stridedMetadata =
- memref::ExtractStridedMetadataOp::create(rewriter, loc, oper);
+static void castAllocResult(T oper, T newOper, PatternRewriter &rewriter,
+ OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
- oper, cast<MemRefType>(oper.getType()), newOper,
- /*offset=*/rewriter.getIndexAttr(0),
- stridedMetadata.getConstifiedMixedSizes(),
- stridedMetadata.getConstifiedMixedStrides());
+ oper, cast<MemRefType>(oper.getType()), newOper, offset, sizes, strides);
}
template <typename T>
@@ -124,18 +120,6 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
Value offset) {
Location loc = op->getLoc();
llvm::TypeSwitch<Operation *>(op.getOperation())
- .Case([&](memref::AllocOp oper) {
- auto newAlloc = memref::AllocOp::create(
- rewriter, loc, cast<MemRefType>(flatMemref.getType()),
- oper.getAlignmentAttr());
- castAllocResult(oper, newAlloc, loc, rewriter);
- })
- .Case([&](memref::AllocaOp oper) {
- auto newAlloca = memref::AllocaOp::create(
- rewriter, loc, cast<MemRefType>(flatMemref.getType()),
- oper.getAlignmentAttr());
- castAllocResult(oper, newAlloca, loc, rewriter);
- })
.Case([&](memref::LoadOp op) {
auto newLoad =
memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
@@ -196,12 +180,7 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
template <typename T>
static ValueRange getIndices(T op) {
- if constexpr (std::is_same_v<T, memref::AllocaOp> ||
- std::is_same_v<T, memref::AllocOp>) {
- return ValueRange{};
- } else {
- return op.getIndices();
- }
+ return op.getIndices();
}
template <typename T>
@@ -230,19 +209,110 @@ static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
.Default([&](auto op) { return success(); });
}
+// Pattern for memref::AllocOp and memref::AllocaOp.
+//
+// The "source" memref for these ops IS the op's own result, so the generic
+// MemRefRewritePattern cannot be used: getFlattenMemrefAndOffset would insert
+// ExtractStridedMetadataOp and ReinterpretCastOp that use op.result BEFORE op
+// in the block. After replaceOpWithNewOp the original result is RAUW'd to the
+// new ReinterpretCastOp, leaving the earlier ops with forward references
+// (domination violations) caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
+//
+// Instead, sizes and strides are computed from the op's operands and type
+// (which all dominate the op), avoiding any reference to op.result until the
+// final replaceOpWithNewOp inside castAllocResult.
+template <typename AllocLikeOp>
+struct AllocLikeFlattenPattern : public OpRewritePattern<AllocLikeOp> {
+ using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(AllocLikeOp op,
+ PatternRewriter &rewriter) const override {
+ if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref()))
+ return failure();
+
+ Location loc = op->getLoc();
+ auto memrefType = cast<MemRefType>(op.getType());
+ auto elemType = memrefType.getElementType();
+ if (!elemType.isIntOrFloat())
+ return failure();
+ unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
+
+ SmallVector<OpFoldResult> sizes = op.getMixedSizes();
+
+ int64_t staticOffset;
+ SmallVector<int64_t> staticStrides;
+ if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
+ return failure();
+ if (staticOffset == ShapedType::kDynamic)
+ return rewriter.notifyMatchFailure(op, "dynamic offset not supported");
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(staticStrides.size());
+ for (int64_t stride : staticStrides) {
+ if (stride == ShapedType::kDynamic)
+ return rewriter.notifyMatchFailure(op,
+ "dynamic stride cannot be computed");
+ strides.push_back(rewriter.getIndexAttr(stride));
+ }
+
+ // Compute the linearized flat extent from sizes and strides (no SSA ops
+ // referencing op.result are created here).
+ memref::LinearizedMemRefInfo linearizedInfo;
+ OpFoldResult linearizedOffset;
+ std::tie(linearizedInfo, linearizedOffset) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, elemBitWidth, elemBitWidth, rewriter.getIndexAttr(0),
+ sizes, strides);
+ (void)linearizedOffset;
+
+ // The total allocation must cover [0, staticOffset + linearizedExtent).
+ // When the offset is non-zero, add it to the computed extent so that the
+ // buffer is large enough for elements accessed at positions
+ // [staticOffset, staticOffset + linearizedExtent).
+ OpFoldResult flatSizeOfr = linearizedInfo.linearizedSize;
+ if (staticOffset != 0) {
+ AffineExpr s0;
+ bindSymbols(rewriter.getContext(), s0);
+ flatSizeOfr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0 + staticOffset, {flatSizeOfr});
+ }
+
+ // Build the flat 1-D MemRefType. The linearized size may be static or
+ // dynamic (OpFoldResult of either IntegerAttr or a Value).
+ int64_t flatDimSize = ShapedType::kDynamic;
+ if (auto attr = dyn_cast<Attribute>(flatSizeOfr))
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+ flatDimSize = intAttr.getInt();
+
+ auto flatMemrefType =
+ MemRefType::get({flatDimSize}, memrefType.getElementType(),
+ StridedLayoutAttr::get(rewriter.getContext(), 0, {1}),
+ memrefType.getMemorySpace());
+
+ // Collect the flat dynamic-size operand (empty for fully-static case).
+ SmallVector<Value, 1> dynSizes;
+ if (flatDimSize == ShapedType::kDynamic)
+ dynSizes.push_back(getValueFromOpFoldResult(rewriter, loc, flatSizeOfr));
+
+ auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
+ op.getAlignmentAttr());
+ castAllocResult(op, newOp, rewriter, rewriter.getIndexAttr(staticOffset),
+ sizes, strides);
+ return success();
+ }
+};
+
template <typename T>
struct MemRefRewritePattern : public OpRewritePattern<T> {
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter) const override {
LogicalResult canFlatten = canBeFlattened(op, rewriter);
- if (failed(canFlatten)) {
+ if (failed(canFlatten))
return canFlatten;
- }
Value memref = getTargetMemref(op);
if (!needFlattening(memref) || !checkLayout(memref))
return failure();
+
auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
rewriter, op->getLoc(), memref, getIndices<T>(op));
replaceOp<T>(op, rewriter, flatMemref, offset);
@@ -285,8 +355,8 @@ void memref::populateFlattenVectorOpsOnMemrefPatterns(
void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
MemRefRewritePattern<memref::StoreOp>,
- MemRefRewritePattern<memref::AllocOp>,
- MemRefRewritePattern<memref::AllocaOp>>(
+ AllocLikeFlattenPattern<memref::AllocOp>,
+ AllocLikeFlattenPattern<memref::AllocaOp>>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index e45a10ca0d431..a76da860a725e 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -271,6 +271,81 @@ func.func @alloca() -> memref<4x8xf32> {
// -----
+func.func @alloc_dynamic(%n: index) -> memref<?x4xf32> {
+ %0 = memref.alloc(%n) : memref<?x4xf32>
+ return %0 : memref<?x4xf32>
+}
+
+// CHECK-LABEL: func @alloc_dynamic
+// CHECK-SAME: (%[[N:.*]]: index)
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) : memref<?xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [%[[N]], 4], strides: [4, 1]
+
+// -----
+
+func.func @alloca_dynamic(%n: index) -> memref<?x4xf32> {
+ %0 = memref.alloca(%n) : memref<?x4xf32>
+ return %0 : memref<?x4xf32>
+}
+
+// CHECK-LABEL: func @alloca_dynamic
+// CHECK-SAME: (%[[N:.*]]: index)
+// CHECK: %[[ALLOCA:.*]] = memref.alloca(%{{.*}}) : memref<?xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOCA]] to offset: [0], sizes: [%[[N]], 4], strides: [4, 1]
+
+// -----
+
+// Explicit row-major strides: same as the default layout, should flatten.
+func.func @flatten_alloc_strided_row_major() -> memref<4x8xf32, strided<[8, 1]>> {
+ %0 = memref.alloc() : memref<4x8xf32, strided<[8, 1]>>
+ return %0 : memref<4x8xf32, strided<[8, 1]>>
+}
+
+// CHECK-LABEL: func @flatten_alloc_strided_row_major
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32, strided<[8, 1]>>
+
+// -----
+
+// Non-zero static offset: the flat allocation covers [0, offset+extent) = [0, 82)
+// and the reinterpret_cast restores the original offset in the result type.
+// The old castAllocResult hardcoded offset=0 in the reinterpret_cast regardless
+// of the source layout, producing a type-mismatched (and thus invalid) cast.
+func.func @flatten_alloc_strided_offset() -> memref<4x8xf32, strided<[8, 1], offset: 50>> {
+ %0 = memref.alloc() : memref<4x8xf32, strided<[8, 1], offset: 50>>
+ return %0 : memref<4x8xf32, strided<[8, 1], offset: 50>>
+}
+
+// CHECK-LABEL: func @flatten_alloc_strided_offset
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<82xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [50], sizes: [4, 8], strides: [8, 1] : memref<82xf32, strided<[1]>> to memref<4x8xf32, strided<[8, 1], offset: 50>>
+
+// -----
+
+// Padded strides: flatten to the maximum extent (max(18*4, 2*8) = 72).
+func.func @flatten_alloc_strided_padded() -> memref<4x8xf32, strided<[18, 2]>> {
+ %0 = memref.alloc() : memref<4x8xf32, strided<[18, 2]>>
+ return %0 : memref<4x8xf32, strided<[18, 2]>>
+}
+
+// CHECK-LABEL: func @flatten_alloc_strided_padded
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<72xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [18, 2] : memref<72xf32, strided<[1]>> to memref<4x8xf32, strided<[18, 2]>>
+
+// -----
+
+// Multi-dynamic alloc: strides are dynamic so the pattern bails out.
+func.func @alloc_multi_dynamic(%m: index, %n: index) -> memref<?x?xf32> {
+ %0 = memref.alloc(%m, %n) : memref<?x?xf32>
+ return %0 : memref<?x?xf32>
+}
+
+// CHECK-LABEL: func @alloc_multi_dynamic
+// CHECK: memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
+// CHECK-NOT: memref.reinterpret_cast
+
+// -----
+
func.func @chained_alloc_load() -> vector<8xf32> {
%c3 = arith.constant 3 : index
%c6 = arith.constant 6 : index
More information about the Mlir-commits
mailing list