[flang-commits] [flang] [flang][cuda] Avoid runtime copies for scalar constant host reads (PR #204193)

Zhen Wang via flang-commits flang-commits at lists.llvm.org
Tue Jun 16 09:47:03 PDT 2026


https://github.com/wangzpgi updated https://github.com/llvm/llvm-project/pull/204193

>From 194e6349a8a7b0d81094ed711767a9159cf4f836 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Mon, 15 Jun 2026 09:53:08 -0700
Subject: [PATCH 1/2] Fix host read-back from scalar CUDA constants

---
 .../Transforms/CUDA/CUFOpConversion.cpp       | 67 ++++++++++++-------
 flang/test/Fir/CUDA/cuda-global-addr.mlir     | 15 +++++
 2 files changed, 56 insertions(+), 26 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
index 8fa578bd0617d..47270d5588f5c 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
@@ -263,6 +263,47 @@ struct CUFDataTransferOpConversion
         return mlir::success();
       }
 
+      mlir::Value dst = op.getDst();
+      mlir::Value src = op.getSrc();
+      // Scalar CUDA constants keep a host shadow for host reads. Host-to-device
+      // assignments also update the device constant symbol.
+      auto getAddrOf = [](mlir::Value val) -> fir::AddrOfOp {
+        if (auto declareOp = val.getDefiningOp<fir::DeclareOp>())
+          return declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
+        if (auto declareOp = val.getDefiningOp<hlfir::DeclareOp>())
+          return declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
+        return {};
+      };
+      if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) {
+        if (fir::AddrOfOp addrOfOp = getAddrOf(src)) {
+          auto global = symtab.lookup<fir::GlobalOp>(
+              addrOfOp.getSymbol().getRootReference().getValue());
+          if (isScalarCudaConstantGlobal(global) &&
+              fir::isa_ref_type(dst.getType())) {
+            mlir::Value hostValue = fir::LoadOp::create(builder, loc, src);
+            hostValue = createConvertOp(rewriter, loc, dstTy, hostValue);
+            fir::StoreOp::create(builder, loc, hostValue, dst);
+            rewriter.eraseOp(op);
+            return mlir::success();
+          }
+        }
+      }
+      if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
+        if (fir::AddrOfOp addrOfOp = getAddrOf(dst)) {
+          auto global = symtab.lookup<fir::GlobalOp>(
+              addrOfOp.getSymbol().getRootReference().getValue());
+          if (isScalarCudaConstantGlobal(global)) {
+            mlir::Value hostValue = src;
+            if (fir::isa_ref_type(src.getType()))
+              hostValue = fir::LoadOp::create(builder, loc, src);
+            hostValue = createConvertOp(rewriter, loc, dstTy, hostValue);
+            fir::StoreOp::create(builder, loc, hostValue, addrOfOp);
+            dst = cuf::DeviceAddressOp::create(rewriter, loc, dst.getType(),
+                                               addrOfOp.getSymbol());
+          }
+        }
+      }
+
       mlir::Type i64Ty = builder.getI64Type();
       mlir::Value nbElement =
           cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
@@ -288,32 +329,6 @@ struct CUFDataTransferOpConversion
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
 
-      mlir::Value dst = op.getDst();
-      mlir::Value src = op.getSrc();
-      // Host assignments to scalar CUDA constants update both the host-visible
-      // global and the device constant symbol.
-      auto getAddrOf = [](mlir::Value val) -> fir::AddrOfOp {
-        if (auto declareOp = val.getDefiningOp<fir::DeclareOp>())
-          return declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
-        if (auto declareOp = val.getDefiningOp<hlfir::DeclareOp>())
-          return declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
-        return {};
-      };
-      if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
-        if (fir::AddrOfOp addrOfOp = getAddrOf(dst)) {
-          auto global = symtab.lookup<fir::GlobalOp>(
-              addrOfOp.getSymbol().getRootReference().getValue());
-          if (isScalarCudaConstantGlobal(global)) {
-            mlir::Value hostValue = src;
-            if (fir::isa_ref_type(src.getType()))
-              hostValue = fir::LoadOp::create(builder, loc, src);
-            hostValue = createConvertOp(rewriter, loc, dstTy, hostValue);
-            fir::StoreOp::create(builder, loc, hostValue, addrOfOp);
-            dst = cuf::DeviceAddressOp::create(rewriter, loc, dst.getType(),
-                                               addrOfOp.getSymbol());
-          }
-        }
-      }
       // Materialize the src if constant.
       if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
         mlir::Value temp = builder.createTemporary(loc, srcTy);
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index 1bf1cd350a669..aed86312e7af0 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -151,6 +151,13 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> :
     fir.call @_QPuse_index(%3) : (index) -> ()
     return
   }
