[Mlir-commits] [mlir] [mlir][Bufferization] Fix to_buffer(tensor.cast) folder (PR #150511)

Quinn Dawkins llvmlistbot at llvm.org
Thu Jul 24 13:10:57 PDT 2025


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/150511

Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast.
The `read_only` attribute was also dropped by the pattern and is retained now as well.

>From dd967db6a34d6c9b1651a64fcb8617e7fc63acc7 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 24 Jul 2025 16:05:41 -0400
Subject: [PATCH] [mlir][Bufferization] Fix to_buffer(tensor.cast) folder

Previously this folder would ignore the layout and memory space on the
to_buffer op and set it as default. This changes the pattern to retain
both fields from the existing memref type but incorporate the static
shape information from the tensor cast.
---
 .../Bufferization/IR/BufferizationOps.cpp     | 12 +++++++++--
 .../Dialect/Bufferization/canonicalize.mlir   | 20 +++++++++++++++++--
 2 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 875a06546c9f0..c3b5476029ee5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -804,10 +804,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
         tensorCastOperand.getOperand().getType());
     if (!srcTensorType)
       return failure();
+    auto currentOutputMemRefType =
+        dyn_cast<MemRefType>(toBuffer.getResult().getType());
+    if (!currentOutputMemRefType)
+      return failure();
+
     auto memrefType = MemRefType::get(srcTensorType.getShape(),
-                                      srcTensorType.getElementType());
+                                      srcTensorType.getElementType(),
+                                      currentOutputMemRefType.getLayout(),
+                                      currentOutputMemRefType.getMemorySpace());
     Value memref = rewriter.create<ToBufferOp>(toBuffer.getLoc(), memrefType,
-                                               tensorCastOperand.getOperand());
+                                               tensorCastOperand.getOperand(),
+                                               toBuffer.getReadOnly());
     rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
                                                 memref);
     return success();
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index f44e29071796d..2acd19453a04d 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
 func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
   memref<?x?x16x32xi8> {
   %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
-  %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
+  %1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
   return %1 : memref<?x?x16x32xi8>
 }
-// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
 // CHECK:   %[[M1:.+]] = memref.cast %[[M]]
 // CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
 // CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
 
 // -----
 
+// CHECK-LABEL: func @tensor_cast_to_buffer
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) ->
+  memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> {
+  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+  %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+  return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+}
+// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK:   %[[M1:.+]] = memref.cast %[[M]]
+// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK:   return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+
+// -----
+
 // Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
 // CHECK-LABEL: func @load_from_buffer_cast(
 func.func @load_from_buffer_cast(%arg0: index, %arg1: index,



More information about the Mlir-commits mailing list