[flang-commits] [flang] [flang][openacc] Use original input for base address with optional (PR #80931)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Feb 6 19:51:42 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/80931

In #80317 the data op generation was updated to use correctly the #0 result from the hlfir.delcare op. In case of optional that are not descriptor, it is preferable to use the original input for the varPtr value of the OpenACC data op.
This patch also make sure that the descriptor value of optional is only accessed when present. 

>From 36b5a26f6cad6049622e6f66d109dfbb3d7dfb03 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 5 Feb 2024 12:58:27 -0800
Subject: [PATCH] [flang][openacc] Use original input for base address with
 optional

---
 flang/lib/Lower/DirectivesCommon.h      | 93 +++++++++++++++++++------
 flang/lib/Lower/OpenACC.cpp             | 20 ++++--
 flang/test/Lower/OpenACC/acc-bounds.f90 | 38 +++++++++-
 3 files changed, 124 insertions(+), 27 deletions(-)

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index bd880376517dd8..8d560db34e05bf 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -52,10 +52,13 @@ namespace lower {
 /// operations.
 struct AddrAndBoundsInfo {
   explicit AddrAndBoundsInfo() {}
-  explicit AddrAndBoundsInfo(mlir::Value addr) : addr(addr) {}
-  explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value isPresent)
-      : addr(addr), isPresent(isPresent) {}
+  explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput)
+      : addr(addr), rawInput(rawInput) {}
+  explicit AddrAndBoundsInfo(mlir::Value addr, mlir::Value rawInput,
+                             mlir::Value isPresent)
+      : addr(addr), rawInput(rawInput), isPresent(isPresent) {}
   mlir::Value addr = nullptr;
+  mlir::Value rawInput = nullptr;
   mlir::Value isPresent = nullptr;
 };
 
@@ -615,20 +618,30 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
                        fir::FirOpBuilder &builder,
                        Fortran::lower::SymbolRef sym, mlir::Location loc) {
   mlir::Value symAddr = converter.getSymbolAddress(sym);
+  mlir::Value rawInput = symAddr;
   if (auto declareOp =
-          mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp()))
+          mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
     symAddr = declareOp.getResults()[0];
+    rawInput = declareOp.getResults()[1];
+  }
 
   // TODO: Might need revisiting to handle for non-shared clauses
   if (!symAddr) {
     if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>())
+            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
       symAddr = converter.getSymbolAddress(details->symbol());
+      rawInput = symAddr;
+    }
   }
 
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
+  mlir::Value isPresent;
+  if (Fortran::semantics::IsOptional(sym))
+    isPresent =
+        builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
+
   if (auto boxTy =
           fir::unwrapRefType(symAddr.getType()).dyn_cast<fir::BaseBoxType>()) {
     if (boxTy.getEleTy().isa<fir::RecordType>())
@@ -638,8 +651,6 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // `fir.ref<fir.class<T>>` type.
     if (symAddr.getType().isa<fir::ReferenceType>()) {
       if (Fortran::semantics::IsOptional(sym)) {
-        mlir::Value isPresent =
-            builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), symAddr);
         mlir::Value addr =
             builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
                 .genThen([&]() {
@@ -652,14 +663,13 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
                   builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
                 })
                 .getResults()[0];
-        return AddrAndBoundsInfo(addr, isPresent);
+        return AddrAndBoundsInfo(addr, rawInput, isPresent);
       }
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
-      return AddrAndBoundsInfo(addr);
-      ;
+      return AddrAndBoundsInfo(addr, rawInput, isPresent);
     }
   }
-  return AddrAndBoundsInfo(symAddr);
+  return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
 template <typename BoundsOp, typename BoundsType>
@@ -807,7 +817,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
              Fortran::lower::StatementContext &stmtCtx,
              const std::list<Fortran::parser::SectionSubscript> &subscripts,
              std::stringstream &asFortran, fir::ExtendedValue &dataExv,
