[Mlir-commits] [mlir] [mlir][bufferization] Canonicalize to_memref(to_tensor(x)) to a CopyO… (PR #126692)

Maya Amrami llvmlistbot at llvm.org
Tue Feb 11 01:56:45 PST 2025


https://github.com/amrami updated https://github.com/llvm/llvm-project/pull/126692

>From 160a65806a4ec485c00974e8a6d4171918dd830f Mon Sep 17 00:00:00 2001
From: Maya Amrami <mayaam88 at gmail.com>
Date: Tue, 11 Feb 2025 09:27:12 +0200
Subject: [PATCH] [mlir][bufferization] Canonicalize to_memref(to_tensor(x)) to
 a CopyOp if memory spaces differ

---
 .../Bufferization/IR/BufferizationOps.cpp     |  4 +---
 .../Dialect/Bufferization/canonicalize.mlir   | 21 +++++++++++--------
 2 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 6be55a1d282240b..4fce9be390bd6c5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -28,11 +28,9 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
     const BufferizationOptions &options) {
   auto srcType = llvm::cast<MemRefType>(value.getType());
 
-  // Element type, rank and memory space must match.
+  // Element type and rank must match.
   if (srcType.getElementType() != destType.getElementType())
     return failure();
-  if (srcType.getMemorySpace() != destType.getMemorySpace())
-    return failure();
   if (srcType.getRank() != destType.getRank())
     return failure();
 
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 3ebc1e4fa8dea34..b662e713e189cee 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -28,17 +28,20 @@ func.func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
 // -----
 
 // If the memrefs are not the same type, don't fold them.
-// If the memrefs are not cast-compatible (e.g. different address space), don't
-// canonicalize them either.
-// CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load(
+// If the memrefs are not cast-compatible but one can be copied into the other
+// (e.g. different address space), canonicalize them to add + copy.
+// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_different_address_space(
 //  CHECK-SAME:   %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>)
 //  CHECK-SAME:     -> memref<?xf32, 7> {
-//       CHECK: %[[TENSOR:.*]] = bufferization.to_tensor
-//  CHECK-SAME:   %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> to tensor<?xf32, 7 : i64>
-//       CHECK: %[[MEMREF_ADDRSPACE7:.*]] = bufferization.to_memref
-//  CHECK-SAME:   %[[TENSOR]] : tensor<?xf32, 7 : i64> to memref<?xf32, 7>
-//       CHECK: return %[[MEMREF_ADDRSPACE7]]
-func.func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>)
+//  CHECK-NOT: bufferization.to_tensor
+//  CHECK-NOT: bufferization.to_memref
+//      CHECK: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[DIM:.*]] = memref.dim %[[MEMREF_ADDRSPACE2]], %[[C0]] : memref<?xf32, 2>
+//      CHECK: %[[MEMREF_ADDRSPACE7:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, 7>
+//      CHECK: memref.copy %[[MEMREF_ADDRSPACE2]], %[[MEMREF_ADDRSPACE7]]
+// CHECK-SAME:   memref<?xf32, 2> to memref<?xf32, 7>
+//      CHECK: return %[[MEMREF_ADDRSPACE7]]
+func.func @canonicalize_buffer_cast_of_tensor_load_different_address_space(%arg0: memref<?xf32, 2>)
     -> memref<?xf32, 7> {
   %0 = bufferization.to_tensor %arg0 : memref<?xf32, 2> to tensor<?xf32, 7>
   %1 = bufferization.to_memref %0 : tensor<?xf32, 7> to memref<?xf32, 7>



More information about the Mlir-commits mailing list