[flang-commits] [flang] [flang][acc] Fix cache directive lowering for derived type designators (PR #176022)

via flang-commits flang-commits at lists.llvm.org
Wed Jan 14 11:49:41 PST 2026


https://github.com/khaki3 created https://github.com/llvm/llvm-project/pull/176022

This PR fixes the OpenACC `cache` directive to correctly handle derived type designators such as `data%array(i-4:i+4)`. Previously, the lowering would assert or silently fail because the implementation assumed all cache operands mapped directly to `hlfir.declare` operations.

Problem
-------

The original implementation had several issues:

1. Assertion failure: For derived type members, the symbol refers to the component definition in the type (e.g., `array`), not a variable with an `hlfir.declare`. The code asserted expecting an `hlfir.declare` for all symbols.

2. DCE removal: Even after fixing the assertion, the `acc.cache` operation was being removed by Dead Code Elimination because:
   - The component symbol isn't in the symbol map (it's a type definition, not a variable)
   - Without proper rebinding, the cache result had no uses

3. Incorrect cache usage: Subsequent accesses like `data%array(i)` continued using the original variable instead of the cached reference.

Solution
--------

1. Use `info.addr` from `gatherDataOperandAddrAndBounds`: This properly handles derived type components by returning the address of the component, not requiring a symbol lookup.

2. Use function-level `StatementContext`: Prevents premature cleanup of generated operations.

3. Fallback base address resolution: When `info.addr` is null, fall back to `info.rawInput` or `genExprAddr()`.

4. Proper rebinding for both cases:
   - Simple variables: Rebind the symbol directly via `bindSymbol()`
   - Derived type components: Extract the `Component` reference from the designator and use `addComponentOverride()` so subsequent accesses use the cached value

Testing
-------

Added test cases for:
- `!$acc cache(data%array(i-4:i+4))` - derived type with array section
- `!$acc cache(readonly: data%array(i-4:i+4))` - readonly modifier
- `!$acc cache(obj%in%arr(i))` - nested derived type

All existing tests continue to pass.

Files Changed
-------------

- flang/lib/Lower/OpenACC.cpp - Fix cache directive lowering
- flang/test/Lower/OpenACC/acc-cache.f90 - Add derived type test cases

>From 6fcd943dc19d0047b7072bdc49d4acc7c9d20df1 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Tue, 13 Jan 2026 16:04:22 -0800
Subject: [PATCH 1/2] [flang][acc] Fix cache directive lowering for derived
 type designators

Use info.addr from gatherDataOperandAddrAndBounds instead of looking up
the symbol's variable definition. This properly handles derived type
members like data%array(i-4:i+4).

