[flang-commits] [flang] a2a5cda - [flang][acc] Fix cache directive for derived type components (#176022)

via flang-commits flang-commits at lists.llvm.org
Tue Jan 20 01:14:57 PST 2026


Author: khaki3
Date: 2026-01-20T01:14:51-08:00
New Revision: a2a5cda0bcf0ad5893eacfea3daff90d597814b9

URL: https://github.com/llvm/llvm-project/commit/a2a5cda0bcf0ad5893eacfea3daff90d597814b9
DIFF: https://github.com/llvm/llvm-project/commit/a2a5cda0bcf0ad5893eacfea3daff90d597814b9.diff

LOG: [flang][acc] Fix cache directive for derived type components (#176022)

Fix the OpenACC cache directive to correctly handle derived type
designators (e.g., `data%array(i-4:i+4)`). The original implementation
assumed all cache operands were simple variables with symbols in the
symbol map. For derived type components, use
`gatherDataOperandAddrAndBounds()` to get the component address and
`addComponentOverride()` so subsequent accesses use the cached value.

Added: 
    

Modified: 
    flang/lib/Lower/OpenACC.cpp
    flang/test/Lower/OpenACC/acc-cache.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index d4e1ace270b3f..1e313b20d464c 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -718,6 +718,31 @@ class AccDataMap {
 };
 } // namespace
 
+/// Extract the component reference from a designator expression, if any.
+/// For `data%array(i)`, this returns the Component for `data%array`.
+static std::optional<Fortran::evaluate::Component>
+extractComponentFromDesignator(
+    const Fortran::semantics::MaybeExpr &designator) {
+  if (!designator)
+    return std::nullopt;
+  std::optional<Fortran::evaluate::Component> componentRef;
+  if (std::optional<Fortran::evaluate::DataRef> dataRef =
+          Fortran::evaluate::ExtractDataRef(*designator)) {
+    Fortran::common::visit(
+        Fortran::common::visitors{
+            [&](const Fortran::evaluate::Component &component) {
+              componentRef = component;
+            },
+            [&](const Fortran::evaluate::ArrayRef &arrayRef) {
+              if (auto *comp = arrayRef.base().UnwrapComponent())
+                componentRef = *comp;
+            },
+            [](const auto &) {}},
+        dataRef->u);
+  }
+  return componentRef;
+}
+
 template <typename Op>
 static void
 genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
@@ -741,23 +766,10 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
 
     Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
 
-    std::optional<Fortran::evaluate::Component> componentRef;
     Fortran::semantics::MaybeExpr designator = Fortran::common::visit(
         [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
-    if (std::optional<Fortran::evaluate::DataRef> dataRef =
-            Fortran::evaluate::ExtractDataRef(designator)) {
-      Fortran::common::visit(
-          Fortran::common::visitors{
-              [&](const Fortran::evaluate::Component &component) {
-                componentRef = component;
-              },
-              [&](const Fortran::evaluate::ArrayRef &arrayRef) {
-                if (auto *comp = arrayRef.base().UnwrapComponent())
-                  componentRef = *comp;
-              },
-              [](const auto &) {}},
-          dataRef->u);
-    }
+    std::optional<Fortran::evaluate::Component> componentRef =
+        extractComponentFromDesignator(designator);
 
     bool isPrivate = std::is_same_v<Op, mlir::acc::PrivateOp> ||
                      std::is_same_v<Op, mlir::acc::FirstprivateOp>;
@@ -4645,7 +4657,6 @@ genACC(Fortran::lower::AbstractConverter &converter,
   for (const auto &accObject : accObjectList.v) {
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
-
     std::stringstream asFortran;
 
     Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
@@ -4653,19 +4664,16 @@ genACC(Fortran::lower::AbstractConverter &converter,
         [&](auto &&s) { return ea.Analyze(s); }, accObject.u);
 
     llvm::SmallVector<mlir::Value> bounds;
-    Fortran::lower::gatherDataOperandAddrAndBounds<mlir::acc::DataBoundsOp,
-                                                   mlir::acc::DataBoundsType>(
-        converter, builder, semanticsContext, stmtCtx, symbol, designator,
-        operandLocation, asFortran, bounds,
-        /*treatIndexAsSection=*/true, /*unwrapFirBox=*/false,
-        /*genDefaultBounds=*/false, /*strideIncludeLowerExtent=*/false,
-        /*loadAllocatableAndPointerComponent=*/false);
-
-    std::optional<fir::FortranVariableOpInterface> varDef =
-        converter.getSymbolMap().lookupVariableDefinition(symbol);
-    assert(varDef.has_value() && llvm::isa<hlfir::DeclareOp>(*varDef) &&
-           "expected symbol to be mapped to hlfir.declare");
-    mlir::Value base = varDef->getBase();
+    fir::factory::AddrAndBoundsInfo info =
+        Fortran::lower::gatherDataOperandAddrAndBounds<
+            mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
+            converter, builder, semanticsContext, stmtCtx, symbol, designator,
+            operandLocation, asFortran, bounds,
+            /*treatIndexAsSection=*/true, /*unwrapFirBox=*/false,
+            /*genDefaultBounds=*/false, /*strideIncludeLowerExtent=*/false,
+            /*loadAllocatableAndPointerComponent=*/false);
+
+    mlir::Value base = info.addr;
 
     mlir::acc::CacheOp cacheOp = createDataEntryOp<mlir::acc::CacheOp>(
         builder, operandLocation, base, asFortran, bounds,
@@ -4676,9 +4684,31 @@ genACC(Fortran::lower::AbstractConverter &converter,
         isReadonly ? mlir::acc::DataClauseModifier::readonly
                    : mlir::acc::DataClauseModifier::none);
 
-    fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(symbol);
-    fir::ExtendedValue cacheExv = fir::substBase(hostExv, cacheOp.getAccVar());
-    converter.bindSymbol(symbol, cacheExv);
+    // Rebind the symbol so subsequent references use the cached value.
+    if (Fortran::lower::SymbolBox symBox =
+            converter.getSymbolMap().lookupSymbol(symbol)) {
+      // For simple variables, rebind the symbol directly.
+      fir::ExtendedValue hostExv = converter.getSymbolExtendedValue(symbol);
+      fir::ExtendedValue cacheExv =
+          fir::substBase(hostExv, cacheOp.getAccVar());
+      converter.bindSymbol(symbol, cacheExv);
+    } else {
+      // Must be a derived type component reference.
+      assert(designator && "expected designator for non-symbol cache operand");
+      std::optional<Fortran::evaluate::Component> componentRef =
+          extractComponentFromDesignator(designator);
+      assert(componentRef &&
+             "expected component reference for derived type cache operand");
+      // Component references are lowered to designate operations.
+      auto designate = base.getDefiningOp<hlfir::DesignateOp>();
+      assert(designate && "expected designate op for component reference");
+      auto declareOp = hlfir::DeclareOp::create(
+          builder, operandLocation, cacheOp.getAccVar(), asFortran.str(),
+          designate.getShape(), designate.getTypeparams(),
+          /*dummyScope=*/nullptr, /*storage=*/nullptr,
+          /*storageOffset=*/0, designate.getFortranAttrsAttr());
+      converter.getSymbolMap().addComponentOverride(*componentRef, declareOp);
+    }
   }
 }
 

diff  --git a/flang/test/Lower/OpenACC/acc-cache.f90 b/flang/test/Lower/OpenACC/acc-cache.f90
index 1cfe064993160..22dd0a84aee8a 100644
--- a/flang/test/Lower/OpenACC/acc-cache.f90
+++ b/flang/test/Lower/OpenACC/acc-cache.f90
@@ -560,3 +560,106 @@ subroutine test_cache_in_nested_do()
 ! CHECK: fir.do_loop
 ! CHECK: hlfir.designate %[[B_VAR]]#0
 end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_cache_derived_type()
+subroutine test_cache_derived_type()
+  type :: dt
+    real :: array(100)
+  end type
+
+  integer, parameter :: n = 100
+  type(dt) :: data
+  real :: a(n)
+  integer :: i
+
+  !$acc loop
+  do i = 5, n - 4
+    !$acc cache(data%array(i-4:i+4))
+    a(i) = data%array(i)
+  end do
+
+! CHECK: acc.loop
+! CHECK: %[[ARRAY_COORD:.*]] = hlfir.designate %{{.*}}{"array"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_derived_typeTdt{array:!fir.array<100xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<100xf32>>
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {name = "data%array(i-4_4:i+4_4)", structured = false}
+! CHECK: acc.yield
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_cache_derived_type_readonly()
+subroutine test_cache_derived_type_readonly()
+  type :: dt
+    real :: array(100)
+  end type
+
+  integer, parameter :: n = 100
+  type(dt) :: data
+  real :: a(n)
+  integer :: i
+
+  !$acc loop
+  do i = 5, n - 4
+    !$acc cache(readonly: data%array(i-4:i+4))
+    a(i) = data%array(i)
+  end do
+
+! CHECK: acc.loop
+! CHECK: %[[ARRAY_COORD:.*]] = hlfir.designate %{{.*}}{"array"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_derived_type_readonlyTdt{array:!fir.array<100xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<100xf32>>
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARRAY_COORD]] : !fir.ref<!fir.array<100xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {modifiers = #acc<data_clause_modifier readonly>, name = "data%array(i-4_4:i+4_4)", structured = false}
+! CHECK: acc.yield
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_cache_nested_derived_type()
+subroutine test_cache_nested_derived_type()
+  type :: inner
+    real :: arr(50)
+  end type
+
+  type :: outer
+    type(inner) :: in
+  end type
+
+  integer, parameter :: n = 50
+  type(outer) :: obj
+  real :: a(n)
+  integer :: i
+
+  !$acc loop
+  do i = 1, n
+    !$acc cache(obj%in%arr(i))
+    a(i) = obj%in%arr(i)
+  end do
+
+! CHECK: acc.loop
+! CHECK: %[[IN_COORD:.*]] = hlfir.designate %{{.*}}{"in"} : (!fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTouter{in:!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>}>>) -> !fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>>
+! CHECK: %[[ARR_COORD:.*]] = hlfir.designate %[[IN_COORD]]{"arr"} shape %{{.*}} : (!fir.ref<!fir.type<_QFtest_cache_nested_derived_typeTinner{arr:!fir.array<50xf32>}>>, !fir.shape<1>) -> !fir.ref<!fir.array<50xf32>>
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%{{.*}} : index) startIdx(%{{.*}} : index)
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%[[ARR_COORD]] : !fir.ref<!fir.array<50xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<50xf32>> {name = "obj%in%arr(i)", structured = false}
+! CHECK: acc.yield
+end subroutine
+
+! Test cache with temporary in designator bounds - verifies local statement context
+! doesn't cause issues with temporary cleanup
+! CHECK-LABEL: func.func @_QPtest_cache_temp_in_designator(
+subroutine test_cache_temp_in_designator(data, a)
+  integer, parameter :: n = 100
+  real :: data(n)
+  real :: a(n)
+  integer :: i
+
+  !$acc loop
+  do i = 5, n - 4
+    !$acc cache(readonly: data(1:maxloc(a+a, dim=1)))
+    a(i) = data(i)
+  end do
+
+! CHECK: acc.loop
+! CHECK: %[[ELEMENTAL:.*]] = hlfir.elemental
+! CHECK: %[[MAXLOC:.*]] = hlfir.maxloc %[[ELEMENTAL]]
+! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound({{.*}}) upperbound({{.*}})
+! CHECK: %[[CACHE:.*]] = acc.cache varPtr(%{{.*}}) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xf32>> {modifiers = #acc<data_clause_modifier readonly>, name = "data(1:maxloc(a+a,dim=1_4))", structured = false}
+! CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[CACHE]]
+! CHECK: hlfir.destroy %[[ELEMENTAL]]
+! CHECK: hlfir.designate %[[DECL]]#0
+! CHECK: acc.yield
+end subroutine


        


More information about the flang-commits mailing list