[Mlir-commits] [flang] [mlir] [acc][flang] Add isDeviceData APIs for device data detection (PR #176219)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 15 10:42:48 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Razvan Lupusoru (razvanlupusoru)

<details>
<summary>Changes</summary>

Add comprehensive APIs to detect device-resident data across OpenACC type and operation interfaces. This enables passes to identify data that is already on the device (e.g., CUF device/managed/constant memory, GPU address spaces) and handle it appropriately.

New interface methods:
- PointerLikeType::isDeviceData(Value): Returns true if the pointer points to device data.
- MappableType::isDeviceData(Value): Returns true if the variable represents device data.
- GlobalVariableOpInterface::isDeviceData(): Returns true if the global variable is device data.

New utilities in OpenACCUtils:
- acc::isDeviceValue(Value): Checks if a value represents device data by querying type interfaces, PartialEntityAccessOpInterface for base entities, and AddressOfGlobalOpInterface for global symbols.
- acc::isValidValueUse(Value, Region): Checks if a value is legal in an OpenACC region by verifying it comes from a data operation, is only used by private clauses, or is device data.

Updated isValidSymbolUse to check
GlobalVariableOpInterface::isDeviceData()
for symbols referencing device-resident globals.

FIR implementations check for CUF data attributes (device, managed, constant, shared, unified) on operations, block arguments, and globals. The implementation traces through fir.rebox, fir.embox, fir.declare, hlfir.declare, and fir.address_of to find the underlying data source.

Memref implementations check for gpu::AddressSpaceAttr on the memref type.

Updated ACCImplicitData to use acc::isDeviceValue for generating acc.deviceptr clauses for device-resident data instead of copyin/copyout.

Updated OpenACCSupport::isValidValueUse to fallback to the new acc::isValidValueUse utility.

---

Patch is 36.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/176219.diff


16 Files Affected:

- (modified) flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h (+1) 
- (modified) flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h (+5) 
- (modified) flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp (+14) 
- (modified) flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp (+129) 
- (modified) flang/test/Transforms/OpenACC/acc-implicit-data.fir (+38) 
- (modified) mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h (+3-9) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td (+4) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td (+24) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h (+17) 
- (modified) mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp (+1-1) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+12) 
- (modified) mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp (+17-6) 
- (modified) mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp (+56) 
- (modified) mlir/test/Dialect/OpenACC/acc-implicit-data.mlir (+37) 
- (modified) mlir/unittests/Dialect/OpenACC/CMakeLists.txt (+1) 
- (modified) mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp (+277-1) 


``````````diff
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index c6f52bbd0c64b..b0252fee7b5a6 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -66,6 +66,7 @@ struct GlobalVariableModel
           GlobalVariableModel, fir::GlobalOp> {
   bool isConstant(mlir::Operation *op) const;
   mlir::Region *getInitRegion(mlir::Operation *op) const;
+  bool isDeviceData(mlir::Operation *op) const;
 };
 
 template <typename Op>
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
index 9db67afeda5e9..a659cb191f917 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
@@ -52,6 +52,9 @@ struct OpenACCPointerLikeModel
   bool genStore(mlir::Type pointer, mlir::OpBuilder &builder,
                 mlir::Location loc, mlir::Value valueToStore,
                 mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
+
+  bool isDeviceData(mlir::Type pointer, mlir::Value var) const;
+
 };
 
 template <typename T>
@@ -102,6 +105,8 @@ struct OpenACCMappableModel
                         mlir::ValueRange bounds,
                         mlir::acc::ReductionOperator op,
                         mlir::Attribute fastmathFlags) const;
+
+  bool isDeviceData(mlir::Type type, mlir::Value var) const;
 };
 
 } // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index dacafb1eeb4b2..0f906fc2e3fd7 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -12,6 +12,8 @@
 
 #include "flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h"
 
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/Support/InternalNames.h"
@@ -84,6 +86,18 @@ mlir::Region *GlobalVariableModel::getInitRegion(mlir::Operation *op) const {
   return globalOp.hasInitializationBody() ? &globalOp.getRegion() : nullptr;
 }
 
