[flang-commits] [flang] [flang][MLIR][OpenMP] Emit `UpdateDataOp` from `!$omp target update` (PR #75345)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Thu Dec 21 11:49:29 PST 2023


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/75345

>From a187a293d8160fb584f1ce97ad8cb61b4d425080 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 13 Dec 2023 05:09:08 -0600
Subject: [PATCH] [flang][MLIR][OpenMP] Emit `UpdateDataOp` from `!$omp target
 update`

Emits MLIR op corresponding to `!$omp target update` directive. So far,
only motion types: `to` and `from` are supported. Motion modifiers:
`present`, `mapper`, and `iterator` are not supported yet.
---
 flang/lib/Lower/OpenMP.cpp             | 109 ++++++++++++++++++++++---
 flang/test/Lower/OpenMP/FIR/target.f90 |  94 +++++++++++++++++++++
 flang/test/Lower/OpenMP/target.f90     |  88 ++++++++++++++++++++
 3 files changed, 280 insertions(+), 11 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 9213cff95d3f11..754e38d2ee0245 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -606,6 +606,16 @@ class ClauseProcessor {
                       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
                           &useDeviceSymbols) const;
 
+  bool
+  processToMotionClauses(Fortran::semantics::SemanticsContext &semanticsContext,
+                         Fortran::lower::StatementContext &stmtCtx,
+                         llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+
+  bool processFromMotionClauses(
+      Fortran::semantics::SemanticsContext &semanticsContext,
+      Fortran::lower::StatementContext &stmtCtx,
+      llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+
   // Call this method for these clauses that should be supported but are not
   // implemented yet. It triggers a compilation error if any of the given
   // clauses is found.
@@ -672,6 +682,12 @@ class ClauseProcessor {
     return false;
   }
 
+  template <typename T>
+  bool
+  processMotionClauses(Fortran::semantics::SemanticsContext &semanticsContext,
+                       Fortran::lower::StatementContext &stmtCtx,
+                       llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+
   Fortran::lower::AbstractConverter &converter;
   const Fortran::parser::OmpClauseList &clauses;
 };
@@ -1892,6 +1908,63 @@ bool ClauseProcessor::processUseDevicePtr(
       });
 }
 
+bool ClauseProcessor::processToMotionClauses(
+    Fortran::semantics::SemanticsContext &semanticsContext,
+    Fortran::lower::StatementContext &stmtCtx,
+    llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+  return processMotionClauses<ClauseProcessor::ClauseTy::To>(
+      semanticsContext, stmtCtx, mapOperands);
+}
+
+bool ClauseProcessor::processFromMotionClauses(
+    Fortran::semantics::SemanticsContext &semanticsContext,
+    Fortran::lower::StatementContext &stmtCtx,
+    llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+  return processMotionClauses<ClauseProcessor::ClauseTy::From>(
+      semanticsContext, stmtCtx, mapOperands);
+}
+
+template <typename T>
+bool ClauseProcessor::processMotionClauses(
+    Fortran::semantics::SemanticsContext &semanticsContext,
+    Fortran::lower::StatementContext &stmtCtx,
+    llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+  return findRepeatableClause<T>(
+      [&](const T *motionClause, const Fortran::parser::CharBlock &source) {
+        mlir::Location clauseLocation = converter.genLocation(source);
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+        static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
+                      std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
+
+        // TODO Support motion modifiers: present, mapper, iterator.
+        constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
+            std::is_same_v<T, ClauseProcessor::ClauseTy::To>
+                ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
+                : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+
+        for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
+          llvm::SmallVector<mlir::Value> bounds;
+          std::stringstream asFortran;
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::gatherDataOperandAddrAndBounds<
+                  Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
+                  mlir::omp::DataBoundsType>(
+                  converter, firOpBuilder, semanticsContext, stmtCtx, ompObject,
+                  clauseLocation, asFortran, bounds, treatIndexAsSection);
+
+          mlir::Value mapOp = createMapInfoOp(
+              firOpBuilder, clauseLocation, info.addr, asFortran, bounds,
+              static_cast<
+                  std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+                  mapTypeBits),
+              mlir::omp::VariableCaptureKind::ByRef, info.addr.getType());
+
+          mapOperands.push_back(mapOp);
+        }
+      });
+}
+
 template <typename... Ts>
 void ClauseProcessor::processTODO(mlir::Location currentLocation,
                                   llvm::omp::Directive directive) const {
@@ -2416,10 +2489,10 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
 
 template <typename OpTy>
 static OpTy
-genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
-                   Fortran::semantics::SemanticsContext &semanticsContext,
-                   mlir::Location currentLocation,
-                   const Fortran::parser::OmpClauseList &clauseList) {
+genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
+                         Fortran::semantics::SemanticsContext &semanticsContext,
+                         mlir::Location currentLocation,
+                         const Fortran::parser::OmpClauseList &clauseList) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   Fortran::lower::StatementContext stmtCtx;
   mlir::Value ifClauseOperand, deviceOperand;
