[flang-commits] [flang] [flang][acc] Fix cache directive lowering for derived type designators (PR #176022)
via flang-commits
flang-commits at lists.llvm.org
Wed Jan 14 11:50:17 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-openacc
Author: None (khaki3)
<details>
<summary>Changes</summary>
This PR fixes the OpenACC `cache` directive to correctly handle derived type designators such as `data%array(i-4:i+4)`. Previously, the lowering would assert or silently fail because the implementation assumed all cache operands mapped directly to `hlfir.declare` operations.
Problem
-------
The original implementation had several issues:
1. Assertion failure: For derived type members, the symbol refers to the component definition in the type (e.g., `array`), not a variable with an `hlfir.declare`. The code asserted expecting an `hlfir.declare` for all symbols.
2. DCE removal: Even after fixing the assertion, the `acc.cache` operation was being removed by Dead Code Elimination because:
- The component symbol isn't in the symbol map (it's a type definition, not a variable)
- Without proper rebinding, the cache result had no uses
3. Incorrect cache usage: Subsequent accesses like `data%array(i)` continued using the original variable instead of the cached reference.
Solution
--------
1. Use `info.addr` from `gatherDataOperandAddrAndBounds`: This properly handles derived type components by returning the address of the component, not requiring a symbol lookup.
2. Use function-level `StatementContext`: Prevents premature cleanup of generated operations.
3. Fallback base address resolution: When `info.addr` is null, fall back to `info.rawInput` or `genExprAddr()`.
4. Proper rebinding for both cases:
- Simple variables: Rebind the symbol directly via `bindSymbol()`
- Derived type components: Extract the `Component` reference from the designator and use `addComponentOverride()` so subsequent accesses use the cached value
Testing
-------
Added test cases for:
- `!$acc cache(data%array(i-4:i+4))` - derived type with array section
- `!$acc cache(readonly: data%array(i-4:i+4))` - readonly modifier
- `!$acc cache(obj%in%arr(i))` - nested derived type
All existing tests continue to pass.
Files Changed
-------------
- flang/lib/Lower/OpenACC.cpp - Fix cache directive lowering
- flang/test/Lower/OpenACC/acc-cache.f90 - Add derived type test cases
---
Full diff: https://github.com/llvm/llvm-project/pull/176022.diff
2 Files Affected:
- (modified) flang/lib/Lower/OpenACC.cpp (+67-18)
- (modified) flang/test/Lower/OpenACC/acc-cache.f90 (+77)
``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 062366f87eb09..fb25991cbdcea 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -4856,12 +4856,11 @@ genACC(Fortran::lower::AbstractConverter &converter,
modifier &&
(*modifier).v == Fortran::parser::AccDataModifier::Modifier::ReadOnly;
- Fortran::lower::StatementContext stmtCtx;
+ Fortran::lower::StatementContext &stmtCtx = converter.getFctCtx();
for (const auto &accObject : accObjectList.v) {
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
-
std::stringstream asFortran;
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
@@ -4869,19 +4868,22 @@ genACC(Fortran::lower::AbstractConverter &converter,
[&](auto &&s) { return ea.Analyze(s); }, accObject.u);
llvm::SmallVector<mlir::Value> bounds;
- Fortran::lower::gatherDataOperandAddrAndBounds<mlir::acc::DataBoundsOp,
- mlir::acc::DataBoundsType>(
- converter, builder, semanticsContext, stmtCtx, symbol, designator,
- operandLocation, asFortran, bounds,
- /*treatIndexAsSection=*/true, /*unwrapFirBox=*/false,
- /*genDefaultBounds=*/false, /*strideIncludeLowerExtent=*/false,
- /*loadAllocatableAndPointerComponent=*/false);
-
- std::optional<fir::FortranVariableOpInterface> varDef =
- converter.getSymbolMap().lookupVariableDefinition(symbol);
- assert(varDef.has_value() && llvm::isa<hlfir::DeclareOp>(*varDef) &&
- "expected symbol to be mapped to hlfir.declare");
- mlir::Value base = varDef->getBase();
+ fir::factory::AddrAndBoundsInfo info =
+ Fortran::lower::gatherDataOperandAddrAndBounds<
+ mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+ converter, builder, semanticsContext, stmtCtx, symbol, designator,
+ operandLocation, asFortran, bounds,
+ /*treatIndexAsSection=*/true, /*unwrapFirBox=*/false,
+ /*genDefaultBounds=*/false, /*strideIncludeLowerExtent=*/false,
+ /*loadAllocatableAndPointerComponent=*/false);
+
+ mlir::Value base = info.addr ? info.addr : info.rawInput;
+ if (!base && designator)
+ base = fir::getBase(
+ converter.genExprAddr(operandLocation, *designator, stmtCtx));
+
+ if (!base)
+ continue;
mlir::acc::CacheOp cacheOp = createDataEntryOp<mlir::acc::CacheOp>(
builder, operandLocation, base, asFortran, bounds,
@@ -4892,9 +4894,56 @@ genACC(Fortran::lower::AbstractConverter &converter,
isReadonly ? mlir::acc::DataClauseModifier::readonly
: mlir::acc::DataClauseModifier::none);
- fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(symbol);
- fir::ExtendedValue cacheExv = fir::substBase(hostExv, cacheOp.getAccVar());
- converter.bindSymbol(symbol, cacheExv);
+ // Rebind the symbol so subsequent references use the cached value.
+ if (Fortran::lower::SymbolBox symBox =
+ converter.getSymbolMap().lookupSymbol(symbol)) {
+ // For simple variables, rebind the symbol directly.
+ fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(symbol);
+ fir::ExtendedValue cacheExv =
+ fir::substBase(hostExv, cacheOp.getAccVar());
+ converter.bindSymbol(symbol, cacheExv);
+ } else if (designator) {
+ // For derived type components, extract the component reference and
+ // add a component override so subsequent accesses use the cached value.
+ std::optional<Fortran::evaluate::Component> componentRef;
+ if (std::optional<Fortran::evaluate::DataRef> dataRef =
+ Fortran::evaluate::ExtractDataRef(*designator)) {
+ Fortran::common::visit(
+ Fortran::common::visitors{
+ [&](const Fortran::evaluate::Component &component) {
+ componentRef = component;
+ },
+ [&](const Fortran::evaluate::ArrayRef &arrayRef) {
+ if (auto *comp = arrayRef.base().UnwrapComponent())
+ componentRef = *comp;
+ },
+ [&](const auto &) {}},
+ dataRef->u);
+ }
+
+ if (componentRef) {
+ // Create an hlfir.declare for the cache result and add component
+ // override so subsequent accesses use the cached value.
+ llvm::SmallVector<mlir::Value> lenParams;
+ mlir::Value shape;
+ if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(
+ fir::unwrapRefType(base.getType()))) {
+ llvm::SmallVector<mlir::Value> extents;
+ for (int64_t ext : arrTy.getShape()) {
+ if (ext != fir::SequenceType::getUnknownExtent())
+ extents.push_back(builder.createIntegerConstant(
+ operandLocation, builder.getIndexType(), ext));
+ }
+ if (!extents.empty())
+ shape = builder.genShape(operandLocation, extents);
+ }
+ auto declareOp = hlfir::DeclareOp::create(
+ builder, operandLocation, cacheOp.getAccVar(), asFortran.str(),
+ shape, lenParams, /*dummyScope=*/nullptr, /*storage=*/nullptr,
+ /*storageOffset=*/0, fir::FortranVariableFlagsAttr{});
+ converter.getSymbolMap().addComponentOverride(*componentRef, declareOp);
+ }
+ }
}
}
diff --git a/flang/test/Lower/OpenACC/acc-cache.f90 b/flang/test/Lower/OpenACC/acc-cache.f90
index 1cfe064993160..3fe5131364679 100644
--- a/flang/test/Lower/OpenACC/acc-cache.f90
+++ b/flang/test/Lower/OpenACC/acc-cache.f90
@@ -560,3 +560,80 @@ subroutine test_cache_in_nested_do()
! CHECK: fir.do_loop
! CHECK: hlfir.designate %[[B_VAR]]#0
end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_cache_derived_type()
+subroutine test_cache_derived_type()
+ type :: dt
+ real :: array(100)
+ end type
+
+ integer, parameter :: n = 100
+ type(dt) :: data
+ real :: a(n)
+ integer :: i
+
+ !$acc loop
+ do i = 5, n - 4
+ !$acc cache(data%array(i-4:i+4))
+ a(i) = data%array(i)
+ end do
+
+! CHECK: acc.loop
+! CHECK: %[[ARRAY_COORD:.*]] = hlfir.designate %{{.*}}{"array"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_derived_typeTdt{array:!fir.array<100xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<100xf32>>
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {name = "data%array(i-4_4:i+4_4)", structured = false}
+! CHECK: acc.yield
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_cache_derived_type_readonly()
+subroutine test_cache_derived_type_readonly()
+ type :: dt
+ real :: array(100)
+ end type
+
+ integer, parameter :: n = 100
+ type(dt) :: data
+ real :: a(n)
+ integer :: i
+
+ !$acc loop
+ do i = 5, n - 4
+ !$acc cache(readonly: data%array(i-4:i+4))
+ a(i) = data%array(i)
+ end do
+
+! CHECK: acc.loop
+! CHECK: %[[ARRAY_COORD:.*]] = hlfir.designate %{{.*}}{"array"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_derived_type_readonlyTdt{array:!fir.array<100xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<100xf32>>
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {modifiers = #acc<data_clause_modifier readonly>, name = "data%array(i-4_4:i+4_4)", structured = false}
+! CHECK: acc.yield
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_cache_nested_derived_type()
+subroutine test_cache_nested_derived_type()
+ type :: inner
+ real :: arr(50)
+ end type
+
+ type :: outer
+ type(inner) :: in
+ end type
+
+ integer, parameter :: n = 50
+ type(outer) :: obj
+ real :: a(n)
+ integer :: i
+
+ !$acc loop
+ do i = 1, n
+ !$acc cache(obj%in%arr(i))
+ a(i) = obj%in%arr(i)
+ end do
+
+! CHECK: acc.loop
+! CHECK: %[[IN_COORD:.*]] = hlfir.designate %{{.*}}{"in"} : (!fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTouter{in:!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>}>>) -> !fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>>
+! CHECK: %[[ARR_COORD:.*]] = hlfir.designate %[[IN_COORD]]{"arr"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<50xf32>>
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARR_COORD]] : !fir.ref<!fir.array<50xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<50xf32>> {name = "obj%in%arr(i)", structured = false}
+! CHECK: acc.yield
+end subroutine
``````````
</details>
https://github.com/llvm/llvm-project/pull/176022
More information about the flang-commits
mailing list