[Mlir-commits] [mlir] 7c06f63 - [mlir][tensor][bufferize] Fix dealloc placement in scf.forall op

Matthias Springer llvmlistbot at llvm.org
Sat Apr 15 17:49:46 PDT 2023


Author: Matthias Springer
Date: 2023-04-16T09:34:43+09:00
New Revision: 7c06f63176da05ef45216c13b271a343b72d75d0

URL: https://github.com/llvm/llvm-project/commit/7c06f63176da05ef45216c13b271a343b72d75d0
DIFF: https://github.com/llvm/llvm-project/commit/7c06f63176da05ef45216c13b271a343b72d75d0.diff

LOG: [mlir][tensor][bufferize] Fix dealloc placement in scf.forall op

The terminator of this op is special: it does not just yield a value,
but bufferizes to a memcpy. This requires special treatment to make sure
that deallocs are placed after the memcpy. (By default, deallocs are
placed right before the terminator.)

Differential Revision: https://reviews.llvm.org/D148408

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 43dedf1c4ce06..cbf28073362f6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1058,6 +1058,21 @@ struct ParallelInsertSliceOpInterface
                                     *srcBuffer, subview)))
       return failure();
 
+    // In case the source was allocated in the same block, make sure that the
+    // deallocation op (if any) appears after the memcpy. By default, deallocs
+    // are placed before the terminator, but this does not work for ForallOp
+    // because the terminator does more than just yielding a value.
+    //
+    // Note: This is not a problem for the destination buffer because these are
+    // assumed to always bufferize in-place.
+    for (Operation *user : srcBuffer->getUsers()) {
+      if (hasEffect<MemoryEffects::Free>(user)) {
+        if (user->getBlock() == parallelCombiningParent->getBlock())
+          user->moveBefore(user->getBlock()->getTerminator());
+        break;
+      }
+    }
+
     // Delete the op.
     rewriter.eraseOp(op);
     return success();

diff  --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 89c6974aa9cbf..a4c868cf35d5c 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -335,7 +335,7 @@ func.func @insert_slice_full_overwrite(%t: tensor<10xf32>, %b: tensor<10xf32>) -
 
 // CHECK-LABEL: func @dim_not_reading(
 //  CHECK-SAME:     %[[t:.*]]: memref<?xf32
-func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index) 
+func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
     -> (tensor<?xf32>, index)
 {
   %c0 = arith.constant 0 : index
@@ -370,3 +370,31 @@ func.func @cast_retains_buffer_layout(
   // in the caller.
   return %casted, %slice : tensor<10xf32>, tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice_source_out_of_place
+func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: tensor<100xf32>, %f: f32) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %num_threads = arith.constant 50 : index
+
+  // CHECK: scf.forall {{.*}} {
+  %result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<100xf32> {
+      // The tensor.insert must bufferize out-of-place.
+      // CHECK: memref.alloc
+      // CHECK: memref.store
+      %insert = tensor.insert %f into %in[%c0] : tensor<1xf32>
+      %r = tensor.extract %in[%c0] : tensor<1xf32>
+      vector.print %r : f32
+
+      // CHECK: memref.copy
+      // CHECK: memref.dealloc
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %insert into %o[%thread_idx][1][1] :
+          tensor<1xf32> into tensor<100xf32>
+      }
+  }
+  // CHECK: }
+  return
+}


        


More information about the Mlir-commits mailing list