[Mlir-commits] [mlir] [mlir][scf] Implement Conversion from scf.parallel to Nested scf.for (PR #147692)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 9 03:44:48 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Michael Marjieh (mmarjieh)

<details>
<summary>Changes</summary>

Add a utility function/transform operation to convert `scf.parallel` loops to nested `scf.for` loops.

---
Full diff: https://github.com/llvm/llvm-project/pull/147692.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+28) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.h (+3) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+11) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h (+6) 
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+38) 
- (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp (+91) 
- (added) mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir (+80) 
- (added) mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir (+62) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5dba8c5e57ba8..e2b42208f3f8e 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -105,6 +105,34 @@ def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
 }
 
+def ParallelForToNestedForOps : Op<Transform_Dialect, "loop.parallel_for_to_nested_fors",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Converts scf.parallel into a nest of scf.for operations";
+  let description = [{
+    Converts the `scf.parallel` operation pointed to by the given handle into a
+    set of nested `scf.for` operations. Each new operation corresponds to one
+    dimension of the original parallel loop.
+
+    The operand handle must be associated with exactly one payload operation.
+
+    Loops with shared outputs are currently not supported.
+
+    #### Return Modes
+
+    Consumes the operand handle. Produces a silenceable failure if the operand
+    is not associated with a single `scf.parallel` payload operation.
+    Returns as many handles as the given `parallel` op has dimensions that are
+    associated with the generated `scf.for` loops.
+    Produces a silenceable failure if another number of resulting handles is
+    requested.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index b70599df6f503..54b0118507184 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForallToForLoopPass();
 /// Creates a pass that converts SCF forall loops to SCF parallel loops.
 std::unique_ptr<Pass> createForallToParallelLoopPass();
 
+/// Creates a pass that converts SCF forall loops to SCF parallel loops.
+std::unique_ptr<Pass> createParallelForToNestedForsPass();
+
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..afa4ef460c219 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -124,6 +124,17 @@ def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
   let constructor = "mlir::createForallToParallelLoopPass()";
 }
 
+def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> {
+  let summary = "Convert SCF parallel for loops to nested SCF for loops";
+  let constructor = "mlir::createParallelForToNestedForsPass()";
+  let description = [{
+    This pass transforms SCF.ParallelOp operations into a nest of SCF.ForOp
+    operations. The transformation is useful for cases where the parallel loop
+    can be expressed as a series of sequential iterations, allowing for more
+    fine-grained control over the loop execution.
+  }];
+}
+
 def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 63163b77f7f16..5e613238d016d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -42,6 +42,12 @@ LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
 LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
                                    ParallelOp *result = nullptr);
 
+/// Try converting scf.forall into an scf.parallel loop.
+/// The conversion is only supported for forall operations with no results.
+LogicalResult parallelForToNestedFors(RewriterBase &rewriter,
+                                      ParallelOp parallelOp,
+                                      ForOp *result = nullptr);
+
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 57c27231f2144..7fd9255c490ef 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -149,6 +149,44 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ParallelForToNestedForOps
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  auto payload = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payload))
+    return emitSilenceableError() << "expected a single payload op";
+
+  auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
+  if (!target) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "expected the payload to be scf.parallel";
+    diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+    return diag;
+  }
+
+  if (getNumResults() != 1) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "op expects one result, given "
+                                       << getNumResults();
+    diag.attachNote(target.getLoc()) << "payload op";
+    return diag;
+  }
+
+  scf::ForOp opResult;
+  if (failed(scf::parallelForToNestedFors(rewriter, target, &opResult))) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "failed to convert parallel into nested fors";
+    return diag;
+  }
+
+  results.set(cast<OpResult>(getTransformed()[0]), {opResult});
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopOutlineOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..a9ffa9dc208a0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   LoopPipelining.cpp
   LoopRangeFolding.cpp
   LoopSpecialization.cpp
+  ParallelForToNestedFors.cpp
   ParallelLoopCollapsing.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
