[flang-commits] [flang] e66ea43 - [flang][cuda] Improve data transfer detection by filtering symbols (#98378)

via flang-commits flang-commits at lists.llvm.org
Thu Jul 11 09:36:38 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-07-11T09:36:35-07:00
New Revision: e66ea43a399a1d70cbe3e4ed6adc77b2570cc51f

URL: https://github.com/llvm/llvm-project/commit/e66ea43a399a1d70cbe3e4ed6adc77b2570cc51f
DIFF: https://github.com/llvm/llvm-project/commit/e66ea43a399a1d70cbe3e4ed6adc77b2570cc51f.diff

LOG: [flang][cuda] Improve data transfer detection by filtering symbols (#98378)

The current data transfer detection was collecting too many symbol and
made wrong decision. This patch introduces a new function
`CollectCudaSymbols` that is different than `CollectSymbols` and collect
only symbol of interest for cuda data transfer in an expression.

Currently two cases where symbols are filtered out are: 
- array subscripts: only the array symbol is on interest, the indexing
can be filtered out
- function arguments: symbols of the function arguments are filtered
out.

This fix some false positive data transfer and implicit data transfer. 

More filtering might be needed and will be added as follow up patches.

Added: 
    

Modified: 
    flang/include/flang/Evaluate/tools.h
    flang/lib/Evaluate/tools.cpp
    flang/test/Lower/CUDA/cuda-data-transfer.cuf

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 625f9e5f6576f..8555073a2d0d4 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1073,6 +1073,16 @@ extern template semantics::UnorderedSymbolSet CollectSymbols(
 extern template semantics::UnorderedSymbolSet CollectSymbols(
     const Expr<SubscriptInteger> &);
 
+// Collects Symbols of interest for the CUDA data transfer in an expression
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeType> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeInteger> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SubscriptInteger> &);
+
 // Predicate: does a variable contain a vector-valued subscript (not a triplet)?
 bool HasVectorSubscript(const Expr<SomeType> &);
 
@@ -1236,7 +1246,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
 // Get the number of distinct symbols with CUDA attribute in the expression.
 template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
   semantics::UnorderedSymbolSet symbols;
-  for (const Symbol &sym : CollectSymbols(expr)) {
+  for (const Symbol &sym : CollectCudaSymbols(expr)) {
     if (const auto *details =
             sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
       if (details->cudaDataAttr() &&
@@ -1259,7 +1269,7 @@ template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
 inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
   unsigned hostSymbols{0};
   unsigned deviceSymbols{0};
-  for (const Symbol &sym : CollectSymbols(expr)) {
+  for (const Symbol &sym : CollectCudaSymbols(expr)) {
     if (const auto *details =
             sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
       if (details->cudaDataAttr() &&

diff  --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index a5f4faa0cef8f..34faba39ffd46 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1000,6 +1000,35 @@ template semantics::UnorderedSymbolSet CollectSymbols(
 template semantics::UnorderedSymbolSet CollectSymbols(
     const Expr<SubscriptInteger> &);
 
+struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
+                                      semantics::UnorderedSymbolSet> {
+  using Base =
+      SetTraverse<CollectCudaSymbolsHelper, semantics::UnorderedSymbolSet>;
+  CollectCudaSymbolsHelper() : Base{*this} {}
+  using Base::operator();
+  semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
+    return {symbol};
+  }
+  // Overload some of the operator() to filter out the symbols that are not
+  // of interest for CUDA data transfer logic.
+  semantics::UnorderedSymbolSet operator()(const Subscript &) const {
+    return {};
+  }
+  semantics::UnorderedSymbolSet operator()(const ProcedureRef &) const {
+    return {};
+  }
+};
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &x) {
+  return CollectCudaSymbolsHelper{}(x);
+}
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeType> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeInteger> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SubscriptInteger> &);
+
 // HasVectorSubscript()
 struct HasVectorSubscriptHelper
     : public AnyTraverse<HasVectorSubscriptHelper, bool,

diff  --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 0191de748d3eb..8b04a2c202dc0 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -6,6 +6,12 @@ module mod1
   type :: t1
     integer :: i
   end type
+
+contains
+  function dev1(a)
+    integer, device :: a(:)
+    integer :: dev1
+  end function
 end
 
 subroutine sub1()
@@ -213,11 +219,35 @@ subroutine sub10(a, b)
   res = a + b
 end subroutine
 
-
-
 ! CHECK-LABEL: func.func @_QPsub10(
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}
 
 ! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub10Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK: cuf.data_transfer %[[A]]#1 to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
 ! CHECK-NOT: cuf.data_transfer
+
+subroutine sub11(n)
+  integer :: n
+  real, dimension(10) :: h
+  real, dimension(n), device :: d
+  do i=1,10
+    h(i) = d(i)
+  end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub11
+! CHECK: %[[RHS:.*]] = hlfir.designate %{{.*}} (%{{.*}})  : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+! CHECK: %[[LHS:.*]] = hlfir.designate %{{.*}} (%{{.*}})  : (!fir.ref<!fir.array<10xf32>>, i64) -> !fir.ref<f32>
+! CHECK: cuf.data_transfer %[[RHS]] to %[[LHS]] {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<f32>, !fir.ref<f32>
+
+subroutine sub12()
+  use mod1
+  integer, device :: a(10)
+  integer :: x
+  x = dev1(a)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub12
+! CHECK: %{{.*}} = fir.call @_QMmod1Pdev1
+! CHECK: hlfir.assign
+! CHECK-NOT: cuf.data_transfer


        


More information about the flang-commits mailing list