[Mlir-commits] [mlir] 4cd7362 - [mlir][SCF] foreach_thread: Capture shared output tensors explicitly

Matthias Springer llvmlistbot at llvm.org
Fri Sep 2 05:54:21 PDT 2022


Author: Matthias Springer
Date: 2022-09-02T14:54:04+02:00
New Revision: 4cd7362083c8801bbc84d2c43b086d1f8f0de93f

URL: https://github.com/llvm/llvm-project/commit/4cd7362083c8801bbc84d2c43b086d1f8f0de93f
DIFF: https://github.com/llvm/llvm-project/commit/4cd7362083c8801bbc84d2c43b086d1f8f0de93f.diff

LOG: [mlir][SCF] foreach_thread: Capture shared output tensors explicitly

This change refines the semantics of scf.foreach_thread. Tensors that are inserted into in the terminator must now be passed to the region explicitly via `shared_outs`. Inside of the body of the op, those tensors are then accessed via block arguments.

The body of a scf.foreach_thread is now treated as a repetitive region. I.e., op dominance can no longer be used in conflict detection when using a value that is defined outside of the body. Such uses may now be considered as conflicts (if there is at least one read and one write in the body), effectively privatizing the tensor. Shared outputs are not privatized when they are used via their corresponding block arguments.

As part of this change, it was also necessary to update the "tiling to scf.foreach_thread", such that the generated tensor.extract_slice ops use the scf.foreach_thread's block arguments. This is implemented by cloning the TilingInterface op inside the scf.foreach_thread, rewriting all of its outputs with block arguments and then calling the tiling implementation. Afterwards, the cloned op is deleted again.

