[flang-commits] [flang] [openmp] Changes to support link clause of declare target with common block (PR #84825)

Anchu Rajendran S via flang-commits flang-commits at lists.llvm.org
Mon Mar 11 13:31:26 PDT 2024


https://github.com/anchuraj created https://github.com/llvm/llvm-project/pull/84825

This change is a continuation to changes in https://github.com/llvm/llvm-project/pull/83643. As per the implicit mapping rules, if a variable is specified in link clause of openmp, it needs to be mapped tofrom if the device type is not specified as nohost.

>From 6d664d6f1e04d39cea5518883b4a300f1376cd8b Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Mon, 11 Mar 2024 15:21:40 -0500
Subject: [PATCH] Changes to support link clause of declare target with common
 block

This change is a continuation to changes in https://github.com/llvm/llvm-project/pull/83643.
As per the implicit mapping rules, if a variable is specified in
link clause of openmp, it needs to be mapped tofrom if the device type
is not specified as nohost.
---
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 31 ++++++++++++++++---
 .../OpenMP/declare-target-link-tarop-cap.f90  | 18 +++++++++++
 .../declare-target-vars-in-target-region.f90  | 31 +++++++++++++++++++
 3 files changed, 75 insertions(+), 5 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 5cff95c7d125b0..83fb2d94508c26 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1037,6 +1037,31 @@ static void genBodyOfTargetOp(
     genNestedEvaluations(converter, eval);
 }
 
+// If the symbol is specified in declare target directive, the function returns
+// the corresponding declare target operation.
+static mlir::omp::DeclareTargetInterface
+getDeclareTargetOp(const Fortran::semantics::Symbol &sym,
+                   Fortran::lower::AbstractConverter &converter) {
+  mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+  mlir::Operation *op;
+  op = mod.lookupSymbol(converter.mangleName(sym));
+  auto declareTargetOp =
+      llvm::dyn_cast_if_present<mlir::omp::DeclareTargetInterface>(op);
+  // If declare target op is not found Check if common block containing the
+  // variable is specified in declare target
+  if (!declareTargetOp || !declareTargetOp.isDeclareTarget()) {
+    if (auto cB = Fortran::semantics::FindCommonBlockContaining(sym)) {
+      op = mod.lookupSymbol(converter.mangleName(*cB));
+      declareTargetOp =
+          llvm::dyn_cast_if_present<mlir::omp::DeclareTargetInterface>(op);
+    }
+  }
+  if (declareTargetOp && declareTargetOp.isDeclareTarget()) {
+    return declareTargetOp;
+  }
+  return static_cast<mlir::omp::DeclareTargetInterface>(nullptr);
+}
+
 static mlir::omp::TargetOp
 genTargetOp(Fortran::lower::AbstractConverter &converter,
             Fortran::semantics::SemanticsContext &semaCtx,
@@ -1122,11 +1147,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
 
         // If a variable is specified in declare target link and if device
         // type is not specified as `nohost`, it needs to be mapped tofrom
-        mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
-        mlir::Operation *op = mod.lookupSymbol(converter.mangleName(sym));
-        auto declareTargetOp =
-            llvm::dyn_cast_if_present<mlir::omp::DeclareTargetInterface>(op);
-        if (declareTargetOp && declareTargetOp.isDeclareTarget()) {
+        if (auto declareTargetOp = getDeclareTargetOp(sym, converter)) {
           if (declareTargetOp.getDeclareTargetCaptureClause() ==
                   mlir::omp::DeclareTargetCaptureClause::link &&
               declareTargetOp.getDeclareTargetDeviceType() !=
diff --git a/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90 b/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90
index 7cd0597161578d..cd2615faba5463 100644
--- a/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90
+++ b/flang/test/Lower/OpenMP/declare-target-link-tarop-cap.f90
@@ -20,6 +20,13 @@ program test_link
   integer, pointer :: test_ptr2
   !$omp declare target link(test_ptr2)
 
+  integer :: test_int_cb
+
+  integer :: test_int_array_cb(3) = (/1,2,3/)
+
+  common /test_cb/ test_int_cb, test_int_array_cb
+  !$omp declare target link(/test_cb/)
+
   !CHECK-DAG: {{%.*}} = omp.map_info var_ptr({{%.*}} : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32> {name = "test_int"}
   !$omp target
     test_int = test_int + 1
@@ -52,4 +59,15 @@ program test_link
     test_ptr2 = test_ptr2 + 1
   !$omp end target
 
+  !CHECK-DAG: {{%.*}} = omp.map_info var_ptr({{%.*}} : !fir.ref<i32>, i32) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<i32> {name = "test_int_cb"}
+  !$omp target
+    test_int_cb = test_int_cb + 1
+  !$omp end target
+
+  !CHECK-DAG: {{%.*}} = omp.map_info var_ptr({{%.*}} : !fir.ref<!fir.array<3xi32>>, !fir.array<3xi32>) map_clauses(implicit, tofrom) capture(ByRef) bounds({{%.*}}) -> !fir.ref<!fir.array<3xi32>> {name = "test_int_array_cb"}
+  !$omp target
+    do i = 1,3
+      test_int_array_cb(i) = i * 2
+    end do
+  !$omp end target
 end
diff --git a/openmp/libomptarget/test/offloading/fortran/declare-target-vars-in-target-region.f90 b/openmp/libomptarget/test/offloading/fortran/declare-target-vars-in-target-region.f90
index f524deac3bcce9..63343d504323b4 100644
--- a/openmp/libomptarget/test/offloading/fortran/declare-target-vars-in-target-region.f90
+++ b/openmp/libomptarget/test/offloading/fortran/declare-target-vars-in-target-region.f90
@@ -16,6 +16,10 @@ module test_0
   !$omp declare target link(arr1) enter(arr2)
   INTEGER :: scalar = 1
   !$omp declare target link(scalar)
+  INTEGER :: scalar_cb = 1
+  INTEGER :: arr_cb(10) = (/0,0,0,0,0,0,0,0,0,0/)
+  COMMON /CB/ scalar_cb, arr_cb
+  !$omp declare target link(/CB/)
 end module test_0
 
 subroutine test_with_array_link_and_tofrom()
@@ -73,9 +77,36 @@ subroutine test_with_scalar_link_only()
   PRINT *, scalar
 end subroutine test_with_scalar_link_only
 
+subroutine test_with_array_cb_link_only()
+  use test_0
+  integer :: i = 1
+  integer :: j = 11
+  !$omp target map(i, j)
+      do while (i <= j)
+          arr_cb(i) = i + 1;
+          i = i + 1
+      end do
+  !$omp end target
+
+  ! CHECK: 2 3 4 5 6 7 8 9 10 11
+  PRINT *, arr_cb(:)
+end subroutine test_with_array_cb_link_only
+
+subroutine test_with_scalar_cb_link_only()
+  use test_0
+  !$omp target
+      scalar_cb = 10
+  !$omp end target
+
+  ! CHECK: 10
+  PRINT *, scalar_cb
+end subroutine test_with_scalar_cb_link_only
+
 program main
   call test_with_array_link_and_tofrom()
   call test_with_array_link_only()
   call test_with_array_enter_only()
   call test_with_scalar_link_only()
+  call test_with_array_cb_link_only()
+  call test_with_scalar_cb_link_only()
 end program



More information about the flang-commits mailing list