[flang-commits] [flang] [flang][openmp] - depend clause support in target, target enter/update/exit data constructs (PR #81610)

Pranav Bhandarkar via flang-commits flang-commits at lists.llvm.org
Tue Feb 13 06:54:20 PST 2024


https://github.com/bhandarkar-pranav created https://github.com/llvm/llvm-project/pull/81610

This patch adds support in flang for the depend clause in target and target enter/update/exit constructs. Previously, the following line in a fortran program would have resulted in the error shown below it.

    !$omp target map(to:a) depend(in:a)


    "not yet implemented: Unhandled clause DEPEND in TARGET construct"

>From 516aa3c3d6b490340547d64b307ddf1d7dc6b443 Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Wed, 7 Feb 2024 20:48:52 -0600
Subject: [PATCH] [flang][openmp] - depend clause support in target, target
 exit/update/data constructs

This patch adds support in flang for the depend clause in target and target exit/update/data
constructs. Previously, the following line in a fortran program would have resulted
in the error shown below it.

    !$omp target map(to:a) depend(in:a)

    "not yet implemented: Unhandled clause DEPEND in TARGET construct"
---
 flang/lib/Lower/OpenMP.cpp                    | 29 ++++---
 flang/lib/Semantics/check-omp-structure.cpp   |  8 ++
 flang/test/Lower/OpenMP/target.f90            | 85 +++++++++++++++++++
 .../Semantics/OpenMP/clause-validity01.f90    |  1 +
 4 files changed, 112 insertions(+), 11 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 06850bebd7d05a..7e36fdac0c4dbb 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2786,7 +2786,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value ifClauseOperand, deviceOperand;
   mlir::UnitAttr nowaitAttr;
-  llvm::SmallVector<mlir::Value> mapOperands;
+  llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
+  llvm::SmallVector<mlir::Attribute> dependTypeOperands;
 
   Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
   llvm::omp::Directive directive;
@@ -2820,13 +2821,15 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
   } else {
     cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
   }
+  cp.processDepend(dependTypeOperands, dependOperands);
 
-  cp.processTODO<Fortran::parser::OmpClause::Depend>(currentLocation,
-                                                     directive);
-
-  return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
-                                   deviceOperand, nullptr, mlir::ValueRange(),
-                                   nowaitAttr, mapOperands);
+  return firOpBuilder.create<OpTy>(
+      currentLocation, ifClauseOperand, deviceOperand,
+      dependTypeOperands.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 dependTypeOperands),
+      dependOperands, nowaitAttr, mapOperands);
 }
 
 // This functions creates a block for the body of the targetOp's region. It adds
@@ -2993,7 +2996,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
   mlir::UnitAttr nowaitAttr;
-  llvm::SmallVector<mlir::Value> mapOperands;
+  llvm::SmallVector<mlir::Attribute> dependTypeOperands;
+  llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
   llvm::SmallVector<mlir::Type> mapSymTypes;
   llvm::SmallVector<mlir::Location> mapSymLocs;
   llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
@@ -3006,8 +3010,8 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   cp.processNowait(nowaitAttr);
   cp.processMap(currentLocation, directive, stmtCtx, mapOperands, &mapSymTypes,
                 &mapSymLocs, &mapSymbols);
+  cp.processDepend(dependTypeOperands, dependOperands);
   cp.processTODO<Fortran::parser::OmpClause::Private,
-                 Fortran::parser::OmpClause::Depend,
                  Fortran::parser::OmpClause::Firstprivate,
                  Fortran::parser::OmpClause::IsDevicePtr,
                  Fortran::parser::OmpClause::HasDeviceAddr,
@@ -3017,7 +3021,6 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
                  Fortran::parser::OmpClause::UsesAllocators,
                  Fortran::parser::OmpClause::Defaultmap>(
       currentLocation, llvm::omp::Directive::OMPD_target);
-
   // 5.8.1 Implicit Data-Mapping Attribute Rules
   // The following code follows the implicit data-mapping rules to map all the
   // symbols used inside the region that have not been explicitly mapped using
@@ -3091,7 +3094,11 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
 
   auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
