[Mlir-commits] [mlir] [mlir][scf] Implement conversion from scf.forall to scf.parallel (PR #94109)

Spenser Bauman llvmlistbot at llvm.org
Mon Jun 3 05:37:26 PDT 2024


https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/94109

>From 68baef86e69e97c2152bf82bb1d7c891f67f0a5c Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Fri, 31 May 2024 17:56:38 -0400
Subject: [PATCH 1/2] [mlir][scf] Implement conversion from scf.forall to
 scf.parallel

There is currently no path to lower scf.forall to scf.parallel with the
goal of targeting the OpenMP dialect.

In the SCF->ControlFlow conversion, scf.forall is briefly converted to
scf.parallel, but the scf.parallel is lowered directly to a sequential
loop. This makes experimenting with scf.forall for CPU execution
difficult.

This change factors out the rewrite in the SCF->ControlFlow pass into
a utility function that can then be used in the SCF->ControlFlow
lowering, but also in a separate -scf-forall-to-parallel pass.
---
 .../SCF/TransformOps/SCFTransformOps.td       | 26 ++++++
 .../mlir/Dialect/SCF/Transforms/Passes.h      |  3 +
 .../mlir/Dialect/SCF/Transforms/Passes.td     |  5 ++
 .../mlir/Dialect/SCF/Transforms/Transforms.h  |  5 ++
 .../SCFToControlFlow/SCFToControlFlow.cpp     | 29 +------
 .../SCF/TransformOps/SCFTransformOps.cpp      | 44 ++++++++++
 .../lib/Dialect/SCF/Transforms/CMakeLists.txt |  1 +
 .../SCF/Transforms/ForallToParallel.cpp       | 82 +++++++++++++++++++
 mlir/test/Dialect/SCF/forall-to-parallel.mlir | 62 ++++++++++++++
 .../SCF/transform-op-forall-to-parallel.mlir  | 60 ++++++++++++++
 10 files changed, 290 insertions(+), 27 deletions(-)
 create mode 100644 mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
 create mode 100644 mlir/test/Dialect/SCF/forall-to-parallel.mlir
 create mode 100644 mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..3d7fe7b0f093f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
 }
 
+def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Converts scf.forall into a nest of scf.for operations";
+  let description = [{
+    Converts the `scf.forall` operation pointed to by the given handle into an
+    `scf.parallel` operation.
+
+    The operand handle must be associated with exactly one payload operation.
+
+    Loops with outputs are not supported.
+
+    #### Return Modes
+
+    Consumes the operand handle. Produces a silenceable failure if the operand
+    is not associated with a single `scf.forall` payload operation.
+    Returns a handle to the new `scf.parallel` operation.
+    Produces a silenceable failure if another number of resulting handles is
+    requested.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 31c3d0eb629d2..fb8411418ff9a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
 /// Creates a pass that converts SCF forall loops to SCF for loops.
 std::unique_ptr<Pass> createForallToForLoopPass();
 
+/// Creates a pass that converts SCF forall loops to SCF parallel loops.
+std::unique_ptr<Pass> createForallToParallelLoopPass();
+
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index a7aeb42d60c0e..9b29affb97c43 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
   let constructor = "mlir::createForallToForLoopPass()";
 }
 
+def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
+  let summary = "Convert SCF forall loops to SCF parallel loops";
+  let constructor = "mlir::createForallToParallelLoopPass()";
+}
+
 def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index b063e6e775e63..186331738d64b 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -39,6 +39,11 @@ class WhileOp;
 LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
                               SmallVectorImpl<Operation *> *results = nullptr);
 
+/// Try converting scf.forall into an scf.parallel loop.
+/// The conversion is only supported for forall operations with no results.
+LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
+                                   ParallelOp *result = nullptr);
+
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 9eb8a289d7d65..16f1db44acc35 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
@@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
 
 LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
                                               PatternRewriter &rewriter) const {
-  Location loc = forallOp.getLoc();
-  if (!forallOp.getOutputs().empty())
-    return rewriter.notifyMatchFailure(
-        forallOp,
-        "only fully bufferized scf.forall ops can be lowered to scf.parallel");
-
-  // Convert mixed bounds and steps to SSA values.
-  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedLowerBound());
-  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
-      rewriter, loc, forallOp.getMixedUpperBound());
-  SmallVector<Value> steps =
-      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
-
-  // Create empty scf.parallel op.
-  auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
-  rewriter.eraseBlock(&parallelOp.getRegion().front());
-  rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
-                              parallelOp.getRegion().begin());
-  // Replace the terminator.
-  rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
-  rewriter.replaceOpWithNewOp<scf::ReduceOp>(
-      parallelOp.getRegion().front().getTerminator());
-
-  // Erase the scf.forall op.
-  rewriter.replaceOp(forallOp, parallelOp);
-  return success();
+  return scf::forallToParallelLoop(rewriter, forallOp);
 }
 
 void mlir::populateSCFToControlFlowConversionPatterns(
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..30699ecdde0a2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ForallToForOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
+                                     transform::TransformResults &results,
+                                     transform::TransformState &state) {
+  auto payload = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payload))
+    return emitSilenceableError() << "expected a single payload op";
+
+  auto target = dyn_cast<scf::ForallOp>(*payload.begin());
+  if (!target) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "expected the payload to be scf.forall";
+    diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+    return diag;
+  }
+
+  if (!target.getOutputs().empty()) {
+    return emitSilenceableError()
+           << "unsupported shared outputs (didn't bufferize?)";
+  }
+
+  if (getNumResults() != 1) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "op expects one result, given "
+                                       << getNumResults();
+    diag.attachNote(target.getLoc()) << "payload op";
+    return diag;
+  }
+
+  scf::ParallelOp opResult;
+  if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "failed to convert forall into parallel";
+    return diag;
+  }
+
+  results.set(cast<OpResult>(getTransformed()[0]), {opResult});
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopOutlineOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index e7671c9cc28f8..d363ffe941fce 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   ForallToFor.cpp
+  ForallToParallel.cpp
   ForToWhile.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
