[Mlir-commits] [mlir] ff86be2 - [MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination violation (#188980)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 3 02:21:06 PDT 2026


Author: Mehdi Amini
Date: 2026-04-03T11:21:00+02:00
New Revision: ff86be21de109403175caf6d906be856210df494

URL: https://github.com/llvm/llvm-project/commit/ff86be21de109403175caf6d906be856210df494
DIFF: https://github.com/llvm/llvm-project/commit/ff86be21de109403175caf6d906be856210df494.diff

LOG: [MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination violation (#188980)

The generic MemRefRewritePattern handles AllocOp/AllocaOp by calling
getFlattenMemrefAndOffset with the op's own result as the source memref.
This inserts ExtractStridedMetadataOp and ReinterpretCastOp that consume
op.result before the alloc op itself in the block. After
replaceOpWithNewOp, op.result is RAUW'd to the new ReinterpretCastOp
result, leaving those earlier ops with forward references — a domination
violation caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.

Replace the AllocOp/AllocaOp cases in MemRefRewritePattern with a
dedicated AllocLikeFlattenPattern that never touches op.result until the
final replaceOpWithNewOp:
- sizes come from op.getMixedSizes() (operands, not the result)
- strides come from getStridesAndOffset on the MemRefType
- the flat allocation size is computed via
getLinearizedMemRefOffsetAndSize plus the static base offset so the
buffer covers [0, offset+extent)
- castAllocResult is simplified to take the pre-computed sizes and
strides rather than inserting an ExtractStridedMetadataOp on the
original op
- non-zero static base offsets are now correctly preserved in the
reinterpret_cast (the old code hardcoded offset=0, which was a verifier
error for layouts with offset \!= 0)
- dynamic offsets or strides bail out via notifyMatchFailure

Also remove the now-dead AllocOp/AllocaOp branches from replaceOp() and
the constexpr specialisation in getIndices().

Assisted-by: Claude Code

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
    mlir/test/Dialect/MemRef/flatten_memref.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 32244728ff333..6b56ea3ff5cac 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -107,35 +107,11 @@ static Value getTargetMemref(Operation *op) {
       .Default(nullptr);
 }
 
-template <typename T>
-static void castAllocResult(T oper, T newOper, Location loc,
-                            PatternRewriter &rewriter) {
-  memref::ExtractStridedMetadataOp stridedMetadata =
-      memref::ExtractStridedMetadataOp::create(rewriter, loc, oper);
-  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
-      oper, cast<MemRefType>(oper.getType()), newOper,
-      /*offset=*/rewriter.getIndexAttr(0),
-      stridedMetadata.getConstifiedMixedSizes(),
-      stridedMetadata.getConstifiedMixedStrides());
-}
-
 template <typename T>
 static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
                       Value offset) {
   Location loc = op->getLoc();
   llvm::TypeSwitch<Operation *>(op.getOperation())
-      .Case([&](memref::AllocOp oper) {
-        auto newAlloc = memref::AllocOp::create(
-            rewriter, loc, cast<MemRefType>(flatMemref.getType()),
-            oper.getAlignmentAttr());
-        castAllocResult(oper, newAlloc, loc, rewriter);
-      })
-      .Case([&](memref::AllocaOp oper) {
-        auto newAlloca = memref::AllocaOp::create(
-            rewriter, loc, cast<MemRefType>(flatMemref.getType()),
-            oper.getAlignmentAttr());
-        castAllocResult(oper, newAlloca, loc, rewriter);
-      })
       .Case([&](memref::LoadOp op) {
         auto newLoad =
             memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
@@ -196,12 +172,7 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
 
 template <typename T>
 static ValueRange getIndices(T op) {
-  if constexpr (std::is_same_v<T, memref::AllocaOp> ||
-                std::is_same_v<T, memref::AllocOp>) {
-    return ValueRange{};
-  } else {
-    return op.getIndices();
-  }
+  return op.getIndices();
 }
 
 template <typename T>
@@ -230,19 +201,111 @@ static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
       .Default([&](auto op) { return success(); });
 }
 
+// Pattern for memref::AllocOp and memref::AllocaOp.
+//
+// The "source" memref for these ops IS the op's own result, so the generic
+// MemRefRewritePattern cannot be used: getFlattenMemrefAndOffset would insert
+// ExtractStridedMetadataOp and ReinterpretCastOp that use op.result BEFORE op
+// in the block. After replaceOpWithNewOp the original result is RAUW'd to the
+// new ReinterpretCastOp, leaving the earlier ops with forward references
+// (domination violations) caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
+//
+// Instead, sizes and strides are computed from the op's operands and type
+// (which all dominate the op), avoiding any reference to op.result until the
+// final replaceOpWithNewOp.
+template <typename AllocLikeOp>
+struct AllocLikeFlattenPattern : public OpRewritePattern<AllocLikeOp> {
+  using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(AllocLikeOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref()))
+      return failure();
+
+    Location loc = op->getLoc();
+    auto memrefType = cast<MemRefType>(op.getType());
+    auto elemType = memrefType.getElementType();
+    if (!elemType.isIntOrFloat())
+      return failure();
+    unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
+
+    SmallVector<OpFoldResult> sizes = op.getMixedSizes();
+
+    int64_t staticOffset;
+    SmallVector<int64_t> staticStrides;
+    if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
+      return failure();
+    if (staticOffset == ShapedType::kDynamic)
+      return rewriter.notifyMatchFailure(op, "dynamic offset not supported");
+    SmallVector<OpFoldResult> strides;
+    strides.reserve(staticStrides.size());
+    for (int64_t stride : staticStrides) {
+      if (stride == ShapedType::kDynamic)
+        return rewriter.notifyMatchFailure(op,
+                                           "dynamic stride cannot be computed");
+      strides.push_back(rewriter.getIndexAttr(stride));
+    }
+
+    // Compute the linearized flat extent from sizes and strides (no SSA ops
+    // referencing op.result are created here).
+    memref::LinearizedMemRefInfo linearizedInfo;
+    OpFoldResult linearizedOffset;
+    std::tie(linearizedInfo, linearizedOffset) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, elemBitWidth, elemBitWidth, rewriter.getIndexAttr(0),
+            sizes, strides);
+    (void)linearizedOffset;
+
+    // The total allocation must cover [0, staticOffset + linearizedExtent).
+    // When the offset is non-zero, add it to the computed extent so that the
+    // buffer is large enough for elements accessed at positions
+    // [staticOffset, staticOffset + linearizedExtent).
+    OpFoldResult flatSizeOfr = linearizedInfo.linearizedSize;
+    if (staticOffset != 0) {
+      AffineExpr s0;
+      bindSymbols(rewriter.getContext(), s0);
+      flatSizeOfr = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, s0 + staticOffset, {flatSizeOfr});
+    }
+
+    // Build the flat 1-D MemRefType. The linearized size may be static or
+    // dynamic (OpFoldResult of either IntegerAttr or a Value).
+    int64_t flatDimSize = ShapedType::kDynamic;
+    if (auto attr = dyn_cast<Attribute>(flatSizeOfr))
+      if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+        flatDimSize = intAttr.getInt();
+
+    auto flatMemrefType =
+        MemRefType::get({flatDimSize}, memrefType.getElementType(),
+                        StridedLayoutAttr::get(rewriter.getContext(), 0, {1}),
+                        memrefType.getMemorySpace());
+
+    // Collect the flat dynamic-size operand (empty for fully-static case).
+    SmallVector<Value, 1> dynSizes;
+    if (flatDimSize == ShapedType::kDynamic)
+      dynSizes.push_back(getValueFromOpFoldResult(rewriter, loc, flatSizeOfr));
+
+    auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
+                                     op.getAlignmentAttr());
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, cast<MemRefType>(op.getType()), newOp,
+        rewriter.getIndexAttr(staticOffset), sizes, strides);
+    return success();
+  }
+};
+
 template <typename T>
 struct MemRefRewritePattern : public OpRewritePattern<T> {
   using OpRewritePattern<T>::OpRewritePattern;
   LogicalResult matchAndRewrite(T op,
                                 PatternRewriter &rewriter) const override {
     LogicalResult canFlatten = canBeFlattened(op, rewriter);
-    if (failed(canFlatten)) {
+    if (failed(canFlatten))
       return canFlatten;
-    }
 
     Value memref = getTargetMemref(op);
     if (!needFlattening(memref) || !checkLayout(memref))
       return failure();
+
     auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
         rewriter, op->getLoc(), memref, getIndices<T>(op));
     replaceOp<T>(op, rewriter, flatMemref, offset);
@@ -285,8 +348,8 @@ void memref::populateFlattenVectorOpsOnMemrefPatterns(
 void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
   patterns.insert<MemRefRewritePattern<memref::LoadOp>,
                   MemRefRewritePattern<memref::StoreOp>,
-                  MemRefRewritePattern<memref::AllocOp>,
-                  MemRefRewritePattern<memref::AllocaOp>>(
+                  AllocLikeFlattenPattern<memref::AllocOp>,
+                  AllocLikeFlattenPattern<memref::AllocaOp>>(
       patterns.getContext());
 }
 

diff  --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir
index e45a10ca0d431..c9166b11c8d13 100644
--- a/mlir/test/Dialect/MemRef/flatten_memref.mlir
+++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir
@@ -271,6 +271,79 @@ func.func @alloca() -> memref<4x8xf32> {
 
 // -----
 
+func.func @alloc_dynamic(%n: index) -> memref<?x4xf32> {
+  %0 = memref.alloc(%n) : memref<?x4xf32>
+  return %0 : memref<?x4xf32>
+}
+
+// CHECK-LABEL: func @alloc_dynamic
+// CHECK-SAME: (%[[N:.*]]: index)
+// CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) : memref<?xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [%[[N]], 4], strides: [4, 1]
+
+// -----
+
+func.func @alloca_dynamic(%n: index) -> memref<?x4xf32> {
+  %0 = memref.alloca(%n) : memref<?x4xf32>
+  return %0 : memref<?x4xf32>
+}
+
+// CHECK-LABEL: func @alloca_dynamic
+// CHECK-SAME: (%[[N:.*]]: index)
+// CHECK: %[[ALLOCA:.*]] = memref.alloca(%{{.*}}) : memref<?xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOCA]] to offset: [0], sizes: [%[[N]], 4], strides: [4, 1]
+
+// -----
+
+// Explicit row-major strides: same as the default layout, should flatten.
+func.func @flatten_alloc_strided_row_major() -> memref<4x8xf32, strided<[8, 1]>> {
+  %0 = memref.alloc() : memref<4x8xf32, strided<[8, 1]>>
+  return %0 : memref<4x8xf32, strided<[8, 1]>>
+}
+
+// CHECK-LABEL: func @flatten_alloc_strided_row_major
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32, strided<[8, 1]>>
+
+// -----
+
+// Non-zero static offset: the flat allocation covers [0, offset+extent) = [0, 82)
+// and the reinterpret_cast restores the original offset in the result type.
+func.func @flatten_alloc_strided_offset() -> memref<4x8xf32, strided<[8, 1], offset: 50>> {
+  %0 = memref.alloc() : memref<4x8xf32, strided<[8, 1], offset: 50>>
+  return %0 : memref<4x8xf32, strided<[8, 1], offset: 50>>
+}
+
+// CHECK-LABEL: func @flatten_alloc_strided_offset
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<82xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [50], sizes: [4, 8], strides: [8, 1] : memref<82xf32, strided<[1]>> to memref<4x8xf32, strided<[8, 1], offset: 50>>
+
+// -----
+
+// Padded strides: flatten to the maximum extent (max(18*4, 2*8) = 72).
+func.func @flatten_alloc_strided_padded() -> memref<4x8xf32, strided<[18, 2]>> {
+  %0 = memref.alloc() : memref<4x8xf32, strided<[18, 2]>>
+  return %0 : memref<4x8xf32, strided<[18, 2]>>
+}
+
+// CHECK-LABEL: func @flatten_alloc_strided_padded
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<72xf32, strided<[1]>>
+// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [18, 2] : memref<72xf32, strided<[1]>> to memref<4x8xf32, strided<[18, 2]>>
+
+// -----
+
+// Multi-dynamic alloc: strides are dynamic so the pattern bails out.
+func.func @alloc_multi_dynamic(%m: index, %n: index) -> memref<?x?xf32> {
+  %0 = memref.alloc(%m, %n) : memref<?x?xf32>
+  return %0 : memref<?x?xf32>
+}
+
+// CHECK-LABEL: func @alloc_multi_dynamic
+// CHECK: memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
+// CHECK-NOT: memref.reinterpret_cast
+
+// -----
+
 func.func @chained_alloc_load() -> vector<8xf32> {
   %c3 = arith.constant 3 : index
   %c6 = arith.constant 6 : index


        


More information about the Mlir-commits mailing list