[Mlir-commits] [mlir] 12b9c0d - [mlir][scf] Implement Conversion from scf.parallel to Nested scf.for (#147692)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 4 06:21:54 PDT 2025


Author: Michael Marjieh
Date: 2025-08-04T06:21:50-07:00
New Revision: 12b9c0da04a0b34cb12000bccf5b90e1f98d23d6

URL: https://github.com/llvm/llvm-project/commit/12b9c0da04a0b34cb12000bccf5b90e1f98d23d6
DIFF: https://github.com/llvm/llvm-project/commit/12b9c0da04a0b34cb12000bccf5b90e1f98d23d6.diff

LOG: [mlir][scf] Implement Conversion from scf.parallel to Nested scf.for (#147692)

Add a utility function/transform operation to convert `scf.parallel`
loops to nested `scf.for` loops.

Added: 
    mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
    mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
    mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
    mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
    mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
    mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5dba8c5e57ba8..e2b42208f3f8e 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -105,6 +105,34 @@ def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
 }
 
+def ParallelForToNestedForOps : Op<Transform_Dialect, "loop.parallel_for_to_nested_fors",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Converts scf.parallel into a nest of scf.for operations";
+  let description = [{
+    Converts the `scf.parallel` operation pointed to by the given handle into a
+    set of nested `scf.for` operations. Each new operation corresponds to one
+    dimension of the original parallel loop.
+
+    The operand handle must be associated with exactly one payload operation.
+
+    Loops with shared outputs are currently not supported.
+
+    #### Return Modes
+
+    Consumes the operand handle. Produces a silenceable failure if the operand
+    is not associated with a single `scf.parallel` payload operation.
+    Returns as many handles as the given `parallel` op has dimensions that are
+    associated with the generated `scf.for` loops.
+    Produces a silenceable failure if another number of resulting handles is
+    requested.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index b70599df6f503..54b0118507184 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForallToForLoopPass();
 /// Creates a pass that converts SCF forall loops to SCF parallel loops.
 std::unique_ptr<Pass> createForallToParallelLoopPass();
 
+/// Creates a pass that converts SCF forall loops to SCF parallel loops.
+std::unique_ptr<Pass> createParallelForToNestedForsPass();
+
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index ca2510bb53af9..8b891aa374b58 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -130,6 +130,17 @@ def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
   let constructor = "mlir::createForallToParallelLoopPass()";
 }
 
+def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> {
+  let summary = "Convert SCF parallel for loops to nested SCF for loops";
+  let constructor = "mlir::createParallelForToNestedForsPass()";
+  let description = [{
+    This pass transforms SCF::ParallelOp operations into a nest of SCF::ForOp
+    operations. The transformation is useful for cases where the parallel loop
+    can be expressed as a series of sequential iterations, allowing for more
+    fine-grained control over the loop execution.
+  }];
+}
+
 def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 63163b77f7f16..00e8572307151 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TRANSFORMS_H_
 #define MLIR_DIALECT_SCF_TRANSFORMS_TRANSFORMS_H_
 
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -42,6 +43,11 @@ LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
 LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
                                    ParallelOp *result = nullptr);
 
+/// Try converting scf.forall into an scf.parallel loop.
+/// The conversion is only supported for parallel operations with no results.
+FailureOr<scf::LoopNest> parallelForToNestedFors(RewriterBase &rewriter,
+                                                 ParallelOp parallelOp);
+
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index aea842dc59a39..71fe9870ac170 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -146,6 +146,45 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ParallelForToNestedForOps
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  auto payload = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payload))
+    return emitSilenceableError() << "expected a single payload op";
+
+  auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
+  if (!target) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "expected the payload to be scf.parallel";
+    diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+    return diag;
+  }
+
+  if (getNumResults() != 1) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "op expects one result, given "
+                                       << getNumResults();
+    diag.attachNote(target.getLoc()) << "payload op";
+    return diag;
+  }
+
+  FailureOr<scf::LoopNest> loopNest =
+      scf::parallelForToNestedFors(rewriter, target);
+  if (failed(loopNest)) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "failed to convert parallel into nested fors";
+    return diag;
+  }
+
+  results.set(cast<OpResult>(getTransformed()[0]), {loopNest->loops.front()});
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopOutlineOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 6d3bafbbc90e4..a07d9d4953d19 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   LoopPipelining.cpp
   LoopRangeFolding.cpp
   LoopSpecialization.cpp
