[llvm-branch-commits] [mlir] [MLIR][SCF] Update scf.parallel lowering to OpenMP (3/5) (PR #89212)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Apr 18 03:52:36 PDT 2024
https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/89212
This patch makes changes to the `scf.parallel` to `omp.parallel` + `omp.wsloop` lowering pass in order to introduce a nested `omp.loop_nest` as well, and to follow the new loop wrapper role for `omp.wsloop`.
This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.
>From fdee8cf17cd7d2dbdb6cf872574776f02e70be7c Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 18 Apr 2024 10:55:16 +0100
Subject: [PATCH] [MLIR][SCF] Update scf.parallel lowering to OpenMP (3/5)
This patch makes changes to the `scf.parallel` to `omp.parallel` + `omp.wsloop`
lowering pass in order to introduce a nested `omp.loop_nest` as well, and to
follow the new loop wrapper role for `omp.wsloop`.
This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
---
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 52 ++++++++++++++-----
.../Conversion/SCFToOpenMP/reductions.mlir | 5 ++
.../Conversion/SCFToOpenMP/scf-to-openmp.mlir | 31 ++++++++---
3 files changed, 68 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 7f91367ad427a2..f0b8d6c5309357 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -461,18 +461,51 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Replace the loop.
{
OpBuilder::InsertionGuard allocaGuard(rewriter);
- auto loop = rewriter.create<omp::WsloopOp>(
+ // Create worksharing loop wrapper.
+ auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
+ if (!reductionVariables.empty()) {
+ wsloopOp.setReductionsAttr(
+ ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
+ wsloopOp.getReductionVarsMutable().append(reductionVariables);
+ }
+ rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
+
+ // The wrapper's entry block arguments will define the reduction
+ // variables.
+ llvm::SmallVector<mlir::Type> reductionTypes;
+ reductionTypes.reserve(reductionVariables.size());
+ llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
+ [](mlir::Value v) { return v.getType(); });
+ rewriter.createBlock(
+ &wsloopOp.getRegion(), {}, reductionTypes,
+ llvm::SmallVector<mlir::Location>(reductionVariables.size(),
+ parallelOp.getLoc()));
+
+ rewriter.setInsertionPoint(
+ rewriter.create<omp::TerminatorOp>(parallelOp.getLoc()));
+
+ // Create loop nest and populate region with contents of scf.parallel.
+ auto loopOp = rewriter.create<omp::LoopNestOp>(
parallelOp.getLoc(), parallelOp.getLowerBound(),
parallelOp.getUpperBound(), parallelOp.getStep());
- rewriter.create<omp::TerminatorOp>(loc);
- rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
- loop.getRegion().begin());
+ rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
+ loopOp.getRegion().begin());
+
+ // Remove reduction-related block arguments from omp.loop_nest and
+ // redirect uses to the corresponding omp.wsloop block argument.
+ mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
+ unsigned numLoops = parallelOp.getNumLoops();
+ rewriter.replaceAllUsesWith(
+ loopOpEntryBlock.getArguments().drop_front(numLoops),
+ wsloopOp.getRegion().getArguments());
+ loopOpEntryBlock.eraseArguments(
+ numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
- Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
- loop.getRegion().begin()->begin());
+ Block *ops = rewriter.splitBlock(&*loopOp.getRegion().begin(),
+ loopOp.getRegion().begin()->begin());
- rewriter.setInsertionPointToStart(&*loop.getRegion().begin());
+ rewriter.setInsertionPointToStart(&*loopOp.getRegion().begin());
auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
TypeRange());
@@ -481,11 +514,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
rewriter.mergeBlocks(ops, scopeBlock);
rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
- if (!reductionVariables.empty()) {
- loop.setReductionsAttr(
- ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
- loop.getReductionVarsMutable().append(reductionVariables);
- }
}
}
diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
index 3b6c145d62f1a8..fc6d56559c2618 100644
--- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
@@ -28,6 +28,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK: omp.parallel
// CHECK: omp.wsloop
// CHECK-SAME: reduction(@[[$REDF]] %[[BUF]] -> %[[PVT_BUF:[a-z0-9]+]]
+ // CHECK: omp.loop_nest
// CHECK: memref.alloca_scope
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%zero) -> (f32) {
@@ -43,6 +44,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
}
// CHECK: omp.yield
}
+ // CHECK: omp.terminator
// CHECK: omp.terminator
// CHECK: llvm.load %[[BUF]]
return
@@ -107,6 +109,7 @@ func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index,
%one = arith.constant 1 : i32
// CHECK: %[[RED_VAR:.*]] = llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr
// CHECK: omp.wsloop reduction(@[[$REDI]] %[[RED_VAR]] -> %[[RED_PVT_VAR:.*]] : !llvm.ptr)
+ // CHECK: omp.loop_nest
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%one) -> (i32) {
// CHECK: %[[C2:.*]] = arith.constant 2 : i32
@@ -208,6 +211,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
// CHECK: omp.wsloop
// CHECK-SAME: reduction(@[[$REDF1]] %[[BUF1]] -> %[[PVT_BUF1:[a-z0-9]+]]
// CHECK-SAME: @[[$REDF2]] %[[BUF2]] -> %[[PVT_BUF2:[a-z0-9]+]]
+ // CHECK: omp.loop_nest
// CHECK: memref.alloca_scope
%res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
@@ -236,6 +240,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
}
// CHECK: omp.yield
}
+ // CHECK: omp.terminator
// CHECK: omp.terminator
// CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr -> f32
// CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr -> i64
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index acd2690c56e2e6..b2f19d294cb5fe 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -2,10 +2,11 @@
// CHECK-LABEL: @parallel
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
- %arg3: index, %arg4: index, %arg5: index) {
+ %arg3: index, %arg4: index, %arg5: index) {
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
- // CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> ()
@@ -13,6 +14,8 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.yield
// CHECK: }
}
+ // CHECK: omp.terminator
+ // CHECK: }
// CHECK: omp.terminator
// CHECK: }
return
@@ -23,20 +26,26 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
- // CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
- // CHECK: memref.alloca_scope
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+ // CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
// CHECK: omp.parallel
- // CHECK: omp.wsloop for (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
// CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()
"test.payload"(%i, %j) : (index, index) -> ()
// CHECK: }
}
- // CHECK: omp.yield
+ // CHECK: omp.yield
+ // CHECK: }
+ // CHECK: omp.terminator
// CHECK: }
}
+ // CHECK: omp.terminator
+ // CHECK: }
// CHECK: omp.terminator
// CHECK: }
return
@@ -47,7 +56,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
- // CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
// CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> ()
@@ -55,12 +65,15 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.yield
// CHECK: }
}
+ // CHECK: omp.terminator
+ // CHECK: }
// CHECK: omp.terminator
// CHECK: }
// CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
- // CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
// CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> ()
@@ -68,6 +81,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.yield
// CHECK: }
}
+ // CHECK: omp.terminator
+ // CHECK: }
// CHECK: omp.terminator
// CHECK: }
return
More information about the llvm-branch-commits
mailing list