+bool GlobalVariableModel::isDeviceData(mlir::Operation *op) const {
+  if (auto dataAttr = cuf::getDataAttr(op)) {
+    auto attr = dataAttr.getValue();
+    return attr == cuf::DataAttribute::Device ||
+           attr == cuf::DataAttribute::Managed ||
+           attr == cuf::DataAttribute::Constant ||
+           attr == cuf::DataAttribute::Shared ||
+           attr == cuf::DataAttribute::Unified;
+  }
+  return false;
+}
+
 // Helper to recursively process address-of operations in derived type
 // descriptors and collect all needed fir.globals.
 static void processAddrOfOpInDerivedTypeDescriptor(
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
index 2997428b7b5d2..ea17ea81ce031 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
@@ -12,10 +12,12 @@
 
 #include "flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/CUFCommon.h"
 #include "flang/Optimizer/Builder/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/IntrinsicCall.h"
+#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
 #include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
@@ -1487,4 +1489,131 @@ template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genStore(
     mlir::Value valueToStore,
     mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
 
+/// Helper function to check if a CUDA attribute represents device data.
+static bool isCUDADeviceAttribute(cuf::DataAttribute attr) {
+  return attr == cuf::DataAttribute::Device ||
+         attr == cuf::DataAttribute::Managed ||
+         attr == cuf::DataAttribute::Constant ||
+         attr == cuf::DataAttribute::Shared ||
+         attr == cuf::DataAttribute::Unified;
+}
+
+/// Helper function to check if an operation has CUDA device data attributes.
+static bool hasCUDADeviceDataAttr(mlir::Operation *op) {
+  if (!op)
+    return false;
+
+  // Check for CUF data attribute on the operation
+  if (auto dataAttr = cuf::getDataAttr(op)) {
+    if (isCUDADeviceAttribute(dataAttr.getValue()))
+      return true;
+  }
+
+  return false;
+}
+
+/// Check CUDA attributes on a function argument.
+static bool hasCUDADeviceAttrOnFuncArg(mlir::BlockArgument blockArg) {
+  auto *owner = blockArg.getOwner();
+  if (!owner)
+    return false;
+
+  auto *parentOp = owner->getParentOp();
+  if (!parentOp)
+    return false;
+
+  if (auto funcLike = mlir::dyn_cast<mlir::FunctionOpInterface>(parentOp)) {
+    unsigned argIndex = blockArg.getArgNumber();
+    if (argIndex < funcLike.getNumArguments()) {
+      if (auto attr = funcLike.getArgAttr(argIndex, cuf::getDataAttrName())) {
+        if (auto cudaAttr = mlir::dyn_cast<cuf::DataAttributeAttr>(attr))
+          return isCUDADeviceAttribute(cudaAttr.getValue());
+      }
+    }
+  }
+  return false;
+}
+
+/// Shared implementation for checking if a value represents device data.
+static bool isDeviceDataImpl(mlir::Value var) {
+  // Strip casts to find the underlying value.
+  mlir::Value currentVal = stripCasts(var, /*stripDeclare=*/false);
+
+  // Handle block arguments (function parameters)
+  if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(currentVal))
+    return hasCUDADeviceAttrOnFuncArg(blockArg);
+
+  mlir::Operation *defOp = currentVal.getDefiningOp();
+  if (!defOp)
+    return false;
+
+  // Check for CUDA attributes on the defining operation.
+  if (hasCUDADeviceDataAttr(defOp))
+    return true;
+
+  // Handle operations that access a partial entity - check if the base entity
+  // is device data.
+  if (auto partialAccess =
+          mlir::dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) {
+    if (mlir::Value base = partialAccess.getBaseEntity())
+      return isDeviceDataImpl(base);
+  }
+
+  // Handle fir.rebox - if the underlying box is device data, so is the result.
+  if (auto rebox = mlir::dyn_cast<fir::ReboxOp>(defOp))
+    return isDeviceDataImpl(rebox.getBox());
+
+  // Handle fir.embox - check if the underlying memref is device data.
+  if (auto embox = mlir::dyn_cast<fir::EmboxOp>(defOp))
+    return isDeviceDataImpl(embox.getMemref());
+
+  // Handle address_of - check the referenced global.
+  if (auto addrOfIface =
+          mlir::dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
+    auto symbol = addrOfIface.getSymbol();
+    if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
+            mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
+      return global.isDeviceData();
+    return false;
+  }
+
+  return false;
+}
+
+template <typename Ty>
+bool OpenACCPointerLikeModel<Ty>::isDeviceData(mlir::Type pointer,
+                                               mlir::Value var) const {
+  return isDeviceDataImpl(var);
+}
+
+template bool OpenACCPointerLikeModel<fir::ReferenceType>::isDeviceData(
+    mlir::Type, mlir::Value) const;
+template bool
+    OpenACCPointerLikeModel<fir::PointerType>::isDeviceData(mlir::Type,
+                                                            mlir::Value) const;
+template bool
+    OpenACCPointerLikeModel<fir::HeapType>::isDeviceData(mlir::Type,
+                                                         mlir::Value) const;
+template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::isDeviceData(
+    mlir::Type, mlir::Value) const;
+
+template <typename Ty>
+bool OpenACCMappableModel<Ty>::isDeviceData(mlir::Type type,
+                                            mlir::Value var) const {
+  return isDeviceDataImpl(var);
+}
+
+template bool
+    OpenACCMappableModel<fir::BaseBoxType>::isDeviceData(mlir::Type,
+                                                         mlir::Value) const;
+template bool
+    OpenACCMappableModel<fir::ReferenceType>::isDeviceData(mlir::Type,
+                                                           mlir::Value) const;
+template bool
+    OpenACCMappableModel<fir::HeapType>::isDeviceData(mlir::Type,
+                                                      mlir::Value) const;
+template bool
+    OpenACCMappableModel<fir::PointerType>::isDeviceData(mlir::Type,
+                                                         mlir::Value) const;
+
 } // namespace fir::acc
