[Mlir-commits] [mlir] [mlir][scf] Implement conversion from scf.forall to scf.parallel (PR #94109)
Spenser Bauman
llvmlistbot at llvm.org
Mon Jun 3 05:37:26 PDT 2024
https://github.com/sabauma updated https://github.com/llvm/llvm-project/pull/94109
>From 68baef86e69e97c2152bf82bb1d7c891f67f0a5c Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Fri, 31 May 2024 17:56:38 -0400
Subject: [PATCH 1/2] [mlir][scf] Implement conversion from scf.forall to
scf.parallel
There is currently no path to lower scf.forall to scf.parallel with the
goal of targeting the OpenMP dialect.
In the SCF->ControlFlow conversion, scf.forall is briefly converted to
scf.parallel, but the scf.parallel is lowered directly to a sequential
loop. This makes experimenting with scf.forall for CPU execution
difficult.
This change factors out the rewrite in the SCF->ControlFlow pass into
a utility function that can then be used in the SCF->ControlFlow
lowering, but also in a separate -scf-forall-to-parallel pass.
---
.../SCF/TransformOps/SCFTransformOps.td | 26 ++++++
.../mlir/Dialect/SCF/Transforms/Passes.h | 3 +
.../mlir/Dialect/SCF/Transforms/Passes.td | 5 ++
.../mlir/Dialect/SCF/Transforms/Transforms.h | 5 ++
.../SCFToControlFlow/SCFToControlFlow.cpp | 29 +------
.../SCF/TransformOps/SCFTransformOps.cpp | 44 ++++++++++
.../lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 +
.../SCF/Transforms/ForallToParallel.cpp | 82 +++++++++++++++++++
mlir/test/Dialect/SCF/forall-to-parallel.mlir | 62 ++++++++++++++
.../SCF/transform-op-forall-to-parallel.mlir | 60 ++++++++++++++
10 files changed, 290 insertions(+), 27 deletions(-)
create mode 100644 mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
create mode 100644 mlir/test/Dialect/SCF/forall-to-parallel.mlir
create mode 100644 mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..3d7fe7b0f093f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}
+def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Converts scf.forall into a nest of scf.for operations";
+ let description = [{
+ Converts the `scf.forall` operation pointed to by the given handle into an
+ `scf.parallel` operation.
+
+ The operand handle must be associated with exactly one payload operation.
+
+ Loops with outputs are not supported.
+
+ #### Return Modes
+
+ Consumes the operand handle. Produces a silenceable failure if the operand
+ is not associated with a single `scf.forall` payload operation.
+ Returns a handle to the new `scf.parallel` operation.
+ Produces a silenceable failure if another number of resulting handles is
+ requested.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 31c3d0eb629d2..fb8411418ff9a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
/// Creates a pass that converts SCF forall loops to SCF for loops.
std::unique_ptr<Pass> createForallToForLoopPass();
+/// Creates a pass that converts SCF forall loops to SCF parallel loops.
+std::unique_ptr<Pass> createForallToParallelLoopPass();
+
// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index a7aeb42d60c0e..9b29affb97c43 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
let constructor = "mlir::createForallToForLoopPass()";
}
+def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
+ let summary = "Convert SCF forall loops to SCF parallel loops";
+ let constructor = "mlir::createForallToParallelLoopPass()";
+}
+
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index b063e6e775e63..186331738d64b 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -39,6 +39,11 @@ class WhileOp;
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
SmallVectorImpl<Operation *> *results = nullptr);
+/// Try converting scf.forall into an scf.parallel loop.
+/// The conversion is only supported for forall operations with no results.
+LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
+ ParallelOp *result = nullptr);
+
/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 9eb8a289d7d65..16f1db44acc35 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
@@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
PatternRewriter &rewriter) const {
- Location loc = forallOp.getLoc();
- if (!forallOp.getOutputs().empty())
- return rewriter.notifyMatchFailure(
- forallOp,
- "only fully bufferized scf.forall ops can be lowered to scf.parallel");
-
- // Convert mixed bounds and steps to SSA values.
- SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
- rewriter, loc, forallOp.getMixedLowerBound());
- SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
- rewriter, loc, forallOp.getMixedUpperBound());
- SmallVector<Value> steps =
- getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
-
- // Create empty scf.parallel op.
- auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
- rewriter.eraseBlock(¶llelOp.getRegion().front());
- rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
- parallelOp.getRegion().begin());
- // Replace the terminator.
- rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front());
- rewriter.replaceOpWithNewOp<scf::ReduceOp>(
- parallelOp.getRegion().front().getTerminator());
-
- // Erase the scf.forall op.
- rewriter.replaceOp(forallOp, parallelOp);
- return success();
+ return scf::forallToParallelLoop(rewriter, forallOp);
}
void mlir::populateSCFToControlFlowConversionPatterns(
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..30699ecdde0a2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// ForallToForOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payload = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(payload))
+ return emitSilenceableError() << "expected a single payload op";
+
+ auto target = dyn_cast<scf::ForallOp>(*payload.begin());
+ if (!target) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "expected the payload to be scf.forall";
+ diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+ return diag;
+ }
+
+ if (!target.getOutputs().empty()) {
+ return emitSilenceableError()
+ << "unsupported shared outputs (didn't bufferize?)";
+ }
+
+ if (getNumResults() != 1) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "op expects one result, given "
+ << getNumResults();
+ diag.attachNote(target.getLoc()) << "payload op";
+ return diag;
+ }
+
+ scf::ParallelOp opResult;
+ if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "failed to convert forall into parallel";
+ return diag;
+ }
+
+ results.set(cast<OpResult>(getTransformed()[0]), {opResult});
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index e7671c9cc28f8..d363ffe941fce 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ForallToFor.cpp
+ ForallToParallel.cpp
ForToWhile.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
new file mode 100644
index 0000000000000..37ded4e2e3371
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -0,0 +1,82 @@
+//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForallOp's into SCF.ParallelOps's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
+ scf::ForallOp forallOp,
+ scf::ParallelOp *result) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(forallOp);
+
+ Location loc = forallOp.getLoc();
+ if (!forallOp.getOutputs().empty())
+ return rewriter.notifyMatchFailure(
+ forallOp,
+ "only fully bufferized scf.forall ops can be lowered to scf.parallel");
+
+ // Convert mixed bounds and steps to SSA values.
+ SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
+ rewriter, loc, forallOp.getMixedLowerBound());
+ SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
+ rewriter, loc, forallOp.getMixedUpperBound());
+ SmallVector<Value> steps =
+ getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
+
+ // Create empty scf.parallel op.
+ auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
+ rewriter.eraseBlock(¶llelOp.getRegion().front());
+ rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
+ parallelOp.getRegion().begin());
+ // Replace the terminator.
+ rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front());
+ rewriter.replaceOpWithNewOp<scf::ReduceOp>(
+ parallelOp.getRegion().front().getTerminator());
+
+ // Erase the scf.forall op.
+ rewriter.replaceOp(forallOp, parallelOp);
+
+ if (result)
+ *result = parallelOp;
+
+ return success();
+}
+
+namespace {
+struct ForallToParallelLoop final
+ : public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
+ void runOnOperation() override {
+ Operation *parentOp = getOperation();
+ IRRewriter rewriter(parentOp->getContext());
+
+ parentOp->walk([&](scf::ForallOp forallOp) {
+ if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
+ return signalPassFailure();
+ }
+ });
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
+ return std::make_unique<ForallToParallelLoop>();
+}
diff --git a/mlir/test/Dialect/SCF/forall-to-parallel.mlir b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
new file mode 100644
index 0000000000000..424ba01fc3a66
--- /dev/null
+++ b/mlir/test/Dialect/SCF/forall-to-parallel.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+
+ // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+ // CHECK: scf.reduce
+ return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @repeated
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @repeated(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+
+ // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+ // CHECK: scf.reduce
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+
+ // CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+ // CHECK: func.call @callee(%[[IV3]], %[[IV4]])
+ // CHECK: scf.reduce
+ return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+// CHECK-LABEL: @nested
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
+func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
+ // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
+ // CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
+ // CHECK: scf.reduce
+ // CHECK: }
+ // CHECK: scf.reduce
+ // CHECK: }
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ scf.forall (%k, %l) in (%ub3, %ub4) {
+ func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+ }
+ }
+ return
+}
diff --git a/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
new file mode 100644
index 0000000000000..b64798e06a4d1
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+// CHECK-LABEL: @two_iters
+// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
+func.func @two_iters(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ // CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
+ // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
+ // CHECK: scf.reduce
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%ub1: index, %ub2: index) {
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ scf.forall (%i, %j) in (%ub1, %ub2) {
+ func.call @callee(%i, %j) : (index, index) -> ()
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{expected a single payload op}}
+ transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{expected the payload to be scf.forall}}
+ transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
>From 059261288170ad8a69bf7eb4de40ba38324fcb06 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Mon, 3 Jun 2024 08:37:12 -0400
Subject: [PATCH 2/2] Address feedback from @adam-smnk
---
mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt | 1 +
mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp | 4 ++++
2 files changed, 5 insertions(+)
diff --git a/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt b/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt
index 6217976159fbb..63c5199af9290 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt
+++ b/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt
@@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRSCFToControlFlow
MLIRArithDialect
MLIRControlFlowDialect
MLIRSCFDialect
+ MLIRSCFTransforms
MLIRTransforms
)
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
index 37ded4e2e3371..1fc0331300379 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
@@ -52,6 +52,10 @@ LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());
+ // If the mapping attribute is present, propagate to the new parallelOp.
+ if (forallOp.getMapping())
+ parallelOp->setAttr("mapping", *forallOp.getMapping());
+
// Erase the scf.forall op.
rewriter.replaceOp(forallOp, parallelOp);
More information about the Mlir-commits
mailing list