[Mlir-commits] [mlir] 8cf4c55 - [mlir][bufferization] Canonicalize to_memref(to_tensor(x)) to a CopyO… (#126692)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 11 03:53:19 PST 2025


Author: Maya Amrami
Date: 2025-02-11T13:53:15+02:00
New Revision: 8cf4c5576d4b9252301c834239791f70f42d94b8

URL: https://github.com/llvm/llvm-project/commit/8cf4c5576d4b9252301c834239791f70f42d94b8
DIFF: https://github.com/llvm/llvm-project/commit/8cf4c5576d4b9252301c834239791f70f42d94b8.diff

LOG: [mlir][bufferization] Canonicalize to_memref(to_tensor(x)) to a CopyO… (#126692)

…p if memory spaces differ

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/test/Dialect/Bufferization/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 6be55a1d282240b..4fce9be390bd6c5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -28,11 +28,9 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
     const BufferizationOptions &options) {
   auto srcType = llvm::cast<MemRefType>(value.getType());
 
-  // Element type, rank and memory space must match.
+  // Element type and rank must match.
   if (srcType.getElementType() != destType.getElementType())
     return failure();
-  if (srcType.getMemorySpace() != destType.getMemorySpace())
-    return failure();
   if (srcType.getRank() != destType.getRank())
     return failure();
 

diff  --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 3ebc1e4fa8dea34..b662e713e189cee 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -28,17 +28,20 @@ func.func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
 // -----
 
 // If the memrefs are not the same type, don't fold them.
-// If the memrefs are not cast-compatible (e.g. 
diff erent address space), don't
-// canonicalize them either.
-// CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load(
+// If the memrefs are not cast-compatible but one can be copied into the other
+// (e.g. 
diff erent address space), canonicalize them to add + copy.
+// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_
diff erent_address_space(
 //  CHECK-SAME:   %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>)
 //  CHECK-SAME:     -> memref<?xf32, 7> {
-//       CHECK: %[[TENSOR:.*]] = bufferization.to_tensor
-//  CHECK-SAME:   %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> to tensor<?xf32, 7 : i64>
-//       CHECK: %[[MEMREF_ADDRSPACE7:.*]] = bufferization.to_memref
-//  CHECK-SAME:   %[[TENSOR]] : tensor<?xf32, 7 : i64> to memref<?xf32, 7>
-//       CHECK: return %[[MEMREF_ADDRSPACE7]]
-func.func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>)
+//  CHECK-NOT: bufferization.to_tensor
+//  CHECK-NOT: bufferization.to_memref
+//      CHECK: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[DIM:.*]] = memref.dim %[[MEMREF_ADDRSPACE2]], %[[C0]] : memref<?xf32, 2>
+//      CHECK: %[[MEMREF_ADDRSPACE7:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, 7>
+//      CHECK: memref.copy %[[MEMREF_ADDRSPACE2]], %[[MEMREF_ADDRSPACE7]]
+// CHECK-SAME:   memref<?xf32, 2> to memref<?xf32, 7>
+//      CHECK: return %[[MEMREF_ADDRSPACE7]]
+func.func @canonicalize_buffer_cast_of_tensor_load_
diff erent_address_space(%arg0: memref<?xf32, 2>)
     -> memref<?xf32, 7> {
   %0 = bufferization.to_tensor %arg0 : memref<?xf32, 2> to tensor<?xf32, 7>
   %1 = bufferization.to_memref %0 : tensor<?xf32, 7> to memref<?xf32, 7>


        


More information about the Mlir-commits mailing list