Also skip symbol rebinding for array sections (when bounds is non-empty)
since we're caching a subset, not the whole symbol.
---
 flang/lib/Lower/OpenACC.cpp            | 35 ++++++------
 flang/test/Lower/OpenACC/acc-cache.f90 | 77 ++++++++++++++++++++++++++
 2 files changed, 96 insertions(+), 16 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 062366f87eb09..aa646716006ca 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -4869,19 +4869,16 @@ genACC(Fortran::lower::AbstractConverter &converter,
         [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
 
     llvm::SmallVector<mlir::Value> bounds;
-    Fortran::lower::gatherDataOperandAddrAndBounds<mlir::acc::DataBoundsOp,
-                                                   mlir::acc::DataBoundsType>(
-        converter, builder, semanticsContext, stmtCtx, symbol, designator,
-        operandLocation, asFortran, bounds,
-        /*treatIndexAsSection=*/true, /*unwrapFirBox=*/false,
-        /*genDefaultBounds=*/false, /*strideIncludeLowerExtent=*/false,
-        /*loadAllocatableAndPointerComponent=*/false);
-
-    std::optional<fir::FortranVariableOpInterface> varDef =
-        converter.getSymbolMap().lookupVariableDefinition(symbol);
-    assert(varDef.has_value() && llvm::isa<hlfir::DeclareOp>(*varDef) &&
-           "expected symbol to be mapped to hlfir.declare");
-    mlir::Value base = varDef->getBase();
+    fir::factory::AddrAndBoundsInfo info =
+        Fortran::lower::gatherDataOperandAddrAndBounds<
+            mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+            converter, builder, semanticsContext, stmtCtx, symbol, designator,
+            operandLocation, asFortran, bounds,
+            /*treatIndexAsSection=*/true, /*unwrapFirBox=*/false,
+            /*genDefaultBounds=*/false, /*strideIncludeLowerExtent=*/false,
+            /*loadAllocatableAndPointerComponent=*/false);
+
+    mlir::Value base = info.addr;
 
     mlir::acc::CacheOp cacheOp = createDataEntryOp<mlir::acc::CacheOp>(
         builder, operandLocation, base, asFortran, bounds,
@@ -4892,9 +4889,15 @@ genACC(Fortran::lower::AbstractConverter &converter,
         isReadonly ? mlir::acc::DataClauseModifier::readonly
                    : mlir::acc::DataClauseModifier::none);
 
-    fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(symbol);
-    fir::ExtendedValue cacheExv = fir::substBase(hostExv, cacheOp.getAccVar());
-    converter.bindSymbol(symbol, cacheExv);
+    // Only rebind the symbol when caching the whole variable (no bounds).
+    // For array sections or component references, the symbol binding doesn't
+    // apply since we're caching a subset, not the whole symbol.
+    if (bounds.empty()) {
+      fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(symbol);
+      fir::ExtendedValue cacheExv =
+          fir::substBase(hostExv, cacheOp.getAccVar());
+      converter.bindSymbol(symbol, cacheExv);
+    }
   }
 }
 
diff --git a/flang/test/Lower/OpenACC/acc-cache.f90 b/flang/test/Lower/OpenACC/acc-cache.f90
index 1cfe064993160..9cc6737f56d8a 100644
--- a/flang/test/Lower/OpenACC/acc-cache.f90
+++ b/flang/test/Lower/OpenACC/acc-cache.f90
@@ -560,3 +560,80 @@ subroutine test_cache_in_nested_do()
 ! CHECK: fir.do_loop
 ! CHECK: hlfir.designate %[[B_VAR]]#0
 end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_cache_derived_type()
