[flang-commits] [flang] [flang][cuda] Flag globals used in device function (PR #109460)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Fri Sep 20 12:06:16 PDT 2024
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/109460
This patch adds a pass to flag globals used in device function with the device or constant attribute as they need to be available in the GPU module later on.
>From ee0d59b043d5999d62159aabec7f5dd074328429 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 19 Sep 2024 14:03:46 -0700
Subject: [PATCH] [flang][cuda] Flag globals used in device function
---
.../flang/Optimizer/Transforms/Passes.h | 1 +
.../flang/Optimizer/Transforms/Passes.td | 8 ++
flang/lib/Optimizer/Transforms/CMakeLists.txt | 1 +
.../Transforms/CufImplicitDeviceGlobal.cpp | 73 +++++++++++++++++++
.../Fir/CUDA/cuda-implicit-device-global.f90 | 49 +++++++++++++
5 files changed, 132 insertions(+)
create mode 100644 flang/lib/Optimizer/Transforms/CufImplicitDeviceGlobal.cpp
create mode 100644 flang/test/Fir/CUDA/cuda-implicit-device-global.f90
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 59266a6adfe464..fcfb8677951a2d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -39,6 +39,7 @@ namespace fir {
#define GEN_PASS_DECL_ASSUMEDRANKOPCONVERSION
#define GEN_PASS_DECL_CHARACTERCONVERSION
#define GEN_PASS_DECL_CFGCONVERSION
+#define GEN_PASS_DECL_CUFIMPLICITDEVICEGLOBAL
#define GEN_PASS_DECL_CUFOPCONVERSION
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION
#define GEN_PASS_DECL_MEMREFDATAFLOWOPT
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 925ada0f9d3507..ab98591c911cdf 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -428,4 +428,12 @@ def CufOpConversion : Pass<"cuf-convert", "mlir::ModuleOp"> {
];
}
+def CufImplicitDeviceGlobal :
+ Pass<"cuf-implicit-device-global", "mlir::ModuleOp"> {
+ let summary = "Flag globals used in device function with data attribute";
+ let dependentDialects = [
+ "cuf::CUFDialect"
+ ];
+}
+
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index b32f2ef86fca44..b68e3d68b9b83e 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_flang_library(FIRTransforms
CompilerGeneratedNames.cpp
ConstantArgumentGlobalisation.cpp
ControlFlowConverter.cpp
+ CufImplicitDeviceGlobal.cpp
CufOpConversion.cpp
ArrayValueCopy.cpp
ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/CufImplicitDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CufImplicitDeviceGlobal.cpp
new file mode 100644
index 00000000000000..5f78bf8f005765
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CufImplicitDeviceGlobal.cpp
@@ -0,0 +1,73 @@
+//===-- CufOpConversion.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/Fortran.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Runtime/CUDA/common.h"
+#include "flang/Runtime/allocatable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFIMPLICITDEVICEGLOBAL
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+namespace {
+
+static fir::GlobalOp getGlobalOpFromValue(mlir::Value v) {
+ if (auto addrOfOp{mlir::dyn_cast_or_null<fir::AddrOfOp>(v.getDefiningOp())}) {
+ auto sym{mlir::SymbolTable::lookupNearestSymbolFrom(
+ addrOfOp, addrOfOp.getSymbolAttr())};
+ return mlir::dyn_cast_or_null<fir::GlobalOp>(sym);
+ }
+ return nullptr;
+}
+
+static void prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
+ bool onlyConstant = true) {
+ auto cudaProcAttr{
+ funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
+ if (!cudaProcAttr || cudaProcAttr.getValue() == cuf::ProcAttribute::Host)
+ return;
+ for (auto addrOfOp : funcOp.getBody().getOps<fir::AddrOfOp>()) {
+ if (auto globalOp{getGlobalOpFromValue(addrOfOp.getResult())}) {
+ bool isCandidate{(onlyConstant ? globalOp.getConstant() : true) &&
+ !globalOp.getDataAttr()};
+ if (isCandidate)
+ globalOp.setDataAttrAttr(cuf::DataAttributeAttr::get(
+ funcOp.getContext(), globalOp.getConstant()
+ ? cuf::DataAttribute::Constant
+ : cuf::DataAttribute::Device));
+ }
+ }
+}
+
+class CufImplicitDeviceGlobal
+ : public fir::impl::CufImplicitDeviceGlobalBase<CufImplicitDeviceGlobal> {
+public:
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+ mlir::ConversionTarget target(*ctx);
+
+ mlir::Operation *op = getOperation();
+ mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
+ if (!module)
+ return signalPassFailure();
+
+ module.walk([&](mlir::func::FuncOp funcOp) {
+ prepareImplicitDeviceGlobals(funcOp);
+ return mlir::WalkResult::advance();
+ });
+ }
+};
+} // namespace
diff --git a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
new file mode 100644
index 00000000000000..c8bee3c62e6443
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
@@ -0,0 +1,49 @@
+// RUN: fir-opt --split-input-file --cuf-implicit-device-global %s | FileCheck %s
+
+// Test that global used in device function are flagged with the correct
+// attribute.
+
+func.func @_QMdataPsetvalue() attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
+ %c6_i32 = arith.constant 6 : i32
+ %21 = fir.address_of(@_QQclX6995815537abaf90e86ce166af128f3a) : !fir.ref<!fir.char<1,32>>
+ %22 = fir.convert %21 : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
+ %c14_i32 = arith.constant 14 : i32
+ %23 = fir.call @_FortranAioBeginExternalListOutput(%c6_i32, %22, %c14_i32) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
+ return
+}
+
+func.func private @_FortranAioBeginExternalListOutput(i32, !fir.ref<i8>, i32) -> !fir.ref<i8> attributes {fir.io, fir.runtime}
+fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a constant : !fir.char<1,32> {
+ %0 = fir.string_lit "cuda-implicit-device-global.fir\00"(32) : !fir.char<1,32>
+ fir.has_value %0 : !fir.char<1,32>
+}
+
+// CHECK-LABEL: func.func @_QMdataPsetvalue() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
+
+// CHECK: %[[GLOBAL:.*]] = fir.address_of(@_QQcl[[SYMBOL:.*]]) : !fir.ref<!fir.char<1,32>>
+// CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
+// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
+// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,32>
+
+// -----
+
+func.func @_QMdataPsetvalue() {
+ %c6_i32 = arith.constant 6 : i32
+ %21 = fir.address_of(@_QQclX6995815537abaf90e86ce166af128f3a) : !fir.ref<!fir.char<1,32>>
+ %22 = fir.convert %21 : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
+ %c14_i32 = arith.constant 14 : i32
+ %23 = fir.call @_FortranAioBeginExternalListOutput(%c6_i32, %22, %c14_i32) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
+ return
+}
+
+func.func private @_FortranAioBeginExternalListOutput(i32, !fir.ref<i8>, i32) -> !fir.ref<i8> attributes {fir.io, fir.runtime}
+fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a constant : !fir.char<1,32> {
+ %0 = fir.string_lit "cuda-implicit-device-global.fir\00"(32) : !fir.char<1,32>
+ fir.has_value %0 : !fir.char<1,32>
+}
+
+// CHECK-LABEL: func.func @_QMdataPsetvalue()
+// CHECK: %[[GLOBAL:.*]] = fir.address_of(@_QQcl[[SYMBOL:.*]]) : !fir.ref<!fir.char<1,32>>
+// CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
+// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
+// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] constant : !fir.char<1,32>
More information about the flang-commits
mailing list