diff --git a/flang/test/Transforms/OpenACC/acc-implicit-data.fir b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
index 058390ab05669..050fe55747d23 100644
--- a/flang/test/Transforms/OpenACC/acc-implicit-data.fir
+++ b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
@@ -356,3 +356,41 @@ func.func @test_acc_declare_deviceptr() {
 // CHECK-NOT: acc.copyin
 // CHECK: acc.deviceptr
 
+
+// -----
+
+// Test that implicit deviceptr is generated for a symbol with CUDA device attribute
+func.func @test_cuda_device_implicit_deviceptr() {
+  %0 = fir.dummy_scope : !fir.dscope
+  %1 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+  %2 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
+  %c0 = arith.constant 0 : index
+  %3 = fir.shape %c0 : (index) -> !fir.shape<1>
+  %4 = fir.embox %2(%3) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+  fir.store %4 to %1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+  %5:2 = hlfir.declare %1 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEa"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+  %c1 = arith.constant 1 : index
+  %c16_i32 = arith.constant 16 : i32
+  %c0_i32 = arith.constant 0 : i32
+  %6 = fir.convert %5#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+  %7 = fir.convert %c1 : (index) -> i64
+  %8 = fir.convert %c16_i32 : (i32) -> i64
+  fir.call @_FortranAAllocatableSetBounds(%6, %c0_i32, %7, %8) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> ()
+  %9 = cuf.allocate %5#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
+  acc.serial {
+    %cst = arith.constant 1.000000e+02 : f32
+    %10 = fir.load %5#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+    %c5 = arith.constant 5 : index
+    %11 = hlfir.designate %10 (%c5)  : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
+    hlfir.assign %cst to %11 : f32, !fir.ref<f32>
+    acc.yield
+  }
+  return
+}
+
+func.func private @_FortranAAllocatableSetBounds(!fir.ref<!fir.box<none>>, i32, i64, i64) -> ()
+
+// CHECK-LABEL: func.func @test_cuda_device_implicit_deviceptr
+// CHECK-NOT: acc.copyin
+// CHECK: acc.deviceptr
+// CHECK-NOT: acc.copyout
diff --git a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h
index 984eaa8b8d78b..c49452556f0c4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h
+++ b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h
@@ -50,6 +50,7 @@
 #ifndef MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H
 #define MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H
 
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
 #include "mlir/IR/Remarks.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Pass/AnalysisManager.h"
@@ -60,13 +61,6 @@
 namespace mlir {
 namespace acc {
 
-// Forward declarations
-enum class RecipeKind : uint32_t;
-bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
-                      Operation **definingOpPtr);
-remark::detail::InFlightRemark emitRemark(Operation *op, const Twine &message,
-                                          llvm::StringRef category);
-
 namespace detail {
 /// This class contains internal trait classes used by OpenACCSupport.
 /// It follows the Concept-Model pattern used throughout MLIR (e.g., in
@@ -170,10 +164,10 @@ struct OpenACCSupportTraits {
     }
 
     bool isValidValueUse(Value v, Region &region) final {
-      if constexpr (has_isValidSymbolUse<ImplT>::value)
+      if constexpr (has_isValidValueUse<ImplT>::value)
         return impl.isValidValueUse(v, region);
       else
-        return false;
+        return acc::isValidValueUse(v, region);
     }
 
   private:
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
index 44632eb4cdac4..95a8f22a3ddfa 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
@@ -75,6 +75,10 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
       }]>,
     InterfaceMethod<"Get the initialization region (returns nullptr if none)",
       "::mlir::Region*", "getInitRegion", (ins)>,
+    InterfaceMethod<"Check if the global variable is device data",
+      "bool", "isDeviceData", (ins), [{
+        return false;
+      }]>,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
index 1d55762e2f7d9..ca60f9bca5cf1 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
@@ -220,6 +220,18 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
         return false;
       }]
     >,
