[flang-commits] [flang] [flang][cuda] Avoid runtime copies for scalar constant host reads (PR #204193)
Zhen Wang via flang-commits
flang-commits at lists.llvm.org
Tue Jun 16 09:47:03 PDT 2026
https://github.com/wangzpgi updated https://github.com/llvm/llvm-project/pull/204193
>From 194e6349a8a7b0d81094ed711767a9159cf4f836 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 15 Jun 2026 09:53:08 -0700
Subject: [PATCH 1/2] Fix host read-back from scalar CUDA constants
---
.../Transforms/CUDA/CUFOpConversion.cpp | 67 ++++++++++++-------
flang/test/Fir/CUDA/cuda-global-addr.mlir | 15 +++++
2 files changed, 56 insertions(+), 26 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
index 8fa578bd0617d..47270d5588f5c 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
@@ -263,6 +263,47 @@ struct CUFDataTransferOpConversion
return mlir::success();
}
+ mlir::Value dst = op.getDst();
+ mlir::Value src = op.getSrc();
+ // Scalar CUDA constants keep a host shadow for host reads. Host-to-device
+ // assignments also update the device constant symbol.
+ auto getAddrOf = [](mlir::Value val) -> fir::AddrOfOp {
+ if (auto declareOp = val.getDefiningOp<fir::DeclareOp>())
+ return declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
+ if (auto declareOp = val.getDefiningOp<hlfir::DeclareOp>())
+ return declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
+ return {};
+ };
+ if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) {
+ if (fir::AddrOfOp addrOfOp = getAddrOf(src)) {
+ auto global = symtab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getRootReference().getValue());
+ if (isScalarCudaConstantGlobal(global) &&
+ fir::isa_ref_type(dst.getType())) {
+ mlir::Value hostValue = fir::LoadOp::create(builder, loc, src);
+ hostValue = createConvertOp(rewriter, loc, dstTy, hostValue);
+ fir::StoreOp::create(builder, loc, hostValue, dst);
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+ }
+ }
+ if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
+ if (fir::AddrOfOp addrOfOp = getAddrOf(dst)) {
+ auto global = symtab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getRootReference().getValue());
+ if (isScalarCudaConstantGlobal(global)) {
+ mlir::Value hostValue = src;
+ if (fir::isa_ref_type(src.getType()))
+ hostValue = fir::LoadOp::create(builder, loc, src);
+ hostValue = createConvertOp(rewriter, loc, dstTy, hostValue);
+ fir::StoreOp::create(builder, loc, hostValue, addrOfOp);
+ dst = cuf::DeviceAddressOp::create(rewriter, loc, dst.getType(),
+ addrOfOp.getSymbol());
+ }
+ }
+ }
+
mlir::Type i64Ty = builder.getI64Type();
mlir::Value nbElement =
cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
@@ -288,32 +329,6 @@ struct CUFDataTransferOpConversion
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
- mlir::Value dst = op.getDst();
- mlir::Value src = op.getSrc();
- // Host assignments to scalar CUDA constants update both the host-visible
- // global and the device constant symbol.
- auto getAddrOf = [](mlir::Value val) -> fir::AddrOfOp {
- if (auto declareOp = val.getDefiningOp<fir::DeclareOp>())
- return declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
- if (auto declareOp = val.getDefiningOp<hlfir::DeclareOp>())
- return declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
- return {};
- };
- if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
- if (fir::AddrOfOp addrOfOp = getAddrOf(dst)) {
- auto global = symtab.lookup<fir::GlobalOp>(
- addrOfOp.getSymbol().getRootReference().getValue());
- if (isScalarCudaConstantGlobal(global)) {
- mlir::Value hostValue = src;
- if (fir::isa_ref_type(src.getType()))
- hostValue = fir::LoadOp::create(builder, loc, src);
- hostValue = createConvertOp(rewriter, loc, dstTy, hostValue);
- fir::StoreOp::create(builder, loc, hostValue, addrOfOp);
- dst = cuf::DeviceAddressOp::create(rewriter, loc, dst.getType(),
- addrOfOp.getSymbol());
- }
- }
- }
// Materialize the src if constant.
if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
mlir::Value temp = builder.createTemporary(loc, srcTy);
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index 1bf1cd350a669..aed86312e7af0 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -151,6 +151,13 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> :
fir.call @_QPuse_index(%3) : (index) -> ()
return
}
+ func.func @_QQconstant_scalar_device_to_host() attributes {fir.bindc_name = "T"} {
+ %0 = fir.address_of(@_QMcon3Ezzz) : !fir.ref<i32>
+ %1 = fir.declare %0 {data_attr = #cuf.cuda<constant>, uniq_name = "_QMcon3Ezzz"} : (!fir.ref<i32>) -> !fir.ref<i32>
+ %2 = fir.alloca i32
+ cuf.data_transfer %1 to %2 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
+ return
+ }
func.func private @_QPuse_index(index)
fir.global @_QMcon3Ezzz {data_attr = #cuf.cuda<constant>} : i32
}
@@ -162,3 +169,11 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> :
// CHECK: fir.store %{{.*}} to %[[ADDR]] : !fir.ref<i32>
// CHECK: fir.call @_FortranACUFGetDeviceAddress
// CHECK-NOT: fir.load %{{.*}} : !fir.ref<i32>
+// CHECK: fir.call @_QPuse_index
+// CHECK-LABEL: func.func @_QQconstant_scalar_device_to_host()
+// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMcon3Ezzz) : !fir.ref<i32>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR]] {data_attr = #cuf.cuda<constant>, uniq_name = "_QMcon3Ezzz"} : (!fir.ref<i32>) -> !fir.ref<i32>
+// CHECK: %[[DST:.*]] = fir.alloca i32
+// CHECK: %[[VALUE:.*]] = fir.load %[[DECL]] : !fir.ref<i32>
+// CHECK: fir.store %[[VALUE]] to %[[DST]] : !fir.ref<i32>
+// CHECK-NOT: fir.call @_FortranACUFDataTransferPtrPtr
>From ebd70d28314b2ec61fd8c072c62f29712d213e01 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Tue, 16 Jun 2026 09:46:43 -0700
Subject: [PATCH 2/2] reorg
---
.../Transforms/CUDA/CUFOpConversion.cpp | 51 +++++++++----------
1 file changed, 25 insertions(+), 26 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
index 47270d5588f5c..6ea3c1fbdddf9 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
@@ -263,6 +263,31 @@ struct CUFDataTransferOpConversion
return mlir::success();
}
+ mlir::Type i64Ty = builder.getI64Type();
+ mlir::Value nbElement =
+ cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
+ unsigned width = 0;
+ if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
+ mlir::Type structTy =
+ typeConverter->convertType(fir::unwrapSequenceType(dstTy));
+ width = dl->getTypeSizeInBits(structTy) / 8;
+ } else {
+ width = cuf::computeElementByteSize(loc, dstTy, kindMap);
+ }
+ mlir::Value widthValue = mlir::arith::ConstantOp::create(
+ rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
+ mlir::Value bytes = nbElement ? mlir::arith::MulIOp::create(
+ rewriter, loc, nbElement, widthValue)
+ : widthValue;
+
+ mlir::func::FuncOp func =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
+ builder);
+ auto fTy = func.getFunctionType();
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
+
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();
// Scalar CUDA constants keep a host shadow for host reads. Host-to-device
@@ -303,32 +328,6 @@ struct CUFDataTransferOpConversion
}
}
}
-
- mlir::Type i64Ty = builder.getI64Type();
- mlir::Value nbElement =
- cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
- unsigned width = 0;
- if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
- mlir::Type structTy =
- typeConverter->convertType(fir::unwrapSequenceType(dstTy));
- width = dl->getTypeSizeInBits(structTy) / 8;
- } else {
- width = cuf::computeElementByteSize(loc, dstTy, kindMap);
- }
- mlir::Value widthValue = mlir::arith::ConstantOp::create(
- rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
- mlir::Value bytes = nbElement ? mlir::arith::MulIOp::create(
- rewriter, loc, nbElement, widthValue)
- : widthValue;
-
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
- builder);
- auto fTy = func.getFunctionType();
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
-
// Materialize the src if constant.
if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
mlir::Value temp = builder.createTemporary(loc, srcTy);
More information about the flang-commits
mailing list