[flang-commits] [flang] 3433e41 - [flang][cuda] Detect constant on the rhs of data transfer (#117806)

via flang-commits flang-commits at lists.llvm.org
Tue Nov 26 17:04:03 PST 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-11-26T17:04:00-08:00
New Revision: 3433e4140d18865fe784061a3cd029c5980f4e2f

URL: https://github.com/llvm/llvm-project/commit/3433e4140d18865fe784061a3cd029c5980f4e2f
DIFF: https://github.com/llvm/llvm-project/commit/3433e4140d18865fe784061a3cd029c5980f4e2f.diff

LOG: [flang][cuda] Detect constant on the rhs of data transfer (#117806)

When the rhs expression has some constants and a device symbol, an
implicit data transfer needs to be generated for the device symbol and
the computation with the constant is done on the host.

Added: 
    

Modified: 
    flang/include/flang/Evaluate/tools.h
    flang/lib/Evaluate/tools.cpp
    flang/lib/Lower/Bridge.cpp
    flang/test/Lower/CUDA/cuda-data-transfer.cuf

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 6261a4eec4a555..dafacdf1ba0c5a 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1102,6 +1102,9 @@ extern template semantics::UnorderedSymbolSet CollectCudaSymbols(
 // Predicate: does a variable contain a vector-valued subscript (not a triplet)?
 bool HasVectorSubscript(const Expr<SomeType> &);
 
+// Predicate: does an expression contain constant?
+bool HasConstant(const Expr<SomeType> &);
+
 // Utilities for attaching the location of the declaration of a symbol
 // of interest to a message.  Handles the case of USE association gracefully.
 parser::Message *AttachDeclaration(parser::Message &, const Symbol &);
@@ -1319,7 +1322,8 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
       ++hostSymbols;
     }
   }
-  return hostSymbols > 0 && deviceSymbols > 0;
+  bool hasConstant{HasConstant(expr)};
+  return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
 }
 
 } // namespace Fortran::evaluate

diff  --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 15e3e9452894de..a040f7ce79dc10 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1051,6 +1051,23 @@ bool HasVectorSubscript(const Expr<SomeType> &expr) {
   return HasVectorSubscriptHelper{}(expr);
 }
 
+// HasConstant()
+struct HasConstantHelper : public AnyTraverse<HasConstantHelper, bool,
+                               /*TraverseAssocEntityDetails=*/false> {
+  using Base = AnyTraverse<HasConstantHelper, bool, false>;
+  HasConstantHelper() : Base{*this} {}
+  using Base::operator();
+  template <typename T> bool operator()(const Constant<T> &) const {
+    return true;
+  }
+  // Only look for constant not in subscript.
+  bool operator()(const Subscript &) const { return false; }
+};
+
+bool HasConstant(const Expr<SomeType> &expr) {
+  return HasConstantHelper{}(expr);
+}
+
 parser::Message *AttachDeclaration(
     parser::Message &message, const Symbol &symbol) {
   const Symbol *unhosted{&symbol};

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index cbae6955e2a076..17b58604da3bf6 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4416,6 +4416,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     bool hasCUDAImplicitTransfer =
         Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
     llvm::SmallVector<mlir::Value> implicitTemps;
+
     if (hasCUDAImplicitTransfer && !isInDeviceContext)
       implicitTemps = genCUDAImplicitDataTransfer(builder, loc, assign);
 

diff  --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 3b6cd67d9a8fa5..cbddcd79c63334 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -38,7 +38,6 @@ subroutine sub1()
   adev = 10
 
   cdev = 0
-
 end
 
 ! CHECK-LABEL: func.func @_QPsub1()
@@ -381,3 +380,17 @@ end subroutine
 
 ! CHECK-LABEL: func.func @_QPsub18
 ! CHECK-NOT: cuf.data_transfer
+
+subroutine sub19()
+  integer, device :: adev(10)
+  integer :: ahost(10)
+  ! Implicit data transfer of adev and then addition on the host
+  ahost = adev + 2
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub19()
+! CHECK: %[[ADEV_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub19Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+! CHECK: %[[ALLOC_TMP:.*]] = fir.allocmem !fir.array<10xi32> {bindc_name = ".tmp", uniq_name = ""}
+! CHECK: %[[TMP:.*]]:2 = hlfir.declare %[[ALLOC_TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
+! CHECK: cuf.data_transfer %[[ADEV_DECL]]#1 to %[[TMP]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>
+! CHECL: hlfir.assign 


        


More information about the flang-commits mailing list