[flang-commits] [flang] [mlir][OpenMP] - Transform target offloading directives with dependencies during PFT to MLIR conversion (PR #85130)
via flang-commits
flang-commits at lists.llvm.org
Wed Mar 13 13:58:57 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Pranav Bhandarkar (bhandarkar-pranav)
<details>
<summary>Changes</summary>
This PR changes PFT to MLIR lowering for the following directives when they have the `depend` clause on them.
```
!$ omp target depend(..)
!$ omp target enter data depend(..)
!$ omp target update data depend(..)
!$ omp target exit data depend(..)
```
With this PR, lowering now involves the creation of an `omp.task` operation that encloses the `omp.target` operation that is otherwise generated for the target construct. In addition, the depend clause from the target is moved to the enclosing new `omp.task`. The new `omp.task` is a mergeable task.
```
!$ omp target map(..) depend(in:a)
b = a
```
is transformed to the following MLIR
```
omp.task mergeable depend(in:a) {
omp.target map(..) {
//MLIR for b = a;
}
omp.terminator
}
```
>[!NOTE]
>This PR is an alternative to https://github.com/llvm/llvm-project/pull/83966. Its benefits are that unlike https://github.com/llvm/llvm-project/pull/83966 it does not do an entire pass over the IR and changes required for translation of the `depend` clause on target constructs into LLVM IR are entirely contained in this one change. This is in comparison to the approach of https://github.com/llvm/llvm-project/pull/83966 which would need adding that pass to the pass pipeline in flang.
Further, now the all `depend` clause related operands and attributes of `omp.target` can be removed from MLIR (patch coming soon after this PR is merged)
---
Full diff: https://github.com/llvm/llvm-project/pull/85130.diff
4 Files Affected:
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+37)
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+7)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+6-14)
- (modified) flang/test/Lower/OpenMP/target.f90 (+9-4)
``````````diff
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a41f8312a28c9e..8961268fd63a17 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -808,6 +808,43 @@ bool ClauseProcessor::processDepend(
});
}
+bool ClauseProcessor::processTargetDepend(
+ mlir::Location currentLocation) const {
+ llvm::SmallVector<mlir::Attribute> dependTypeOperands;
+ llvm::SmallVector<mlir::Value> dependOperands;
+
+ processDepend(dependTypeOperands, dependOperands);
+ if (dependTypeOperands.empty())
+ return false;
+
+ // If 'dependTypeOperands' is not empty, this means the depend
+ // clause was used and we create an omp.task operation that'll
+ // enclose the omp.target operation corresponding to the target
+ // construct used. This new omp.task will be a mergeable task
+ // on which the depend clause will be tacked on. The depend
+ // clause on the original target construct is dropped.
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ // Create the new omp.task op.
+ // As per the OpenMP Spec a target directive creates a mergeable 'target
+ // task'
+ mlir::omp::TaskOp taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
+ currentLocation, /*if_expr*/ mlir::Value(),
+ /*final_expr*/ mlir::Value(), /*untied*/ mlir::UnitAttr(),
+ /*mergeable*/ firOpBuilder.getUnitAttr(),
+ /*in_reduction_vars*/ mlir::ValueRange(), /*in_reductions*/ nullptr,
+ /*priority*/ mlir::Value(),
+ mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+ dependTypeOperands),
+ dependOperands, /*allocate_vars*/ mlir::ValueRange(),
+ /*allocate_vars*/ mlir::ValueRange());
+
+ firOpBuilder.createBlock(&taskOp.getRegion());
+ firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
+ firOpBuilder.setInsertionPointToStart(&taskOp.getRegion().front());
+ return true;
+}
+
bool ClauseProcessor::processIf(
Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
mlir::Value &result) const {
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 11aff0be25053d..140c7a8d5cefe8 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -101,6 +101,13 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const;
bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands,
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
+
+ // This is a special case of processDepend that processes the depend
+ // clause on Target ops - TargetOp, EnterDataOp, ExitDataOp, UpdateDataOp
+ // It sets up the generation of MLIR code for the target construct
+ // in question by first creating an enclosing omp.task operation and transfers
+ // the 'depend' clause and its arguments to this new omp.task operation.
+ bool processTargetDepend(mlir::Location currentLocation) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 25bb4d9cff5d16..011e7d75245061 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -933,7 +933,7 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
ClauseProcessor cp(converter, semaCtx, clauseList);
cp.processIf(directiveName, ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
+ cp.processTargetDepend(currentLocation);
cp.processNowait(nowaitAttr);
if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
@@ -946,13 +946,9 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
}
- return firOpBuilder.create<OpTy>(
- currentLocation, ifClauseOperand, deviceOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
+ deviceOperand, nullptr, mlir::ValueRange(),
+ nowaitAttr, mapOperands);
}
// This functions creates a block for the body of the targetOp's region. It adds
@@ -1130,7 +1126,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
ifClauseOperand);
cp.processDevice(stmtCtx, deviceOperand);
cp.processThreadLimit(stmtCtx, threadLimitOperand);
- cp.processDepend(dependTypeOperands, dependOperands);
+ cp.processTargetDepend(currentLocation);
cp.processNowait(nowaitAttr);
cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
&mapSymLocs, &mapSymbols);
@@ -1232,11 +1228,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
- dependTypeOperands.empty()
- ? nullptr
- : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
- dependTypeOperands),
- dependOperands, nowaitAttr, mapOperands);
+ nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
mapSymLocs, mapSymbols, currentLocation);
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 030533e1a04553..10c928dda19b43 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -27,9 +27,10 @@ subroutine omp_target_enter_depend
!$omp task depend(out: a)
call foo(a)
!$omp end task
+ !CHECK: omp.task mergeable depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
- !CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+ !CHECK: omp.target_enter_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target enter data map(to: a) depend(in: a)
return
end subroutine omp_target_enter_depend
@@ -166,9 +167,10 @@ subroutine omp_target_exit_depend
!$omp task depend(out: a)
call foo(a)
!$omp end task
+ !CHECK: omp.task mergeable depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
!CHECK: %[[BOUNDS:.*]] = omp.bounds lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}}) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
- !CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+ !CHECK: omp.target_exit_data map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target exit data map(from: a) depend(out: a)
end subroutine omp_target_exit_depend
@@ -187,9 +189,10 @@ subroutine omp_target_update_depend
call foo(a)
!$omp end task
+ !CHECK: omp.task mergeable depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[BOUNDS:.*]] = omp.bounds
!CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
- !CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+ !CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target update to(a) depend(in:a)
end subroutine omp_target_update_depend
@@ -367,12 +370,14 @@ subroutine omp_target_depend
!$omp task depend(out: a)
call foo(a)
!$omp end task
+
+ !CHECK: omp.task mergeable depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
!CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index
!CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index
!CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index
!CHECK: %[[BOUNDS_A:.*]] = omp.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index)
!CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
- !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+ !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>)
!$omp target map(tofrom: a) depend(in: a)
a(1) = 10
!CHECK: omp.terminator
``````````
</details>
https://github.com/llvm/llvm-project/pull/85130
More information about the flang-commits
mailing list