[llvm-branch-commits] [llvm] [mlir] [MLIR][OpenMP] Post-translate declare-target USM indirection in OpenMPIRBuilder (PR #194291)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Apr 26 22:27:53 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-offload

Author: Kareem Ergawy (ergawy)

<details>
<summary>Changes</summary>

When lowering OpenMP to LLVM IR for the target device, record pairs of the `declare target` device global and the OMPIRBuilder "ref" pointer global (used for unified shared memory) via `OpenMPIRBuilder`. During the `OpenMPIRBuilder::finalize` pass, run a postpass that rewrites remaining uses of the original global to load from the ref global and adjust the pointer (shared path for `ConstantExpr` addrspace/bitcast chains and for direct instruction uses).

This follows what is done by clang for similar cases: https://reviews.llvm.org/D63108.

Co-authored-by: Composer
Co-authored-by: Gemini Pro

---
Full diff: https://github.com/llvm/llvm-project/pull/194291.diff


5 Files Affected:

- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+22) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+69) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+11-3) 
- (added) mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir (+24) 
- (added) offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90 (+39) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index dbd8f0c6b8927..f4e5d5be35604 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -538,6 +538,28 @@ class OpenMPIRBuilder {
   /// used in the OpenMPIRBuilder generated from OMPKinds.def.
   LLVM_ABI void initialize();
 
+  SmallVector<std::pair<GlobalVariable *, GlobalVariable *>>
+      declareTargetUsmRefPtrPairs;
+
+  /// Replaces all uses of `origGV` with a load from `refPtrGV`.
+  /// This is used for OpenMP `declare target` global variables mapped under
+  /// Unified Shared Memory (USM) where access is routed through a reference
+  /// pointer.
+  Error
+  rewriteDeclareTargetGlobalUsesWithRefPtr(GlobalVariable *origGV,
+                                           GlobalVariable *refPtrGV);
+
+  /// Registers a mapping between an original `declare target` global variable
+  /// and the corresponding reference pointer global variable generated for
+  /// Unified Shared Memory (USM).
+  void addDeclareTargetUsmRefPair(GlobalVariable *orig,
+                                  GlobalVariable *refPtr);
+
+  /// Rewrites the uses of all `declare target` global variables registered via
+  /// `addDeclareTargetUsmRefPair` to use their corresponding USM reference
+  /// pointers. This pass is executed at the end of the module translation.
+  Error finalizeDeclareTargetUsmIndirectLoads();
+
   void setConfig(OpenMPIRBuilderConfig C) { Config = C; }
 
   /// Finalize the underlying module, e.g., by outlining regions.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 125620bd49502..981036a955f72 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -743,6 +743,72 @@ CallInst *OpenMPIRBuilder::createRuntimeFunctionCall(FunctionCallee Callee,
   return Call;
 }
 
+void OpenMPIRBuilder::addDeclareTargetUsmRefPair(GlobalVariable *orig,
+                                                 GlobalVariable *refPtr) {
+  declareTargetUsmRefPtrPairs.emplace_back(orig, refPtr);
+}
+
+Error OpenMPIRBuilder::rewriteDeclareTargetGlobalUsesWithRefPtr(
+    GlobalVariable *origGV, GlobalVariable *refPtrGV) {
+  auto replaceUsesWithRefLoad = [refPtrGV](Instruction *inst,
+                                           Value *replaced) {
+    IRBuilder<> b(inst);
+    Value *rep =
+        b.CreateLoad(refPtrGV->getValueType(), refPtrGV, "decltgt.ref");
+    if (rep->getType() != replaced->getType())
+      rep = b.CreatePointerBitCastOrAddrSpaceCast(
+          rep, replaced->getType(), "decltgt.as");
+    inst->replaceUsesOfWith(replaced, rep);
+  };
+
+  SmallSetVector<User *, 8> users;
+  for (User *u : origGV->users())
+    users.insert(u);
+
+  for (User *u : users) {
+    if (auto *ce = dyn_cast<ConstantExpr>(u)) {
+      const bool isPointerCast =
+          ce->getOpcode() == Instruction::AddrSpaceCast ||
+          (ce->getOpcode() == Instruction::BitCast &&
+           ce->getType()->isPointerTy());
+
+      if (ce->getOperand(0) != origGV || !isPointerCast)
+        continue;
+
+      SmallVector<User *, 8> instUsers;
+      for (User *ceUser : ce->users())
+        if (isa<Instruction>(ceUser))
+          instUsers.push_back(ceUser);
+
+      for (User *ceUser : instUsers) {
+        auto *inst = cast<Instruction>(ceUser);
+        replaceUsesWithRefLoad(inst, ce);
+      }
+
+      if (ce->use_empty())
+        ce->destroyConstant();
+    } else if (auto *insn = dyn_cast<Instruction>(u)) {
+      replaceUsesWithRefLoad(insn, origGV);
+    }
+  }
+
+  if (!origGV->use_empty())
+    return createStringError(inconvertibleErrorCode(),
+                             "expected all uses of '%s' to be replaced",
+                             origGV->getName().str().c_str());
+
+  return Error::success();
+}
+
+Error OpenMPIRBuilder::finalizeDeclareTargetUsmIndirectLoads() {
+  if (!Config.isTargetDevice() || declareTargetUsmRefPtrPairs.empty())
+    return Error::success();
+  for (auto [orig, ref] : declareTargetUsmRefPtrPairs)
+    if (Error Err = rewriteDeclareTargetGlobalUsesWithRefPtr(orig, ref))
+      return Err;
+  return Error::success();
+}
+
 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
 
 static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
