[flang-commits] [flang] 0488ecb - [MLIR][OpenMP] Add Lowering support for OpenMP Target Data, Exit Data and Enter Data directives

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Thu Jan 26 03:00:06 PST 2023


Author: Akash Banerjee
Date: 2023-01-26T10:59:55Z
New Revision: 0488ecb5696d93c3d149bab45268992a4744e668

URL: https://github.com/llvm/llvm-project/commit/0488ecb5696d93c3d149bab45268992a4744e668
DIFF: https://github.com/llvm/llvm-project/commit/0488ecb5696d93c3d149bab45268992a4744e668.diff

LOG: [MLIR][OpenMP] Add Lowering support for OpenMP Target Data, Exit Data and Enter Data directives

This patch adds Fortran Lowering support for the OpenMP Target Data, Target Exit Data and Target Enter Data constructs.
operation.

Differential Revision: https://reviews.llvm.org/D142357

Added: 
    flang/test/Lower/OpenMP/target_data.f90

Modified: 
    flang/lib/Lower/OpenMP.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 8e13c24ff9364..4bb734b90d99d 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -404,6 +404,19 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
   }
 }
 
+static mlir::Value
+getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
+                   Fortran::lower::StatementContext &stmtCtx,
+                   const Fortran::parser::OmpClause::If *ifClause) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.getCurrentLocation();
+  auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
+  mlir::Value ifVal = fir::getBase(
+      converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+  return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(),
+                                    ifVal);
+}
+
 static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
                                  std::size_t loopVarTypeSize) {
   // OpenMP runtime requires 32-bit or 64-bit loop variables.
@@ -547,6 +560,130 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
   }
 }
 
