[flang-commits] [flang] fa4b1e1 - [flang][OpenMP] Added allocate clause translation for OpenMP block constructs

Shraiysh Vaishay via flang-commits flang-commits at lists.llvm.org
Fri Apr 8 07:31:42 PDT 2022


Author: Shraiysh Vaishay
Date: 2022-04-08T20:01:22+05:30
New Revision: fa4b1e1e95d0dda8476aead84b9f4bb4bb416e49

URL: https://github.com/llvm/llvm-project/commit/fa4b1e1e95d0dda8476aead84b9f4bb4bb416e49
DIFF: https://github.com/llvm/llvm-project/commit/fa4b1e1e95d0dda8476aead84b9f4bb4bb416e49.diff

LOG: [flang][OpenMP] Added allocate clause translation for OpenMP block constructs

This patch adds translation for allocate clause for parallel and single
constructs.

Also added tests for block constructs.

This patch also adds tests for parallel construct which were not added earlier.

Reviewed By: NimishMishra, peixin

Differential Revision: https://reviews.llvm.org/D122483

Co-authored-by: Sourabh Singh Tomar <SourabhSingh.Tomar at amd.com>

Added: 
    flang/test/Lower/OpenMP/parallel.f90

Modified: 
    flang/include/flang/Parser/parse-tree.h
    flang/lib/Lower/OpenMP.cpp
    flang/test/Lower/OpenMP/single.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index bc15b3c84606e..23c572c2a6b73 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -3308,7 +3308,7 @@ WRAPPER_CLASS(PauseStmt, std::optional<StopCode>);
 
 // 2.5 proc-bind-clause -> PROC_BIND (MASTER | CLOSE | SPREAD)
 struct OmpProcBindClause {
-  ENUM_CLASS(Type, Close, Master, Spread)
+  ENUM_CLASS(Type, Close, Master, Spread, Primary)
   WRAPPER_CLASS_BOILERPLATE(OmpProcBindClause, Type);
 };
 

diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 4ecbb7d37f914..235159d8fdd7b 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -179,75 +179,79 @@ genOMP(Fortran::lower::AbstractConverter &converter,
   auto currentLocation = converter.getCurrentLocation();
   Fortran::lower::StatementContext stmtCtx;
   llvm::ArrayRef<mlir::Type> argTy;
-  if (blockDirective.v == llvm::omp::OMPD_parallel) {
-
-    mlir::Value ifClauseOperand, numThreadsClauseOperand;
-    Attribute procBindClauseOperand;
+  mlir::Value ifClauseOperand, numThreadsClauseOperand;
+  mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
+  SmallVector<Value> allocateOperands, allocatorOperands;
+  mlir::UnitAttr nowaitAttr;
 
-    const auto &parallelOpClauseList =
-        std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
-    for (const auto &clause : parallelOpClauseList.v) {
-      if (const auto &ifClause =
-              std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
-        auto &expr =
-            std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
-        ifClauseOperand = fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(expr), stmtCtx));
-      } else if (const auto &numThreadsClause =
-                     std::get_if<Fortran::parser::OmpClause::NumThreads>(
-                         &clause.u)) {
-        // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
-        numThreadsClauseOperand = fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
+  for (const auto &clause :
+       std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t).v) {
+    if (const auto &ifClause =
+            std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
+      auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
+      ifClauseOperand = fir::getBase(
+          converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+    } else if (const auto &numThreadsClause =
+                   std::get_if<Fortran::parser::OmpClause::NumThreads>(
+                       &clause.u)) {
+      // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
+      numThreadsClauseOperand = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
+    } else if (const auto &procBindClause =
+                   std::get_if<Fortran::parser::OmpClause::ProcBind>(
+                       &clause.u)) {
+      omp::ClauseProcBindKind pbKind;
+      switch (procBindClause->v.v) {
+      case Fortran::parser::OmpProcBindClause::Type::Master:
+        pbKind = omp::ClauseProcBindKind::Master;
+        break;
+      case Fortran::parser::OmpProcBindClause::Type::Close:
+        pbKind = omp::ClauseProcBindKind::Close;
+        break;
+      case Fortran::parser::OmpProcBindClause::Type::Spread:
+        pbKind = omp::ClauseProcBindKind::Spread;
+        break;
+      case Fortran::parser::OmpProcBindClause::Type::Primary:
+        pbKind = omp::ClauseProcBindKind::Primary;
+        break;
       }
-      // TODO: Handle private, firstprivate, shared and copyin
+      procBindKindAttr =
+          omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
+    } else if (const auto &allocateClause =
+                   std::get_if<Fortran::parser::OmpClause::Allocate>(
+                       &clause.u)) {
+      genAllocateClause(converter, allocateClause->v, allocatorOperands,
+                        allocateOperands);
+    } else if (const auto &privateClause =
+                   std::get_if<Fortran::parser::OmpClause::Private>(
+                       &clause.u)) {
+      // TODO: Handle private. This cannot be a hard TODO because testing for
+      // allocate clause requires private variables.
+    } else {
+      TODO(currentLocation, "OpenMP Block construct clauses");
     }
+  }
+
+  for (const auto &clause :
+       std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
+    if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
+      nowaitAttr = firOpBuilder.getUnitAttr();
+  }
+
+  if (blockDirective.v == llvm::omp::OMPD_parallel) {
     // Create and insert the operation.
     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
         currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
-        /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
-        /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
-        procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>());
-    // Handle attribute based clauses.
-    for (const auto &clause : parallelOpClauseList.v) {
-      // TODO: Handle default clause
-      if (const auto &procBindClause =
-              std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u)) {
-        const auto &ompProcBindClause{procBindClause->v};
-        omp::ClauseProcBindKind pbKind;
-        switch (ompProcBindClause.v) {
-        case Fortran::parser::OmpProcBindClause::Type::Master:
-          pbKind = omp::ClauseProcBindKind::Master;
-          break;
-        case Fortran::parser::OmpProcBindClause::Type::Close:
-          pbKind = omp::ClauseProcBindKind::Close;
-          break;
-        case Fortran::parser::OmpProcBindClause::Type::Spread:
-          pbKind = omp::ClauseProcBindKind::Spread;
-          break;
-        }
-        parallelOp.proc_bind_valAttr(omp::ClauseProcBindKindAttr::get(
-            firOpBuilder.getContext(), pbKind));
-      }
-    }
+        allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
+        /*reductions=*/nullptr, procBindKindAttr);
     createBodyOfOp<omp::ParallelOp>(parallelOp, firOpBuilder, currentLocation);
   } else if (blockDirective.v == llvm::omp::OMPD_master) {
     auto masterOp =
         firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
     createBodyOfOp<omp::MasterOp>(masterOp, firOpBuilder, currentLocation);
-
-    // Single Construct
   } else if (blockDirective.v == llvm::omp::OMPD_single) {
-    mlir::UnitAttr nowaitAttr;
-    for (const auto &clause :
-         std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
-      if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
-        nowaitAttr = firOpBuilder.getUnitAttr();
-      // TODO: Handle allocate clause (D122302)
-    }
     auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
-        currentLocation, /*allocate_vars=*/ValueRange(),
-        /*allocators_vars=*/ValueRange(), nowaitAttr);
+        currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
     createBodyOfOp(singleOp, firOpBuilder, currentLocation);
   }
 }