Differential Revision: https://reviews.llvm.org/D133114

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
    mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
    mlir/test/Dialect/SCF/invalid.mlir
    mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir
    mlir/test/Dialect/SCF/ops.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 655380e4ec786..1cd1cb59e0965 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -324,6 +324,7 @@ def ForOp : SCF_Op<"for",
 //===----------------------------------------------------------------------===//
 
 def ForeachThreadOp : SCF_Op<"foreach_thread", [
+       AttrSizedOperandSegments,
        SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">,
        RecursiveSideEffects,
        AutomaticAllocationScope,
@@ -335,6 +336,17 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     parallel body and it takes index operands that indicate how many parallel
     instances of that function are created.
 
+    The op also takes a variadic number of tensor operands (`shared_outs`).
+    The future buffers corresponding to these tensors are shared among all
+    threads. Shared tensors should be accessed via their corresponding block
+    arguments. If multiple threads write to a shared buffer in a racy
+    fashion, these writes will execute in some unspecified order. Tensors that
+    are not shared can be used inside the body (i.e., the op is not isolated
+    from above); however, if a use of such a tensor bufferizes to a memory
+    write, the tensor is privatized, i.e., a thread-local copy of the tensor is
+    used. This ensures that memory side effects of a thread are not visible to
+    other threads (or in the parent body), apart from explicitly shared tensors.
+
     The name "thread" conveys the fact that the parallel execution is mapped
     (i.e. distributed) to a set of virtual threads of execution, one function
     application per thread. Further lowerings are responsible for specifying
@@ -349,26 +361,20 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     context of the concrete target the op is lowered to, or to ignore it when
     the specification is ill-formed or unsupported for a particular target.
 
-    The only allowed terminator is `scf.foreach_thread.perform_concurrently`,
-    which dictates how the partial results of all parallel invocations should be
-    reconciled into a full value.
+    The only allowed terminator is `scf.foreach_thread.perform_concurrently`.
+    `scf.foreach_thread` returns one value per `shared_out` operand. The
+    actions of the `perform_concurrently` terminators specify how to combine the
+    partial results of all parallel invocations into a full value, in some
+    unspecified order. The "destination" of each such op must be a `shared_out`
+    block argument of the `scf.foreach_thread` op.
 
-    `scf.foreach_thread` returns values that are formed by aggregating the
-    actions of all the `perform_concurrently` terminator of all the virtual
-    threads, in some unspecified order.
-    In other words, `scf.foreach_thread` performs all actions specified in the
-    `perform_concurrently` terminator, after it receives the control back from
-    its body along each virtual thread of execution.
     The actions involved in constructing the return values are further described
-    by [parallel_insert_slice](#parallelinsertslice-parallelinsertsliceop).
+    by `tensor.parallel_insert_slice`.
 
     `scf.foreach_thread` acts as an implicit synchronization point.
 
-    Multi-value returns are encoded by including multiple operations inside the
-    `perform_concurrently` block.
-
-    When the parallel function body has side effects, the order of reads and
-    writes to memory is unspecified across threads.
+    When the parallel function body has side effects, their order is unspecified
+    across threads.
 
     Example:
 
@@ -377,7 +383,8 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     // Sequential context.
     //
     %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
-         (%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
+        (%num_threads_1, %numthread_id_2) shared_outs(%o1 = %C, %o2 = %pointwise)
+      -> (tensor<?x?xT>, tensor<?xT>) {
       //
       // Parallel context, each thread with id = (%thread_id_1, %thread_id_2)
       // runs its version of the code.
@@ -386,21 +393,19 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
         tensor<?x?xT> to tensor<?x?xT>
       %sB = tensor.extract_slice %B[g((%thread_id_1, %thread_id_2))]:
         tensor<?x?xT> to tensor<?x?xT>
-      %sC = tensor.extract_slice %C[h((%thread_id_1, %thread_id_2))]:
+      %sC = tensor.extract_slice %o1[h((%thread_id_1, %thread_id_2))]:
         tensor<?x?xT> to tensor<?x?xT>
       %sD = matmul ins(%sA, %sB) outs(%sC)
 
-      %spointwise = subtensor %pointwise[i((%thread_id_1, %thread_id_2))]:
+      %spointwise = subtensor %o2[i((%thread_id_1, %thread_id_2))]:
         tensor<?xT> to tensor<?xT>
       %sE = add ins(%spointwise) outs(%sD)
 
       scf.foreach_thread.perform_concurrently {
-        // First op within the parallel terminator contributes to producing %matmul_and_pointwise#0.
-        scf.foreach_thread.parallel_insert_slice %sD into %C[h((%thread_id_1, %thread_id_2))]:
+        scf.foreach_thread.parallel_insert_slice %sD into %o1[h((%thread_id_1, %thread_id_2))]:
           tensor<?x?xT> into tensor<?x?xT>
 
-        // Second op within the parallel terminator contributes to producing %matmul_and_pointwise#1.
-        scf.foreach_thread.parallel_insert_slice %spointwise into %pointwise[i((%thread_id_1, %thread_id_2))]:
+        scf.foreach_thread.parallel_insert_slice %spointwise into %o2[i((%thread_id_1, %thread_id_2))]:
           tensor<?xT> into tensor<?xT>
       }
     }
@@ -414,7 +419,8 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     // Sequential context.
     //
     %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
-         (%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
+        (%num_threads_1, %numthread_id_2) shared_outs(...)
+      -> (tensor<?x?xT>, tensor<?xT>) {
       //
       // Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)**
       // runs its version of the code.
@@ -426,9 +432,23 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
     // Implicit synchronization point.
     // Sequential context.
     //
+
+    Example with privatized tensors:
+    %t0 = ...
+    %t1 = ...
+    %r = scf.foreach_thread ... shared_outs(%o = t0) -> tensor<?xf32> {
+      // %t0 and %t1 are privatized. %t0 is definitely copied for each thread
+      // because the scf.foreach_thread op's %t0 use bufferizes to a memory
+      // write. In the absence of other conflicts, %t1 is copied only if there
+      // are uses of %t1 in the body that bufferize to a memory read and to a
+      // memory write.
+      "some_use"(%t0)
+      "some_use"(%t1)
+    }
   }];
   let arguments = (ins Variadic<Index>:$num_threads,
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
+                       Variadic<AnyRankedTensor>:$outputs,
+                       DefaultValuedAttr<I64ArrayAttr, "{}">:$thread_dim_mapping);
 
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
@@ -439,19 +459,48 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
   // The default builder does not add the proper body BBargs, roll our own.
   let skipDefaultBuilders = 1;
   let builders = [
-    // Bodyless builder, result types must be specified.
-    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads,
+    // Bodyless builder, outputs must be specified.
+    OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
                    CArg<"ArrayRef<int64_t>", "{}">:$thread_dim_mapping)>,
-    // Builder that takes a bodyBuilder lambda, result types are inferred from
-    // the terminator.
-    OpBuilder<(ins "ValueRange":$num_threads,
+    // Builder that takes a bodyBuilder lambda.
+    OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads,
                    "ArrayRef<int64_t>":$thread_dim_mapping,
                    "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
   ];
   let extraClassDeclaration = [{
     int64_t getRank() { return getNumThreads().size(); }
-    ::mlir::ValueRange getThreadIndices() { return getBody()->getArguments(); }
-    ::mlir::Value getThreadIndex(int64_t idx) { return getBody()->getArgument(idx); }
+
+    OpResult getTiedOpResult(OpOperand *opOperand) {
+      assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
+      return getOperation()->getOpResult(
+          opOperand->getOperandNumber() - getRank());
+    }
+
+    OpOperand *getTiedOpOperand(BlockArgument bbArg) {
+      assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg");
+      return &getOperation()->getOpOperand(bbArg.getArgNumber());
+    }
+
+    BlockArgument getTiedBlockArgument(OpOperand *opOperand) {
+      assert(opOperand->getOperandNumber() >= getRank() && "invalid operand");
+      return getBody()->getArgument(opOperand->getOperandNumber());
+    }
+
+    ArrayRef<BlockArgument> getOutputBlockArguments() {
+      return getBody()->getArguments().drop_front(getRank());
+    }
+
+    ::mlir::ValueRange getThreadIndices() {
+      return getBody()->getArguments().take_front(getRank());
+    }
+
+    ::mlir::Value getThreadIndex(int64_t idx) {
+      return getThreadIndices()[idx];
+    }
+
+    ::mlir::Block::BlockArgListType getRegionOutArgs() {
+      return getBody()->getArguments().drop_front(getRank());
+    }
 
     // The ensureTerminator method generated by SingleBlockImplicitTerminator is
     // unaware of the fact that our terminator also needs a region to be
@@ -497,7 +546,7 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
   // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
   // appear inside perform_concurrently.
   let extraClassDeclaration = [{
-    ::llvm::SmallVector<::mlir::Type> getYieldedTypes();
+    ::llvm::SmallVector<::mlir::BlockArgument> getDests();
     ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
     ::mlir::OpResult getParentResult(int64_t idx);
   }];

diff  --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index 45497fbe038db..29870c58adeff 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -17,11 +17,7 @@ include "mlir/IR/OpBase.td"
 
 def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
   let description = [{
-    A parallel combining op is an op with a region, that is not isolated from
-    above and yields values to its parent op without itself returning an SSA
-    value. The yielded values are determined by subvalues produced by the ops 
-    contained in the region (the `yieldingOps`) and combined in any unspecified
-    order to produce the values yielded to the parent op.
+    A parallel combining op is an op with a region.
 
     This is useful as a terminator to parallel operations that iterate over 
     some set and return tensors while avoiding tight coupling between the 
@@ -53,18 +49,6 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
         return $_op.getYieldingOps();
       }]
     >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the contained ops that yield subvalues that this op combines to
-        yield to its parent.
-      }],
-      /*retTy=*/"::llvm::SmallVector<::mlir::Type>",
-      /*methodName=*/"getYieldedTypes",
-      /*args=*/(ins),
-      /*methodBody=*/[{
-        return $_op.getYieldedTypes();
-      }]
-    >,
   ];
   // TODO: Single region single block interface on interfaces ?
   let verify = [{

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index c6a252c674261..3d7f3212a68a3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -235,8 +235,8 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
   if (llvm::any_of(loopRanges, hasStrideOne))
     return op->emitOpError("only stride-1 supported atm");
   // TODO: support `getTiledImplementation` with >1 produced tiled ops.
-  auto destOperands = op.getDestinationOperands(b);
-  if (destOperands.size() != 1)
+  auto dest = op.getDestinationOperands(b);
+  if (dest.size() != 1)
     return op->emitOpError("only single dest operand supported atm");
 
   SmallVector<OpFoldResult> nonZeroNumThreads =
@@ -255,8 +255,7 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
   // version because we require the use of RewriterBase in the body, so we
   // manually move the insertion point to the body below.
   scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
-      loc, op->getResultTypes(), ValueRange(materializedNonZeroNumThreads),
-      threadDimMapping);
+      loc, dest, ValueRange(materializedNonZeroNumThreads), threadDimMapping);
 
   // Fill out the ForeachThreadOp body.
   b.setInsertionPointToStart(foreachThreadOp.getBody(0));
@@ -317,17 +316,34 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
     ++threadIdIdx;
   }
 
+  // Clone the tileable op and update its destination operands to use the output
+  // bbArgs of the ForeachThreadOp.
+  ArrayRef<BlockArgument> destBbArgs =
+      foreachThreadOp.getOutputBlockArguments();
+  Operation *clonedOp = b.clone(*op.getOperation());
+  auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
+  if (destinationStyleOp) {
+    for (OpOperand *outOperand : destinationStyleOp.getOutputOperands()) {
+      auto it = llvm::find(dest, outOperand->get());
+      assert(it != dest.end() && "dest operand not found in dest");
+      unsigned destNum = std::distance(dest.begin(), it);
+      outOperand->set(destBbArgs[destNum]);
+    }
+  }
+
+  // Tile the cloned op and delete the clone.
   SmallVector<Operation *> tiledOps =
-      op.getTiledImplementation(b, tiledOffsets, tiledSizes);
+      cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
+                                                             tiledSizes);
+  b.eraseOp(clonedOp);
   assert(tiledOps.size() == 1 && "expected a single produced tiled op");
   tiledOp = tiledOps.front();
 
   auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
   assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
   OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