+  func.func @_QQconstant_scalar_device_to_host() attributes {fir.bindc_name = "T"} {
+    %0 = fir.address_of(@_QMcon3Ezzz) : !fir.ref<i32>
+    %1 = fir.declare %0 {data_attr = #cuf.cuda<constant>, uniq_name = "_QMcon3Ezzz"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %2 = fir.alloca i32
+    cuf.data_transfer %1 to %2 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<i32>, !fir.ref<i32>
+    return
+  }
   func.func private @_QPuse_index(index)
   fir.global @_QMcon3Ezzz {data_attr = #cuf.cuda<constant>} : i32
 }
@@ -162,3 +169,11 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> :
 // CHECK: fir.store %{{.*}} to %[[ADDR]] : !fir.ref<i32>
 // CHECK: fir.call @_FortranACUFGetDeviceAddress
 // CHECK-NOT: fir.load %{{.*}} : !fir.ref<i32>
+// CHECK: fir.call @_QPuse_index
+// CHECK-LABEL: func.func @_QQconstant_scalar_device_to_host()
+// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMcon3Ezzz) : !fir.ref<i32>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR]] {data_attr = #cuf.cuda<constant>, uniq_name = "_QMcon3Ezzz"} : (!fir.ref<i32>) -> !fir.ref<i32>
+// CHECK: %[[DST:.*]] = fir.alloca i32
+// CHECK: %[[VALUE:.*]] = fir.load %[[DECL]] : !fir.ref<i32>
+// CHECK: fir.store %[[VALUE]] to %[[DST]] : !fir.ref<i32>
+// CHECK-NOT: fir.call @_FortranACUFDataTransferPtrPtr

>From ebd70d28314b2ec61fd8c072c62f29712d213e01 Mon Sep 17 00:00:00 2001
From: Zhen Wang <zhenw at nvidia.com>
Date: Tue, 16 Jun 2026 09:46:43 -0700
Subject: [PATCH 2/2] reorg

---
 .../Transforms/CUDA/CUFOpConversion.cpp       | 51 +++++++++----------
 1 file changed, 25 insertions(+), 26 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
index 47270d5588f5c..6ea3c1fbdddf9 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFOpConversion.cpp
@@ -263,6 +263,31 @@ struct CUFDataTransferOpConversion
         return mlir::success();
       }
 
+      mlir::Type i64Ty = builder.getI64Type();
+      mlir::Value nbElement =
+          cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
+      unsigned width = 0;
+      if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
+        mlir::Type structTy =
+            typeConverter->convertType(fir::unwrapSequenceType(dstTy));
+        width = dl->getTypeSizeInBits(structTy) / 8;
+      } else {
+        width = cuf::computeElementByteSize(loc, dstTy, kindMap);
+      }
+      mlir::Value widthValue = mlir::arith::ConstantOp::create(
+          rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
+      mlir::Value bytes = nbElement ? mlir::arith::MulIOp::create(
+                                          rewriter, loc, nbElement, widthValue)
+                                    : widthValue;
+
+      mlir::func::FuncOp func =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
+                                                                       builder);
+      auto fTy = func.getFunctionType();
+      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
+
       mlir::Value dst = op.getDst();
       mlir::Value src = op.getSrc();
       // Scalar CUDA constants keep a host shadow for host reads. Host-to-device
@@ -303,32 +328,6 @@ struct CUFDataTransferOpConversion
           }
         }
       }
-
-      mlir::Type i64Ty = builder.getI64Type();
-      mlir::Value nbElement =
-          cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
-      unsigned width = 0;
-      if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
-        mlir::Type structTy =
-            typeConverter->convertType(fir::unwrapSequenceType(dstTy));
-        width = dl->getTypeSizeInBits(structTy) / 8;
-      } else {
-        width = cuf::computeElementByteSize(loc, dstTy, kindMap);
-      }
-      mlir::Value widthValue = mlir::arith::ConstantOp::create(
-          rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
-      mlir::Value bytes = nbElement ? mlir::arith::MulIOp::create(
-                                          rewriter, loc, nbElement, widthValue)
-                                    : widthValue;
-
-      mlir::func::FuncOp func =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
-                                                                       builder);
-      auto fTy = func.getFunctionType();
-      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-      mlir::Value sourceLine =
-          fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
-
       // Materialize the src if constant.
       if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
         mlir::Value temp = builder.createTemporary(loc, srcTy);



More information about the flang-commits mailing list