[Mlir-commits] [mlir] 9b7193f - [mlir][SCF] Add parallel abstraction on tensors.
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jun 1 02:02:22 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-01T09:02:16Z
New Revision: 9b7193f852874a035d9ede1f7464c9fc5b7dca7a
URL: https://github.com/llvm/llvm-project/commit/9b7193f852874a035d9ede1f7464c9fc5b7dca7a
DIFF: https://github.com/llvm/llvm-project/commit/9b7193f852874a035d9ede1f7464c9fc5b7dca7a.diff
LOG: [mlir][SCF] Add parallel abstraction on tensors.
This revision adds `scf.foreach_thread` and other supporting abstractions
that allow connecting parallel abstractions and tensors.
Discussion is available [here](https://discourse.llvm.org/t/rfc-parallel-abstraction-for-tensors-and-buffers/62607).
Added:
Modified:
mlir/include/mlir/Dialect/SCF/SCF.h
mlir/include/mlir/Dialect/SCF/SCFOps.td
mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/invalid.mlir
mlir/test/Dialect/SCF/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/SCF.h b/mlir/include/mlir/Dialect/SCF/SCF.h
index dbe505ccaa74d..dab58cb0b963c 100644
--- a/mlir/include/mlir/Dialect/SCF/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/SCF.h
@@ -14,9 +14,11 @@
#define MLIR_DIALECT_SCF_SCF_H
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
namespace scf {
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index 887b8323f2e6b..89e156223d08e 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -16,6 +16,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
def SCF_Dialect : Dialect {
let name = "scf";
@@ -312,6 +313,245 @@ def ForOp : SCF_Op<"for",
let hasRegionVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// ForeachThreadOp
+//===----------------------------------------------------------------------===//
+
+def ForeachThreadOp : SCF_Op<"foreach_thread", [
+ SingleBlockImplicitTerminator<"scf::PerformConcurrentlyOp">,
+ RecursiveSideEffects,
+ AutomaticAllocationScope,
+ ]> {
+ let summary = "evaluate a block multiple times in parallel";
+ let description = [{
+ `scf.foreach_thread` is a target-independent multi-dimensional parallel
+ function application operation. It has exactly one block that represents the
+ parallel function body and it takes index operands that indicate how many
+ parallel instances of that function are instantiated.
+
+ The only allowed terminator is `scf.foreach_thread.perform_concurrently`,
+ which dictates how the partial results of all parallel invocations should be
+ reconciled into a full value.
+
+ `scf.foreach_thread` returns values that are formed by aggregating the
+ actions of all the `perform_concurrently` terminator of all the threads,
+ in some unspecified order.
+ In other words, `scf.foreach_thread` performs all actions specified in the
+ `perform_concurrently` terminator, after it receives the control back from
+ its body along each thread.
+
+ `scf.foreach_thread` acts as an implicit synchronization point.
+
+ Multi-value returns are encoded by including multiple operations inside the
+ `perform_concurrently` block.
+
+ When the parallel function body has side effects, the order of reads and
+ writes to memory is unspecified across threads.
+
+ Example:
+ ```
+ //
+ // Sequential context.
+ //
+ %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in
+ (%num_threads_1, %numthread_id_2) -> (tensor<?x?xT>, tensor<?xT>) {
+ //
+ // Parallel context, each thread with id = (%thread_id_1, %thread_id_2)
+ // runs its version of the code.
+ //
+ %sA = tensor.extract_slice %A[f((%thread_id_1, %thread_id_2))]:
+ tensor<?x?xT> to tensor<?x?xT>
+ %sB = tensor.extract_slice %B[g((%thread_id_1, %thread_id_2))]:
+ tensor<?x?xT> to tensor<?x?xT>
+ %sC = tensor.extract_slice %C[h((%thread_id_1, %thread_id_2))]:
+ tensor<?x?xT> to tensor<?x?xT>
+ %sD = matmul ins(%sA, %sB) outs(%sC)
+
+ %spointwise = subtensor %pointwise[i((%thread_id_1, %thread_id_2))]:
+ tensor<?xT> to tensor<?xT>
+ %sE = add ins(%spointwise) outs(%sD)
+
+ scf.foreach_thread.perform_concurrently {
+ // First op within the parallel terminator contributes to producing %matmul_and_pointwise#0.
+ scf.foreach_thread.parallel_insert_slice %sD into %C[h((%thread_id_1, %thread_id_2))]:
+ tensor<?x?xT> into tensor<?x?xT>
+
+ // Second op within the parallel terminator contributes to producing %matmul_and_pointwise#1.
+ scf.foreach_thread.parallel_insert_slice %spointwise into %pointwise[i((%thread_id_1, %thread_id_2))]:
+ tensor<?xT> into tensor<?xT>
+ }
+ }
+ // Implicit synchronization point.
+ // Sequential context.
+ //
+```
+
+ }];
+ let arguments = (ins Variadic<Index>:$num_threads);
+
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$region);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
+ // The default builder does not add the proper body BBargs, roll our own.
+ let skipDefaultBuilders = 1;
+ let builders = [
+ // Bodyless builder, result types must be specified.
+ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>,
+ // Builder that takes a bodyBuilder lambda, result types are inferred from
+ // the terminator.
+ OpBuilder<(ins "ValueRange":$num_threads,
+ "function_ref<void(OpBuilder &, Location, ValueRange)>":$bodyBuilder)>
+ ];
+ let extraClassDeclaration = [{
+ int64_t getRank() { return getNumThreads().size(); }
+ ValueRange getThreadIndices() { return getBody()->getArguments(); }
+ Value getThreadIndex(int64_t idx) { return getBody()->getArgument(idx); }
+
+ static void ensureTerminator(Region ®ion, Builder &builder, Location loc);
+
+ PerformConcurrentlyOp getTerminator();
+ }];
+}
+
+def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [
+ NoSideEffect,
+ Terminator,
+ SingleBlockImplicitTerminator<"scf::EndPerformConcurrentlyOp">,
+ HasParent<"ForeachThreadOp">,
+ ]> {
+ let summary = "terminates a `foreach_thread` block";
+ let description = [{
+ `scf.foreach_thread.perform_concurrently` is a designated terminator for
+ the `scf.foreach_thread` operation.
+
+ It has a single region with a single block that contains a flat list of ops.
+ Each such op participates in the aggregate formation of a single result of
+ the enclosing `scf.foreach_thread`.
+ The result number corresponds to the position of the op in the terminator.
+ }];
+
+ let regions = (region SizedRegion<1>:$region);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+
+ // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can
+ // appear inside perform_concurrently.
+ let extraClassDeclaration = [{
+ SmallVector<Type> yieldedTypes();
+ SmallVector<ParallelInsertSliceOp> yieldingOps();
+ }];
+}
+
+def EndPerformConcurrentlyOp : SCF_Op<"foreach_thread.end_perform_concurrently", [
+ NoSideEffect, Terminator, HasParent<"PerformConcurrentlyOp">]> {
+ let summary = "terminates a `foreach_thread.perform_concurrently` block";
+ let description = [{
+ A designated terminator for `foreach_thread.perform_concurrently`.
+ It is not expected to appear in the textual form of the IR.
+ }];
+}
+
+// TODO: Implement PerformConcurrentlyOpInterface.
+def ParallelInsertSliceOp : SCF_Op<"foreach_thread.parallel_insert_slice", [
+ AttrSizedOperandSegments,
+ OffsetSizeAndStrideOpInterface,
+ // PerformConcurrentlyOpInterface,
+ HasParent<"PerformConcurrentlyOp">]> {
+ let summary = [{
+ Specify the tensor slice update of a single thread within the terminator of
+ an `scf.foreach_thread`.
+ }];
+ let description = [{
+ The parent `scf.foreach_thread` returns values that are formed by aggregating
+ the actions of all the ops contained within the `perform_concurrently`
+ terminator of all the threads, in some unspecified order.
+ The `scf.foreach_thread.parallel_insert_slice` is one such op allowed in
+ the `scf.foreach_thread.perform_concurrently` terminator.
+
+ Conflicting writes result in undefined semantics, in that the indices written
+ to by multiple parallel updates might contain data from any of the updates, or
+ even a malformed bit pattern.
+
+ If an index is updated by exactly one updates, the value contained at that index
+ in the resulting tensor will be equal to the value at a corresponding index of a
+ slice that was used for the updated. If an index is not updated at all, its value
+ will be equal to the one in the original tensor.
+
+ This op does not create a new value, which allows maintaining a clean
+ separation between the subset and full tensor.
+ Note that we cannot mark this operation as pure (NoSideEffects), even
+ though it has no side effects, because it will get DCEd during
+ canonicalization.
+ }];
+
+ let arguments = (ins
+ AnyRankedTensor:$source,
+ AnyRankedTensor:$dest,
+ Variadic<Index>:$offsets,
+ Variadic<Index>:$sizes,
+ Variadic<Index>:$strides,
+ I64ArrayAttr:$static_offsets,
+ I64ArrayAttr:$static_sizes,
+ I64ArrayAttr:$static_strides
+ );
+ let assemblyFormat = [{
+ $source `into` $dest ``
+ custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
+ custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
+ custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ attr-dict `:` type($source) `into` type($dest)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::Operation::operand_range offsets() { return getOffsets(); }
+ ::mlir::Operation::operand_range sizes() { return getSizes(); }
+ ::mlir::Operation::operand_range strides() { return getStrides(); }
+ ArrayAttr static_offsets() { return getStaticOffsets(); }
+ ArrayAttr static_sizes() { return getStaticSizes(); }
+ ArrayAttr static_strides() { return getStaticStrides(); }
+
+ Type yieldedType() { return getDest().getType(); }
+
+ RankedTensorType getSourceType() {
+ return getSource().getType().cast<RankedTensorType>();
+ }
+
+ /// Return the expected rank of each of the `static_offsets`, `static_sizes`
+ /// and `static_strides` attributes.
+ std::array<unsigned, 3> getArrayAttrMaxRanks() {
+ unsigned rank = getSourceType().getRank();
+ return {rank, rank, rank};
+ }
+
+ /// Return the number of leading operands before `offsets`, `sizes` and
+ /// `strides` operands.
+ static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
+ }];
+
+ let builders = [
+ // Build a ParallelInsertSliceOp with mixed static and dynamic entries.
+ OpBuilder<(ins "Value":$source, "Value":$dest,
+ "ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
+ "ArrayRef<OpFoldResult>":$strides,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+
+ // Build a ParallelInsertSliceOp with dynamic entries.
+ OpBuilder<(ins "Value":$source, "Value":$dest,
+ "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+ ];
+
+ // let hasCanonicalizer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
def IfOp : SCF_Op<"if",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getNumRegionInvocations",
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 9d3e61e69d9a5..4a682952c853c 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -1044,6 +1045,286 @@ void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
}
+//===----------------------------------------------------------------------===//
+// ForeachThreadOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ForeachThreadOp::verify() {
+ // Check that the body defines as single block argument for the thread index.
+ auto *body = getBody();
+ if (body->getNumArguments() != getRank())
+ return emitOpError("region expects ") << getRank() << " arguments";
+ if (!llvm::all_of(body->getArgumentTypes(),
+ [](Type t) { return t.isIndex(); }))
+ return emitOpError(
+ "expected all region arguments to be of index type `index`");
+
+ // Verify consistency between the result types and the terminator.
+ auto terminatorTypes = getTerminator().yieldedTypes();
+ auto opResults = getResults();
+ if (opResults.size() != terminatorTypes.size())
+ return emitOpError("produces ")
+ << opResults.size() << " results, but its terminator yields "
+ << terminatorTypes.size() << " values";
+ unsigned i = 0;
+ for (auto e : llvm::zip(terminatorTypes, opResults)) {
+ if (std::get<0>(e) != std::get<1>(e).getType())
+ return emitOpError() << "type mismatch between " << i
+ << "th result of foreach_thread (" << std::get<0>(e)
+ << ") and " << i << "th result yielded by its "
+ << "terminator (" << std::get<1>(e).getType() << ")";
+ i++;
+ }
+ return success();
+}
+
+void ForeachThreadOp::print(OpAsmPrinter &p) {
+ p << '(';
+ llvm::interleaveComma(getThreadIndices(), p);
+ p << ") in (";
+ llvm::interleaveComma(getNumThreads(), p);
+ p << ") -> (" << getResultTypes() << ") ";
+ p.printRegion(getRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/getNumResults() > 0);
+ p.printOptionalAttrDict(getOperation()->getAttrs());
+}
+
+ParseResult ForeachThreadOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+ // Parse an opening `(` followed by thread index variables followed by `)`
+ SmallVector<OpAsmParser::Argument, 4> threadIndices;
+ if (parser.parseArgumentList(threadIndices, OpAsmParser::Delimiter::Paren))
+ return failure();
+
+ // Parse `in` threadNums.
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> threadNums;
+ if (parser.parseKeyword("in") ||
+ parser.parseOperandList(threadNums, threadIndices.size(),
+ OpAsmParser::Delimiter::Paren) ||
+ parser.resolveOperands(threadNums, builder.getIndexType(),
+ result.operands))
+ return failure();
+
+ // Parse optional results.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+
+ // Parse region.
+ std::unique_ptr<Region> region = std::make_unique<Region>();
+ for (auto &idx : threadIndices)
+ idx.type = builder.getIndexType();
+ if (parser.parseRegion(*region, threadIndices))
+ return failure();
+
+ // Ensure terminator and move region.
+ ForeachThreadOp::ensureTerminator(*region, builder, result.location);
+ result.addRegion(std::move(region));
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+// Bodyless builder, result types must be specified.
+void ForeachThreadOp::build(mlir::OpBuilder &builder,
+ mlir::OperationState &result, TypeRange resultTypes,
+ ValueRange numThreads) {
+ result.addOperands(numThreads);
+
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block);
+ Block &bodyBlock = bodyRegion->front();
+ bodyBlock.addArguments(
+ SmallVector<Type>(numThreads.size(), builder.getIndexType()),
+ SmallVector<Location>(numThreads.size(), result.location));
+ ForeachThreadOp::ensureTerminator(*bodyRegion, builder, result.location);
+ result.addTypes(resultTypes);
+}
+
+// Builder that takes a bodyBuilder lambda, result types are inferred from
+// the terminator.
+void ForeachThreadOp::build(
+ mlir::OpBuilder &builder, mlir::OperationState &result,
+ ValueRange numThreads,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
+ result.addOperands(numThreads);
+
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block);
+ Block &bodyBlock = bodyRegion->front();
+ bodyBlock.addArguments(
+ SmallVector<Type>(numThreads.size(), builder.getIndexType()),
+ SmallVector<Location>(numThreads.size(), result.location));
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&bodyBlock);
+ bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+ auto terminator =
+ llvm::cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
+ result.addTypes(terminator.yieldedTypes());
+}
+
+// The ensureTerminator method generated by SingleBlockImplicitTerminator is
+// unaware of the fact that our terminator also needs a region to be
+// well-formed. We override it here to ensure that we do the right thing.
+void ForeachThreadOp::ensureTerminator(Region ®ion, Builder &builder,
+ Location loc) {
+ OpTrait::SingleBlockImplicitTerminator<PerformConcurrentlyOp>::Impl<
+ ForeachThreadOp>::ensureTerminator(region, builder, loc);
+ auto terminator =
+ llvm::dyn_cast<PerformConcurrentlyOp>(region.front().getTerminator());
+ PerformConcurrentlyOp::ensureTerminator(terminator.getRegion(), builder, loc);
+}
+
+PerformConcurrentlyOp ForeachThreadOp::getTerminator() {
+ return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
+}
+
+//===----------------------------------------------------------------------===//
+// ParallelInsertSliceOp
+//===----------------------------------------------------------------------===//
+
+// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
+void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
+ Value source, Value dest,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<NamedAttribute> attrs) {
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+ ShapedType::kDynamicStrideOrOffset);
+ dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
+ ShapedType::kDynamicSize);
+ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+ ShapedType::kDynamicStrideOrOffset);
+ build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
+ dynamicStrides, b.getI64ArrayAttr(staticOffsets),
+ b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
+ result.addAttributes(attrs);
+}
+
+// Build a ParallelInsertSliceOp with dynamic entries.
+void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
+ Value source, Value dest, ValueRange offsets,
+ ValueRange sizes, ValueRange strides,
+ ArrayRef<NamedAttribute> attrs) {
+ SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
+ llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
+ SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
+ llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
+ SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
+ llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
+ build(b, result, source, dest, offsetValues, sizeValues, strideValues);
+}
+
+// namespace {
+// /// Pattern to rewrite a parallel_insert_slice op with constant arguments.
+// class ParallelInsertSliceOpConstantArgumentFolder final
+// : public OpRewritePattern<ParallelInsertSliceOp> {
+// public:
+// using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
+
+// LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
+// PatternRewriter &rewriter) const override {
+// // No constant operand, just return.
+// if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
+// return matchPattern(operand, matchConstantIndex());
+// }))
+// return failure();
+
+// // At least one of offsets/sizes/strides is a new constant.
+// // Form the new list of operands and constant attributes from the
+// // existing.
+// SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
+// SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
+// SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
+// canonicalizeSubViewPart(mixedOffsets,
+// ShapedType::isDynamicStrideOrOffset); canonicalizeSubViewPart(mixedSizes,
+// ShapedType::isDynamic); canonicalizeSubViewPart(mixedStrides,
+// ShapedType::isDynamicStrideOrOffset);
+
+// // Create the new op in canonical form.
+// rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
+// insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
+// mixedOffsets, mixedSizes, mixedStrides);
+// return success();
+// }
+// };
+// } // namespace
+
+// void ParallelInsertSliceOp::getCanonicalizationPatterns(
+// RewritePatternSet &results, MLIRContext *context) {
+// results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
+// }
+
+//===----------------------------------------------------------------------===//
+// PerformConcurrentlyOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PerformConcurrentlyOp::verify() {
+ // TODO: PerformConcurrentlyOpInterface.
+ for (const Operation &op : getRegion().front().getOperations())
+ if (!isa<ParallelInsertSliceOp, EndPerformConcurrentlyOp>(op))
+ return emitOpError(
+ "expected only scf.foreach_thread.parallel_insert_slice ops");
+ return success();
+}
+
+void PerformConcurrentlyOp::print(OpAsmPrinter &p) {
+ p << " ";
+ p.printRegion(getRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+ p.printOptionalAttrDict(getOperation()->getAttrs());
+}
+
+ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+
+ SmallVector<OpAsmParser::Argument, 8> regionOperands;
+ std::unique_ptr<Region> region = std::make_unique<Region>();
+ if (parser.parseRegion(*region, regionOperands))
+ return failure();
+
+ PerformConcurrentlyOp::ensureTerminator(*region, builder, result.location);
+ result.addRegion(std::move(region));
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ return success();
+}
+
+SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
+ return llvm::to_vector(
+ llvm::map_range(this->yieldingOps(), [](ParallelInsertSliceOp op) {
+ return op.yieldedType();
+ }));
+}
+
+SmallVector<ParallelInsertSliceOp> PerformConcurrentlyOp::yieldingOps() {
+ SmallVector<ParallelInsertSliceOp> ret;
+ for (Operation &op : *getBody()) {
+ // TODO: PerformConcurrentlyOpInterface interface when this grows up.
+ if (auto sliceOp = llvm::dyn_cast<ParallelInsertSliceOp>(op)) {
+ ret.push_back(sliceOp);
+ continue;
+ }
+ if (auto endPerformOp = llvm::dyn_cast<EndPerformConcurrentlyOp>(op)) {
+ continue;
+ }
+ llvm_unreachable("Unexpected operation in perform_concurrently");
+ }
+ return ret;
+}
+
//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 402d1b67a2630..29c4545047d9d 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -520,3 +520,13 @@ func.func @execute_region() {
}) : () -> ()
return
}
+
+// -----
+
+func.func @wrong_number_of_arguments() -> () {
+ %num_threads = arith.constant 100 : index
+ // expected-error @+1 {{region expects 2 arguments}}
+ scf.foreach_thread (%thread_idx) in (%num_threads, %num_threads) -> () {
+ }
+ return
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index b732b1ede38de..ac43844187569 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -310,3 +310,50 @@ func.func @execute_region() -> i64 {
}) : () -> ()
return %res : i64
}
+
+// CHECK-LABEL: func.func @simple_example
+func.func @simple_example(%in: tensor<100xf32>, %out: tensor<100xf32>) {
+ %c1 = arith.constant 1 : index
+ %num_threads = arith.constant 100 : index
+
+ // CHECK: scf.foreach_thread
+ // CHECK-NEXT: tensor.extract_slice
+ // CHECK-NEXT: scf.foreach_thread.perform_concurrently
+ // CHECK-NEXT: scf.foreach_thread.parallel_insert_slice
+ // CHECK-NEXT: }
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<100xf32> {
+ %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
+ scf.foreach_thread.perform_concurrently {
+ scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] :
+ tensor<1xf32> into tensor<100xf32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @elide_terminator
+func.func @elide_terminator() -> () {
+ %num_threads = arith.constant 100 : index
+
+ // CHECK: scf.foreach_thread
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
+ scf.foreach_thread.perform_concurrently {
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @no_terminator
+func.func @no_terminator() -> () {
+ %num_threads = arith.constant 100 : index
+ // CHECK: scf.foreach_thread
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ scf.foreach_thread (%thread_idx) in (%num_threads) -> () {
+ }
+ return
+}
More information about the Mlir-commits
mailing list