[flang-commits] [flang] [mlir][acc] Create UseDeviceOp for both results of hlfir.declare (PR #148017)
via flang-commits
flang-commits at lists.llvm.org
Thu Jul 10 10:59:22 PDT 2025
https://github.com/nvptm created https://github.com/llvm/llvm-project/pull/148017
A sample such as
```
program test
integer :: N = 100
real*8 :: b(-1:N)
!$acc data copy(b)
!$acc host_data use_device(b)
call vadd(b)
!$acc end host_data
!$acc end data
end
```
is lowered to
```
%13:2 = hlfir.declare %11(%12) {uniq_name = "_QFEb"} : (!fir.ref<!fir.array<?xf64>>, !fir.shapeshift<1>) -> (!fir.box<!fir.array<?xf64>>, !fir.ref<!fir.array<?xf64>>)
%14 = acc.copyin var(%13#0 : !fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>> {dataClause = #acc<data_clause acc_copy>, name = "b"}
acc.data dataOperands(%14 : !fir.box<!fir.array<?xf64>>) {
%15 = acc.use_device var(%13#0 : !fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>> {name = "b"}
acc.host_data dataOperands(%15 : !fir.box<!fir.array<?xf64>>) {
fir.call @_QPvadd(%13#1) fastmath<contract> : (!fir.ref<!fir.array<?xf64>>) -> ()
acc.terminator
}
acc.terminator
}
acc.copyout accVar(%14 : !fir.box<!fir.array<?xf64>>) to var(%13#0 : !fir.box<!fir.array<?xf64>>) {dataClause = #acc<data_clause acc_copy>, name = "b"}
```
Note that while the use_device clause is applied to %13#0, the argument passed to vadd is %13#1. To avoid problems later in lowering, this change additionally applies the use_device clause to %13#1, so that the resulting MLIR is
```
%13:2 = hlfir.declare %11(%12) {uniq_name = "_QFEb"} : (!fir.ref<!fir.array<?xf64>>, !fir.shapeshift<1>) -> (!fir.box<!fir.array<?xf64>>, !fir.ref<!fir.array<?xf64>>)
%14 = acc.copyin var(%13#0 : !fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>> {dataClause = #acc<data_clause acc_copy>, name = "b"}
acc.data dataOperands(%14 : !fir.box<!fir.array<?xf64>>) {
%15 = acc.use_device var(%13#0 : !fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>> {name = "b"}
%16 = acc.use_device varPtr(%13#1 : !fir.ref<!fir.array<?xf64>>) -> !fir.ref<!fir.array<?xf64>> {name = "b"}
acc.host_data dataOperands(%15, %16 : !fir.box<!fir.array<?xf64>>, !fir.ref<!fir.array<?xf64>>) {
fir.call @_QPvadd(%13#1) fastmath<contract> : (!fir.ref<!fir.array<?xf64>>) -> ()
acc.terminator
}
acc.terminator
}
acc.copyout accVar(%14 : !fir.box<!fir.array<?xf64>>) to var(%13#0 : !fir.box<!fir.array<?xf64>>) {dataClause = #acc<data_clause acc_copy>, name = "b"}
```
>From 0fb4bc6169d76081b641c78db1c032ab876a6905 Mon Sep 17 00:00:00 2001
From: nvpm <pmathew at nvidia.com>
Date: Wed, 9 Jul 2025 23:13:49 -0700
Subject: [PATCH 1/4] use_device for all other results of hlfir.declare
---
flang/lib/Lower/OpenACC.cpp | 17 ++++++++++++++++-
1 file changed, 16 insertions(+), 1 deletion(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 42842bcb41a74..4f637b88fd269 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -724,7 +724,7 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
/*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
/*genDefaultBounds=*/generateDefaultBounds,
/*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
- LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "Here \n"; info.dump(llvm::dbgs()));
// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
@@ -738,6 +738,21 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
implicit, dataClause, baseAddr.getType(), async, asyncDeviceTypes,
asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent);
dataOperands.push_back(op.getAccVar());
+ // If the input value has a descriptor, we need to create a use device op
+ // for the descriptor as well as the base address.
+ if constexpr (std::is_same_v<Op, mlir::acc::UseDeviceOp>) {
+ LLVM_DEBUG(llvm::dbgs() << __func__ << " found usedeviceop \n"; info.dump(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << __func__ << " had previously created and added usedeviceop \n"; op.dump());
+ if (mlir::isa<hlfir::DeclareOp>(baseAddr.getDefiningOp())) {
+ Op op = createDataEntryOp<Op>(
+ builder, operandLocation, baseAddr.getDefiningOp()->getResult(1), asFortran, bounds, structured,
+ implicit, dataClause, baseAddr.getDefiningOp()->getResult(1).getType(), async, asyncDeviceTypes,
+ asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent);
+ LLVM_DEBUG(llvm::dbgs() << __func__ << " created usedeviceop \n"; op.dump());
+ dataOperands.push_back(op.getAccVar());
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "added usedeviceop on descriptor\n"; info.dump(llvm::dbgs()));
+ }
+ }
}
}
>From be91a9ab05a43dff381347f0e6f49dbbe019669f Mon Sep 17 00:00:00 2001
From: nvpm <pmathew at nvidia.com>
Date: Thu, 10 Jul 2025 09:19:58 -0700
Subject: [PATCH 2/4] update tests
---
.../acc-host-data-unwrap-defaultbounds.f90 | 14 +++++++------
flang/test/Lower/OpenACC/acc-host-data.f90 | 21 +++++++++++--------
2 files changed, 20 insertions(+), 15 deletions(-)
diff --git a/flang/test/Lower/OpenACC/acc-host-data-unwrap-defaultbounds.f90 b/flang/test/Lower/OpenACC/acc-host-data-unwrap-defaultbounds.f90
index 164eb32a8f684..2de7cc5761a2b 100644
--- a/flang/test/Lower/OpenACC/acc-host-data-unwrap-defaultbounds.f90
+++ b/flang/test/Lower/OpenACC/acc-host-data-unwrap-defaultbounds.f90
@@ -15,15 +15,17 @@ subroutine acc_host_data()
!$acc end host_data
! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
-! CHECK: %[[DA:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
-! CHECK: acc.host_data dataOperands(%[[DA]] : !fir.ref<!fir.array<10xf32>>)
+! CHECK: %[[DA0:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+! CHECK: %[[DA1:.*]] = acc.use_device varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+ ! CHECK: acc.host_data dataOperands(%[[DA0]], %[[DA1]] : !fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
!$acc host_data use_device(a) if_present
!$acc end host_data
! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
-! CHECK: %[[DA:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
-! CHECK: acc.host_data dataOperands(%[[DA]] : !fir.ref<!fir.array<10xf32>>) {
+! CHECK: %[[DA0:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+! CHECK: %[[DA1:.*]] = acc.use_device varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+! CHECK: acc.host_data dataOperands(%[[DA0]], %[[DA1]] : !fir.ref<!fir.array<10xf32>>{{.*}}) {
! CHECK: } attributes {ifPresent}
!$acc host_data use_device(a) if(ifCondition)
@@ -33,14 +35,14 @@ subroutine acc_host_data()
! CHECK: %[[DA:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
! CHECK: %[[LOAD_IFCOND:.*]] = fir.load %[[DECLIFCOND]]#0 : !fir.ref<!fir.logical<4>>
! CHECK: %[[IFCOND_I1:.*]] = fir.convert %[[LOAD_IFCOND]] : (!fir.logical<4>) -> i1
-! CHECK: acc.host_data if(%[[IFCOND_I1]]) dataOperands(%[[DA]] : !fir.ref<!fir.array<10xf32>>)
+! CHECK: acc.host_data if(%[[IFCOND_I1]]) dataOperands(%[[DA]]{{.*}} : !fir.ref<!fir.array<10xf32>>{{.*}})
!$acc host_data use_device(a) if(.true.)
!$acc end host_data
! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
! CHECK: %[[DA:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
-! CHECK: acc.host_data dataOperands(%[[DA]] : !fir.ref<!fir.array<10xf32>>)
+! CHECK: acc.host_data dataOperands(%[[DA]]{{.*}} : !fir.ref<!fir.array<10xf32>>{{.*}})
!$acc host_data use_device(a) if(.false.)
a = 1.0
diff --git a/flang/test/Lower/OpenACC/acc-host-data.f90 b/flang/test/Lower/OpenACC/acc-host-data.f90
index 871eabd256ca6..4d09b25b983b9 100644
--- a/flang/test/Lower/OpenACC/acc-host-data.f90
+++ b/flang/test/Lower/OpenACC/acc-host-data.f90
@@ -14,34 +14,37 @@ subroutine acc_host_data()
!$acc host_data use_device(a)
!$acc end host_data
-! CHECK: %[[DA:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
-! CHECK: acc.host_data dataOperands(%[[DA]] : !fir.ref<!fir.array<10xf32>>)
+! CHECK: %[[DA0:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+! CHECK: %[[DA1:.*]] = acc.use_device varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+! CHECK: acc.host_data dataOperands(%[[DA0]], %[[DA1]] : !fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
!$acc host_data use_device(a) if_present
!$acc end host_data
-! CHECK: %[[DA:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
-! CHECK: acc.host_data dataOperands(%[[DA]] : !fir.ref<!fir.array<10xf32>>) {
+! CHECK: %[[DA0:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+! CHECK: %[[DA1:.*]] = acc.use_device varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+! CHECK: acc.host_data dataOperands(%[[DA0]], %[[DA1]] : !fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
! CHECK: } attributes {ifPresent}
- !$acc host_data use_device(a) if_present if_present
+ !$acc host_data use_device(a) if_present
!$acc end host_data
-! CHECK: acc.host_data dataOperands(%{{.*}} : !fir.ref<!fir.array<10xf32>>) {
+! CHECK: acc.host_data dataOperands(%{{.*}}{{.*}} : !fir.ref<!fir.array<10xf32>>{{.*}}) {
! CHECK: } attributes {ifPresent}
!$acc host_data use_device(a) if(ifCondition)
!$acc end host_data
-! CHECK: %[[DA:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+! CHECK: %[[DA0:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
+! CHECK: %[[DA1:.*]] = acc.use_device varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
! CHECK: %[[LOAD_IFCOND:.*]] = fir.load %[[DECLIFCOND]]#0 : !fir.ref<!fir.logical<4>>
! CHECK: %[[IFCOND_I1:.*]] = fir.convert %[[LOAD_IFCOND]] : (!fir.logical<4>) -> i1
-! CHECK: acc.host_data if(%[[IFCOND_I1]]) dataOperands(%[[DA]] : !fir.ref<!fir.array<10xf32>>)
+! CHECK: acc.host_data if(%[[IFCOND_I1]]) dataOperands(%[[DA0]]{{.*}} : !fir.ref<!fir.array<10xf32>>{{.*}})
!$acc host_data use_device(a) if(.true.)
!$acc end host_data
! CHECK: %[[DA:.*]] = acc.use_device varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "a"}
-! CHECK: acc.host_data dataOperands(%[[DA]] : !fir.ref<!fir.array<10xf32>>)
+! CHECK: acc.host_data dataOperands(%[[DA]]{{.*}} : !fir.ref<!fir.array<10xf32>>{{.*}})
!$acc host_data use_device(a) if(.false.)
a = 1.0
>From ec7ad5b11cc46322c6188bafb23f89cb3767ad70 Mon Sep 17 00:00:00 2001
From: nvpm <pmathew at nvidia.com>
Date: Thu, 10 Jul 2025 09:58:36 -0700
Subject: [PATCH 3/4] Remove debug messages. Format file.
---
flang/lib/Lower/OpenACC.cpp | 16 +++++++---------
1 file changed, 7 insertions(+), 9 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4f637b88fd269..d43345021063b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -738,19 +738,17 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
implicit, dataClause, baseAddr.getType(), async, asyncDeviceTypes,
asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent);
dataOperands.push_back(op.getAccVar());
- // If the input value has a descriptor, we need to create a use device op
- // for the descriptor as well as the base address.
+ // For UseDeviceOp, if operand is one of a pair resulting from a
+ // declare operation, create a UseDeviceOp for the other operand as well.
if constexpr (std::is_same_v<Op, mlir::acc::UseDeviceOp>) {
- LLVM_DEBUG(llvm::dbgs() << __func__ << " found usedeviceop \n"; info.dump(llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << __func__ << " had previously created and added usedeviceop \n"; op.dump());
if (mlir::isa<hlfir::DeclareOp>(baseAddr.getDefiningOp())) {
Op op = createDataEntryOp<Op>(
- builder, operandLocation, baseAddr.getDefiningOp()->getResult(1), asFortran, bounds, structured,
- implicit, dataClause, baseAddr.getDefiningOp()->getResult(1).getType(), async, asyncDeviceTypes,
- asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent);
- LLVM_DEBUG(llvm::dbgs() << __func__ << " created usedeviceop \n"; op.dump());
+ builder, operandLocation, baseAddr.getDefiningOp()->getResult(1),
+ asFortran, bounds, structured, implicit, dataClause,
+ baseAddr.getDefiningOp()->getResult(1).getType(), async,
+ asyncDeviceTypes, asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true,
+ info.isPresent);
dataOperands.push_back(op.getAccVar());
- LLVM_DEBUG(llvm::dbgs() << __func__ << "added usedeviceop on descriptor\n"; info.dump(llvm::dbgs()));
}
}
}
>From c385240e757bc2e6c127304a36bae1f1e9d02d60 Mon Sep 17 00:00:00 2001
From: nvpm <pmathew at nvidia.com>
Date: Thu, 10 Jul 2025 09:59:57 -0700
Subject: [PATCH 4/4] Remove debug messages.
---
flang/lib/Lower/OpenACC.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index d43345021063b..23481d0ef7935 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -724,7 +724,7 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
/*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
/*genDefaultBounds=*/generateDefaultBounds,
/*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
- LLVM_DEBUG(llvm::dbgs() << __func__ << "Here \n"; info.dump(llvm::dbgs()));
+ LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
More information about the flang-commits
mailing list