[flang-commits] [flang] [flang][cuda] Use getOrCreateGPUModule in CUFDeviceGlobal pass (PR #114468)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Oct 31 14:37:49 PDT 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/114468

Make the pass functional if gpu module was not created yet. 

>From 845f501f6631bd6ad852d5cbd8206a823d1c947f Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 31 Oct 2024 14:36:46 -0700
Subject: [PATCH] [flang][cuda] Use getOrCreateGPUModule in CUFDeviceGlobal
 pass

---
 .../flang/Optimizer/Transforms/Passes.td      |  2 +-
 .../Optimizer/Transforms/CUFDeviceGlobal.cpp  | 38 +++++++++----------
 .../Fir/CUDA/cuda-implicit-device-global.f90  |  6 +++
 3 files changed, 26 insertions(+), 20 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index a41f0f348f27a6..d89713a9fc0b97 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -432,7 +432,7 @@ def CUFDeviceGlobal :
     Pass<"cuf-device-global", "mlir::ModuleOp"> {
   let summary = "Flag globals used in device function with data attribute";
   let dependentDialects = [
-    "cuf::CUFDialect"
+    "cuf::CUFDialect", "mlir::gpu::GPUDialect", "mlir::NVVM::NVVMDialect"
   ];
 }
 
diff --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
index dc39be8574f844..a69b47ff743911 100644
--- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
@@ -14,6 +14,7 @@
 #include "flang/Optimizer/Transforms/CUFCommon.h"
 #include "flang/Runtime/CUDA/common.h"
 #include "flang/Runtime/allocatable.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -62,27 +63,26 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
 
     // Copying the device global variable into the gpu module
     mlir::SymbolTable parentSymTable(mod);
-    auto gpuMod =
-        parentSymTable.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
-    if (gpuMod) {
-      mlir::SymbolTable gpuSymTable(gpuMod);
-      for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
-        auto attr = globalOp.getDataAttrAttr();
-        if (!attr)
-          continue;
-        switch (attr.getValue()) {
-        case cuf::DataAttribute::Device:
-        case cuf::DataAttribute::Constant:
-        case cuf::DataAttribute::Managed: {
-          auto globalName{globalOp.getSymbol().getValue()};
-          if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
-            break;
-          }
-          gpuSymTable.insert(globalOp->clone());
-        } break;
-        default:
+    auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable);
+    if (!gpuMod)
+      return signalPassFailure();
+    mlir::SymbolTable gpuSymTable(gpuMod);
+    for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
+      auto attr = globalOp.getDataAttrAttr();
+      if (!attr)
+        continue;
+      switch (attr.getValue()) {
+      case cuf::DataAttribute::Device:
+      case cuf::DataAttribute::Constant:
+      case cuf::DataAttribute::Managed: {
+        auto globalName{globalOp.getSymbol().getValue()};
+        if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
           break;
         }
+        gpuSymTable.insert(globalOp->clone());
+      } break;
+      default:
+        break;
       }
     }
   }
diff --git a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
index 82a0c5948d9cb9..18b56a491cd65f 100644
--- a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
+++ b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
@@ -25,6 +25,9 @@ // Test that global used in device function are flagged with the correct
 // CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
 // CHECK: fir.global linkonce @_QQcl[[SYMBOL]] {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,32>
 
+// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
+// CHECK: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a
+
 // -----
 
 func.func @_QMdataPsetvalue() {
@@ -47,3 +50,6 @@ // Test that global used in device function are flagged with the correct
 // CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
 // CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
 // CHECK: fir.global linkonce @_QQcl[[SYMBOL]] constant : !fir.char<1,32>
+
+// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
+// CHECK-NOT: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a



More information about the flang-commits mailing list