[Mlir-commits] [flang] [mlir] [acc][flang] Add isDeviceData APIs for device data detection (PR #176219)
Razvan Lupusoru
llvmlistbot at llvm.org
Thu Jan 15 10:42:17 PST 2026
https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/176219
Add comprehensive APIs to detect device-resident data across OpenACC type and operation interfaces. This enables passes to identify data that is already on the device (e.g., CUF device/managed/constant memory, GPU address spaces) and handle it appropriately.
New interface methods:
- PointerLikeType::isDeviceData(Value): Returns true if the pointer points to device data.
- MappableType::isDeviceData(Value): Returns true if the variable represents device data.
- GlobalVariableOpInterface::isDeviceData(): Returns true if the global variable is device data.
New utilities in OpenACCUtils:
- acc::isDeviceValue(Value): Checks if a value represents device data by querying type interfaces, PartialEntityAccessOpInterface for base entities, and AddressOfGlobalOpInterface for global symbols.
- acc::isValidValueUse(Value, Region): Checks if a value is legal in an OpenACC region by verifying it comes from a data operation, is only used by private clauses, or is device data.
Updated isValidSymbolUse to check
GlobalVariableOpInterface::isDeviceData()
for symbols referencing device-resident globals.
FIR implementations check for CUF data attributes (device, managed, constant, shared, unified) on operations, block arguments, and globals. The implementation traces through fir.rebox, fir.embox, fir.declare, hlfir.declare, and fir.address_of to find the underlying data source.
Memref implementations check for gpu::AddressSpaceAttr on the memref type.
Updated ACCImplicitData to use acc::isDeviceValue for generating acc.deviceptr clauses for device-resident data instead of copyin/copyout.
Updated OpenACCSupport::isValidValueUse to fallback to the new acc::isValidValueUse utility.
>From 613ad1dd7e921d714241493af0cc229744878b47 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Thu, 15 Jan 2026 10:40:55 -0800
Subject: [PATCH] [acc][flang] Add isDeviceData APIs for device data detection
Add comprehensive APIs to detect device-resident data across OpenACC
type and operation interfaces. This enables passes to identify data that
is already on the device (e.g., CUF device/managed/constant memory,
GPU address spaces) and handle it appropriately.
New interface methods:
- PointerLikeType::isDeviceData(Value): Returns true if the pointer
points to device data.
- MappableType::isDeviceData(Value): Returns true if the variable
represents device data.
- GlobalVariableOpInterface::isDeviceData(): Returns true if the global
variable is device data.
New utilities in OpenACCUtils:
- acc::isDeviceValue(Value): Checks if a value represents device data
by querying type interfaces, PartialEntityAccessOpInterface for base
entities, and AddressOfGlobalOpInterface for global symbols.
- acc::isValidValueUse(Value, Region): Checks if a value is legal in an
OpenACC region by verifying it comes from a data operation, is only
used by private clauses, or is device data.
Updated isValidSymbolUse to check
GlobalVariableOpInterface::isDeviceData()
for symbols referencing device-resident globals.
FIR implementations check for CUF data attributes (device, managed,
constant, shared, unified) on operations, block arguments, and globals.
The implementation traces through fir.rebox, fir.embox, fir.declare,
hlfir.declare, and fir.address_of to find the underlying data source.
Memref implementations check for gpu::AddressSpaceAttr on the memref
type.
Updated ACCImplicitData to use acc::isDeviceValue for generating
acc.deviceptr clauses for device-resident data instead of
copyin/copyout.
Updated OpenACCSupport::isValidValueUse to fallback to the new
acc::isValidValueUse utility.
---
.../OpenACC/Support/FIROpenACCOpsInterfaces.h | 1 +
.../Support/FIROpenACCTypeInterfaces.h | 5 +
.../Support/FIROpenACCOpsInterfaces.cpp | 14 +
.../Support/FIROpenACCTypeInterfaces.cpp | 129 ++++++++
.../Transforms/OpenACC/acc-implicit-data.fir | 38 +++
.../Dialect/OpenACC/Analysis/OpenACCSupport.h | 12 +-
.../Dialect/OpenACC/OpenACCOpsInterfaces.td | 4 +
.../Dialect/OpenACC/OpenACCTypeInterfaces.td | 24 ++
.../mlir/Dialect/OpenACC/OpenACCUtils.h | 17 ++
.../OpenACC/Analysis/OpenACCSupport.cpp | 2 +-
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 12 +
.../OpenACC/Transforms/ACCImplicitData.cpp | 23 +-
.../Dialect/OpenACC/Utils/OpenACCUtils.cpp | 56 ++++
.../Dialect/OpenACC/acc-implicit-data.mlir | 37 +++
mlir/unittests/Dialect/OpenACC/CMakeLists.txt | 1 +
.../Dialect/OpenACC/OpenACCUtilsTest.cpp | 278 +++++++++++++++++-
16 files changed, 636 insertions(+), 17 deletions(-)
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index c6f52bbd0c64b..b0252fee7b5a6 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -66,6 +66,7 @@ struct GlobalVariableModel
GlobalVariableModel, fir::GlobalOp> {
bool isConstant(mlir::Operation *op) const;
mlir::Region *getInitRegion(mlir::Operation *op) const;
+ bool isDeviceData(mlir::Operation *op) const;
};
template <typename Op>
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
index 9db67afeda5e9..a659cb191f917 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h
@@ -52,6 +52,9 @@ struct OpenACCPointerLikeModel
bool genStore(mlir::Type pointer, mlir::OpBuilder &builder,
mlir::Location loc, mlir::Value valueToStore,
mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
+
+ bool isDeviceData(mlir::Type pointer, mlir::Value var) const;
+
};
template <typename T>
@@ -102,6 +105,8 @@ struct OpenACCMappableModel
mlir::ValueRange bounds,
mlir::acc::ReductionOperator op,
mlir::Attribute fastmathFlags) const;
+
+ bool isDeviceData(mlir::Type type, mlir::Value var) const;
};
} // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index dacafb1eeb4b2..0f906fc2e3fd7 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -12,6 +12,8 @@
#include "flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h"
+#include "flang/Optimizer/Builder/CUFCommon.h"
+#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/InternalNames.h"
@@ -84,6 +86,18 @@ mlir::Region *GlobalVariableModel::getInitRegion(mlir::Operation *op) const {
return globalOp.hasInitializationBody() ? &globalOp.getRegion() : nullptr;
}
+bool GlobalVariableModel::isDeviceData(mlir::Operation *op) const {
+ if (auto dataAttr = cuf::getDataAttr(op)) {
+ auto attr = dataAttr.getValue();
+ return attr == cuf::DataAttribute::Device ||
+ attr == cuf::DataAttribute::Managed ||
+ attr == cuf::DataAttribute::Constant ||
+ attr == cuf::DataAttribute::Shared ||
+ attr == cuf::DataAttribute::Unified;
+ }
+ return false;
+}
+
// Helper to recursively process address-of operations in derived type
// descriptors and collect all needed fir.globals.
static void processAddrOfOpInDerivedTypeDescriptor(
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
index 2997428b7b5d2..ea17ea81ce031 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.cpp
@@ -12,10 +12,12 @@
#include "flang/Optimizer/OpenACC/Support/FIROpenACCTypeInterfaces.h"
#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
+#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
@@ -1487,4 +1489,131 @@ template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::genStore(
mlir::Value valueToStore,
mlir::TypedValue<mlir::acc::PointerLikeType> destPtr) const;
+/// Helper function to check if a CUDA attribute represents device data.
+static bool isCUDADeviceAttribute(cuf::DataAttribute attr) {
+ return attr == cuf::DataAttribute::Device ||
+ attr == cuf::DataAttribute::Managed ||
+ attr == cuf::DataAttribute::Constant ||
+ attr == cuf::DataAttribute::Shared ||
+ attr == cuf::DataAttribute::Unified;
+}
+
+/// Helper function to check if an operation has CUDA device data attributes.
+static bool hasCUDADeviceDataAttr(mlir::Operation *op) {
+ if (!op)
+ return false;
+
+ // Check for CUF data attribute on the operation
+ if (auto dataAttr = cuf::getDataAttr(op)) {
+ if (isCUDADeviceAttribute(dataAttr.getValue()))
+ return true;
+ }
+
+ return false;
+}
+
+/// Check CUDA attributes on a function argument.
+static bool hasCUDADeviceAttrOnFuncArg(mlir::BlockArgument blockArg) {
+ auto *owner = blockArg.getOwner();
+ if (!owner)
+ return false;
+
+ auto *parentOp = owner->getParentOp();
+ if (!parentOp)
+ return false;
+
+ if (auto funcLike = mlir::dyn_cast<mlir::FunctionOpInterface>(parentOp)) {
+ unsigned argIndex = blockArg.getArgNumber();
+ if (argIndex < funcLike.getNumArguments()) {
+ if (auto attr = funcLike.getArgAttr(argIndex, cuf::getDataAttrName())) {
+ if (auto cudaAttr = mlir::dyn_cast<cuf::DataAttributeAttr>(attr))
+ return isCUDADeviceAttribute(cudaAttr.getValue());
+ }
+ }
+ }
+ return false;
+}
+
+/// Shared implementation for checking if a value represents device data.
+static bool isDeviceDataImpl(mlir::Value var) {
+ // Strip casts to find the underlying value.
+ mlir::Value currentVal = stripCasts(var, /*stripDeclare=*/false);
+
+ // Handle block arguments (function parameters)
+ if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(currentVal))
+ return hasCUDADeviceAttrOnFuncArg(blockArg);
+
+ mlir::Operation *defOp = currentVal.getDefiningOp();
+ if (!defOp)
+ return false;
+
+ // Check for CUDA attributes on the defining operation.
+ if (hasCUDADeviceDataAttr(defOp))
+ return true;
+
+ // Handle operations that access a partial entity - check if the base entity
+ // is device data.
+ if (auto partialAccess =
+ mlir::dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) {
+ if (mlir::Value base = partialAccess.getBaseEntity())
+ return isDeviceDataImpl(base);
+ }
+
+ // Handle fir.rebox - if the underlying box is device data, so is the result.
+ if (auto rebox = mlir::dyn_cast<fir::ReboxOp>(defOp))
+ return isDeviceDataImpl(rebox.getBox());
+
+ // Handle fir.embox - check if the underlying memref is device data.
+ if (auto embox = mlir::dyn_cast<fir::EmboxOp>(defOp))
+ return isDeviceDataImpl(embox.getMemref());
+
+ // Handle address_of - check the referenced global.
+ if (auto addrOfIface =
+ mlir::dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
+ auto symbol = addrOfIface.getSymbol();
+ if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
+ mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
+ return global.isDeviceData();
+ return false;
+ }
+
+ return false;
+}
+
+template <typename Ty>
+bool OpenACCPointerLikeModel<Ty>::isDeviceData(mlir::Type pointer,
+ mlir::Value var) const {
+ return isDeviceDataImpl(var);
+}
+
+template bool OpenACCPointerLikeModel<fir::ReferenceType>::isDeviceData(
+ mlir::Type, mlir::Value) const;
+template bool
+ OpenACCPointerLikeModel<fir::PointerType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool
+ OpenACCPointerLikeModel<fir::HeapType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool OpenACCPointerLikeModel<fir::LLVMPointerType>::isDeviceData(
+ mlir::Type, mlir::Value) const;
+
+template <typename Ty>
+bool OpenACCMappableModel<Ty>::isDeviceData(mlir::Type type,
+ mlir::Value var) const {
+ return isDeviceDataImpl(var);
+}
+
+template bool
+ OpenACCMappableModel<fir::BaseBoxType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool
+ OpenACCMappableModel<fir::ReferenceType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool
+ OpenACCMappableModel<fir::HeapType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+template bool
+ OpenACCMappableModel<fir::PointerType>::isDeviceData(mlir::Type,
+ mlir::Value) const;
+
} // namespace fir::acc
diff --git a/flang/test/Transforms/OpenACC/acc-implicit-data.fir b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
index 058390ab05669..050fe55747d23 100644
--- a/flang/test/Transforms/OpenACC/acc-implicit-data.fir
+++ b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
@@ -356,3 +356,41 @@ func.func @test_acc_declare_deviceptr() {
// CHECK-NOT: acc.copyin
// CHECK: acc.deviceptr
+
+// -----
+
+// Test that implicit deviceptr is generated for a symbol with CUDA device attribute
+func.func @test_cuda_device_implicit_deviceptr() {
+ %0 = fir.dummy_scope : !fir.dscope
+ %1 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+ %2 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
+ %c0 = arith.constant 0 : index
+ %3 = fir.shape %c0 : (index) -> !fir.shape<1>
+ %4 = fir.embox %2(%3) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+ fir.store %4 to %1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+ %5:2 = hlfir.declare %1 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEa"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+ %c1 = arith.constant 1 : index
+ %c16_i32 = arith.constant 16 : i32
+ %c0_i32 = arith.constant 0 : i32
+ %6 = fir.convert %5#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
+ %7 = fir.convert %c1 : (index) -> i64
+ %8 = fir.convert %c16_i32 : (i32) -> i64
+ fir.call @_FortranAAllocatableSetBounds(%6, %c0_i32, %7, %8) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> ()
+ %9 = cuf.allocate %5#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>} -> i32
+ acc.serial {
+ %cst = arith.constant 1.000000e+02 : f32
+ %10 = fir.load %5#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+ %c5 = arith.constant 5 : index
+ %11 = hlfir.designate %10 (%c5) : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
+ hlfir.assign %cst to %11 : f32, !fir.ref<f32>
+ acc.yield
+ }
+ return
+}
+
+func.func private @_FortranAAllocatableSetBounds(!fir.ref<!fir.box<none>>, i32, i64, i64) -> ()
+
+// CHECK-LABEL: func.func @test_cuda_device_implicit_deviceptr
+// CHECK-NOT: acc.copyin
+// CHECK: acc.deviceptr
+// CHECK-NOT: acc.copyout
diff --git a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h
index 984eaa8b8d78b..c49452556f0c4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h
+++ b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h
@@ -50,6 +50,7 @@
#ifndef MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H
#define MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/IR/Remarks.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/AnalysisManager.h"
@@ -60,13 +61,6 @@
namespace mlir {
namespace acc {
-// Forward declarations
-enum class RecipeKind : uint32_t;
-bool isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
- Operation **definingOpPtr);
-remark::detail::InFlightRemark emitRemark(Operation *op, const Twine &message,
- llvm::StringRef category);
-
namespace detail {
/// This class contains internal trait classes used by OpenACCSupport.
/// It follows the Concept-Model pattern used throughout MLIR (e.g., in
@@ -170,10 +164,10 @@ struct OpenACCSupportTraits {
}
bool isValidValueUse(Value v, Region ®ion) final {
- if constexpr (has_isValidSymbolUse<ImplT>::value)
+ if constexpr (has_isValidValueUse<ImplT>::value)
return impl.isValidValueUse(v, region);
else
- return false;
+ return acc::isValidValueUse(v, region);
}
private:
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
index 44632eb4cdac4..95a8f22a3ddfa 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
@@ -75,6 +75,10 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
}]>,
InterfaceMethod<"Get the initialization region (returns nullptr if none)",
"::mlir::Region*", "getInitRegion", (ins)>,
+ InterfaceMethod<"Check if the global variable is device data",
+ "bool", "isDeviceData", (ins), [{
+ return false;
+ }]>,
];
}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
index 1d55762e2f7d9..ca60f9bca5cf1 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td
@@ -220,6 +220,18 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
return false;
}]
>,
+ InterfaceMethod<
+ /*description=*/[{
+ Returns true if the pointer points to device data.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isDeviceData",
+ /*args=*/(ins "::mlir::Value":$var),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
+ >,
];
}
@@ -442,6 +454,18 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> {
return true;
}]
>,
+ InterfaceMethod<
+ /*description=*/[{
+ Returns true if the variable represents device data.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isDeviceData",
+ /*args=*/(ins "::mlir::Value":$var),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
+ >,
];
}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
index e3f4e6889ffe8..a436050babce2 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
@@ -67,6 +67,23 @@ mlir::Value getBaseEntity(mlir::Value val);
bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol,
mlir::Operation **definingOpPtr = nullptr);
+/// Check if a value represents device data.
+/// This checks if the value represents device data via the
+/// MappableType, PointerLikeType, and GlobalVariableOpInterface interfaces.
+/// \param val The value to check
+/// \return true if the value is device data, false otherwise
+bool isDeviceValue(mlir::Value val);
+
+/// Check if a value use is valid in an OpenACC region.
+/// This is true if:
+/// - The value is produced by an ACC data entry operation
+/// - The value is device data
+/// - The value is only used by private clauses in the region
+/// \param val The value to check
+/// \param region The OpenACC region
+/// \return true if the value use is valid, false otherwise
+bool isValidValueUse(mlir::Value val, mlir::Region ®ion);
+
/// Collects all data clauses that dominate the compute construct.
/// This includes data clauses from:
/// - The compute construct itself
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
index c487c43e8369c..4fbe3e06c2532 100644
--- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
+++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
@@ -59,7 +59,7 @@ bool OpenACCSupport::isValidSymbolUse(Operation *user, SymbolRefAttr symbol,
bool OpenACCSupport::isValidValueUse(Value v, Region ®ion) {
if (impl)
return impl->isValidValueUse(v, region);
- return false;
+ return acc::isValidValueUse(v, region);
}
} // namespace acc
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6872aa9d26e29..55a438eb3906b 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -245,6 +245,12 @@ struct MemRefPointerLikeModel
memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
return true;
}
+
+ bool isDeviceData(Type pointer, Value var) const {
+ auto memrefTy = cast<T>(pointer);
+ Attribute memSpace = memrefTy.getMemorySpace();
+ return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
+ }
};
struct LLVMPointerPointerLikeModel
@@ -290,6 +296,12 @@ struct MemrefGlobalVariableModel
// GlobalOp uses attributes for initialization, not regions
return nullptr;
}
+
+ bool isDeviceData(Operation *op) const {
+ auto globalOp = cast<memref::GlobalOp>(op);
+ Attribute memSpace = globalOp.getType().getMemorySpace();
+ return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
+ }
};
struct GPULaunchOffloadRegionModel
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
index c4b43c5cb1b27..a5b5753cd414e 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -285,16 +285,17 @@ static bool isCandidateForImplicitData(Value val, Region &accRegion,
!acc::isMappableType(val.getType()))
return false;
- if (accSupport.isValidValueUse(val, accRegion))
- return false;
-
// If this is already coming from a data clause, we do not need to generate
// another.
if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.getDefiningOp()))
return false;
- // If this is only used by private clauses, it is not a real live-in.
- if (acc::isOnlyUsedByPrivateClauses(val, accRegion))
+ // Device data is a candidate - it will get a deviceptr clause.
+ if (acc::isDeviceValue(val))
+ return true;
+
+ // If it is otherwise valid, skip it.
+ if (accSupport.isValidValueUse(val, accRegion))
return false;
return true;
@@ -472,7 +473,17 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate(
/*structured=*/true, /*implicit=*/true,
accSupport.getVariableName(var),
acc::getBounds(op));
- } else if (isScalar) {
+ }
+
+
+ if (acc::isDeviceValue(var)) {
+ // If the variable is device data, use deviceptr clause.
+ return acc::DevicePtrOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ }
+
+ if (isScalar) {
if (enableImplicitReductionCopy &&
acc::isOnlyUsedByReductionClauses(var,
computeConstructOp->getRegion(0))) {
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index e97fd1471e6e0..bd3dda48d44b4 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -182,6 +182,13 @@ bool mlir::acc::isValidSymbolUse(mlir::Operation *user,
mlir::acc::FirstprivateRecipeOp>(definingOp))
return true;
+ // Check if the defining op is a global variable that is device data.
+ // Device data is already resident on the device and does not need mapping.
+ if (auto globalVar =
+ mlir::dyn_cast<mlir::acc::GlobalVariableOpInterface>(definingOp))
+ if (globalVar.isDeviceData())
+ return true;
+
// Check if the defining op is a function
if (auto func =
mlir::dyn_cast_if_present<mlir::FunctionOpInterface>(definingOp)) {
@@ -208,6 +215,55 @@ bool mlir::acc::isValidSymbolUse(mlir::Operation *user,
return hasDeclare;
}
+bool mlir::acc::isDeviceValue(mlir::Value val) {
+ // Check if the value is device data via type interfaces.
+ // Device data is already resident on the device and does not need mapping.
+ if (auto mappableTy = dyn_cast<mlir::acc::MappableType>(val.getType()))
+ if (mappableTy.isDeviceData(val))
+ return true;
+
+ if (auto pointerLikeTy = dyn_cast<mlir::acc::PointerLikeType>(val.getType()))
+ if (pointerLikeTy.isDeviceData(val))
+ return true;
+
+ // Handle operations that access a partial entity - check if the base entity
+ // is device data.
+ if (auto *defOp = val.getDefiningOp()) {
+ if (auto partialAccess =
+ dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) {
+ if (mlir::Value base = partialAccess.getBaseEntity())
+ return isDeviceValue(base);
+ }
+
+ // Handle address_of - check if the referenced global is device data.
+ if (auto addrOfIface =
+ dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
+ auto symbol = addrOfIface.getSymbol();
+ if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
+ mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
+ return global.isDeviceData();
+ }
+ }
+
+ return false;
+}
+
+bool mlir::acc::isValidValueUse(mlir::Value val, mlir::Region ®ion) {
+ // If this is produced by an ACC data entry operation, it is valid.
+ if (isa_and_nonnull<ACC_DATA_ENTRY_OPS>(val.getDefiningOp()))
+ return true;
+
+ // If the value is only used by private clauses, it is not a live-in.
+ if (isOnlyUsedByPrivateClauses(val, region))
+ return true;
+
+ // If this is device data, it is valid.
+ if (isDeviceValue(val))
+ return true;
+
+ return false;
+}
+
llvm::SmallVector<mlir::Value>
mlir::acc::getDominatingDataClauses(mlir::Operation *computeConstructOp,
mlir::DominanceInfo &domInfo,
diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
index 6909fe6a4eb84..df0dbbfee8b1d 100644
--- a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
@@ -222,3 +222,40 @@ func.func @test_memref_view(%size: index) {
// CHECK-LABEL: func.func @test_memref_view
// CHECK: acc.present varPtr({{.*}} : memref<8x64xf32>) -> memref<8x64xf32> {implicit = true, name = ""}
+// -----
+
+// Test device data (memref with GPU address space) - should generate deviceptr
+func.func @test_device_data_in_parallel() {
+ %alloc = memref.alloca() : memref<10xf32, #gpu.address_space<global>>
+ acc.parallel {
+ %c0 = arith.constant 0 : index
+ %load = memref.load %alloc[%c0] : memref<10xf32, #gpu.address_space<global>>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_device_data_in_parallel
+// CHECK: acc.deviceptr varPtr({{.*}} : memref<10xf32, #gpu.address_space<global>>) -> memref<10xf32, #gpu.address_space<global>> {implicit = true, name = ""}
+// CHECK-NOT: acc.copyin
+// CHECK-NOT: acc.copyout
+
+// -----
+
+// Test device global (memref.global with GPU address space) - should generate deviceptr
+memref.global @device_global : memref<10xf32, #gpu.address_space<global>>
+
+func.func @test_device_global_in_parallel() {
+ %global = memref.get_global @device_global : memref<10xf32, #gpu.address_space<global>>
+ acc.parallel {
+ %c0 = arith.constant 0 : index
+ %load = memref.load %global[%c0] : memref<10xf32, #gpu.address_space<global>>
+ acc.yield
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_device_global_in_parallel
+// CHECK: acc.deviceptr varPtr({{.*}} : memref<10xf32, #gpu.address_space<global>>) -> memref<10xf32, #gpu.address_space<global>> {implicit = true, name = ""}
+// CHECK-NOT: acc.copyin
+// CHECK-NOT: acc.copyout
diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
index 29448d2af5537..ef0c8fb11d0c9 100644
--- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt
@@ -10,6 +10,7 @@ mlir_target_link_libraries(MLIROpenACCTests
MLIRIR
MLIRAffineDialect
MLIRFuncDialect
+ MLIRGPUDialect
MLIRMemRefDialect
MLIRArithDialect
MLIROpenACCDialect
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
index 60d87326c0e9b..f0cd845168abf 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/BuiltinOps.h"
@@ -30,7 +31,8 @@ class OpenACCUtilsTest : public ::testing::Test {
protected:
OpenACCUtilsTest() : b(&context), loc(UnknownLoc::get(&context)) {
context.loadDialect<acc::OpenACCDialect, arith::ArithDialect,
- memref::MemRefDialect, func::FuncDialect>();
+ gpu::GPUDialect, memref::MemRefDialect,
+ func::FuncDialect>();
}
MLIRContext context;
@@ -1367,3 +1369,277 @@ TEST_F(OpenACCUtilsTest, getDominatingDataClausesEmpty) {
// Should be empty
EXPECT_EQ(dataClauses.size(), 0ul);
}
+
+//===----------------------------------------------------------------------===//
+// isDeviceValue Tests
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCUtilsTest, isDeviceValueMemrefGlobalAddressSpace) {
+ // Test that a memref with GPU global address space is considered device data
+ auto gpuAddressSpace =
+ gpu::AddressSpaceAttr::get(&context, gpu::AddressSpace::Global);
+ auto memrefTy = MemRefType::get({10}, b.getI32Type(), AffineMap(), gpuAddressSpace);
+
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+ Value val = allocOp->getResult();
+
+ // Should return true since memref has GPU global address space
+ EXPECT_TRUE(isDeviceValue(val));
+}
+
+TEST_F(OpenACCUtilsTest, isDeviceValueMemrefWorkgroupAddressSpace) {
+ // Test that a memref with GPU workgroup address space is considered device
+ // data
+ auto gpuAddressSpace =
+ gpu::AddressSpaceAttr::get(&context, gpu::AddressSpace::Workgroup);
+ auto memrefTy = MemRefType::get({10}, b.getI32Type(), AffineMap(), gpuAddressSpace);
+
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+ Value val = allocOp->getResult();
+
+ // Should return true since memref has GPU workgroup address space
+ EXPECT_TRUE(isDeviceValue(val));
+}
+
+TEST_F(OpenACCUtilsTest, isDeviceValueMemrefPrivateAddressSpace) {
+ // Test that a memref with GPU private address space is considered device
+ // data
+ auto gpuAddressSpace =
+ gpu::AddressSpaceAttr::get(&context, gpu::AddressSpace::Private);
+ auto memrefTy = MemRefType::get({10}, b.getI32Type(), AffineMap(), gpuAddressSpace);
+
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+ Value val = allocOp->getResult();
+
+ // Should return true since memref has GPU private address space
+ EXPECT_TRUE(isDeviceValue(val));
+}
+
+TEST_F(OpenACCUtilsTest, isDeviceValueMemrefNoAddressSpace) {
+ // Test that a regular memref without GPU address space is not device data
+ auto memrefTy = MemRefType::get({10}, b.getI32Type());
+
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+ Value val = allocOp->getResult();
+
+ // Should return false since memref has no GPU address space
+ EXPECT_FALSE(isDeviceValue(val));
+}
+
+TEST_F(OpenACCUtilsTest, isDeviceValueNonMappableType) {
+ // Test with a non-mappable type (i32 value)
+ auto constOp = arith::ConstantOp::create(b, loc, b.getI32IntegerAttr(42));
+ Value val = constOp.getResult();
+
+ // Should return false since i32 is not a MappableType or PointerLikeType
+ EXPECT_FALSE(isDeviceValue(val));
+}
+
+TEST_F(OpenACCUtilsTest, isDeviceValueGlobalWithGPUAddressSpace) {
+ // Test that memref.get_global referencing a global with GPU address space
+ // is considered device data
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ Block *moduleBlock = module->getBody();
+
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleBlock);
+
+ // Create a memref type with GPU global address space
+ auto gpuAddressSpace =
+ gpu::AddressSpaceAttr::get(&context, gpu::AddressSpace::Global);
+ auto memrefTy =
+ MemRefType::get({10}, b.getI32Type(), AffineMap(), gpuAddressSpace);
+
+ // Create a global op with the GPU address space memref type
+ llvm::StringRef globalName = "device_global";
+ OwningOpRef<memref::GlobalOp> globalOp = memref::GlobalOp::create(
+ b, loc, globalName, /*sym_visibility=*/b.getStringAttr("public"),
+ /*type=*/memrefTy, /*initial_value=*/Attribute(),
+ /*constant=*/false, /*alignment=*/IntegerAttr());
+
+ // Create a get_global that references the device global
+ OwningOpRef<memref::GetGlobalOp> getGlobalOp =
+ memref::GetGlobalOp::create(b, loc, memrefTy, globalName);
+ Value val = getGlobalOp->getResult();
+
+ // Should return true since the global has GPU address space
+ EXPECT_TRUE(isDeviceValue(val));
+}
+
+TEST_F(OpenACCUtilsTest, isDeviceValueGlobalWithoutGPUAddressSpace) {
+ // Test that memref.get_global referencing a regular global is not device data
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ Block *moduleBlock = module->getBody();
+
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleBlock);
+
+ // Create a regular memref type without GPU address space
+ auto memrefTy = MemRefType::get({10}, b.getI32Type());
+
+ // Create a global op without GPU address space
+ llvm::StringRef globalName = "host_global";
+ OwningOpRef<memref::GlobalOp> globalOp = memref::GlobalOp::create(
+ b, loc, globalName, /*sym_visibility=*/b.getStringAttr("public"),
+ /*type=*/memrefTy, /*initial_value=*/Attribute(),
+ /*constant=*/false, /*alignment=*/IntegerAttr());
+
+ // Create a get_global that references the host global
+ OwningOpRef<memref::GetGlobalOp> getGlobalOp =
+ memref::GetGlobalOp::create(b, loc, memrefTy, globalName);
+ Value val = getGlobalOp->getResult();
+
+ // Should return false since the global has no GPU address space
+ EXPECT_FALSE(isDeviceValue(val));
+}
+
+//===----------------------------------------------------------------------===//
+// isValidValueUse Tests
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCUtilsTest, isValidValueUseFromDataEntryOp) {
+ // Create a module to hold a function
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ Block *moduleBlock = module->getBody();
+
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleBlock);
+
+ // Create a function with a serial region
+ auto funcType = b.getFunctionType({}, {});
+ OwningOpRef<func::FuncOp> funcOp =
+ func::FuncOp::create(b, loc, "test_func", funcType);
+ Block *entryBlock = funcOp->addEntryBlock();
+ b.setInsertionPointToStart(entryBlock);
+
+ // Create a memref and a copyin operation
+ auto memrefTy = MemRefType::get({10}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+ TypedValue<PointerLikeType> varPtr =
+ cast<TypedValue<PointerLikeType>>(allocOp->getResult());
+
+ OwningOpRef<CopyinOp> copyinOp =
+ CopyinOp::create(b, loc, varPtr, /*structured=*/true, /*implicit=*/false,
+ /*name=*/"test_var");
+ Value dataClauseResult = copyinOp->getAccVar();
+
+ // Create a serial region
+ OwningOpRef<SerialOp> serialOp =
+ SerialOp::create(b, loc, TypeRange{}, ValueRange{});
+ Region &serialRegion = serialOp->getRegion();
+
+ // Value from data entry op should be valid
+ EXPECT_TRUE(isValidValueUse(dataClauseResult, serialRegion));
+}
+
+TEST_F(OpenACCUtilsTest, isValidValueUseDeviceData) {
+ // Create a module to hold a function
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ Block *moduleBlock = module->getBody();
+
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleBlock);
+
+ // Create a function
+ auto funcType = b.getFunctionType({}, {});
+ OwningOpRef<func::FuncOp> funcOp =
+ func::FuncOp::create(b, loc, "test_func", funcType);
+ Block *entryBlock = funcOp->addEntryBlock();
+ b.setInsertionPointToStart(entryBlock);
+
+ // Create a memref with GPU address space (device data)
+ auto gpuAddressSpace =
+ gpu::AddressSpaceAttr::get(&context, gpu::AddressSpace::Global);
+ auto memrefTy =
+ MemRefType::get({10}, b.getI32Type(), AffineMap(), gpuAddressSpace);
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+ Value deviceVal = allocOp->getResult();
+
+ // Create a serial region
+ OwningOpRef<SerialOp> serialOp =
+ SerialOp::create(b, loc, TypeRange{}, ValueRange{});
+ Region &serialRegion = serialOp->getRegion();
+
+ // Device data should be valid
+ EXPECT_TRUE(isValidValueUse(deviceVal, serialRegion));
+}
+
+TEST_F(OpenACCUtilsTest, isValidValueUseOnlyUsedByPrivate) {
+ // Create a module to hold a function
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ Block *moduleBlock = module->getBody();
+
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleBlock);
+
+ // Create a function
+ auto funcType = b.getFunctionType({}, {});
+ OwningOpRef<func::FuncOp> funcOp =
+ func::FuncOp::create(b, loc, "test_func", funcType);
+ Block *entryBlock = funcOp->addEntryBlock();
+ b.setInsertionPointToStart(entryBlock);
+
+ // Create a memref
+ auto memrefTy = MemRefType::get({10}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+ TypedValue<PointerLikeType> varPtr =
+ cast<TypedValue<PointerLikeType>>(allocOp->getResult());
+
+ // Create a serial region with a private clause using the variable
+ OwningOpRef<SerialOp> serialOp =
+ SerialOp::create(b, loc, TypeRange{}, ValueRange{});
+ Region &serialRegion = serialOp->getRegion();
+ Block *serialBlock = b.createBlock(&serialRegion);
+ b.setInsertionPointToStart(serialBlock);
+
+ OwningOpRef<PrivateOp> privateOp = PrivateOp::create(
+ b, loc, varPtr, /*structured=*/true, /*implicit=*/false);
+
+ // Value only used by private clause should be valid
+ EXPECT_TRUE(isValidValueUse(varPtr, serialRegion));
+}
+
+TEST_F(OpenACCUtilsTest, isValidValueUseRegularValue) {
+ // Create a module to hold a function
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ Block *moduleBlock = module->getBody();
+
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleBlock);
+
+ // Create a function
+ auto funcType = b.getFunctionType({}, {});
+ OwningOpRef<func::FuncOp> funcOp =
+ func::FuncOp::create(b, loc, "test_func", funcType);
+ Block *entryBlock = funcOp->addEntryBlock();
+ b.setInsertionPointToStart(entryBlock);
+
+ // Create a regular memref without GPU address space
+ auto memrefTy = MemRefType::get({10}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+ Value regularVal = allocOp->getResult();
+
+ // Create a serial region with a non-private use of the value
+ OwningOpRef<SerialOp> serialOp =
+ SerialOp::create(b, loc, TypeRange{}, ValueRange{});
+ Region &serialRegion = serialOp->getRegion();
+ Block *serialBlock = b.createBlock(&serialRegion);
+ b.setInsertionPointToStart(serialBlock);
+
+ // Add a function call to create a synthetic use of the value inside the
+ // region
+ func::CallOp::create(b, loc, "some_func", TypeRange{},
+ ValueRange{regularVal});
+
+ // Regular value (not device data, not from data op, not private) should be
+ // invalid
+ EXPECT_FALSE(isValidValueUse(regularVal, serialRegion));
+}
More information about the Mlir-commits
mailing list