[llvm-branch-commits] [flang] [mlir] [MLIR][OpenMP] Use map format to represent use_device_{addr, ptr} (PR #109810)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Sep 30 04:43:44 PDT 2024
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/109810
>From f61e3a60d6f494d08b58ded9b802f2b3d92b728f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 20 Sep 2024 17:11:34 +0100
Subject: [PATCH] [MLIR][OpenMP] Use map format to represent
use_device_{addr,ptr}
This patch updates the `omp.target_data` operation to use the same formatting
as `map` clauses on `omp.target` for `use_device_addr` and `use_device_ptr`.
This is done so the mapping that is being enforced between op arguments and
associated entry block arguments is explicit.
The way it is achieved is by marking these clauses as entry block
argument-defining and adjusting printer/parsers accordingly.
As a result of this change, block arguments for `use_device_addr` come before
those for `use_device_ptr`, which is the opposite of the previous undocumented
situation. Some unit tests are updated based on this change, in addition to
those updated because of the format change.
---
.../Fir/convert-to-llvm-openmp-and-fir.fir | 5 +-
flang/test/Lower/OpenMP/target.f90 | 6 +-
.../use-device-ptr-to-use-device-addr.f90 | 12 +--
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 28 ++++++-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 6 ++
.../Dialect/OpenMP/OpenMPOpsInterfaces.td | 37 ++++++++-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 43 +++++++++++
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 77 ++++++++++++-------
mlir/test/Dialect/OpenMP/ops.mlir | 6 +-
mlir/test/Target/LLVMIR/omptarget-llvm.mlir | 19 ++---
.../openmp-target-use-device-nested.mlir | 3 +-
11 files changed, 179 insertions(+), 63 deletions(-)
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 4d226eaa754c12..61f18008633d50 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -429,13 +429,14 @@ func.func @_QPopenmp_target_data_region() {
func.func @_QPomp_target_data_empty() {
%0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
- omp.target_data use_device_addr(%0 : !fir.ref<!fir.array<1024xi32>>) {
+ omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
+ omp.terminator
}
return
}
// CHECK-LABEL: llvm.func @_QPomp_target_data_empty
-// CHECK: omp.target_data use_device_addr(%1 : !llvm.ptr) {
+// CHECK: omp.target_data use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
// CHECK: }
// -----
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index dedce581436490..ab33b6b3808315 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -506,9 +506,8 @@ subroutine omp_target_device_ptr
type(c_ptr) :: a
integer, target :: b
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
- !CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}})
+ !CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}} -> %[[VAL_1:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)
!$omp target data map(tofrom: a) use_device_ptr(a)
- !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
!CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
a = c_loc(b)
!CHECK: omp.terminator
@@ -529,9 +528,8 @@ subroutine omp_target_device_addr
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
- !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]], %[[DEV_ADDR]] : {{.*}}) {
+ !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]] -> %[[ARG_0:.*]], %[[DEV_ADDR]] -> %[[ARG_1:.*]] : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
!$omp target data map(tofrom: a) use_device_addr(a)
- !CHECK: ^bb0(%[[ARG_0:.*]]: !fir.llvm_ptr<!fir.ref<i32>>, %[[ARG_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
!CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_1]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
!CHECK: %[[C10:.*]] = arith.constant 10 : i32
!CHECK: %[[A_BOX:.*]] = fir.load %[[VAL_1_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
diff --git a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
index 085f5419fa7f88..cb26246a6e80f0 100644
--- a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
+++ b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
@@ -6,8 +6,7 @@
! use_device_ptr to use_device_addr works, without breaking any functionality.
!CHECK: func.func @{{.*}}only_use_device_ptr()
-!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
-!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine only_use_device_ptr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -19,8 +18,7 @@ subroutine only_use_device_ptr
end subroutine
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
-!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
-!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -32,8 +30,7 @@ subroutine mix_use_device_ptr_and_addr
end subroutine
!CHECK: func.func @{{.*}}only_use_device_addr()
- !CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
- !CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
+ !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
subroutine only_use_device_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -45,8 +42,7 @@ subroutine only_use_device_addr
end subroutine
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
- !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
- !CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
+ !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr_and_map
use iso_c_binding
integer :: i, j
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 97e8b368050725..886554f66afffc 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1209,18 +1209,28 @@ class OpenMP_UseDeviceAddrClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
+ let traits = [
+ BlockArgOpenMPOpInterface
+ ];
+
let arguments = (ins
Variadic<OpenMP_PointerLikeType>:$use_device_addr_vars
);
- let optAssemblyFormat = [{
- `use_device_addr` `(` $use_device_addr_vars `:` type($use_device_addr_vars) `)`
+ let extraClassDeclaration = [{
+ unsigned numUseDeviceAddrBlockArgs() {
+ return getUseDeviceAddrVars().size();
+ }
}];
let description = [{
The optional `use_device_addr_vars` specifies the address of the objects in
the device data environment.
}];
+
+ // Assembly format not defined because this clause must be processed together
+ // with the first region of the operation, as it defines entry block
+ // arguments.
}
def OpenMP_UseDeviceAddrClause : OpenMP_UseDeviceAddrClauseSkip<>;
@@ -1234,18 +1244,28 @@ class OpenMP_UseDevicePtrClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
+ let traits = [
+ BlockArgOpenMPOpInterface
+ ];
+
let arguments = (ins
Variadic<OpenMP_PointerLikeType>:$use_device_ptr_vars
);
- let optAssemblyFormat = [{
- `use_device_ptr` `(` $use_device_ptr_vars `:` type($use_device_ptr_vars) `)`
+ let extraClassDeclaration = [{
+ unsigned numUseDevicePtrBlockArgs() {
+ return getUseDevicePtrVars().size();
+ }
}];
let description = [{
The optional `use_device_ptr_vars` specifies the device pointers to the
corresponding list items in the device data environment.
}];
+
+ // Assembly format not defined because this clause must be processed together
+ // with the first region of the operation, as it defines entry block
+ // arguments.
}
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index e58ccc4e930210..d2a2b44c042fb7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -987,6 +987,12 @@ def TargetDataOp: OpenMP_Op<"target_data", traits = [
OpBuilder<(ins CArg<"const TargetDataOperands &">:$clauses)>
];
+ let assemblyFormat = clausesAssemblyFormat # [{
+ custom<UseDeviceAddrUseDevicePtrRegion>(
+ $region, $use_device_addr_vars, type($use_device_addr_vars),
+ $use_device_ptr_vars, type($use_device_ptr_vars)) attr-dict
+ }];
+
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 1aaa4060793995..93ffa35a636911 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -45,6 +45,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
"unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
return 0;
}]>,
+ InterfaceMethod<"Get number of block arguments defined by `use_device_addr`.",
+ "unsigned", "numUseDeviceAddrBlockArgs", (ins), [{}], [{
+ return 0;
+ }]>,
+ InterfaceMethod<"Get number of block arguments defined by `use_device_ptr`.",
+ "unsigned", "numUseDevicePtrBlockArgs", (ins), [{}], [{
+ return 0;
+ }]>,
// Unified access methods for clause-associated entry block arguments.
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
@@ -72,6 +80,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs();
}]>,
+ InterfaceMethod<"Get start index of block arguments defined by `use_device_addr`.",
+ "unsigned", "getUseDeviceAddrBlockArgsStart", (ins), [{
+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+ return iface.getTaskReductionBlockArgsStart() + $_op.numTaskReductionBlockArgs();
+ }]>,
+ InterfaceMethod<"Get start index of block arguments defined by `use_device_ptr`.",
+ "unsigned", "getUseDevicePtrBlockArgsStart", (ins), [{
+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+ return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
+ }]>,
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
@@ -109,13 +127,30 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
iface.getTaskReductionBlockArgsStart(),
$_op.numTaskReductionBlockArgs());
}]>,
+ InterfaceMethod<"Get block arguments defined by `use_device_addr`.",
+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+ "getUseDeviceAddrBlockArgs", (ins), [{
+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+ return $_op->getRegion(0).getArguments().slice(
+ iface.getUseDeviceAddrBlockArgsStart(),
+ $_op.numUseDeviceAddrBlockArgs());
+ }]>,
+ InterfaceMethod<"Get block arguments defined by `use_device_ptr`.",
+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+ "getUseDevicePtrBlockArgs", (ins), [{
+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+ return $_op->getRegion(0).getArguments().slice(
+ iface.getUseDevicePtrBlockArgsStart(),
+ $_op.numUseDevicePtrBlockArgs());
+ }]>,
];
let verify = [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
unsigned expectedArgs = iface.numInReductionBlockArgs() +
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
- iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
+ iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
+ iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
return $_op->emitOpError() << "expected at least " << expectedArgs
<< " entry block argument(s)";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index a96f70807cc813..220eb848ab4de2 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -504,6 +504,8 @@ struct AllRegionParseArgs {
std::optional<PrivateParseArgs> privateArgs;
std::optional<ReductionParseArgs> reductionArgs;
std::optional<ReductionParseArgs> taskReductionArgs;
+ std::optional<MapParseArgs> useDeviceAddrArgs;
+ std::optional<MapParseArgs> useDevicePtrArgs;
};
} // namespace
@@ -648,6 +650,16 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
return parser.emitError(parser.getCurrentLocation())
<< "invalid `task_reduction` format";
+ if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
+ args.useDeviceAddrArgs)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid `use_device_addr` format";
+
+ if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
+ args.useDevicePtrArgs)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid `use_device_addr` format";
+
return parser.parseRegion(region, entryBlockArgs);
}
@@ -735,6 +747,18 @@ static ParseResult parseTaskReductionRegion(
return parseBlockArgRegion(parser, region, args);
}
+static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
+ OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
+ SmallVectorImpl<Type> &useDeviceAddrTypes,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars,
+ SmallVectorImpl<Type> &useDevicePtrTypes) {
+ AllRegionParseArgs args;
+ args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
+ args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
+ return parseBlockArgRegion(parser, region, args);
+}
+
//===----------------------------------------------------------------------===//
// Printers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
@@ -767,6 +791,8 @@ struct AllRegionPrintArgs {
std::optional<PrivatePrintArgs> privateArgs;
std::optional<ReductionPrintArgs> reductionArgs;
std::optional<ReductionPrintArgs> taskReductionArgs;
+ std::optional<MapPrintArgs> useDeviceAddrArgs;
+ std::optional<MapPrintArgs> useDevicePtrArgs;
};
} // namespace
@@ -849,6 +875,11 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
printBlockArgClause(p, ctx, "task_reduction",
iface.getTaskReductionBlockArgs(),
args.taskReductionArgs);
+ printBlockArgClause(p, ctx, "use_device_addr",
+ iface.getUseDeviceAddrBlockArgs(),
+ args.useDeviceAddrArgs);
+ printBlockArgClause(p, ctx, "use_device_ptr",
+ iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
@@ -925,6 +956,18 @@ static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
printBlockArgRegion(p, op, region, args);
}
+static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
+ Region ®ion,
+ ValueRange useDeviceAddrVars,
+ TypeRange useDeviceAddrTypes,
+ ValueRange useDevicePtrVars,
+ TypeRange useDevicePtrTypes) {
+ AllRegionPrintArgs args;
+ args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
+ args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
+ printBlockArgRegion(p, op, region, args);
+}
+
/// Verifies Reduction Clause
static LogicalResult
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index bbc0b518e99bfc..0a808160eef211 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2449,8 +2449,8 @@ static void collectMapDataFromMapOperands(
}
};
- addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
+ addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
}
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
@@ -3056,6 +3056,31 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
return combinedInfo;
};
+ // Define a lambda to apply mappings between use_device_addr and
+ // use_device_ptr base pointers, and their associated block arguments.
+ auto mapUseDevice =
+ [&moduleTranslation](
+ llvm::OpenMPIRBuilder::DeviceInfoTy type,
+ llvm::ArrayRef<BlockArgument> blockArgs,
+ llvm::OpenMPIRBuilder::MapValuesArrayTy &basePointers,
+ llvm::OpenMPIRBuilder::MapDeviceInfoArrayTy &devicePointers,
+ llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
+ // Get a range to iterate over `basePointers` after filtering based on
+ // `devicePointers` and the given device info type.
+ auto basePtrRange = llvm::map_range(
+ llvm::make_filter_range(
+ llvm::zip_equal(basePointers, devicePointers),
+ [type](auto x) { return std::get<1>(x) == type; }),
+ [](auto x) { return std::get<0>(x); });
+
+ // Map block arguments to the corresponding processed base pointer. If
+ // a mapper is not specified, map the block argument to the base pointer
+ // directly.
+ for (auto [arg, basePointer] : llvm::zip_equal(blockArgs, basePtrRange))
+ moduleTranslation.mapValue(arg, mapper ? mapper(basePointer)
+ : basePointer);
+ };
+
llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
/*SeparateBeginEndCalls=*/true);
@@ -3064,29 +3089,28 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) {
assert(isa<omp::TargetDataOp>(op) &&
"BodyGen requested for non TargetDataOp");
+ auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
Region ®ion = cast<omp::TargetDataOp>(op).getRegion();
switch (bodyGenType) {
case BodyGenTy::Priv:
// Check if any device ptr/addr info is available
if (!info.DevicePtrInfoMap.empty()) {
builder.restoreIP(codeGenIP);
- unsigned argIndex = 0;
- for (auto [basePointer, devicePointer] : llvm::zip_equal(
- combinedInfo.BasePointers, combinedInfo.DevicePointers)) {
- if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) {
- const auto &arg = region.front().getArgument(argIndex);
- moduleTranslation.mapValue(
- arg, info.DevicePtrInfoMap[basePointer].second);
- argIndex++;
- } else if (devicePointer ==
- llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
- const auto &arg = region.front().getArgument(argIndex);
- auto *loadInst = builder.CreateLoad(
- builder.getPtrTy(), info.DevicePtrInfoMap[basePointer].second);
- moduleTranslation.mapValue(arg, loadInst);
- argIndex++;
- }
- }
+
+ mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
+ blockArgIface.getUseDeviceAddrBlockArgs(),
+ combinedInfo.BasePointers, combinedInfo.DevicePointers,
+ [&](llvm::Value *basePointer) -> llvm::Value * {
+ return builder.CreateLoad(
+ builder.getPtrTy(),
+ info.DevicePtrInfoMap[basePointer].second);
+ });
+ mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
+ blockArgIface.getUseDevicePtrBlockArgs(),
+ combinedInfo.BasePointers, combinedInfo.DevicePointers,
+ [&](llvm::Value *basePointer) {
+ return info.DevicePtrInfoMap[basePointer].second;
+ });
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
builder, moduleTranslation);
@@ -3101,17 +3125,14 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
// For device pass, if use_device_ptr(addr) mappings were present,
// we need to link them here before codegen.
if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
- unsigned argIndex = 0;
- for (auto [basePointer, devicePointer] :
- llvm::zip_equal(mapData.BasePointers, mapData.DevicePointers)) {
- if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer ||
- devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Address) {
- const auto &arg = region.front().getArgument(argIndex);
- moduleTranslation.mapValue(arg, basePointer);
- argIndex++;
- }
- }
+ mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
+ blockArgIface.getUseDeviceAddrBlockArgs(),
+ mapData.BasePointers, mapData.DevicePointers);
+ mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
+ blockArgIface.getUseDevicePtrBlockArgs(),
+ mapData.BasePointers, mapData.DevicePointers);
}
+
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
builder, moduleTranslation);
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4b1468a6761e66..ce3351ba1149f3 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -864,9 +864,11 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref<i
omp.target_data if(%if_cond) device(%device : si32) map_entries(%mapv1 : memref<?xi32>){}
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
- // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[VAL_4:.*]] : memref<?xi32>) use_device_ptr(%[[VAL_3:.*]] : memref<i32>)
+ // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref<?xi32>) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref<?xi32>) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref<i32>)
%mapv2 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(close, present, to) capture(ByRef) -> memref<?xi32> {name = ""}
- omp.target_data use_device_ptr(%device_ptr : memref<i32>) use_device_addr(%device_addr : memref<?xi32>) map_entries(%mapv2 : memref<?xi32>) {}
+ omp.target_data map_entries(%mapv2 : memref<?xi32>) use_device_addr(%device_addr -> %arg0 : memref<?xi32>) use_device_ptr(%device_ptr -> %arg1 : memref<i32>) {
+ omp.terminator
+ }
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
// CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index 458d2f28a78f8d..654763c577d1af 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -210,8 +210,7 @@ llvm.func @_QPopenmp_target_use_dev_ptr() {
%a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
%map1 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
%map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map1 : !llvm.ptr) use_device_ptr(%map2 : !llvm.ptr) {
- ^bb0(%arg0: !llvm.ptr):
+ omp.target_data map_entries(%map1 : !llvm.ptr) use_device_ptr(%map2 -> %arg0 : !llvm.ptr) {
%1 = llvm.mlir.constant(10 : i32) : i32
%2 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
llvm.store %1, %2 : i32, !llvm.ptr
@@ -255,8 +254,7 @@ llvm.func @_QPopenmp_target_use_dev_addr() {
%a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
%map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
%map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr) {
- ^bb0(%arg0: !llvm.ptr):
+ omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 -> %arg0 : !llvm.ptr) {
%1 = llvm.mlir.constant(10 : i32) : i32
%2 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
llvm.store %1, %2 : i32, !llvm.ptr
@@ -298,8 +296,7 @@ llvm.func @_QPopenmp_target_use_dev_addr_no_ptr() {
%a = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
%map = omp.map.info var_ptr(%a : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
%map2 = omp.map.info var_ptr(%a : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr) {
- ^bb0(%arg0: !llvm.ptr):
+ omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 -> %arg0 : !llvm.ptr) {
%1 = llvm.mlir.constant(10 : i32) : i32
llvm.store %1, %arg0 : i32, !llvm.ptr
omp.terminator
@@ -341,8 +338,7 @@ llvm.func @_QPopenmp_target_use_dev_addr_nomap() {
%b = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
%map = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
%map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr) {
- ^bb0(%arg0: !llvm.ptr):
+ omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 -> %arg0 : !llvm.ptr) {
%2 = llvm.mlir.constant(10 : i32) : i32
%3 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
llvm.store %2, %3 : i32, !llvm.ptr
@@ -400,13 +396,12 @@ llvm.func @_QPopenmp_target_use_dev_both() {
%map1 = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
%map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
%map3 = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data map_entries(%map, %map1 : !llvm.ptr, !llvm.ptr) use_device_ptr(%map2 : !llvm.ptr) use_device_addr(%map3 : !llvm.ptr) {
- ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ omp.target_data map_entries(%map, %map1 : !llvm.ptr, !llvm.ptr) use_device_addr(%map3 -> %arg0 : !llvm.ptr) use_device_ptr(%map2 -> %arg1 : !llvm.ptr) {
%2 = llvm.mlir.constant(10 : i32) : i32
- %3 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
+ %3 = llvm.load %arg1 : !llvm.ptr -> !llvm.ptr
llvm.store %2, %3 : i32, !llvm.ptr
%4 = llvm.mlir.constant(20 : i32) : i32
- %5 = llvm.load %arg1 : !llvm.ptr -> !llvm.ptr
+ %5 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr
llvm.store %4, %5 : i32, !llvm.ptr
omp.terminator
}
diff --git a/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
index a4f8098879a9f8..3a71778e7d0a7e 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir
@@ -22,8 +22,7 @@ module attributes {omp.is_target_device = true } {
%0 = llvm.mlir.constant(1 : i64) : i64
%a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr
%map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
- omp.target_data use_device_ptr(%map : !llvm.ptr) {
- ^bb0(%arg0: !llvm.ptr):
+ omp.target_data use_device_ptr(%map -> %arg0 : !llvm.ptr) {
%map1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
omp.target map_entries(%map1 -> %arg1 : !llvm.ptr){
%1 = llvm.mlir.constant(999 : i32) : i32
More information about the llvm-branch-commits
mailing list