-  for (auto it :
-       llvm::zip(llvm::seq(unsigned(0), unsigned(destOperands.size())),
-                 tilingInterfaceOp->getResults(), destOperands)) {
+  for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
+                           tilingInterfaceOp->getResults(), destBbArgs)) {
     b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
     SmallVector<OpFoldResult> resultOffsets, resultSizes;
     if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2de15969288d2..548bee85e24c2 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1055,26 +1055,25 @@ LogicalResult ForeachThreadOp::verify() {
   if (failed(getTerminator().verify()))
     return failure();
 
-  // Check that the body defines as single block argument for the thread index.
+  // Check number of outputs.
+  if (getNumResults() != getOutputs().size())
+    return emitOpError("produces ")
+           << getNumResults() << " results, but has only "
+           << getOutputs().size() << " outputs";
+
+  // Check that the body defines block arguments for thread indices and outputs.
   auto *body = getBody();
-  if (body->getNumArguments() != getRank())
+  if (body->getNumArguments() != getRank() + getOutputs().size())
     return emitOpError("region expects ") << getRank() << " arguments";
+  for (int64_t i = 0; i < getRank(); ++i)
+    if (!body->getArgument(i).getType().isIndex())
+      return emitOpError("expects ")
+             << i << "-th block argument to be an index";
+  for (unsigned i = 0; i < getOutputs().size(); ++i)
+    if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType())
+      return emitOpError("type mismatch between ")
+             << i << "-th output and corresponding block argument";
 
-  // Verify consistency between the result types and the terminator.
-  auto terminatorTypes = getTerminator().getYieldedTypes();
-  auto opResults = getResults();
-  if (opResults.size() != terminatorTypes.size())
-    return emitOpError("produces ")
-           << opResults.size() << " results, but its terminator yields "
-           << terminatorTypes.size() << " value(s)";
-  unsigned i = 0;
-  for (auto e : llvm::zip(terminatorTypes, opResults)) {
-    if (std::get<0>(e) != std::get<1>(e).getType())
-      return emitOpError() << "type mismatch between result " << i << " ("
-                           << std::get<1>(e).getType() << ") and terminator ("
-                           << std::get<0>(e) << ")";
-    i++;
-  }
   return success();
 }
 
@@ -1083,11 +1082,16 @@ void ForeachThreadOp::print(OpAsmPrinter &p) {
   llvm::interleaveComma(getThreadIndices(), p);
   p << ") in (";
   llvm::interleaveComma(getNumThreads(), p);
-  p << ") -> (" << getResultTypes() << ") ";
+  p << ")";
+  printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
+  p << " ";
+  if (!getRegionOutArgs().empty())
+    p << "-> (" << getResultTypes() << ") ";
   p.printRegion(getRegion(),
                 /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/getNumResults() > 0);
-  p.printOptionalAttrDict(getOperation()->getAttrs());
+  p.printOptionalAttrDict(getOperation()->getAttrs(),
+                          {"operand_segment_sizes"});
 }
 
 ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
@@ -1109,15 +1113,34 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
                              result.operands))
     return failure();
 
-  // Parse optional results.
-  if (parser.parseOptionalArrowTypeList(result.types))
-    return failure();
+  // Parse out operands and results.
+  SmallVector<OpAsmParser::Argument, 4> regionOutArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> outOperands;
+  SMLoc outOperandsLoc = parser.getCurrentLocation();
+  if (succeeded(parser.parseOptionalKeyword("shared_outs"))) {
+    if (outOperands.size() != result.types.size())
+      return parser.emitError(outOperandsLoc,
+                              "mismatch between out operands and types");
+    if (parser.parseAssignmentList(regionOutArgs, outOperands) ||
+        parser.parseOptionalArrowTypeList(result.types) ||
+        parser.resolveOperands(outOperands, result.types, outOperandsLoc,
+                               result.operands))
+      return failure();
+  }
 
   // Parse region.
+  SmallVector<OpAsmParser::Argument, 4> regionArgs;
   std::unique_ptr<Region> region = std::make_unique<Region>();
-  for (auto &idx : threadIndices)
+  for (auto &idx : threadIndices) {
     idx.type = builder.getIndexType();
-  if (parser.parseRegion(*region, threadIndices))
+    regionArgs.push_back(idx);
+  }
+  for (const auto &it : llvm::enumerate(regionOutArgs)) {
+    auto &out = it.value();
+    out.type = result.types[it.index()];
+    regionArgs.push_back(out);
+  }
+  if (parser.parseRegion(*region, regionArgs))
     return failure();
 
   // Ensure terminator and move region.
@@ -1128,19 +1151,27 @@ ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
   // Parse the optional attribute list.
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
-
+  result.addAttribute("operand_segment_sizes",
+                      parser.getBuilder().getDenseI32ArrayAttr(
+                          {static_cast<int32_t>(threadNums.size()),
+                           static_cast<int32_t>(outOperands.size())}));
   return success();
 }
 
-// Bodyless builder, result types must be specified.
+// Bodyless builder, outputs must be specified.
 void ForeachThreadOp::build(mlir::OpBuilder &builder,
-                            mlir::OperationState &result, TypeRange resultTypes,
+                            mlir::OperationState &result, ValueRange outputs,
                             ValueRange numThreads,
                             ArrayRef<int64_t> threadDimMapping) {
   result.addOperands(numThreads);
+  result.addOperands(outputs);
+  result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name),
+                      builder.getI64ArrayAttr(threadDimMapping));
   result.addAttribute(
-      // TODO: getThreadDimMappingAttrName() but it is not a static member.
-      "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
+      "operand_segment_sizes",
+      builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
+                                    static_cast<int32_t>(outputs.size())}));
+  result.addTypes(TypeRange(outputs));
 
   Region *bodyRegion = result.addRegion();
   OpBuilder::InsertionGuard g(builder);
@@ -1149,40 +1180,51 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder,
   // expects it ..
   builder.createBlock(bodyRegion);
   Block &bodyBlock = bodyRegion->front();
+  // Add block arguments for indices and outputs.
   bodyBlock.addArguments(
       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
       SmallVector<Location>(numThreads.size(), result.location));
+  bodyBlock.addArguments(
+      TypeRange(outputs),
+      SmallVector<Location>(outputs.size(), result.location));
   ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
-  result.addTypes(resultTypes);
 }
 
