[Mlir-commits] [flang] [mlir] [flang][OpenACC] Fix implicit data mapping for deviceptr inside host_data (PR #192710)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 23 18:58:44 PDT 2026
https://github.com/khaki3 updated https://github.com/llvm/llvm-project/pull/192710
>From 8ecbeebb978c609b8b0d521c044a348e2f3d6547 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Fri, 17 Apr 2026 11:14:36 -0700
Subject: [PATCH 1/3] [flang][OpenACC] Fix implicit data mapping for deviceptr
inside host_data
When a subroutine with `!$acc data deviceptr(b)` enclosing `!$acc serial`
is inlined, the ACCImplicitData pass fails to recognize that `b` is
already covered by the enclosing deviceptr clause. This happens because
the deviceptr operates on a box (fir.embox result) while the serial
construct uses the underlying ref, and alias analysis cannot match them
across the type boundary. The pass falls back to generating an implicit
acc.copyin/acc.copyout which tries to copy from a device pointer on the
host, causing a segfault at runtime.
Fix this with two changes:
1. Register fir::BoxAddrOp as implementing PartialEntityAccessOpInterface
so that isDeviceValue and other utilities can trace through box_addr
operations.
2. In getOriginalDataClauseOpForAlias, for deviceptr clauses additionally
check if the clause variable is directly derived from the candidate
(e.g., deviceptr operates on embox(candidate)), catching the case
where a fir.embox wraps the candidate ref into a box used by
acc.deviceptr.
Made-with: Cursor
---
.../OpenACC/Support/FIROpenACCOpsInterfaces.h | 1 +
.../Support/FIROpenACCOpsInterfaces.cpp | 6 ++++
.../Support/RegisterOpenACCExtensions.cpp | 2 ++
.../Transforms/OpenACC/acc-implicit-data.fir | 34 +++++++++++++++++++
.../OpenACC/Transforms/ACCImplicitData.cpp | 17 ++++++++--
5 files changed, 58 insertions(+), 2 deletions(-)
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index 4ffa0877ff190..854faba45bd2e 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -19,6 +19,7 @@
namespace fir {
class AddrOfOp;
+class BoxAddrOp;
class DeclareOp;
class GlobalOp;
} // namespace fir
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index 6d2c6ea5c8e57..eeb1b0bdfa4db 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -64,6 +64,12 @@ mlir::Value PartialEntityAccessModel<fir::CoordinateOp>::getBaseEntity(
return mlir::cast<fir::CoordinateOp>(op).getRef();
}
+template <>
+mlir::Value PartialEntityAccessModel<fir::BoxAddrOp>::getBaseEntity(
+ mlir::Operation *op) const {
+ return mlir::cast<fir::BoxAddrOp>(op).getVal();
+}
+
template <>
mlir::Value PartialEntityAccessModel<hlfir::DesignateOp>::getBaseEntity(
mlir::Operation *op) const {
diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
index f2fa5bf38872d..1766b7cc1d675 100644
--- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
@@ -53,6 +53,8 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) {
PartialEntityAccessModel<fir::CoordinateOp>>(*ctx);
fir::DeclareOp::attachInterface<PartialEntityAccessModel<fir::DeclareOp>>(
*ctx);
+ fir::BoxAddrOp::attachInterface<PartialEntityAccessModel<fir::BoxAddrOp>>(
+ *ctx);
fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
diff --git a/flang/test/Transforms/OpenACC/acc-implicit-data.fir b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
index 050fe55747d23..58c567c831715 100644
--- a/flang/test/Transforms/OpenACC/acc-implicit-data.fir
+++ b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
@@ -394,3 +394,37 @@ func.func private @_FortranAAllocatableSetBounds(!fir.ref<!fir.box<none>>, i32,
// CHECK-NOT: acc.copyin
// CHECK: acc.deviceptr
// CHECK-NOT: acc.copyout
+
+// -----
+
+// Test that acc.serial inside acc.data deviceptr generates implicit deviceptr
+// (not copyin) when the deviceptr var is an embox of the ref used by serial.
+// This pattern arises when a subroutine with !$acc data deviceptr(b) wrapping
+// !$acc serial is inlined and the deviceptr var is an embox of the ref used
+// by the serial construct.
+func.func @test_serial_inside_data_deviceptr_embox() {
+ %c10 = arith.constant 10 : index
+ %c1 = arith.constant 1 : index
+ %arr = fir.alloca !fir.array<10xf32> {bindc_name = "b"}
+ %shape = fir.shape %c10 : (index) -> !fir.shape<1>
+ %arr_decl = fir.declare %arr(%shape) {uniq_name = "b"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xf32>>
+ %box = fir.embox %arr_decl(%shape) : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf32>>
+ %devptr = acc.deviceptr var(%box : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "b(1:10)"}
+ acc.data dataOperands(%devptr : !fir.box<!fir.array<10xf32>>) {
+ acc.serial {
+ %elem = fir.array_coor %arr_decl(%shape) %c1 : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+ %val = fir.load %elem : !fir.ref<f32>
+ acc.yield
+ }
+ acc.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test_serial_inside_data_deviceptr_embox
+// CHECK: acc.deviceptr var({{.*}} : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "b(1:10)"}
+// CHECK: acc.data
+// CHECK: acc.deviceptr varPtr({{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {implicit = true, name = "b"}
+// CHECK: acc.serial dataOperands({{.*}} : !fir.ref<!fir.array<10xf32>>)
+// CHECK-NOT: acc.copyin
+// CHECK-NOT: acc.copyout
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
index 95c8d1076ccb0..d543b4f6b99d6 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -311,9 +311,22 @@ Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
if (auto *dataClauseOp = dataClause.getDefiningOp()) {
// Only accept clauses that guarantee that the alias is present.
if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
- acc::DevicePtrOp>(dataClauseOp))
- if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust())
+ acc::DevicePtrOp>(dataClauseOp)) {
+ Value clauseVar = acc::getVar(dataClauseOp);
+ if (aliasAnalysis.alias(clauseVar, var).isMust())
return dataClauseOp;
+ // For deviceptr clauses, also check if the clause variable is
+ // directly derived from 'var' (e.g., deviceptr operates on
+ // embox(var) — the box wrapping var). This arises when a
+ // subroutine with deviceptr is inlined and the deviceptr's box
+ // and the compute region's ref are different SSA values.
+ if (isa<acc::DevicePtrOp>(dataClauseOp)) {
+ for (Operation *user : var.getUsers()) {
+ if (llvm::is_contained(user->getResults(), clauseVar))
+ return dataClauseOp;
+ }
+ }
+ }
}
}
return nullptr;
>From d4e84659d73f48daa9661312e1b94586b7711412 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Tue, 21 Apr 2026 15:59:32 -0700
Subject: [PATCH 2/3] Address review: remove FIR-specific references from
upstream MLIR comment
Move the detailed inlining/embox explanation to the FIR test file where
it belongs, and keep the ACCImplicitData.cpp comment generic.
Made-with: Cursor
---
.../Transforms/OpenACC/acc-implicit-data.fir | 17 +++++++++++++----
.../OpenACC/Transforms/ACCImplicitData.cpp | 6 ++----
2 files changed, 15 insertions(+), 8 deletions(-)
diff --git a/flang/test/Transforms/OpenACC/acc-implicit-data.fir b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
index 58c567c831715..e20e3231438d8 100644
--- a/flang/test/Transforms/OpenACC/acc-implicit-data.fir
+++ b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
@@ -398,10 +398,19 @@ func.func private @_FortranAAllocatableSetBounds(!fir.ref<!fir.box<none>>, i32,
// -----
// Test that acc.serial inside acc.data deviceptr generates implicit deviceptr
-// (not copyin) when the deviceptr var is an embox of the ref used by serial.
-// This pattern arises when a subroutine with !$acc data deviceptr(b) wrapping
-// !$acc serial is inlined and the deviceptr var is an embox of the ref used
-// by the serial construct.
+// (not copyin) when the deviceptr clause variable is derived from the ref used
+// by the serial construct (here, via fir.embox wrapping the declared ref).
+// This pattern arises when a subroutine containing:
+// !$acc data deviceptr(b) ← deviceptr operates on embox(b) (a box type)
+// !$acc serial
+// ... uses b ... ← serial uses b directly (a ref type)
+// !$acc end serial
+// !$acc end data
+// is inlined into the caller. After inlining, the deviceptr's box and the
+// serial's ref are different SSA values with different types, so alias
+// analysis returns NoAlias. The pass must recognize that the deviceptr's
+// variable is derived from the ref and generate an implicit deviceptr clause
+// instead of falling back to copyin/copyout.
func.func @test_serial_inside_data_deviceptr_embox() {
%c10 = arith.constant 10 : index
%c1 = arith.constant 1 : index
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
index d543b4f6b99d6..a62804c5e3130 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -316,10 +316,8 @@ Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
if (aliasAnalysis.alias(clauseVar, var).isMust())
return dataClauseOp;
// For deviceptr clauses, also check if the clause variable is
- // directly derived from 'var' (e.g., deviceptr operates on
- // embox(var) — the box wrapping var). This arises when a
- // subroutine with deviceptr is inlined and the deviceptr's box
- // and the compute region's ref are different SSA values.
+ // directly derived from 'var' (e.g., through a wrapping
+ // operation that produces the clause variable from 'var').
if (isa<acc::DevicePtrOp>(dataClauseOp)) {
for (Operation *user : var.getUsers()) {
if (llvm::is_contained(user->getResults(), clauseVar))
>From 2c22ede67c5adab7457a8ea413f7c19bb8047a80 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Thu, 23 Apr 2026 18:56:27 -0700
Subject: [PATCH 3/3] Address review: use backward def-chain walk instead of
forward user-walk
Replace the unsound forward user-walk in getOriginalDataClauseOpForAlias
with a backward walk from clauseVar through PartialEntityAccessOpInterface.
Register fir::EmboxOp with PartialEntityAccessModel so the backward walk
can see through embox(declare(ref)) to reach the underlying ref.
Made-with: Cursor
---
.../OpenACC/Support/FIROpenACCOpsInterfaces.h | 1 +
.../OpenACC/Support/FIROpenACCOpsInterfaces.cpp | 6 ++++++
.../Support/RegisterOpenACCExtensions.cpp | 2 ++
.../OpenACC/Transforms/ACCImplicitData.cpp | 16 +++++++++++-----
4 files changed, 20 insertions(+), 5 deletions(-)
diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index 854faba45bd2e..3f203531cf00b 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -21,6 +21,7 @@ namespace fir {
class AddrOfOp;
class BoxAddrOp;
class DeclareOp;
+class EmboxOp;
class GlobalOp;
} // namespace fir
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index eeb1b0bdfa4db..c5697a2b46e2c 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -70,6 +70,12 @@ mlir::Value PartialEntityAccessModel<fir::BoxAddrOp>::getBaseEntity(
return mlir::cast<fir::BoxAddrOp>(op).getVal();
}
+template <>
+mlir::Value PartialEntityAccessModel<fir::EmboxOp>::getBaseEntity(
+ mlir::Operation *op) const {
+ return mlir::cast<fir::EmboxOp>(op).getMemref();
+}
+
template <>
mlir::Value PartialEntityAccessModel<hlfir::DesignateOp>::getBaseEntity(
mlir::Operation *op) const {
diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
index 1766b7cc1d675..b0d7ae2dd6883 100644
--- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
@@ -55,6 +55,8 @@ void registerOpenACCExtensions(mlir::DialectRegistry ®istry) {
*ctx);
fir::BoxAddrOp::attachInterface<PartialEntityAccessModel<fir::BoxAddrOp>>(
*ctx);
+ fir::EmboxOp::attachInterface<PartialEntityAccessModel<fir::EmboxOp>>(
+ *ctx);
fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
index a62804c5e3130..58306151cef01 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -315,12 +315,18 @@ Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
Value clauseVar = acc::getVar(dataClauseOp);
if (aliasAnalysis.alias(clauseVar, var).isMust())
return dataClauseOp;
- // For deviceptr clauses, also check if the clause variable is
- // directly derived from 'var' (e.g., through a wrapping
- // operation that produces the clause variable from 'var').
+ // For deviceptr clauses, walk the def-chain of the clause
+ // variable backward through PartialEntityAccessOpInterface to
+ // check if 'var' is its base entity.
if (isa<acc::DevicePtrOp>(dataClauseOp)) {
- for (Operation *user : var.getUsers()) {
- if (llvm::is_contained(user->getResults(), clauseVar))
+ Value v = clauseVar;
+ while (auto *defOp = v.getDefiningOp()) {
+ if (auto partialOp =
+ dyn_cast<acc::PartialEntityAccessOpInterface>(defOp))
+ v = partialOp.getBaseEntity();
+ else
+ break;
+ if (v == var)
return dataClauseOp;
}
}
More information about the Mlir-commits
mailing list