diff  --git a/flang/test/Lower/OpenMP/parallel.f90 b/flang/test/Lower/OpenMP/parallel.f90
new file mode 100644
index 0000000000000..849db5c3705f6
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel.f90
@@ -0,0 +1,163 @@
+!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes="FIRDialect,OMPDialect"
+!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="OMPDialect"
+
+!FIRDialect-LABEL: func @_QPparallel_simple
+subroutine parallel_simple()
+   !OMPDialect: omp.parallel
+!$omp parallel
+   !FIRDialect: fir.call
+   call f1()
+!$omp end parallel
+end subroutine parallel_simple
+
+!===============================================================================
+! `if` clause
+!===============================================================================
+
+!FIRDialect-LABEL: func @_QPparallel_if
+subroutine parallel_if(alpha)
+   integer, intent(in) :: alpha
+
+   !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+   !$omp parallel if(alpha .le. 0)
+   !FIRDialect: fir.call
+   call f1()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+   !$omp parallel if(.false.)
+   !FIRDialect: fir.call
+   call f2()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+   !$omp parallel if(alpha .ge. 0)
+   !FIRDialect: fir.call
+   call f3()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+   !$omp parallel if(.true.)
+   !FIRDialect: fir.call
+   call f4()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+end subroutine parallel_if
+
+!===============================================================================
+! `num_threads` clause
+!===============================================================================
+
+!FIRDialect-LABEL: func @_QPparallel_numthreads
+subroutine parallel_numthreads(num_threads)
+   integer, intent(inout) :: num_threads
+
+   !OMPDialect: omp.parallel num_threads(%{{.*}}: i32) {
+   !$omp parallel num_threads(16)
+   !FIRDialect: fir.call
+   call f1()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   num_threads = 4
+
+   !OMPDialect: omp.parallel num_threads(%{{.*}} : i32) {
+   !$omp parallel num_threads(num_threads)
+   !FIRDialect: fir.call
+   call f2()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+end subroutine parallel_numthreads
+
+!===============================================================================
+! `proc_bind` clause
+!===============================================================================
+
+!FIRDialect-LABEL: func @_QPparallel_proc_bind
+subroutine parallel_proc_bind()
+
+   !OMPDialect: omp.parallel proc_bind(master) {
+   !$omp parallel proc_bind(master)
+   !FIRDialect: fir.call
+   call f1()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel proc_bind(close) {
+   !$omp parallel proc_bind(close)
+   !FIRDialect: fir.call
+   call f2()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel proc_bind(spread) {
+   !$omp parallel proc_bind(spread)
+   !FIRDialect: fir.call
+   call f3()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+end subroutine parallel_proc_bind
+
+!===============================================================================
+! `allocate` clause
+!===============================================================================
+
+!FIRDialect-LABEL: func @_QPparallel_allocate
+subroutine parallel_allocate()
+   use omp_lib
+   integer :: x
+   !OMPDialect: omp.parallel allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref<i32>) {
+   !$omp parallel allocate(omp_high_bw_mem_alloc: x) private(x)
+   !FIRDialect: arith.addi
+   x = x + 12
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+end subroutine parallel_allocate
+
+!===============================================================================
+! multiple clauses
+!===============================================================================
+
+!FIRDialect-LABEL: func @_QPparallel_multiple_clauses
+subroutine parallel_multiple_clauses(alpha, num_threads)
+   use omp_lib
+   integer, intent(inout) :: alpha
+   integer, intent(in) :: num_threads
+
+   !OMPDialect: omp.parallel if({{.*}} : i1) proc_bind(master) {
+   !$omp parallel if(alpha .le. 0) proc_bind(master)
+   !FIRDialect: fir.call
+   call f1()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel num_threads({{.*}} : i32) proc_bind(close) {
+   !$omp parallel proc_bind(close) num_threads(num_threads)
+   !FIRDialect: fir.call
+   call f2()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel if({{.*}} : i1) num_threads({{.*}} : i32) {
+   !$omp parallel num_threads(num_threads) if(alpha .le. 0)
+   !FIRDialect: fir.call
+   call f3()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel if({{.*}} : i1) num_threads({{.*}} : i32) allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref<i32>) {
+   !$omp parallel num_threads(num_threads) if(alpha .le. 0) allocate(omp_high_bw_mem_alloc: alpha) private(alpha)
+   !FIRDialect: fir.call
+   call f3()
+   !FIRDialect: arith.addi
+   alpha = alpha + 12
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+end subroutine parallel_multiple_clauses

diff  --git a/flang/test/Lower/OpenMP/single.f90 b/flang/test/Lower/OpenMP/single.f90
index 5d00bdd0897c7..e159dcf73d9e2 100644
--- a/flang/test/Lower/OpenMP/single.f90
+++ b/flang/test/Lower/OpenMP/single.f90
@@ -44,3 +44,23 @@ subroutine omp_single_nowait(x)
   !OMPDialect: omp.terminator
   !$omp end parallel
 end subroutine omp_single_nowait
+
+!===============================================================================
+! Single construct with allocate
+!===============================================================================
+
+!FIRDialect-LABEL: func @_QPsingle_allocate
+subroutine single_allocate()
+  use omp_lib
+  integer :: x
+  !OMPDialect: omp.parallel {
+  !$omp parallel
+  !OMPDialect: omp.single allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref<i32>) {
+  !$omp single allocate(omp_high_bw_mem_alloc: x) private(x)
+  !FIRDialect: arith.addi
+  x = x + 12
+  !OMPDialect: omp.terminator
+  !$omp end single
+  !OMPDialect: omp.terminator
+  !$omp end parallel
+end subroutine single_allocate


        


More information about the flang-commits mailing list