-             bool dataExvIsAssumedSize, mlir::Value baseAddr,
+             bool dataExvIsAssumedSize, AddrAndBoundsInfo &info,
              bool treatIndexAsSection = false) {
   int dimension = 0;
   mlir::Type idxTy = builder.getIndexType();
@@ -831,11 +841,30 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
       mlir::Value stride = one;
       bool strideInBytes = false;
 
-      if (fir::unwrapRefType(baseAddr.getType()).isa<fir::BaseBoxType>()) {
-        mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
-        auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
-                                                      baseAddr, d);
-        stride = dimInfo.getByteStride();
+      if (fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
+        if (info.isPresent) {
+          stride =
+              builder
+                  .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
+                  .genThen([&]() {
+                    mlir::Value d =
+                        builder.createIntegerConstant(loc, idxTy, dimension);
+                    auto dimInfo = builder.create<fir::BoxDimsOp>(
+                        loc, idxTy, idxTy, idxTy, info.addr, d);
+                    builder.create<fir::ResultOp>(loc, dimInfo.getByteStride());
+                  })
+                  .genElse([&] {
+                    mlir::Value zero =
+                        builder.createIntegerConstant(loc, idxTy, 0);
+                    builder.create<fir::ResultOp>(loc, zero);
+                  })
+                  .getResults()[0];
+        } else {
+          mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
+          auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy,
+                                                        idxTy, info.addr, d);
+          stride = dimInfo.getByteStride();
+        }
         strideInBytes = true;
       }
 
@@ -919,7 +948,26 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
           }
         }
 
-        extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
+        if (info.isPresent &&
+            fir::unwrapRefType(info.addr.getType()).isa<fir::BaseBoxType>()) {
+          extent =
+              builder
+                  .genIfOp(loc, idxTy, info.isPresent, /*withElseRegion=*/true)
+                  .genThen([&]() {
+                    mlir::Value ext = fir::factory::readExtent(
+                        builder, loc, dataExv, dimension);
+                    builder.create<fir::ResultOp>(loc, ext);
+                  })
+                  .genElse([&] {
+                    mlir::Value zero =
+                        builder.createIntegerConstant(loc, idxTy, 0);
+                    builder.create<fir::ResultOp>(loc, zero);
+                  })
+                  .getResults()[0];
+        } else {
+          extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
+        }
+
         if (dataExvIsAssumedSize && dimension + 1 == dataExvRank) {
           extent = zero;
           if (ubound && lbound) {
@@ -976,6 +1024,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                   dataExv = converter.genExprAddr(operandLocation, *exprBase,
                                                   stmtCtx);
                   info.addr = fir::getBase(dataExv);
+                  info.rawInput = info.addr;
                   asFortran << (*exprBase).AsFortran();
                 } else {
                   const Fortran::parser::Name &name =
@@ -993,7 +1042,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                   bounds = genBoundsOps<BoundsOp, BoundsType>(
                       builder, operandLocation, converter, stmtCtx,
                       arrayElement->subscripts, asFortran, dataExv,
-                      dataExvIsAssumedSize, info.addr, treatIndexAsSection);
+                      dataExvIsAssumedSize, info, treatIndexAsSection);
                 }
                 asFortran << ')';
               } else if (auto structComp = Fortran::parser::Unwrap<
@@ -1001,6 +1050,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                 fir::ExtendedValue compExv =
                     converter.genExprAddr(operandLocation, *expr, stmtCtx);
                 info.addr = fir::getBase(compExv);
+                info.rawInput = info.addr;
                 if (fir::unwrapRefType(info.addr.getType())
                         .isa<fir::SequenceType>())
                   bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
@@ -1012,7 +1062,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                     *Fortran::parser::GetLastName(*structComp).symbol);
                 if (isOptional)
                   info.isPresent = builder.create<fir::IsPresentOp>(
-                      operandLocation, builder.getI1Type(), info.addr);
+                      operandLocation, builder.getI1Type(), info.rawInput);
 
                 if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(
                         info.addr.getDefiningOp())) {
@@ -1020,6 +1070,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                       fir::isPointerType(loadOp.getType()))
                     info.addr = builder.create<fir::BoxAddrOp>(operandLocation,
                                                                info.addr);
+                  info.rawInput = info.addr;
                 }
 
                 // If the component is an allocatable or pointer the result of
@@ -1029,6 +1080,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                 if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
                         info.addr.getDefiningOp())) {
                   info.addr = boxAddrOp.getVal();
+                  info.rawInput = info.addr;
                   bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
                       builder, operandLocation, converter, compExv, info);
                 }
@@ -1043,6 +1095,7 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
                   fir::ExtendedValue compExv =
                       converter.genExprAddr(operandLocation, *expr, stmtCtx);
                   info.addr = fir::getBase(compExv);
