[flang-commits] [flang] [flang][openacc] Use original input for base address with optional (PR #80931)

via flang-commits flang-commits at lists.llvm.org
Tue Feb 6 19:52:09 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

In #<!-- -->80317 the data op generation was updated to use correctly the #<!-- -->0 result from the hlfir.delcare op. In case of optional that are not descriptor, it is preferable to use the original input for the varPtr value of the OpenACC data op.
This patch also make sure that the descriptor value of optional is only accessed when present. 

---
Full diff: https://github.com/llvm/llvm-project/pull/80931.diff


3 Files Affected:

- (modified) flang/lib/Lower/DirectivesCommon.h (+73-20) 
- (modified) flang/lib/Lower/OpenACC.cpp (+15-5) 
- (modified) flang/test/Lower/OpenACC/acc-bounds.f90 (+36-2) 


``````````diff
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index bd880376517dd..8d560db34e05b 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -52,10 +52,13 @@ namespace lower {
 /// operations.
 struct AddrAndBoundsInfo {
   explicit AddrAndBoundsInfo() {}
-  explicit AddrAndBoundsInfo(mlir::Value addr) : addr(addr) {}
-  explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value isPresent)
-      : addr(addr), isPresent(isPresent) {}
+  explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput)
+      : addr(addr), rawInput(rawInput) {}
+  explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput,
+                             mlir::Value isPresent)
+      : addr(addr), rawInput(rawInput), isPresent(isPresent) {}
   mlir::Value addr = nullptr;
+  mlir::Value rawInput = nullptr;
   mlir::Value isPresent = nullptr;
 };
 
@@ -615,20 +618,30 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
                        fir::FirOpBuilder &builder,
                        Fortran::lower::SymbolRef sym, mlir::Location loc) {
   mlir::Value symAddr = converter.getSymbolAddress(sym);
+  mlir::Value rawInput = symAddr;
   if (auto declareOp =
-          mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp()))
+          mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
     symAddr = declareOp.getResults()[0];
+    rawInput = declareOp.getResults()[1];
+  }
 
   // TODO: Might need revisiting to handle for non-shared clauses
   if (!symAddr) {
     if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>())
+            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
       symAddr = converter.getSymbolAddress(details->symbol());
+      rawInput = symAddr;
+    }
   }
 
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
+  mlir::Value isPresent;
+  if (Fortran::semantics::IsOptional(sym))
+    isPresent =
+        builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
+
   if (auto boxTy =
           fir::unwrapRefType(symAddr.getType()).dyn_cast<fir::BaseBoxType>()) {
     if (boxTy.getEleTy().isa<fir::RecordType>())
@@ -638,8 +651,6 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // `fir.ref<fir.class<T>>` type.
     if (symAddr.getType().isa<fir::ReferenceType>()) {
       if (Fortran::semantics::IsOptional(sym)) {
-        mlir::Value isPresent =
-            builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), symAddr);
         mlir::Value addr =
             builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
                 .genThen([&]() {
@@ -652,14 +663,13 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
                   builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
                 })
                 .getResults()[0];
-        return AddrAndBoundsInfo(addr, isPresent);
+        return AddrAndBoundsInfo(addr, rawInput, isPresent);
       }
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
-      return AddrAndBoundsInfo(addr);
-      ;
+      return AddrAndBoundsInfo(addr, rawInput, isPresent);
     }
   }
-  return AddrAndBoundsInfo(symAddr);
+  return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
 template <typename BoundsOp, typename BoundsType>
@@ -807,7 +817,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
              Fortran::lower::StatementContext &stmtCtx,
              const std::list<Fortran::parser::SectionSubscript> &subscripts,
              std::stringstream &asFortran, fir::ExtendedValue &dataExv,
