[Mlir-commits] [mlir] fd8f69d - [mlir][Bufferization] Fix to_buffer(tensor.cast) folder (#150511)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 24 14:50:17 PDT 2025
Author: Quinn Dawkins
Date: 2025-07-24T17:50:14-04:00
New Revision: fd8f69d3eb9be3a987b4044fa93dd9ed0aafe094
URL: https://github.com/llvm/llvm-project/commit/fd8f69d3eb9be3a987b4044fa93dd9ed0aafe094
DIFF: https://github.com/llvm/llvm-project/commit/fd8f69d3eb9be3a987b4044fa93dd9ed0aafe094.diff
LOG: [mlir][Bufferization] Fix to_buffer(tensor.cast) folder (#150511)
Previously this folder would ignore the layout and memory space on the
to_buffer op and set it as default. This changes the pattern to retain
both fields from the existing memref type but incorporate the static
shape information from the tensor cast.
The `read_only` attribute was also dropped by the pattern and is
retained now as well.
Added:
Modified:
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/test/Dialect/Bufferization/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index dbc7d0dd74a00..7eb729f349638 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -805,10 +805,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
tensorCastOperand.getOperand().getType());
if (!srcTensorType)
return failure();
+ auto currentOutputMemRefType =
+ dyn_cast<MemRefType>(toBuffer.getResult().getType());
+ if (!currentOutputMemRefType)
+ return failure();
+
auto memrefType = MemRefType::get(srcTensorType.getShape(),
- srcTensorType.getElementType());
+ srcTensorType.getElementType(),
+ currentOutputMemRefType.getLayout(),
+ currentOutputMemRefType.getMemorySpace());
Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
- tensorCastOperand.getOperand());
+ tensorCastOperand.getOperand(),
+ toBuffer.getReadOnly());
rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
memref);
return success();
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index f44e29071796d..2acd19453a04d 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
memref<?x?x16x32xi8> {
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
- %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
+ %1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
return %1 : memref<?x?x16x32xi8>
}
-// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
// CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
// -----
+// CHECK-LABEL: func @tensor_cast_to_buffer
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) ->
+ memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> {
+ %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+ %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+ return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+}
+// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK: %[[M1:.+]] = memref.cast %[[M]]
+// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK: return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+
+// -----
+
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
// CHECK-LABEL: func @load_from_buffer_cast(
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
More information about the Mlir-commits
mailing list