[Mlir-commits] [mlir] [mlir][bufferization] Canonicalize to_memref(to_tensor(x)) to a CopyO… (PR #126692)
Maya Amrami
llvmlistbot at llvm.org
Mon Feb 10 23:27:58 PST 2025
https://github.com/amrami created https://github.com/llvm/llvm-project/pull/126692
…p if memory spaces differ
>From 53aee45bd91831bd8b82a8bdc8da1f793374dc0a Mon Sep 17 00:00:00 2001
From: Maya Amrami <mayaam88 at gmail.com>
Date: Tue, 11 Feb 2025 09:27:12 +0200
Subject: [PATCH] [mlir][bufferization] Canonicalize to_memref(to_tensor(x)) to
a CopyOp if memory spaces differ
---
.../Bufferization/IR/BufferizationOps.cpp | 4 +---
.../Dialect/Bufferization/canonicalize.mlir | 22 ++++++++++---------
2 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 6be55a1d282240b..4fce9be390bd6c5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -28,11 +28,9 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
const BufferizationOptions &options) {
auto srcType = llvm::cast<MemRefType>(value.getType());
- // Element type, rank and memory space must match.
+ // Element type and rank must match.
if (srcType.getElementType() != destType.getElementType())
return failure();
- if (srcType.getMemorySpace() != destType.getMemorySpace())
- return failure();
if (srcType.getRank() != destType.getRank())
return failure();
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 3ebc1e4fa8dea34..833d83be89b5aab 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -27,18 +27,20 @@ func.func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
// -----
-// If the memrefs are not the same type, don't fold them.
-// If the memrefs are not cast-compatible (e.g. different address space), don't
-// canonicalize them either.
-// CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load(
+// If the memrefs are not cast-compatible (e.g. different address space),
+// canonicalize them to copy.
+// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_different_address_space(
// CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>)
// CHECK-SAME: -> memref<?xf32, 7> {
-// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor
-// CHECK-SAME: %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> to tensor<?xf32, 7 : i64>
-// CHECK: %[[MEMREF_ADDRSPACE7:.*]] = bufferization.to_memref
-// CHECK-SAME: %[[TENSOR]] : tensor<?xf32, 7 : i64> to memref<?xf32, 7>
-// CHECK: return %[[MEMREF_ADDRSPACE7]]
-func.func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>)
+// CHECK-NOT: bufferization.to_tensor
+// CHECK-NOT: bufferization.to_memref
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = memref.dim %[[MEMREF_ADDRSPACE2]], %[[C0]] : memref<?xf32, 2>
+// CHECK: %[[MEMREF_ADDRSPACE7:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, 7>
+// CHECK: memref.copy %[[MEMREF_ADDRSPACE2]], %[[MEMREF_ADDRSPACE7]]
+// CHECK-SAME: memref<?xf32, 2> to memref<?xf32, 7>
+// CHECK: return %[[MEMREF_ADDRSPACE7]]
+func.func @canonicalize_buffer_cast_of_tensor_load_different_address_space(%arg0: memref<?xf32, 2>)
-> memref<?xf32, 7> {
%0 = bufferization.to_tensor %arg0 : memref<?xf32, 2> to tensor<?xf32, 7>
%1 = bufferization.to_memref %0 : tensor<?xf32, 7> to memref<?xf32, 7>
More information about the Mlir-commits
mailing list