@@ -2436,6 +2509,10 @@ genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
     directiveName =
         Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
     directive = llvm::omp::Directive::OMPD_target_exit_data;
+  } else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
+    directiveName =
+        Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
+    directive = llvm::omp::Directive::OMPD_target_update;
   } else {
     return nullptr;
   }
@@ -2444,8 +2521,16 @@ genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
   cp.processIf(directiveName, ifClauseOperand);
   cp.processDevice(stmtCtx, deviceOperand);
   cp.processNowait(nowaitAttr);
-  cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
-                mapOperands);
+
+  if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
+    cp.processToMotionClauses(semanticsContext, stmtCtx, mapOperands);
+    cp.processFromMotionClauses(semanticsContext, stmtCtx, mapOperands);
+
+  } else {
+    cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
+                  mapOperands);
+  }
+
   cp.processTODO<Fortran::parser::OmpClause::Depend>(currentLocation,
                                                      directive);
 
@@ -2848,15 +2933,17 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
     genDataOp(converter, eval, semanticsContext, currentLocation, opClauseList);
     break;
   case llvm::omp::Directive::OMPD_target_enter_data:
-    genEnterExitDataOp<mlir::omp::EnterDataOp>(converter, semanticsContext,
-                                               currentLocation, opClauseList);
+    genEnterExitUpdateDataOp<mlir::omp::EnterDataOp>(
+        converter, semanticsContext, currentLocation, opClauseList);
     break;
   case llvm::omp::Directive::OMPD_target_exit_data:
-    genEnterExitDataOp<mlir::omp::ExitDataOp>(converter, semanticsContext,
-                                              currentLocation, opClauseList);
+    genEnterExitUpdateDataOp<mlir::omp::ExitDataOp>(
+        converter, semanticsContext, currentLocation, opClauseList);
     break;
   case llvm::omp::Directive::OMPD_target_update:
-    TODO(currentLocation, "OMPD_target_update");
+    genEnterExitUpdateDataOp<mlir::omp::UpdateDataOp>(
+        converter, semanticsContext, currentLocation, opClauseList);
+    break;
   case llvm::omp::Directive::OMPD_ordered:
     TODO(currentLocation, "OMPD_ordered");
   }
diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90
index 2034ac84334e54..5d36699bf0e902 100644
--- a/flang/test/Lower/OpenMP/FIR/target.f90
+++ b/flang/test/Lower/OpenMP/FIR/target.f90
@@ -133,6 +133,100 @@ subroutine omp_target_exit_device
    !$omp target exit data map(from: a) device(d)
 end subroutine omp_target_exit_device
 
