[Mlir-commits] [mlir] d955ca4 - [BufferDeallocation] Don't assume successor operands are unique

Benjamin Kramer llvmlistbot at llvm.org
Thu Feb 17 05:16:41 PST 2022


Author: Benjamin Kramer
Date: 2022-02-17T14:16:32+01:00
New Revision: d955ca49379e73485304eb7f500db53e33109b0f

URL: https://github.com/llvm/llvm-project/commit/d955ca49379e73485304eb7f500db53e33109b0f
DIFF: https://github.com/llvm/llvm-project/commit/d955ca49379e73485304eb7f500db53e33109b0f.diff

LOG: [BufferDeallocation] Don't assume successor operands are unique

This would create a double free when a memref is passed twice to the
same op. This wasn't a problem at the time the pass was written but is
common since the introduction of scf.while.

There's a latent non-determinism that's triggered by the test, but this
change is messy enough as-is so I'll leave that for later.

Differential Revision: https://reviews.llvm.org/D120044

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
    mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index f3646806639e3..6d04dd4e92e0d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -376,17 +376,20 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
 
     // Determine the actual operand to introduce a clone for and rewire the
     // operand to point to the clone instead.
-    Value operand =
-        regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber())
-            [llvm::find(it->getSuccessorInputs(), blockArg).getIndex()];
+    auto operands =
+        regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber());
+    size_t operandIndex =
+        llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
+        operands.getBeginOperandIndex();
+    Value operand = parentOp->getOperand(operandIndex);
+    assert(operand ==
+               operands[operandIndex - operands.getBeginOperandIndex()] &&
+           "region interface operands don't match parentOp operands");
     auto clone = introduceCloneBuffers(operand, parentOp);
     if (failed(clone))
       return failure();
 
-    auto op = llvm::find(parentOp->getOperands(), operand);
-    assert(op != parentOp->getOperands().end() &&
-           "parentOp does not contain operand");
-    parentOp->setOperand(op.getIndex(), *clone);
+    parentOp->setOperand(operandIndex, *clone);
     return success();
   }
 

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
index 0a80265aba50f..e7219b5c8cb7a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir
@@ -1222,3 +1222,61 @@ func @dealloc_existing_clones(%arg0: memref<?x?xf64>, %arg1: memref<?x?xf64>) ->
   %1 = bufferization.clone %arg1 : memref<?x?xf64> to memref<?x?xf64>
   return %0 : memref<?x?xf64>
 }
+
+// -----
+
+// CHECK-LABEL: func @while_two_arg
+func @while_two_arg(%arg0: index) {
+  %a = memref.alloc(%arg0) : memref<?xf32>
+// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ALLOC:.*]], %[[ARG2:.*]] = %[[CLONE:.*]])
+  scf.while (%arg1 = %a, %arg2 = %a) : (memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
+// CHECK-NEXT: make_condition
+    %0 = "test.make_condition"() : () -> i1
+// CHECK-NEXT: bufferization.clone %[[ARG2]]
+// CHECK-NEXT: memref.dealloc %[[ARG2]]
+    scf.condition(%0) %arg1, %arg2 : memref<?xf32>, memref<?xf32>
+  } do {
+  ^bb0(%arg1: memref<?xf32>, %arg2: memref<?xf32>):
+// CHECK: %[[ALLOC2:.*]] = memref.alloc
+    %b = memref.alloc(%arg0) : memref<?xf32>
+// CHECK: memref.dealloc %[[ARG2]]
+// CHECK: %[[CLONE2:.*]] = bufferization.clone %[[ALLOC2]]
+// CHECK: memref.dealloc %[[ALLOC2]]
+    scf.yield %arg1, %b : memref<?xf32>, memref<?xf32>
+  }
+// CHECK: }
+// CHECK-NEXT: memref.dealloc %[[WHILE]]#1
+// CHECK-NEXT: memref.dealloc %[[ALLOC]]
+// CHECK-NEXT: return
+  return
+}
+
+// -----
+
+func @while_three_arg(%arg0: index) {
+// CHECK: %[[ALLOC:.*]] = memref.alloc
+  %a = memref.alloc(%arg0) : memref<?xf32>
+// CHECK-NEXT: %[[CLONE1:.*]] = bufferization.clone %[[ALLOC]]
+// CHECK-NEXT: %[[CLONE2:.*]] = bufferization.clone %[[ALLOC]]
+// CHECK-NEXT: %[[CLONE3:.*]] = bufferization.clone %[[ALLOC]]
+// CHECK-NEXT: memref.dealloc %[[ALLOC]]
+// CHECK-NEXT: %[[WHILE:.*]]:3 = scf.while
+// FIXME: This is non-deterministic
+// CHECK-SAME-DAG: [[CLONE1]]
+// CHECK-SAME-DAG: [[CLONE2]]
+// CHECK-SAME-DAG: [[CLONE3]]
+  scf.while (%arg1 = %a, %arg2 = %a, %arg3 = %a) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
+    %0 = "test.make_condition"() : () -> i1
+    scf.condition(%0) %arg1, %arg2, %arg3 : memref<?xf32>, memref<?xf32>, memref<?xf32>
+  } do {
+  ^bb0(%arg1: memref<?xf32>, %arg2: memref<?xf32>, %arg3: memref<?xf32>):
+    %b = memref.alloc(%arg0) : memref<?xf32>
+    %q = memref.alloc(%arg0) : memref<?xf32>
+    scf.yield %q, %b, %arg2: memref<?xf32>, memref<?xf32>, memref<?xf32>
+  }
+// CHECK-DAG: memref.dealloc %[[WHILE]]#0
+// CHECK-DAG: memref.dealloc %[[WHILE]]#1
+// CHECK-DAG: memref.dealloc %[[WHILE]]#2
+// CHECK-NEXT: return
+  return
+}


        


More information about the Mlir-commits mailing list