-      nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
+      dependTypeOperands.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 dependTypeOperands),
+      dependOperands, nowaitAttr, mapOperands);
 
   genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
                     mapSymLocs, mapSymbols, currentLocation);
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index 03423de0c6104d..54101ab8a42bbf 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -2815,6 +2815,14 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Device &x) {
 
 void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
   CheckAllowed(llvm::omp::Clause::OMPC_depend);
+  if ((std::holds_alternative<parser::OmpDependClause::Source>(x.v.u) ||
+          std::holds_alternative<parser::OmpDependClause::Sink>(x.v.u)) &&
+      GetContext().directive != llvm::omp::OMPD_ordered) {
+    context_.Say(GetContext().clauseSource,
+        "DEPEND(SOURCE) or DEPEND(SINK : vec) can be used only with the ordered"
+        " directive. Used here in the %s construct."_err_en_US,
+        parser::ToUpperCaseLetters(getDirectiveName(GetContext().directive)));
+  }
   if (const auto *inOut{std::get_if<parser::OmpDependClause::InOut>(&x.v.u)}) {
     const auto &designators{std::get<std::list<parser::Designator>>(inOut->t)};
     for (const auto &ele : designators) {
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index fa07b7f71d514e..030533e1a04553 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -14,6 +14,26 @@ subroutine omp_target_enter_simple
     return
 end subroutine omp_target_enter_simple
 
+!===============================================================================
+! Target_Enter `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_depend() {
+subroutine omp_target_enter_depend
+   !CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_enter_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
+   integer :: a(1024)
+
+   !CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp task depend(out: a)
+   call foo(a)
+   !$omp end task
+   !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
+   !CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}})   map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: omp.target_enter_data   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target enter data map(to: a) depend(in: a)
+    return
+end subroutine omp_target_enter_depend
+
 !===============================================================================
 ! Target_Enter Map types
 !===============================================================================
@@ -134,6 +154,45 @@ subroutine omp_target_exit_device
    !$omp target exit data map(from: a) device(d)
 end subroutine omp_target_exit_device
 
+!===============================================================================
+! Target_Exit `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_exit_depend() {
+subroutine omp_target_exit_depend
+   !CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_exit_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
+   integer :: a(1024)
+   !CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp task depend(out: a)
+   call foo(a)
+   !$omp end task
+   !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
+   !CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}})   map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: omp.target_exit_data   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target exit data map(from: a) depend(out: a)
+end subroutine omp_target_exit_depend
+
+
+!===============================================================================
+! Target_Update `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_depend() {
+subroutine omp_target_update_depend
+   !CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_update_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
+   integer :: a(1024)
+
+   !CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp task depend(out: a)
+   call foo(a)
+   !$omp end task
+
+   !CHECK: %[[BOUNDS:.*]] = omp.bounds
+   !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(to) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: omp.target_update_data motion_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update to(a) depend(in:a)
+end subroutine omp_target_update_depend
+
 !===============================================================================
 ! Target_Update `to` clause
 !===============================================================================
@@ -295,6 +354,32 @@ subroutine omp_target
    !CHECK: }
 end subroutine omp_target
 
+!===============================================================================
+! Target with region `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_depend() {
+subroutine omp_target_depend
+   !CHECK: %[[EXTENT_A:.*]] = arith.constant 1024 : index
+   !CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFomp_target_dependEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
+   integer :: a(1024)
+   !CHECK: omp.task depend(taskdependout -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp task depend(out: a)
+   call foo(a)
+   !$omp end task
+   !CHECK: %[[STRIDE_A:.*]] = arith.constant 1 : index
+   !CHECK: %[[LBOUND_A:.*]] = arith.constant 0 : index
+   !CHECK: %[[UBOUND_A:.*]] = arith.subi %c1024, %c1 : index
+   !CHECK: %[[BOUNDS_A:.*]] = omp.bounds lower_bound(%[[LBOUND_A]] : index) upper_bound(%[[UBOUND_A]] : index) extent(%[[EXTENT_A]] : index) stride(%[[STRIDE_A]] : index) start_idx(%[[STRIDE_A]] : index)
+   !CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[A]]#0 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_A]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: omp.target map_entries(%[[MAP_A]] -> %[[BB0_ARG:.*]] : !fir.ref<!fir.array<1024xi32>>) depend(taskdependin -> %[[A]]#1 : !fir.ref<!fir.array<1024xi32>>) {
+   !$omp target map(tofrom: a) depend(in: a)
+      a(1) = 10
+      !CHECK: omp.terminator
+   !$omp end target
+   !CHECK: }
+ end subroutine omp_target_depend
+
 !===============================================================================
 ! Target implicit capture
 !===============================================================================
diff --git a/flang/test/Semantics/OpenMP/clause-validity01.f90 b/flang/test/Semantics/OpenMP/clause-validity01.f90
index 3fa86ed105a292..d9573a81821f32 100644
--- a/flang/test/Semantics/OpenMP/clause-validity01.f90
+++ b/flang/test/Semantics/OpenMP/clause-validity01.f90
@@ -481,6 +481,7 @@
   !$omp taskyield
   !$omp barrier
   !$omp taskwait
+  !ERROR: DEPEND(SOURCE) or DEPEND(SINK : vec) can be used only with the ordered directive. Used here in the TASKWAIT construct.
   !$omp taskwait depend(source)
   ! !$omp taskwait depend(sink:i-1)
   ! !$omp target enter data map(to:arrayA) map(alloc:arrayB)



More information about the flang-commits mailing list