[flang-commits] [flang] [flang][cuda] Improve data transfer detection by filtering symbols (PR #98378)
via flang-commits
flang-commits at lists.llvm.org
Wed Jul 10 13:26:34 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
The current data transfer detection was collecting too many symbol and made wrong decision. This patch introduces a new function `CollectCudaSymbols` that is different than `CollectSymbols` and collect only symbol of interest for cuda data transfer in an expression.
Currently two cases where symbols are filtered out are:
- array subscripts: only the array symbol is on interest, the indexing can be filtered out
- function arguments: symbols of the function arguments are filtered out.
---
Full diff: https://github.com/llvm/llvm-project/pull/98378.diff
3 Files Affected:
- (modified) flang/include/flang/Evaluate/tools.h (+12-2)
- (modified) flang/lib/Evaluate/tools.cpp (+29)
- (modified) flang/test/Lower/CUDA/cuda-data-transfer.cuf (+31-2)
``````````diff
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 625f9e5f6576f..8555073a2d0d4 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1073,6 +1073,16 @@ extern template semantics::UnorderedSymbolSet CollectSymbols(
extern template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SubscriptInteger> &);
+// Collects Symbols of interest for the CUDA data transfer in an expression
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SomeType> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SomeInteger> &);
+extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SubscriptInteger> &);
+
// Predicate: does a variable contain a vector-valued subscript (not a triplet)?
bool HasVectorSubscript(const Expr<SomeType> &);
@@ -1236,7 +1246,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
// Get the number of distinct symbols with CUDA attribute in the expression.
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
semantics::UnorderedSymbolSet symbols;
- for (const Symbol &sym : CollectSymbols(expr)) {
+ for (const Symbol &sym : CollectCudaSymbols(expr)) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
@@ -1259,7 +1269,7 @@ template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
unsigned hostSymbols{0};
unsigned deviceSymbols{0};
- for (const Symbol &sym : CollectSymbols(expr)) {
+ for (const Symbol &sym : CollectCudaSymbols(expr)) {
if (const auto *details =
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
if (details->cudaDataAttr() &&
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index a5f4faa0cef8f..34faba39ffd46 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1000,6 +1000,35 @@ template semantics::UnorderedSymbolSet CollectSymbols(
template semantics::UnorderedSymbolSet CollectSymbols(
const Expr<SubscriptInteger> &);
+struct CollectCudaSymbolsHelper : public SetTraverse<CollectCudaSymbolsHelper,
+ semantics::UnorderedSymbolSet> {
+ using Base =
+ SetTraverse<CollectCudaSymbolsHelper, semantics::UnorderedSymbolSet>;
+ CollectCudaSymbolsHelper() : Base{*this} {}
+ using Base::operator();
+ semantics::UnorderedSymbolSet operator()(const Symbol &symbol) const {
+ return {symbol};
+ }
+ // Overload some of the operator() to filter out the symbols that are not
+ // of interest for CUDA data transfer logic.
+ semantics::UnorderedSymbolSet operator()(const Subscript &) const {
+ return {};
+ }
+ semantics::UnorderedSymbolSet operator()(const ProcedureRef &) const {
+ return {};
+ }
+};
+template <typename A>
+semantics::UnorderedSymbolSet CollectCudaSymbols(const A &x) {
+ return CollectCudaSymbolsHelper{}(x);
+}
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SomeType> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SomeInteger> &);
+template semantics::UnorderedSymbolSet CollectCudaSymbols(
+ const Expr<SubscriptInteger> &);
+
// HasVectorSubscript()
struct HasVectorSubscriptHelper
: public AnyTraverse<HasVectorSubscriptHelper, bool,
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 0191de748d3eb..4929e1dcfabfc 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -6,6 +6,12 @@ module mod1
type :: t1
integer :: i
end type
+
+contains
+ function dev1(a)
+ integer, device :: a(:)
+ integer :: dev1
+ end function
end
subroutine sub1()
@@ -213,11 +219,34 @@ subroutine sub10(a, b)
res = a + b
end subroutine
-
-
! CHECK-LABEL: func.func @_QPsub10(
! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}
! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %1 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub10Ea"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: cuf.data_transfer %[[A]]#1 to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
! CHECK-NOT: cuf.data_transfer
+
+subroutine sub11(n)
+ integer :: n
+ real, dimension(10) :: h
+ real, dimension(n), device :: d
+ do i=1,10
+ h(i) = d(i)
+ end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub11
+! CHECK: %[[RHS:.*]] = hlfir.designate %{{.*}} (%{{.*}}) : (!fir.box<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+! CHECK: %[[LHS:.*]] = hlfir.designate %{{.*}} (%{{.*}}) : (!fir.ref<!fir.array<10xf32>>, i64) -> !fir.ref<f32>
+! CHECK: cuf.data_transfer %[[RHS]] to %[[LHS]] {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<f32>, !fir.ref<f32>
+
+subroutine sub12()
+ use mod1
+ integer, device :: a(10)
+ integer :: x
+ x = dev1(a)
+end subroutine
+
+! CHECK: %{{.*}} = fir.call @_QMmod1Pdev1
+! CHECK: hlfir.assign
+! CHECK-NOT: cuf.data_transfer
``````````
</details>
https://github.com/llvm/llvm-project/pull/98378
More information about the flang-commits
mailing list