[Mlir-commits] [mlir] cd1363b - [mlir][Bufferization] Support cast from ranked to unranked in canonic… (#152257)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 5 23:58:23 PDT 2025


Author: Adrian Kuegel
Date: 2025-08-06T08:58:20+02:00
New Revision: cd1363bf42aad4cc55f6fe6892f63de9b32977ae

URL: https://github.com/llvm/llvm-project/commit/cd1363bf42aad4cc55f6fe6892f63de9b32977ae
DIFF: https://github.com/llvm/llvm-project/commit/cd1363bf42aad4cc55f6fe6892f63de9b32977ae.diff

LOG: [mlir][Bufferization] Support cast from ranked to unranked in canonic… (#152257)

https://github.com/llvm/llvm-project/pull/150511 changed the
canonicalization pattern to not allow casts from ranked to unranked
anymore. This patch restores this functionality, while still keeping the
fix to preserve memory space and layout.

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/test/Dialect/Bufferization/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 7eb729f349638..f1f12f4bca70e 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -806,14 +806,12 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
     if (!srcTensorType)
       return failure();
     auto currentOutputMemRefType =
-        dyn_cast<MemRefType>(toBuffer.getResult().getType());
+        dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
     if (!currentOutputMemRefType)
       return failure();
 
-    auto memrefType = MemRefType::get(srcTensorType.getShape(),
-                                      srcTensorType.getElementType(),
-                                      currentOutputMemRefType.getLayout(),
-                                      currentOutputMemRefType.getMemorySpace());
+    auto memrefType = currentOutputMemRefType.cloneWith(
+        srcTensorType.getShape(), srcTensorType.getElementType());
     Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
                                       tensorCastOperand.getOperand(),
                                       toBuffer.getReadOnly());

diff  --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 2acd19453a04d..ae1d1fcfc19dc 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -263,6 +263,19 @@ func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
 // CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
 // CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
 
+// CHECK-LABEL: func @tensor_cast_to_unranked_buffer
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+func.func @tensor_cast_to_unranked_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
+  memref<*xi8> {
+  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<*xi8>
+  %1 = bufferization.to_buffer %0 read_only : tensor<*xi8> to memref<*xi8>
+  return %1 : memref<*xi8>
+}
+// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
+// CHECK:   %[[M1:.+]] = memref.cast %[[M]]
+// CHECK-SAME: memref<4x6x16x32xi8> to memref<*xi8>
+// CHECK:   return %[[M1]] : memref<*xi8>
+
 // -----
 
 // CHECK-LABEL: func @tensor_cast_to_buffer


        


More information about the Mlir-commits mailing list