+static void
+createTargetDataOp(Fortran::lower::AbstractConverter &converter,
+                   const Fortran::parser::OmpClauseList &opClauseList,
+                   const llvm::omp::Directive &directive) {
+  Fortran::lower::StatementContext stmtCtx;
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  mlir::Value ifClauseOperand, deviceOperand;
+  mlir::UnitAttr nowaitAttr;
+  llvm::SmallVector<mlir::Value> useDevicePtrOperand, useDeviceAddrOperand,
+      mapOperands;
+  llvm::SmallVector<mlir::IntegerAttr> mapTypes;
+
+  auto addMapClause = [&firOpBuilder, &converter, &mapOperands,
+                       &mapTypes](const auto &mapClause) {
+    auto mapType = std::get<Fortran::parser::OmpMapType::Type>(
+        std::get<std::optional<Fortran::parser::OmpMapType>>(mapClause->v.t)
+            ->t);
+    llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
+        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+    switch (mapType) {
+    case Fortran::parser::OmpMapType::Type::To:
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
+      break;
+    case Fortran::parser::OmpMapType::Type::From:
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+      break;
+    case Fortran::parser::OmpMapType::Type::Tofrom:
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
+                     llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+      break;
+    case Fortran::parser::OmpMapType::Type::Alloc:
+    case Fortran::parser::OmpMapType::Type::Release:
+      // alloc and release is the default map_type for the Target Data Ops, i.e.
+      // if no bits for map_type is supplied then alloc/release is implicitly
+      // assumed based on the target directive. Default value for Target Data
+      // and Enter Data is alloc and for Exit Data it is release.
+      break;
+    case Fortran::parser::OmpMapType::Type::Delete:
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
+    }
+    if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
+            std::get<std::optional<Fortran::parser::OmpMapType>>(mapClause->v.t)
+                ->t)
+            .has_value())
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
+
+    // TODO: Add support MapTypeModifiers close, mapper, present, iterator
+
+    mlir::IntegerAttr mapTypeAttr = firOpBuilder.getIntegerAttr(
+        firOpBuilder.getI64Type(),
+        static_cast<
+            std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+            mapTypeBits));
+
+    llvm::SmallVector<mlir::Value> mapOperand;
+    genObjectList(std::get<Fortran::parser::OmpObjectList>(mapClause->v.t),
+                  converter, mapOperand);
+
+    for (mlir::Value mapOp : mapOperand) {
+      mapOperands.push_back(mapOp);
+      mapTypes.push_back(mapTypeAttr);
+    }
+  };
+
+  for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
+    mlir::Location currentLocation = converter.genLocation(clause.source);
+    if (const auto &ifClause =
+            std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
+      ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
+    } else if (const auto &deviceClause =
+                   std::get_if<Fortran::parser::OmpClause::Device>(&clause.u)) {
+      if (auto deviceModifier = std::get<
+              std::optional<Fortran::parser::OmpDeviceClause::DeviceModifier>>(
+              deviceClause->v.t)) {
+        if (deviceModifier ==
+            Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) {
+          TODO(currentLocation, "OMPD_target Device Modifier Ancestor");
+        }
+      }
+      if (const auto *deviceExpr = Fortran::semantics::GetExpr(
+              std::get<Fortran::parser::ScalarIntExpr>(deviceClause->v.t))) {
+        deviceOperand =
+            fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
+      }
+    } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
+                   &clause.u)) {
+      TODO(currentLocation, "OMPD_target Use Device Ptr");
+    } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
+                   &clause.u)) {
+      TODO(currentLocation, "OMPD_target Use Device Addr");
+    } else if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
+      nowaitAttr = firOpBuilder.getUnitAttr();
+    } else if (const auto &mapClause =
+                   std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
+      addMapClause(mapClause);
+    } else {
+      TODO(currentLocation, "OMPD_target unhandled clause");
+    }
+  }
+
+  llvm::SmallVector<mlir::Attribute> mapTypesAttr(mapTypes.begin(),
+                                                  mapTypes.end());
+  mlir::ArrayAttr mapTypesArrayAttr =
+      ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr);
+  mlir::Location currentLocation = converter.getCurrentLocation();
+
+  if (directive == llvm::omp::Directive::OMPD_target_data) {
+    firOpBuilder.create<omp::DataOp>(
+        currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand,
+        useDeviceAddrOperand, mapOperands, mapTypesArrayAttr);
+  } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) {
+    firOpBuilder.create<omp::EnterDataOp>(currentLocation, ifClauseOperand,
+                                          deviceOperand, nowaitAttr,
+                                          mapOperands, mapTypesArrayAttr);
+  } else if (directive == llvm::omp::Directive::OMPD_target_exit_data) {
+    firOpBuilder.create<omp::ExitDataOp>(currentLocation, ifClauseOperand,
+                                         deviceOperand, nowaitAttr, mapOperands,
+                                         mapTypesArrayAttr);
+  } else {
+    TODO(currentLocation, "OMPD_target directive unknown");
+  }
+}
+
 static void genOMP(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::pft::Evaluation &eval,
                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
@@ -554,25 +691,27 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
   const auto &directive =
       std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
           simpleStandaloneConstruct.t);
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  const Fortran::parser::OmpClauseList &opClauseList =
+      std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t);
+
   switch (directive.v) {
   default:
     break;
   case llvm::omp::Directive::OMPD_barrier:
-    converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
-        converter.getCurrentLocation());
+    firOpBuilder.create<omp::BarrierOp>(converter.getCurrentLocation());
     break;
   case llvm::omp::Directive::OMPD_taskwait:
-    converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
-        converter.getCurrentLocation());
+    firOpBuilder.create<omp::TaskwaitOp>(converter.getCurrentLocation());
     break;
   case llvm::omp::Directive::OMPD_taskyield:
-    converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
-        converter.getCurrentLocation());
+    firOpBuilder.create<omp::TaskyieldOp>(converter.getCurrentLocation());
     break;
+  case llvm::omp::Directive::OMPD_target_data:
   case llvm::omp::Directive::OMPD_target_enter_data:
-    TODO(converter.getCurrentLocation(), "OMPD_target_enter_data");
   case llvm::omp::Directive::OMPD_target_exit_data:
-    TODO(converter.getCurrentLocation(), "OMPD_target_exit_data");
+    createTargetDataOp(converter, opClauseList, directive.v);
+    break;
   case llvm::omp::Directive::OMPD_target_update:
     TODO(converter.getCurrentLocation(), "OMPD_target_update");
   case llvm::omp::Directive::OMPD_ordered:
