[Mlir-commits] [mlir] b87f1b2 - [MLIR] Add `InParallelOpInterface` for parallel combining operations (#157736)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 12 14:23:04 PDT 2025
Author: Alan Li
Date: 2025-09-12T14:23:00-07:00
New Revision: b87f1b22a8d8a77d5360f201af5ba08adbb0a974
URL: https://github.com/llvm/llvm-project/commit/b87f1b22a8d8a77d5360f201af5ba08adbb0a974
DIFF: https://github.com/llvm/llvm-project/commit/b87f1b22a8d8a77d5360f201af5ba08adbb0a974.diff
LOG: [MLIR] Add `InParallelOpInterface` for parallel combining operations (#157736)
This commit:
- Introduces a new `InParallelOpInterface`, along with the
`ParallelCombiningOpInterface`, represent the parallel updating
operations we have in a parallel loop of `scf.forall`.
- Change the name of `ParallelCombiningOpInterface` to
`InParallelOpInterface` as the naming was quite confusing.
- `ParallelCombiningOpInterface` now is used to generalize operations
that insert into shared tensors within parallel combining regions.
Previously, only `tensor.parallel_insert_slice` was supported directly
in `scf.InParallelOp` regions.
- `tensor.parallel_insert_slice` now implements
`ParallelCombiningOpInterface`.
This change enables future extensions to support additional parallel
combining operations beyond `tensor.parallel_insert_slice`, which have
different update semantics, so the `in_parallel` region can correctly
and safely represent these kinds of operation without potential mistakes
such as races.
Author credits: @qedawkins
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/SCF/invalid.mlir
mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 88df54174da24..d3c01c31636a7 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -654,7 +654,7 @@ def ForallOp : SCF_Op<"forall", [
def InParallelOp : SCF_Op<"forall.in_parallel", [
Pure,
Terminator,
- DeclareOpInterfaceMethods<ParallelCombiningOpInterface>,
+ DeclareOpInterfaceMethods<InParallelOpInterface>,
HasParent<"ForallOp">,
] # GraphRegionNoTerminator.traits> {
let summary = "terminates a `forall` block";
@@ -679,8 +679,6 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
OpBuilder<(ins)>,
];
- // TODO: Add a `InParallelOpInterface` interface for ops that can
- // appear inside in_parallel.
let extraClassDeclaration = [{
::llvm::SmallVector<::mlir::BlockArgument> getDests();
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7d396e5c64c28..2453cf5b5b5a4 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1470,24 +1470,25 @@ def Tensor_PadOp : Tensor_Op<"pad", [
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//
-// TODO: Implement InParallelOpInterface.
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
+ DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
+ ["getUpdatedDestinations", "getIteratingParent"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
- // HasParent<"ParallelCombiningOpInterface">
+ // HasParent<"InParallelOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread of a parent
- ParallelCombiningOpInterface op.
+ InParallelOpInterface op.
}];
let description = [{
The `parallel_insert_slice` yields a subset tensor value to its parent
- ParallelCombiningOpInterface. These subset tensor values are aggregated to
+ InParallelOpInterface. These subset tensor values are aggregated to
in some unspecified order into a full tensor value returned by the parent
parallel iterating op.
The `parallel_insert_slice` is one such op allowed in the
- ParallelCombiningOpInterface op.
+ InParallelOpInterface op.
Conflicting writes result in undefined semantics, in that the indices written
to by multiple parallel updates might contain data from any of the updates,
@@ -1569,8 +1570,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
return ::llvm::cast<RankedTensorType>(getDest().getType());
}
- ParallelCombiningOpInterface getParallelCombiningParent() {
- return dyn_cast<ParallelCombiningOpInterface>(
+ InParallelOpInterface getParallelCombiningParent() {
+ return dyn_cast<InParallelOpInterface>(
getOperation()->getParentOp());
}
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index 72db06163df37..82ab427699f64 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -19,7 +19,7 @@
namespace mlir {
namespace detail {
// TODO: Single region single block interface on interfaces ?
-LogicalResult verifyParallelCombiningOpInterface(Operation *op);
+LogicalResult verifyInParallelOpInterface(Operation *op);
} // namespace detail
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index 424b4cf0a0a58..ace26f723ef53 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
-// Defines the interface for ops that perform parallel combining operations.
+// Defines the interface for ops that perform in parallel combining
+// operations.
//
//===----------------------------------------------------------------------===//
@@ -15,9 +16,9 @@
include "mlir/IR/OpBase.td"
-def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
+def InParallelOpInterface : OpInterface<"InParallelOpInterface"> {
let description = [{
- A parallel combining op is an op with a region.
+ An in parallel op is an op with a region.
This is useful as a terminator to parallel operations that iterate over
some set and return tensors while avoiding tight coupling between the
@@ -52,8 +53,60 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
];
// TODO: Single region single block interface on interfaces ?
let verify = [{
- return verifyParallelCombiningOpInterface($_op);
+ return verifyInParallelOpInterface($_op);
+ }];
+}
+
+def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
+ let description = [{
+ A parallel combining op is an operation that models parallel contributions
+ to result tensors within the context of a parent iterating operation.
+
+ This interface is designed for operations that need to coordinate parallel
+ insertions or contributions to tensors that are being constructed across
+ multiple parallel iterations. The destination refers to a tensor value that
+ is assembled by aggregating results from parallel computations; each
+ parallel iteration may contribute a slice, element, or region to the final
+ result. No in-place mutation of tensors is implied.
+
+ One significant use case for this interface is `tensor.parallel_insert_slice`
+ which allows parallel insertion of slices that are aggregated into a
+ destination tensor. With this interface, other operations that express
+ similar parallel contributions can also be defined.
+
+ This op works within an op implementing the `InParallelOpInterface` that
+ specifies how the parallel results are combined.
+
+ Key semantics:
+ - The operation identifies destination tensors to which iterations
+ contribute through the `getUpdatedDestinations` method
+ - Each parallel iteration may produce elements or regions that are
+ incorporated into the destination tensor
+ - The parent iterating operation manages the coordination and ensures
+ proper synchronization of these contributions
+
+ Note: This interface does not verify itself, it is up to the implementing operation
+ to verify the correctness of the op.
}];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the list of destination values this op contributes to.
+ }],
+ /*retTy=*/"::mlir::MutableOperandRange",
+ /*methodName=*/"getUpdatedDestinations",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the iterating parent for this op.
+ }],
+ /*retTy=*/"::mlir::Operation*",
+ /*methodName=*/"getIteratingParent",
+ /*args=*/(ins)
+ >,
+ ];
}
#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f3db8f7ccfaa1..715eebb3c4a13 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -36,6 +36,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -4147,12 +4148,11 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
return DiagnosedSilenceableFailure::success();
}
- // If we are inside an InParallel region, temporarily set the insertion point
- // outside: only tensor.parallel_insert_slice ops are allowed in there.
- if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
- rewriter.setInsertionPoint(
- target->template getParentOfType<scf::InParallelOp>());
- }
+ // If we are inside a `ParallelCombiningOp` region, temporarily set the
+ // insertion point outside: only ops implementing ParallelCombiningOpInterface
+ // are allowed in there.
+ if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
+ rewriter.setInsertionPoint(target->getParentOp());
Value extracted = tensor::ExtractSliceOp::create(
rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 84f9777a443fd..45b14fcf8aadd 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
@@ -681,7 +682,9 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
results.reserve(forallOp.getResults().size());
for (auto &yieldingOp : terminator.getYieldingOps()) {
auto parallelInsertSliceOp =
- cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+ dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+ if (!parallelInsertSliceOp)
+ continue;
Value dst = parallelInsertSliceOp.getDest();
Value src = parallelInsertSliceOp.getSource();
@@ -1439,12 +1442,9 @@ InParallelOp ForallOp::getTerminator() {
SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
SmallVector<Operation *> storeOps;
- InParallelOp inParallelOp = getTerminator();
- for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
- if (auto parallelInsertSliceOp =
- dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
- parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
- storeOps.push_back(parallelInsertSliceOp);
+ for (Operation *user : bbArg.getUsers()) {
+ if (auto parallelOp = dyn_cast<ParallelCombiningOpInterface>(user)) {
+ storeOps.push_back(parallelOp);
}
}
return storeOps;
@@ -1911,8 +1911,10 @@ struct FoldTensorCastOfOutputIntoForallOp
auto terminator = newForallOp.getTerminator();
for (auto [yieldingOp, outputBlockArg] : llvm::zip(
terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
- auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
- insertSliceOp.getDestMutable().assign(outputBlockArg);
+ if (auto parallelCombingingOp =
+ dyn_cast<ParallelCombiningOpInterface>(yieldingOp)) {
+ parallelCombingingOp.getUpdatedDestinations().assign(outputBlockArg);
+ }
}
// Cast results back to the original types.
@@ -1971,19 +1973,22 @@ LogicalResult InParallelOp::verify() {
if (!forallOp)
return this->emitOpError("expected forall op parent");
- // TODO: InParallelOpInterface.
for (Operation &op : getRegion().front().getOperations()) {
- if (!isa<tensor::ParallelInsertSliceOp>(op)) {
- return this->emitOpError("expected only ")
- << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
+ auto parallelCombiningOp = dyn_cast<ParallelCombiningOpInterface>(&op);
+ if (!parallelCombiningOp) {
+ return this->emitOpError("expected only ParallelCombiningOpInterface")
+ << " ops";
}
// Verify that inserts are into out block arguments.
- Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
+ MutableOperandRange dests = parallelCombiningOp.getUpdatedDestinations();
ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
- if (!llvm::is_contained(regionOutArgs, dest))
- return op.emitOpError("may only insert into an output block argument");
+ for (OpOperand &dest : dests) {
+ if (!llvm::is_contained(regionOutArgs, dest.get()))
+ return op.emitOpError("may only insert into an output block argument");
+ }
}
+
return success();
}
@@ -2018,12 +2023,17 @@ OpResult InParallelOp::getParentResult(int64_t idx) {
}
SmallVector<BlockArgument> InParallelOp::getDests() {
- return llvm::to_vector<4>(
- llvm::map_range(getYieldingOps(), [](Operation &op) {
- // Add new ops here as needed.
- auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
- return llvm::cast<BlockArgument>(insertSliceOp.getDest());
- }));
+ SmallVector<BlockArgument> updatedDests;
+ for (Operation &yieldingOp : getYieldingOps()) {
+ auto parallelCombiningOp =
+ dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+ if (!parallelCombiningOp)
+ continue;
+ for (OpOperand &updatedOperand :
+ parallelCombiningOp.getUpdatedDestinations())
+ updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
+ }
+ return updatedDests;
}
llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index a44612410bdee..63216e7cc7fba 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -16,7 +16,7 @@ using namespace mlir::bufferization;
namespace {
/// The `scf.forall.in_parallel` terminator is special in a few ways:
/// * It does not implement the BranchOpInterface or
-/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface
+/// RegionBranchTerminatorOpInterface, but the InParallelOpInterface
/// which is not supported by BufferDeallocation.
/// * It has a graph-like region which only allows one specific tensor op
/// * After bufferization the nested region is always empty
@@ -40,9 +40,9 @@ namespace {
/// <implicit in_parallel terminator here>
/// }
/// ```
-struct InParallelOpInterface
- : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
- scf::InParallelOp> {
+struct InParallelDeallocOpInterface
+ : public BufferDeallocationOpInterface::ExternalModel<
+ InParallelDeallocOpInterface, scf::InParallelOp> {
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
const DeallocationOptions &options) const {
auto inParallelOp = cast<scf::InParallelOp>(op);
@@ -75,7 +75,7 @@ struct ReduceReturnOpInterface
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
- InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
+ InParallelOp::attachInterface<InParallelDeallocOpInterface>(*ctx);
ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 68584ec4fd814..fa97b49a41d97 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2976,9 +2976,9 @@ class InsertSliceOpConstantArgumentFolder final
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
// The only
diff erence between InsertSliceOp and ParallelInsertSliceOp
- // is that the insertion point is just before the ParallelCombiningOp in
+ // is that the insertion point is just before the InParallelOp in
// the parallel case.
- if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+ if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
sourceType, toInsert);
@@ -3153,9 +3153,9 @@ struct InsertSliceOpSourceCastInserter final
// Insert the cast.
OpBuilder::InsertionGuard g(rewriter);
// The only
diff erence between InsertSliceOp and ParallelInsertSliceOp is
- // that the insertion point is just before the ParallelCombiningOp in the
+ // that the insertion point is just before the InParallelOp in the
// parallel case.
- if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+ if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
newSrcType, insertSliceOp.getSource());
@@ -3846,8 +3846,7 @@ OpFoldResult PadOp::fold(FoldAdaptor) {
//===----------------------------------------------------------------------===//
OpResult ParallelInsertSliceOp::getTiedOpResult() {
- ParallelCombiningOpInterface parallelCombiningParent =
- getParallelCombiningParent();
+ InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
for (const auto &it :
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
Operation &nextOp = it.value();
@@ -3901,8 +3900,8 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
}
LogicalResult ParallelInsertSliceOp::verify() {
- if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
- return this->emitError("expected ParallelCombiningOpInterface parent, got:")
+ if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
+ return this->emitError("expected InParallelOpInterface parent, got:")
<< *(getOperation()->getParentOp());
// Verify result type against inferred type.
@@ -3935,6 +3934,19 @@ llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
}
+// ParallelCombiningOpInterface implementation.
+MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
+ return getDestMutable();
+}
+
+Operation *ParallelInsertSliceOp::getIteratingParent() {
+ // Return the parent InParallelOpInterface's parent.
+ if (auto combiningOp =
+ dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
+ return combiningOp->getParentOp();
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c3356c1e4b9d8..bce964e47a3be 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -970,10 +970,10 @@ struct ParallelInsertSliceOpInterface
BufferizationState &state) const {
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
- ParallelCombiningOpInterface parallelCombiningParent =
+ InParallelOpInterface parallelCombiningParent =
parallelInsertSliceOp.getParallelCombiningParent();
- // Bufferize the op outside of the parallel combining terminator.
+ // Bufferize the op outside of the in parallel terminator.
rewriter.setInsertionPoint(parallelCombiningParent);
// Get source and destination buffers.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index d76c02af7ab16..b32faf481af80 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -215,12 +215,11 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
sourceInsertSliceOp.getMixedSizes(),
droppedDims, resolvedSizes);
- // If we are inside an InParallel region, temporarily set the insertion
- // point outside: only tensor.parallel_insert_slice ops are allowed in
- // there.
- if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
- rewriter.setInsertionPoint(
- insertSliceOp->template getParentOfType<scf::InParallelOp>());
+ // If we are inside a ParallelCombining region, temporarily set the
+ // insertion point outside: only ops of ParallelCombiningOpInterface are
+ // allowed in there.
+ if (isa<mlir::ParallelCombiningOpInterface>(insertSliceOp.getOperation())) {
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
}
// Resolve offsets according to source offsets and strides.
diff --git a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
index 2b6703543bbd3..30b8191bf34b0 100644
--- a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
+++ b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
@@ -11,11 +11,11 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ParallelCombiningOpInterface
+// InParallelOpInterface (formerly ParallelCombiningOpInterface)
//===----------------------------------------------------------------------===//
// TODO: Single region single block interface on interfaces ?
-LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) {
+LogicalResult mlir::detail::verifyInParallelOpInterface(Operation *op) {
if (op->getNumRegions() != 1)
return op->emitError("expected single region op");
if (!op->getRegion(0).hasOneBlock())
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5f42938244db6..9005110205630 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -915,7 +915,7 @@ func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> te
// -----
-func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
+func.func @parallel_insert_slice() -> tensor<4x2xf32> {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
@@ -923,6 +923,7 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
%res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
%1 = tensor.empty() : tensor<1x1xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ // CHECK: scf.forall.in_parallel
scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
// CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index bb7958083e55c..37fc86b18e7f0 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -645,7 +645,7 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) {
%result = scf.forall (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) {
%1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
- // expected-error @+1 {{expected only tensor.parallel_insert_slice ops}}
+ // expected-error @+1 {{expected only ParallelCombiningOpInterface ops}}
scf.forall.in_parallel {
tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] :
tensor<1xf32> into tensor<100xf32>
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
index 9bb87ffbb2090..ed3685514dd0d 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
@@ -908,3 +908,111 @@ func.func @parallel_region_no_read()
}
return
}
+
+// -----
+
+// CHECK-LABEL: func @in_order_multiple_parallel_writes
+func.func @in_order_multiple_parallel_writes(%2: tensor<320xf32> {bufferization.writable = true},
+ %3: tensor<320xf32> {bufferization.writable = true})
+ -> (tensor<320xf32>, tensor<320xf32>)
+{
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant -0.000000e+00 : f32
+ %c320 = arith.constant 320 : index
+ %4:2 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2, %arg2 = %3) -> (tensor<320xf32>, tensor<320xf32>) {
+ // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]}
+ %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32>
+ // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]}
+ %7 = tensor.extract_slice %arg2[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32>
+ // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]}
+ %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]}
+ // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]}
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %6 into %arg2[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32>
+ tensor.parallel_insert_slice %8 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32>
+ }
+ }
+ return %4#0, %4#1 : tensor<320xf32>, tensor<320xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_order_parallel_write
+func.func @out_of_order_parallel_write(%2: tensor<320xf32> {bufferization.writable = true},
+ %3: tensor<320xf32> {bufferization.writable = true})
+ -> (tensor<320xf32>, tensor<320xf32>)
+{
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant -0.000000e+00 : f32
+ %c320 = arith.constant 320 : index
+ %4:2 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2, %arg2 = %3) -> (tensor<320xf32>, tensor<320xf32>) {
+ // The extract_slice cannot operate in place because it is used after the
+ // first write.
+ // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]}
+ %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32>
+
+ // Additionally the fill aliases the thread local slice.
+ // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]}
+ %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<1xf32>) -> tensor<1xf32>
+
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]}
+ tensor.parallel_insert_slice %7 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32>
+ // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]}
+ tensor.parallel_insert_slice %6 into %arg2[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32>
+ }
+ }
+ return %4#0, %4#1 : tensor<320xf32>, tensor<320xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_order_parallel_write
+func.func @out_of_order_parallel_write_multiple_reads(%2: tensor<320xf32> {bufferization.writable = true},
+ %3: tensor<320xf32> {bufferization.writable = true})
+ -> (tensor<320xf32>, tensor<320xf32>)
+{
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant -0.000000e+00 : f32
+ %c320 = arith.constant 320 : index
+ %4:2 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2, %arg2 = %3) -> (tensor<320xf32>, tensor<320xf32>) {
+ // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["false", "none"]}
+ %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32>
+ // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]}
+ %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<1xf32>) -> tensor<1xf32>
+
+ %reverse = arith.subi %c320, %arg0 : index
+ // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]}
+ %8 = tensor.extract_slice %arg1[%reverse] [1] [1] : tensor<320xf32> to tensor<1xf32>
+ scf.forall.in_parallel {
+ // Also cannot operate in place due to subsequent conflicting reads.
+ // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]}
+ tensor.parallel_insert_slice %7 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32>
+ // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]}
+ tensor.parallel_insert_slice %8 into %arg2[%reverse] [1] [1] : tensor<1xf32> into tensor<320xf32>
+ }
+ }
+ return %4#0, %4#1 : tensor<320xf32>, tensor<320xf32>
+}
+// -----
+
+// CHECK-LABEL: func @in_order_multiple_parallel_writes
+func.func @in_order_multiple_parallel_writes(%2: tensor<320xf32> {bufferization.writable = true})
+ -> (tensor<320xf32>)
+{
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant -0.000000e+00 : f32
+ %c320 = arith.constant 320 : index
+ %4 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2) -> (tensor<320xf32>) {
+ // CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true", "none"]}
+ %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32>
+ %reverse = arith.subi %c320, %arg0 : index
+ // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]}
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %6 into %arg1[%reverse] [1] [1] : tensor<1xf32> into tensor<320xf32>
+ }
+ }
+ return %4 : tensor<320xf32>
+}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
index 8f4b924cfd3cc..92486b8ed7208 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir
@@ -112,7 +112,7 @@ func.func @scf_while_non_equiv_condition_and_body(%A: tensor<5xi1>,
// CHECK-SAME: %[[arg0:.*]]: tensor<100xf32>, %[[arg1:.*]]: tensor<100xf32>
// CHECK-FUNC-LABEL: func @scf_forall_out_of_place(
func.func @scf_forall_out_of_place(%in: tensor<100xf32>,
- %out: tensor<100xf32>) {
+ %out: tensor<100xf32>) {
%c1 = arith.constant 1 : index
%num_threads = arith.constant 100 : index
@@ -132,3 +132,31 @@ func.func @scf_forall_out_of_place(%in: tensor<100xf32>,
} {mapping = [#gpu.thread<x>]}
return
}
+
+// -----
+
+// CHECK-LABEL: func @in_order_multiple_parallel_writes
+func.func @in_order_multiple_parallel_writes(%2: tensor<320xf32>,
+ %3: tensor<320xf32>)
+ -> (tensor<320xf32>, tensor<320xf32>)
+{
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant -0.000000e+00 : f32
+ %c320 = arith.constant 320 : index
+ %4:2 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %2, %arg2 = %3) -> (tensor<320xf32>, tensor<320xf32>) {
+ // CHECK: tensor.extract_slice {{.*}}
+ %6 = tensor.extract_slice %arg1[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32>
+ // CHECK: tensor.extract_slice {{.*}}
+ %7 = tensor.extract_slice %arg2[%arg0] [1] [1] : tensor<320xf32> to tensor<1xf32>
+ // CHECK: linalg.fill {{.*}}
+ %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1xf32>) -> tensor<1xf32>
+
+ // CHECK: tensor.parallel_insert_slice {{.*}}
+ // CHECK: tensor.parallel_insert_slice {{.*}}
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %6 into %arg2[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32>
+ tensor.parallel_insert_slice %8 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32>
+ }
+ }
+ return %4#0, %4#1 : tensor<320xf32>, tensor<320xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5042198d78b74..66cb7956c89f2 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10819,6 +10819,7 @@ cc_library(
":LinalgTransformOpsIncGen",
":LinalgTransforms",
":LinalgUtils",
+ ":ParallelCombiningOpInterface",
":SCFDialect",
":SCFTransforms",
":Support",
More information about the Mlir-commits
mailing list