-// Builder that takes a bodyBuilder lambda, result types are inferred from
-// the terminator.
+// Builder that takes a bodyBuilder lambda.
 void ForeachThreadOp::build(
-    mlir::OpBuilder &builder, mlir::OperationState &result,
+    mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs,
     ValueRange numThreads, ArrayRef<int64_t> threadDimMapping,
     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
   result.addOperands(numThreads);
+  result.addOperands(outputs);
+  result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name),
+                      builder.getI64ArrayAttr(threadDimMapping));
   result.addAttribute(
-      // TODO: getThreadDimMappingAttrName() but it is not a static member.
-      "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping));
+      "operand_segment_sizes",
+      builder.getDenseI32ArrayAttr({static_cast<int32_t>(numThreads.size()),
+                                    static_cast<int32_t>(outputs.size())}));
+  result.addTypes(TypeRange(outputs));
 
-  OpBuilder::InsertionGuard g(builder);
   Region *bodyRegion = result.addRegion();
+  OpBuilder::InsertionGuard g(builder);
   builder.createBlock(bodyRegion);
   Block &bodyBlock = bodyRegion->front();
+  // Add block arguments for indices and outputs.
   bodyBlock.addArguments(
       SmallVector<Type>(numThreads.size(), builder.getIndexType()),
       SmallVector<Location>(numThreads.size(), result.location));
+  bodyBlock.addArguments(
+      TypeRange(outputs),
+      SmallVector<Location>(outputs.size(), result.location));
 
-  OpBuilder::InsertionGuard guard(builder);
   builder.setInsertionPointToStart(&bodyBlock);
   bodyBuilder(builder, result.location, bodyBlock.getArguments());
+#ifndef NDEBUG
   auto terminator =
       llvm::dyn_cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
   assert(terminator &&
          "expected bodyBuilder to create PerformConcurrentlyOp terminator");
-  result.addTypes(terminator.getYieldedTypes());
+#endif // NDEBUG
 }
 
 // The ensureTerminator method generated by SingleBlockImplicitTerminator is
@@ -1223,12 +1265,23 @@ void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) {
 }
 
 LogicalResult PerformConcurrentlyOp::verify() {
+  scf::ForeachThreadOp foreachThreadOp =
+      dyn_cast<scf::ForeachThreadOp>(getOperation()->getParentOp());
+  if (!foreachThreadOp)
+    return this->emitOpError("expected foreach_thread op parent");
+
   // TODO: PerformConcurrentlyOpInterface.
-  for (const Operation &op : getRegion().front().getOperations()) {
+  for (Operation &op : getRegion().front().getOperations()) {
     if (!isa<tensor::ParallelInsertSliceOp>(op)) {
       return this->emitOpError("expected only ")
              << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
     }
+
+    // Verify that inserts are into out block arguments.
+    Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
+    ArrayRef<BlockArgument> regionOutArgs = foreachThreadOp.getRegionOutArgs();
+    if (llvm::find(regionOutArgs, dest) == regionOutArgs.end())
+      return op.emitOpError("may only insert into an output block argument");
   }
   return success();
 }
@@ -1264,11 +1317,12 @@ OpResult PerformConcurrentlyOp::getParentResult(int64_t idx) {
   return getOperation()->getParentOp()->getResult(idx);
 }
 
-SmallVector<Type> PerformConcurrentlyOp::getYieldedTypes() {
+SmallVector<BlockArgument> PerformConcurrentlyOp::getDests() {
   return llvm::to_vector<4>(
       llvm::map_range(getYieldingOps(), [](Operation &op) {
-        auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
-        return insertSliceOp ? insertSliceOp.yieldedType() : Type();
+        // Add new ops here as needed.
+        auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
+        return insertSliceOp.getDest().cast<BlockArgument>();
       }));
 }
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index a13badaba7cd9..fd0ff88657900 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1054,18 +1054,6 @@ struct YieldOpInterface
   }
 };
 
-/// Return the destinations that an ForeachThreadOp is inserting into. One per
-/// ParallelInsertSliceOp.
-static SmallVector<OpOperand *>
-getInsertionDest(ForeachThreadOp foreachThreadOp) {
-  PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator();
-  SmallVector<OpOperand *> result;
-  terminator.walk([&](tensor::ParallelInsertSliceOp insertOp) {
-    result.push_back(&insertOp->getOpOperand(1) /*dest*/);
-  });
-  return result;
-}
-
 /// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
 /// region. There are op interfaces for the terminators (PerformConcurrentlyOp
 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
@@ -1073,57 +1061,114 @@ getInsertionDest(ForeachThreadOp foreachThreadOp) {
 struct ForeachThreadOpInterface
     : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
                                                     ForeachThreadOp> {
-  SmallVector<OpOperand *>
-  getAliasingOpOperand(Operation *op, OpResult opResult,
-                       const AnalysisState &state) const {
-    // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+                              const AnalysisState &state) const {
+    // scf::ForeachThreadOp alone doesn't bufferize to a memory read, one of the
+    // uses of its matching bbArg may.
     auto foreachThreadOp = cast<ForeachThreadOp>(op);
-    return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]};
+    return state.isValueRead(foreachThreadOp.getTiedBlockArgument(&opOperand));
   }
 
-  bool isMemoryWrite(Operation *op, OpResult opResult,
-                     const AnalysisState &state) const {
-    // This op is a memory write. Stop lookup here to avoid finding false
-    // conflicts involving this op and one of the ops in the region. This is
-    // similar to how scf.if ops are analyzed.
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+                               const AnalysisState &state) const {
+    // Outputs of scf::ForeachThreadOps are always considered as a write.
     return true;
   }
 
+  SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+                                            const AnalysisState &state) const {
+    auto foreachThreadOp = cast<ForeachThreadOp>(op);
+    return {foreachThreadOp.getTiedOpResult(&opOperand)};
+  }
+
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
                                 const AnalysisState &state) const {
     return BufferRelation::Equivalent;
   }
 
+  bool isWritable(Operation *op, Value value,
+                  const AnalysisState &state) const {
+    return true;
+  }
+
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
+    OpBuilder::InsertionGuard guard(rewriter);
     auto foreachThreadOp = cast<ForeachThreadOp>(op);
+    int64_t rank = foreachThreadOp.getRank();
 
-#ifndef NDEBUG
-    // ParallelInsertSliceOpInterface replaces all uses.
-    for (OpResult opResult : foreachThreadOp->getOpResults())
-      assert(opResult.getUses().empty() &&
-             "expected that all uses were already replaced");
-#endif // NDEBUG
+    // Get buffers for all output operands.
+    SmallVector<Value> buffers;
+    for (Value out : foreachThreadOp.getOutputs()) {
+      FailureOr<Value> buffer = getBuffer(rewriter, out, options);
+      if (failed(buffer))
+        return failure();
+      buffers.push_back(*buffer);
+    }
+
+    // Use buffers instead of block arguments.
+    rewriter.setInsertionPointToStart(foreachThreadOp.getBody());
+    for (const auto &it :
+         llvm::zip(foreachThreadOp.getBody()->getArguments().drop_front(rank),
+                   buffers)) {
+      BlockArgument bbArg = std::get<0>(it);
+      Value buffer = std::get<1>(it);
+      Value bufferAsTensor =
+          rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), buffer);
+      bbArg.replaceAllUsesWith(bufferAsTensor);
+    }
 
     // Create new ForeachThreadOp without any results and drop the automatically
     // introduced terminator.
