[Mlir-commits] [mlir] f0bf972 - [MemRef] Propagate strided layout through view-like ops in multiBuffer (#176941)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 27 07:00:27 PST 2026
Author: Zhuoran Yin
Date: 2026-01-27T10:00:22-05:00
New Revision: f0bf97281f981f130eee0b7672a5eb223bb8bc01
URL: https://github.com/llvm/llvm-project/commit/f0bf97281f981f130eee0b7672a5eb223bb8bc01
DIFF: https://github.com/llvm/llvm-project/commit/f0bf97281f981f130eee0b7672a5eb223bb8bc01.diff
LOG: [MemRef] Propagate strided layout through view-like ops in multiBuffer (#176941)
The memref::multiBuffer transformation replaces an allocation with a
multi-buffered allocation and creates a strided memref.subview at each
loop iteration. When the original allocation is used through view-like
ops, the existing code only handles SubViewOp, leaving other view-like
ops with incorrect types.
This patch extends replaceUsesAndPropagateType to handle ExpandShapeOp,
CollapseShapeOp, and CastOp using TypeSwitch. For each view-like op, we
compute the correct result type (or assert on failure) and create a new
operation, then recursively propagate the updated type through chains.
New FileCheck tests cover expand_shape, collapse_shape, cast, and a
chained expand_shape->cast case.
A single ViewLikeOpInterface hook is not practical here: view-like ops
have distinct type inference and validity rules (e.g., subview uses
offset/size/stride inference, expand/collapse use reassociation, cast
requires compatibility checks). Ops like memref.view or
memref.reinterpret_cast need additional layout/size validation beyond
what multi-buffering currently tracks, so this patch handles the common
safe cases directly.
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
mlir/test/Dialect/MemRef/multibuffer.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 860384f954536..ce45f847ccaed 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
@@ -37,40 +38,101 @@ static bool overrideBuffer(Operation *op, Value buffer) {
return copyOp.getTarget() == buffer;
}
-/// Replace the uses of `oldOp` with the given `val` and for subview uses
+/// Replace the uses of `oldOp` with the given `val` and for view-like uses
/// propagate the type change. Changing the memref type may require propagating
-/// it through subview ops so we cannot just do a replaceAllUse but need to
-/// propagate the type change and erase old subview ops.
-static void replaceUsesAndPropagateType(RewriterBase &rewriter,
- Operation *oldOp, Value val) {
+/// it through view-like ops (subview, expand_shape, collapse_shape, cast) so
+/// we need to propagate the type change and erase old view ops.
+///
+/// Only view-like ops whose result type can be recomputed from the new source
+/// type and existing op attributes are handled here. Other ops fall back to
+/// operand replacement without type propagation.
+static LogicalResult replaceUsesAndPropagateType(RewriterBase &rewriter,
+ Operation *oldOp, Value val) {
+ SmallVector<Operation *> opsToErase;
// Iterate with early_inc to erase current user inside the loop.
for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
Operation *user = use.getOwner();
- if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) {
- // `subview(old_op)` is replaced by a new `subview(val)`.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(subviewUse);
- MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
- subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
- subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
- subviewUse.getStaticStrides());
- Value newSubview = memref::SubViewOp::create(
- rewriter, subviewUse->getLoc(), newType, val,
- subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
- subviewUse.getMixedStrides());
-
- // Ouch recursion ... is this really necessary?
- replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
-
- // Safe to erase.
- rewriter.eraseOp(subviewUse);
- continue;
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(user);
+ MemRefType srcType = cast<MemRefType>(val.getType());
+
+ // Try to create a new view-like op with updated result type.
+ // Each view-like op has its own method to compute the result type.
+ bool typeInferenceFailed = false;
+ Value replacement =
+ llvm::TypeSwitch<Operation *, Value>(user)
+ .Case([&](memref::SubViewOp subview) -> Value {
+ MemRefType newType =
+ memref::SubViewOp::inferRankReducedResultType(
+ subview.getType().getShape(), srcType,
+ subview.getStaticOffsets(), subview.getStaticSizes(),
+ subview.getStaticStrides());
+ return memref::SubViewOp::create(
+ rewriter, subview->getLoc(), newType, val,
+ subview.getMixedOffsets(), subview.getMixedSizes(),
+ subview.getMixedStrides());
+ })
+ .Case([&](memref::ExpandShapeOp expand) -> Value {
+ FailureOr<MemRefType> newType =
+ memref::ExpandShapeOp::computeExpandedType(
+ srcType, expand.getResultType().getShape(),
+ expand.getReassociationIndices());
+ if (failed(newType)) {
+ typeInferenceFailed = true;
+ return Value();
+ }
+ return memref::ExpandShapeOp::create(
+ rewriter, expand->getLoc(), *newType, val,
+ expand.getReassociationIndices(),
+ expand.getMixedOutputShape());
+ })
+ .Case([&](memref::CollapseShapeOp collapse) -> Value {
+ FailureOr<MemRefType> newType =
+ memref::CollapseShapeOp::computeCollapsedType(
+ srcType, collapse.getReassociationIndices());
+ if (failed(newType)) {
+ typeInferenceFailed = true;
+ return Value();
+ }
+ return memref::CollapseShapeOp::create(
+ rewriter, collapse->getLoc(), *newType, val,
+ collapse.getReassociationIndices());
+ })
+ .Case([&](memref::CastOp cast) -> Value {
+ if (!memref::CastOp::areCastCompatible(srcType, cast.getType())) {
+ typeInferenceFailed = true;
+ return Value();
+ }
+ return memref::CastOp::create(rewriter, cast->getLoc(),
+ cast.getType(), val);
+ })
+ .Default([&](Operation *) -> Value { return Value(); });
+
+ if (typeInferenceFailed) {
+ user->emitOpError(
+ "failed to compute view-like result type after multi-buffering");
+ return failure();
+ }
+
+ if (replacement) {
+ // Recursively propagate through view-like ops and mark old op for
+ // erasure.
+ if (failed(replaceUsesAndPropagateType(rewriter, user, replacement)))
+ return failure();
+ opsToErase.push_back(user);
+ } else {
+ // Not a view-like op: just replace operand.
+ rewriter.startOpModification(user);
+ use.set(val);
+ rewriter.finalizeOpModification(user);
}
- // Non-subview: replace with new value.
- rewriter.startOpModification(user);
- use.set(val);
- rewriter.finalizeOpModification(user);
}
+
+ for (Operation *op : opsToErase) {
+ rewriter.eraseOp(op);
+ }
+
+ return success();
}
// Transformation to do multi-buffering/array expansion to remove dependencies
@@ -216,7 +278,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
}
// 6. RAUW with the particular slice, taking modular rotation into account.
- replaceUsesAndPropagateType(rewriter, allocOp, subview);
+ if (failed(replaceUsesAndPropagateType(rewriter, allocOp, subview)))
+ return failure();
// 7. Finally, erase the old allocOp.
rewriter.eraseOp(allocOp);
diff --git a/mlir/test/Dialect/MemRef/multibuffer.mlir b/mlir/test/Dialect/MemRef/multibuffer.mlir
index 4ab7d993e6fd1..b004ebfa1abd0 100644
--- a/mlir/test/Dialect/MemRef/multibuffer.mlir
+++ b/mlir/test/Dialect/MemRef/multibuffer.mlir
@@ -104,3 +104,123 @@ func.func @multi_buffer_negative(%a: memref<1024x1024xf32>) {
return
}
+// -----
+
+// Test that multi-buffering correctly propagates strided layout through expand_shape.
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer_expand_shape
+func.func @multi_buffer_expand_shape(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32>
+ %0 = memref.alloc() : memref<4x128xf32>
+ %c1024 = arith.constant 1024 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %{{.*}}
+ scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+ %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+ memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+ memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: %[[EXPANDED:.*]] = memref.expand_shape %[[SV]] {{\[\[}}0, 1], [2, 3]] output_shape [2, 2, 64, 2] : memref<4x128xf32, strided<[128, 1], offset: ?>> into memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>>
+ %expanded = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [2, 2, 64, 2]
+ : memref<4x128xf32> into memref<2x2x64x2xf32>
+// CHECK: "some_use"(%[[EXPANDED]]) : (memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>>) -> ()
+ "some_use"(%expanded) : (memref<2x2x64x2xf32>) -> ()
+ }
+ return
+}
+
+// -----
+
+// Test that multi-buffering correctly propagates strided layout through collapse_shape.
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer_collapse_shape
+func.func @multi_buffer_collapse_shape(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32>
+ %0 = memref.alloc() : memref<4x128xf32>
+ %c1024 = arith.constant 1024 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %{{.*}}
+ scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+ %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+ memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+ memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[SV]] {{\[\[}}0, 1]] : memref<4x128xf32, strided<[128, 1], offset: ?>> into memref<512xf32, strided<[1], offset: ?>>
+ %collapsed = memref.collapse_shape %0 [[0, 1]]
+ : memref<4x128xf32> into memref<512xf32>
+// CHECK: "some_use"(%[[COLLAPSED]]) : (memref<512xf32, strided<[1], offset: ?>>) -> ()
+ "some_use"(%collapsed) : (memref<512xf32>) -> ()
+ }
+ return
+}
+
+// -----
+
+// Test that multi-buffering correctly propagates through memref.cast.
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer_cast
+func.func @multi_buffer_cast(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32>
+ %0 = memref.alloc() : memref<4x128xf32>
+ %c1024 = arith.constant 1024 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %{{.*}}
+ scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+ %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+ memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+ memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: %[[CAST:.*]] = memref.cast %[[SV]] : memref<4x128xf32, strided<[128, 1], offset: ?>> to memref<?x128xf32>
+ %casted = memref.cast %0 : memref<4x128xf32> to memref<?x128xf32>
+// CHECK: "some_use"(%[[CAST]]) : (memref<?x128xf32>) -> ()
+ "some_use"(%casted) : (memref<?x128xf32>) -> ()
+ }
+ return
+}
+
+// -----
+
+// Test that multi-buffering correctly propagates through chained view-like ops.
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer_chained_view_ops
+func.func @multi_buffer_chained_view_ops(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32>
+ %0 = memref.alloc() : memref<4x128xf32>
+ %c1024 = arith.constant 1024 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %{{.*}}
+ scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+ %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+ memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+ memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: %[[EXPANDED:.*]] = memref.expand_shape %[[SV]] {{\[\[}}0, 1], [2, 3]] output_shape [2, 2, 64, 2] : memref<4x128xf32, strided<[128, 1], offset: ?>> into memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>>
+ %expanded = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [2, 2, 64, 2]
+ : memref<4x128xf32> into memref<2x2x64x2xf32>
+// CHECK: %[[CAST:.*]] = memref.cast %[[EXPANDED]] : memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>> to memref<?x2x64x2xf32>
+ %casted = memref.cast %expanded : memref<2x2x64x2xf32> to memref<?x2x64x2xf32>
+// CHECK: "some_use"(%[[CAST]]) : (memref<?x2x64x2xf32>) -> ()
+ "some_use"(%casted) : (memref<?x2x64x2xf32>) -> ()
+ }
+ return
+}
More information about the Mlir-commits
mailing list