[flang-commits] [flang] [flang][openacc] Add support for allocatable and pointer arrays in reduction (PR #68261)
via flang-commits
flang-commits at lists.llvm.org
Wed Oct 4 14:12:32 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
<details>
<summary>Changes</summary>
This patch adds support for allocatable and pointer arrays in the reduction recipe lowering.
---
Full diff: https://github.com/llvm/llvm-project/pull/68261.diff
4 Files Affected:
- (modified) flang/include/flang/Optimizer/Dialect/FIRType.h (+3)
- (modified) flang/lib/Lower/OpenACC.cpp (+42-8)
- (modified) flang/lib/Optimizer/Dialect/FIRType.cpp (+12)
- (modified) flang/test/Lower/OpenACC/acc-reduction.f90 (+71-6)
``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 4eea6d2bbb3ee36..2abcc6547bbabf6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -336,6 +336,9 @@ bool isAssumedType(mlir::Type ty);
/// Return true iff `ty` is the type of an assumed shape array.
bool isAssumedShape(mlir::Type ty);
+/// Return true iff `ty` is the type of an allocatable array.
+bool isAllocatableOrPointerArray(mlir::Type ty);
+
/// Return true iff `boxTy` wraps a record type or an unlimited polymorphic
/// entity. Polymorphic entities with intrinsic type spec do not have addendum
inline bool boxHasAddendum(fir::BaseBoxType boxTy) {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 84ce4cde94200f9..ee8869d5a11f879 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -742,9 +742,28 @@ static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder,
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
return getReductionInitValue(builder, loc, boxTy.getEleTy(), op);
+ if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
+ return getReductionInitValue(builder, loc, heapTy.getEleTy(), op);
+
+ if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
+ return getReductionInitValue(builder, loc, ptrTy.getEleTy(), op);
+
llvm::report_fatal_error("Unsupported OpenACC reduction type");
}
+/// Return the nested sequence type if any.
+static mlir::Type extractSequenceType(mlir::Type ty) {
+ if (mlir::isa<fir::SequenceType>(ty))
+ return ty;
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
+ return extractSequenceType(boxTy.getEleTy());
+ if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
+ return extractSequenceType(heapTy.getEleTy());
+ if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
+ return extractSequenceType(ptrTy.getEleTy());
+ return mlir::Type{};
+}
+
static mlir::Value genReductionInitRegion(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Type ty,
mlir::acc::ReductionOperator op) {
@@ -788,7 +807,8 @@ static mlir::Value genReductionInitRegion(fir::FirOpBuilder &builder,
return declareOp.getBase();
}
} else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
- if (!mlir::isa<fir::SequenceType>(boxTy.getEleTy()))
+ mlir::Type innerTy = extractSequenceType(boxTy);
+ if (!mlir::isa<fir::SequenceType>(innerTy))
TODO(loc, "Unsupported boxed type for reduction");
// Create the private copy from the initial fir.box.
hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)};
@@ -993,8 +1013,9 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
builder.setInsertionPointAfter(loops[0]);
} else if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
llvm::SmallVector<mlir::Value> tripletArgs;
+ mlir::Type innerTy = extractSequenceType(boxTy);
fir::SequenceType seqTy =
- mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy());
+ mlir::dyn_cast_or_null<fir::SequenceType>(innerTy);
if (!seqTy)
TODO(loc, "Unsupported boxed type in OpenACC reduction");
@@ -1110,6 +1131,19 @@ mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
return recipe;
}
+static bool isSupportedReductionType(mlir::Type ty) {
+ ty = fir::unwrapRefType(ty);
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty))
+ return isSupportedReductionType(boxTy.getEleTy());
+ if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
+ return isSupportedReductionType(seqTy.getEleTy());
+ if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
+ return isSupportedReductionType(heapTy.getEleTy());
+ if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(ty))
+ return isSupportedReductionType(ptrTy.getEleTy());
+ return fir::isa_trivial(ty);
+}
+
static void
genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
Fortran::lower::AbstractConverter &converter,
@@ -1135,10 +1169,7 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
reductionTy = seqTy.getEleTy();
- if (!fir::isa_trivial(reductionTy) &&
- ((fir::isAllocatableType(reductionTy) ||
- fir::isPointerType(reductionTy)) &&
- !bounds.empty()))
+ if (!isSupportedReductionType(reductionTy))
TODO(operandLocation, "reduction with unsupported type");
auto op = createDataEntryOp<mlir::acc::ReductionOp>(
@@ -1146,13 +1177,16 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
/*structured=*/true, /*implicit=*/false,
mlir::acc::DataClause::acc_reduction, baseAddr.getType());
mlir::Type ty = op.getAccPtr().getType();
+ if (!areAllBoundConstant(bounds) ||
+ fir::isAssumedShape(baseAddr.getType()) ||
+ fir::isAllocatableOrPointerArray(baseAddr.getType()))
+ ty = baseAddr.getType();
std::string suffix =
areAllBoundConstant(bounds) ? getBoundsString(bounds) : "";
std::string recipeName = fir::getTypeAsString(
ty, converter.getKindMap(),
("reduction_" + stringifyReductionOperator(mlirOp)).str() + suffix);
- if (!areAllBoundConstant(bounds) || fir::isAssumedShape(baseAddr.getType()))
- ty = baseAddr.getType();
+
mlir::acc::ReductionRecipeOp recipe =
Fortran::lower::createOrGetReductionRecipe(
builder, recipeName, operandLocation, ty, mlirOp, bounds);
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index bc35f9b44e73e9f..323b589cc7f2eb1 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -318,6 +318,18 @@ bool isAssumedShape(mlir::Type ty) {
return false;
}
+bool isAllocatableOrPointerArray(mlir::Type ty) {
+ if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
+ ty = refTy;
+ if (auto boxTy = mlir::dyn_cast<fir::BoxType>(ty)) {
+ if (auto heapTy = mlir::dyn_cast<fir::HeapType>(boxTy.getEleTy()))
+ return mlir::isa<fir::SequenceType>(heapTy.getEleTy());
+ if (auto ptrTy = mlir::dyn_cast<fir::PointerType>(boxTy.getEleTy()))
+ return mlir::isa<fir::SequenceType>(ptrTy.getEleTy());
+ }
+ return false;
+}
+
bool isPolymorphicType(mlir::Type ty) {
if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
ty = refTy;
diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90
index c102991b48632d2..07979445394d929 100644
--- a/flang/test/Lower/OpenACC/acc-reduction.f90
+++ b/flang/test/Lower/OpenACC/acc-reduction.f90
@@ -3,7 +3,42 @@
! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s --check-prefixes=CHECK,FIR
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s --check-prefixes=CHECK,HLFIR
-! CHECK-LABEL: acc.reduction.recipe @reduction_add_section_lb1.ub3_ref_Uxi32 : !fir.box<!fir.array<?xi32>> reduction_operator <add> init {
+! CHECK-LABEL: acc.reduction.recipe @reduction_max_box_ptr_Uxf32 : !fir.box<!fir.ptr<!fir.array<?xf32>>> reduction_operator <max> init {
+! CHECK: ^bb0(%{{.*}}: !fir.box<!fir.ptr<!fir.array<?xf32>>>):
+! CHECK: } combiner {
+! CHECK: ^bb0(%{{.*}}: !fir.box<!fir.ptr<!fir.array<?xf32>>>, %{{.*}}: !fir.box<!fir.ptr<!fir.array<?xf32>>>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
+! CHECK: }
+
+! CHECK-LABEL: acc.reduction.recipe @reduction_max_box_heap_Uxf32 : !fir.box<!fir.heap<!fir.array<?xf32>>> reduction_operator <max> init {
+! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.heap<!fir.array<?xf32>>>):
+! HLFIR: %[[CST:.*]] = arith.constant -1.401300e-45 : f32
+! HLFIR: %[[C0:.*]] = arith.constant 0 : index
+! HLFIR: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[C0]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> (index, index, index)
+! HLFIR: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 : (index) -> !fir.shape<1>
+! HLFIR: %[[TEMP:.*]] = fir.allocmem !fir.array<?xf32>, %[[BOX_DIMS]]#1 {bindc_name = ".tmp", uniq_name = ""}
+! HLFIR: %[[DECLARE:.*]]:2 = hlfir.declare %2(%1) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.heap<!fir.array<?xf32>>)
+! HLFIR: hlfir.assign %[[CST]] to %[[DECLARE]]#0 : f32, !fir.box<!fir.array<?xf32>>
+! HLFIR: acc.yield %[[DECLARE]]#0 : !fir.box<!fir.array<?xf32>>
+! CHECK: } combiner {
+! HLFIR: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.heap<!fir.array<?xf32>>>, %[[ARG1:.*]]: !fir.box<!fir.heap<!fir.array<?xf32>>>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+! HLFIR: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! HLFIR: %[[DES_V1:.*]] = hlfir.designate %[[ARG0]] (%[[ARG2]]:%[[ARG3]]:%[[ARG4]]) shape %[[SHAPE]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+! HLFIR: %[[DES_V2:.*]] = hlfir.designate %[[ARG1]] (%[[ARG2]]:%[[ARG3]]:%[[ARG4]]) shape %[[SHAPE]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
+! HLFIR: %[[ELEMENTAL:.*]] = hlfir.elemental %[[SHAPE]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! HLFIR: ^bb0(%[[IV:.*]]: index):
+! HLFIR: %[[V1:.*]] = hlfir.designate %[[DES_V1]] (%[[IV]]) : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
+! HLFIR: %[[V2:.*]] = hlfir.designate %[[DES_V2]] (%[[IV]]) : (!fir.box<!fir.heap<!fir.array<?xf32>>>, index) -> !fir.ref<f32>
+! HLFIR: %[[LOAD_V1:.*]] = fir.load %[[V1]] : !fir.ref<f32>
+! HLFIR: %[[LOAD_V2:.*]] = fir.load %[[V2]] : !fir.ref<f32>
+! HLFIR: %[[CMP:.*]] = arith.cmpf ogt, %[[LOAD_V1]], %[[LOAD_V2]] : f32
+! HLFIR: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LOAD_V1]], %[[LOAD_V2]] : f32
+! HLFIR: hlfir.yield_element %[[SELECT]] : f32
+! HLFIR: }
+! HLFIR: hlfir.assign %[[ELEMENTAL]] to %[[ARG0]] : !hlfir.expr<?xf32>, !fir.box<!fir.heap<!fir.array<?xf32>>>
+! HLFIR: acc.yield %[[ARG0]] : !fir.box<!fir.heap<!fir.array<?xf32>>>
+! CHECK: }
+
+! CHECK-LABEL: acc.reduction.recipe @reduction_add_section_lb1.ub3_box_Uxi32 : !fir.box<!fir.array<?xi32>> reduction_operator <add> init {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xi32>>):
! HLFIR: %[[BOX_DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %c0{{.*}} : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
! HLFIR: %[[SHAPE:.*]] = fir.shape %[[BOX_DIMS]]#1 : (index) -> !fir.shape<1>
@@ -29,7 +64,7 @@
! HLFIR: acc.yield %[[ARG0]] : !fir.box<!fir.array<?xi32>>
! HLFIR: }
-! CHECK-LABEL: acc.reduction.recipe @reduction_max_ref_Uxf32 : !fir.box<!fir.array<?xf32>> reduction_operator <max> init {
+! CHECK-LABEL: acc.reduction.recipe @reduction_max_box_Uxf32 : !fir.box<!fir.array<?xf32>> reduction_operator <max> init {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xf32>>):
! CHECK: %[[INIT_VALUE:.*]] = arith.constant -1.401300e-45 : f32
! HLFIR: %[[C0:.*]] = arith.constant 0 : index
@@ -57,7 +92,7 @@
! CHECK: acc.yield %[[ARG0]] : !fir.box<!fir.array<?xf32>>
! CHECK: }
-! CHECK-LABEL: acc.reduction.recipe @reduction_add_ref_Uxi32 : !fir.box<!fir.array<?xi32>> reduction_operator <add> init {
+! CHECK-LABEL: acc.reduction.recipe @reduction_add_box_Uxi32 : !fir.box<!fir.array<?xi32>> reduction_operator <add> init {
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.box<!fir.array<?xi32>>):
! HLFIR: %[[INIT_VALUE:.*]] = arith.constant 0 : i32
! HLFIR: %[[C0:.*]] = arith.constant 0 : index
@@ -1097,7 +1132,7 @@ subroutine acc_reduction_add_dynamic_extent_add(a)
! CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a"})
! HLFIR: %[[DECLARG0:.*]]:2 = hlfir.declare %[[ARG0]]
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.array<?xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<?xi32>> {name = "a"}
-! HLFIR: acc.parallel reduction(@reduction_add_ref_Uxi32 -> %[[RED:.*]] : !fir.ref<!fir.array<?xi32>>)
+! HLFIR: acc.parallel reduction(@reduction_add_box_Uxi32 -> %[[RED:.*]] : !fir.ref<!fir.array<?xi32>>)
subroutine acc_reduction_add_dynamic_extent_max(a)
real :: a(:)
@@ -1109,7 +1144,7 @@ subroutine acc_reduction_add_dynamic_extent_max(a)
! CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a"})
! HLFIR: %[[DECLARG0:.*]]:2 = hlfir.declare %[[ARG0]]
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
-! HLFIR: acc.parallel reduction(@reduction_max_ref_Uxf32 -> %[[RED]] : !fir.ref<!fir.array<?xf32>>) {
+! HLFIR: acc.parallel reduction(@reduction_max_box_Uxf32 -> %[[RED]] : !fir.ref<!fir.array<?xf32>>) {
subroutine acc_reduction_add_dynamic_extent_add_with_section(a)
integer :: a(:)
@@ -1123,4 +1158,34 @@ subroutine acc_reduction_add_dynamic_extent_add_with_section(a)
! HLFIR: %[[BOUND:.*]] = acc.bounds lowerbound(%c1{{.*}} : index) upperbound(%c3{{.*}} : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}} : index) {strideInBytes = true}
! HLFIR: %[[BOX_ADDR:.*]] = fir.box_addr %[[DECL]]#1 : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<?xi32>> {name = "a(2:4)"}
-! HLFIR: acc.parallel reduction(@reduction_add_section_lb1.ub3_ref_Uxi32 -> %[[RED]] : !fir.ref<!fir.array<?xi32>>)
+! HLFIR: acc.parallel reduction(@reduction_add_section_lb1.ub3_box_Uxi32 -> %[[RED]] : !fir.ref<!fir.array<?xi32>>)
+
+subroutine acc_reduction_add_allocatable(a)
+ real, allocatable :: a(:)
+ !$acc parallel reduction(max:a)
+ !$acc end parallel
+end subroutine
+
+! CHECK-LABEL: func.func @_QPacc_reduction_add_allocatable(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {fir.bindc_name = "a"})
+! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFacc_reduction_add_allocatableEa"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
+! HLFIR: %[[BOX:.*]] = fir.load %[[DECL]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! HLFIR: %[[BOUND:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}}#0 : index) {strideInBytes = true}
+! HLFIR: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.heap<!fir.array<?xf32>>
+! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%6) -> !fir.heap<!fir.array<?xf32>> {name = "a"}
+! HLFIR: acc.parallel reduction(@reduction_max_box_heap_Uxf32 -> %[[RED]] : !fir.heap<!fir.array<?xf32>>)
+
+subroutine acc_reduction_add_pointer_array(a)
+ real, pointer :: a(:)
+ !$acc parallel reduction(max:a)
+ !$acc end parallel
+end subroutine
+
+! CHECK-LABEL: func.func @_QPacc_reduction_add_pointer_array(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a"})
+! HLFIR: %[[DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFacc_reduction_add_pointer_arrayEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+! HLFIR: %[[BOX:.*]] = fir.load %[[DECL]]#1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+! HLFIR: %[[BOUND:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}}#0 : index) {strideInBytes = true}
+! HLFIR: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
+! HLFIR: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
+! HLFIR: acc.parallel reduction(@reduction_max_box_ptr_Uxf32 -> %[[RED]] : !fir.ptr<!fir.array<?xf32>>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/68261
More information about the flang-commits
mailing list