+    InterfaceMethod<
+      /*description=*/[{
+        Returns true if the pointer points to device data.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isDeviceData",
+      /*args=*/(ins "::mlir::Value":$var),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return false;
+      }]
+    >,
   ];
 }
 
@@ -442,6 +454,18 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> {
         return true;
       }]
     >,
+    InterfaceMethod<
+      /*description=*/[{
+        Returns true if the variable represents device data.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isDeviceData",
+      /*args=*/(ins "::mlir::Value":$var),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return false;
+      }]
+    >,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
index e3f4e6889ffe8..a436050babce2 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
@@ -67,6 +67,23 @@ mlir::Value getBaseEntity(mlir::Value val);
 bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol,
                       mlir::Operation **definingOpPtr = nullptr);
 
+/// Check if a value represents device data.
+/// This checks if the value represents device data via the
+/// MappableType, PointerLikeType, and GlobalVariableOpInterface interfaces.
+/// \param val The value to check
+/// \return true if the value is device data, false otherwise
+bool isDeviceValue(mlir::Value val);
+
+/// Check if a value use is valid in an OpenACC region.
+/// This is true if:
+/// - The value is produced by an ACC data entry operation
+/// - The value is device data
+/// - The value is only used by private clauses in the region
+/// \param val The value to check
+/// \param region The OpenACC region
+/// \return true if the value use is valid, false otherwise
+bool isValidValueUse(mlir::Value val, mlir::Region &region);
+
 /// Collects all data clauses that dominate the compute construct.
 /// This includes data clauses from:
 /// - The compute construct itself
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
index c487c43e8369c..4fbe3e06c2532 100644
--- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
+++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
@@ -59,7 +59,7 @@ bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
 bool OpenACCSupport::isValidValueUse(Value v, Region &region) {
   if (impl)
     return impl->isValidValueUse(v, region);
-  return false;
+  return acc::isValidValueUse(v, region);
 }
 
 } // namespace acc
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6872aa9d26e29..55a438eb3906b 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -245,6 +245,12 @@ struct MemRefPointerLikeModel
     memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
     return true;
   }
+
+  bool isDeviceData(Type pointer, Value var) const {
+    auto memrefTy = cast<T>(pointer);
+    Attribute memSpace = memrefTy.getMemorySpace();
+    return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
+  }
 };
 
 struct LLVMPointerPointerLikeModel
@@ -290,6 +296,12 @@ struct MemrefGlobalVariableModel
     // GlobalOp uses attributes for initialization, not regions
     return nullptr;
   }
+
+  bool isDeviceData(Operation *op) const {
+    auto globalOp = cast<memref::GlobalOp>(op);
+    Attribute memSpace = globalOp.getType().getMemorySpace();
+    return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
+  }
 };
 
 struct GPULaunchOffloadRegionModel
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
index c4b43c5cb1b27..a5b5753cd414e 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -285,16 +285,17 @@ static bool isCandidateForImplicitData(Value val, Region &accRegion,
       !acc::isMappableType(val.getType()))
     return false;
 
-  if (accSupport.isValidValueUse(val, accRegion))
-    return false;
-
   // If this is already coming from a data clause, we do not need to generate
   // another.
   if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.getDefiningOp()))
     return false;
 
-  // If this is only used by private clauses, it is not a real live-in.
-  if (acc::isOnlyUsedByPrivateClauses(val, accRegion))
+  // Device data is a candidate - it will get a deviceptr clause.
+  if (acc::isDeviceValue(val))
+    return true;
+
+  // If it is otherwise valid, skip it.
+  if (accSupport.isValidValueUse(val, accRegion))
     return false;
 
   return true;
@@ -472,7 +473,17 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate(
                                   /*structured=*/true, /*implicit=*/true,
                                   accSupport.getVariableName(var),
                                   acc::getBounds(op));
-  } else if (isScalar) {
+  }
+
+
+  if (acc::isDeviceValue(var)) {
+    // If the variable is device data, use deviceptr clause.
+    return acc::DevicePtrOp::create(builder, loc, var,
+                                    /*structured=*/true, /*implicit=*/true,
+                                    accSupport.getVariableName(var));
+  }
+
+  if (isScalar) {
     if (enableImplicitReductionCopy &&
         acc::isOnlyUsedByReductionClauses(var,
                                           computeConstructOp->getRegion(0))) {
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index e97fd1471e6e0..bd3dda48d44b4 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -182,6 +182,13 @@ bool mlir::acc::isValidSymbolUse(mlir::Operation *user,
                 mlir::acc::FirstprivateRecipeOp>(definingOp))
     return true;
 
+  // Check if the defining op is a global variable that is device data.
+  // Device data i...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/176219


More information about the Mlir-commits mailing list