[Mlir-commits] [mlir] [MLIR][buffer-deallocation] Introduce copies only for MemRef typed values. (PR #121582)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 3 08:53:01 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (erick-xanadu)

<details>
<summary>Changes</summary>

Hello,

I know that the buffer-deallocation pass is deprecated, but just in case anyone else is still relying on it, we found this issue with buffer-deallocation. The following test case segfaults on main:

```mlir
func.func public @<!-- -->jit_main() {
    %c0 = arith.constant 0 : i1

    %4 = scf.while (%arg1 = %c0) : (i1) -> (i1) {
      scf.condition(%c0) %arg1 : i1
    } do {
    ^bb0(%arg1: i1):
      %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<i32>

      %7 = func.call @<!-- -->jitted_fun_0(%alloc_1, %arg1) : (memref<i32>, i1) -> (i1)

      scf.yield %7#<!-- -->0: i1
    }

    return
}

func.func private @<!-- -->jitted_fun_0(%arg1: memref<i32>, %arg2: i1) -> (i1) {
    return %arg2 : i1
}
```

It looks like it is attempting to introduce a clone for the i1. With the changes submitted, I believe the error is corrected. Happy to improve the test. It would be nice if this is still merged (if deemed a fix) despite buffer-deallocation being deprecated as other people may encounter this.

Thanks, and happy to make changes to improve this fix.

---
Full diff: https://github.com/llvm/llvm-project/pull/121582.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp (+3) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir (+21) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index a0a81d4add7121..36edbb3395eddf 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -308,6 +308,9 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
 
     // Add new allocs and additional clone operations.
     for (Value value : valuesToFree) {
+      if (!isa<BaseMemRefType>(value.getType())) {
+          continue;
+      }
       if (failed(isa<BlockArgument>(value)
                      ? introduceBlockArgCopy(cast<BlockArgument>(value))
                      : introduceValueCopyForRegionResult(value)))
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
index 3fbe3913c6549e..a6e6f724a822ac 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
@@ -1271,6 +1271,27 @@ func.func @while_two_arg(%arg0: index) {
 
 // -----
 
+// CHECK-LABEL: func @while_fun
+func.func @while_fun() {
+    %c0 = arith.constant 0 : i1
+    %4 = scf.while (%arg1 = %c0) : (i1) -> (i1) {
+      scf.condition(%c0) %arg1 : i1
+    } do {
+    ^bb0(%arg1: i1):
+      %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<i32>
+      %7 = func.call @foo(%alloc_1, %arg1) : (memref<i32>, i1) -> (i1)
+      scf.yield %7#0: i1
+    }
+    return
+}
+
+func.func private @foo(%arg1: memref<i32>, %arg2: i1) -> (i1) {
+    return %arg2 : i1
+}
+
+// -----
+
+// CHECK-LABEL: func @while_three_arg
 func.func @while_three_arg(%arg0: index) {
 // CHECK: %[[ALLOC:.*]] = memref.alloc
   %a = memref.alloc(%arg0) : memref<?xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/121582


More information about the Mlir-commits mailing list