[flang-commits] [flang] [mlir][OpenMP] - Transform target offloading directives with dependencies during PFT to MLIR conversion (PR #85130)
Pranav Bhandarkar via flang-commits
flang-commits at lists.llvm.org
Wed Mar 13 13:58:24 PDT 2024
https://github.com/bhandarkar-pranav created https://github.com/llvm/llvm-project/pull/85130
This PR changes PFT to MLIR lowering for the following directives when they have the `depend` clause on them.
```
!$ omp target depend(..)
!$ omp target enter data depend(..)
!$ omp target update data depend(..)
!$ omp target exit data depend(..)
```
With this PR, lowering now involves the creation of an `omp.task` operation that encloses the `omp.target` operation that is otherwise generated for the target construct. In addition, the depend clause from the target is moved to the enclosing new `omp.task`. The new `omp.task` is a mergeable task.
```
!$ omp target map(..) depend(in:a)
b = a
```
is transformed to the following MLIR
```
omp.task mergeable depend(in:a) {
omp.target map(..) {
//MLIR for b = a;
}
omp.terminator
}
```
>[!NOTE]
>This PR is an alternative to https://github.com/llvm/llvm-project/pull/83966. Its benefits are that unlike https://github.com/llvm/llvm-project/pull/83966 it does not do an entire pass over the IR and changes required for translation of the `depend` clause on target constructs into LLVM IR are entirely contained in this one change. This is in comparison to the approach of https://github.com/llvm/llvm-project/pull/83966 which would need adding that pass to the pass pipeline in flang.
Further, now the all `depend` clause related operands and attributes of `omp.target` can be removed from MLIR (patch coming soon after this PR is merged)
>From 1da3208d9ea66cb8b475ddefb0db759a802363a7 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Wed, 13 Mar 2024 00:26:53 -0500
Subject: [PATCH] [mlir][OpenMP] - Transform target offloading directives with
dependencies during PFT to MLIR conversion
This patch changes PFT to MLIR lowering for the following directives
when they have the `depend` clause on them.
!$ omp target depend(..)
!$ omp target enter data depend(..)
!$ omp target update data depend(..)
!$ omp target exit data depend(..)
Lowering now involves the creation of an `omp.task` operation that encloses
the `omp.target` operation that is otherwise generated for the target construct.
In addition, the depend clause from the target is moved to the enclosing
new `omp.task`. The new `omp.task` is a mergeable task.
```
!$ omp target map(..) depend(in:a)
b = a
```
is transformed to the following MLIR
```
omp.task mergeable depend(in:a) {
omp.target map(..) {
//MLIR for b = a;
}
omp.terminator
}
```
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 37 ++++++++++++++++++++++
flang/lib/Lower/OpenMP/ClauseProcessor.h | 7 ++++
flang/lib/Lower/OpenMP/OpenMP.cpp | 20 ++++--------
flang/test/Lower/OpenMP/target.f90 | 13 +++++---
4 files changed, 59 insertions(+), 18 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a41f8312a28c9e..8961268fd63a17 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -808,6 +808,43 @@ bool ClauseProcessor::processDepend(
});
}
+bool ClauseProcessor::processTargetDepend(
+ mlir::Location currentLocation) const {
+ llvm::SmallVector<mlir::Attribute> dependTypeOperands;
+ llvm::SmallVector<mlir::Value> dependOperands;
+
+ processDepend(dependTypeOperands, dependOperands);
+ if (dependTypeOperands.empty())
+ return false;
+
+ // If 'dependTypeOperands' is not empty, this means the depend
+ // clause was used and we create an omp.task operation that'll
+ // enclose the omp.target operation corresponding to the target
+ // construct used. This new omp.task will be a mergeable task
+ // on which the depend clause will be tacked on. The depend
+ // clause on the original target construct is dropped.
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ // Create the new omp.task op.
+ // As per the OpenMP Spec a target directive creates a mergeable 'target
+ // task'
+ mlir::omp::TaskOp taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
+ currentLocation, /*if_expr*/ mlir::Value(),
+ /*final_expr*/ mlir::Value(), /*untied*/ mlir::UnitAttr(),
+ /*mergeable*/ firOpBuilder.getUnitAttr(),
+ /*in_reduction_vars*/ mlir::ValueRange(), /*in_reductions*/ nullptr,
+ /*priority*/ mlir::Value(),
+ mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+ dependTypeOperands),
+ dependOperands, /*allocate_vars*/ mlir::ValueRange(),
+ /*allocate_vars*/ mlir::ValueRange());
+
+ firOpBuilder.createBlock(&taskOp.getRegion());
+ firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
+ firOpBuilder.setInsertionPointToStart(&taskOp.getRegion().front());
+ return true;
+}
+
bool ClauseProcessor::processIf(
Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
mlir::Value &result) const {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 11aff0be25053d..140c7a8d5cefe8 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -101,6 +101,13 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const;
bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
+
+ // This is a special case of processDepend that processes the depend
+ // clause on Target ops - TargetOp, EnterDataOp, ExitDataOp, UpdateDataOp
+ // It sets up the generation of MLIR code for the target construct
+ // in question by first creating an enclosing omp.task operation and transfers
+ // the 'depend' clause and its arguments to this new omp.task operation.
+ bool processTargetDepend(mlir::Location currentLocation) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 25bb4d9cff5d16..011e7d75245061 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -933,7 +933,7 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(directiveName, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
+ cp.processTargetDepend(currentLocation);
cp.processNowait(nowaitAttr);
if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
@@ -946,13 +946,9 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
}
- return firOpBuilder.create<OpTy>(
- currentLocation, ifClauseOperand, deviceOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
+ deviceOperand, nullptr, mlir::ValueRange(),
+ nowaitAttr, mapOperands);
}
// This functions creates a block for the body of the targetOp's region. It adds
@@ -1130,7 +1126,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLimitOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
+ cp.processTargetDepend(currentLocation);
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
&mapSymLocs, &mapSymbols);
@@ -1232,11 +1228,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
mapSymLocs, mapSymbols, currentLocation);
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 030533e1a04553..10c928dda19b43 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -27,9 +27,10 @@ subroutine omp_target_enter_depend
!$omp task depend(out: a)
call foo(a)
!$omp end task
+ !CHECK: omp.task mergeable depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
- !CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+ !CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target enter data map(to: a) depend(in: a)
return
end subroutine omp_target_enter_depend
@@ -166,9 +167,10 @@ subroutine omp_target_exit_depend
!$omp task depend(out: a)
call foo(a)
!$omp end task
+ !CHECK: omp.task mergeable depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
- !CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+ !CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target exit data map(from: a) depend(out: a)
end subroutine omp_target_exit_depend
@@ -187,9 +189,10 @@ subroutine omp_target_update_depend
call foo(a)
!$omp end task
+ !CHECK: omp.task mergeable depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[BOUNDS:.*]] = omp.bounds
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
- !CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+ !CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target update to(a) depend(in:a)
end subroutine omp_target_update_depend
@@ -367,12 +370,14 @@ subroutine omp_target_depend
!$omp task depend(out: a)
call foo(a)
!$omp end task
+
+ !CHECK: omp.task mergeable depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index
!CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index
!CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index
!CHECK: %[[BOUNDS_A:.*]] = omp.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index)
!CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
- !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+ !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target map(tofrom: a) depend(in: a)
a(1) = 10
!CHECK: omp.terminator
More information about the flang-commits
mailing list