[Mlir-commits] [mlir] [SCF] Add interface methods to `ParallelCombiningOp` for promotion (PR #159840)
Alan Li
llvmlistbot at llvm.org
Mon Dec 8 19:03:35 PST 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/159840
>From 020e2c1d6b4a615d291100dff44c06031d1c8f11 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 18 Sep 2025 13:29:16 -0400
Subject: [PATCH] [MLIR][SCF] Add `promoteInParallelLoop` to
`ParallelCombiningOpInterface`
This is a followup of #157736 where we introduced `ParallelCombiningOpInterface`.
This patch extends `ParallelCombiningOpInterface` with two new methods:
* `canPromoteInParallelLoop`: returns whether an op can be promoted
* `promoteInParallelLoop`: promotes the op and returns the replacement value
The `scf::promote` function is refactored to use these interface methods
instead of hardcoding `tensor::ParallelInsertSliceOp` handling. This makes
the promotion logic extensible to other parallel combining ops.
---
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 2 +-
.../mlir/Dialect/Tensor/IR/TensorOps.td | 3 +-
.../Interfaces/ParallelCombiningOpInterface.h | 2 +
.../ParallelCombiningOpInterface.td | 21 ++++++++
mlir/lib/Dialect/SCF/IR/SCF.cpp | 54 ++++++++++++-------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 19 +++++++
mlir/test/Dialect/SCF/canonicalize.mlir | 25 +++++++++
7 files changed, 104 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..ac6f034dba728 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -58,7 +58,7 @@ ForallOp getForallOpThreadIndexOwner(Value val);
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
/// Promotes the loop body of a scf::ForallOp to its containing block.
-void promote(RewriterBase &rewriter, scf::ForallOp forallOp);
+LogicalResult promote(RewriterBase &rewriter, scf::ForallOp forallOp);
/// An owning vector of values, handy to return from functions.
using ValueVector = SmallVector<Value>;
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d2b6007c628..b03380bf65b8b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1483,7 +1483,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
- ["getUpdatedDestinations", "getIteratingParent"]>,
+ ["getUpdatedDestinations", "getIteratingParent",
+ "promoteInParallelLoop", "canPromoteInParallelLoop"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"InParallelOpInterface">
]> {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index 82ab427699f64..85cc18c47a527 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -15,6 +15,8 @@
#define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace detail {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index ace26f723ef53..1a333d82d8468 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -106,6 +106,27 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
/*methodName=*/"getIteratingParent",
/*args=*/(ins)
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Promotes this parallel combining op out of its enclosing parallel loop
+ and returns the value that should replace the destination updated by
+ this op.
+ }],
+ /*retTy=*/"::mlir::FailureOr<::mlir::Value>",
+ /*methodName=*/"promoteInParallelLoop",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns true if this op can be promoted out of its enclosing parallel
+ loop.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"canPromoteInParallelLoop",
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return false; }]
+ >,
];
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c75528a76c999..b069ae90d4e68 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -777,8 +777,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
return failure();
}
- promote(rewriter, *this);
- return success();
+ return promote(rewriter, *this);
}
Block::BlockArgListType ForallOp::getRegionIterArgs() {
@@ -790,10 +789,28 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
}
/// Promotes the loop body of a scf::ForallOp to its containing block.
-void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
+LogicalResult mlir::scf::promote(RewriterBase &rewriter,
+ scf::ForallOp forallOp) {
OpBuilder::InsertionGuard g(rewriter);
scf::InParallelOp terminator = forallOp.getTerminator();
+ // Make sure we can promote all parallel combining ops in terminator:
+ unsigned numParallelCombiningOps = 0;
+ for (auto &yieldingOp : terminator.getYieldingOps()) {
+ auto parallelCombiningOp =
+ dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+ if (!parallelCombiningOp)
+ continue;
+ ++numParallelCombiningOps;
+ if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
+ return rewriter.notifyMatchFailure(
+ forallOp, "parallel combining op cannot be promoted");
+ }
+ if (numParallelCombiningOps != forallOp.getResults().size())
+ return rewriter.notifyMatchFailure(
+ forallOp,
+ "number of parallel combining ops does not match number of results");
+
// Replace block arguments with lower bounds (replacements for IVs) and
// outputs.
SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
@@ -809,30 +826,26 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
SmallVector<Value> results;
results.reserve(forallOp.getResults().size());
for (auto &yieldingOp : terminator.getYieldingOps()) {
- auto parallelInsertSliceOp =
- dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
- if (!parallelInsertSliceOp)
+ auto parallelCombiningOp =
+ dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+ if (!parallelCombiningOp)
continue;
- Value dst = parallelInsertSliceOp.getDest();
- Value src = parallelInsertSliceOp.getSource();
- if (llvm::isa<TensorType>(src.getType())) {
- results.push_back(tensor::InsertSliceOp::create(
- rewriter, forallOp.getLoc(), dst.getType(), src, dst,
- parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
- parallelInsertSliceOp.getStrides(),
- parallelInsertSliceOp.getStaticOffsets(),
- parallelInsertSliceOp.getStaticSizes(),
- parallelInsertSliceOp.getStaticStrides()));
- } else {
- llvm_unreachable("unsupported terminator");
- }
+ assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));
+
+ FailureOr<Value> promotedValue =
+ parallelCombiningOp.promoteInParallelLoop(rewriter);
+ if (failed(promotedValue))
+ return failure();
+
+ results.push_back(*promotedValue);
}
rewriter.replaceAllUsesWith(forallOp.getResults(), results);
// Erase the old terminator and the loop.
rewriter.eraseOp(terminator);
rewriter.eraseOp(forallOp);
+ return success();
}
LoopNest mlir::scf::buildLoopNest(
@@ -1890,7 +1903,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
// All of the loop dimensions perform a single iteration. Inline loop body.
if (newMixedLowerBounds.empty()) {
- promote(rewriter, op);
+ if (failed(promote(rewriter, op)))
+ return failure();
return success();
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 204e9bb73e12c..bf3790cd023af 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3962,6 +3962,25 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
return nullptr;
}
+FailureOr<Value>
+ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
+ Value dst = getDest();
+ Value src = getSource();
+ if (!isa<TensorType>(src.getType()))
+ return failure();
+
+ Value inserted = tensor::InsertSliceOp::create(
+ rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
+ getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
+
+ return inserted;
+}
+
+bool ParallelInsertSliceOp::canPromoteInParallelLoop(RewriterBase &) {
+ return isa<TensorType>(getSource().getType()) &&
+ isa<TensorType>(getDest().getType());
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..e99948fdc6926 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2171,3 +2171,28 @@ func.func @scf_for_all_step_size_0() {
}
return
}
+
+// -----
+
+// Test single-iteration forall with multiple parallel_insert_slice ops.
+func.func @inline_forall_loop_multiple_results(
+ %arg0: tensor<8x8xf32>, %arg1: tensor<4x4xf32>,
+ %s0: tensor<2x3xf32>, %s1: tensor<2x2xf32>) -> (tensor<8x8xf32>, tensor<4x4xf32>) {
+ %0:2 = scf.forall (%i) in (1) shared_outs (%out0 = %arg0, %out1 = %arg1)
+ -> (tensor<8x8xf32>, tensor<4x4xf32>) {
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %s0 into %out0[0, 0] [2, 3] [1, 1]
+ : tensor<2x3xf32> into tensor<8x8xf32>
+ tensor.parallel_insert_slice %s1 into %out1[0, 0] [2, 2] [1, 1]
+ : tensor<2x2xf32> into tensor<4x4xf32>
+ }
+ }
+ return %0#0, %0#1 : tensor<8x8xf32>, tensor<4x4xf32>
+}
+// CHECK-LABEL: @inline_forall_loop_multiple_results
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x8xf32>, %[[ARG1:.*]]: tensor<4x4xf32>,
+// CHECK-SAME: %[[S0:.*]]: tensor<2x3xf32>, %[[S1:.*]]: tensor<2x2xf32>
+// CHECK-NOT: scf.forall
+// CHECK-DAG: %[[R0:.*]] = tensor.insert_slice %[[S0]] into %[[ARG0]]
+// CHECK-DAG: %[[R1:.*]] = tensor.insert_slice %[[S1]] into %[[ARG1]]
+// CHECK: return %[[R0]], %[[R1]]
More information about the Mlir-commits
mailing list