[flang-commits] [flang] [mlir] [flang][acc] Improve acc lowering around fir.box and arrays (PR #125600)

via flang-commits flang-commits at lists.llvm.org
Mon Feb 3 15:57:33 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openacc

Author: Razvan Lupusoru (razvanlupusoru)

<details>
<summary>Changes</summary>

When the acc dialect was first introduced, explicit expansion of semantics of frontend was required. More specifically, the following logic was included as part of lowering of OpenACC:
- Creation of `acc.bounds` operations for all arrays, including those whose dimensions are captured in the type (eg `!fir.array<100xf32>`)
- Explicit expansion of box types by only putting the box's address in the data clause. The address was extracted with a `fir.box_addr` operation and the bounds were filled with `fir.box_dims` operation.

However, with the creation of the new type interface `MappableType`, the idea is that specific type-based semantics can now be used. This also really simplifies representation in the IR. Consider the following example:
```
subroutine sub(arr)
  real :: arr(:)
  !$acc enter data copyin(arr)
end subroutine
```

Before the current PR, the relevant acc dialect IR looked like:
```
func.func @<!-- -->_QPsub(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name =
"arr"}) {
  ...
  %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsubEarr"} :
(!fir.box<!fir.array<?xf32>>, !fir.dscope) ->
(!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %2:3 = fir.box_dims %1#<!-- -->0, %c0 : (!fir.box<!fir.array<?xf32>>, index)
-> (index, index, index)
  %c0_0 = arith.constant 0 : index
  %3 = arith.subi %2#<!-- -->1, %c1 : index
  %4 = acc.bounds lowerbound(%c0_0 : index) upperbound(%3 : index)
extent(%2#<!-- -->1 : index) stride(%2#<!-- -->2 : index) startIdx(%c1 : index)
{strideInBytes = true}
  %5 = fir.box_addr %1#<!-- -->0 : (!fir.box<!fir.array<?xf32>>) ->
!fir.ref<!fir.array<?xf32>>
  %6 = acc.copyin varPtr(%5 : !fir.ref<!fir.array<?xf32>>) bounds(%4) ->
!fir.ref<!fir.array<?xf32>> {name = "arr", structured = false}
  acc.enter_data dataOperands(%6 : !fir.ref<!fir.array<?xf32>>)
```

After the current change, it looks like:
```
func.func @<!-- -->_QPsub(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name =
"arr"}) {
  ...
  %1:2 = hlfir.declare %arg0 dummy_scope %0 {uniq_name = "_QFsubEarr"} :
(!fir.box<!fir.array<?xf32>>, !fir.dscope) ->
(!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
  %2 = acc.copyin var(%1#<!-- -->0 : !fir.box<!fir.array<?xf32>>) ->
!fir.box<!fir.array<?xf32>> {name = "arr", structured = false}
  acc.enter_data dataOperands(%2 : !fir.box<!fir.array<?xf32>>)
```

Restoring the old behavior can be done with following command line options:
--openacc-unwrap-fir-box=true --openacc-generate-default-bounds=true

---

Patch is 602.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125600.diff


30 Files Affected:

- (modified) flang/include/flang/Lower/DirectivesCommon.h (+29-20) 
- (modified) flang/include/flang/Optimizer/Builder/DirectivesCommon.h (+3-2) 
- (modified) flang/lib/Lower/OpenACC.cpp (+278-170) 
- (modified) flang/test/Lower/OpenACC/acc-bounds.f90 (+1-1) 
- (added) flang/test/Lower/OpenACC/acc-data-operands-unwrap-defaultbounds.f90 (+152) 
- (modified) flang/test/Lower/OpenACC/acc-data-operands.f90 (+15-29) 
- (added) flang/test/Lower/OpenACC/acc-data-unwrap-defaultbounds.f90 (+205) 
- (modified) flang/test/Lower/OpenACC/acc-data.f90 (+53-53) 
- (added) flang/test/Lower/OpenACC/acc-declare-unwrap-defaultbounds.f90 (+476) 
- (modified) flang/test/Lower/OpenACC/acc-declare.f90 (+41-80) 
- (added) flang/test/Lower/OpenACC/acc-enter-data-unwrap-defaultbounds.f90 (+818) 
- (modified) flang/test/Lower/OpenACC/acc-enter-data.f90 (+82-247) 
- (added) flang/test/Lower/OpenACC/acc-exit-data-unwrap-defaultbounds.f90 (+107) 
- (modified) flang/test/Lower/OpenACC/acc-exit-data.f90 (+35-37) 
- (added) flang/test/Lower/OpenACC/acc-host-data-unwrap-defaultbounds.f90 (+52) 
- (modified) flang/test/Lower/OpenACC/acc-host-data.f90 (+4-8) 
- (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+29-33) 
- (modified) flang/test/Lower/OpenACC/acc-kernels.f90 (+41-45) 
- (modified) flang/test/Lower/OpenACC/acc-loop.f90 (+1-1) 
- (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+32-36) 
- (modified) flang/test/Lower/OpenACC/acc-parallel.f90 (+52-56) 
- (added) flang/test/Lower/OpenACC/acc-private-unwrap-defaultbounds.f90 (+403) 
- (modified) flang/test/Lower/OpenACC/acc-private.f90 (+32-44) 
- (added) flang/test/Lower/OpenACC/acc-reduction-unwrap-defaultbounds.f90 (+1213) 
- (modified) flang/test/Lower/OpenACC/acc-reduction.f90 (+113-88) 
- (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+33-37) 
- (modified) flang/test/Lower/OpenACC/acc-serial.f90 (+48-52) 
- (modified) flang/test/Lower/OpenACC/acc-update.f90 (+39-39) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+3-3) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+5-5) 


