[Mlir-commits] [mlir] 7c06f63 - [mlir][tensor][bufferize] Fix dealloc placement in scf.forall op
Matthias Springer
llvmlistbot at llvm.org
Sat Apr 15 17:49:46 PDT 2023
Author: Matthias Springer
Date: 2023-04-16T09:34:43+09:00
New Revision: 7c06f63176da05ef45216c13b271a343b72d75d0
URL: https://github.com/llvm/llvm-project/commit/7c06f63176da05ef45216c13b271a343b72d75d0
DIFF: https://github.com/llvm/llvm-project/commit/7c06f63176da05ef45216c13b271a343b72d75d0.diff
LOG: [mlir][tensor][bufferize] Fix dealloc placement in scf.forall op
The terminator of this op is special: it does not just yield a value,
but bufferizes to a memcpy. This requires special treatment to make sure
that deallocs are placed after the memcpy. (By default, deallocs are
placed right before the terminator.)
Differential Revision: https://reviews.llvm.org/D148408
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 43dedf1c4ce06..cbf28073362f6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1058,6 +1058,21 @@ struct ParallelInsertSliceOpInterface
*srcBuffer, subview)))
return failure();
+ // In case the source was allocated in the same block, make sure that the
+ // deallocation op (if any) appears after the memcpy. By default, deallocs
+ // are placed before the terminator, but this does not work for ForallOp
+ // because the terminator does more than just yielding a value.
+ //
+ // Note: This is not a problem for the destination buffer because these are
+ // assumed to always bufferize in-place.
+ for (Operation *user : srcBuffer->getUsers()) {
+ if (hasEffect<MemoryEffects::Free>(user)) {
+ if (user->getBlock() == parallelCombiningParent->getBlock())
+ user->moveBefore(user->getBlock()->getTerminator());
+ break;
+ }
+ }
+
// Delete the op.
rewriter.eraseOp(op);
return success();
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 89c6974aa9cbf..a4c868cf35d5c 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -335,7 +335,7 @@ func.func @insert_slice_full_overwrite(%t: tensor<10xf32>, %b: tensor<10xf32>) -
// CHECK-LABEL: func @dim_not_reading(
// CHECK-SAME: %[[t:.*]]: memref<?xf32
-func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
+func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
-> (tensor<?xf32>, index)
{
%c0 = arith.constant 0 : index
@@ -370,3 +370,31 @@ func.func @cast_retains_buffer_layout(
// in the caller.
return %casted, %slice : tensor<10xf32>, tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice_source_out_of_place
+func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: tensor<100xf32>, %f: f32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %num_threads = arith.constant 50 : index
+
+ // CHECK: scf.forall {{.*}} {
+ %result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<100xf32> {
+ // The tensor.insert must bufferize out-of-place.
+ // CHECK: memref.alloc
+ // CHECK: memref.store
+ %insert = tensor.insert %f into %in[%c0] : tensor<1xf32>
+ %r = tensor.extract %in[%c0] : tensor<1xf32>
+ vector.print %r : f32
+
+ // CHECK: memref.copy
+ // CHECK: memref.dealloc
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %insert into %o[%thread_idx][1][1] :
+ tensor<1xf32> into tensor<100xf32>
+ }
+ }
+ // CHECK: }
+ return
+}
More information about the Mlir-commits
mailing list