[flang-commits] [flang] [mlir] [mlir][acc] Improve implicit deviceptr detection for alias (PR #195934)
Razvan Lupusoru via flang-commits
flang-commits at lists.llvm.org
Tue May 5 13:48:14 PDT 2026
https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/195934
The ACCImplicitData automatically is able to use deviceptr clause when variable is detected as being device data. However, it was missing check for own `acc declare deviceptr` attribute.
>From 123050845961c8adac39e3233975668960b867be Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Tue, 5 May 2026 13:47:15 -0700
Subject: [PATCH] [mlir][acc] Improve implicit deviceptr detection for alias
The ACCImplicitData automatically is able to use deviceptr
clause when variable is detected as being device data. However,
it was missing check for own `acc declare deviceptr` attribute.
---
.../Transforms/OpenACC/acc-implicit-data.fir | 33 +++++++++++++
.../mlir/Dialect/OpenACC/OpenACCUtils.h | 4 +-
.../OpenACC/Transforms/ACCImplicitData.cpp | 23 +++++----
.../Dialect/OpenACC/Utils/OpenACCUtils.cpp | 38 +++++++++------
.../Dialect/OpenACC/acc-implicit-data.mlir | 28 +++++++++++
.../Dialect/OpenACC/OpenACCUtilsTest.cpp | 47 +++++++++++++++++++
6 files changed, 150 insertions(+), 23 deletions(-)
diff --git a/flang/test/Transforms/OpenACC/acc-implicit-data.fir b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
index 050fe55747d23..18d9f310eabee 100644
--- a/flang/test/Transforms/OpenACC/acc-implicit-data.fir
+++ b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
@@ -394,3 +394,36 @@ func.func private @_FortranAAllocatableSetBounds(!fir.ref<!fir.box<none>>, i32,
// CHECK-NOT: acc.copyin
// CHECK: acc.deviceptr
// CHECK-NOT: acc.copyout
+
+// -----
+
+// Test argument mapped with deviceptr but used not via data mapping.
+func.func @test_fir_declare_deviceptr_arg_in_parallel(%arg0: !fir.ref<!fir.array<10xf64>>) {
+ %c10 = arith.constant 10 : index
+ %shape = fir.shape %c10 : (index) -> !fir.shape<1>
+ %arr_decl = fir.declare %arg0(%shape) {acc.declare = #acc.declare<dataClause = acc_deviceptr>, uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xf64>>
+ %arr_box = fir.embox %arr_decl(%shape) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
+ %devptr = acc.deviceptr var(%arr_box : !fir.box<!fir.array<10xf64>>) -> !fir.box<!fir.array<10xf64>> {name = "a"}
+ %token = acc.declare_enter dataOperands(%devptr : !fir.box<!fir.array<10xf64>>)
+ acc.parallel {
+ %addr = fir.box_addr %arr_box : (!fir.box<!fir.array<10xf64>>) -> !fir.ref<!fir.array<10xf64>>
+ %elem = fir.array_coor %arr_decl(%shape) %c10 : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>, index) -> !fir.ref<f64>
+ acc.yield
+ }
+ acc.declare_exit token(%token) dataOperands(%devptr : !fir.box<!fir.array<10xf64>>)
+ return
+}
+
+// CHECK-LABEL: func.func @test_fir_declare_deviceptr_arg_in_parallel
+// CHECK: %[[DECL:.*]] = fir.declare %{{.*}}{{.*}}{acc.declare = #acc.declare<dataClause = acc_deviceptr>{{.*}}
+// CHECK: %[[BOX:.*]] = fir.embox %[[DECL]]
+// CHECK: %[[DEVPTR:.*]] = acc.deviceptr var(%[[BOX]] : !fir.box<!fir.array<10xf64>>) -> !fir.box<!fir.array<10xf64>> {name = "a"}
+// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVPTR]] : !fir.box<!fir.array<10xf64>>)
+// CHECK: %[[IMPLICIT_BOX:.*]] = acc.deviceptr var(%[[BOX]] : !fir.box<!fir.array<10xf64>>) -> !fir.box<!fir.array<10xf64>> {implicit = true, name = "a"}
+// CHECK: %[[IMPLICIT_REF:.*]] = acc.deviceptr varPtr(%[[DECL]] : !fir.ref<!fir.array<10xf64>>) -> !fir.ref<!fir.array<10xf64>> {implicit = true, name = "a"}
+// CHECK: acc.parallel dataOperands(%[[IMPLICIT_BOX]], %[[IMPLICIT_REF]] : !fir.box<!fir.array<10xf64>>, !fir.ref<!fir.array<10xf64>>) {
+// CHECK: fir.box_addr %[[IMPLICIT_BOX]] : (!fir.box<!fir.array<10xf64>>) -> !fir.ref<!fir.array<10xf64>>
+// CHECK: fir.array_coor %[[IMPLICIT_REF]]
+// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVPTR]] : !fir.box<!fir.array<10xf64>>)
+// CHECK-NOT: acc.copyin
+// CHECK-NOT: acc.copyout
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
index c26ddbd54f1b9..5a8e4e362108d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
@@ -78,7 +78,9 @@ bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol,
/// Check if a value represents device data.
/// This checks if the value represents device data via the
-/// MappableType, PointerLikeType, and GlobalVariableOpInterface interfaces.
+/// MappableType, PointerLikeType, and GlobalVariableOpInterface interfaces,
+/// and whether the defining operation carries `acc.declare` with the deviceptr
+/// clause.
/// \param val The value to check
/// \return true if the value is device data, false otherwise
bool isDeviceValue(mlir::Value val);
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
index 95c8d1076ccb0..2de714ffcbc35 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -312,8 +312,13 @@ Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
// Only accept clauses that guarantee that the alias is present.
if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
acc::DevicePtrOp>(dataClauseOp))
- if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust())
+ if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Using existing data clause:\n\t" << *dataClauseOp
+ << "\n\tas reference when processing var:\n\t" << var
+ << "\n";);
return dataClauseOp;
+ }
}
}
return nullptr;
@@ -452,6 +457,15 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate(
typeCategory, acc::VariableTypeCategory::aggregate);
Location loc = computeConstructOp->getLoc();
+ if (acc::isDeviceValue(var)) {
+ // If the variable is device data, use deviceptr clause.
+ LLVM_DEBUG(llvm::dbgs() << "Using deviceptr clause because variable is "
+ "device data\n");
+ return acc::DevicePtrOp::create(builder, loc, var,
+ /*structured=*/true, /*implicit=*/true,
+ accSupport.getVariableName(var));
+ }
+
Operation *op = nullptr;
op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
dominatingDataClauses);
@@ -476,13 +490,6 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate(
acc::getBounds(op));
}
- if (acc::isDeviceValue(var)) {
- // If the variable is device data, use deviceptr clause.
- return acc::DevicePtrOp::create(builder, loc, var,
- /*structured=*/true, /*implicit=*/true,
- accSupport.getVariableName(var));
- }
-
if (isScalar) {
if (enableImplicitReductionCopy &&
acc::isOnlyUsedByReductionClauses(var,
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index f20ace4398696..411b0a7a4457d 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -245,23 +245,33 @@ bool mlir::acc::isDeviceValue(mlir::Value val) {
if (pointerLikeTy.isDeviceData(val))
return true;
+ mlir::Operation *defOp = val.getDefiningOp();
+ if (!defOp)
+ return false;
+
+ // `acc.declare` with deviceptr marks data that is already associated with
+ // the device.
+ if (auto declareAttr = defOp->getAttrOfType<mlir::acc::DeclareAttr>(
+ mlir::acc::getDeclareAttrName()))
+ if (declareAttr.getDataClause().getValue() ==
+ mlir::acc::DataClause::acc_deviceptr)
+ return true;
+
// Handle operations that access a partial entity - check if the base entity
// is device data.
- if (auto *defOp = val.getDefiningOp()) {
- if (auto partialAccess =
- dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) {
- if (mlir::Value base = partialAccess.getBaseEntity())
- return isDeviceValue(base);
- }
+ if (auto partialAccess =
+ dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) {
+ if (mlir::Value base = partialAccess.getBaseEntity())
+ return isDeviceValue(base);
+ }
- // Handle address_of - check if the referenced global is device data.
- if (auto addrOfIface =
- dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
- auto symbol = addrOfIface.getSymbol();
- if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
- mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
- return global.isDeviceData();
- }
+ // Handle address_of - check if the referenced global is device data.
+ if (auto addrOfIface =
+ dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
+ auto symbol = addrOfIface.getSymbol();
+ if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
+ mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
+ return global.isDeviceData();
}
return false;
diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
index df0dbbfee8b1d..3b6b5e1ade5e0 100644
--- a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
@@ -259,3 +259,31 @@ func.func @test_device_global_in_parallel() {
// CHECK: acc.deviceptr varPtr({{.*}} : memref<10xf32, #gpu.address_space<global>>) -> memref<10xf32, #gpu.address_space<global>> {implicit = true, name = ""}
// CHECK-NOT: acc.copyin
// CHECK-NOT: acc.copyout
+
+// -----
+
+// Test memref.view tagged with acc.declare deviceptr and used directly in region.
+func.func @test_declare_deviceptr_arg_in_parallel(%arg0: memref<?xi8>) {
+ %c0 = arith.constant 0 : index
+ %view = memref.view %arg0[%c0][] {acc.declare = #acc.declare<dataClause = acc_deviceptr>} : memref<?xi8> to memref<10xf32>
+ %devptr = acc.deviceptr varPtr(%view : memref<10xf32>) -> memref<10xf32> {name = "arg0"}
+ %token = acc.declare_enter dataOperands(%devptr : memref<10xf32>)
+ acc.parallel {
+ %c0_1 = arith.constant 0 : index
+ %load = memref.load %arg0[%c0_1] : memref<?xi8>
+ acc.yield
+ }
+ acc.declare_exit token(%token) dataOperands(%devptr : memref<10xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @test_declare_deviceptr_arg_in_parallel
+// CHECK: %[[VIEW:.*]] = memref.view %{{.*}}[{{.*}}][] {acc.declare = #acc.declare<dataClause = acc_deviceptr>} : memref<?xi8> to memref<10xf32>
+// CHECK: %[[DEVPTR:.*]] = acc.deviceptr varPtr(%[[VIEW]] : memref<10xf32>) -> memref<10xf32> {name = "arg0"}
+// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVPTR]] : memref<10xf32>)
+// CHECK: %[[IMPLICIT_DEVPTR:.*]] = acc.deviceptr varPtr(%{{.*}} : memref<?xi8>) -> memref<?xi8> {implicit = true, name = ""}
+// CHECK: acc.parallel dataOperands(%[[IMPLICIT_DEVPTR]] : memref<?xi8>) {
+// CHECK: memref.load %[[IMPLICIT_DEVPTR]][{{.*}}] : memref<?xi8>
+// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVPTR]] : memref<10xf32>)
+// CHECK-NOT: acc.copyin
+// CHECK-NOT: acc.copyout
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
index 489fe9108c04e..30c6f4312efaf 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
@@ -1398,6 +1398,33 @@ TEST_F(OpenACCUtilsTest, getDominatingDataClausesEmpty) {
// isDeviceValue Tests
//===----------------------------------------------------------------------===//
+namespace {
+static Value memrefViewFromBlockArgWithDeclare(OpBuilder &builder, Location loc,
+ MLIRContext *ctx,
+ DataClause clause,
+ ModuleOp module,
+ StringRef funcName) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(module.getBody());
+
+ auto i8BufTy = MemRefType::get({40}, builder.getI8Type());
+ auto viewTy = MemRefType::get({10}, builder.getI32Type());
+ auto funcType = builder.getFunctionType({i8BufTy}, {});
+ func::FuncOp funcOp = func::FuncOp::create(builder, loc, funcName, funcType);
+ Block *entry = funcOp.addEntryBlock();
+
+ builder.setInsertionPointToStart(entry);
+ Value buf = entry->getArgument(0);
+ Value c0 = arith::ConstantIndexOp::create(builder, loc, 0);
+ memref::ViewOp viewOp =
+ memref::ViewOp::create(builder, loc, viewTy, buf, c0, ValueRange{});
+ viewOp->setAttr(getDeclareAttrName(),
+ DeclareAttr::get(ctx, DataClauseAttr::get(ctx, clause)));
+ func::ReturnOp::create(builder, loc);
+ return viewOp.getResult();
+}
+} // namespace
+
TEST_F(OpenACCUtilsTest, isDeviceValueMemrefGlobalAddressSpace) {
// Test that a memref with GPU global address space is considered device data
auto gpuAddressSpace =
@@ -1525,6 +1552,26 @@ TEST_F(OpenACCUtilsTest, isDeviceValueGlobalWithoutGPUAddressSpace) {
EXPECT_FALSE(isDeviceValue(val));
}
+TEST_F(OpenACCUtilsTest, isDeviceValueAccDeclareDeviceptr) {
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(module->getBody());
+ Value val = memrefViewFromBlockArgWithDeclare(
+ b, loc, &context, DataClause::acc_deviceptr, module.get(),
+ "test_memref_view_declare_devptr");
+ EXPECT_TRUE(isDeviceValue(val));
+}
+
+TEST_F(OpenACCUtilsTest, isDeviceValueAccDeclareNonDeviceptr) {
+ OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(module->getBody());
+ Value val = memrefViewFromBlockArgWithDeclare(
+ b, loc, &context, DataClause::acc_copyin, module.get(),
+ "test_memref_view_declare_copyin");
+ EXPECT_FALSE(isDeviceValue(val));
+}
+
//===----------------------------------------------------------------------===//
// isValidValueUse Tests
//===----------------------------------------------------------------------===//
More information about the flang-commits
mailing list