new file mode 100644
index 0000000000000..75672f1c9239e
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
@@ -0,0 +1,91 @@
+//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ParallelOp to nested scf.for ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
+                                                 scf::ParallelOp parallelOp,
+                                                 scf::ForOp *result) {
+
+  if (!parallelOp.getResults().empty()) {
+    parallelOp->emitError("Currently ScfParallel to ScfFor conversion "
+                          "doesn't support ScfParallel with results.");
+    return failure();
+  }
+
+  rewriter.setInsertionPoint(parallelOp);
+
+  Location loc = parallelOp.getLoc();
+  auto lowerBounds = parallelOp.getLowerBound();
+  auto upperBounds = parallelOp.getUpperBound();
+  auto steps = parallelOp.getStep();
+
+  assert(lowerBounds.size() == upperBounds.size() &&
+         lowerBounds.size() == steps.size() &&
+         "Mismatched parallel loop bounds");
+
+  SmallVector<Value> ivs;
+  auto loopNest =
+      scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
+
+  auto oldInductionVars = parallelOp.getInductionVars();
+  auto newInductionVars = llvm::map_to_vector(
+      loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
+  assert(oldInductionVars.size() == newInductionVars.size() &&
+         "Mismatched induction variables");
+  for (auto [oldIV, newIV] : llvm::zip(oldInductionVars, newInductionVars))
+    oldIV.replaceAllUsesWith(newIV);
+
+  auto *linearizedBody = loopNest.loops.back().getBody();
+  Block &parallelBody = *parallelOp.getBody();
+  for (Operation &op : llvm::make_early_inc_range(parallelBody)) {
+    // Skip the terminator of the parallelOp body.
+    if (&op == parallelBody.getTerminator())
+      continue;
+    op.moveBefore(linearizedBody->getTerminator());
+  }
+  rewriter.eraseOp(parallelOp);
+  if (result)
+    *result = loopNest.loops.front();
+  return success();
+}
+
+namespace {
+struct ParallelForToNestedFors final
+    : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
+  void runOnOperation() override {
+    Operation *parentOp = getOperation();
+    IRRewriter rewriter(parentOp->getContext());
+
+    parentOp->walk([&](scf::ParallelOp parallelOp) {
+      if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
+        return signalPassFailure();
+      }
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
+  return std::make_unique<ParallelForToNestedFors>();
+}
diff --git a/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..4df7bab790ea5
--- /dev/null
+++ b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-for-to-nested-fors))' -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  // CHECK:           scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG2]] step %[[ARG4]] {
+  // CHECK:             scf.for %[[VAL_3:.*]] = %[[ARG1]] to %[[ARG3]] step %[[ARG5]] {
+  // CHECK:               func.call @callee(%[[VAL_2]], %[[VAL_3]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+func.func @nested(%lb1: index, %lb2: index, %lb3: index, %lb4: index, %ub1: index, %ub2: index, %ub3: index, %ub4: index, %step1: index, %step2: index, %step3: index, %step4: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    scf.parallel (%k, %l) = (%lb3, %lb4) to (%ub3, %ub4) step (%step3, %step4) {
+      func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+    }
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG4:.*]] step %[[ARG8:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG5:.*]] step %[[ARG9:.*]] {
+  // CHECK:               scf.for %[[VAL_2:.*]] = %[[ARG2:.*]] to %[[ARG6:.*]] step %[[ARG10:.*]] {
+  // CHECK:                 scf.for %[[VAL_3:.*]] = %[[ARG3:.*]] to %[[ARG7:.*]] step %[[ARG11:.*]] {
+  // CHECK:                   func.call @callee(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (index, index, index, index) -> ()
+  // CHECK:                 }
+  // CHECK:               }
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+// -----
+func.func private @callee(%i: index, %j: index) -> i32
+
+func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 {
+  %c0 = arith.constant 0 : i32
+  // expected-error at +1 {{Currently ScfParallel to ScfFor conversion doesn't support ScfParallel with results}}
+  %0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 {
+    %curr = func.call @callee(%i, %j) : (index, index) -> i32
+    scf.reduce(%curr : i32) {
+      ^bb0(%arg3: i32, %arg4: i32):
+        %3 = arith.addi %arg3, %arg4 : i32
+        scf.reduce.return %3 : i32
+    }
+  }
+  return %0 : i32
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..496123b288038
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected a single payload op}}
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected the payload to be scf.parallel}}
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/147692


More information about the Mlir-commits mailing list