``````````diff
diff --git a/flang/include/flang/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
index c7cac1357b22771..517bbff4f7515f5 100644
--- a/flang/include/flang/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -584,10 +584,11 @@ void createEmptyRegionBlocks(
 inline fir::factory::AddrAndBoundsInfo
 getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
                        fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+                       Fortran::lower::SymbolRef sym, mlir::Location loc,
+                       bool unwrapFirBox = true) {
   return fir::factory::getDataOperandBaseAddr(
       builder, converter.getSymbolAddress(sym),
-      Fortran::semantics::IsOptional(sym), loc);
+      Fortran::semantics::IsOptional(sym), loc, unwrapFirBox);
 }
 
 namespace detail {
@@ -880,13 +881,15 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
     Fortran::semantics::SymbolRef symbol,
     const Fortran::semantics::MaybeExpr &maybeDesignator,
     mlir::Location operandLocation, std::stringstream &asFortran,
-    llvm::SmallVector<mlir::Value> &bounds, bool treatIndexAsSection = false) {
+    llvm::SmallVector<mlir::Value> &bounds, bool treatIndexAsSection = false,
+    bool unwrapFirBox = true, bool genDefaultBounds = true) {
   using namespace Fortran;
 
   fir::factory::AddrAndBoundsInfo info;
 
   if (!maybeDesignator) {
-    info = getDataOperandBaseAddr(converter, builder, symbol, operandLocation);
+    info = getDataOperandBaseAddr(converter, builder, symbol, operandLocation,
+                                  unwrapFirBox);
     asFortran << symbol->name().ToString();
     return info;
   }
@@ -930,7 +933,8 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
       const semantics::Symbol &sym = arrayRef->GetLastSymbol();
       dataExvIsAssumedSize =
           Fortran::semantics::IsAssumedSizeArray(sym.GetUltimate());
-      info = getDataOperandBaseAddr(converter, builder, sym, operandLocation);
+      info = getDataOperandBaseAddr(converter, builder, sym, operandLocation,
+                                    unwrapFirBox);
       dataExv = converter.getSymbolExtendedValue(sym);
       asFortran << sym.name().ToString();
     }