new file mode 100644
index 0000000000000..37ded4e2e3371
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -0,0 +1,82 @@
+//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForallOp's into SCF.ParallelOps's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
+                                              scf::ForallOp forallOp,
+                                              scf::ParallelOp *result) {
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(forallOp);
+
+  Location loc = forallOp.getLoc();
+  if (!forallOp.getOutputs().empty())
+    return rewriter.notifyMatchFailure(
+        forallOp,
+        "only fully bufferized scf.forall ops can be lowered to scf.parallel");
+
+  // Convert mixed bounds and steps to SSA values.
+  SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedLowerBound());
+  SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
+      rewriter, loc, forallOp.getMixedUpperBound());
+  SmallVector<Value> steps =
+      getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+
+  // Create empty scf.parallel op.
+  auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
+  rewriter.eraseBlock(&parallelOp.getRegion().front());
+  rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
+                              parallelOp.getRegion().begin());
+  // Replace the terminator.
+  rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
+  rewriter.replaceOpWithNewOp<scf::ReduceOp>(
+      parallelOp.getRegion().front().getTerminator());
+
+  // Erase the scf.forall op.
+  rewriter.replaceOp(forallOp, parallelOp);
+
+  if (result)
+    *result = parallelOp;
+
+  return success();
+}
+
+namespace {
+struct ForallToParallelLoop final
+    : public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
+  void runOnOperation() override {
+    Operation *parentOp = getOperation();
+    IRRewriter rewriter(parentOp->getContext());
+
+    parentOp->walk([&](scf::ForallOp forallOp) {
+      if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
+        return signalPassFailure();
+      }
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
+  return std::make_unique<ForallToParallelLoop>();
+}
diff --git a/mlir/test/Dialect/SCF/forall-to-parallel.mlir b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
new file mode 100644
index 0000000000000..424ba01fc3a66
--- /dev/null
+++ b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+  // CHECK:   func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+  // CHECK:   scf.reduce
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @repeated
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @repeated(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+  // CHECK:   func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+  // CHECK:   scf.reduce
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  // CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+  // CHECK:   func.call @callee(%[[IV3]], %[[IV4]])
+  // CHECK:   scf.reduce
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+// CHECK-LABEL: @nested
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
+func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
+  // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
+  // CHECK:   scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
+  // CHECK:     func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
+  // CHECK:     scf.reduce
+  // CHECK:   }
+  // CHECK:   scf.reduce
+  // CHECK: }
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    scf.forall (%k, %l) in (%ub3, %ub4) {
+      func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+    }
+  }
+  return
+}
diff --git a/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
new file mode 100644
index 0000000000000..b64798e06a4d1
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+  // CHECK:   func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+  // CHECK:   scf.reduce
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  scf.forall (%i, %j) in (%ub1, %ub2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected a single payload op}}
+    transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected the payload to be scf.forall}}
+    transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

>From 059261288170ad8a69bf7eb4de40ba38324fcb06 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Mon, 3 Jun 2024 08:37:12 -0400
Subject: [PATCH 2/2] Address feedback from @adam-smnk

---
 mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt  | 1 +
 mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp | 4 ++++
 2 files changed, 5 insertions(+)

diff --git a/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt b/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt
index 6217976159fbb..63c5199af9290 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt
+++ b/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt
@@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRSCFToControlFlow
   MLIRArithDialect
   MLIRControlFlowDialect
   MLIRSCFDialect
+  MLIRSCFTransforms
   MLIRTransforms
   )
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
index 37ded4e2e3371..1fc0331300379 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -52,6 +52,10 @@ LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
   rewriter.replaceOpWithNewOp<scf::ReduceOp>(
       parallelOp.getRegion().front().getTerminator());
 
+  // If the mapping attribute is present, propagate to the new parallelOp.
+  if (forallOp.getMapping())
+    parallelOp->setAttr("mapping", *forallOp.getMapping());
+
   // Erase the scf.forall op.
   rewriter.replaceOp(forallOp, parallelOp);
 



More information about the Mlir-commits mailing list