[Mlir-commits] [mlir] 9b7193f - [mlir][SCF] Add parallel abstraction on tensors.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jun 1 02:02:22 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-01T09:02:16Z
New Revision: 9b7193f852874a035d9ede1f7464c9fc5b7dca7a

URL: https://github.com/llvm/llvm-project/commit/9b7193f852874a035d9ede1f7464c9fc5b7dca7a
DIFF: https://github.com/llvm/llvm-project/commit/9b7193f852874a035d9ede1f7464c9fc5b7dca7a.diff

LOG: [mlir][SCF] Add parallel abstraction on tensors.

This revision adds `scf.foreach_thread` and other supporting abstractions
that allow connecting parallel abstractions and tensors.

Discussion is available [here](https://discourse.llvm.org/t/rfc-parallel-abstraction-for-tensors-and-buffers/62607).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCF.h
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/test/Dialect/SCF/invalid.mlir
    mlir/test/Dialect/SCF/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCF.h b/mlir/include/mlir/Dialect/SCF/SCF.h
index dbe505ccaa74d..dab58cb0b963c 100644
--- a/mlir/include/mlir/Dialect/SCF/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/SCF.h
@@ -14,9 +14,11 @@
 #define MLIR_DIALECT_SCF_SCF_H
 
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 
 namespace mlir {
 namespace scf {

diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 887b8323f2e6b..89e156223d08e 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -16,6 +16,7 @@
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
 
 def SCF_Dialect : Dialect {
   let name = "scf";
@@ -312,6 +313,245 @@ def ForOp : SCF_Op<"for",
   let hasRegionVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ForeachThreadOp
+//===----------------------------------------------------------------------===//
+
+def ForeachThreadOp : SCF_Op<"foreach_thread", [
+       SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">,
+       RecursiveSideEffects,
+       AutomaticAllocationScope,
+      ]> {
+  let summary = "evaluate a block multiple times in parallel";
+  let description = [{
+    `scf.foreach_thread` is a target-independent multi-dimensional parallel
+    function application operation. It has exactly one block that represents the
+    parallel function body and it takes index operands that indicate how many
+    parallel instances of that function are instantiated.
+
+    The only allowed terminator is `scf.foreach_thread.perform_concurrently`,
+    which dictates how the partial results of all parallel invocations should be
+    reconciled into a full value.
+
+    `scf.foreach_thread` returns values that are formed by aggregating the
+    actions of all the `perform_concurrently` terminator of all the threads,
+    in some unspecified order.
+    In other words, `scf.foreach_thread` performs all actions specified in the
+    `perform_concurrently` terminator, after it receives the control back from
+    its body along each thread.
+
+    `scf.foreach_thread` acts as an implicit synchronization point.
+
+    Multi-value returns are encoded by including multiple operations inside the
+    `perform_concurrently` block.
+
+    When the parallel function body has side effects, the order of reads and
+    writes to memory is unspecified across threads.
+
+    Example:
+    ```
+    //
+    // Sequential context.
+    //
+    %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
+         (%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
+      //
+      // Parallel context, each thread with id = (%thread_id_1, %thread_id_2)
+      // runs its version of the code.
+      //
+      %sA = tensor.extract_slice %A[f((%thread_id_1, %thread_id_2))]:
+        tensor<?x?xT> to tensor<?x?xT>
+      %sB = tensor.extract_slice %B[g((%thread_id_1, %thread_id_2))]:
+        tensor<?x?xT> to tensor<?x?xT>
+      %sC = tensor.extract_slice %C[h((%thread_id_1, %thread_id_2))]:
+        tensor<?x?xT> to tensor<?x?xT>
+      %sD = matmul ins(%sA, %sB) outs(%sC)
+
+      %spointwise = subtensor %pointwise[i((%thread_id_1, %thread_id_2))]:
+        tensor<?xT> to tensor<?xT>
+      %sE = add ins(%spointwise) outs(%sD)
+
+      scf.foreach_thread.perform_concurrently {
+        // First op within the parallel terminator contributes to producing %matmul_and_pointwise#0.
+        scf.foreach_thread.parallel_insert_slice %sD into %C[h((%thread_id_1, %thread_id_2))]:
+          tensor<?x?xT> into tensor<?x?xT>
+
+        // Second op within the parallel terminator contributes to producing %matmul_and_pointwise#1.
+        scf.foreach_thread.parallel_insert_slice %spointwise into %pointwise[i((%thread_id_1, %thread_id_2))]:
+          tensor<?xT> into tensor<?xT>
+      }
+    }
+    // Implicit synchronization point.
+    // Sequential context.
+    //
+```
+
+  }];
+  let arguments = (ins Variadic<Index>:$num_threads);
+
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$region);
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+
+  // The default builder does not add the proper body BBargs, roll our own.
+  let skipDefaultBuilders = 1;
+  let builders = [
+    // Bodyless builder, result types must be specified.
+    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>,
+    // Builder that takes a bodyBuilder lambda, result types are inferred from
+    // the terminator.
+    OpBuilder<(ins "ValueRange":$num_threads,
+              "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
+  ];
+  let extraClassDeclaration = [{
+    int64_t getRank() { return getNumThreads().size(); }
+    ValueRange getThreadIndices() { return getBody()->getArguments(); }
+    Value getThreadIndex(int64_t idx) { return getBody()->getArgument(idx); }
+
+    static void ensureTerminator(Region &region, Builder &builder, Location loc);
+
+    PerformConcurrentlyOp getTerminator();
+  }];
+}
+
+def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
+       NoSideEffect,
+       Terminator,
+       SingleBlockImplicitTerminator<"scf::EndPerformConcurrentlyOp">,
+       HasParent<"ForeachThreadOp">,
+      ]> {
+  let summary = "terminates a `foreach_thread` block";
+  let description = [{
+    `scf.foreach_thread.perform_concurrently` is a designated terminator for
+    the `scf.foreach_thread` operation.
+
+    It has a single region with a single block that contains a flat list of ops.
+    Each such op participates in the aggregate formation of a single result of
+    the enclosing `scf.foreach_thread`.
+    The result number corresponds to the position of the op in the terminator.
+  }];
+
+  let regions = (region SizedRegion<1>:$region);
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+
+  // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
+  // appear inside perform_concurrently.
+  let extraClassDeclaration = [{
+    SmallVector<Type> yieldedTypes();
+    SmallVector<ParallelInsertSliceOp> yieldingOps();
+  }];
+}
+
+def EndPerformConcurrentlyOp : SCF_Op<"foreach_thread.end_perform_concurrently", [
+       NoSideEffect, Terminator, HasParent<"PerformConcurrentlyOp">]> {
+  let summary = "terminates a `foreach_thread.perform_concurrently` block";
+  let description = [{
+    A designated terminator for `foreach_thread.perform_concurrently`.
+    It is not expected to appear in the textual form of the IR.
+  }];
+}
+
+// TODO: Implement PerformConcurrentlyOpInterface.
+def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
+       AttrSizedOperandSegments,
+       OffsetSizeAndStrideOpInterface,
+       // PerformConcurrentlyOpInterface,
+       HasParent<"PerformConcurrentlyOp">]> {
+  let summary = [{
+    Specify the tensor slice update of a single thread within the terminator of
+    an `scf.foreach_thread`.
+  }];
+  let description = [{
+    The parent `scf.foreach_thread` returns values that are formed by aggregating
+    the actions of all the ops contained within the `perform_concurrently`
+    terminator of all the threads, in some unspecified order.
+    The `scf.foreach_thread.parallel_insert_slice` is one such op allowed in
+    the `scf.foreach_thread.perform_concurrently` terminator.
+
+    Conflicting writes result in undefined semantics, in that the indices written
+    to by multiple parallel updates might contain data from any of the updates, or
+    even a malformed bit pattern.
+
+    If an index is updated by exactly one updates, the value contained at that index
+    in the resulting tensor will be equal to the value at a corresponding index of a
+    slice that was used for the updated. If an index is not updated at all, its value
+    will be equal to the one in the original tensor.
+
+    This op does not create a new value, which allows maintaining a clean
+    separation between the subset and full tensor.
+    Note that we cannot mark this operation as pure (NoSideEffects), even
+    though it has no side effects, because it will get DCEd during
+    canonicalization.
+  }];
+
+  let arguments = (ins
+    AnyRankedTensor:$source,
+    AnyRankedTensor:$dest,
+    Variadic<Index>:$offsets,
+    Variadic<Index>:$sizes,
+    Variadic<Index>:$strides,
+    I64ArrayAttr:$static_offsets,
+    I64ArrayAttr:$static_sizes,
+    I64ArrayAttr:$static_strides
+  );
+  let assemblyFormat = [{
+    $source `into` $dest ``
+    custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
+    custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
+    custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+    attr-dict `:` type($source) `into` type($dest)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::Operation::operand_range offsets() { return getOffsets(); }
+    ::mlir::Operation::operand_range sizes() { return getSizes(); }
+    ::mlir::Operation::operand_range strides() { return getStrides(); }
+    ArrayAttr static_offsets() { return getStaticOffsets(); }
+    ArrayAttr static_sizes() { return getStaticSizes(); }
+    ArrayAttr static_strides() { return getStaticStrides(); }
+
+    Type yieldedType() { return getDest().getType(); }
+
+    RankedTensorType getSourceType() {
+      return getSource().getType().cast<RankedTensorType>();
+    }
+
+    /// Return the expected rank of each of the `static_offsets`, `static_sizes`
+    /// and `static_strides` attributes.
+    std::array<unsigned, 3> getArrayAttrMaxRanks() {
+      unsigned rank = getSourceType().getRank();
+      return {rank, rank, rank};
+    }
+
+    /// Return the number of leading operands before `offsets`, `sizes` and
+    /// `strides` operands.
+    static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+  }];
+
+  let builders = [
+    // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
+    OpBuilder<(ins "Value":$source, "Value":$dest,
+      "ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
+      "ArrayRef<OpFoldResult>":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+
+    // Build a ParallelInsertSliceOp with dynamic entries.
+    OpBuilder<(ins "Value":$source, "Value":$dest,
+      "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+  ];
+
+  // let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
 def IfOp : SCF_Op<"if",
       [DeclareOpInterfaceMethods<RegionBranchOpInterface,
                                  ["getNumRegionInvocations",

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 9d3e61e69d9a5..4a682952c853c 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -1044,6 +1045,286 @@ void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
               LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// ForeachThreadOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ForeachThreadOp::verify() {
+  // Check that the body defines as single block argument for the thread index.
+  auto *body = getBody();
+  if (body->getNumArguments() != getRank())
+    return emitOpError("region expects ") << getRank() << " arguments";
+  if (!llvm::all_of(body->getArgumentTypes(),
+                    [](Type t) { return t.isIndex(); }))
+    return emitOpError(
+        "expected all region arguments to be of index type `index`");
+
+  // Verify consistency between the result types and the terminator.
+  auto terminatorTypes = getTerminator().yieldedTypes();
+  auto opResults = getResults();
+  if (opResults.size() != terminatorTypes.size())
+    return emitOpError("produces ")
+           << opResults.size() << " results, but its terminator yields "
+           << terminatorTypes.size() << " values";
+  unsigned i = 0;
+  for (auto e : llvm::zip(terminatorTypes, opResults)) {
+    if (std::get<0>(e) != std::get<1>(e).getType())
+      return emitOpError() << "type mismatch between " << i
+                           << "th result of foreach_thread (" << std::get<0>(e)
+                           << ") and " << i << "th result yielded by its "
+                           << "terminator (" << std::get<1>(e).getType() << ")";
+    i++;
+  }
+  return success();
+}
+
+void ForeachThreadOp::print(OpAsmPrinter &p) {
+  p << '(';
+  llvm::interleaveComma(getThreadIndices(), p);
+  p << ") in (";
+  llvm::interleaveComma(getNumThreads(), p);
+  p << ") -> (" << getResultTypes() << ") ";
+  p.printRegion(getRegion(),
+                /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/getNumResults() > 0);
+  p.printOptionalAttrDict(getOperation()->getAttrs());
+}
+
+ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
+                                   OperationState &result) {
+  auto &builder = parser.getBuilder();
+  // Parse an opening `(` followed by thread index variables followed by `)`
+  SmallVector<OpAsmParser::Argument, 4> threadIndices;
+  if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  // Parse `in` threadNums.
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> threadNums;
+  if (parser.parseKeyword("in") ||
+      parser.parseOperandList(threadNums, threadIndices.size(),
+                              OpAsmParser::Delimiter::Paren) ||
+      parser.resolveOperands(threadNums, builder.getIndexType(),
+                             result.operands))
+    return failure();
+
+  // Parse optional results.
+  if (parser.parseOptionalArrowTypeList(result.types))
+    return failure();
+
+  // Parse region.
+  std::unique_ptr<Region> region = std::make_unique<Region>();
+  for (auto &idx : threadIndices)
+    idx.type = builder.getIndexType();
+  if (parser.parseRegion(*region, threadIndices))
+    return failure();
+
+  // Ensure terminator and move region.
+  ForeachThreadOp::ensureTerminator(*region, builder, result.location);
+  result.addRegion(std::move(region));
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+// Bodyless builder, result types must be specified.
+void ForeachThreadOp::build(mlir::OpBuilder &builder,
+                            mlir::OperationState &result, TypeRange resultTypes,
+                            ValueRange numThreads) {
+  result.addOperands(numThreads);
+
+  Region *bodyRegion = result.addRegion();
+  bodyRegion->push_back(new Block);
+  Block &bodyBlock = bodyRegion->front();
+  bodyBlock.addArguments(
+      SmallVector<Type>(numThreads.size(), builder.getIndexType()),
+      SmallVector<Location>(numThreads.size(), result.location));
+  ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
+  result.addTypes(resultTypes);
+}
+
+// Builder that takes a bodyBuilder lambda, result types are inferred from
+// the terminator.
+void ForeachThreadOp::build(
+    mlir::OpBuilder &builder, mlir::OperationState &result,
+    ValueRange numThreads,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
+  result.addOperands(numThreads);
+
+  Region *bodyRegion = result.addRegion();
+  bodyRegion->push_back(new Block);
+  Block &bodyBlock = bodyRegion->front();
+  bodyBlock.addArguments(
+      SmallVector<Type>(numThreads.size(), builder.getIndexType()),
+      SmallVector<Location>(numThreads.size(), result.location));
+
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(&bodyBlock);
+  bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+  auto terminator =
+      llvm::cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
+  result.addTypes(terminator.yieldedTypes());
+}
+
+// The ensureTerminator method generated by SingleBlockImplicitTerminator is
+// unaware of the fact that our terminator also needs a region to be
+// well-formed. We override it here to ensure that we do the right thing.
+void ForeachThreadOp::ensureTerminator(Region &region, Builder &builder,
+                                       Location loc) {
+  OpTrait::SingleBlockImplicitTerminator<PerformConcurrentlyOp>::Impl<
+      ForeachThreadOp>::ensureTerminator(region, builder, loc);
+  auto terminator =
+      llvm::dyn_cast<PerformConcurrentlyOp>(region.front().getTerminator());
+  PerformConcurrentlyOp::ensureTerminator(terminator.getRegion(), builder, loc);
+}
+
+PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
+  return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
+}
+
+//===----------------------------------------------------------------------===//
+// ParallelInsertSliceOp
+//===----------------------------------------------------------------------===//
+
+// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
+void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
+                                  Value source, Value dest,
+                                  ArrayRef<OpFoldResult> offsets,
+                                  ArrayRef<OpFoldResult> sizes,
+                                  ArrayRef<OpFoldResult> strides,
+                                  ArrayRef<NamedAttribute> attrs) {
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+                             ShapedType::kDynamicStrideOrOffset);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
+                             ShapedType::kDynamicSize);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+                             ShapedType::kDynamicStrideOrOffset);
+  build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
+        dynamicStrides, b.getI64ArrayAttr(staticOffsets),
+        b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
+  result.addAttributes(attrs);
+}
+
+// Build a ParallelInsertSliceOp with dynamic entries.
+void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
+                                  Value source, Value dest, ValueRange offsets,
+                                  ValueRange sizes, ValueRange strides,
+                                  ArrayRef<NamedAttribute> attrs) {
+  SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
+      llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
+  SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
+      llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
+  SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
+      llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
+  build(b, result, source, dest, offsetValues, sizeValues, strideValues);
+}
+
+// namespace {
+// /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
+// class ParallelInsertSliceOpConstantArgumentFolder final
+//     : public OpRewritePattern<ParallelInsertSliceOp> {
+// public:
+//   using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
+
+//   LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
+//                                 PatternRewriter &rewriter) const override {
+//     // No constant operand, just return.
+//     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
+//           return matchPattern(operand, matchConstantIndex());
+//         }))
+//       return failure();
+
+//     // At least one of offsets/sizes/strides is a new constant.
+//     // Form the new list of operands and constant attributes from the
+//     // existing.
+//     SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
+//     SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
+//     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
+//     canonicalizeSubViewPart(mixedOffsets,
+//     ShapedType::isDynamicStrideOrOffset); canonicalizeSubViewPart(mixedSizes,
+//     ShapedType::isDynamic); canonicalizeSubViewPart(mixedStrides,
+//     ShapedType::isDynamicStrideOrOffset);
+
+//     // Create the new op in canonical form.
+//     rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
+//         insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
+//         mixedOffsets, mixedSizes, mixedStrides);
+//     return success();
+//   }
+// };
+// } // namespace
+
+// void ParallelInsertSliceOp::getCanonicalizationPatterns(
+//     RewritePatternSet &results, MLIRContext *context) {
+//   results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
+// }
+
+//===----------------------------------------------------------------------===//
+// PerformConcurrentlyOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PerformConcurrentlyOp::verify() {
+  // TODO: PerformConcurrentlyOpInterface.
+  for (const Operation &op : getRegion().front().getOperations())
+    if (!isa<ParallelInsertSliceOp, EndPerformConcurrentlyOp>(op))
+      return emitOpError(
+          "expected only scf.foreach_thread.parallel_insert_slice ops");
+  return success();
+}
+
+void PerformConcurrentlyOp::print(OpAsmPrinter &p) {
+  p << " ";
+  p.printRegion(getRegion(),
+                /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/false);
+  p.printOptionalAttrDict(getOperation()->getAttrs());
+}
+
+ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
+                                         OperationState &result) {
+  auto &builder = parser.getBuilder();
+
+  SmallVector<OpAsmParser::Argument, 8> regionOperands;
+  std::unique_ptr<Region> region = std::make_unique<Region>();
+  if (parser.parseRegion(*region, regionOperands))
+    return failure();
+
+  PerformConcurrentlyOp::ensureTerminator(*region, builder, result.location);
+  result.addRegion(std::move(region));
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  return success();
+}
+
+SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
+  return llvm::to_vector(
+      llvm::map_range(this->yieldingOps(), [](ParallelInsertSliceOp op) {
+        return op.yieldedType();
+      }));
+}
+
+SmallVector<ParallelInsertSliceOp> PerformConcurrentlyOp::yieldingOps() {
+  SmallVector<ParallelInsertSliceOp> ret;
+  for (Operation &op : *getBody()) {
+    // TODO: PerformConcurrentlyOpInterface interface when this grows up.
+    if (auto sliceOp = llvm::dyn_cast<ParallelInsertSliceOp>(op)) {
+      ret.push_back(sliceOp);
+      continue;
+    }
+    if (auto endPerformOp = llvm::dyn_cast<EndPerformConcurrentlyOp>(op)) {
+      continue;
+    }
+    llvm_unreachable("Unexpected operation in perform_concurrently");
+  }
+  return ret;
+}
+
 //===----------------------------------------------------------------------===//
 // IfOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 402d1b67a2630..29c4545047d9d 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -520,3 +520,13 @@ func.func @execute_region() {
   }) : () -> ()
   return
 }
