[Mlir-commits] [mlir] [mlir][bufferization] Clone simplify fails when input and result type not cast compatiable (PR #71310)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 5 02:02:51 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: donald chen (cxy-1993)

<details>
<summary>Changes</summary>

The simplify  of bufferization.clone generates a memref.cast op, but the checks in simplify do not verify whether the operand types and return types of clone op is compatiable, leading to errors. This patch addresses this issue.

---
Full diff: https://github.com/llvm/llvm-project/pull/71310.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+4) 
- (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+12-15) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ec5feab1ed0d856..edaa3945b70eb50 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -457,6 +457,10 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
     }
 
     Value source = cloneOp.getInput();
+    if (!memref::CastOp::areCastCompatible({source.getType()},
+                                           {cloneOp.getType()}))
+      return failure();
+
     // Aims to find the dealloc op for the canonical source
     // which otherwise could prevent removal of unnecessary allocs.
     Value canonicalSource = source;
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 3ba283928a83f0e..3e183c7d9449817 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -156,6 +156,18 @@ func.func @clone_and_cast(%arg0: memref<?xf32>) -> memref<32xf32> {
 
 // -----
 
+// CHECK-LABEL: @clone_incompatible
+func.func @clone_incompatible(%arg0: memref<32xf32, strided<[2]>>) -> memref<32xf32> {
+  %0 = bufferization.clone %arg0 : memref<32xf32, strided<[2]>> to memref<32xf32>
+  memref.dealloc %arg0 : memref<32xf32, strided<[2]>>
+  return %0 : memref<32xf32>
+}
+// CHECK-SAME: %[[ARG:.*]]: memref<32xf32, strided<[2]>>
+// CHECK-NEXT: bufferization.clone %[[ARG]] : memref<32xf32, strided<[2]>> to memref<32xf32>
+// CHECK-NOT: memref.cast
+
+// -----
+
 // CHECK-LABEL: @alias_is_freed
 func.func @alias_is_freed(%arg0 : memref<?xf32>) {
   %0 = memref.cast %arg0 : memref<?xf32> to memref<32xf32>
@@ -267,21 +279,6 @@ func.func @alloc_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
 
 // -----
 
-func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<*xf32> {
-  %c1 = arith.constant 1 : index
-  %0 = memref.alloc(%c1) : memref<?xf32>
-  %1 = memref.reshape %0(%arg0) : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
-  %2 = bufferization.clone %1 : memref<*xf32> to memref<*xf32>
-  memref.dealloc %0 : memref<?xf32>
-  return %2 : memref<*xf32>
-}
-// CHECK-LABEL: @dealloc_canonicalize_clone_removal
-//   CHECK-NOT:   bufferization.clone
-//   CHECK-NOT:   memref.dealloc
-//       CHECK:   return {{.*}}
-
-// -----
-
 func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1) {
   %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg5, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>)
   bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2)

``````````

</details>


https://github.com/llvm/llvm-project/pull/71310


More information about the Mlir-commits mailing list