+  ParallelForToNestedFors.cpp
   ParallelLoopCollapsing.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
new file mode 100644
index 0000000000000..8f7d5e308f433
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
@@ -0,0 +1,86 @@
+//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ParallelOp to nested scf.for ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "parallel-for-to-nested-fors"
+using namespace mlir;
+
+FailureOr<scf::LoopNest>
+mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
+                                   scf::ParallelOp parallelOp) {
+
+  if (!parallelOp.getResults().empty())
+    return rewriter.notifyMatchFailure(
+        parallelOp, "Currently scf.parallel to scf.for conversion doesn't "
+                    "support scf.parallel with results.");
+
+  rewriter.setInsertionPoint(parallelOp);
+
+  Location loc = parallelOp.getLoc();
+  SmallVector<Value> lowerBounds = parallelOp.getLowerBound();
+  SmallVector<Value> upperBounds = parallelOp.getUpperBound();
+  SmallVector<Value> steps = parallelOp.getStep();
+
+  assert(lowerBounds.size() == upperBounds.size() &&
+         lowerBounds.size() == steps.size() &&
+         "Mismatched parallel loop bounds");
+
+  SmallVector<Value> ivs;
+  scf::LoopNest loopNest =
+      scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
+
+  SmallVector<Value> newInductionVars = llvm::map_to_vector(
+      loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
+  Block *linearizedBody = loopNest.loops.back().getBody();
+  Block *parallelBody = parallelOp.getBody();
+  rewriter.eraseOp(parallelBody->getTerminator());
+  rewriter.inlineBlockBefore(parallelBody, linearizedBody->getTerminator(),
+                             newInductionVars);
+  rewriter.eraseOp(parallelOp);
+  return loopNest;
+}
+
+namespace {
+struct ParallelForToNestedFors final
+    : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
+  void runOnOperation() override {
+    Operation *parentOp = getOperation();
+    IRRewriter rewriter(parentOp->getContext());
+
+    parentOp->walk(
+        [&](scf::ParallelOp parallelOp) {
+          if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
+            LLVM_DEBUG(
+                llvm::dbgs()
+                << "Failed to convert scf.parallel to nested scf.for ops for:\n"
+                << parallelOp << "\n");
+            return WalkResult::advance();
+          }
+          return WalkResult::advance();
+        });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
+  return std::make_unique<ParallelForToNestedFors>();
+}

diff  --git a/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..47a8da8c244b2
--- /dev/null
+++ b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-for-to-nested-fors))' -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  // CHECK:           scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG2]] step %[[ARG4]] {
+  // CHECK:             scf.for %[[VAL_3:.*]] = %[[ARG1]] to %[[ARG3]] step %[[ARG5]] {
+  // CHECK:               func.call @callee(%[[VAL_2]], %[[VAL_3]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+func.func @nested(%lb1: index, %lb2: index, %lb3: index, %lb4: index, %ub1: index, %ub2: index, %ub3: index, %ub4: index, %step1: index, %step2: index, %step3: index, %step4: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    scf.parallel (%k, %l) = (%lb3, %lb4) to (%ub3, %ub4) step (%step3, %step4) {
+      func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+    }
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG4:.*]] step %[[ARG8:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG5:.*]] step %[[ARG9:.*]] {
+  // CHECK:               scf.for %[[VAL_2:.*]] = %[[ARG2:.*]] to %[[ARG6:.*]] step %[[ARG10:.*]] {
+  // CHECK:                 scf.for %[[VAL_3:.*]] = %[[ARG3:.*]] to %[[ARG7:.*]] step %[[ARG11:.*]] {
+  // CHECK:                   func.call @callee(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (index, index, index, index) -> ()
+  // CHECK:                 }
+  // CHECK:               }
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+// -----
+func.func private @callee(%i: index, %j: index) -> i32
+
+func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 {
+  %c0 = arith.constant 0 : i32
+  // CHECK: scf.parallel
+  %0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 {
+    %curr = func.call @callee(%i, %j) : (index, index) -> i32
+    scf.reduce(%curr : i32) {
+      ^bb0(%arg3: i32, %arg4: i32):
+        %3 = arith.addi %arg3, %arg4 : i32
+        scf.reduce.return %3 : i32
+    }
+  }
+  return %0 : i32
+}

diff  --git a/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..496123b288038
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected a single payload op}}
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected the payload to be scf.parallel}}
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}


        


More information about the Mlir-commits mailing list