+                  info.rawInput = info.addr;
                   asFortran << (*expr).AsFortran();
                 } else if (const auto *dataRef{
                                std::get_if<Fortran::parser::DataRef>(
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 43f54c6d2a71bb..6ae270f63f5cf4 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -67,9 +67,12 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
   mlir::Value varPtrPtr;
   if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
     if (isPresent) {
+      mlir::Type ifRetTy = boxTy.getEleTy();
+      if (!fir::isa_ref_type(ifRetTy))
+        ifRetTy = fir::ReferenceType::get(ifRetTy);
       baseAddr =
           builder
-              .genIfOp(loc, {boxTy.getEleTy()}, isPresent,
+              .genIfOp(loc, {ifRetTy}, isPresent,
                        /*withElseRegion=*/true)
               .genThen([&]() {
                 mlir::Value boxAddr =
@@ -78,7 +81,7 @@ static Op createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
               })
               .genElse([&] {
                 mlir::Value absent =
-                    builder.create<fir::AbsentOp>(loc, boxTy.getEleTy());
+                    builder.create<fir::AbsentOp>(loc, ifRetTy);
                 builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
               })
               .getResults()[0];
@@ -295,9 +298,16 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
                                        asFortran, bounds,
                                        /*treatIndexAsSection=*/true);
 
-    Op op = createDataEntryOp<Op>(
-        builder, operandLocation, info.addr, asFortran, bounds, structured,
-        implicit, dataClause, info.addr.getType(), info.isPresent);
+    // If the input value is optional and is not a descriptor, we use the
+    // rawInput directly.
+    mlir::Value baseAddr =
+        ((info.addr.getType() != fir::unwrapRefType(info.rawInput.getType())) &&
+         info.isPresent)
+            ? info.rawInput
+            : info.addr;
+    Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
+                                  bounds, structured, implicit, dataClause,
+                                  baseAddr.getType(), info.isPresent);
     dataOperands.push_back(op.getAccPtr());
   }
 }
diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90
index bd96bc8bcba359..df97cbcd187d2b 100644
--- a/flang/test/Lower/OpenACC/acc-bounds.f90
+++ b/flang/test/Lower/OpenACC/acc-bounds.f90
@@ -126,8 +126,8 @@ subroutine acc_optional_data(a)
   
 ! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data(
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a", fir.optional}) {
-! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
-! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#0 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
+! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %[[ARG0]] {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
 ! CHECK: %[[BOX:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box<!fir.ptr<!fir.array<?xf32>>>) {
 ! CHECK:   %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
 ! CHECK:   fir.result %[[LOAD]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
@@ -153,4 +153,38 @@ subroutine acc_optional_data(a)
 ! CHECK: %[[ATTACH:.*]] = acc.attach varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
 ! CHECK: acc.data dataOperands(%[[ATTACH]] : !fir.ptr<!fir.array<?xf32>>)
 
+  subroutine acc_optional_data2(a, n)
+    integer :: n
+    real, optional :: a(n)
+    !$acc data no_create(a)
+    !$acc end data
+  end subroutine
+
+! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data2(
+! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
+! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "_QMopenacc_boundsFacc_optional_data2Ea"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+! CHECK: %[[NO_CREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%10) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
+! CHECK: acc.data dataOperands(%[[NO_CREATE]] : !fir.ref<!fir.array<?xf32>>) {
+
+  subroutine acc_optional_data3(a, n)
+    integer :: n
+    real, optional :: a(n)
+    !$acc data no_create(a(1:n))
+    !$acc end data
+  end subroutine
+
+! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data3(
+! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a", fir.optional}, %[[N:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
+! CHECK: %[[DECL_A:.*]]:2 = hlfir.declare %[[A]](%{{.*}}) {fortran_attrs = #fir.var_attrs<optional>, uniq_name = "_QMopenacc_boundsFacc_optional_data3Ea"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+! CHECK: %[[PRES:.*]] = fir.is_present %[[DECL_A]]#1 : (!fir.ref<!fir.array<?xf32>>) -> i1
+! CHECK: %[[STRIDE:.*]] = fir.if %[[PRES]] -> (index) {
+! CHECK:   %[[DIMS:.*]]:3 = fir.box_dims %[[DECL_A]]#0, %c0{{.*}} : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
+! CHECK:   fir.result %[[DIMS]]#2 : index
+! CHECK: } else {
+! CHECK:   fir.result %c0{{.*}} : index
+! CHECK: }
+! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true}
+! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%14) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
+! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref<!fir.array<?xf32>>) {
+
 end module



More information about the flang-commits mailing list