-    TypeRange newResultTypes;
+    rewriter.setInsertionPoint(foreachThreadOp);
     auto newForeachThreadOp = rewriter.create<ForeachThreadOp>(
-        foreachThreadOp.getLoc(), newResultTypes,
+        foreachThreadOp.getLoc(), /*outputs=*/ValueRange(),
         foreachThreadOp.getNumThreads(),
         extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping()));
     newForeachThreadOp.getBody()->getTerminator()->erase();
 
     // Move over block contents of the old op.
+    SmallVector<Value> replacementBbArgs;
+    replacementBbArgs.append(
+        newForeachThreadOp.getBody()->getArguments().begin(),
+        newForeachThreadOp.getBody()->getArguments().end());
+    replacementBbArgs.append(foreachThreadOp.getOutputs().size(), Value());
     rewriter.mergeBlocks(foreachThreadOp.getBody(),
-                         newForeachThreadOp.getBody(),
-                         {newForeachThreadOp.getBody()->getArguments()});
+                         newForeachThreadOp.getBody(), replacementBbArgs);
 
-    // Remove the old op.
-    rewriter.eraseOp(op);
+    // Remove the old op and replace all of its uses.
+    replaceOpWithBufferizedValues(rewriter, op, buffers);
 
     return success();
   }
+
+  FailureOr<BaseMemRefType>
+  getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+                const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+    auto foreachThreadOp = cast<ForeachThreadOp>(op);
+
+    if (auto bbArg = value.dyn_cast<BlockArgument>())
+      // A tensor block argument has the same bufferized type as the
+      // corresponding output operand.
+      return bufferization::getBufferType(
+          foreachThreadOp.getTiedOpOperand(bbArg)->get(), options, fixedTypes);
+
+    // The bufferized result type is the same as the bufferized type of the
+    // corresponding output operand.
+    return bufferization::getBufferType(
+        foreachThreadOp.getOutputs()[value.cast<OpResult>().getResultNumber()],
+        options, fixedTypes);
+  }
+
+  bool isRepetitiveRegion(Operation *op, unsigned index) const {
+    auto foreachThreadOp = cast<ForeachThreadOp>(op);
+    // This op is not repetitive if it has just a single thread.
+    if (llvm::all_of(foreachThreadOp.getNumThreads(), [](Value v) {
+          return getConstantIntValue(v) == static_cast<int64_t>(1);
+        }))
+      return false;
+    return true;
+  }
 };
 
 /// Nothing to do for PerformConcurrentlyOp.

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 35010f520e022..65a9375b1123c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -922,12 +922,7 @@ struct ParallelInsertSliceOpInterface
           ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
                                             const AnalysisState &state) const {
-    if (&opOperand != &op->getOpOperand(1) /*dest*/)
       return {};
-
-    // ParallelInsertSliceOp itself has no results, query its tied op results.
-    auto insertOp = cast<ParallelInsertSliceOp>(op);
-    return {insertOp.getTiedOpResult()};
   }
 
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
@@ -940,84 +935,21 @@ struct ParallelInsertSliceOpInterface
     return &opOperand == &op->getOpOperand(1) /*dest*/;
   }
 
-  BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const AnalysisState &state) const {
-    return BufferRelation::Equivalent;
-  }
-
-  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
-                                 const AnalysisState &state) const {
-    // This interface method is overridden because we want to set a custom
-    // insertion point for tensor copies. They should be inserted right before
-    // the ForeachThreadOp. E.g.:
-    //
-    // %r0, %r1 = foreach_thead ... {
-    //   ...
-    //   perform_concurrently {
-    //     parallel_insert_slice %a into %b ... {inplace = ["true", "true"]}
-    //     parallel_insert_slice %c into %d ... {inplace = ["true", "false"]}
-    //   }
-    // }
-    //
-    // After TensorCopyInsertion:
-    //
-    // %copy = bufferization.alloc_tensor() copy(%d)
-    // %r0, %r1 = foreach_thead ... {
-    //   ...
-    //   perform_concurrently {
-    //     parallel_insert_slice %a into %b ...
-    //     parallel_insert_slice %c into %copy ...
-    //   }
-    // }
-
-    OpBuilder::InsertionGuard g(rewriter);
-    auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
-    ParallelCombiningOpInterface parallelCombiningParent =
-        parallelInsertSliceOp.getParallelCombiningParent();
-    Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
-
-    // Nothing to do if the destination tensor is inplace.
-    assert(state.isInPlace(op->getOpOperand(0) /*src*/) &&
-           "source is always in-place");
-    if (state.isInPlace(op->getOpOperand(1) /*dest*/))
-      return success();
-
-    // Find corresponding OpResult.
-    OpResult opResult = parallelInsertSliceOp.getTiedOpResult();
-
-    // Insert tensor allocation right before the ForeachThreadOp.
-    rewriter.setInsertionPoint(parallelIteratingOp);
-    bool isYielded = state.isTensorYielded(opResult);
-    FailureOr<Value> alloc = allocateTensorForShapedValue(
-        rewriter, op->getLoc(), parallelInsertSliceOp.getDest(),
-        /*escape=*/isYielded, state.getOptions());
-    if (failed(alloc))
-      return failure();
-
-    // Update destination operand.
-    rewriter.updateRootInPlace(parallelInsertSliceOp, [&]() {
-      parallelInsertSliceOp.getDestMutable().assign(*alloc);
-    });
-
-    return success();
-  }
-
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     OpBuilder::InsertionGuard g(rewriter);
     auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
     ParallelCombiningOpInterface parallelCombiningParent =
         parallelInsertSliceOp.getParallelCombiningParent();
-    Operation *parallelIteratingOp = parallelCombiningParent->getParentOp();
 
-    // Get destination buffer.
+    // Bufferize the op outside of the parallel combining terminator.
+    rewriter.setInsertionPoint(parallelCombiningParent);
+
+    // Get source and destination buffers.
     FailureOr<Value> destBuffer =
         getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
     if (failed(destBuffer))
       return failure();
-
-    // Bufferize the ParallelInsertSliceOp outside of `parallelCombiningParent`.
-    rewriter.setInsertionPoint(parallelCombiningParent);
     FailureOr<Value> srcBuffer =
         getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
     if (failed(srcBuffer))
@@ -1043,18 +975,7 @@ struct ParallelInsertSliceOpInterface
                                     *srcBuffer, subview)))
       return failure();
 
