[Mlir-commits] [mlir] [mlir][scf] Implement Conversion from scf.parallel to Nested scf.for (PR #147692)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 9 03:44:49 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Michael Marjieh (mmarjieh)
<details>
<summary>Changes</summary>
Add a utility function/transform operation to convert `scf.parallel` loops to nested `scf.for` loops.
---
Full diff: https://github.com/llvm/llvm-project/pull/147692.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+28)
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.h (+3)
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+11)
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h (+6)
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+38)
- (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp (+91)
- (added) mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir (+80)
- (added) mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir (+62)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5dba8c5e57ba8..e2b42208f3f8e 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -105,6 +105,34 @@ def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}
+def ParallelForToNestedForOps : Op<Transform_Dialect, "loop.parallel_for_to_nested_fors",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Converts scf.parallel into a nest of scf.for operations";
+ let description = [{
+ Converts the `scf.parallel` operation pointed to by the given handle into a
+ set of nested `scf.for` operations. Each new operation corresponds to one
+ dimension of the original parallel loop.
+
+ The operand handle must be associated with exactly one payload operation.
+
+ Loops with shared outputs are currently not supported.
+
+ #### Return Modes
+
+ Consumes the operand handle. Produces a silenceable failure if the operand
+ is not associated with a single `scf.parallel` payload operation.
+ Returns as many handles as the given `parallel` op has dimensions that are
+ associated with the generated `scf.for` loops.
+ Produces a silenceable failure if another number of resulting handles is
+ requested.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index b70599df6f503..54b0118507184 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForallToForLoopPass();
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
std::unique_ptr<Pass> createForallToParallelLoopPass();
+/// Creates a pass that converts SCF forall loops to SCF parallel loops.
+std::unique_ptr<Pass> createParallelForToNestedForsPass();
+
// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..afa4ef460c219 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -124,6 +124,17 @@ def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
let constructor = "mlir::createForallToParallelLoopPass()";
}
+def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> {
+ let summary = "Convert SCF parallel for loops to nested SCF for loops";
+ let constructor = "mlir::createParallelForToNestedForsPass()";
+ let description = [{
+ This pass transforms SCF.ParallelOp operations into a nest of SCF.ForOp
+ operations. The transformation is useful for cases where the parallel loop
+ can be expressed as a series of sequential iterations, allowing for more
+ fine-grained control over the loop execution.
+ }];
+}
+
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 63163b77f7f16..5e613238d016d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -42,6 +42,12 @@ LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
ParallelOp *result = nullptr);
+/// Try converting scf.forall into an scf.parallel loop.
+/// The conversion is only supported for forall operations with no results.
+LogicalResult parallelForToNestedFors(RewriterBase &rewriter,
+ ParallelOp parallelOp,
+ ForOp *result = nullptr);
+
/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 57c27231f2144..7fd9255c490ef 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -149,6 +149,44 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// ParallelForToNestedForOps
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
+ transform::TransformRewriter &rewriter,
+ transform::TransformResults &results, transform::TransformState &state) {
+ auto payload = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(payload))
+ return emitSilenceableError() << "expected a single payload op";
+
+ auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
+ if (!target) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "expected the payload to be scf.parallel";
+ diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+ return diag;
+ }
+
+ if (getNumResults() != 1) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "op expects one result, given "
+ << getNumResults();
+ diag.attachNote(target.getLoc()) << "payload op";
+ return diag;
+ }
+
+ scf::ForOp opResult;
+ if (failed(scf::parallelForToNestedFors(rewriter, target, &opResult))) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "failed to convert parallel into nested fors";
+ return diag;
+ }
+
+ results.set(cast<OpResult>(getTransformed()[0]), {opResult});
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..a9ffa9dc208a0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
LoopPipelining.cpp
LoopRangeFolding.cpp
LoopSpecialization.cpp
+ ParallelForToNestedFors.cpp
ParallelLoopCollapsing.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
new file mode 100644
index 0000000000000..75672f1c9239e
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
@@ -0,0 +1,91 @@
+//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ParallelOp to nested scf.for ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
+ scf::ParallelOp parallelOp,
+ scf::ForOp *result) {
+
+ if (!parallelOp.getResults().empty()) {
+ parallelOp->emitError("Currently ScfParallel to ScfFor conversion "
+ "doesn't support ScfParallel with results.");
+ return failure();
+ }
+
+ rewriter.setInsertionPoint(parallelOp);
+
+ Location loc = parallelOp.getLoc();
+ auto lowerBounds = parallelOp.getLowerBound();
+ auto upperBounds = parallelOp.getUpperBound();
+ auto steps = parallelOp.getStep();
+
+ assert(lowerBounds.size() == upperBounds.size() &&
+ lowerBounds.size() == steps.size() &&
+ "Mismatched parallel loop bounds");
+
+ SmallVector<Value> ivs;
+ auto loopNest =
+ scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
+
+ auto oldInductionVars = parallelOp.getInductionVars();
+ auto newInductionVars = llvm::map_to_vector(
+ loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
+ assert(oldInductionVars.size() == newInductionVars.size() &&
+ "Mismatched induction variables");
+ for (auto [oldIV, newIV] : llvm::zip(oldInductionVars, newInductionVars))
+ oldIV.replaceAllUsesWith(newIV);
+
+ auto *linearizedBody = loopNest.loops.back().getBody();
+ Block ¶llelBody = *parallelOp.getBody();
+ for (Operation &op : llvm::make_early_inc_range(parallelBody)) {
+ // Skip the terminator of the parallelOp body.
+ if (&op == parallelBody.getTerminator())
+ continue;
+ op.moveBefore(linearizedBody->getTerminator());
+ }
+ rewriter.eraseOp(parallelOp);
+ if (result)
+ *result = loopNest.loops.front();
+ return success();
+}
+
+namespace {
+struct ParallelForToNestedFors final
+ : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
+ void runOnOperation() override {
+ Operation *parentOp = getOperation();
+ IRRewriter rewriter(parentOp->getContext());
+
+ parentOp->walk([&](scf::ParallelOp parallelOp) {
+ if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
+ return signalPassFailure();
+ }
+ });
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
+ return std::make_unique<ParallelForToNestedFors>();
+}
diff --git a/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..4df7bab790ea5
--- /dev/null
+++ b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-for-to-nested-fors))' -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+ scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ // CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+ // CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+ // CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+ // CHECK: }
+ // CHECK: }
+ return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+ scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+
+ scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ // CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+ // CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+ // CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+ // CHECK: }
+ // CHECK: }
+ // CHECK: scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG2]] step %[[ARG4]] {
+ // CHECK: scf.for %[[VAL_3:.*]] = %[[ARG1]] to %[[ARG3]] step %[[ARG5]] {
+ // CHECK: func.call @callee(%[[VAL_2]], %[[VAL_3]]) : (index, index) -> ()
+ // CHECK: }
+ // CHECK: }
+
+ return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+func.func @nested(%lb1: index, %lb2: index, %lb3: index, %lb4: index, %ub1: index, %ub2: index, %ub3: index, %ub4: index, %step1: index, %step2: index, %step3: index, %step4: index) {
+ scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ scf.parallel (%k, %l) = (%lb3, %lb4) to (%ub3, %ub4) step (%step3, %step4) {
+ func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+ }
+ }
+ // CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG4:.*]] step %[[ARG8:.*]] {
+ // CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG5:.*]] step %[[ARG9:.*]] {
+ // CHECK: scf.for %[[VAL_2:.*]] = %[[ARG2:.*]] to %[[ARG6:.*]] step %[[ARG10:.*]] {
+ // CHECK: scf.for %[[VAL_3:.*]] = %[[ARG3:.*]] to %[[ARG7:.*]] step %[[ARG11:.*]] {
+ // CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (index, index, index, index) -> ()
+ // CHECK: }
+ // CHECK: }
+ // CHECK: }
+ // CHECK: }
+ return
+}
+
+// -----
+func.func private @callee(%i: index, %j: index) -> i32
+
+func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 {
+ %c0 = arith.constant 0 : i32
+ // expected-error at +1 {{Currently ScfParallel to ScfFor conversion doesn't support ScfParallel with results}}
+ %0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 {
+ %curr = func.call @callee(%i, %j) : (index, index) -> i32
+ scf.reduce(%curr : i32) {
+ ^bb0(%arg3: i32, %arg4: i32):
+ %3 = arith.addi %arg3, %arg4 : i32
+ scf.reduce.return %3 : i32
+ }
+ }
+ return %0 : i32
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..496123b288038
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+ scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ // CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+ // CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+ // CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+ // CHECK: }
+ // CHECK: }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+ scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+
+ scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{expected a single payload op}}
+ transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{expected the payload to be scf.parallel}}
+ transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/147692
More information about the Mlir-commits
mailing list