@@ -948,6 +1014,9 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     emitUsed("llvm.compiler.used", LLVMCompilerUsed);
   }
 
+  if (Error Err = finalizeDeclareTargetUsmIndirectLoads())
+    report_fatal_error(std::move(Err));
+
   IsFinalized = true;
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8614aed1ab80c..97b94c4bae00f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -30,6 +30,8 @@
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/ReplaceConstant.h"
@@ -7475,12 +7477,18 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
           (attribute.getCaptureClause().getValue() !=
                mlir::omp::DeclareTargetCaptureClause::to ||
            ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
-        ompBuilder->getAddrOfDeclareTargetVar(
+        llvm::Type *ptrTy = gVal->getType();
+        if (ompBuilder->Config.hasRequiresUnifiedSharedMemory())
+          ptrTy = llvm::PointerType::get(llvmModule->getContext(), 0);
+        llvm::Constant *refPtrConst = ompBuilder->getAddrOfDeclareTargetVar(
             captureClause, deviceClause, isDeclaration, isExternallyVisible,
             ompBuilder->getTargetEntryUniqueInfo(fileInfoCallBack, *vfs),
             mangledName, generatedRefs, /*OpenMPSimd*/ false, targetTriple,
-            gVal->getType(), /*GlobalInitializer*/ nullptr,
-            /*VariableLinkage*/ nullptr);
+            ptrTy, /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr);
+        if (auto *origGV = llvm::dyn_cast<llvm::GlobalVariable>(gVal))
+          if (auto *refPtrGV = llvm::dyn_cast_or_null<llvm::GlobalVariable>(
+                  refPtrConst))
+            ompBuilder->addDeclareTargetUsmRefPair(origGV, refPtrGV);
       }
     }
   }
diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir
new file mode 100644
index 0000000000000..fdbf16914e25c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-usm-ref-ptr.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// This tests the replacement of uses of a declare target global variable with
+// the unified shared memory (USM) generated reference pointer in an explicit device function.
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true, omp.requires = #omp<clause_requires unified_shared_memory>} {
+  // CHECK-DAG: @_QMmEnx_vals_decl_tgt_ref_ptr = weak global ptr null, align 8
+  llvm.mlir.global external @_QMmEnx_vals() {addr_space = 1 : i32, omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link)>} : i32 {
+    %0 = llvm.mlir.zero : i32
+    llvm.return %0 : i32
+  }
+
+  // CHECK-LABEL: define void @_QMmPget_dims_noarg(ptr %0)
+  llvm.func @_QMmPget_dims_noarg(%arg0: !llvm.ptr) attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} {
+    // CHECK: %[[REF_LOAD:.*]] = load ptr, ptr @_QMmEnx_vals_decl_tgt_ref_ptr, align 8
+    // CHECK: %[[AS_CAST:.*]] = addrspacecast ptr %[[REF_LOAD]] to ptr addrspace(1)
+    // CHECK: %[[VAL:.*]] = load i32, ptr addrspace(1) %[[AS_CAST]], align 4
+    // CHECK: store i32 %[[VAL]], ptr %0, align 4
+    %0 = llvm.mlir.addressof @_QMmEnx_vals : !llvm.ptr<1>
+    %1 = llvm.load %0 : !llvm.ptr<1> -> i32
+    llvm.store %1, %arg0 : i32, !llvm.ptr
+    llvm.return
+  }
+}
\ No newline at end of file
diff --git a/offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90 b/offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90
new file mode 100644
index 0000000000000..7d539a82af91b
--- /dev/null
+++ b/offload/test/offloading/fortran/declare-target-usm-ref-ptr.f90
@@ -0,0 +1,39 @@
+! Test declare target global replacement with USM reference pointer.
+!
+! REQUIRES: flang, amdgpu
+! RUN: %libomptarget-compile-fortran-generic -fopenmp-force-usm
+! RUN: env LIBOMPTARGET_INFO=16 HSA_XNACK=1 %libomptarget-run-generic 2>&1 | %fcheck-generic
+
+module m
+    implicit none
+    integer :: nx_vals
+    !$omp declare target(nx_vals)
+contains
+    subroutine get_dims_noarg(kv)
+        !$omp declare target
+        integer, intent(out) :: kv
+        kv = nx_vals
+    end subroutine get_dims_noarg
+end module m
+
+program reproducer
+    use m
+    implicit none
+    integer :: kv, kv_debug
+
+    nx_vals = 6
+    !$omp target enter data map(always, to: nx_vals)
+
+    kv_debug = -1
+    !$omp target map(tofrom: kv_debug)
+    call get_dims_noarg(kv)
+    kv_debug = kv
+    !$omp end target
+
+    print *, 'kv_debug after target (host)', kv_debug
+
+    !$omp target exit data map(release: nx_vals)
+end program reproducer
+
+! CHECK: PluginInterface device {{[0-9]+}} info: Launching kernel
+    ! CHECK: kv_debug after target (host) 6

``````````

</details>


https://github.com/llvm/llvm-project/pull/194291


More information about the llvm-branch-commits mailing list