-             bool dataExvIsAssumedSize, mlir::Value baseAddr,
+             bool dataExvIsAssumedSize, AddrAndBoundsInfo &info,
              bool treatIndexAsSection = false) {
   int dimension = 0;
   mlir::Type idxTy = builder.getIndexType();
@@ -831,11 +841,30 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
       mlir::Value stride = one;
       bool strideInBytes = false;
 
-      if (fir::unwrapRefType(baseAddr.getType()).isa<fir::BaseBoxType>()) {
-        mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
-        auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
-                                                      baseAddr, d);
-        stride = dimInfo.getByteStride();
+      if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
+        if (info.isPresent) {
+          stride =
+              builder
+                  .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
+                  .genThen([&]() {
+                    mlir::Value d =
+                        builder.createIntegerConstant(loc, idxTy, dimension);
+                    auto dimInfo = builder.create<fir::BoxDimsOp>(
+                        loc, idxTy, idxTy, idxTy, info.addr, d);
+                    builder.create<fir::ResultOp>(loc, dimInfo.getByteStride());
+                  })
+                  .genElse([&] {
+                    mlir::Value zero =
+                        builder.createIntegerConstant(loc, idxTy, 0);
+                    builder.create<fir::ResultOp>(loc, zero);
+                  })
+                  .getResults()[0];
+        } else {
+          mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
+          auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy,
+                                                        idxTy, info.addr, d);
+          stride = dimInfo.getByteStride();
+        }
         strideInBytes = true;
       }
 
@@ -919,7 +948,26 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
           }
         }
 
-        extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
+        if (info.isPresent &&
+            fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
+          extent =
+              builder
+                  .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
+                  .genThen([&]() {
+                    mlir::Value ext = fir::factory::readExtent(
+                        builder, loc, dataExv, dimension);
+                    builder.create<fir::ResultOp>(loc, ext);
+                  })
+                  .genElse([&] {
+                    mlir::Value zero =
+                        builder.createIntegerConstant(loc, idxTy, 0);
+                    builder.create<fir::ResultOp>(loc, zero);
+                  })
+                  .getResults()[0];
+        } else {
+          extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
+        }
+
         if (dataExvIsAssumedSize && dimension + 1 == dataExvRank) {
           extent = zero;
           if (ubound && lbound) {
@@ -976,6 +1024,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                   dataExv = converter.genExprAddr(operandLocation, *exprBase,
                                                   stmtCtx);
                   info.addr = fir::getBase(dataExv);
+                  info.rawInput = info.addr;
                   asFortran << (*exprBase).AsFortran();
                 } else {
                   const Fortran::parser::Name &name =
@@ -993,7 +1042,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                   bounds = genBoundsOps<BoundsOp, BoundsType>(
                       builder, operandLocation, converter, stmtCtx,
                       arrayElement->subscripts, asFortran, dataExv,
-                      dataExvIsAssumedSize, info.addr, treatIndexAsSection);
+                      dataExvIsAssumedSize, info, treatIndexAsSection);
                 }
                 asFortran << ')';
               } else if (auto structComp = Fortran::parser::Unwrap<
@@ -1001,6 +1050,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                 fir::ExtendedValue compExv =
                     converter.genExprAddr(operandLocation, *expr, stmtCtx);
                 info.addr = fir::getBase(compExv);
+                info.rawInput = info.addr;
                 if (fir::unwrapRefType(info.addr.getType())
                         .isa<fir::SequenceType>())
                   bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
@@ -1012,7 +1062,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                     *Fortran::parser::GetLastName(*structComp).symbol);
                 if (isOptional)
                   info.isPresent = builder.create<fir::IsPresentOp>(
-                      operandLocation, builder.getI1Type(), info.addr);
+                      operandLocation, builder.getI1Type(), info.rawInput);
 
                 if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(
                         info.addr.getDefiningOp())) {
@@ -1020,6 +1070,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                       fir::isPointerType(loadOp.getType()))
                     info.addr = builder.create<fir::BoxAddrOp>(operandLocation,
                                                                info.addr);
+                  info.rawInput = info.addr;
                 }
 
                 // If the component is an allocatable or pointer the result of
@@ -1029,6 +1080,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                 if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
                         info.addr.getDefiningOp())) {
                   info.addr = boxAddrOp.getVal();
+                  info.rawInput = info.addr;
                   bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
                       builder, operandLocation, converter, compExv, info);
                 }
@@ -1043,6 +1095,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                   fir::ExtendedValue compExv =
                       converter.genExprAddr(operandLocation, *expr, stmtCtx);
                   info.addr = fir::getBase(compExv);