+!===============================================================================
+! Target_Update `to` clause
+!===============================================================================
+
+subroutine omp_target_update_to
+   integer :: a(1024)
+
+   !CHECK-DAG: %[[A_ALLOC:.*]] = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_update_toEa"}
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+
+   !CHECK: %[[TO_MAP:.*]] = omp.map_info var_ptr(%[[A_ALLOC]] : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)
+   !CHECK-SAME: map_clauses(to) capture(ByRef)
+   !CHECK-SAME: bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+
+   !CHECK: omp.target_update_data
+   !CHECK-SAME: motion_entries(%[[TO_MAP]] : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update to(a)
+end subroutine omp_target_update_to
+
+!===============================================================================
+! Target_Update `from` clause
+!===============================================================================
+
+subroutine omp_target_update_from
+   integer :: a(1024)
+
+   !CHECK-DAG: %[[A_ALLOC:.*]] = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_update_fromEa"}
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+
+   !CHECK: %[[FROM_MAP:.*]] = omp.map_info var_ptr(%[[A_ALLOC]] : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)
+   !CHECK-SAME: map_clauses(from) capture(ByRef)
+   !CHECK-SAME: bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+
+   !CHECK: omp.target_update_data
+   !CHECK-SAME: motion_entries(%[[FROM_MAP]] : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update from(a)
+end subroutine omp_target_update_from
+
+!===============================================================================
+! Target_Update `if` clause
+!===============================================================================
+
+subroutine omp_target_update_if
+   integer :: a(1024)
+   logical :: i
+
+   !CHECK-DAG: %[[A_ALLOC:.*]] = fir.alloca
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+   !CHECK-DAG: %[[COND:.*]] = fir.convert %{{.*}} : (!fir.logical<4>) -> i1
+
+   !CHECK: %[[TO_MAP:.*]] = omp.map_info
+
+   !CHECK: omp.target_update_data if(%[[COND]] : i1)
+   !CHECK-SAME: motion_entries(%[[TO_MAP]] : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update to(a) if(i)
+end subroutine omp_target_update_if
+
+!===============================================================================
+! Target_Update `device` clause
+!===============================================================================
+
+subroutine omp_target_update_device
+   integer :: a(1024)
+
+   !CHECK-DAG: %[[A_ALLOC:.*]] = fir.alloca
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+   !CHECK-DAG: %[[DEVICE:.*]] = arith.constant 1 : i32
+
+   !CHECK: %[[TO_MAP:.*]] = omp.map_info
+
+   !CHECK: omp.target_update_data
+   !CHECK-SAME: device(%[[DEVICE]] : i32)
+   !CHECK-SAME: motion_entries(%[[TO_MAP]] : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update to(a) device(1)
+end subroutine omp_target_update_device
+
+!===============================================================================
+! Target_Update `nowait` clause
+!===============================================================================
+
+subroutine omp_target_update_nowait
+   integer :: a(1024)
+
+   !CHECK-DAG: %[[A_ALLOC:.*]] = fir.alloca
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+
+   !CHECK: %[[TO_MAP:.*]] = omp.map_info
+
+   !CHECK: omp.target_update_data
+   !CHECK-SAME: nowait
+   !CHECK-SAME: motion_entries(%[[TO_MAP]] : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update to(a) nowait
+end subroutine omp_target_update_nowait
+
 !===============================================================================
 ! Target_Data with region
 !===============================================================================
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 26b4d595d52291..5ca3a08d8a8b20 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -134,6 +134,94 @@ subroutine omp_target_exit_device
    !$omp target exit data map(from: a) device(d)
 end subroutine omp_target_exit_device
 
+!===============================================================================
+! Target_Update `to` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_to() {
+subroutine omp_target_update_to
+   integer :: a(1024)
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+
+   !CHECK: %[[TO_MAP:.*]] = omp.map_info var_ptr(%[[A_DECL]]#1 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)
+   !CHECK-SAME: map_clauses(to) capture(ByRef)
+   !CHECK-SAME: bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+
+   !CHECK: omp.target_update_data motion_entries(%[[TO_MAP]] : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update to(a)
+end subroutine omp_target_update_to
+
+!===============================================================================
+! Target_Update `from` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_from() {
+subroutine omp_target_update_from
+   integer :: a(1024)
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+
+   !CHECK: %[[FROM_MAP:.*]] = omp.map_info var_ptr(%[[A_DECL]]#1 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)
+   !CHECK-SAME: map_clauses(from) capture(ByRef)
+   !CHECK-SAME: bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+
+   !CHECK: omp.target_update_data motion_entries(%[[FROM_MAP]] : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update from(a)
+end subroutine omp_target_update_from
+
+!===============================================================================
+! Target_Update `if` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_if() {
+subroutine omp_target_update_if
+   integer :: a(1024)
+   logical :: i
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+   !CHECK-DAG: %[[COND:.*]] = fir.convert %{{.*}} : (!fir.logical<4>) -> i1
+
+   !CHECK: omp.target_update_data if(%[[COND]] : i1) motion_entries
+   !$omp target update from(a) if(i)
+end subroutine omp_target_update_if
+
+!===============================================================================
+! Target_Update `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_device() {
+subroutine omp_target_update_device
+   integer :: a(1024)
+   logical :: i
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+   !CHECK-DAG: %[[DEVICE:.*]] = arith.constant 1 : i32
+
+   !CHECK: omp.target_update_data device(%[[DEVICE]] : i32) motion_entries
+   !$omp target update from(a) device(1)
+end subroutine omp_target_update_device
+
+!===============================================================================
+! Target_Update `nowait` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_nowait() {
+subroutine omp_target_update_nowait
+   integer :: a(1024)
+   logical :: i
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+
+   !CHECK: omp.target_update_data nowait motion_entries
+   !$omp target update from(a) nowait
+end subroutine omp_target_update_nowait
+
 !===============================================================================
 ! Target_Data with region
 !===============================================================================



More information about the flang-commits mailing list