[llvm-branch-commits] [llvm] [mlir] [MLIR][OpenMP] Post-translate declare-target USM indirection in OpenMPIRBuilder (PR #194291)
Kareem Ergawy via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Apr 26 22:29:10 PDT 2026
https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/194291
>From f1155516b0a85d1bdb92ec56951915c6bf3065e1 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Thu, 23 Apr 2026 00:39:22 -0500
Subject: [PATCH] [MLIR][OpenMP] Post-translate declare-target USM indirection
in OpenMPIRBuilder
When lowering OpenMP to LLVM IR for the target device, record pairs of the
`declare target` device global and the OMPIRBuilder "ref" pointer global
(used for unified shared memory) via `OpenMPIRBuilder`. During the
`OpenMPIRBuilder::finalize` pass, run a postpass that rewrites remaining uses of the
original global to load from the ref global and adjust the pointer (shared
path for `ConstantExpr` addrspace/bitcast chains and for direct
instruction uses).
This follows what is done by clang for similar cases:
https://reviews.llvm.org/D63108.
Co-authored-by: Composer
Co-authored-by: Gemini Pro
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 20 ++++++
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 68 +++++++++++++++++++
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 14 +++-
.../omptarget-declare-target-usm-ref-ptr.mlir | 24 +++++++
.../fortran/declare-target-usm-ref-ptr.f90 | 39 +++++++++++
5 files changed, 162 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir
create mode 100644 offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index dbd8f0c6b8927..3a184da7a0855 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -538,6 +538,26 @@ class OpenMPIRBuilder {
/// used in the OpenMPIRBuilder generated from OMPKinds.def.
LLVM_ABI void initialize();
+ SmallVector<std::pair<GlobalVariable *, GlobalVariable *>>
+ declareTargetUsmRefPtrPairs;
+
+ /// Replaces all uses of `origGV` with a load from `refPtrGV`.
+ /// This is used for OpenMP `declare target` global variables mapped under
+ /// Unified Shared Memory (USM) where access is routed through a reference
+ /// pointer.
+ Error rewriteDeclareTargetGlobalUsesWithRefPtr(GlobalVariable *origGV,
+ GlobalVariable *refPtrGV);
+
+ /// Registers a mapping between an original `declare target` global variable
+ /// and the corresponding reference pointer global variable generated for
+ /// Unified Shared Memory (USM).
+ void addDeclareTargetUsmRefPair(GlobalVariable *orig, GlobalVariable *refPtr);
+
+ /// Rewrites the uses of all `declare target` global variables registered via
+ /// `addDeclareTargetUsmRefPair` to use their corresponding USM reference
+ /// pointers. This pass is executed at the end of the module translation.
+ Error finalizeDeclareTargetUsmIndirectLoads();
+
void setConfig(OpenMPIRBuilderConfig C) { Config = C; }
/// Finalize the underlying module, e.g., by outlining regions.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 125620bd49502..778cc63d74abc 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -743,6 +743,71 @@ CallInst *OpenMPIRBuilder::createRuntimeFunctionCall(FunctionCallee Callee,
return Call;
}
+void OpenMPIRBuilder::addDeclareTargetUsmRefPair(GlobalVariable *orig,
+ GlobalVariable *refPtr) {
+ declareTargetUsmRefPtrPairs.emplace_back(orig, refPtr);
+}
+
+Error OpenMPIRBuilder::rewriteDeclareTargetGlobalUsesWithRefPtr(
+ GlobalVariable *origGV, GlobalVariable *refPtrGV) {
+ auto replaceUsesWithRefLoad = [refPtrGV](Instruction *inst, Value *replaced) {
+ IRBuilder<> b(inst);
+ Value *rep =
+ b.CreateLoad(refPtrGV->getValueType(), refPtrGV, "decltgt.ref");
+ if (rep->getType() != replaced->getType())
+ rep = b.CreatePointerBitCastOrAddrSpaceCast(rep, replaced->getType(),
+ "decltgt.as");
+ inst->replaceUsesOfWith(replaced, rep);
+ };
+
+ SmallSetVector<User *, 8> users;
+ for (User *u : origGV->users())
+ users.insert(u);
+
+ for (User *u : users) {
+ if (auto *ce = dyn_cast<ConstantExpr>(u)) {
+ const bool isPointerCast =
+ ce->getOpcode() == Instruction::AddrSpaceCast ||
+ (ce->getOpcode() == Instruction::BitCast &&
+ ce->getType()->isPointerTy());
+
+ if (ce->getOperand(0) != origGV || !isPointerCast)
+ continue;
+
+ SmallVector<User *, 8> instUsers;
+ for (User *ceUser : ce->users())
+ if (isa<Instruction>(ceUser))
+ instUsers.push_back(ceUser);
+
+ for (User *ceUser : instUsers) {
+ auto *inst = cast<Instruction>(ceUser);
+ replaceUsesWithRefLoad(inst, ce);
+ }
+
+ if (ce->use_empty())
+ ce->destroyConstant();
+ } else if (auto *insn = dyn_cast<Instruction>(u)) {
+ replaceUsesWithRefLoad(insn, origGV);
+ }
+ }
+
+ if (!origGV->use_empty())
+ return createStringError(inconvertibleErrorCode(),
+ "expected all uses of '%s' to be replaced",
+ origGV->getName().str().c_str());
+
+ return Error::success();
+}
+
+Error OpenMPIRBuilder::finalizeDeclareTargetUsmIndirectLoads() {
+ if (!Config.isTargetDevice() || declareTargetUsmRefPtrPairs.empty())
+ return Error::success();
+ for (auto [orig, ref] : declareTargetUsmRefPtrPairs)
+ if (Error Err = rewriteDeclareTargetGlobalUsesWithRefPtr(orig, ref))
+ return Err;
+ return Error::success();
+}
+
void OpenMPIRBuilder::initialize() { initializeTypes(M); }
static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
@@ -948,6 +1013,9 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
emitUsed("llvm.compiler.used", LLVMCompilerUsed);
}
+ if (Error Err = finalizeDeclareTargetUsmIndirectLoads())
+ report_fatal_error(std::move(Err));
+
IsFinalized = true;
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8614aed1ab80c..bdf738eacc113 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -30,6 +30,8 @@
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/ReplaceConstant.h"
@@ -7475,12 +7477,18 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
(attribute.getCaptureClause().getValue() !=
mlir::omp::DeclareTargetCaptureClause::to ||
ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
- ompBuilder->getAddrOfDeclareTargetVar(
+ llvm::Type *ptrTy = gVal->getType();
+ if (ompBuilder->Config.hasRequiresUnifiedSharedMemory())
+ ptrTy = llvm::PointerType::get(llvmModule->getContext(), 0);
+ llvm::Constant *refPtrConst = ompBuilder->getAddrOfDeclareTargetVar(
captureClause, deviceClause, isDeclaration, isExternallyVisible,
ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
- gVal->getType(), /*GlobalInitializer*/ nullptr,
- /*VariableLinkage*/ nullptr);
+ ptrTy, /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr);
+ if (auto *origGV = llvm::dyn_cast<llvm::GlobalVariable>(gVal))
+ if (auto *refPtrGV =
+ llvm::dyn_cast_or_null<llvm::GlobalVariable>(refPtrConst))
+ ompBuilder->addDeclareTargetUsmRefPair(origGV, refPtrGV);
}
}
}
diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir
new file mode 100644
index 0000000000000..fdbf16914e25c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// This tests the replacement of uses of a declare target global variable with
+// the unified shared memory (USM) generated reference pointer in an explicit device function.
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true, omp.requires = #omp<clause_requires unified_shared_memory>} {
+ // CHECK-DAG: @_QMmEnx_vals_decl_tgt_ref_ptr = weak global ptr null, align 8
+ llvm.mlir.global external @_QMmEnx_vals() {addr_space = 1 : i32, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link)>} : i32 {
+ %0 = llvm.mlir.zero : i32
+ llvm.return %0 : i32
+ }
+
+ // CHECK-LABEL: define void @_QMmPget_dims_noarg(ptr %0)
+ llvm.func @_QMmPget_dims_noarg(%arg0: !llvm.ptr) attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} {
+ // CHECK: %[[REF_LOAD:.*]] = load ptr, ptr @_QMmEnx_vals_decl_tgt_ref_ptr, align 8
+ // CHECK: %[[AS_CAST:.*]] = addrspacecast ptr %[[REF_LOAD]] to ptr addrspace(1)
+ // CHECK: %[[VAL:.*]] = load i32, ptr addrspace(1) %[[AS_CAST]], align 4
+ // CHECK: store i32 %[[VAL]], ptr %0, align 4
+ %0 = llvm.mlir.addressof @_QMmEnx_vals : !llvm.ptr<1>
+ %1 = llvm.load %0 : !llvm.ptr<1> -> i32
+ llvm.store %1, %arg0 : i32, !llvm.ptr
+ llvm.return
+ }
+}
\ No newline at end of file
diff --git a/offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90 b/offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90
new file mode 100644
index 0000000000000..7d539a82af91b
--- /dev/null
+++ b/offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90
@@ -0,0 +1,39 @@
+! Test declare target global replacement with USM reference pointer.
+!
+! REQUIRES: flang, amdgpu
+! RUN: %libomptarget-compile-fortran-generic -fopenmp-force-usm
+! RUN: env LIBOMPTARGET_INFO=16 HSA_XNACK=1 %libomptarget-run-generic 2>&1 | %fcheck-generic
+
+module m
+ implicit none
+ integer :: nx_vals
+ !$omp declare target(nx_vals)
+contains
+ subroutine get_dims_noarg(kv)
+ !$omp declare target
+ integer, intent(out) :: kv
+ kv = nx_vals
+ end subroutine get_dims_noarg
+end module m
+
+program reproducer
+ use m
+ implicit none
+ integer :: kv, kv_debug
+
+ nx_vals = 6
+ !$omp target enter data map(always, to: nx_vals)
+
+ kv_debug = -1
+ !$omp target map(tofrom: kv_debug)
+ call get_dims_noarg(kv)
+ kv_debug = kv
+ !$omp end target
+
+ print *, 'kv_debug after target (host)', kv_debug
+
+ !$omp target exit data map(release: nx_vals)
+end program reproducer
+
+! CHECK: PluginInterface device {{[0-9]+}} info: Launching kernel
+ ! CHECK: kv_debug after target (host) 6
More information about the llvm-branch-commits
mailing list