[Mlir-commits] [mlir] [mlir][Bufferization] Fix to_buffer(tensor.cast) folder (PR #150511)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 24 13:11:30 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast.
The `read_only` attribute was also dropped by the pattern and is retained now as well.

---
Full diff: https://github.com/llvm/llvm-project/pull/150511.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10-2) 
- (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+18-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 875a06546c9f0..c3b5476029ee5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -804,10 +804,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
         tensorCastOperand.getOperand().getType());
     if (!srcTensorType)
       return failure();
+    auto currentOutputMemRefType =
+        dyn_cast<MemRefType>(toBuffer.getResult().getType());
+    if (!currentOutputMemRefType)
+      return failure();
+
     auto memrefType = MemRefType::get(srcTensorType.getShape(),
-                                      srcTensorType.getElementType());
+                                      srcTensorType.getElementType(),
+                                      currentOutputMemRefType.getLayout(),
+                                      currentOutputMemRefType.getMemorySpace());
     Value memref = rewriter.create<ToBufferOp>(toBuffer.getLoc(), memrefType,
-                                               tensorCastOperand.getOperand());
+                                               tensorCastOperand.getOperand(),
+                                               toBuffer.getReadOnly());
     rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
                                                 memref);
     return success();
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index f44e29071796d..2acd19453a04d 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
 func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
   memref<?x?x16x32xi8> {
   %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
-  %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
+  %1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
   return %1 : memref<?x?x16x32xi8>
 }
-// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
 // CHECK:   %[[M1:.+]] = memref.cast %[[M]]
 // CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
 // CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
 
 // -----
 
+// CHECK-LABEL: func @tensor_cast_to_buffer
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) ->
+  memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> {
+  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+  %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+  return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+}
+// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK:   %[[M1:.+]] = memref.cast %[[M]]
+// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK:   return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+
+// -----
+
 // Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
 // CHECK-LABEL: func @load_from_buffer_cast(
 func.func @load_from_buffer_cast(%arg0: index, %arg1: index,

``````````

</details>


https://github.com/llvm/llvm-project/pull/150511


More information about the Mlir-commits mailing list