[flang-commits] [flang] [flang][cuda] Distinguish constant fir.global from globals with a #cuf.cuda<constant> attribute (PR #118912)

Renaud Kauffmann via flang-commits flang-commits at lists.llvm.org
Thu Dec 5 18:23:57 PST 2024


https://github.com/Renaud-K created https://github.com/llvm/llvm-project/pull/118912

1. In `CufOpConversion` `isDeviceGlobal` was renamed `isRegisteredGlobal` and moved to the common file. `isRegisteredGlobal` excludes constant `fir.global` operation from registration. This is to avoid calls to `_FortranACUFGetDeviceAddress` on globals which do not have any symbols in the runtime. This was done for `_FortranACUFRegisterVariable` in #118582, but also needs to be done here after #118591
2. `CufDeviceGlobal` no longer adds the `#cuf.cuda<constant>` attribute to the constant global. As discussed in #118582 a module variable with the #cuf.cuda<constant> attribute is not a compile time constant. Yet, the compile time constant also needs to be copied into the GPU module, there candidates for copy to the GPU modules are 

- the globals needing regsitrations regardless of their uses in device code (they can be referred to in host code as well) 
- the compile time constant when used in device code 

3. The registration of "constant" module device variables ( #cuf.cuda<constant>) can be restored in `CufAddConstructor`

>From 412971a4bc5df62506cb96571642f7027eeaf26a Mon Sep 17 00:00:00 2001
From: Renaud-K <rkauffmann at nvidia.com>
Date: Thu, 5 Dec 2024 17:58:47 -0800
Subject: [PATCH] Distinguish constant fir.global from globals with a
 #cuf.cuda<constant> attribute

---
 .../flang/Optimizer/Transforms/CUFCommon.h    |  2 +
 .../Transforms/CUFAddConstructor.cpp          |  3 +-
 flang/lib/Optimizer/Transforms/CUFCommon.cpp  | 11 ++++
 .../Optimizer/Transforms/CUFDeviceGlobal.cpp  | 63 ++++++++-----------
 .../Optimizer/Transforms/CUFOpConversion.cpp  | 13 +---
 flang/test/Fir/CUDA/cuda-constructor-2.f90    |  2 +-
 flang/test/Fir/CUDA/cuda-global-addr.mlir     | 32 +++++++++-
 .../Fir/CUDA/cuda-implicit-device-global.f90  | 12 ++--
 8 files changed, 82 insertions(+), 56 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/CUFCommon.h b/flang/include/flang/Optimizer/Transforms/CUFCommon.h
index f019d1893bda4c..df1b709dc86080 100644
--- a/flang/include/flang/Optimizer/Transforms/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Transforms/CUFCommon.h
@@ -9,6 +9,7 @@
 #ifndef FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
 #define FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
 
+#include "flang/Optimizer/Dialect/FIROps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 
@@ -21,6 +22,7 @@ mlir::gpu::GPUModuleOp getOrCreateGPUModule(mlir::ModuleOp mod,
                                             mlir::SymbolTable &symTab);
 
 bool isInCUDADeviceContext(mlir::Operation *op);
+bool isRegisteredDeviceGlobal(fir::GlobalOp op);
 
 } // namespace cuf
 
