[Mlir-commits] [mlir] [SCF] Add interface methods to `ParallelCombiningOp` for promotion (PR #159840)
Alan Li
llvmlistbot at llvm.org
Fri Sep 19 12:45:21 PDT 2025
https://github.com/lialan created https://github.com/llvm/llvm-project/pull/159840
`ParallelCombiningOp` adds expandability to the parallel insertion of a `scf.forall.in_parallel` op.
This patch adds interface methods for the optimizer to promote ops.
* `canPromoteInParallelLoop` make decisions whether we can fold/promote in trivial iteration cases.
* `promoteInParallelLoop` does the actual work.
>From a9073c056efe40e7a36917aec3025c97b668ed35 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 18 Sep 2025 13:29:16 -0400
Subject: [PATCH 1/2] Adding a missing op for ParallelCombiningOpInterface
---
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 2 +-
.../mlir/Dialect/Tensor/IR/TensorOps.td | 3 +-
.../Interfaces/ParallelCombiningOpInterface.h | 3 ++
.../ParallelCombiningOpInterface.td | 20 +++++++
mlir/lib/Dialect/SCF/IR/SCF.cpp | 52 ++++++++++++-------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 15 ++++++
6 files changed, 73 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index ba648181daecb..830b49321c2e4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -58,7 +58,7 @@ ForallOp getForallOpThreadIndexOwner(Value val);
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
/// Promotes the loop body of a scf::ForallOp to its containing block.
-void promote(RewriterBase &rewriter, scf::ForallOp forallOp);
+LogicalResult promote(RewriterBase &rewriter, scf::ForallOp forallOp);
/// An owning vector of values, handy to return from functions.
using ValueVector = SmallVector<Value>;
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..be04c3a4aebbe 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1474,7 +1474,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
- ["getUpdatedDestinations", "getIteratingParent"]>,
+ ["getUpdatedDestinations", "getIteratingParent",
+ "promoteInParallelLoop"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"InParallelOpInterface">
]> {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index 82ab427699f64..ff4e5a87d05c7 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -15,6 +15,9 @@
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
namespace mlir {
namespace detail {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index ace26f723ef53..632371b2777fd 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -106,6 +106,26 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
/*methodName=*/"getIteratingParent",
/*args=*/(ins)
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Promotes this parallel combining op out of its enclosing parallel loop
+ and returns the values that should replace the destinations updated by
+ this op.
+ }],
+ /*retTy=*/"::mlir::FailureOr<::llvm::SmallVector<::mlir::Value>>",
+ /*methodName=*/"promoteInParallelLoop",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if this op can be promoted out of its enclosing parallel
+ loop.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"canPromoteInParallelLoop",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+ /*methodBody=*/[{ return true; }]
+ >,
];
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c35989ecba6cd..4115ca00f64b5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -651,8 +651,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
return failure();
}
- promote(rewriter, *this);
- return success();
+ return promote(rewriter, *this);
}
Block::BlockArgListType ForallOp::getRegionIterArgs() {
@@ -664,10 +663,23 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
}
/// Promotes the loop body of a scf::ForallOp to its containing block.
-void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
+LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
OpBuilder::InsertionGuard g(rewriter);
scf::InParallelOp terminator = forallOp.getTerminator();
+ // Make sure we can promote all parallel combining ops in terminator:
+ for (auto &yieldingOp : terminator.getYieldingOps()) {
+ auto parallelCombiningOp =
+ dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+ if (!parallelCombiningOp)
+ return rewriter.notifyMatchFailure(
+ forallOp, "terminator has non-parallel-combining op");
+ if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
+ return rewriter.notifyMatchFailure(
+ forallOp, "parallel combining op cannot be promoted");
+ }
+
+
// Replace block arguments with lower bounds (replacements for IVs) and
// outputs.
SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
@@ -683,30 +695,29 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
SmallVector<Value> results;
results.reserve(forallOp.getResults().size());
for (auto &yieldingOp : terminator.getYieldingOps()) {
- auto parallelInsertSliceOp =
- dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
- if (!parallelInsertSliceOp)
+ auto parallelCombiningOp =
+ dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+ if (!parallelCombiningOp)
continue;
- Value dst = parallelInsertSliceOp.getDest();
- Value src = parallelInsertSliceOp.getSource();
- if (llvm::isa<TensorType>(src.getType())) {
- results.push_back(tensor::InsertSliceOp::create(
- rewriter, forallOp.getLoc(), dst.getType(), src, dst,
- parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
- parallelInsertSliceOp.getStrides(),
- parallelInsertSliceOp.getStaticOffsets(),
- parallelInsertSliceOp.getStaticSizes(),
- parallelInsertSliceOp.getStaticStrides()));
- } else {
- llvm_unreachable("unsupported terminator");
- }
+ assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));
+
+ FailureOr<SmallVector<Value>> promotedValues =
+ parallelCombiningOp.promoteInParallelLoop(rewriter);
+ if (failed(promotedValues))
+ return failure();
+
+ results.append(promotedValues->begin(), promotedValues->end());
}
+ if (results.size() != forallOp.getResults().size())
+ return rewriter.notifyMatchFailure(
+ forallOp, "failed to materialize replacements for all results");
rewriter.replaceAllUsesWith(forallOp.getResults(), results);
// Erase the old terminator and the loop.
rewriter.eraseOp(terminator);
rewriter.eraseOp(forallOp);
+ return success();
}
LoopNest mlir::scf::buildLoopNest(
@@ -1789,7 +1800,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
// All of the loop dimensions perform a single iteration. Inline loop body.
if (newMixedLowerBounds.empty()) {
- promote(rewriter, op);
+ if (failed(promote(rewriter, op)))
+ return failure();
return success();
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa97b49a41d97..2932000b85b3b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3947,6 +3947,21 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
return nullptr;
}
+FailureOr<SmallVector<Value>>
+ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
+ Value dst = getDest();
+ Value src = getSource();
+ if (!isa<TensorType>(src.getType()))
+ return rewriter.notifyMatchFailure(getOperation(),
+ "expected tensor source");
+
+ Value inserted = tensor::InsertSliceOp::create(
+ rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
+ getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
+
+ return SmallVector<Value>{inserted};
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
>From 551f35e3d6c4890b2c5f16e48bd2c1d1d9645bed Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 18 Sep 2025 14:00:55 -0400
Subject: [PATCH 2/2] Update. Related test: test/Dialect/SCF/transform-ops.mlir
---
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 2 +-
.../mlir/Interfaces/ParallelCombiningOpInterface.h | 1 -
.../mlir/Interfaces/ParallelCombiningOpInterface.td | 7 ++++---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 10 ++++------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 12 ++++++++----
5 files changed, 17 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index be04c3a4aebbe..4fb4cc8410230 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1475,7 +1475,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
["getUpdatedDestinations", "getIteratingParent",
- "promoteInParallelLoop"]>,
+ "promoteInParallelLoop", "canPromoteInParallelLoop"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"InParallelOpInterface">
]> {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index ff4e5a87d05c7..85cc18c47a527 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -17,7 +17,6 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/SmallVector.h"
namespace mlir {
namespace detail {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index 632371b2777fd..1a333d82d8468 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -109,10 +109,10 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
InterfaceMethod<
/*desc=*/[{
Promotes this parallel combining op out of its enclosing parallel loop
- and returns the values that should replace the destinations updated by
+ and returns the value that should replace the destination updated by
this op.
}],
- /*retTy=*/"::mlir::FailureOr<::llvm::SmallVector<::mlir::Value>>",
+ /*retTy=*/"::mlir::FailureOr<::mlir::Value>",
/*methodName=*/"promoteInParallelLoop",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter)
>,
@@ -124,7 +124,8 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
/*retTy=*/"bool",
/*methodName=*/"canPromoteInParallelLoop",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter),
- /*methodBody=*/[{ return true; }]
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return false; }]
>,
];
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 4115ca00f64b5..04737738d8593 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -672,14 +672,12 @@ LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp)
auto parallelCombiningOp =
dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
if (!parallelCombiningOp)
- return rewriter.notifyMatchFailure(
- forallOp, "terminator has non-parallel-combining op");
+ continue;
if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
return rewriter.notifyMatchFailure(
forallOp, "parallel combining op cannot be promoted");
}
-
// Replace block arguments with lower bounds (replacements for IVs) and
// outputs.
SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
@@ -702,12 +700,12 @@ LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp)
assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));
- FailureOr<SmallVector<Value>> promotedValues =
+ FailureOr<Value> promotedValue =
parallelCombiningOp.promoteInParallelLoop(rewriter);
- if (failed(promotedValues))
+ if (failed(promotedValue))
return failure();
- results.append(promotedValues->begin(), promotedValues->end());
+ results.push_back(*promotedValue);
}
if (results.size() != forallOp.getResults().size())
return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2932000b85b3b..f05c58a40fde0 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3947,19 +3947,23 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
return nullptr;
}
-FailureOr<SmallVector<Value>>
+FailureOr<Value>
ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
Value dst = getDest();
Value src = getSource();
if (!isa<TensorType>(src.getType()))
- return rewriter.notifyMatchFailure(getOperation(),
- "expected tensor source");
+ return failure();
Value inserted = tensor::InsertSliceOp::create(
rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
- return SmallVector<Value>{inserted};
+ return inserted;
+}
+
+bool ParallelInsertSliceOp::canPromoteInParallelLoop(RewriterBase &) {
+ return isa<TensorType>(getSource().getType()) &&
+ isa<TensorType>(getDest().getType());
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list