[flang-commits] [flang] [flang][openacc/mp] Do not read bounds on absent box (PR #75252)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Dec 12 15:02:33 PST 2023


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/75252

Make sure we only load box and read its bounds when it is present.

Fix also some template parameter ordering issues. 

>From c927c793a06cb4b5dec987e91a87f0bdac5c9e4f Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 12 Dec 2023 13:36:23 -0800
Subject: [PATCH] [flang][openacc] Do not load optional box if not present

---
 flang/lib/Lower/DirectivesCommon.h      | 104 ++++++++++++++++++++----
 flang/lib/Lower/OpenACC.cpp             |  30 ++++---
 flang/lib/Lower/OpenMP.cpp              |   4 +-
 flang/test/Lower/OpenACC/acc-bounds.f90 |  31 +++++++
 flang/test/Lower/OpenACC/acc-data.f90   |   1 -
 5 files changed, 137 insertions(+), 33 deletions(-)

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 88a8916663df75..39f87202f90f5f 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -620,25 +620,36 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
 
     // Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a
     // `fir.ref<fir.class<T>>` type.
-    if (symAddr.getType().isa<fir::ReferenceType>())
+    if (symAddr.getType().isa<fir::ReferenceType>()) {
+      if (Fortran::semantics::IsOptional(sym)) {
+        mlir::Value isPresent =
+            builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), symAddr);
+        return builder.genIfOp(loc, {boxTy}, isPresent, /*withElseRegion=*/true)
+            .genThen([&]() {
+              mlir::Value load = builder.create<fir::LoadOp>(loc, symAddr);
+              builder.create<fir::ResultOp>(loc, mlir::ValueRange{load});
+            })
+            .genElse([&] {
+              mlir::Value absent = builder.create<fir::AbsentOp>(loc, boxTy);
+              builder.create<fir::ResultOp>(loc, mlir::ValueRange{absent});
+            })
+            .getResults()[0];
+      }
       return builder.create<fir::LoadOp>(loc, symAddr);
+    }
   }
   return symAddr;
 }
 
-/// Generate the bounds operation from the descriptor information.
 template <typename BoundsOp, typename BoundsType>
-llvm::SmallVector<mlir::Value>
-genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
-                    Fortran::lower::AbstractConverter &converter,
+static llvm::SmallVector<mlir::Value>
+gatherBoundsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
                     fir::ExtendedValue dataExv, mlir::Value box) {
+  mlir::Value byteStride;
   llvm::SmallVector<mlir::Value> bounds;
   mlir::Type idxTy = builder.getIndexType();
   mlir::Type boundTy = builder.getType<BoundsType>();
   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
-  assert(box.getType().isa<fir::BaseBoxType>() &&
-         "expect fir.box or fir.class");
-  mlir::Value byteStride;
   for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
     mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim);
     mlir::Value baseLb =
@@ -660,6 +671,58 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
   return bounds;
 }
 
+/// Generate the bounds operation from the descriptor information.
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
+                    Fortran::lower::AbstractConverter &converter,
+                    fir::ExtendedValue dataExv, mlir::Value box,
+                    bool isOptional = false) {
+  llvm::SmallVector<mlir::Value> bounds;
+  mlir::Type idxTy = builder.getIndexType();
+  mlir::Type boundTy = builder.getType<BoundsType>();
+
+  assert(box.getType().isa<fir::BaseBoxType>() &&
+         "expect fir.box or fir.class");
+
+  if (isOptional) {
+    mlir::Value isPresent =
+        builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), box);
+
+    llvm::SmallVector<mlir::Type> resTypes;
+    for (unsigned dim = 0; dim < dataExv.rank(); ++dim)
+      resTypes.push_back(boundTy);
+
+    auto ifOp =
+        builder.genIfOp(loc, resTypes, isPresent, /*withElseRegion=*/true)
+            .genThen([&]() {
+              llvm::SmallVector<mlir::Value> tempBounds =
+                  gatherBoundsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                            dataExv, box);
+              builder.create<fir::ResultOp>(loc, tempBounds);
+            })
+            .genElse([&] {
+              llvm::SmallVector<mlir::Value> tempBounds;
+              mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
+              mlir::Value minusOne =
+                  builder.createIntegerConstant(loc, idxTy, -1);
+              for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
+                mlir::Value bound = builder.create<BoundsOp>(
+                    loc, boundTy, zero, minusOne, zero, mlir::Value(), false,
+                    mlir::Value{});
+                tempBounds.push_back(bound);
+              }
+              builder.create<fir::ResultOp>(loc, tempBounds);
+            });
+    bounds.append(ifOp.getResults().begin(), ifOp.getResults().end());
+  } else {
+    llvm::SmallVector<mlir::Value> tempBounds =
+        gatherBoundsFromBox<BoundsOp, BoundsType>(builder, loc, dataExv, box);
+    bounds.append(tempBounds.begin(), tempBounds.end());
+  }
+  return bounds;
+}
+
 /// Generate bounds operation for base array without any subscripts
 /// provided.
 template <typename BoundsOp, typename BoundsType>