@@ -669,19 +808,6 @@ static omp::ClauseProcBindKindAttr genProcBindKindAttr(
   return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
 }
 
-static mlir::Value
-getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
-                   Fortran::lower::StatementContext &stmtCtx,
-                   const Fortran::parser::OmpClause::If *ifClause) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.getCurrentLocation();
-  auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
-  mlir::Value ifVal = fir::getBase(
-      converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
-  return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(),
-                                    ifVal);
-}
-
 /* When parallel is used in a combined construct, then use this function to
  * create the parallel operation. It handles the parallel specific clauses
  * and leaves the rest for handling at the inner operations.

diff  --git a/flang/test/Lower/OpenMP/target_data.f90 b/flang/test/Lower/OpenMP/target_data.f90
new file mode 100644
index 0000000000000..d3994bde7d82f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/target_data.f90
@@ -0,0 +1,105 @@
+!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+
+!===============================================================================
+! Target_Enter Simple
+!===============================================================================
+
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_simple() {
+subroutine omp_target_enter_simple
+   integer :: a(1024)
+   !CHECK: omp.target_enter_data   map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data map(to: a)
+end subroutine omp_target_enter_simple
+
+!===============================================================================
+! Target_Enter Map types
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_mt() {
+subroutine omp_target_enter_mt
+   integer :: a(1024)
+   integer :: b(1024)
+   integer :: c(1024)
+   integer :: d(1024)
+   !CHECK: omp.target_enter_data   map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (always, alloc -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data map(to: a, b) map(always, alloc: c) map(to: d)
+end subroutine omp_target_enter_mt
+
+!===============================================================================
+! `Nowait` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_nowait() {
+subroutine omp_target_enter_nowait
+   integer :: a(1024)
+   !CHECK: omp.target_enter_data   nowait map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data map(to: a) nowait
+end subroutine omp_target_enter_nowait
+
+!===============================================================================
+! `if` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_if() {
+subroutine omp_target_enter_if
+   integer :: a(1024)
+   integer :: i
+   i = 5
+   !CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1:.*]] : !fir.ref<i32>
+   !CHECK: %[[VAL_4:.*]] = arith.constant 10 : i32
+   !CHECK: %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_3:.*]], %[[VAL_4:.*]] : i32
+   !CHECK: omp.target_enter_data   if(%[[VAL_5:.*]] : i1) map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data if(i<10) map(to: a)
+end subroutine omp_target_enter_if
+
+!===============================================================================
+! `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_device() {
+subroutine omp_target_enter_device
+   integer :: a(1024)
+   !CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32
+   !CHECK: omp.target_enter_data   device(%[[VAL_1:.*]] : i32) map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data map(to: a) device(2)
+end subroutine omp_target_enter_device
+
+!===============================================================================
+! Target_Exit Simple
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_exit_simple() {
+subroutine omp_target_exit_simple
+   integer :: a(1024)
+   !CHECK: omp.target_exit_data   map((from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target exit data map(from: a)
+end subroutine omp_target_exit_simple
+
+!===============================================================================
+! Target_Exit Map types
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_exit_mt() {
+subroutine omp_target_exit_mt
+   integer :: a(1024)
+   integer :: b(1024)
+   integer :: c(1024)
+   integer :: d(1024)
+   integer :: e(1024)
+   !CHECK: omp.target_exit_data   map((from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (release -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (always, delete -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target exit data map(from: a,b) map(release: c) map(always, delete: d) map(from: e)
+end subroutine omp_target_exit_mt
+
+!===============================================================================
+! `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_exit_device() {
+subroutine omp_target_exit_device
+   integer :: a(1024)
+   integer :: d
+   !CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_1:.*]] : !fir.ref<i32>
+   !CHECK: omp.target_exit_data   device(%[[VAL_2:.*]] : i32) map((from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target exit data map(from: a) device(d)
+end subroutine omp_target_exit_device


        


More information about the flang-commits mailing list