[flang-commits] [flang] [flang][cuda] Improve data transfer detection by filtering symbols (PR #98378)

via flang-commits flang-commits at lists.llvm.org
Wed Jul 10 13:26:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

The current data transfer detection was collecting too many symbol and made wrong decision. This patch introduces a new function `CollectCudaSymbols` that is different than `CollectSymbols` and collect only symbol of interest for cuda data transfer in an expression.

Currently two cases where symbols are filtered out are: 
- array subscripts: only the array symbol is on interest, the indexing can be filtered out
- function arguments: symbols of the function arguments are filtered out. 

---
Full diff: https://github.com/llvm/llvm-project/pull/98378.diff


3 Files Affected:

- (modified) flang/include/flang/Evaluate/tools.h (+12-2) 
- (modified) flang/lib/Evaluate/tools.cpp (+29) 
- (modified) flang/test/Lower/CUDA/cuda-data-transfer.cuf (+31-2) 


``````````diff
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 625f9e5f6576f..8555073a2d0d4 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1073,6 +1073,16 @@ extern template semantics::UnorderedSymbolSet CollectSymbols(
 extern template semantics::UnorderedSymbolSet CollectSymbols(
     const Expr<SubscriptInteger> &);
 
+// Collects Symbols of interest for the CUDA data transfer in an expression
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeType> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeInteger> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SubscriptInteger> &);
+
 // Predicate: does a variable contain a vector-valued subscript (not a triplet)?
 bool HasVectorSubscript(const Expr<SomeType> &);
 
@@ -1236,7 +1246,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
 // Get the number of distinct symbols with CUDA attribute in the expression.
 template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
   semantics::UnorderedSymbolSet symbols;
-  for (const Symbol &sym : CollectSymbols(expr)) {
+  for (const Symbol &sym : CollectCudaSymbols(expr)) {
     if (const auto *details =
             sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
       if (details->cudaDataAttr() &&
@@ -1259,7 +1269,7 @@ template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
 inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
   unsigned hostSymbols{0};
   unsigned deviceSymbols{0};
-  for (const Symbol &sym : CollectSymbols(expr)) {
+  for (const Symbol &sym : CollectCudaSymbols(expr)) {
     if (const auto *details =
             sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
       if (details->cudaDataAttr() &&
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index a5f4faa0cef8f..34faba39ffd46 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1000,6 +1000,35 @@ template semantics::UnorderedSymbolSet CollectSymbols(
 template semantics::UnorderedSymbolSet CollectSymbols(
     const Expr<SubscriptInteger> &);
 
+struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
+                                      semantics::UnorderedSymbolSet> {
+  using Base =
+      SetTraverse<CollectCudaSymbolsHelper, semantics::UnorderedSymbolSet>;
+  CollectCudaSymbolsHelper() : Base{*this} {}
+  using Base::operator();
+  semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
+    return {symbol};
+  }
+  // Overload some of the operator() to filter out the symbols that are not
+  // of interest for CUDA data transfer logic.
+  semantics::UnorderedSymbolSet operator()(const Subscript &) const {
+    return {};
+  }
+  semantics::UnorderedSymbolSet operator()(const ProcedureRef &) const {
+    return {};
+  }
+};
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &x) {
+  return CollectCudaSymbolsHelper{}(x);
+}
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeType> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SomeInteger> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+    const Expr<SubscriptInteger> &);
+
 // HasVectorSubscript()
 struct HasVectorSubscriptHelper
     : public AnyTraverse<HasVectorSubscriptHelper, bool,
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 0191de748d3eb..4929e1dcfabfc 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -6,6 +6,12 @@ module mod1
   type :: t1
     integer :: i
   end type
+
+contains
+  function dev1(a)
+    integer, device :: a(:)
+    integer :: dev1
+  end function
 end
 
 subroutine sub1()
@@ -213,11 +219,34 @@ subroutine sub10(a, b)
   res = a + b
 end subroutine
 
-
-
 ! CHECK-LABEL: func.func @_QPsub10(
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}
 
 ! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub10Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
 ! CHECK: cuf.data_transfer %[[A]]#1 to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
 ! CHECK-NOT: cuf.data_transfer
+
+subroutine sub11(n)
+  integer :: n
+  real, dimension(10) :: h
+  real, dimension(n), device :: d
+  do i=1,10
+    h(i) = d(i)
+  end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub11
+! CHECK: %[[RHS:.*]] = hlfir.designate %{{.*}} (%{{.*}})  : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+! CHECK: %[[LHS:.*]] = hlfir.designate %{{.*}} (%{{.*}})  : (!fir.ref<!fir.array<10xf32>>, i64) -> !fir.ref<f32>
+! CHECK: cuf.data_transfer %[[RHS]] to %[[LHS]] {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<f32>, !fir.ref<f32>
+
+subroutine sub12()
+  use mod1
+  integer, device :: a(10)
+  integer :: x
+  x = dev1(a)
+end subroutine
+
+! CHECK: %{{.*}} = fir.call @_QMmod1Pdev1
+! CHECK: hlfir.assign
+! CHECK-NOT: cuf.data_transfer

``````````

</details>


https://github.com/llvm/llvm-project/pull/98378


More information about the flang-commits mailing list