[flang-commits] [flang] [flang][CodeGen] Fix address space mismatch for CUF globals in AddrOfOpConversion (PR #190408)
via flang-commits
flang-commits at lists.llvm.org
Fri Apr 3 14:19:28 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: khaki3
<details>
<summary>Changes</summary>
When lowering fir.address_of to llvm.mlir.addressof, the AddrOfOpConversion
only looked up llvm.mlir.global to determine the address space. During
conversion, the global may still be a fir.global (not yet converted), causing
the lookup to fail and fall back to address space 0. For CUF constant memory
globals (address space 4), this creates a mismatch between the addressof
pointer and the global's address space, triggering a verification error.
Fix by also looking up fir::GlobalOp when LLVM::GlobalOp is not found and
deriving the address space from its CUF data_attr.
---
Full diff: https://github.com/llvm/llvm-project/pull/190408.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+34-4)
- (modified) flang/test/Fir/CUDA/cuda-code-gen.mlir (+21)
``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 2d01463cf604d..e3340887ee5dc 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -187,10 +187,36 @@ mlir::Value replaceWithAddrOfOrASCast(mlir::ConversionPatternRewriter &rewriter,
return mlir::LLVM::AddressOfOp::create(rewriter, loc, type, symName);
}
+/// Return the NVVM address space implied by a CUF data attribute on a
+/// fir::GlobalOp that has not yet been converted to llvm.mlir.global.
+static unsigned getCUFAddrSpace(fir::GlobalOp global) {
+ if (auto dataAttr = global.getDataAttr()) {
+ if (*dataAttr == cuf::DataAttribute::Constant)
+ return static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Constant);
+ if (*dataAttr == cuf::DataAttribute::Shared)
+ return static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared);
+ if (*dataAttr == cuf::DataAttribute::Managed)
+ return static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Global);
+ }
+ return 0;
+}
+
/// Lower `fir.address_of` operation to `llvm.address_of` operation.
struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
using FIROpConversion::FIROpConversion;
+ /// Look up the address space for a symbol in \p mod, handling both
+ /// already-converted llvm.mlir.global and not-yet-converted fir.global.
+ template <typename ModOp>
+ unsigned getAddrSpaceForGlobal(ModOp mod, mlir::SymbolRefAttr sym,
+ unsigned fallback) const {
+ if (auto g = mod.template lookupSymbol<mlir::LLVM::GlobalOp>(sym))
+ return g.getAddrSpace();
+ if (auto g = mod.template lookupSymbol<fir::GlobalOp>(sym))
+ return getCUFAddrSpace(g);
+ return fallback;
+ }
+
llvm::LogicalResult
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -199,7 +225,9 @@ struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
auto global = gpuMod.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
replaceWithAddrOfOrASCast(
rewriter, addr->getLoc(),
- global ? global.getAddrSpace() : getGlobalAddressSpace(rewriter),
+ global ? global.getAddrSpace()
+ : getAddrSpaceForGlobal(gpuMod, addr.getSymbol(),
+ getGlobalAddressSpace(rewriter)),
getProgramAddressSpace(rewriter),
global ? global.getSymName()
: addr.getSymbol().getRootReference().getValue(),
@@ -207,11 +235,13 @@ struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
return mlir::success();
}
- auto global = addr->getParentOfType<mlir::ModuleOp>()
- .lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
+ auto mod = addr->getParentOfType<mlir::ModuleOp>();
+ auto global = mod.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
replaceWithAddrOfOrASCast(
rewriter, addr->getLoc(),
- global ? global.getAddrSpace() : getGlobalAddressSpace(rewriter),
+ global ? global.getAddrSpace()
+ : getAddrSpaceForGlobal(mod, addr.getSymbol(),
+ getGlobalAddressSpace(rewriter)),
getProgramAddressSpace(rewriter),
global ? global.getSymName()
: addr.getSymbol().getRootReference().getValue(),
diff --git a/flang/test/Fir/CUDA/cuda-code-gen.mlir b/flang/test/Fir/CUDA/cuda-code-gen.mlir
index fc962f8de5039..923c15e07edd4 100644
--- a/flang/test/Fir/CUDA/cuda-code-gen.mlir
+++ b/flang/test/Fir/CUDA/cuda-code-gen.mlir
@@ -328,3 +328,24 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
}
// CHECK: llvm.mlir.global external @_QMtestEmanx() {addr_space = 1 : i32, nvvm.managed} : !llvm.array<100 x i32>
+
+// -----
+
+// Test that a host-side fir.address_of referencing a fir.global with CUF
+// constant data_attr produces an addrspacecast from ptr<4> to ptr.
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+ fir.global @_QMmodEcval {data_attr = #cuf.cuda<constant>} : i32 {
+ %0 = fir.zero_bits i32
+ fir.has_value %0 : i32
+ }
+ func.func @_QQhost() {
+ %0 = fir.address_of(@_QMmodEcval) : !fir.ref<i32>
+ return
+ }
+}
+
+// CHECK: llvm.mlir.global external @_QMmodEcval() {addr_space = 4 : i32} : i32
+// CHECK-LABEL: llvm.func @_QQhost()
+// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @_QMmodEcval : !llvm.ptr<4>
+// CHECK: %{{.*}} = llvm.addrspacecast %[[ADDR]] : !llvm.ptr<4> to !llvm.ptr
``````````
</details>
https://github.com/llvm/llvm-project/pull/190408
More information about the flang-commits
mailing list