+
+// -----
+
+func.func @wrong_number_of_arguments() -> () {
+  %num_threads = arith.constant 100 : index
+  // expected-error @+1 {{region expects 2 arguments}}
+  scf.foreach_thread (%thread_idx) in (%num_threads, %num_threads) -> () {
+  }
+  return
+}

diff  --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index b732b1ede38de..ac43844187569 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -310,3 +310,50 @@ func.func @execute_region() -> i64 {
   }) : () -> ()
   return %res : i64
 }
+
+// CHECK-LABEL: func.func @simple_example
+func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
+  %c1 = arith.constant 1 : index
+  %num_threads = arith.constant 100 : index
+
+  //      CHECK:    scf.foreach_thread
+  // CHECK-NEXT:  tensor.extract_slice
+  // CHECK-NEXT:  scf.foreach_thread.perform_concurrently
+  // CHECK-NEXT:  scf.foreach_thread.parallel_insert_slice
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  return
+  %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
+      %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
+      scf.foreach_thread.perform_concurrently {
+        scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
+          tensor<1xf32> into tensor<100xf32>
+      }
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @elide_terminator
+func.func @elide_terminator() -> () {
+  %num_threads = arith.constant 100 : index
+
+  //      CHECK:    scf.foreach_thread
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  return
+  scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
+    scf.foreach_thread.perform_concurrently {
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @no_terminator
+func.func @no_terminator() -> () {
+  %num_threads = arith.constant 100 : index
+  //      CHECK:    scf.foreach_thread
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  return
+  scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
+  }
+  return
+}


        


More information about the Mlir-commits mailing list