-    // Replace all uses of parallelIteratingOp (just the corresponding result).
-    rewriter.setInsertionPointAfter(parallelIteratingOp);
-    Value toTensorOp =
-        rewriter.create<ToTensorOp>(parallelIteratingOp->getLoc(), *destBuffer);
-    // PerformConcurrentlyOp can have multiple ParallelInsertSliceOps.
-    SmallVector<OpOperand *> resultUses = llvm::to_vector(
-        llvm::map_range(parallelInsertSliceOp.getTiedOpResult().getUses(),
-                        [](OpOperand &use) { return &use; }));
-    for (OpOperand *use : resultUses) {
-      rewriter.updateRootInPlace(use->getOwner(),
-                                 [&]() { use->set(toTensorOp); });
-    }
+    // Delete the op.
     rewriter.eraseOp(op);
     return success();
   }

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index cab14698e57c4..8941b03391f71 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -835,16 +835,16 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
   %c4 = arith.constant 4 : index
   %cst = arith.constant 0.000000e+00 : f32
   %0 = linalg.init_tensor [4, 2] : tensor<4x2xf32>
-  %res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) -> (tensor<4x2xf32>) {
+  %res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
     %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
     %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
     scf.foreach_thread.perform_concurrently {
       //      CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
       // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
-      tensor.parallel_insert_slice %2 into %0[%arg0, %arg1] [1, 1] [1, 1] :
+      tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] :
         tensor<1x1xf32> into tensor<4x2xf32>
     }
-  }  
+  }
   return %res: tensor<4x2xf32>
 }
 

diff  --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index d519adb4dfd2c..bf1a2cdeae41c 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -15,15 +15,15 @@ module {
   func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
   //  CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
   //  CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
-  //      CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) -> (tensor<?x?xf32>) {
+  //      CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) -> (tensor<?x?xf32>) {
   //      CHECK:   %[[tA:.*]] = tensor.extract_slice %[[A]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
   //      CHECK:   %[[tB:.*]] = tensor.extract_slice %[[B]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
-  //      CHECK:   %[[tC:.*]] = tensor.extract_slice %[[C]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+  //      CHECK:   %[[tC:.*]] = tensor.extract_slice %[[C_BLK]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
   //      CHECK:   %[[RES:.*]] = linalg.matmul
   // CHECK-SAME:      ins(%[[tA]], %[[tB]] : tensor<?x?xf32>, tensor<?x?xf32>)
   // CHECK-SAME:     outs(%[[tC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
   //      CHECK:   scf.foreach_thread.perform_concurrently {
-  // CHECK-NEXT:     tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} :
+  // CHECK-NEXT:     tensor.parallel_insert_slice %[[RES]] into %[[C_BLK]]{{.*}} :
   // CHECK-SAME:       tensor<?x?xf32> into tensor<?x?xf32>
   // CHECK-NEXT:   }
   // CHECK-NEXT: } {thread_dim_mapping = [1, 0]}
@@ -55,10 +55,10 @@ module {
 //  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor
 //  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor
 //  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor
-func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {  
+func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
   //  CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
   //  CHECK-DAG: %[[c21:.+]] = arith.constant 21 : index
-  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]])
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
   //      CHECK:   %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV1]])
   //      CHECK:   %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
   //  CHECK-NOT:   affine.min
@@ -67,7 +67,7 @@ func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: t
   //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
   //      CHECK:   %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
   //      CHECK:   %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
-  //      CHECK:   %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
+  //      CHECK:   %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
   //      CHECK:   linalg.matmul
   //      CHECK:   scf.foreach_thread.perform_concurrently
   // CHECK-NEXT:    tensor.parallel_insert_slice
@@ -104,14 +104,14 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
   //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : 
   //      CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]]
   //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
-  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
   //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]  
   //      CHECK:   %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
   //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
-  //      CHECK    tensor.extract_slice %[[A]]
   //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
-  //      CHECK    tensor.extract_slice %[[B]]
-  //      CHECK    tensor.extract_slice %[[C]]
+  //      CHECK:   tensor.extract_slice %[[A]]
+  //      CHECK:   tensor.extract_slice %[[B]]
+  //      CHECK:   tensor.extract_slice %[[C_BLK]]
   //      CHECK:   linalg.matmul
   //      CHECK:   scf.foreach_thread.perform_concurrently
   // CHECK-NEXT:    tensor.parallel_insert_slice
@@ -144,7 +144,7 @@ transform.with_pdl_patterns {
 func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
   //  CHECK-DAG: %[[c10:.+]] = arith.constant 10 :
   //  CHECK-DAG: %[[c15:.+]] = arith.constant 15 :
-  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]])
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
   //      CHECK:   %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])  
   //  CHECK-NOT:   affine.max
   //  CHECK-NOT:   affine.min
@@ -152,7 +152,7 @@ func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf
   //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
   //      CHECK:   %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
   //      CHECK:   %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
-  //      CHECK:   %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
+  //      CHECK:   %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
   //      CHECK:   linalg.matmul
   //      CHECK:   scf.foreach_thread.perform_concurrently
   // CHECK-NEXT:    tensor.parallel_insert_slice