@@ -947,7 +951,7 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
         converter.genExprAddr(operandLocation, designator, stmtCtx);
     info.addr = fir::getBase(compExv);
     info.rawInput = info.addr;
-    if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(info.addr.getType())))
+    if (genDefaultBounds && mlir::isa<fir::SequenceType>(fir::unwrapRefType(info.addr.getType())))
       bounds = fir::factory::genBaseBoundsOps<BoundsOp, BoundsType>(
           builder, operandLocation, compExv,
           /*isAssumedSize=*/false);
@@ -958,14 +962,17 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
           operandLocation, builder.getI1Type(), info.rawInput);
     }
 
-    if (auto loadOp =
-            mlir::dyn_cast_or_null<fir::LoadOp>(info.addr.getDefiningOp())) {
-      if (fir::isAllocatableType(loadOp.getType()) ||
-          fir::isPointerType(loadOp.getType())) {
-        info.boxType = info.addr.getType();
-        info.addr = builder.create<fir::BoxAddrOp>(operandLocation, info.addr);
+    if (unwrapFirBox) {
+      if (auto loadOp =
+              mlir::dyn_cast_or_null<fir::LoadOp>(info.addr.getDefiningOp())) {
+        if (fir::isAllocatableType(loadOp.getType()) ||
+            fir::isPointerType(loadOp.getType())) {
+          info.boxType = info.addr.getType();
+          info.addr =
+              builder.create<fir::BoxAddrOp>(operandLocation, info.addr);
+        }
+        info.rawInput = info.addr;
       }
-      info.rawInput = info.addr;
     }
 
     // If the component is an allocatable or pointer the result of
@@ -977,8 +984,9 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
       info.addr = boxAddrOp.getVal();
       info.boxType = info.addr.getType();
       info.rawInput = info.addr;
-      bounds = fir::factory::genBoundsOpsFromBox<BoundsOp, BoundsType>(
-          builder, operandLocation, compExv, info);
+      if (genDefaultBounds)
+        bounds = fir::factory::genBoundsOpsFromBox<BoundsOp, BoundsType>(
+            builder, operandLocation, compExv, info);
     }
   } else {
     if (detail::getRef<evaluate::ArrayRef>(designator)) {
@@ -990,17 +998,18 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
     } else if (auto symRef = detail::getRef<semantics::SymbolRef>(designator)) {
       // Scalar or full array.
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(*symRef);
-      info =
-          getDataOperandBaseAddr(converter, builder, *symRef, operandLocation);
-      if (mlir::isa<fir::BaseBoxType>(
-              fir::unwrapRefType(info.addr.getType()))) {
+      info = getDataOperandBaseAddr(converter, builder, *symRef,
+                                    operandLocation, unwrapFirBox);
+      if (genDefaultBounds && mlir::isa<fir::BaseBoxType>(
+                                  fir::unwrapRefType(info.addr.getType()))) {
         info.boxType = fir::unwrapRefType(info.addr.getType());
         bounds = fir::factory::genBoundsOpsFromBox<BoundsOp, BoundsType>(
             builder, operandLocation, dataExv, info);
       }
       bool dataExvIsAssumedSize =
           Fortran::semantics::IsAssumedSizeArray(symRef->get().GetUltimate());
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(info.addr.getType())))
+      if (genDefaultBounds &&
+          mlir::isa<fir::SequenceType>(fir::unwrapRefType(info.addr.getType())))
         bounds = fir::factory::genBaseBoundsOps<BoundsOp, BoundsType>(
             builder, operandLocation, dataExv, dataExvIsAssumedSize);
       asFortran << symRef->get().name().ToString();
diff --git a/flang/include/flang/Optimizer/Builder/DirectivesCommon.h b/flang/include/flang/Optimizer/Builder/DirectivesCommon.h
index 443b0ee59007fa2..4802e346a078eb1 100644
--- a/flang/include/flang/Optimizer/Builder/DirectivesCommon.h
+++ b/flang/include/flang/Optimizer/Builder/DirectivesCommon.h
@@ -54,7 +54,8 @@ struct AddrAndBoundsInfo {
 inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
                                                 mlir::Value symAddr,
                                                 bool isOptional,
-                                                mlir::Location loc) {
+                                                mlir::Location loc,
+                                                bool unwrapFirBox = true) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -80,7 +81,7 @@ inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
+    if (unwrapFirBox && mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 456c30264d0684d..96f21d3c7474f0b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -26,17 +26,32 @@
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/IntrinsicCall.h"
 #include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/expression.h"
 #include "flang/Semantics/scope.h"
 #include "flang/Semantics/tools.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Frontend/OpenACC/ACC.h.inc"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "flang-lower-openacc"
 
+static llvm::cl::opt<bool> unwrapFirBox(
+    "openacc-unwrap-fir-box",
+    llvm::cl::desc(
+        "Whether to use the address from fix.box in data clause operations."),
+    llvm::cl::init(false));
+
+static llvm::cl::opt<bool> generateDefaultBounds(
+    "openacc-generate-default-bounds",
+    llvm::cl::desc("Whether to generate default bounds for arrays."),
+    llvm::cl::init(false));
+
 // Special value for * passed in device_type or gang clauses.
 static constexpr std::int64_t starCst = -1;
 
@@ -94,8 +109,9 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
   // The data clause may apply to either the box reference itself or the
   // pointer to the data it holds. So use `unwrapBoxAddr` to decide.
   // When we have a box value - assume it refers to the data inside box.
-  if ((fir::isBoxAddress(baseAddr.getType()) && unwrapBoxAddr) ||
-      fir::isa_box_type(baseAddr.getType())) {
+  if (unwrapFirBox &&
+      ((fir::isBoxAddress(baseAddr.getType()) && unwrapBoxAddr) ||
+       fir::isa_box_type(baseAddr.getType()))) {
     if (isPresent) {
       mlir::Type ifRetTy =
           mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(baseAddr.getType()))
@@ -140,8 +156,16 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
   op.setStructured(structured);
   op.setImplicit(implicit);
   op.setDataClause(dataClause);
-  op.setVarType(mlir::cast<mlir::acc::PointerLikeType>(baseAddr.getType())
-                    .getElementType());
+  if (auto mappableTy =
+          mlir::dyn_cast<mlir::acc::MappableType>(baseAddr.getType())) {
+    op.setVarType(baseAddr.getType());
+  } else {
+    assert(mlir::isa<mlir::acc::PointerLikeType>(baseAddr.getType()) &&
+           "expected pointer-like");
+    op.setVarType(mlir::cast<mlir::acc::PointerLikeType>(baseAddr.getType())
+                      .getElementType());
+  }
+
   op->setAttr(Op::getOperandSegmentSizeAttr(),
               builder.getDenseI32ArrayAttr(operandSegments));
   if (!asyncDeviceTypes.empty())
@@ -208,7 +232,9 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
 
   llvm::SmallVector<mlir::Value> bounds;
   std::stringstream asFortranDesc;
-  asFortranDesc << asFortran.str() << accFirDescriptorPostfix.str();
+  asFortranDesc << asFortran.str();
+  if (unwrapFirBox)
+    asFortranDesc << accFirDescriptorPostfix.str();
 
   // Updating descriptor must occur before the mapping of the data so that
   // attached data pointer is not overwritten.
@@ -222,17 +248,19 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
-  mlir::Value desc =
-      builder.create<fir::LoadOp>(loc, registerFuncOp.getArgument(0));
-  fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, desc);
-  addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
-  EntryOp entryOp = createDataEntryOp<EntryOp>(
-      builder, loc, boxAddrOp.getResult(), asFortran, bounds,
-      /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(),
-      /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
-  builder.create<mlir::acc::DeclareEnterOp>(
-      loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
-      mlir::ValueRange(entryOp.getAccPtr()));
+  if (unwrapFirBox) {
+    mlir::Value desc =
+        builder.create<fir::LoadOp>(loc, registerFuncOp.getArgument(0));
+    fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, desc);
+    addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
+    EntryOp entryOp = createDataEntryOp<EntryOp>(
+        builder, loc, boxAddrOp.getResult(), asFortran, bounds,
+        /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(),
+        /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
+    builder.create<mlir::acc::DeclareEnterOp>(
+        loc, mlir::acc::DeclareTokenType::get(entryOp.getContext()),
+        mlir::ValueRange(entryOp.getAccVar()));
+  }
 
   modBuilder.setInsertionPointAfter(registerFuncOp);
   builder.restoreInsertionPoint(crtInsPt);
@@ -252,31 +280,36 @@ static void createDeclareDeallocFuncWithArg(
     descTy = fir::ReferenceType::get(descTy);
   auto preDeallocOp = createDeclareFunc(
       modBuilder, builder, loc, preDeallocFuncName.str(), {descTy}, {loc});
-  mlir::Value loadOp =
-      builder.create<fir::LoadOp>(loc, preDeallocOp.getArgument(0));
-  fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
-  addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
+
+  mlir::Value var = preDeallocOp.getArgument(0);
+  if (unwrapFirBox) {
+    mlir::Value loadOp =
+        builder.create<fir::LoadOp>(loc, preDeallocOp.getArgument(0));
+    fir::BoxAddrOp boxAddrOp = builder.create<fir::BoxAddrOp>(loc, loadOp);
+    addDeclareAttr(builder, boxAddrOp.getOperation(), clause);
+    var = boxAddrOp.getResult();
+  }
 
   llvm::SmallVector<mlir::Value> bounds;
   mlir::acc::GetDevicePtrOp entryOp =
       createDataEntryOp<mlir::acc::GetDevicePtrOp>(
-          builder, loc, boxAddrOp.getResult(), asFortran, bounds,
-          /*structured=*/false, /*implicit=*/false, clause, boxAddrOp.getType(),
+          builder, loc, var, asFortran, bounds,
+          /*structured=*/false, /*implicit=*/false, clause, var.getType(),
           /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
   builder.create<mlir::acc::DeclareExitOp>(
-      loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccPtr()));
+      loc, mlir::Value{}, mlir::ValueRange(entryOp.getAccVar()));
 
   if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
                 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
-    builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
-                           entryOp.getVarPtr(), entryOp.getVarType(),
+    builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccVar(),
+                           entryOp.getVar(), entryOp.getVarType(),
                            entryOp.getBounds(), entryOp.getAsyncOperands(),
                            entryOp.getAsyncOperandsDeviceTypeAttr(),
                            entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
                            /*structured=*/false, /*implicit=*/false,
                            builder.getStringAttr(*entryOp.getName()));
   else
-    builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
+    builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccVar(),
                            entryOp.getBounds(), entryOp.getAsyncOperands(),
                            entryOp.getAsyncOperandsDeviceTypeAttr(),
                            entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
@@ -290,13 +323,18 @@ static void createDeclareDeallocFuncWithArg(
                       << Fortran::lower::declarePostDeallocSuffix.str();
   auto postDeallocOp = createDeclareFunc(
       modBuilder, builder, loc, postDeallocFuncName.str(), {descTy}, {loc});
-  loadOp = builder.create<fir::LoadOp>(loc, postDeallocOp.getArgument(0));
-  asFortran << accFirDescriptorPostfix.str();
+
+  var = postDeallocOp.getArgument(0);
+  if (unwrapFirBox) {
+    var = builder.create<fir::LoadOp>(loc, postDeallocOp.getArgument(0));
+    asFortran << accFirDescriptorPostfix.str();
+  }
+
   mlir::acc::UpdateDeviceOp updateDeviceOp =
       createDataEntryOp<mlir::acc::UpdateDeviceOp>(
-          builder, loc, loadOp, asFortran, bounds,
+          builder, loc, var, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
-          mlir::acc::DataClause::acc_update_device, loadOp.getType(),
+          mlir::acc::DataClause::acc_update_device, var.getType(),
           /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
   llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
@@ -357,7 +395,8 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
             mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
             converter, builder, semanticsContext, stmtCtx, symbol, designator,
             operandLocation, asFortran, bounds,
-            /*treatIndexAsSection=*/true);
+            /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
+            /*genDefaultBounds=*/generateDefaultBounds);
     LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
 
     // If the input value is optional and is not a descriptor, we use the
@@ -371,7 +410,7 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
         builder, operandLocation, baseAddr, asFortran, bounds, structured,
         implicit, dataClause, baseAddr.getType(), async, asyncDeviceTypes,
         asyncOnlyDeviceTypes, /*unwrapBoxAddr=*/true, info.isPresent);
-    dataOperands.push_back(op.getAccPtr());
+    dataOperands.push_back(op.getAccVar());
   }
 }
 
