[flang-commits] [flang] 067ce5c - [flang][cuda] Use getOrCreateGPUModule in CUFDeviceGlobal pass (#114468)

via flang-commits flang-commits at lists.llvm.org
Thu Oct 31 18:58:47 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-10-31T18:58:43-07:00
New Revision: 067ce5ca1826b5e143a21ed79fb3d6e24474e2ff

URL: https://github.com/llvm/llvm-project/commit/067ce5ca1826b5e143a21ed79fb3d6e24474e2ff
DIFF: https://github.com/llvm/llvm-project/commit/067ce5ca1826b5e143a21ed79fb3d6e24474e2ff.diff

LOG: [flang][cuda] Use getOrCreateGPUModule in CUFDeviceGlobal pass (#114468)

Make the pass functional if gpu module was not created yet.

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
    flang/test/Fir/CUDA/cuda-implicit-device-global.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index a41f0f348f27a6..d89713a9fc0b97 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -432,7 +432,7 @@ def CUFDeviceGlobal :
     Pass<"cuf-device-global", "mlir::ModuleOp"> {
   let summary = "Flag globals used in device function with data attribute";
   let dependentDialects = [
-    "cuf::CUFDialect"
+    "cuf::CUFDialect", "mlir::gpu::GPUDialect", "mlir::NVVM::NVVMDialect"
   ];
 }
 

diff  --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
index dc39be8574f844..a69b47ff743911 100644
--- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
@@ -14,6 +14,7 @@
 #include "flang/Optimizer/Transforms/CUFCommon.h"
 #include "flang/Runtime/CUDA/common.h"
 #include "flang/Runtime/allocatable.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -62,27 +63,26 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
 
     // Copying the device global variable into the gpu module
     mlir::SymbolTable parentSymTable(mod);
-    auto gpuMod =
-        parentSymTable.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
-    if (gpuMod) {
-      mlir::SymbolTable gpuSymTable(gpuMod);
-      for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
-        auto attr = globalOp.getDataAttrAttr();
-        if (!attr)
-          continue;
-        switch (attr.getValue()) {
-        case cuf::DataAttribute::Device:
-        case cuf::DataAttribute::Constant:
-        case cuf::DataAttribute::Managed: {
-          auto globalName{globalOp.getSymbol().getValue()};
-          if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
-            break;
-          }
-          gpuSymTable.insert(globalOp->clone());
-        } break;
-        default:
+    auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable);
+    if (!gpuMod)
+      return signalPassFailure();
+    mlir::SymbolTable gpuSymTable(gpuMod);
+    for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
+      auto attr = globalOp.getDataAttrAttr();
+      if (!attr)
+        continue;
+      switch (attr.getValue()) {
+      case cuf::DataAttribute::Device:
+      case cuf::DataAttribute::Constant:
+      case cuf::DataAttribute::Managed: {
+        auto globalName{globalOp.getSymbol().getValue()};
+        if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
           break;
         }
+        gpuSymTable.insert(globalOp->clone());
+      } break;
+      default:
+        break;
       }
     }
   }

diff  --git a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
index 82a0c5948d9cb9..18b56a491cd65f 100644
--- a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
+++ b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
@@ -25,6 +25,9 @@ // Test that global used in device function are flagged with the correct
 // CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
 // CHECK: fir.global linkonce @_QQcl[[SYMBOL]] {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,32>
 
+// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
+// CHECK: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a
+
 // -----
 
 func.func @_QMdataPsetvalue() {
@@ -47,3 +50,6 @@ // Test that global used in device function are flagged with the correct
 // CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
 // CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
 // CHECK: fir.global linkonce @_QQcl[[SYMBOL]] constant : !fir.char<1,32>
+
+// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
+// CHECK-NOT: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a


        


More information about the flang-commits mailing list