[flang-commits] [flang] [flang][cuda] Use getOrCreateGPUModule in CUFDeviceGlobal pass (PR #114468)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Thu Oct 31 14:37:49 PDT 2024
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/114468
Make the pass functional if gpu module was not created yet.
>From 845f501f6631bd6ad852d5cbd8206a823d1c947f Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 31 Oct 2024 14:36:46 -0700
Subject: [PATCH] [flang][cuda] Use getOrCreateGPUModule in CUFDeviceGlobal
pass
---
.../flang/Optimizer/Transforms/Passes.td | 2 +-
.../Optimizer/Transforms/CUFDeviceGlobal.cpp | 38 +++++++++----------
.../Fir/CUDA/cuda-implicit-device-global.f90 | 6 +++
3 files changed, 26 insertions(+), 20 deletions(-)
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index a41f0f348f27a6..d89713a9fc0b97 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -432,7 +432,7 @@ def CUFDeviceGlobal :
Pass<"cuf-device-global", "mlir::ModuleOp"> {
let summary = "Flag globals used in device function with data attribute";
let dependentDialects = [
- "cuf::CUFDialect"
+ "cuf::CUFDialect", "mlir::gpu::GPUDialect", "mlir::NVVM::NVVMDialect"
];
}
diff --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
index dc39be8574f844..a69b47ff743911 100644
--- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
@@ -14,6 +14,7 @@
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/allocatable.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -62,27 +63,26 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
// Copying the device global variable into the gpu module
mlir::SymbolTable parentSymTable(mod);
- auto gpuMod =
- parentSymTable.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
- if (gpuMod) {
- mlir::SymbolTable gpuSymTable(gpuMod);
- for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
- auto attr = globalOp.getDataAttrAttr();
- if (!attr)
- continue;
- switch (attr.getValue()) {
- case cuf::DataAttribute::Device:
- case cuf::DataAttribute::Constant:
- case cuf::DataAttribute::Managed: {
- auto globalName{globalOp.getSymbol().getValue()};
- if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
- break;
- }
- gpuSymTable.insert(globalOp->clone());
- } break;
- default:
+ auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable);
+ if (!gpuMod)
+ return signalPassFailure();
+ mlir::SymbolTable gpuSymTable(gpuMod);
+ for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
+ auto attr = globalOp.getDataAttrAttr();
+ if (!attr)
+ continue;
+ switch (attr.getValue()) {
+ case cuf::DataAttribute::Device:
+ case cuf::DataAttribute::Constant:
+ case cuf::DataAttribute::Managed: {
+ auto globalName{globalOp.getSymbol().getValue()};
+ if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
break;
}
+ gpuSymTable.insert(globalOp->clone());
+ } break;
+ default:
+ break;
}
}
}
diff --git a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
index 82a0c5948d9cb9..18b56a491cd65f 100644
--- a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
+++ b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
@@ -25,6 +25,9 @@ // Test that global used in device function are flagged with the correct
// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,32>
+// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
+// CHECK: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a
+
// -----
func.func @_QMdataPsetvalue() {
@@ -47,3 +50,6 @@ // Test that global used in device function are flagged with the correct
// CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] constant : !fir.char<1,32>
+
+// CHECK-LABEL: gpu.module @cuda_device_mod [#nvvm.target]
+// CHECK-NOT: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a
More information about the flang-commits
mailing list