+                  info.rawInput = info.addr;
                   asFortran << (*expr).AsFortran();
                 } else if (const auto *dataRef{
                                std::get_if<Fortran::parser::DataRef>(
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 43f54c6d2a71b..6ae270f63f5cf 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -67,9 +67,12 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
   mlir::Value varPtrPtr;
   if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
     if (isPresent) {
+      mlir::Type ifRetTy = boxTy.getEleTy();
+      if (!fir::isa_ref_type(ifRetTy))
+        ifRetTy = fir::ReferenceType::get(ifRetTy);
       baseAddr =
           builder
-              .genIfOp(loc, {boxTy.getEleTy()}, isPresent,
+              .genIfOp(loc, {ifRetTy}, isPresent,
                        /*withElseRegion=*/true)
               .genThen([&]() {
                 mlir::Value boxAddr =
@@ -78,7 +81,7 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
               })
               .genElse([&] {
                 mlir::Value absent =
-                    builder.create<fir::AbsentOp>(loc, boxTy.getEleTy());
+                    builder.create<fir::AbsentOp>(loc, ifRetTy);
                 builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
               })
               .getResults()[0];
@@ -295,9 +298,16 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
                                        asFortran, bounds,
                                        /*treatIndexAsSection=*/true);
 
-    Op op = createDataEntryOp<Op>(
-        builder, operandLocation, info.addr, asFortran, bounds, structured,
-        implicit, dataClause, info.addr.getType(), info.isPresent);
+    // If the input value is optional and is not a descriptor, we use the
+    // rawInput directly.
+    mlir::Value baseAddr =
+        ((info.addr.getType() != fir::unwrapRefType(info.rawInput.getType())) &&
+         info.isPresent)
+            ? info.rawInput
+            : info.addr;
+    Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
+                                  bounds, structured, implicit, dataClause,
+                                  baseAddr.getType(), info.isPresent);
     dataOperands.push_back(op.getAccPtr());
   }
 }
diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90
index bd96bc8bcba35..df97cbcd187d2 100644
--- a/flang/test/Lower/OpenACC/acc-bounds.f90
+++ b/flang/test/Lower/OpenACC/acc-bounds.f90
@@ -126,8 +126,8 @@ subroutine acc_optional_data(a)
   
 ! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data(
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a", fir.optional}) {
-! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
-! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#0 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
+! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
 ! CHECK: %[[BOX:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box<!fir.ptr<!fir.array<?xf32>>>) {
 ! CHECK:   %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
 ! CHECK:   fir.result %[[LOAD]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
@@ -153,4 +153,38 @@ subroutine acc_optional_data(a)
 ! CHECK: %[[ATTACH:.*]] = acc.attach varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
 ! CHECK: acc.data dataOperands(%[[ATTACH]] : !fir.ptr<!fir.array<?xf32>>)
 
+  subroutine acc_optional_data2(a, n)
+    integer :: n
+    real, optional :: a(n)
+    !$acc data no_create(a)
+    !$acc end data
+  end subroutine
+
+! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data2(
+! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
+! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "_QMopenacc_boundsFacc_optional_data2Ea"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+! CHECK: %[[NO_CREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%10) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
+! CHECK: acc.data dataOperands(%[[NO_CREATE]] : !fir.ref<!fir.array<?xf32>>) {
+
+  subroutine acc_optional_data3(a, n)
+    integer :: n
+    real, optional :: a(n)
+    !$acc data no_create(a(1:n))
+    !$acc end data
+  end subroutine
+
+! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data3(
+! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
+! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "_QMopenacc_boundsFacc_optional_data3Ea"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+! CHECK: %[[PRES:.*]] = fir.is_present %[[DECL_A]]#1 : (!fir.ref<!fir.array<?xf32>>) -> i1
+! CHECK: %[[STRIDE:.*]] = fir.if %[[PRES]] -> (index) {
+! CHECK:   %[[DIMS:.*]]:3 = fir.box_dims %[[DECL_A]]#0, %c0{{.*}} : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+! CHECK:   fir.result %[[DIMS]]#2 : index
+! CHECK: } else {
+! CHECK:   fir.result %c0{{.*}} : index
+! CHECK: }
+! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true}
+! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%14) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
+! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref<!fir.array<?xf32>>) {
+
 end module

``````````

</details>


https://github.com/llvm/llvm-project/pull/80931


More information about the flang-commits mailing list