[Mlir-commits] [mlir] [MemRef] Propagate strided layout through view-like ops in multiBuffer (PR #176941)

Zhuoran Yin llvmlistbot at llvm.org
Tue Jan 20 07:25:05 PST 2026


https://github.com/jerryyin created https://github.com/llvm/llvm-project/pull/176941

The memref::multiBuffer transformation replaces an allocation with a multi-buffered allocation and creates a strided memref.subview at each loop iteration. When the original allocation is used through view-like ops, the existing code only handles SubViewOp, leaving other view-like ops with incorrect types.

This patch extends replaceUsesAndPropagateType to handle ExpandShapeOp, CollapseShapeOp, and CastOp using TypeSwitch. For each view-like op, we compute the correct result type (or assert on failure) and create a new operation, then recursively propagate the updated type through chains. New FileCheck tests cover expand_shape, collapse_shape, cast, and a chained expand_shape->cast case.

A single ViewLikeOpInterface hook is not practical here: view-like ops have distinct type inference and validity rules (e.g., subview uses offset/size/stride inference, expand/collapse use reassociation, cast requires compatibility checks). Ops like memref.view or memref.reinterpret_cast need additional layout/size validation beyond what multi-buffering currently tracks, so this patch handles the common safe cases directly.

>From a403d42aa9351934edd844218030f22cecbb31ee Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 19 Jan 2026 22:17:21 +0000
Subject: [PATCH] [MemRef] Propagate strided layout through view-like ops in
 multiBuffer

The memref::multiBuffer transformation replaces an allocation with a
multi-buffered allocation and creates a strided memref.subview at each
loop iteration. When the original allocation is used through view-like
ops, the existing code only handles SubViewOp, leaving other view-like
ops with incorrect types.

This patch extends replaceUsesAndPropagateType to handle ExpandShapeOp,
CollapseShapeOp, and CastOp using TypeSwitch. For each view-like op, we
compute the correct result type (or assert on failure) and create a new
operation, then recursively propagate the updated type through chains.
New FileCheck tests cover expand_shape, collapse_shape, cast, and a
chained expand_shape->cast case.

A single ViewLikeOpInterface hook is not practical here: view-like ops
have distinct type inference and validity rules (e.g., subview uses
offset/size/stride inference, expand/collapse use reassociation, cast
requires compatibility checks). Ops like memref.view or
memref.reinterpret_cast need additional layout/size validation beyond
what multi-buffering currently tracks, so this patch handles the common
safe cases directly.
---
 .../Dialect/MemRef/Transforms/MultiBuffer.cpp | 134 ++++++++++++++----
 mlir/test/Dialect/MemRef/multibuffer.mlir     | 120 ++++++++++++++++
 2 files changed, 225 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 860384f954536..f0db878fa0e8f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 
 using namespace mlir;
@@ -37,40 +38,114 @@ static bool overrideBuffer(Operation *op, Value buffer) {
   return copyOp.getTarget() == buffer;
 }
 
-/// Replace the uses of `oldOp` with the given `val` and for subview uses
+/// Replace the uses of `oldOp` with the given `val` and for view-like uses
 /// propagate the type change. Changing the memref type may require propagating
-/// it through subview ops so we cannot just do a replaceAllUse but need to
-/// propagate the type change and erase old subview ops.
-static void replaceUsesAndPropagateType(RewriterBase &rewriter,
-                                        Operation *oldOp, Value val) {
+/// it through view-like ops (subview, expand_shape, collapse_shape, cast) so
+/// we need to propagate the type change and erase old view ops.
+///
+/// Only view-like ops whose result type can be recomputed from the new source
+/// type and existing op attributes are handled here. Other ops fall back to
+/// operand replacement without type propagation.
+static LogicalResult replaceUsesAndPropagateType(RewriterBase &rewriter,
+                                                 Operation *oldOp, Value val) {
+  SmallVector<Operation *> opsToErase;
   // Iterate with early_inc to erase current user inside the loop.
   for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
     Operation *user = use.getOwner();
-    if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) {
-      // `subview(old_op)` is replaced by a new `subview(val)`.
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(subviewUse);
-      MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
-          subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
-          subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
-          subviewUse.getStaticStrides());
-      Value newSubview = memref::SubViewOp::create(
-          rewriter, subviewUse->getLoc(), newType, val,
-          subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
-          subviewUse.getMixedStrides());
-
-      // Ouch recursion ... is this really necessary?
-      replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
-
-      // Safe to erase.
-      rewriter.eraseOp(subviewUse);
-      continue;
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(user);
+    MemRefType srcType = cast<MemRefType>(val.getType());
+
+    // Try to create a new view-like op with updated result type.
+    // Each view-like op has its own method to compute the result type.
+    bool typeInferenceFailed = false;
+    Value replacement =
+        llvm::TypeSwitch<Operation *, Value>(user)
+            .Case<memref::SubViewOp>([&](memref::SubViewOp subview) -> Value {
+              MemRefType newType =
+                  memref::SubViewOp::inferRankReducedResultType(
+                      subview.getType().getShape(), srcType,
+                      subview.getStaticOffsets(), subview.getStaticSizes(),
+                      subview.getStaticStrides());
+              return memref::SubViewOp::create(
+                  rewriter, subview->getLoc(), newType, val,
+                  subview.getMixedOffsets(), subview.getMixedSizes(),
+                  subview.getMixedStrides());
+            })
+            .Case<memref::ExpandShapeOp>(
+                [&](memref::ExpandShapeOp expand) -> Value {
+                  FailureOr<MemRefType> newType =
+                      memref::ExpandShapeOp::computeExpandedType(
+                          srcType, expand.getResultType().getShape(),
+                          expand.getReassociationIndices());
+                  assert(succeeded(newType) &&
+                         "expected to compute expanded type after "
+                         "multi-buffering");
+                  if (failed(newType)) {
+                    typeInferenceFailed = true;
+                    return Value();
+                  }
+                  return memref::ExpandShapeOp::create(
+                      rewriter, expand->getLoc(), *newType, val,
+                      expand.getReassociationIndices(),
+                      expand.getMixedOutputShape());
+                })
+            .Case<memref::CollapseShapeOp>(
+                [&](memref::CollapseShapeOp collapse) -> Value {
+                  FailureOr<MemRefType> newType =
+                      memref::CollapseShapeOp::computeCollapsedType(
+                          srcType, collapse.getReassociationIndices());
+                  assert(succeeded(newType) &&
+                         "expected to compute collapsed type after "
+                         "multi-buffering");
+                  if (failed(newType)) {
+                    typeInferenceFailed = true;
+                    return Value();
+                  }
+                  return memref::CollapseShapeOp::create(
+                      rewriter, collapse->getLoc(), *newType, val,
+                      collapse.getReassociationIndices());
+                })
+            .Case<memref::CastOp>([&](memref::CastOp cast) -> Value {
+              bool isCastCompatible =
+                  memref::CastOp::areCastCompatible(srcType, cast.getType());
+              assert(isCastCompatible &&
+                     "expected cast to remain compatible after "
+                     "multi-buffering");
+              if (!isCastCompatible) {
+                typeInferenceFailed = true;
+                return Value();
+              }
+              return memref::CastOp::create(rewriter, cast->getLoc(),
+                                            cast.getType(), val);
+            })
+            .Default([&](Operation *) -> Value { return Value(); });
+
+    if (typeInferenceFailed) {
+      user->emitError(
+          "failed to compute view-like result type after multi-buffering");
+      return failure();
+    }
+
+    if (replacement) {
+      // Recursively propagate through view-like ops and mark old op for
+      // erasure.
+      if (failed(replaceUsesAndPropagateType(rewriter, user, replacement)))
+        return failure();
+      opsToErase.push_back(user);
+    } else {
+      // Not a view-like op: just replace operand.
+      rewriter.startOpModification(user);
+      use.set(val);
+      rewriter.finalizeOpModification(user);
     }
-    // Non-subview: replace with new value.
-    rewriter.startOpModification(user);
-    use.set(val);
-    rewriter.finalizeOpModification(user);
   }
+
+  for (Operation *op : opsToErase) {
+    rewriter.eraseOp(op);
+  }
+
+  return success();
 }
 
 // Transformation to do multi-buffering/array expansion to remove dependencies
