[flang-commits] [flang] 2e89e6b - [flang][cuda] Flag globals used in device function (#109460)

via flang-commits flang-commits at lists.llvm.org
Fri Sep 20 18:03:28 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-09-20T18:03:25-07:00
New Revision: 2e89e6b59a32450f43416bfdfb65748ea4606875

URL: https://github.com/llvm/llvm-project/commit/2e89e6b59a32450f43416bfdfb65748ea4606875
DIFF: https://github.com/llvm/llvm-project/commit/2e89e6b59a32450f43416bfdfb65748ea4606875.diff

LOG: [flang][cuda] Flag globals used in device function (#109460)

Added: 
    flang/lib/Optimizer/Transforms/CufImplicitDeviceGlobal.cpp
    flang/test/Fir/CUDA/cuda-implicit-device-global.f90

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Optimizer/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 59266a6adfe464..fcfb8677951a2d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -39,6 +39,7 @@ namespace fir {
 #define GEN_PASS_DECL_ASSUMEDRANKOPCONVERSION
 #define GEN_PASS_DECL_CHARACTERCONVERSION
 #define GEN_PASS_DECL_CFGCONVERSION
+#define GEN_PASS_DECL_CUFIMPLICITDEVICEGLOBAL
 #define GEN_PASS_DECL_CUFOPCONVERSION
 #define GEN_PASS_DECL_EXTERNALNAMECONVERSION
 #define GEN_PASS_DECL_MEMREFDATAFLOWOPT

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 925ada0f9d3507..ab98591c911cdf 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -428,4 +428,12 @@ def CufOpConversion : Pass<"cuf-convert", "mlir::ModuleOp"> {
   ];
 }
 
+def CufImplicitDeviceGlobal :
+    Pass<"cuf-implicit-device-global", "mlir::ModuleOp"> {
+  let summary = "Flag globals used in device function with data attribute";
+  let dependentDialects = [
+    "cuf::CUFDialect"
+  ];
+}
+
 #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index b32f2ef86fca44..b68e3d68b9b83e 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -9,6 +9,7 @@ add_flang_library(FIRTransforms
   CompilerGeneratedNames.cpp
   ConstantArgumentGlobalisation.cpp
   ControlFlowConverter.cpp
+  CufImplicitDeviceGlobal.cpp
   CufOpConversion.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp

diff  --git a/flang/lib/Optimizer/Transforms/CufImplicitDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CufImplicitDeviceGlobal.cpp
new file mode 100644
index 00000000000000..206400c2ef8e53
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CufImplicitDeviceGlobal.cpp
@@ -0,0 +1,64 @@
+//===-- CufOpConversion.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/Fortran.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Runtime/CUDA/common.h"
+#include "flang/Runtime/allocatable.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFIMPLICITDEVICEGLOBAL
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+namespace {
+
+static void prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
+                                         mlir::SymbolTable &symbolTable,
+                                         bool onlyConstant = true) {
+  auto cudaProcAttr{
+      funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
+  if (!cudaProcAttr || cudaProcAttr.getValue() == cuf::ProcAttribute::Host)
+    return;
+  for (auto addrOfOp : funcOp.getBody().getOps<fir::AddrOfOp>()) {
+    if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
+            addrOfOp.getSymbol().getRootReference().getValue())) {
+      bool isCandidate{(onlyConstant ? globalOp.getConstant() : true) &&
+                       !globalOp.getDataAttr()};
+      if (isCandidate)
+        globalOp.setDataAttrAttr(cuf::DataAttributeAttr::get(
+            funcOp.getContext(), globalOp.getConstant()
+                                     ? cuf::DataAttribute::Constant
+                                     : cuf::DataAttribute::Device));
+    }
+  }
+}
+
+class CufImplicitDeviceGlobal
+    : public fir::impl::CufImplicitDeviceGlobalBase<CufImplicitDeviceGlobal> {
+public:
+  void runOnOperation() override {
+    mlir::Operation *op = getOperation();
+    mlir::ModuleOp mod = mlir::dyn_cast<mlir::ModuleOp>(op);
+    if (!mod)
+      return signalPassFailure();
+
+    mlir::SymbolTable symTable(mod);
+    mod.walk([&](mlir::func::FuncOp funcOp) {
+      prepareImplicitDeviceGlobals(funcOp, symTable);
+      return mlir::WalkResult::advance();
+    });
+  }
+};
+} // namespace

diff  --git a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
new file mode 100644
index 00000000000000..c8bee3c62e6443
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
@@ -0,0 +1,49 @@
+// RUN: fir-opt --split-input-file --cuf-implicit-device-global %s | FileCheck %s
+
+// Test that global used in device function are flagged with the correct
+// attribute.
+
+func.func @_QMdataPsetvalue() attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
+  %c6_i32 = arith.constant 6 : i32
+  %21 = fir.address_of(@_QQclX6995815537abaf90e86ce166af128f3a) : !fir.ref<!fir.char<1,32>>
+  %22 = fir.convert %21 : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
+  %c14_i32 = arith.constant 14 : i32
+  %23 = fir.call @_FortranAioBeginExternalListOutput(%c6_i32, %22, %c14_i32) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
+  return
+}
+
+func.func private @_FortranAioBeginExternalListOutput(i32, !fir.ref<i8>, i32) -> !fir.ref<i8> attributes {fir.io, fir.runtime}
+fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a constant : !fir.char<1,32> {
+  %0 = fir.string_lit "cuda-implicit-device-global.fir\00"(32) : !fir.char<1,32>
+  fir.has_value %0 : !fir.char<1,32>
+}
+
+// CHECK-LABEL: func.func @_QMdataPsetvalue() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
+
+// CHECK: %[[GLOBAL:.*]] = fir.address_of(@_QQcl[[SYMBOL:.*]]) : !fir.ref<!fir.char<1,32>>
+// CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
+// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
+// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,32>
+
+// -----
+
+func.func @_QMdataPsetvalue() {
+  %c6_i32 = arith.constant 6 : i32
+  %21 = fir.address_of(@_QQclX6995815537abaf90e86ce166af128f3a) : !fir.ref<!fir.char<1,32>>
+  %22 = fir.convert %21 : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
+  %c14_i32 = arith.constant 14 : i32
+  %23 = fir.call @_FortranAioBeginExternalListOutput(%c6_i32, %22, %c14_i32) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
+  return
+}
+
+func.func private @_FortranAioBeginExternalListOutput(i32, !fir.ref<i8>, i32) -> !fir.ref<i8> attributes {fir.io, fir.runtime}
+fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a constant : !fir.char<1,32> {
+  %0 = fir.string_lit "cuda-implicit-device-global.fir\00"(32) : !fir.char<1,32>
+  fir.has_value %0 : !fir.char<1,32>
+}
+
+// CHECK-LABEL: func.func @_QMdataPsetvalue()
+// CHECK: %[[GLOBAL:.*]] = fir.address_of(@_QQcl[[SYMBOL:.*]]) : !fir.ref<!fir.char<1,32>>
+// CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
+// CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
+// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] constant : !fir.char<1,32>


        


More information about the flang-commits mailing list