[Mlir-commits] [mlir] [mlir][bufferization] Clone simplify fails when input and result type not cast compatiable (PR #71310)
donald chen
llvmlistbot at llvm.org
Thu Jan 11 06:31:26 PST 2024
https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/71310
>From 44efb00fc082c36f0cc4c52b72b0463e66f261fd Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Sun, 5 Nov 2023 17:31:44 +0800
Subject: [PATCH] [mlir] Clone simplify fails when input and result type not
cast compatiable
Fixed a bug that caused a cast-incompatible memref.cast operation when
simplifying the clone operation.
---
.../Bufferization/IR/BufferizationOps.cpp | 4 ++++
.../Dialect/Bufferization/canonicalize.mlir | 19 +++++++++++++++----
2 files changed, 19 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ca0d2f407c2d83..eeeb2a92ca5afe 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -457,6 +457,10 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
}
Value source = cloneOp.getInput();
+ if (!memref::CastOp::areCastCompatible({source.getType()},
+ {cloneOp.getType()}))
+ return failure();
+
// Aims to find the dealloc op for the canonical source
// which otherwise could prevent removal of unnecessary allocs.
Value canonicalSource = source;
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 3ba283928a83f0..67a06a6b15f996 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -156,6 +156,18 @@ func.func @clone_and_cast(%arg0: memref<?xf32>) -> memref<32xf32> {
// -----
+// CHECK-LABEL: @clone_incompatible
+func.func @clone_incompatible(%arg0: memref<32xf32, strided<[2]>>) -> memref<32xf32> {
+ %0 = bufferization.clone %arg0 : memref<32xf32, strided<[2]>> to memref<32xf32>
+ memref.dealloc %arg0 : memref<32xf32, strided<[2]>>
+ return %0 : memref<32xf32>
+}
+// CHECK-SAME: %[[ARG:.*]]: memref<32xf32, strided<[2]>>
+// CHECK-NEXT: bufferization.clone %[[ARG]] : memref<32xf32, strided<[2]>> to memref<32xf32>
+// CHECK-NOT: memref.cast
+
+// -----
+
// CHECK-LABEL: @alias_is_freed
func.func @alias_is_freed(%arg0 : memref<?xf32>) {
%0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32>
@@ -267,13 +279,12 @@ func.func @alloc_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
// -----
-func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<*xf32> {
+func.func @dealloc_canonicalize_clone_removal() -> memref<?xf32> {
%c1 = arith.constant 1 : index
%0 = memref.alloc(%c1) : memref<?xf32>
- %1 = memref.reshape %0(%arg0) : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
- %2 = bufferization.clone %1 : memref<*xf32> to memref<*xf32>
+ %1 = bufferization.clone %0 : memref<?xf32> to memref<?xf32>
memref.dealloc %0 : memref<?xf32>
- return %2 : memref<*xf32>
+ return %1 : memref<?xf32>
}
// CHECK-LABEL: @dealloc_canonicalize_clone_removal
// CHECK-NOT: bufferization.clone
More information about the Mlir-commits
mailing list