[llvm-branch-commits] [llvm] [mlir] [MLIR][OpenMP] Post-translate declare-target USM indirection in OpenMPIRBuilder (PR #194291)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Apr 26 22:27:53 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-offload
Author: Kareem Ergawy (ergawy)
<details>
<summary>Changes</summary>
When lowering OpenMP to LLVM IR for the target device, record pairs of the `declare target` device global and the OMPIRBuilder "ref" pointer global (used for unified shared memory) via `OpenMPIRBuilder`. During the `OpenMPIRBuilder::finalize` pass, run a postpass that rewrites remaining uses of the original global to load from the ref global and adjust the pointer (shared path for `ConstantExpr` addrspace/bitcast chains and for direct instruction uses).
This follows what is done by clang for similar cases: https://reviews.llvm.org/D63108.
Co-authored-by: Composer
Co-authored-by: Gemini Pro
---
Full diff: https://github.com/llvm/llvm-project/pull/194291.diff
5 Files Affected:
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+22)
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+69)
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+11-3)
- (added) mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir (+24)
- (added) offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90 (+39)
``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index dbd8f0c6b8927..f4e5d5be35604 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -538,6 +538,28 @@ class OpenMPIRBuilder {
/// used in the OpenMPIRBuilder generated from OMPKinds.def.
LLVM_ABI void initialize();
+ SmallVector<std::pair<GlobalVariable *, GlobalVariable *>>
+ declareTargetUsmRefPtrPairs;
+
+ /// Replaces all uses of `origGV` with a load from `refPtrGV`.
+ /// This is used for OpenMP `declare target` global variables mapped under
+ /// Unified Shared Memory (USM) where access is routed through a reference
+ /// pointer.
+ Error
+ rewriteDeclareTargetGlobalUsesWithRefPtr(GlobalVariable *origGV,
+ GlobalVariable *refPtrGV);
+
+ /// Registers a mapping between an original `declare target` global variable
+ /// and the corresponding reference pointer global variable generated for
+ /// Unified Shared Memory (USM).
+ void addDeclareTargetUsmRefPair(GlobalVariable *orig,
+ GlobalVariable *refPtr);
+
+ /// Rewrites the uses of all `declare target` global variables registered via
+ /// `addDeclareTargetUsmRefPair` to use their corresponding USM reference
+ /// pointers. This pass is executed at the end of the module translation.
+ Error finalizeDeclareTargetUsmIndirectLoads();
+
void setConfig(OpenMPIRBuilderConfig C) { Config = C; }
/// Finalize the underlying module, e.g., by outlining regions.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 125620bd49502..981036a955f72 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -743,6 +743,72 @@ CallInst *OpenMPIRBuilder::createRuntimeFunctionCall(FunctionCallee Callee,
return Call;
}
+void OpenMPIRBuilder::addDeclareTargetUsmRefPair(GlobalVariable *orig,
+ GlobalVariable *refPtr) {
+ declareTargetUsmRefPtrPairs.emplace_back(orig, refPtr);
+}
+
+Error OpenMPIRBuilder::rewriteDeclareTargetGlobalUsesWithRefPtr(
+ GlobalVariable *origGV, GlobalVariable *refPtrGV) {
+ auto replaceUsesWithRefLoad = [refPtrGV](Instruction *inst,
+ Value *replaced) {
+ IRBuilder<> b(inst);
+ Value *rep =
+ b.CreateLoad(refPtrGV->getValueType(), refPtrGV, "decltgt.ref");
+ if (rep->getType() != replaced->getType())
+ rep = b.CreatePointerBitCastOrAddrSpaceCast(
+ rep, replaced->getType(), "decltgt.as");
+ inst->replaceUsesOfWith(replaced, rep);
+ };
+
+ SmallSetVector<User *, 8> users;
+ for (User *u : origGV->users())
+ users.insert(u);
+
+ for (User *u : users) {
+ if (auto *ce = dyn_cast<ConstantExpr>(u)) {
+ const bool isPointerCast =
+ ce->getOpcode() == Instruction::AddrSpaceCast ||
+ (ce->getOpcode() == Instruction::BitCast &&
+ ce->getType()->isPointerTy());
+
+ if (ce->getOperand(0) != origGV || !isPointerCast)
+ continue;
+
+ SmallVector<User *, 8> instUsers;
+ for (User *ceUser : ce->users())
+ if (isa<Instruction>(ceUser))
+ instUsers.push_back(ceUser);
+
+ for (User *ceUser : instUsers) {
+ auto *inst = cast<Instruction>(ceUser);
+ replaceUsesWithRefLoad(inst, ce);
+ }
+
+ if (ce->use_empty())
+ ce->destroyConstant();
+ } else if (auto *insn = dyn_cast<Instruction>(u)) {
+ replaceUsesWithRefLoad(insn, origGV);
+ }
+ }
+
+ if (!origGV->use_empty())
+ return createStringError(inconvertibleErrorCode(),
+ "expected all uses of '%s' to be replaced",
+ origGV->getName().str().c_str());
+
+ return Error::success();
+}
+
+Error OpenMPIRBuilder::finalizeDeclareTargetUsmIndirectLoads() {
+ if (!Config.isTargetDevice() || declareTargetUsmRefPtrPairs.empty())
+ return Error::success();
+ for (auto [orig, ref] : declareTargetUsmRefPtrPairs)
+ if (Error Err = rewriteDeclareTargetGlobalUsesWithRefPtr(orig, ref))
+ return Err;
+ return Error::success();
+}
+
void OpenMPIRBuilder::initialize() { initializeTypes(M); }
static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
@@ -948,6 +1014,9 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
emitUsed("llvm.compiler.used", LLVMCompilerUsed);
}
+ if (Error Err = finalizeDeclareTargetUsmIndirectLoads())
+ report_fatal_error(std::move(Err));
+
IsFinalized = true;
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8614aed1ab80c..97b94c4bae00f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -30,6 +30,8 @@
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/ReplaceConstant.h"
@@ -7475,12 +7477,18 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
(attribute.getCaptureClause().getValue() !=
mlir::omp::DeclareTargetCaptureClause::to ||
ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
- ompBuilder->getAddrOfDeclareTargetVar(
+ llvm::Type *ptrTy = gVal->getType();
+ if (ompBuilder->Config.hasRequiresUnifiedSharedMemory())
+ ptrTy = llvm::PointerType::get(llvmModule->getContext(), 0);
+ llvm::Constant *refPtrConst = ompBuilder->getAddrOfDeclareTargetVar(
captureClause, deviceClause, isDeclaration, isExternallyVisible,
ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
- gVal->getType(), /*GlobalInitializer*/ nullptr,
- /*VariableLinkage*/ nullptr);
+ ptrTy, /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr);
+ if (auto *origGV = llvm::dyn_cast<llvm::GlobalVariable>(gVal))
+ if (auto *refPtrGV = llvm::dyn_cast_or_null<llvm::GlobalVariable>(
+ refPtrConst))
+ ompBuilder->addDeclareTargetUsmRefPair(origGV, refPtrGV);
}
}
}
diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir
new file mode 100644
index 0000000000000..fdbf16914e25c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// This tests the replacement of uses of a declare target global variable with
+// the unified shared memory (USM) generated reference pointer in an explicit device function.
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true, omp.requires = #omp<clause_requires unified_shared_memory>} {
+ // CHECK-DAG: @_QMmEnx_vals_decl_tgt_ref_ptr = weak global ptr null, align 8
+ llvm.mlir.global external @_QMmEnx_vals() {addr_space = 1 : i32, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link)>} : i32 {
+ %0 = llvm.mlir.zero : i32
+ llvm.return %0 : i32
+ }
+
+ // CHECK-LABEL: define void @_QMmPget_dims_noarg(ptr %0)
+ llvm.func @_QMmPget_dims_noarg(%arg0: !llvm.ptr) attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} {
+ // CHECK: %[[REF_LOAD:.*]] = load ptr, ptr @_QMmEnx_vals_decl_tgt_ref_ptr, align 8
+ // CHECK: %[[AS_CAST:.*]] = addrspacecast ptr %[[REF_LOAD]] to ptr addrspace(1)
+ // CHECK: %[[VAL:.*]] = load i32, ptr addrspace(1) %[[AS_CAST]], align 4
+ // CHECK: store i32 %[[VAL]], ptr %0, align 4
+ %0 = llvm.mlir.addressof @_QMmEnx_vals : !llvm.ptr<1>
+ %1 = llvm.load %0 : !llvm.ptr<1> -> i32
+ llvm.store %1, %arg0 : i32, !llvm.ptr
+ llvm.return
+ }
+}
\ No newline at end of file
diff --git a/offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90 b/offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90
new file mode 100644
index 0000000000000..7d539a82af91b
--- /dev/null
+++ b/offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90
@@ -0,0 +1,39 @@
+! Test declare target global replacement with USM reference pointer.
+!
+! REQUIRES: flang, amdgpu
+! RUN: %libomptarget-compile-fortran-generic -fopenmp-force-usm
+! RUN: env LIBOMPTARGET_INFO=16 HSA_XNACK=1 %libomptarget-run-generic 2>&1 | %fcheck-generic
+
+module m
+ implicit none
+ integer :: nx_vals
+ !$omp declare target(nx_vals)
+contains
+ subroutine get_dims_noarg(kv)
+ !$omp declare target
+ integer, intent(out) :: kv
+ kv = nx_vals
+ end subroutine get_dims_noarg
+end module m
+
+program reproducer
+ use m
+ implicit none
+ integer :: kv, kv_debug
+
+ nx_vals = 6
+ !$omp target enter data map(always, to: nx_vals)
+
+ kv_debug = -1
+ !$omp target map(tofrom: kv_debug)
+ call get_dims_noarg(kv)
+ kv_debug = kv
+ !$omp end target
+
+ print *, 'kv_debug after target (host)', kv_debug
+
+ !$omp target exit data map(release: nx_vals)
+end program reproducer
+
+! CHECK: PluginInterface device {{[0-9]+}} info: Launching kernel
+ ! CHECK: kv_debug after target (host) 6
``````````
</details>
https://github.com/llvm/llvm-project/pull/194291
More information about the llvm-branch-commits
mailing list