[flang-commits] [flang] 211ed03 - [Flang][OpenMP][Lower] Support lowering of `teams` directive to MLIR

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Tue Aug 15 05:38:27 PDT 2023


Author: Sergio Afonso
Date: 2023-08-15T13:38:09+01:00
New Revision: 211ed03bfd3421414b29cd4d1d959d51a38966a8

URL: https://github.com/llvm/llvm-project/commit/211ed03bfd3421414b29cd4d1d959d51a38966a8
DIFF: https://github.com/llvm/llvm-project/commit/211ed03bfd3421414b29cd4d1d959d51a38966a8.diff

LOG: [Flang][OpenMP][Lower] Support lowering of `teams` directive to MLIR

This patch adds support for translating `teams` OpenMP directives to MLIR, when
appearing as either loop or block constructs and as part of combined constructs
or on its own.

The current Fortran parser does not allow the specification of the optional
lower bound for the "num_teams" clause, so only the `num_teams_upper` MLIR
argument is set by this patch.

Depends on D156809

Differential Revision: https://reviews.llvm.org/D156884

Added: 
    flang/test/Lower/OpenMP/Todo/reduction-teams.f90
    flang/test/Lower/OpenMP/teams.f90

Modified: 
    flang/lib/Lower/OpenMP.cpp
    flang/test/Lower/OpenMP/if-clause.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 79d54e232b777d..ac911f988abcaf 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -496,6 +496,8 @@ class ClauseProcessor {
   bool processHint(mlir::IntegerAttr &result) const;
   bool processMergeable(mlir::UnitAttr &result) const;
   bool processNowait(mlir::UnitAttr &result) const;
+  bool processNumTeams(Fortran::lower::StatementContext &stmtCtx,
+                       mlir::Value &result) const;
   bool processNumThreads(Fortran::lower::StatementContext &stmtCtx,
                          mlir::Value &result) const;
   bool processOrdered(mlir::IntegerAttr &result) const;
@@ -1347,6 +1349,18 @@ bool ClauseProcessor::processNowait(mlir::UnitAttr &result) const {
   return markClauseOccurrence<ClauseTy::Nowait>(result);
 }
 
+bool ClauseProcessor::processNumTeams(Fortran::lower::StatementContext &stmtCtx,
+                                      mlir::Value &result) const {
+  // TODO Get lower and upper bounds for num_teams when parser is updated to
+  // accept both.
+  if (auto *numTeamsClause = findUniqueClause<ClauseTy::NumTeams>()) {
+    result = fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(numTeamsClause->v), stmtCtx));
+    return true;
+  }
+  return false;
+}
+
 bool ClauseProcessor::processNumThreads(
     Fortran::lower::StatementContext &stmtCtx, mlir::Value &result) const {
   if (auto *numThreadsClause = findUniqueClause<ClauseTy::NumThreads>()) {
@@ -2359,6 +2373,40 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
       mapOperands, mapTypesArrayAttr);
 }
 
+static mlir::omp::TeamsOp
+genTeamsOp(Fortran::lower::AbstractConverter &converter,
+           Fortran::lower::pft::Evaluation &eval,
+           mlir::Location currentLocation,
+           const Fortran::parser::OmpClauseList &clauseList,
+           bool outerCombined = false) {
+  Fortran::lower::StatementContext stmtCtx;
+  mlir::Value numTeamsClauseOperand, ifClauseOperand, threadLimitClauseOperand;
+  llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
+      reductionVars;
+  llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
+
+  ClauseProcessor cp(converter, clauseList);
+  cp.processIf(stmtCtx,
+               Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
+               ifClauseOperand);
+  cp.processAllocate(allocatorOperands, allocateOperands);
+  cp.processDefault();
+  cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
+  cp.processThreadLimit(stmtCtx, threadLimitClauseOperand);
+  if (cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols))
+    TODO(currentLocation, "Reduction of TEAMS directive");
+
+  return genOpWithBody<mlir::omp::TeamsOp>(
+      converter, eval, currentLocation, outerCombined, &clauseList,
+      /*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
+      threadLimitClauseOperand, allocateOperands, allocatorOperands,
+      reductionVars,
+      reductionDeclSymbols.empty()
+          ? nullptr
+          : mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
+                                 reductionDeclSymbols));
+}
+
 //===----------------------------------------------------------------------===//
 // genOMP() Code generation helper functions
 //===----------------------------------------------------------------------===//
@@ -2483,7 +2531,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
     if ((llvm::omp::allTeamsSet & llvm::omp::loopConstructSet)
             .test(ompDirective)) {
       validDirective = true;
-      TODO(currentLocation, "Teams construct");
+      genTeamsOp(converter, eval, currentLocation, loopOpClauseList,
+                 /*outerCombined=*/true);
     }
     if (llvm::omp::allDistributeSet.test(ompDirective)) {
       validDirective = true;
@@ -2628,7 +2677,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
         !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) &&
         !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) &&