diff --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
index 73a46843f0320b..9591f48c5d4177 100644
--- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
@@ -106,7 +106,8 @@ struct CUFAddConstructor
 
         mlir::func::FuncOp func;
         switch (attr.getValue()) {
-        case cuf::DataAttribute::Device: {
+        case cuf::DataAttribute::Device:
+        case cuf::DataAttribute::Constant: {
           func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>(
               loc, builder);
           auto fTy = func.getFunctionType();
diff --git a/flang/lib/Optimizer/Transforms/CUFCommon.cpp b/flang/lib/Optimizer/Transforms/CUFCommon.cpp
index 5b7631bbacb5f2..bbe33217e8f455 100644
--- a/flang/lib/Optimizer/Transforms/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFCommon.cpp
@@ -43,3 +43,14 @@ bool cuf::isInCUDADeviceContext(mlir::Operation *op) {
   }
   return false;
 }
+
+bool cuf::isRegisteredDeviceGlobal(fir::GlobalOp op) {
+  if (op.getConstant())
+    return false;
+  auto attr = op.getDataAttr();
+  if (attr && (*attr == cuf::DataAttribute::Device ||
+               *attr == cuf::DataAttribute::Managed ||
+               *attr == cuf::DataAttribute::Constant))
+    return true;
+  return false;
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
index 714b0b291be1e9..18150c4e595d40 100644
--- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/DenseSet.h"
 
 namespace fir {
 #define GEN_PASS_DEF_CUFDEVICEGLOBAL
@@ -27,36 +28,30 @@ namespace fir {
 namespace {
 
 static void processAddrOfOp(fir::AddrOfOp addrOfOp,
-                            mlir::SymbolTable &symbolTable, bool onlyConstant) {
+                            mlir::SymbolTable &symbolTable,
+                            llvm::DenseSet<fir::GlobalOp> &candidates) {
   if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
           addrOfOp.getSymbol().getRootReference().getValue())) {
-    bool isCandidate{(onlyConstant ? globalOp.getConstant() : true) &&
-                     !globalOp.getDataAttr()};
-    if (isCandidate)
-      globalOp.setDataAttrAttr(cuf::DataAttributeAttr::get(
-          addrOfOp.getContext(), globalOp.getConstant()
-                                     ? cuf::DataAttribute::Constant
-                                     : cuf::DataAttribute::Device));
+    // TO DO: limit candidates to non-scalars. Scalars appear to have been
+    // folded in already.
+    if (globalOp.getConstant()) {
+      candidates.insert(globalOp);
+    }
   }
 }
 
-static void prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
-                                         mlir::SymbolTable &symbolTable,
-                                         bool onlyConstant = true) {
+static void
+prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
+                             mlir::SymbolTable &symbolTable,
+                             llvm::DenseSet<fir::GlobalOp> &candidates) {
+
   auto cudaProcAttr{
       funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
-  if (!cudaProcAttr || cudaProcAttr.getValue() == cuf::ProcAttribute::Host) {
-    // Look for globlas in CUF KERNEL DO operations.
-    for (auto cufKernelOp : funcOp.getBody().getOps<cuf::KernelOp>()) {
-      cufKernelOp.walk([&](fir::AddrOfOp addrOfOp) {
-        processAddrOfOp(addrOfOp, symbolTable, onlyConstant);
-      });
-    }
-    return;
+  if (cudaProcAttr && cudaProcAttr.getValue() != cuf::ProcAttribute::Host) {
+    funcOp.walk([&](fir::AddrOfOp addrOfOp) {
+      processAddrOfOp(addrOfOp, symbolTable, candidates);
+    });
   }
-  funcOp.walk([&](fir::AddrOfOp addrOfOp) {
-    processAddrOfOp(addrOfOp, symbolTable, onlyConstant);
-  });
 }
 
 class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
@@ -67,9 +62,10 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
     if (!mod)
       return signalPassFailure();
 
+    llvm::DenseSet<fir::GlobalOp> candidates;
     mlir::SymbolTable symTable(mod);
     mod.walk([&](mlir::func::FuncOp funcOp) {
-      prepareImplicitDeviceGlobals(funcOp, symTable);
+      prepareImplicitDeviceGlobals(funcOp, symTable, candidates);
       return mlir::WalkResult::advance();
     });
 
@@ -80,22 +76,15 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
       return signalPassFailure();
     mlir::SymbolTable gpuSymTable(gpuMod);
     for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
-      auto attr = globalOp.getDataAttrAttr();
-      if (!attr)
-        continue;
-      switch (attr.getValue()) {
-      case cuf::DataAttribute::Device:
-      case cuf::DataAttribute::Constant:
-      case cuf::DataAttribute::Managed: {
-        auto globalName{globalOp.getSymbol().getValue()};
-        if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
-          break;
-        }
-        gpuSymTable.insert(globalOp->clone());
-      } break;
-      default:
+      if (cuf::isRegisteredDeviceGlobal(globalOp))
+        candidates.insert(globalOp);
+    }
+    for (auto globalOp : candidates) {
+      auto globalName{globalOp.getSymbol().getValue()};
+      if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
         break;
       }
+      gpuSymTable.insert(globalOp->clone());
     }
   }
 };
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 7f6843d66d39f8..1df82e6accfeda 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -81,15 +81,6 @@ static bool hasDoubleDescriptors(OpTy op) {
   return false;
 }
 
-bool isDeviceGlobal(fir::GlobalOp op) {
-  auto attr = op.getDataAttr();
-  if (attr && (*attr == cuf::DataAttribute::Device ||
-               *attr == cuf::DataAttribute::Managed ||
-               *attr == cuf::DataAttribute::Constant))
-    return true;
-  return false;
-}
-
 static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
                                    mlir::Location loc, mlir::Type toTy,
                                    mlir::Value val) {
@@ -388,7 +379,7 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
     if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
       if (auto global = symTab.lookup<fir::GlobalOp>(
               addrOfOp.getSymbol().getRootReference().getValue())) {
-        if (isDeviceGlobal(global)) {
+        if (cuf::isRegisteredDeviceGlobal(global)) {
           rewriter.setInsertionPointAfter(addrOfOp);
           auto mod = op->getParentOfType<mlir::ModuleOp>();
           fir::FirOpBuilder builder(rewriter, mod);
@@ -833,7 +824,7 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
                 addrOfOp.getSymbol().getRootReference().getValue())) {
           if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType())))
             return true;