@@ -216,7 +291,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
   }
 
   // 6. RAUW with the particular slice, taking modular rotation into account.
-  replaceUsesAndPropagateType(rewriter, allocOp, subview);
+  if (failed(replaceUsesAndPropagateType(rewriter, allocOp, subview)))
+    return failure();
 
   // 7. Finally, erase the old allocOp.
   rewriter.eraseOp(allocOp);
diff --git a/mlir/test/Dialect/MemRef/multibuffer.mlir b/mlir/test/Dialect/MemRef/multibuffer.mlir
index 4ab7d993e6fd1..b004ebfa1abd0 100644
--- a/mlir/test/Dialect/MemRef/multibuffer.mlir
+++ b/mlir/test/Dialect/MemRef/multibuffer.mlir
@@ -104,3 +104,123 @@ func.func @multi_buffer_negative(%a: memref<1024x1024xf32>) {
   return
 }
 
+// -----
+
+// Test that multi-buffering correctly propagates strided layout through expand_shape.
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer_expand_shape
+func.func @multi_buffer_expand_shape(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32>
+  %0 = memref.alloc() : memref<4x128xf32>
+  %c1024 = arith.constant 1024 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %{{.*}}
+  scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+    %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+        memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+    memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: %[[EXPANDED:.*]] = memref.expand_shape %[[SV]] {{\[\[}}0, 1], [2, 3]] output_shape [2, 2, 64, 2] : memref<4x128xf32, strided<[128, 1], offset: ?>> into memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>>
+    %expanded = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [2, 2, 64, 2]
+        : memref<4x128xf32> into memref<2x2x64x2xf32>
+// CHECK: "some_use"(%[[EXPANDED]]) : (memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>>) -> ()
+    "some_use"(%expanded) : (memref<2x2x64x2xf32>) -> ()
+  }
+  return
+}
+
+// -----
+
+// Test that multi-buffering correctly propagates strided layout through collapse_shape.
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer_collapse_shape
+func.func @multi_buffer_collapse_shape(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32>
+  %0 = memref.alloc() : memref<4x128xf32>
+  %c1024 = arith.constant 1024 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %{{.*}}
+  scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+    %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+        memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+    memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[SV]] {{\[\[}}0, 1]] : memref<4x128xf32, strided<[128, 1], offset: ?>> into memref<512xf32, strided<[1], offset: ?>>
+    %collapsed = memref.collapse_shape %0 [[0, 1]]
+        : memref<4x128xf32> into memref<512xf32>
+// CHECK: "some_use"(%[[COLLAPSED]]) : (memref<512xf32, strided<[1], offset: ?>>) -> ()
+    "some_use"(%collapsed) : (memref<512xf32>) -> ()
+  }
+  return
+}
+
+// -----
+
+// Test that multi-buffering correctly propagates through memref.cast.
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer_cast
+func.func @multi_buffer_cast(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32>
+  %0 = memref.alloc() : memref<4x128xf32>
+  %c1024 = arith.constant 1024 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %{{.*}}
+  scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+    %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+        memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+    memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: %[[CAST:.*]] = memref.cast %[[SV]] : memref<4x128xf32, strided<[128, 1], offset: ?>> to memref<?x128xf32>
+    %casted = memref.cast %0 : memref<4x128xf32> to memref<?x128xf32>
+// CHECK: "some_use"(%[[CAST]]) : (memref<?x128xf32>) -> ()
+    "some_use"(%casted) : (memref<?x128xf32>) -> ()
+  }
+  return
+}
+
+// -----
+
+// Test that multi-buffering correctly propagates through chained view-like ops.
+
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (((d0 - 1) floordiv 3) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer_chained_view_ops
+func.func @multi_buffer_chained_view_ops(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<5x4x128xf32>
+  %0 = memref.alloc() : memref<4x128xf32>
+  %c1024 = arith.constant 1024 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %{{.*}}
+  scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]])
+// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+    %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+        memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, strided<[128, 1], offset: ?>>
+    memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: %[[EXPANDED:.*]] = memref.expand_shape %[[SV]] {{\[\[}}0, 1], [2, 3]] output_shape [2, 2, 64, 2] : memref<4x128xf32, strided<[128, 1], offset: ?>> into memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>>
+    %expanded = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [2, 2, 64, 2]
+        : memref<4x128xf32> into memref<2x2x64x2xf32>
+// CHECK: %[[CAST:.*]] = memref.cast %[[EXPANDED]] : memref<2x2x64x2xf32, strided<[256, 128, 2, 1], offset: ?>> to memref<?x2x64x2xf32>
+    %casted = memref.cast %expanded : memref<2x2x64x2xf32> to memref<?x2x64x2xf32>
+// CHECK: "some_use"(%[[CAST]]) : (memref<?x2x64x2xf32>) -> ()
+    "some_use"(%casted) : (memref<?x2x64x2xf32>) -> ()
+  }
+  return
+}



More information about the Mlir-commits mailing list