[flang-commits] [flang] [flang][cuda] Sync global descriptor when nullifying pointer (PR #121595)

via flang-commits flang-commits at lists.llvm.org
Fri Jan 3 11:54:35 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/121595.diff


5 Files Affected:

- (modified) flang/include/flang/Optimizer/Builder/CUFCommon.h (+6) 
- (modified) flang/lib/Lower/Allocatable.cpp (+2-17) 
- (modified) flang/lib/Lower/Bridge.cpp (+2) 
- (modified) flang/lib/Optimizer/Builder/CUFCommon.cpp (+17) 
- (modified) flang/test/Lower/CUDA/cuda-pointer-sync.cuf (+5-1) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h
index df1b709dc86080..b99e3304296220 100644
--- a/flang/include/flang/Optimizer/Builder/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h
@@ -15,6 +15,10 @@
 
 static constexpr llvm::StringRef cudaDeviceModuleName = "cuda_device_mod";
 
+namespace fir {
+class FirOpBuilder;
+} // namespace fir
+
 namespace cuf {
 
 /// Retrieve or create the CUDA Fortran GPU module in the given \p mod.
@@ -24,6 +28,8 @@ mlir::gpu::GPUModuleOp getOrCreateGPUModule(mlir::ModuleOp mod,
 bool isInCUDADeviceContext(mlir::Operation *op);
 bool isRegisteredDeviceGlobal(fir::GlobalOp op);
 
+void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
+
 } // namespace cuf
 
 #endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
diff --git a/flang/lib/Lower/Allocatable.cpp b/flang/lib/Lower/Allocatable.cpp
index 4c64870675816e..5c63c79892f428 100644
--- a/flang/lib/Lower/Allocatable.cpp
+++ b/flang/lib/Lower/Allocatable.cpp
@@ -1088,22 +1088,6 @@ bool Fortran::lower::isArraySectionWithoutVectorSubscript(
          !Fortran::evaluate::HasVectorSubscript(expr);
 }
 
-static void genCUFPointerSync(const mlir::Value box,
-                              fir::FirOpBuilder &builder) {
-  if (auto declareOp = box.getDefiningOp<hlfir::DeclareOp>()) {
-    if (auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>()) {
-      auto mod = addrOfOp->getParentOfType<mlir::ModuleOp>();
-      if (auto globalOp =
-              mod.lookupSymbol<fir::GlobalOp>(addrOfOp.getSymbol())) {
-        if (cuf::isRegisteredDeviceGlobal(globalOp)) {
-          builder.create<cuf::SyncDescriptorOp>(box.getLoc(),
-                                                addrOfOp.getSymbol());
-        }
-      }
-    }
-  }
-}
-
 void Fortran::lower::associateMutableBox(
     Fortran::lower::AbstractConverter &converter, mlir::Location loc,
     const fir::MutableBoxValue &box, const Fortran::lower::SomeExpr &source,
@@ -1111,12 +1095,13 @@ void Fortran::lower::associateMutableBox(
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(source)) {
     fir::factory::disassociateMutableBox(builder, loc, box);
+    cuf::genPointerSync(box.getAddr(), builder);
     return;
   }
   if (converter.getLoweringOptions().getLowerToHighLevelFIR()) {
     fir::ExtendedValue rhs = converter.genExprAddr(loc, source, stmtCtx);
     fir::factory::associateMutableBox(builder, loc, box, rhs, lbounds);
-    genCUFPointerSync(box.getAddr(), builder);
+    cuf::genPointerSync(box.getAddr(), builder);
     return;
   }
   // The right hand side is not be evaluated into a temp. Array sections can
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index c7e2635230e988..c7bf4248155483 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -34,6 +34,7 @@
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/Support/Utils.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/CUFCommon.h"
 #include "flang/Optimizer/Builder/Character.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/Runtime/Assign.h"
@@ -3952,6 +3953,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       } else {
         fir::MutableBoxValue box = genExprMutableBox(loc, *expr);
         fir::factory::disassociateMutableBox(*builder, loc, box);
+        cuf::genPointerSync(box.getAddr(), *builder);
       }
     }
   }
diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index 81a8a90ce394e0..39848205f47af9 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -7,7 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 
@@ -54,3 +56,18 @@ bool cuf::isRegisteredDeviceGlobal(fir::GlobalOp op) {
     return true;
   return false;
 }
+
+void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
+  if (auto declareOp = box.getDefiningOp<hlfir::DeclareOp>()) {
+    if (auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>()) {
+      auto mod = addrOfOp->getParentOfType<mlir::ModuleOp>();
+      if (auto globalOp =
+              mod.lookupSymbol<fir::GlobalOp>(addrOfOp.getSymbol())) {
+        if (cuf::isRegisteredDeviceGlobal(globalOp)) {
+          builder.create<cuf::SyncDescriptorOp>(box.getLoc(),
+                                                addrOfOp.getSymbol());
+        }
+      }
+    }
+  }
+}
diff --git a/flang/test/Lower/CUDA/cuda-pointer-sync.cuf b/flang/test/Lower/CUDA/cuda-pointer-sync.cuf
index e17869b2d63573..4c64f4bd34aa02 100644
--- a/flang/test/Lower/CUDA/cuda-pointer-sync.cuf
+++ b/flang/test/Lower/CUDA/cuda-pointer-sync.cuf
@@ -8,10 +8,14 @@ use devptr
 real, device, target, dimension(4) :: a_dev
 a_dev = 42.0
 dev_ptr => a_dev
+
+dev_ptr => null()
+
+nullify(dev_ptr)
 end
 
 ! CHECK: fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>>
 ! CHECK-LABEL: func.func @_QQmain()
 ! CHECK: fir.embox
 ! CHECK: fir.store
-! CHECK: cuf.sync_descriptor @_QMdevptrEdev_ptr
+! CHECK-COUNT-3: cuf.sync_descriptor @_QMdevptrEdev_ptr

``````````

</details>


https://github.com/llvm/llvm-project/pull/121595


More information about the flang-commits mailing list