@@ -885,20 +948,20 @@ mlir::Value gatherDataOperandAddrAndBounds(
 
                 if (!arrayElement->subscripts.empty()) {
                   asFortran << '(';
-                  bounds = genBoundsOps<BoundsType, BoundsOp>(
+                  bounds = genBoundsOps<BoundsOp, BoundsType>(
                       builder, operandLocation, converter, stmtCtx,
                       arrayElement->subscripts, asFortran, dataExv, baseAddr,
                       treatIndexAsSection);
                 }
                 asFortran << ')';
-              } else if (Fortran::parser::Unwrap<
+              } else if (auto structComp = Fortran::parser::Unwrap<
                              Fortran::parser::StructureComponent>(designator)) {
                 fir::ExtendedValue compExv =
                     converter.genExprAddr(operandLocation, *expr, stmtCtx);
                 baseAddr = fir::getBase(compExv);
                 if (fir::unwrapRefType(baseAddr.getType())
                         .isa<fir::SequenceType>())
-                  bounds = genBaseBoundsOps<BoundsType, BoundsOp>(
+                  bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
                       builder, operandLocation, converter, compExv, baseAddr);
                 asFortran << (*expr).AsFortran();
 
@@ -917,8 +980,11 @@ mlir::Value gatherDataOperandAddrAndBounds(
                 if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
                         baseAddr.getDefiningOp())) {
                   baseAddr = boxAddrOp.getVal();
-                  bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>(
-                      builder, operandLocation, converter, compExv, baseAddr);
+                  bool isOptional = Fortran::semantics::IsOptional(
+                      *Fortran::parser::GetLastName(*structComp).symbol);
+                  bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+                      builder, operandLocation, converter, compExv, baseAddr,
+                      isOptional);
                 }
               } else {
                 if (Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
@@ -943,12 +1009,16 @@ mlir::Value gatherDataOperandAddrAndBounds(
                   baseAddr = getDataOperandBaseAddr(
                       converter, builder, *name.symbol, operandLocation);
                   if (fir::unwrapRefType(baseAddr.getType())
-                          .isa<fir::BaseBoxType>())
-                    bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>(
-                        builder, operandLocation, converter, dataExv, baseAddr);
+                          .isa<fir::BaseBoxType>()) {
+                    bool isOptional =
+                        Fortran::semantics::IsOptional(*name.symbol);
+                    bounds = genBoundsOpsFromBox<BoundsOp, BoundsType>(
+                        builder, operandLocation, converter, dataExv, baseAddr,
+                        isOptional);
+                  }
                   if (fir::unwrapRefType(baseAddr.getType())
                           .isa<fir::SequenceType>())
-                    bounds = genBaseBoundsOps<BoundsType, BoundsOp>(
+                    bounds = genBaseBoundsOps<BoundsOp, BoundsType>(
                         builder, operandLocation, converter, dataExv, baseAddr);
                   asFortran << name.ToString();
                 } else { // Unsupported
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e2abed1b9f4f67..531685948bc843 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -266,10 +266,11 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
     std::stringstream asFortran;
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-        Fortran::parser::AccObject, mlir::acc::DataBoundsType,
-        mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
-                                 accObject, operandLocation, asFortran, bounds,
-                                 /*treatIndexAsSection=*/true);
+        Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+        mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+                                   stmtCtx, accObject, operandLocation,
+                                   asFortran, bounds,
+                                   /*treatIndexAsSection=*/true);
     Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
                                   bounds, structured, implicit, dataClause,
                                   baseAddr.getType());
@@ -291,9 +292,10 @@ static void genDeclareDataOperandOperations(
     std::stringstream asFortran;
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-        Fortran::parser::AccObject, mlir::acc::DataBoundsType,
-        mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
-                                 accObject, operandLocation, asFortran, bounds);
+        Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+        mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+                                   stmtCtx, accObject, operandLocation,
+                                   asFortran, bounds);
     EntryOp op = createDataEntryOp<EntryOp>(
         builder, operandLocation, baseAddr, asFortran, bounds, structured,
         implicit, dataClause, baseAddr.getType());
@@ -748,9 +750,10 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
     std::stringstream asFortran;
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-        Fortran::parser::AccObject, mlir::acc::DataBoundsType,
-        mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
-                                 accObject, operandLocation, asFortran, bounds);
+        Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+        mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+                                   stmtCtx, accObject, operandLocation,
+                                   asFortran, bounds);
 
     RecipeOp recipe;
     mlir::Type retTy = getTypeFromBounds(bounds, baseAddr.getType());
