[flang] [llvm] [mlir] [flang][cuda] Support non-allocatable module-level managed variables (PR #188526)
Zhen Wang via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 25 20:13:25 PDT 2026
https://github.com/wangzpgi updated https://github.com/llvm/llvm-project/pull/188526
>From e40bb2f7c6209979a7821c413f7d8293fac39ca1 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 23 Mar 2026 14:59:35 -0700
Subject: [PATCH 1/8] Initial draft of supporting non-allocatable managed var
---
.../Transforms/CUDA/CUFAddConstructor.cpp | 77 +++++++++++++++----
.../Transforms/CUDA/CUFOpConversionLate.cpp | 24 +++++-
flang/test/Fir/CUDA/cuda-constructor-2.f90 | 33 ++++++++
flang/test/Fir/CUDA/cuda-device-address.mlir | 37 +++++++++
4 files changed, 155 insertions(+), 16 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
index baa8e591ee162..cd471c7713e53 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
@@ -41,6 +41,42 @@ namespace {
static constexpr llvm::StringRef cudaFortranCtorName{
"__cudaFortranConstructor"};
+static constexpr llvm::StringRef managedPtrSuffix{".managed.ptr"};
+
+/// Create an 8-byte pointer global in the __nv_managed_data__ section.
+/// The CUDA runtime fills this pointer with the unified memory address
+/// during __cudaRegisterManagedVar.
+static fir::GlobalOp createManagedPointerGlobal(fir::FirOpBuilder &builder,
+ mlir::ModuleOp mod,
+ fir::GlobalOp globalOp) {
+ auto *ctx = mod.getContext();
+ std::string ptrGlobalName =
+ (globalOp.getSymName() + managedPtrSuffix).str();
+ auto ptrTy =
+ fir::LLVMPointerType::get(ctx, mlir::IntegerType::get(ctx, 8));
+
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointAfter(globalOp);
+
+ llvm::SmallVector<mlir::NamedAttribute> attrs;
+ attrs.push_back(mlir::NamedAttribute(
+ mlir::StringAttr::get(ctx, "section"),
+ mlir::StringAttr::get(ctx, "__nv_managed_data__")));
+
+ mlir::DenseElementsAttr initAttr = {};
+ auto ptrGlobal = fir::GlobalOp::create(
+ builder, globalOp.getLoc(), ptrGlobalName, /*isConstant=*/false,
+ /*isTarget=*/false, ptrTy, initAttr,
+ /*linkName=*/builder.createInternalLinkage(), attrs);
+
+ mlir::Region ®ion = ptrGlobal.getRegion();
+ mlir::Block *block = builder.createBlock(®ion);
+ builder.setInsertionPointToStart(block);
+ mlir::Value zero = fir::ZeroOp::create(builder, globalOp.getLoc(), ptrTy);
+ fir::HasValueOp::create(builder, globalOp.getLoc(), zero);
+
+ return ptrGlobal;
+}
struct CUFAddConstructor
: public fir::impl::CUFAddConstructorBase<CUFAddConstructor> {
@@ -108,19 +144,15 @@ struct CUFAddConstructor
if (!attr)
continue;
- if (attr.getValue() == cuf::DataAttribute::Managed &&
- !mlir::isa<fir::BaseBoxType>(globalOp.getType()))
- TODO(loc, "registration of non-allocatable managed variables");
+ bool isNonAllocManagedGlobal =
+ attr.getValue() == cuf::DataAttribute::Managed &&
+ !mlir::isa<fir::BaseBoxType>(globalOp.getType());
mlir::func::FuncOp func;
switch (attr.getValue()) {
case cuf::DataAttribute::Device:
case cuf::DataAttribute::Constant:
case cuf::DataAttribute::Managed: {
- func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>(
- loc, builder);
- auto fTy = func.getFunctionType();
-
// Global variable name
std::string gblNameStr = globalOp.getSymbol().getValue().str();
gblNameStr += '\0';
@@ -141,13 +173,30 @@ struct CUFAddConstructor
}
auto sizeVal = builder.createIntegerConstant(loc, idxTy, *size);
- // Global variable address
- mlir::Value addr = fir::AddrOfOp::create(
- builder, loc, globalOp.resultType(), globalOp.getSymbol());
-
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, registeredMod, addr, gblName, sizeVal)};
- fir::CallOp::create(builder, loc, func, args);
+ if (isNonAllocManagedGlobal) {
+ // Non-allocatable managed globals use nvcc-style unified memory:
+ // create an 8-byte pointer global in __nv_managed_data__ and
+ // register with __cudaRegisterManagedVar via the runtime wrapper.
+ fir::GlobalOp ptrGlobal =
+ createManagedPointerGlobal(builder, mod, globalOp);
+ func = fir::runtime::getRuntimeFunc<
+ mkRTKey(CUFRegisterManagedVariable)>(loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value addr = fir::AddrOfOp::create(
+ builder, loc, ptrGlobal.resultType(), ptrGlobal.getSymbol());
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, registeredMod, addr, gblName, sizeVal)};
+ fir::CallOp::create(builder, loc, func, args);
+ } else {
+ func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>(
+ loc, builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value addr = fir::AddrOfOp::create(
+ builder, loc, globalOp.resultType(), globalOp.getSymbol());
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, registeredMod, addr, gblName, sizeVal)};
+ fir::CallOp::create(builder, loc, func, args);
+ }
} break;
default:
break;
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp
index fe459712a6ba4..73a385591437c 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp
@@ -13,6 +13,7 @@
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
@@ -48,6 +49,8 @@ static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
return val;
}
+static constexpr llvm::StringRef managedPtrSuffix{".managed.ptr"};
+
struct CUFDeviceAddressOpConversion
: public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
using OpRewritePattern::OpRewritePattern;
@@ -59,10 +62,27 @@ struct CUFDeviceAddressOpConversion
mlir::LogicalResult
matchAndRewrite(cuf::DeviceAddressOp op,
mlir::PatternRewriter &rewriter) const override {
- if (auto global = symTab.lookup<fir::GlobalOp>(
- op.getHostSymbol().getRootReference().getValue())) {
+ auto symName = op.getHostSymbol().getRootReference().getValue();
+ if (auto global = symTab.lookup<fir::GlobalOp>(symName)) {
auto mod = op->getParentOfType<mlir::ModuleOp>();
mlir::Location loc = op.getLoc();
+
+ // For non-allocatable managed globals, CUFAddConstructor created a
+ // companion pointer global (@sym.managed.ptr) that holds the unified
+ // memory address. Load from it instead of calling CUFGetDeviceAddress.
+ std::string ptrGlobalName = (symName + managedPtrSuffix).str();
+ if (auto ptrGlobal =
+ symTab.lookup<fir::GlobalOp>(ptrGlobalName)) {
+ auto ptrRef = fir::AddrOfOp::create(rewriter, loc,
+ ptrGlobal.resultType(),
+ ptrGlobal.getSymbol());
+ auto rawPtr = fir::LoadOp::create(rewriter, loc, ptrRef);
+ auto converted =
+ fir::ConvertOp::create(rewriter, loc, op.getType(), rawPtr);
+ rewriter.replaceOp(op, converted);
+ return success();
+ }
+
auto hostAddr = fir::AddrOfOp::create(
rewriter, loc, fir::ReferenceType::get(global.getType()),
op.getHostSymbol());
diff --git a/flang/test/Fir/CUDA/cuda-constructor-2.f90 b/flang/test/Fir/CUDA/cuda-constructor-2.f90
index f21d8f9c37637..0fdd5966550e9 100644
--- a/flang/test/Fir/CUDA/cuda-constructor-2.f90
+++ b/flang/test/Fir/CUDA/cuda-constructor-2.f90
@@ -78,3 +78,36 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<i8 = dense<8> : vector<2xi64>, i
// CHECK: llvm.func internal @__cudaFortranConstructor()
// CHECK: fir.address_of(@_QMmEa00)
// CHECK: fir.call @_FortranACUFRegisterVariable
+
+// -----
+
+// Non-allocatable managed global: should create pointer global in
+// __nv_managed_data__ and register with CUFRegisterManagedVariable.
+//
+// Fortran source:
+// module test
+// integer*4, managed :: manx(100)
+// contains
+// attributes(global) subroutine kernel()
+// end subroutine
+// end module
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f32, dense<32> : vector<2xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+
+ fir.global @_QMtestEmanx {data_attr = #cuf.cuda<managed>} : !fir.array<100xi32> {
+ %0 = fir.zero_bits !fir.array<100xi32>
+ fir.has_value %0 : !fir.array<100xi32>
+ }
+
+ gpu.module @cuda_device_mod {
+ }
+}
+
+// Pointer global should be created with section attribute.
+// CHECK: fir.global internal @_QMtestEmanx.managed.ptr {section = "__nv_managed_data__"} : !fir.llvm_ptr<i8>
+// CHECK: fir.zero_bits !fir.llvm_ptr<i8>
+
+// Constructor should register with CUFRegisterManagedVariable.
+// CHECK: llvm.func internal @__cudaFortranConstructor()
+// CHECK: fir.address_of(@_QMtestEmanx.managed.ptr) : !fir.ref<!fir.llvm_ptr<i8>>
+// CHECK: fir.call @_FortranACUFRegisterManagedVariable
diff --git a/flang/test/Fir/CUDA/cuda-device-address.mlir b/flang/test/Fir/CUDA/cuda-device-address.mlir
index e86208321b8ab..d87e5467ab313 100644
--- a/flang/test/Fir/CUDA/cuda-device-address.mlir
+++ b/flang/test/Fir/CUDA/cuda-device-address.mlir
@@ -12,3 +12,40 @@ func.func @_QPxa(%arg0: !fir.ref<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<d
// CHECK-LABEL: func.func @_QPxa
// CHECK: fir.call @_FortranACUFGetDeviceAddress
+
+// -----
+
+// Non-allocatable managed global with companion pointer global:
+// cuf.device_address should load from the pointer global instead of
+// calling CUFGetDeviceAddress.
+//
+// Fortran source:
+// module test
+// integer*4, managed :: manx(100)
+// end module
+// subroutine user()
+// use test
+// manx(1) = 42
+// end subroutine
+
+fir.global @_QMtestEmanx {data_attr = #cuf.cuda<managed>} : !fir.array<100xi32> {
+ %0 = fir.zero_bits !fir.array<100xi32>
+ fir.has_value %0 : !fir.array<100xi32>
+}
+
+fir.global internal @_QMtestEmanx.managed.ptr {section = "__nv_managed_data__"} : !fir.llvm_ptr<i8> {
+ %0 = fir.zero_bits !fir.llvm_ptr<i8>
+ fir.has_value %0 : !fir.llvm_ptr<i8>
+}
+
+func.func @_QPuser() {
+ %0 = cuf.device_address @_QMtestEmanx -> !fir.ref<!fir.array<100xi32>>
+ %1 = fir.declare %0 {uniq_name = "_QMtestEmanx"} : (!fir.ref<!fir.array<100xi32>>) -> !fir.ref<!fir.array<100xi32>>
+ return
+}
+
+// CHECK-LABEL: func.func @_QPuser
+// CHECK-NOT: fir.call @_FortranACUFGetDeviceAddress
+// CHECK: %[[PTR_REF:.*]] = fir.address_of(@_QMtestEmanx.managed.ptr) : !fir.ref<!fir.llvm_ptr<i8>>
+// CHECK: %[[RAW_PTR:.*]] = fir.load %[[PTR_REF]] : !fir.ref<!fir.llvm_ptr<i8>>
+// CHECK: %[[ADDR:.*]] = fir.convert %[[RAW_PTR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<100xi32>>
>From f6d6f913719aace62aa7d43c887f5f849ee8ff9f Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 23 Mar 2026 15:31:18 -0700
Subject: [PATCH 2/8] NVVM Managed Annotation
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 9 +++++++++
flang/test/Fir/CUDA/cuda-code-gen.mlir | 16 +++++++++++++++
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 3 +++
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 20 ++++++++++++++++++-
mlir/test/Target/LLVMIR/nvvmir.mlir | 10 ++++++++++
5 files changed, 57 insertions(+), 1 deletion(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 25eb6194efa99..8ac536d64ad96 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3448,6 +3448,15 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
g.setAddrSpace(
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Constant));
+ if (gpuMod && global.getDataAttr() &&
+ *global.getDataAttr() == cuf::DataAttribute::Managed &&
+ !mlir::isa<fir::BaseBoxType>(global.getType())) {
+ g.setAddrSpace(
+ static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Global));
+ g->setAttr(NVVM::NVVMDialect::getManagedAttrName(),
+ mlir::UnitAttr::get(global.getContext()));
+ }
+
rewriter.eraseOp(global);
return mlir::success();
}
diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir
index e83648f21bdf1..a21592cff7990 100644
--- a/flang/test/Fir/CUDA/cuda-code-gen.mlir
+++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir
@@ -312,3 +312,19 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
// CHECK-LABEL: gpu.func @_QMkernelsPassign
// CHECK: %[[ADDROF:.*]] = llvm.mlir.addressof @_QMkernelsEinitial_val : !llvm.ptr<4>
// CHECK: %{{.*}} = llvm.addrspacecast %[[ADDROF]] : !llvm.ptr<4> to !llvm.ptr
+
+// -----
+
+// Test that non-allocatable managed globals inside gpu.module get
+// addr_space = 1 (Global) and the nvvm.managed annotation.
+
+module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+ gpu.module @cuda_device_mod {
+ fir.global @_QMtestEmanx {data_attr = #cuf.cuda<Managed>} : !fir.array<100xi32> {
+ %0 = fir.zero_bits !fir.array<100xi32>
+ fir.has_value %0 : !fir.array<100xi32>
+ }
+ }
+}
+
+// CHECK: llvm.mlir.global external @_QMtestEmanx() {addr_space = 1 : i32, nvvm.managed} : !llvm.array<100 x i32>
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0c5dae265e2ca..87fd75f5a3e19 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -96,6 +96,9 @@ def NVVM_Dialect : Dialect {
/// nvvm.cluster_dim attributes.
static StringRef getBlocksAreClustersAttrName() { return "nvvm.blocksareclusters"; }
+ /// Get the name of the attribute used to annotate managed global variables.
+ static StringRef getManagedAttrName() { return "nvvm.managed"; }
+
/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 18d42e9577095..83aef44210ca3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -698,11 +698,29 @@ class NVVMDialectLLVMIRTranslationInterface
return failure();
}
- /// Attaches module-level metadata for functions marked as kernels.
+ /// Attaches module-level metadata for functions marked as kernels
+ /// and managed annotations for global variables.
LogicalResult
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
+ if (auto globalOp = dyn_cast<LLVM::GlobalOp>(op)) {
+ if (attribute.getName() == NVVM::NVVMDialect::getManagedAttrName()) {
+ auto *gv = cast<llvm::GlobalVariable>(
+ moduleTranslation.lookupGlobal(globalOp));
+ llvm::Module *m = gv->getParent();
+ llvm::LLVMContext &ctx = m->getContext();
+ llvm::NamedMDNode *md =
+ m->getOrInsertNamedMetadata("nvvm.annotations");
+ md->addOperand(llvm::MDNode::get(
+ ctx, {llvm::ConstantAsMetadata::get(gv),
+ llvm::MDString::get(ctx, "managed"),
+ llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
+ llvm::Type::getInt32Ty(ctx), 1))}));
+ }
+ return success();
+ }
+
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8a7e9bae4ec2e..686a9a481f5ac 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -915,3 +915,13 @@ llvm.func @nanosleep(%duration: i32) {
nvvm.nanosleep %duration
llvm.return
}
+
+// -----
+
+// CHECK: @managed_g = global i32 0
+// CHECK: !nvvm.annotations = !{![[MANAGED:[0-9]+]]}
+// CHECK: ![[MANAGED]] = !{ptr @managed_g, !"managed", i32 1}
+llvm.mlir.global external @managed_g() {addr_space = 1 : i32, nvvm.managed} : i32 {
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ llvm.return %0 : i32
+}
>From ec9c1ccbbb692532cead34378eb475bc5be3fc69 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 23 Mar 2026 17:39:13 -0700
Subject: [PATCH 3/8] modify tests
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 2 +-
flang/test/Fir/CUDA/cuda-code-gen.mlir | 2 +-
flang/test/Fir/CUDA/cuda-device-address.mlir | 4 +++-
mlir/test/Target/LLVMIR/nvvmir.mlir | 4 ++--
4 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 8ac536d64ad96..2d01463cf604d 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3453,7 +3453,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
!mlir::isa<fir::BaseBoxType>(global.getType())) {
g.setAddrSpace(
static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Global));
- g->setAttr(NVVM::NVVMDialect::getManagedAttrName(),
+ g->setAttr(mlir::NVVM::NVVMDialect::getManagedAttrName(),
mlir::UnitAttr::get(global.getContext()));
}
diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir
index a21592cff7990..fc962f8de5039 100644
--- a/flang/test/Fir/CUDA/cuda-code-gen.mlir
+++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir
@@ -320,7 +320,7 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
gpu.module @cuda_device_mod {
- fir.global @_QMtestEmanx {data_attr = #cuf.cuda<Managed>} : !fir.array<100xi32> {
+ fir.global @_QMtestEmanx {data_attr = #cuf.cuda<managed>} : !fir.array<100xi32> {
%0 = fir.zero_bits !fir.array<100xi32>
fir.has_value %0 : !fir.array<100xi32>
}
diff --git a/flang/test/Fir/CUDA/cuda-device-address.mlir b/flang/test/Fir/CUDA/cuda-device-address.mlir
index d87e5467ab313..a2dae71557869 100644
--- a/flang/test/Fir/CUDA/cuda-device-address.mlir
+++ b/flang/test/Fir/CUDA/cuda-device-address.mlir
@@ -39,8 +39,10 @@ fir.global internal @_QMtestEmanx.managed.ptr {section = "__nv_managed_data__"}
}
func.func @_QPuser() {
+ %c100 = arith.constant 100 : index
%0 = cuf.device_address @_QMtestEmanx -> !fir.ref<!fir.array<100xi32>>
- %1 = fir.declare %0 {uniq_name = "_QMtestEmanx"} : (!fir.ref<!fir.array<100xi32>>) -> !fir.ref<!fir.array<100xi32>>
+ %1 = fir.shape %c100 : (index) -> !fir.shape<1>
+ %2 = fir.declare %0(%1) {uniq_name = "_QMtestEmanx"} : (!fir.ref<!fir.array<100xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<100xi32>>
return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 686a9a481f5ac..deecf09491ebe 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -918,9 +918,9 @@ llvm.func @nanosleep(%duration: i32) {
// -----
-// CHECK: @managed_g = global i32 0
+// CHECK: @managed_g = addrspace(1) global i32 0
// CHECK: !nvvm.annotations = !{![[MANAGED:[0-9]+]]}
-// CHECK: ![[MANAGED]] = !{ptr @managed_g, !"managed", i32 1}
+// CHECK: ![[MANAGED]] = !{ptr addrspace(1) @managed_g, !"managed", i32 1}
llvm.mlir.global external @managed_g() {addr_space = 1 : i32, nvvm.managed} : i32 {
%0 = llvm.mlir.constant(0 : i32) : i32
llvm.return %0 : i32
>From 323ff1ff76585c6ece427c5c1e338b41d5351a0e Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 23 Mar 2026 18:32:59 -0700
Subject: [PATCH 4/8] skip data transfer for module managed
---
flang/include/flang/Evaluate/tools.h | 25 ++++++++++++++++++++
flang/test/Lower/CUDA/cuda-data-transfer.cuf | 14 +++++++++++
2 files changed, 39 insertions(+)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 963452755064d..52f39d1e72cf6 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1311,6 +1311,27 @@ inline bool IsCUDAManagedOrUnifiedSymbol(const Symbol &sym) {
return false;
}
+// Non-allocatable module-level managed/unified variables use nvcc-style
+// unified memory with shadow pointer indirection. cudaMemcpy targeting
+// these variables would use the shadow address rather than the actual
+// unified memory address, so data transfers must be avoided.
+inline bool IsNonAllocatableModuleCUDAManagedSymbol(const Symbol &sym) {
+ const Symbol &ultimate = sym.GetUltimate();
+ if (!IsCUDAManagedOrUnifiedSymbol(ultimate))
+ return false;
+ if (ultimate.attrs().test(semantics::Attr::ALLOCATABLE))
+ return false;
+ return ultimate.owner().IsModule();
+}
+
+template <typename A>
+inline bool HasNonAllocatableModuleCUDAManagedSymbols(const A &expr) {
+ for (const Symbol &sym : CollectCudaSymbols(expr))
+ if (IsNonAllocatableModuleCUDAManagedSymbol(sym))
+ return true;
+ return false;
+}
+
// Get the number of distinct symbols with CUDA device
// attribute in the expression.
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
@@ -1350,6 +1371,10 @@ inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
int rhsNbManagedSymbols{GetNbOfCUDAManagedOrUnifiedSymbols(rhs)};
int rhsNbSymbols{GetNbOfCUDADeviceSymbols(rhs)};
+ if (HasNonAllocatableModuleCUDAManagedSymbols(lhs) ||
+ HasNonAllocatableModuleCUDAManagedSymbols(rhs))
+ return false;
+
if (lhsNbManagedSymbols >= 1 && lhs.Rank() > 0 && rhsNbSymbols == 0 &&
rhsNbManagedSymbols == 0 && (IsVariable(rhs) || IsConstantExpr(rhs))) {
return true; // Managed arrays initialization is performed on the device.
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 66c3a28f9aec4..7d238107772ba 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -637,3 +637,17 @@ end subroutine
! CHECK-LABEL: func.func @_QPsub34
! CHECK: cuf.data_transfer %{{.*}} to %{{.*}} {hasManagedOrUnifedSymbols, transfer_kind = #cuf.cuda_transfer<host_device>} : f16, !fir.box<!fir.array<?xf16>>
+
+module managed_mod
+ integer, managed :: marray(10)
+end module
+
+subroutine sub35()
+ use managed_mod
+ integer :: host_arr(10)
+ marray = host_arr
+ marray = 0
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub35()
+! CHECK-NOT: cuf.data_transfer
>From 303770cbf30851dda4310fcecba34bef96bbae83 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Tue, 24 Mar 2026 09:29:23 -0700
Subject: [PATCH 5/8] add __cudaInitModue and modify IsCUDADataTransfer
---
flang-rt/lib/cuda/registration.cpp | 2 ++
flang/include/flang/Evaluate/tools.h | 6 +++---
2 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/flang-rt/lib/cuda/registration.cpp b/flang-rt/lib/cuda/registration.cpp
index 8123220c2624c..c4de75b9f9711 100644
--- a/flang-rt/lib/cuda/registration.cpp
+++ b/flang-rt/lib/cuda/registration.cpp
@@ -27,6 +27,7 @@ extern void __cudaRegisterVar(void **fatCubinHandle, char *hostVar,
extern void __cudaRegisterManagedVar(void **fatCubinHandle,
void **hostVarPtrAddress, char *deviceAddress, const char *deviceName,
int ext, size_t size, int constant, int global);
+extern char __cudaInitModule(void **fatCubinHandle);
void *RTDECL(CUFRegisterModule)(void *data) {
void **fatHandle{__cudaRegisterFatBinary(data)};
@@ -48,6 +49,7 @@ void RTDEF(CUFRegisterVariable)(
void RTDEF(CUFRegisterManagedVariable)(
void **module, void **varSym, char *varName, int64_t size) {
__cudaRegisterManagedVar(module, varSym, varName, varName, 0, size, 0, 0);
+ __cudaInitModule(module);
}
} // extern "C"
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 52f39d1e72cf6..10d541755b810 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1371,8 +1371,7 @@ inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
int rhsNbManagedSymbols{GetNbOfCUDAManagedOrUnifiedSymbols(rhs)};
int rhsNbSymbols{GetNbOfCUDADeviceSymbols(rhs)};
- if (HasNonAllocatableModuleCUDAManagedSymbols(lhs) ||
- HasNonAllocatableModuleCUDAManagedSymbols(rhs))
+ if (HasNonAllocatableModuleCUDAManagedSymbols(lhs))
return false;
if (lhsNbManagedSymbols >= 1 && lhs.Rank() > 0 && rhsNbSymbols == 0 &&
@@ -1384,7 +1383,8 @@ inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
// - Only managed or unifed symbols are involved on RHS and LHS.
// - LHS is managed or unified and the RHS is host only.
if ((lhsNbManagedSymbols >= 1 && rhsNbManagedSymbols == rhsNbSymbols) ||
- (lhsNbManagedSymbols == 0 && rhsNbManagedSymbols >= 1 &&
+ (lhsNbManagedSymbols == 0 && !HasCUDADeviceAttrs(lhs) &&
+ rhsNbManagedSymbols >= 1 &&
rhsNbManagedSymbols == rhsNbSymbols) ||
(lhsNbManagedSymbols >= 1 && rhsNbSymbols == 0)) {
return false;
>From 81bf4ad099b8e80b31dc40248e3ab70f8eb6c229 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Tue, 24 Mar 2026 09:31:01 -0700
Subject: [PATCH 6/8] format
---
flang/include/flang/Evaluate/tools.h | 3 +--
.../Transforms/CUDA/CUFAddConstructor.cpp | 16 +++++++---------
.../Transforms/CUDA/CUFOpConversionLate.cpp | 8 +++-----
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 3 +--
4 files changed, 12 insertions(+), 18 deletions(-)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 10d541755b810..791ec28a9800a 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1384,8 +1384,7 @@ inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
// - LHS is managed or unified and the RHS is host only.
if ((lhsNbManagedSymbols >= 1 && rhsNbManagedSymbols == rhsNbSymbols) ||
(lhsNbManagedSymbols == 0 && !HasCUDADeviceAttrs(lhs) &&
- rhsNbManagedSymbols >= 1 &&
- rhsNbManagedSymbols == rhsNbSymbols) ||
+ rhsNbManagedSymbols >= 1 && rhsNbManagedSymbols == rhsNbSymbols) ||
(lhsNbManagedSymbols >= 1 && rhsNbSymbols == 0)) {
return false;
}
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
index cd471c7713e53..7f47d6271ecbd 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
@@ -50,18 +50,16 @@ static fir::GlobalOp createManagedPointerGlobal(fir::FirOpBuilder &builder,
mlir::ModuleOp mod,
fir::GlobalOp globalOp) {
auto *ctx = mod.getContext();
- std::string ptrGlobalName =
- (globalOp.getSymName() + managedPtrSuffix).str();
- auto ptrTy =
- fir::LLVMPointerType::get(ctx, mlir::IntegerType::get(ctx, 8));
+ std::string ptrGlobalName = (globalOp.getSymName() + managedPtrSuffix).str();
+ auto ptrTy = fir::LLVMPointerType::get(ctx, mlir::IntegerType::get(ctx, 8));
mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(globalOp);
llvm::SmallVector<mlir::NamedAttribute> attrs;
- attrs.push_back(mlir::NamedAttribute(
- mlir::StringAttr::get(ctx, "section"),
- mlir::StringAttr::get(ctx, "__nv_managed_data__")));
+ attrs.push_back(
+ mlir::NamedAttribute(mlir::StringAttr::get(ctx, "section"),
+ mlir::StringAttr::get(ctx, "__nv_managed_data__")));
mlir::DenseElementsAttr initAttr = {};
auto ptrGlobal = fir::GlobalOp::create(
@@ -179,8 +177,8 @@ struct CUFAddConstructor
// register with __cudaRegisterManagedVar via the runtime wrapper.
fir::GlobalOp ptrGlobal =
createManagedPointerGlobal(builder, mod, globalOp);
- func = fir::runtime::getRuntimeFunc<
- mkRTKey(CUFRegisterManagedVariable)>(loc, builder);
+ func = fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFRegisterManagedVariable)>(loc, builder);
auto fTy = func.getFunctionType();
mlir::Value addr = fir::AddrOfOp::create(
builder, loc, ptrGlobal.resultType(), ptrGlobal.getSymbol());
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp
index 73a385591437c..62f95f5d23c34 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversionLate.cpp
@@ -71,11 +71,9 @@ struct CUFDeviceAddressOpConversion
// companion pointer global (@sym.managed.ptr) that holds the unified
// memory address. Load from it instead of calling CUFGetDeviceAddress.
std::string ptrGlobalName = (symName + managedPtrSuffix).str();
- if (auto ptrGlobal =
- symTab.lookup<fir::GlobalOp>(ptrGlobalName)) {
- auto ptrRef = fir::AddrOfOp::create(rewriter, loc,
- ptrGlobal.resultType(),
- ptrGlobal.getSymbol());
+ if (auto ptrGlobal = symTab.lookup<fir::GlobalOp>(ptrGlobalName)) {
+ auto ptrRef = fir::AddrOfOp::create(
+ rewriter, loc, ptrGlobal.resultType(), ptrGlobal.getSymbol());
auto rawPtr = fir::LoadOp::create(rewriter, loc, ptrRef);
auto converted =
fir::ConvertOp::create(rewriter, loc, op.getType(), rawPtr);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 83aef44210ca3..e56952cf565fd 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -710,8 +710,7 @@ class NVVMDialectLLVMIRTranslationInterface
moduleTranslation.lookupGlobal(globalOp));
llvm::Module *m = gv->getParent();
llvm::LLVMContext &ctx = m->getContext();
- llvm::NamedMDNode *md =
- m->getOrInsertNamedMetadata("nvvm.annotations");
+ llvm::NamedMDNode *md = m->getOrInsertNamedMetadata("nvvm.annotations");
md->addOperand(llvm::MDNode::get(
ctx, {llvm::ConstantAsMetadata::get(gv),
llvm::MDString::get(ctx, "managed"),
>From af377bd5c3dd271c86c30fd875761a2634cb48b6 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Tue, 24 Mar 2026 09:47:29 -0700
Subject: [PATCH 7/8] add more tests
---
flang/include/flang/Evaluate/tools.h | 16 ++++++++------
.../Transforms/CUDA/CUFAddConstructor.cpp | 10 ++++-----
flang/test/Lower/CUDA/cuda-data-transfer.cuf | 22 +++++++++++++++++++
3 files changed, 36 insertions(+), 12 deletions(-)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 791ec28a9800a..51dc0582fcdea 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1311,10 +1311,11 @@ inline bool IsCUDAManagedOrUnifiedSymbol(const Symbol &sym) {
return false;
}
-// Non-allocatable module-level managed/unified variables use nvcc-style
-// unified memory with shadow pointer indirection. cudaMemcpy targeting
-// these variables would use the shadow address rather than the actual
-// unified memory address, so data transfers must be avoided.
+// Non-allocatable module-level managed/unified variables use pointer
+// indirection through a companion global in __nv_managed_data__.
+// Explicit data transfers (cudaMemcpy) must be avoided for these
+// variables since they would target the shadow address rather than
+// the actual unified memory address.
inline bool IsNonAllocatableModuleCUDAManagedSymbol(const Symbol &sym) {
const Symbol &ultimate = sym.GetUltimate();
if (!IsCUDAManagedOrUnifiedSymbol(ultimate))
@@ -1379,9 +1380,10 @@ inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
return true; // Managed arrays initialization is performed on the device.
}
- // Special cases performed on the host:
- // - Only managed or unifed symbols are involved on RHS and LHS.
- // - LHS is managed or unified and the RHS is host only.
+ // Cases where no explicit data transfer is needed:
+ // - Both sides involve only managed/unified symbols (host-accessible).
+ // - LHS is host-only and RHS has only managed/unified symbols.
+ // - LHS is managed/unified and RHS is host-only.
if ((lhsNbManagedSymbols >= 1 && rhsNbManagedSymbols == rhsNbSymbols) ||
(lhsNbManagedSymbols == 0 && !HasCUDADeviceAttrs(lhs) &&
rhsNbManagedSymbols >= 1 && rhsNbManagedSymbols == rhsNbSymbols) ||
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
index 7f47d6271ecbd..59b0992679f36 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
@@ -44,8 +44,8 @@ static constexpr llvm::StringRef cudaFortranCtorName{
static constexpr llvm::StringRef managedPtrSuffix{".managed.ptr"};
/// Create an 8-byte pointer global in the __nv_managed_data__ section.
-/// The CUDA runtime fills this pointer with the unified memory address
-/// during __cudaRegisterManagedVar.
+/// The CUDA runtime populates this pointer with the unified memory address
+/// when the module is initialized via __cudaInitModule.
static fir::GlobalOp createManagedPointerGlobal(fir::FirOpBuilder &builder,
mlir::ModuleOp mod,
fir::GlobalOp globalOp) {
@@ -172,9 +172,9 @@ struct CUFAddConstructor
auto sizeVal = builder.createIntegerConstant(loc, idxTy, *size);
if (isNonAllocManagedGlobal) {
- // Non-allocatable managed globals use nvcc-style unified memory:
- // create an 8-byte pointer global in __nv_managed_data__ and
- // register with __cudaRegisterManagedVar via the runtime wrapper.
+ // Non-allocatable managed globals use pointer indirection:
+ // a companion pointer in __nv_managed_data__ holds the unified
+ // memory address, registered via __cudaRegisterManagedVar.
fir::GlobalOp ptrGlobal =
createManagedPointerGlobal(builder, mod, globalOp);
func = fir::runtime::getRuntimeFunc<mkRTKey(
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 7d238107772ba..1d0e510c110ee 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -651,3 +651,25 @@ end subroutine
! CHECK-LABEL: func.func @_QPsub35()
! CHECK-NOT: cuf.data_transfer
+
+! Test that host_var = managed_module_var does NOT generate cuf.data_transfer
+! (managed memory is host-accessible, so direct assignment suffices).
+subroutine sub36()
+ use managed_mod
+ integer :: host_arr(10)
+ host_arr = marray
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub36()
+! CHECK-NOT: cuf.data_transfer
+
+! Test that device_var = managed_module_var DOES generate cuf.data_transfer
+! (device memory requires explicit cudaMemcpy).
+subroutine sub37()
+ use managed_mod
+ integer, device :: dev_arr(10)
+ dev_arr = marray
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub37()
+! CHECK: cuf.data_transfer
>From a41b4536a6cda18cd476ccb164672515ac64bc95 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Wed, 25 Mar 2026 20:12:51 -0700
Subject: [PATCH 8/8] Call __cudaInitModule once after all managed variables
are registered
---
flang-rt/lib/cuda/registration.cpp | 3 ++-
flang/include/flang/Runtime/CUDA/registration.h | 5 +++++
.../Transforms/CUDA/CUFAddConstructor.cpp | 15 ++++++++++++++-
flang/test/Fir/CUDA/cuda-constructor-2.f90 | 3 ++-
4 files changed, 23 insertions(+), 3 deletions(-)
diff --git a/flang-rt/lib/cuda/registration.cpp b/flang-rt/lib/cuda/registration.cpp
index c4de75b9f9711..58077d6a6a52b 100644
--- a/flang-rt/lib/cuda/registration.cpp
+++ b/flang-rt/lib/cuda/registration.cpp
@@ -49,9 +49,10 @@ void RTDEF(CUFRegisterVariable)(
void RTDEF(CUFRegisterManagedVariable)(
void **module, void **varSym, char *varName, int64_t size) {
__cudaRegisterManagedVar(module, varSym, varName, varName, 0, size, 0, 0);
- __cudaInitModule(module);
}
+void RTDEF(CUFInitModule)(void **module) { __cudaInitModule(module); }
+
} // extern "C"
} // namespace Fortran::runtime::cuda
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
index 15f013432fa04..74dbf9e189076 100644
--- a/flang/include/flang/Runtime/CUDA/registration.h
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -32,6 +32,11 @@ void RTDECL(CUFRegisterVariable)(
void RTDECL(CUFRegisterManagedVariable)(
void **module, void **varSym, char *varName, int64_t size);
+/// Initialize a CUDA module after all variables have been registered.
+/// Triggers the runtime to populate managed variable pointers with
+/// unified memory addresses.
+void RTDECL(CUFInitModule)(void **module);
+
} // extern "C"
} // namespace Fortran::runtime::cuda
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
index 59b0992679f36..a63c95d2d36b0 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFAddConstructor.cpp
@@ -49,7 +49,7 @@ static constexpr llvm::StringRef managedPtrSuffix{".managed.ptr"};
static fir::GlobalOp createManagedPointerGlobal(fir::FirOpBuilder &builder,
mlir::ModuleOp mod,
fir::GlobalOp globalOp) {
- auto *ctx = mod.getContext();
+ mlir::MLIRContext *ctx = mod.getContext();
std::string ptrGlobalName = (globalOp.getSymName() + managedPtrSuffix).str();
auto ptrTy = fir::LLVMPointerType::get(ctx, mlir::IntegerType::get(ctx, 8));
@@ -137,6 +137,7 @@ struct CUFAddConstructor
}
// Register variables
+ bool hasNonAllocManagedGlobal = false;
for (fir::GlobalOp globalOp : mod.getOps<fir::GlobalOp>()) {
auto attr = globalOp.getDataAttrAttr();
if (!attr)
@@ -175,6 +176,7 @@ struct CUFAddConstructor
// Non-allocatable managed globals use pointer indirection:
// a companion pointer in __nv_managed_data__ holds the unified
// memory address, registered via __cudaRegisterManagedVar.
+ hasNonAllocManagedGlobal = true;
fir::GlobalOp ptrGlobal =
createManagedPointerGlobal(builder, mod, globalOp);
func = fir::runtime::getRuntimeFunc<mkRTKey(
@@ -200,6 +202,17 @@ struct CUFAddConstructor
break;
}
}
+
+ // Initialize the module once after all managed variables are
+ // registered so the runtime populates their unified memory pointers.
+ if (hasNonAllocManagedGlobal) {
+ mlir::func::FuncOp initFunc =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFInitModule)>(loc, builder);
+ auto initFTy = initFunc.getFunctionType();
+ llvm::SmallVector<mlir::Value> initArgs{
+ fir::runtime::createArguments(builder, loc, initFTy, registeredMod)};
+ fir::CallOp::create(builder, loc, initFunc, initArgs);
+ }
}
mlir::LLVM::ReturnOp::create(builder, loc, mlir::ValueRange{});
diff --git a/flang/test/Fir/CUDA/cuda-constructor-2.f90 b/flang/test/Fir/CUDA/cuda-constructor-2.f90
index 0fdd5966550e9..9504bcb39f0ea 100644
--- a/flang/test/Fir/CUDA/cuda-constructor-2.f90
+++ b/flang/test/Fir/CUDA/cuda-constructor-2.f90
@@ -107,7 +107,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr, dense<
// CHECK: fir.global internal @_QMtestEmanx.managed.ptr {section = "__nv_managed_data__"} : !fir.llvm_ptr<i8>
// CHECK: fir.zero_bits !fir.llvm_ptr<i8>
-// Constructor should register with CUFRegisterManagedVariable.
+// Constructor should register with CUFRegisterManagedVariable then init module.
// CHECK: llvm.func internal @__cudaFortranConstructor()
// CHECK: fir.address_of(@_QMtestEmanx.managed.ptr) : !fir.ref<!fir.llvm_ptr<i8>>
// CHECK: fir.call @_FortranACUFRegisterManagedVariable
+// CHECK: fir.call @_FortranACUFInitModule
More information about the llvm-commits
mailing list