+subroutine test_cache_derived_type()
+  type :: dt
+    real :: array(100)
+  end type
+
+  integer, parameter :: n = 100
+  type(dt) :: data
+  real :: a(n)
+  integer :: i
+
+  !$acc loop
+  do i = 5, n - 4
+    !$acc cache(data%array(i-4:i+4))
+    a(i) = data%array(i)
+  end do
+
+! CHECK: acc.loop
+! CHECK: %[[ARRAY_COORD:.*]] = hlfir.designate %{{.*}}{"array"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_derived_typeTdt{array:!fir.array<100xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<100xf32>>
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {name = "data%array(i-4:i+4)"}
+! CHECK: acc.yield
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_cache_derived_type_readonly()
+subroutine test_cache_derived_type_readonly()
+  type :: dt
+    real :: array(100)
+  end type
+
+  integer, parameter :: n = 100
+  type(dt) :: data
+  real :: a(n)
+  integer :: i
+
+  !$acc loop
+  do i = 5, n - 4
+    !$acc cache(readonly: data%array(i-4:i+4))
+    a(i) = data%array(i)
+  end do
+
+! CHECK: acc.loop
+! CHECK: %[[ARRAY_COORD:.*]] = hlfir.designate %{{.*}}{"array"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_derived_type_readonlyTdt{array:!fir.array<100xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<100xf32>>
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {modifiers = #acc<data_clause_modifier readonly>, name = "data%array(i-4:i+4)"}
+! CHECK: acc.yield
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_cache_nested_derived_type()
+subroutine test_cache_nested_derived_type()
+  type :: inner
+    real :: arr(50)
+  end type
+
+  type :: outer
+    type(inner) :: in
+  end type
+
+  integer, parameter :: n = 50
+  type(outer) :: obj
+  real :: a(n)
+  integer :: i
+
+  !$acc loop
+  do i = 1, n
+    !$acc cache(obj%in%arr(i))
+    a(i) = obj%in%arr(i)
+  end do
+
+! CHECK: acc.loop
+! CHECK: %[[IN_COORD:.*]] = hlfir.designate %{{.*}}{"in"} : (!fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTouter{in:!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>}>>) -> !fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>>
+! CHECK: %[[ARR_COORD:.*]] = hlfir.designate %[[IN_COORD]]{"arr"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<50xf32>>
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARR_COORD]] : !fir.ref<!fir.array<50xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<50xf32>> {name = "obj%in%arr(i)"}
+! CHECK: acc.yield
+end subroutine

>From d6dcf430accbd58c3d4c591aea060eee72cfc240 Mon Sep 17 00:00:00 2001
From: Kazuaki Matsumura <kmatsumura at nvidia.com>
Date: Wed, 14 Jan 2026 11:43:57 -0800
Subject: [PATCH 2/2] [flang][OpenACC] Fix cache directive for derived type
 designators

The cache directive was not working correctly for derived type
designators (e.g., data%array(i-4:i+4)). The acc.cache operation was
being created but removed by DCE because:

1. For derived type components, the symbol refers to the component
   definition in the type, not a variable in the symbol map
2. Without proper rebinding, the acc.cache result had no uses

This patch fixes the issue by:

1. Using the function-level StatementContext to prevent premature
   cleanup of generated operations
2. Adding fallback logic for base address resolution when info.addr
   is null (common for derived type components)
3. For simple variables: rebinding the symbol directly via bindSymbol()
4. For derived type components: extracting the Component reference
   from the designator and using addComponentOverride() so subsequent
   accesses use the cached value

The component override mechanism ensures that when subsequent
statements like 'a(i) = data%array(i)' are lowered, ConvertExprToHLFIR
finds the override and uses the cached reference instead of
re-accessing the original component.
---
 flang/lib/Lower/OpenACC.cpp            | 60 +++++++++++++++++++++++---
 flang/test/Lower/OpenACC/acc-cache.f90 |  6 +--
 2 files changed, 56 insertions(+), 10 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index aa646716006ca..fb25991cbdcea 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -4856,12 +4856,11 @@ genACC(Fortran::lower::AbstractConverter &converter,
       modifier &&
       (*modifier).v == Fortran::parser::AccDataModifier::Modifier::ReadOnly;
 
-  Fortran::lower::StatementContext stmtCtx;
+  Fortran::lower::StatementContext &stmtCtx = converter.getFctCtx();
 
   for (const auto &accObject : accObjectList.v) {
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
-
     std::stringstream asFortran;
 
     Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
@@ -4878,7 +4877,13 @@ genACC(Fortran::lower::AbstractConverter &converter,
             /*genDefaultBounds=*/false, /*strideIncludeLowerExtent=*/false,
             /*loadAllocatableAndPointerComponent=*/false);
 
-    mlir::Value base = info.addr;
+    mlir::Value base = info.addr ? info.addr : info.rawInput;
+    if (!base && designator)
+      base = fir::getBase(
+          converter.genExprAddr(operandLocation, *designator, stmtCtx));
+
+    if (!base)
+      continue;
 
     mlir::acc::CacheOp cacheOp = createDataEntryOp<mlir::acc::CacheOp>(
         builder, operandLocation, base, asFortran, bounds,
@@ -4889,14 +4894,55 @@ genACC(Fortran::lower::AbstractConverter &converter,
         isReadonly ? mlir::acc::DataClauseModifier::readonly
                    : mlir::acc::DataClauseModifier::none);
 
-    // Only rebind the symbol when caching the whole variable (no bounds).
-    // For array sections or component references, the symbol binding doesn't
-    // apply since we're caching a subset, not the whole symbol.
-    if (bounds.empty()) {
+    // Rebind the symbol so subsequent references use the cached value.
+    if (Fortran::lower::SymbolBox symBox =
+            converter.getSymbolMap().lookupSymbol(symbol)) {
+      // For simple variables, rebind the symbol directly.
       fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(symbol);
       fir::ExtendedValue cacheExv =
           fir::substBase(hostExv, cacheOp.getAccVar());
       converter.bindSymbol(symbol, cacheExv);
+    } else if (designator) {
+      // For derived type components, extract the component reference and
+      // add a component override so subsequent accesses use the cached value.
+      std::optional<Fortran::evaluate::Component> componentRef;
+      if (std::optional<Fortran::evaluate::DataRef> dataRef =
+              Fortran::evaluate::ExtractDataRef(*designator)) {
+        Fortran::common::visit(
+            Fortran::common::visitors{
+                [&](const Fortran::evaluate::Component &component) {
+                  componentRef = component;
+                },
+                [&](const Fortran::evaluate::ArrayRef &arrayRef) {
+                  if (auto *comp = arrayRef.base().UnwrapComponent())
+                    componentRef = *comp;
+                },
+                [&](const auto &) {}},
+            dataRef->u);
+      }
+
+      if (componentRef) {
+        // Create an hlfir.declare for the cache result and add component
+        // override so subsequent accesses use the cached value.
+        llvm::SmallVector<mlir::Value> lenParams;
+        mlir::Value shape;
+        if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(
+                fir::unwrapRefType(base.getType()))) {
+          llvm::SmallVector<mlir::Value> extents;
+          for (int64_t ext : arrTy.getShape()) {
+            if (ext != fir::SequenceType::getUnknownExtent())
+              extents.push_back(builder.createIntegerConstant(
+                  operandLocation, builder.getIndexType(), ext));
+          }
+          if (!extents.empty())
+            shape = builder.genShape(operandLocation, extents);
+        }
+        auto declareOp = hlfir::DeclareOp::create(
+            builder, operandLocation, cacheOp.getAccVar(), asFortran.str(),
+            shape, lenParams, /*dummyScope=*/nullptr, /*storage=*/nullptr,
+            /*storageOffset=*/0, fir::FortranVariableFlagsAttr{});
+        converter.getSymbolMap().addComponentOverride(*componentRef, declareOp);
+      }
     }
   }
 }
diff --git a/flang/test/Lower/OpenACC/acc-cache.f90 b/flang/test/Lower/OpenACC/acc-cache.f90
index 9cc6737f56d8a..3fe5131364679 100644
--- a/flang/test/Lower/OpenACC/acc-cache.f90
+++ b/flang/test/Lower/OpenACC/acc-cache.f90
@@ -581,7 +581,7 @@ subroutine test_cache_derived_type()
 ! CHECK: acc.loop
 ! CHECK: %[[ARRAY_COORD:.*]] = hlfir.designate %{{.*}}{"array"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_derived_typeTdt{array:!fir.array<100xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<100xf32>>
 ! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
-! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {name = "data%array(i-4:i+4)"}
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {name = "data%array(i-4_4:i+4_4)", structured = false}
 ! CHECK: acc.yield
 end subroutine
 
@@ -605,7 +605,7 @@ subroutine test_cache_derived_type_readonly()
 ! CHECK: acc.loop
 ! CHECK: %[[ARRAY_COORD:.*]] = hlfir.designate %{{.*}}{"array"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_derived_type_readonlyTdt{array:!fir.array<100xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<100xf32>>
 ! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
-! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {modifiers = #acc<data_clause_modifier readonly>, name = "data%array(i-4:i+4)"}
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {modifiers = #acc<data_clause_modifier readonly>, name = "data%array(i-4_4:i+4_4)", structured = false}
 ! CHECK: acc.yield
 end subroutine
 
@@ -634,6 +634,6 @@ subroutine test_cache_nested_derived_type()
 ! CHECK: %[[IN_COORD:.*]] = hlfir.designate %{{.*}}{"in"} : (!fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTouter{in:!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>}>>) -> !fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>>
 ! CHECK: %[[ARR_COORD:.*]] = hlfir.designate %[[IN_COORD]]{"arr"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<50xf32>>
 ! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
-! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARR_COORD]] : !fir.ref<!fir.array<50xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<50xf32>> {name = "obj%in%arr(i)"}
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARR_COORD]] : !fir.ref<!fir.array<50xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<50xf32>> {name = "obj%in%arr(i)", structured = false}
 ! CHECK: acc.yield
 end subroutine



More information about the flang-commits mailing list