@@ -1324,9 +1327,10 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
     std::stringstream asFortran;
     mlir::Location operandLocation = genOperandLocation(converter, accObject);
     mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-        Fortran::parser::AccObject, mlir::acc::DataBoundsType,
-        mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
-                                 accObject, operandLocation, asFortran, bounds);
+        Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
+        mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
+                                   stmtCtx, accObject, operandLocation,
+                                   asFortran, bounds);
 
     mlir::Type reductionTy = fir::unwrapRefType(baseAddr.getType());
     if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index eeba87fcd15116..59e06e8458e6c0 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1794,8 +1794,8 @@ bool ClauseProcessor::processMap(
           llvm::SmallVector<mlir::Value> bounds;
           std::stringstream asFortran;
           mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
-              Fortran::parser::OmpObject, mlir::omp::DataBoundsType,
-              mlir::omp::DataBoundsOp>(
+              Fortran::parser::OmpObject, mlir::omp::DataBoundsOp,
+              mlir::omp::DataBoundsType>(
               converter, firOpBuilder, semanticsContext, stmtCtx, ompObject,
               clauseLocation, asFortran, bounds, treatIndexAsSection);
 
diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90
index 8db18ab5aa9c4b..c8787c5e118f97 100644
--- a/flang/test/Lower/OpenACC/acc-bounds.f90
+++ b/flang/test/Lower/OpenACC/acc-bounds.f90
@@ -116,4 +116,35 @@ subroutine acc_multi_strides(a)
 ! CHECK: %[[PRESENT:.*]] = acc.present varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?x?x?xf32>>) bounds(%29, %33, %37) -> !fir.ref<!fir.array<?x?x?xf32>> {name = "a"}
 ! CHECK: acc.kernels dataOperands(%[[PRESENT]] : !fir.ref<!fir.array<?x?x?xf32>>) {
 
+  subroutine acc_optional_data(a)
+    real, pointer, optional :: a(:)
+    !$acc data attach(a)
+    !$acc end data
+  end subroutine
+  
+  ! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_optional_data(
+  ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {fir.bindc_name = "a", fir.optional}) {
+  ! CHECK: %[[ARG0_DECL:.*]]:2 = hlfir.declare %arg0 {fortran_attrs = #fir.var_attrs<optional, pointer>, uniq_name = "_QMopenacc_boundsFacc_optional_dataEa"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+  ! CHECK: %[[IS_PRESENT:.*]] = fir.is_present %[[ARG0_DECL]]#1 : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> i1
+  ! CHECK: %[[ADDR:.*]] = fir.if %[[IS_PRESENT]] -> (!fir.box<!fir.ptr<!fir.array<?xf32>>>) {
+  ! CHECK:   %[[LOAD:.*]] = fir.load %[[ARG0_DECL]]#1 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+  ! CHECK:   fir.result %[[LOAD]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  ! CHECK: } else {
+  ! CHECK:   %[[ABSENT:.*]] = fir.absent !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  ! CHECK:   fir.result %[[ABSENT]] : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+  ! CHECK: }
+  ! CHECK: %[[BOUNDS:.*]] = fir.if %{{.*}} -> (!acc.data_bounds_ty) {
+  ! CHECK:   %[[BOUND:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}}#1 : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}}#0 : index) {strideInBytes = true}
+  ! CHECK:   fir.result %[[BOUND]] : !acc.data_bounds_ty
+  ! CHECK: } else {
+  ! CHECK:   %[[C0:.*]] = arith.constant 0 : index
+  ! CHECK:   %[[CM1:.*]] = arith.constant -1 : index
+  ! CHECK:   %[[BOUND:.*]] = acc.bounds lowerbound(%[[C0]] : index) upperbound(%[[CM1]] : index) extent(%[[C0]] : index)
+  ! CHECK:   fir.result %[[BOUND]] : !acc.data_bounds_ty
+  ! CHECK: }
+  ! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[ADDR]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
+  ! CHECK: %[[ATTACH:.*]] = acc.attach varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUNDS]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
+  ! CHECK: acc.data dataOperands(%[[ATTACH]] : !fir.ptr<!fir.array<?xf32>>)
+  
+
 end module
diff --git a/flang/test/Lower/OpenACC/acc-data.f90 b/flang/test/Lower/OpenACC/acc-data.f90
index d302be85c5df46..a6572e14707606 100644
--- a/flang/test/Lower/OpenACC/acc-data.f90
+++ b/flang/test/Lower/OpenACC/acc-data.f90
@@ -198,4 +198,3 @@ subroutine acc_data
 ! CHECK-NOT: acc.data
 
 end subroutine acc_data
-



More information about the flang-commits mailing list