@@ -199,7 +199,7 @@ module {
 
 // CHECK-LABEL: extract_source(
 //       CHECK:  %[[C2:.*]] = arith.constant 2 : index
-//       CHECK:  scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) -> (tensor<4xf32>) {
+//       CHECK:  scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) shared_outs(%{{.*}} = %{{.*}}) -> (tensor<4xf32>) {
 //       CHECK:    %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]])
 //       CHECK:    scf.foreach_thread.perform_concurrently {
 //       CHECK:      tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32>
@@ -227,10 +227,10 @@ func.func @matmul_tile_size_dynamic_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?x
   //  CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 :
   //  CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size]]]
   //  CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]]
-  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
-  //      CHECK    tensor.extract_slice %[[A]]
-  //      CHECK    tensor.extract_slice %[[B]]
-  //      CHECK    tensor.extract_slice %[[C]]
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
+  //      CHECK:   tensor.extract_slice %[[A]]
+  //      CHECK:   tensor.extract_slice %[[B]]
+  //      CHECK:   tensor.extract_slice %[[C_BLK]]
   //      CHECK:   linalg.matmul
   //      CHECK:   scf.foreach_thread.perform_concurrently
   // CHECK-NEXT:    tensor.parallel_insert_slice

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 32f8950d24bb7..b3cd3283286ce 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -17,10 +17,10 @@ module {
     %1 = affine.apply #map0()[%d0, %arg0]
 
     // CHECK: scf.foreach_thread {{.*}} {
-    %2 = scf.foreach_thread (%arg3) in (%1)  -> (tensor<?xf32>) {
+    %2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
       %3 = affine.apply #map1(%arg3)[%arg0]
       %4 = affine.min #map2(%arg3)[%d0, %arg0]
-      %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
 
       // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
       // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
@@ -29,7 +29,7 @@ module {
       // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
       %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
       scf.foreach_thread.perform_concurrently {
-        tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
       }
     }
     // CHECK: }
@@ -70,16 +70,16 @@ module {
     %1 = affine.apply #map0()[%arg0]
 
     // CHECK: scf.foreach_thread {{.*}} {
-    %2 = scf.foreach_thread (%arg3) in (%1)  -> (tensor<64xf32>) {
+    %2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<64xf32>) {
       // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor
       %3 = affine.apply #map1(%arg3)[%arg0]
       %4 = affine.min #map2(%arg3)[%arg0]
-      %5 = tensor.extract_slice %arg2[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>
+      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>
 
       // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]]
       %7 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
       scf.foreach_thread.perform_concurrently {
-        tensor.parallel_insert_slice %7 into %arg2[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
+        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
       }
     }
     // CHECK: }

diff  --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 86d1cb65f539f..b79ecb48d7d7f 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -527,11 +527,11 @@ func.func @wrong_num_results(%in: tensor<100xf32>, %out: tensor<100xf32>) {
   %c1 = arith.constant 1 : index
   %num_threads = arith.constant 100 : index
 
-  // expected-error @+1 {{produces 2 results, but its terminator yields 1 value(s)}}
-  %result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>, tensor<100xf32>) {
+  // expected-error @+1 {{1 operands present, but expected 2}}
+  %result:2 = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>, tensor<100xf32>) {
       %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
       scf.foreach_thread.perform_concurrently {
-        tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
+        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
           tensor<1xf32> into tensor<100xf32>
       }
   }
@@ -540,14 +540,14 @@ func.func @wrong_num_results(%in: tensor<100xf32>, %out: tensor<100xf32>) {
 
 // -----
 
-func.func @wrong_type_result(%in: tensor<100xf32>, %out: tensor<100xf32>) {
+func.func @invalid_insert_dest(%in: tensor<100xf32>, %out: tensor<100xf32>) {
   %c1 = arith.constant 1 : index
   %num_threads = arith.constant 100 : index
 
-  // expected-error @+1 {{type mismatch between result 0 ('tensor<?xf32>') and terminator ('tensor<100xf32>')}}
-  %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<?xf32>) {
+  %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) {
       %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
       scf.foreach_thread.perform_concurrently {
+        // expected-error @+1 {{may only insert into an output block argument}}
         tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
           tensor<1xf32> into tensor<100xf32>
       }
@@ -561,11 +561,11 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) {
   %c1 = arith.constant 1 : index
   %num_threads = arith.constant 100 : index
 
-  %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> (tensor<100xf32>) {
+  %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) {
       %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
       // expected-error @+1 {{expected only tensor.parallel_insert_slice ops}}
       scf.foreach_thread.perform_concurrently {
-        tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
+        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
           tensor<1xf32> into tensor<100xf32>
         %0 = arith.constant 1: index
       }

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
index 4efa5513b4200..ec0ffa657d876 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
@@ -120,14 +120,14 @@ func.func @scf_foreach_thread_out_of_place(%in: tensor<100xf32>,
 
   // CHECK-FUNC-NOT: alloc_tensor
   // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[arg1]]) {bufferization.escape = [false]} : tensor<100xf32>
-  // CHECK: scf.foreach_thread
-  %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
+  // CHECK: scf.foreach_thread {{.*}} shared_outs(%[[o:.*]] = %[[alloc]])
+  %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> tensor<100xf32> {
       // CHECK: tensor.extract_slice
       // CHECK: scf.foreach_thread.perform_concurrently
-      // CHECK: tensor.parallel_insert_slice %{{.*}} into %[[alloc]]
+      // CHECK: tensor.parallel_insert_slice %{{.*}} into %[[o]]
       %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
       scf.foreach_thread.perform_concurrently {
-        tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
+        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
           tensor<1xf32> into tensor<100xf32>
       }
   // CHECK: } {thread_dim_mapping = [5]}

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 61ec6faae86ab..a72c6d3714aba 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -525,10 +525,10 @@ func.func @parallel_insert_slice_no_conflict(
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
 
-  // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]])  -> ()
-  %2 = scf.foreach_thread (%arg3) in (%idx2)  -> (tensor<?xf32>) {
+  // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]])
+  %2 = scf.foreach_thread (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
       // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
-      %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
+      %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
       // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref<?xf32
       %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
       // Self-copy will DCE away later.
@@ -538,7 +538,7 @@ func.func @parallel_insert_slice_no_conflict(
       // CHECK-NOT: scf.foreach_thread.perform_concurrently
       // CHECK-NOT: parallel_insert_slice
       scf.foreach_thread.perform_concurrently {
-        tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : 
+        tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] :
           tensor<?xf32> into tensor<?xf32>
       }
   }
@@ -571,26 +571,22 @@ func.func @parallel_insert_slice_with_conflict(
   // CHECK: %[[alloc1:.*]] = memref.alloc
   // CHECK: memref.copy %[[arg2]], %[[alloc1]]
 
-  // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]])  -> ()
-  %2 = scf.foreach_thread (%arg3) in (%idx2)  -> (tensor<?xf32>) {
-      // Another alloc for the extract_slice op.
-      // CHECK: %[[alloc2:.*]] = memref.alloc
-      %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
+  // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]])
+  %2 = scf.foreach_thread (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
+      // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
+      %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
 
-      // CHECK: linalg.fill ins(%{{.*}}) outs(%[[alloc2]] : memref<?xf32
+      // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview1]] : memref<?xf32
       %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
 
-      // Now the copy of the actual insert_slice.
-      // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
-      //
-      // CHECK: memref.copy %[[alloc2]], %[[subview1]]
-      // CHECK: memref.dealloc %[[alloc2]]
+      // Now the copy of the actual insert_slice. (It will fold away.)
+      // CHECK: memref.copy %[[subview1]], %[[subview1]]
 
       // Empty terminator is elided from pretty-printing.
       // CHECK-NOT: scf.foreach_thread.perform_concurrently
       // CHECK-NOT: parallel_insert_slice
       scf.foreach_thread.perform_concurrently {
-        tensor.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] :
+        tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] :
           tensor<?xf32> into tensor<?xf32>
       }
   }
@@ -617,18 +613,18 @@ func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<
   %c2 = arith.constant 2 : index
   %c4 = arith.constant 4 : index
 
-  // CHECK: scf.foreach_thread {{.*}}  -> ()
-  %0 = scf.foreach_thread (%arg3, %arg4) in (%c2, %c4) -> (tensor<8x8xf32>) {
+  // CHECK: scf.foreach_thread {{.*}}
+  %0 = scf.foreach_thread (%arg3, %arg4) in (%c2, %c4) shared_outs(%o = %arg2) -> (tensor<8x8xf32>) {
     %1 = affine.apply #map0(%arg3)
     %3 = tensor.extract_slice %arg0[%1, 0] [4, 8] [1, 1] : tensor<8x8xf32> to tensor<4x8xf32>
     %4 = affine.apply #map1(%arg4)
     %6 = tensor.extract_slice %arg1[0, %4] [8, 4] [1, 1] : tensor<8x8xf32> to tensor<8x4xf32>
-    %7 = tensor.extract_slice %arg2[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32>
- 
+    %7 = tensor.extract_slice %o[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32>
+
     //      CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>)
     %8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32>
     scf.foreach_thread.perform_concurrently {
-      tensor.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
+      tensor.parallel_insert_slice %8 into %o[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
     }
   }
   return %0 : tensor<8x8xf32>
@@ -636,6 +632,71 @@ func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<
 
 // -----
 
+// CHECK-LABEL: func @scf_foreach_private_var(
+//  CHECK-SAME:     %[[t:.*]]: memref<10xf32
+func.func @scf_foreach_private_var(%t: tensor<10xf32>) -> f32 {
+  %c2 = arith.constant 2 : index
+  %c5 = arith.constant 5 : index
+
+  // A copy is inserted for the uses of %t in the loop.
+  // CHECK: %[[t_copy:.*]] = memref.alloc() {{.*}} : memref<10xf32>
+  // CHECK: memref.copy %[[t]], %[[t_copy]]
+
+  // CHECK: scf.foreach_thread (%{{.*}}) in (%{{.*}}) {
+
+  // Load from the copy and store into the shared output.
+  // CHECK:   %[[subview:.*]] = memref.subview %[[t]]
+  // CHECK:   memref.load %[[t_copy]]
+  // CHECK:   memref.store %{{.*}}, %[[subview]]
+  %0 = scf.foreach_thread (%tid) in (%c2) shared_outs(%o = %t) -> tensor<10xf32> {
+    %offset = arith.muli %c5, %tid : index
+    %slice = tensor.extract_slice %o[%offset] [5] [1]
+        : tensor<10xf32> to tensor<5xf32>
+    %r2 = tensor.extract %t[%tid] : tensor<10xf32>
+    %i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32>
+    scf.foreach_thread.perform_concurrently {
+      tensor.parallel_insert_slice %i into %o[%offset] [5] [1]
+          : tensor<5xf32> into tensor<10xf32>
+    }
+  }
+
+  %r = tensor.extract %0[%c2] : tensor<10xf32>
+  return %r : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scf_foreach_privatized_but_not_copied(
+//  CHECK-SAME:     %[[t0:.*]]: memref<10xf32, {{.*}}>, %[[t1:.*]]: memref<10xf32
+func.func @scf_foreach_privatized_but_not_copied(
+    %t0: tensor<10xf32>, %t1: tensor<10xf32>) -> f32 {
+  %c2 = arith.constant 2 : index
+  %c5 = arith.constant 5 : index
+
+  // CHECK-NOT: memref.alloc
+  // CHECK-NOT: memref.copy
+  // CHECK: scf.foreach_thread {{.*}} {
+  %0 = scf.foreach_thread (%tid) in (%c2) shared_outs(%o = %t0) -> tensor<10xf32> {
+    %offset = arith.muli %c5, %tid : index
+    %slice = tensor.extract_slice %o[%offset] [5] [1]
+        : tensor<10xf32> to tensor<5xf32>
+
+    // %t1 is never written in here, so no copy is needed
+    // CHECK: memref.load %[[t1]]
+    %r2 = tensor.extract %t1[%tid] : tensor<10xf32>
+    %i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32>
+    scf.foreach_thread.perform_concurrently {
+      tensor.parallel_insert_slice %i into %o[%offset] [5] [1]
+          : tensor<5xf32> into tensor<10xf32>
+    }
+  }
+
+  %r = tensor.extract %0[%c2] : tensor<10xf32>
+  return %r : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @scf_if_memory_space
 func.func @scf_if_memory_space(%c: i1, %f: f32) -> (f32, f32)
 {

diff  --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 8d0a42fde9165..c1fa4e65f6f5e 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -323,10 +323,10 @@ func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
   // CHECK-NEXT:  }
   // CHECK-NEXT:  }
   // CHECK-NEXT:  return
-  %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
+  %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> tensor<100xf32> {
       %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
       scf.foreach_thread.perform_concurrently {
-        tensor.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
+        tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
           tensor<1xf32> into tensor<100xf32>
       }
   }
@@ -340,7 +340,7 @@ func.func @elide_terminator() -> () {
   //      CHECK:    scf.foreach_thread
   // CHECK-NEXT:  } {thread_dim_mapping = [42]}
   // CHECK-NEXT:  return
-  scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
+  scf.foreach_thread (%thread_idx) in (%num_threads) {
     scf.foreach_thread.perform_concurrently {
     }
   } {thread_dim_mapping = [42]}

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ad50ecb40db2f..83ef943abb7df 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1455,13 +1455,13 @@ func.func @canonicalize_parallel_insert_slice_indices(
   %c1 = arith.constant 1 : index
 
   //  CHECK-NOT: tensor.cast
-  //      CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
+  //      CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<?x?xf32>) {
   // CHECK-NEXT:   scf.foreach_thread.perform_concurrently {
-  // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
-  %2 = scf.foreach_thread (%tidx) in (%num_threads)  -> (tensor<?x?xf32>) {
+  // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[o]][%[[tidx]], 0] [1, 5] [1, 1]
+  %2 = scf.foreach_thread (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
     %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor<?x5xf32>
     scf.foreach_thread.perform_concurrently {
-      tensor.parallel_insert_slice %3 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %3 into %o[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
     }
   }
   return %2 : tensor<?x?xf32>
@@ -1477,12 +1477,12 @@ func.func @dont_fold_parallel_insert_slice(
 {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  //      CHECK: scf.foreach_thread () in () -> (tensor<1x5xf32>) {
+  //      CHECK: scf.foreach_thread () in () shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<1x5xf32>) {
   // CHECK-NEXT:   scf.foreach_thread.perform_concurrently {
-  // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
-  %2 = scf.foreach_thread () in ()  -> (tensor<1x5xf32>) {
+  // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[o]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
+  %2 = scf.foreach_thread () in () shared_outs(%o = %arg1) -> (tensor<1x5xf32>) {
     scf.foreach_thread.perform_concurrently {
-      tensor.parallel_insert_slice %arg0 into %arg1[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
+      tensor.parallel_insert_slice %arg0 into %o[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
     }
   }
   return %2 : tensor<1x5xf32>

diff  --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 220d18d2011c7..1f1936c1df347 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -205,12 +205,12 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
   %num_threads = arith.constant 100 : index
 
   // CHECK: scf.foreach_thread {{.*}} {
-  %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<200x100xf32> {
+  %result = scf.foreach_thread (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<200x100xf32> {
       %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
       scf.foreach_thread.perform_concurrently {
         // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<100xf32, #[[$MAP0]]> to memref<1xf32, #[[$MAP0]]>
         // CHECK: memref.subview %{{.*}}[1, %{{.*}}] [1, 1] [1, 1] : memref<200x100xf32, #[[$MAP1]]> to memref<1xf32, #[[$MAP0]]>
-        tensor.parallel_insert_slice %1 into %out[1, %thread_idx][1, 1][1, 1] :
+        tensor.parallel_insert_slice %1 into %o[1, %thread_idx][1, 1][1, 1] :
           tensor<1xf32> into tensor<200x100xf32>
       }
   }


        


More information about the Mlir-commits mailing list