-        !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u)) {
+        !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) &&
+        !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u)) {
       TODO(clauseLocation, "OpenMP Block construct clause");
     }
   }
@@ -2667,7 +2717,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     genTaskGroupOp(converter, eval, currentLocation, beginClauseList);
     break;
   case llvm::omp::Directive::OMPD_teams:
-    TODO(currentLocation, "Teams construct");
+    genTeamsOp(converter, eval, currentLocation, beginClauseList,
+               /*outerCombined=*/false);
     break;
   case llvm::omp::Directive::OMPD_workshare:
     TODO(currentLocation, "Workshare construct");
@@ -2683,7 +2734,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     }
     if ((llvm::omp::allTeamsSet & llvm::omp::blockConstructSet)
             .test(directive.v)) {
-      TODO(currentLocation, "Teams construct");
+      genTeamsOp(converter, eval, currentLocation, beginClauseList);
       combinedDirective = true;
     }
     if ((llvm::omp::allParallelSet & llvm::omp::blockConstructSet)

diff  --git a/flang/test/Lower/OpenMP/Todo/reduction-teams.f90 b/flang/test/Lower/OpenMP/Todo/reduction-teams.f90
new file mode 100644
index 00000000000000..f045c2f323f715
--- /dev/null
+++ b/flang/test/Lower/OpenMP/Todo/reduction-teams.f90
@@ -0,0 +1,12 @@
+! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+! CHECK: not yet implemented: Reduction of TEAMS directive
+subroutine reduction_teams()
+  integer :: i
+  i = 0
+
+  !$omp teams reduction(+:i)
+  i = i + 1
+  !$omp end teams
+end subroutine reduction_teams

diff  --git a/flang/test/Lower/OpenMP/if-clause.f90 b/flang/test/Lower/OpenMP/if-clause.f90
index 9b8b5b7b439387..ef98a00f10dbd2 100644
--- a/flang/test/Lower/OpenMP/if-clause.f90
+++ b/flang/test/Lower/OpenMP/if-clause.f90
@@ -13,7 +13,6 @@ program main
   ! - PARALLEL SECTIONS
   ! - PARALLEL WORKSHARE
   ! - TARGET PARALLEL
-  ! - TARGET TEAMS
   ! - TARGET TEAMS DISTRIBUTE
   ! - TARGET TEAMS DISTRIBUTE PARALLEL DO
   ! - TARGET TEAMS DISTRIBUTE PARALLEL DO SIMD
@@ -21,7 +20,6 @@ program main
   ! - TARGET UPDATE
   ! - TASKLOOP
   ! - TASKLOOP SIMD
-  ! - TEAMS
   ! - TEAMS DISTRIBUTE
   ! - TEAMS DISTRIBUTE PARALLEL DO
   ! - TEAMS DISTRIBUTE PARALLEL DO SIMD
@@ -416,6 +414,54 @@ program main
   end do
   !$omp end target simd
 
+  ! ----------------------------------------------------------------------------
+  ! TARGET TEAMS
+  ! ----------------------------------------------------------------------------
+
+  ! CHECK:      omp.target
+  ! CHECK-NOT:  if({{.*}})
+  ! CHECK-SAME: {
+  ! CHECK:      omp.teams
+  ! CHECK-NOT:  if({{.*}})
+  ! CHECK-SAME: {
+  !$omp target teams
+  i = 1
+  !$omp end target teams
+
+  ! CHECK:      omp.target
+  ! CHECK-SAME: if({{.*}})
+  ! CHECK:      omp.teams
+  ! CHECK-SAME: if({{.*}})
+  !$omp target teams if(.true.)
+  i = 1
+  !$omp end target teams
+
+  ! CHECK:      omp.target
+  ! CHECK-SAME: if({{.*}})
+  ! CHECK:      omp.teams
+  ! CHECK-SAME: if({{.*}})
+  !$omp target teams if(target: .true.) if(teams: .false.)
+  i = 1
+  !$omp end target teams
+
+  ! CHECK:      omp.target
+  ! CHECK-SAME: if({{.*}})
+  ! CHECK:      omp.teams
+  ! CHECK-NOT:  if({{.*}})
+  ! CHECK-SAME: {
+  !$omp target teams if(target: .true.)
+  i = 1
+  !$omp end target teams
+
+  ! CHECK:      omp.target
+  ! CHECK-NOT:  if({{.*}})
+  ! CHECK-SAME: {
+  ! CHECK:      omp.teams
+  ! CHECK-SAME: if({{.*}})
+  !$omp target teams if(teams: .true.)
+  i = 1
+  !$omp end target teams
+
   ! ----------------------------------------------------------------------------
   ! TASK
   ! ----------------------------------------------------------------------------
@@ -434,4 +480,26 @@ program main
   ! CHECK-SAME: if({{.*}})
   !$omp task if(task: .true.)
   !$omp end task
+
+  ! ----------------------------------------------------------------------------
+  ! TEAMS
+  ! ----------------------------------------------------------------------------
+  ! CHECK:      omp.teams
+  ! CHECK-NOT:  if({{.*}})
+  ! CHECK-SAME: {
+  !$omp teams
+  i = 1
+  !$omp end teams
+
+  ! CHECK:      omp.teams
+  ! CHECK-SAME: if({{.*}})
+  !$omp teams if(.true.)
+  i = 1
+  !$omp end teams
+
+  ! CHECK:      omp.teams
+  ! CHECK-SAME: if({{.*}})
+  !$omp teams if(teams: .true.)
+  i = 1
+  !$omp end teams
 end program main

diff  --git a/flang/test/Lower/OpenMP/teams.f90 b/flang/test/Lower/OpenMP/teams.f90
new file mode 100644
index 00000000000000..51f9087acb0b59
--- /dev/null
+++ b/flang/test/Lower/OpenMP/teams.f90
@@ -0,0 +1,114 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+
+! CHECK-LABEL: func @_QPteams_simple
+subroutine teams_simple()
+  ! CHECK: omp.teams
+  !$omp teams
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_simple
+
+!===============================================================================
+! `num_teams` clause
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_numteams
+subroutine teams_numteams(num_teams)
+  integer, intent(inout) :: num_teams
+
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams( to %{{.*}}: i32)
+  !$omp teams num_teams(4)
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams( to %{{.*}}: i32)
+  !$omp teams num_teams(num_teams)
+  ! CHECK: fir.call
+  call f2()
+  ! CHECK: omp.terminator
+  !$omp end teams
+
+end subroutine teams_numteams
+
+!===============================================================================
+! `if` clause
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_if
+subroutine teams_if(alpha)
+  integer, intent(in) :: alpha
+  logical :: condition
+
+  ! CHECK: omp.teams
+  ! CHECK-SAME: if(%{{.*}})
+  !$omp teams if(.false.)
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+
+  ! CHECK: omp.teams
+  ! CHECK-SAME: if(%{{.*}})
+  !$omp teams if(alpha .le. 0)
+  ! CHECK: fir.call
+  call f2()
+  ! CHECK: omp.terminator
+  !$omp end teams
+
+  ! CHECK: omp.teams
+  ! CHECK-SAME: if(%{{.*}})
+  !$omp teams if(condition)
+  ! CHECK: fir.call
+  call f3()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_if
+
+!===============================================================================
+! `thread_limit` clause
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_threadlimit
+subroutine teams_threadlimit(thread_limit)
+  integer, intent(inout) :: thread_limit
+
+  ! CHECK: omp.teams
+  ! CHECK-SAME: thread_limit(%{{.*}}: i32)
+  !$omp teams thread_limit(4)
+  ! CHECK: fir.call
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+
+  ! CHECK: omp.teams
+  ! CHECK-SAME: thread_limit(%{{.*}}: i32)
+  !$omp teams thread_limit(thread_limit)
+  ! CHECK: fir.call
+  call f2()
+  ! CHECK: omp.terminator
+  !$omp end teams
+
+end subroutine teams_threadlimit
+
+!===============================================================================
+! `allocate` clause
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_allocate
+subroutine teams_allocate()
+   use omp_lib
+   integer :: x
+   ! CHECK: omp.teams
+   ! CHECK-SAME: allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref<i32>)
+   !$omp teams allocate(omp_high_bw_mem_alloc: x) private(x)
+   ! CHECK: arith.addi
+   x = x + 12
+   ! CHECK: omp.terminator
+   !$omp end teams
+end subroutine teams_allocate


        


More information about the flang-commits mailing list