@@ -396,14 +435,16 @@ static void genDeclareDataOperandOperations(
         Fortran::lower::gatherDataOperandAddrAndBounds<
             mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
             converter, builder, semanticsContext, stmtCtx, symbol, designator,
-            operandLocation, asFortran, bounds);
+            operandLocation, asFortran, bounds,
+            /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
+            /*genDefaultBounds=*/generateDefaultBounds);
     LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
     EntryOp op = createDataEntryOp<EntryOp>(
         builder, operandLocation, info.addr, asFortran, bounds, structured,
         implicit, dataClause, info.addr.getType(),
         /*async=*/{}, /*asyncDeviceTypes=*/{}, /*asyncOnlyDeviceTypes=*/{});
-    dataOperands.push_back(op.getAccPtr());
-    addDeclareAttr(builder, op.getVarPtr().getDefiningOp(), dataClause);
+    dataOperands.push_back(op.getAccVar());
+    addDeclareAttr(builder, op.getVar().getDefiningOp(), dataClause);
     if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
       mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
       modBuilder.setInsertionPointAfter(builder.getFunction());
@@ -452,14 +493,14 @@ static void genDataExitOperations(fir::FirOpBuilder &builder,
     if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
                   std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
       builder.create<ExitOp>(
-          entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getVarPtr(),
+          entryOp.getLoc(), entryOp.getAccVar(), entryOp.getVar(),
           entryOp.getVarType(), entryOp.getBounds(), entryOp.getAsyncOperands(),
           entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
           entryOp.getDataClause(), structured, entryOp.getImplicit(),
           builder.getStringAttr(*entryOp.getName()));
     else
       builder.create<ExitOp>(
-          entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getBounds(),
+          entryOp.getLoc(), entryOp.getAccVar(), entryOp.getBounds(),
           entryOp.getAsyncOperands(), entryOp.getAsyncOperandsDeviceTypeAttr(),
           entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(), structured,
           entryOp.getImplicit(), builder.getStringAttr(*entryOp.getName()));
@@ -480,41 +521,38 @@ template <typename RecipeOp>
 static void genPrivateLikeInitRegion(mlir::OpBuilder &builder, RecipeOp recipe,
                                      mlir::Type ty, mlir::Location loc) {
   mlir::Value retVal = recipe.getInitRegion().front().getArgument(0);
-  if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(ty)) {
-    if (fir::isa_trivial(refTy.getEleTy())) {
-      auto alloca = builder.create<fir::AllocaOp>(loc, refTy.getEleTy());
+  ty = fir::unwrapRefType(ty);
+  if (fir::isa_trivial(ty)) {
+    auto alloca = builder.create<fir::AllocaOp>(loc, ty);
+    auto declareOp = builder.create<hlfir::DeclareOp>(
+        loc, alloca, accPrivateInitName, /*shape=*/nullp...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/125600


More information about the flang-commits mailing list