-          if (isDeviceGlobal(global))
+          if (cuf::isRegisteredDeviceGlobal(global))
             return false;
         }
       }
diff --git a/flang/test/Fir/CUDA/cuda-constructor-2.f90 b/flang/test/Fir/CUDA/cuda-constructor-2.f90
index 29efdb083878af..eb118ccee311c0 100644
--- a/flang/test/Fir/CUDA/cuda-constructor-2.f90
+++ b/flang/test/Fir/CUDA/cuda-constructor-2.f90
@@ -39,7 +39,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr, dense<
 // CHECK-NOT: fir.call @_FortranACUFRegisterVariable
 
 module attributes {dlti.dl_spec = #dlti.dl_spec<i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f16 = dense<16> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, "dlti.stack_alignment" = 128 : i64, "dlti.endianness" = "little">, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (https://github.com/llvm/llvm-project.git 3372303188df0f7f8ac26e7ab610cf8b0f716d42)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
-  fir.global @_QMiso_c_bindingECc_int {data_attr = #cuf.cuda<constant>} constant : i32
+  fir.global @_QMiso_c_bindingECc_int constant : i32
   
 
   fir.type_info @_QM__fortran_builtinsT__builtin_c_ptr noinit nodestroy nofinal : !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index 2baead4010f5c5..94ee74736f6508 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -1,4 +1,4 @@
-// RUN: fir-opt --cuf-convert %s | FileCheck %s
+// RUN: fir-opt --split-input-file --cuf-convert %s | FileCheck %s
 
 module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
 fir.global @_QMmod1Eadev {data_attr = #cuf.cuda<device>} : !fir.array<10xi32> {
@@ -34,3 +34,33 @@ func.func @_QQmain() attributes {fir.bindc_name = "test"} {
 // CHECK: %[[ARRAY_COOR:.*]] = fir.array_coor %[[DECL]](%{{.*}}) %c4{{.*}} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
 // CHECK: %[[ARRAY_COOR_PTR:.*]] = fir.convert %[[ARRAY_COOR]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
 // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[ARRAY_COOR_PTR]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+
+// -----
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+
+  fir.global @_QMdevmodEdarray {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>> {
+    %c0 = arith.constant 0 : index
+    %0 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
+    %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+    %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+    fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xf32>>>
+  }
+  func.func @_QQmain() attributes {fir.bindc_name = "arraysize"} {
+    %0 = fir.address_of(@_QMiso_c_bindingECc_int) : !fir.ref<i32>
+    %1 = fir.declare %0 {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QMiso_c_bindingECc_int"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %2 = fir.address_of(@_QMdevmodEdarray) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+    %3 = fir.declare %2 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMdevmodEdarray"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+    %4 = fir.alloca i32 {bindc_name = "exp", uniq_name = "_QFEexp"}
+    %5 = fir.declare %4 {uniq_name = "_QFEexp"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %6 = fir.alloca i32 {bindc_name = "hsize", uniq_name = "_QFEhsize"}
+    %7 = fir.declare %6 {uniq_name = "_QFEhsize"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    return
+  }
+  fir.global @_QMiso_c_bindingECc_int constant : i32
+}
+
+// We cannot call _FortranACUFGetDeviceAddress on a constant global. 
+// There is no symbol for it and the call would result into an unresolved reference.
+// CHECK-NOT: fir.call {{.*}}GetDeviceAddress
+
diff --git a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90 b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
index 772e2696171a63..5a4cc8590f4168 100644
--- a/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
+++ b/flang/test/Fir/CUDA/cuda-implicit-device-global.f90
@@ -23,7 +23,7 @@ // Test that global used in device function are flagged with the correct
 // CHECK: %[[GLOBAL:.*]] = fir.address_of(@_QQcl[[SYMBOL:.*]]) : !fir.ref<!fir.char<1,32>>
 // CHECK: %[[CONV:.*]] = fir.convert %[[GLOBAL]] : (!fir.ref<!fir.char<1,32>>) -> !fir.ref<i8>
 // CHECK: fir.call @_FortranAioBeginExternalListOutput(%{{.*}}, %[[CONV]], %{{.*}}) fastmath<contract> : (i32, !fir.ref<i8>, i32) -> !fir.ref<i8>
-// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,32>
+// CHECK: fir.global linkonce @_QQcl[[SYMBOL]] constant : !fir.char<1,32>
 
 // CHECK-LABEL: gpu.module @cuda_device_mod
 // CHECK: fir.global linkonce @_QQclX6995815537abaf90e86ce166af128f3a
@@ -99,10 +99,11 @@ // Test that global used in device function are flagged with the correct
   fir.has_value %0 : !fir.char<1,11>
 }
 
-// CHECK: fir.global linkonce @_QQclX5465737420504153534544 {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,11>
+// Checking that a constant fir.global that is only used in host code is not copied over to the device
+// CHECK: fir.global linkonce @_QQclX5465737420504153534544 constant : !fir.char<1,11>
 
 // CHECK-LABEL: gpu.module @cuda_device_mod
-// CHECK: fir.global linkonce @_QQclX5465737420504153534544 {data_attr = #cuf.cuda<constant>} constant
+// CHECK-NOT: fir.global linkonce @_QQclX5465737420504153534544
 
 // -----
 
@@ -140,7 +141,8 @@ // Test that global used in device function are flagged with the correct
 }
 func.func private @_FortranAioEndIoStatement(!fir.ref<i8>) -> i32 attributes {fir.io, fir.runtime}
 
-// CHECK: fir.global linkonce @_QQclX5465737420504153534544 {data_attr = #cuf.cuda<constant>} constant : !fir.char<1,11>
+// Checking that a constant fir.global that is used in device code is copied over to the device
+// CHECK: fir.global linkonce @_QQclX5465737420504153534544 constant : !fir.char<1,11>
 
 // CHECK-LABEL: gpu.module @cuda_device_mod 
-// CHECK: fir.global linkonce @_QQclX5465737420504153534544 {data_attr = #cuf.cuda<constant>} constant
+// CHECK: fir.global linkonce @_QQclX5465737420504153534544 constant



More information about the flang-commits mailing list