[flang-commits] [flang] [mlir] [flang][acc] Update stride calculation to include inner-dimensions (PR #136613)

Razvan Lupusoru via flang-commits flang-commits at lists.llvm.org
Mon Apr 21 14:10:10 PDT 2025


https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/136613

The acc.bounds operation allows specifying stride - but it did not clarify what it meant. The dialect was updated to specifically note that stride must capture inner dimension sizes when specified for outer dimensions.

Flang lowering was also updated for OpenACC to adhere to this. This was already the case for descriptor-based arrays - but now this is also being done for all arrays.

>From f6d5d9d7801bea4c01b8aa68d12e167c930aa6ad Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Mon, 21 Apr 2025 14:09:05 -0700
Subject: [PATCH] [flang][acc] Update stride calculation to include
 inner-dimensions

The acc.bounds operation allows specifying stride - but it did not
clarify what it meant. The dialect was updated to specifically note
that stride must capture inner dimension sizes when specified for outer
dimensions.

Flang lowering was also updated for OpenACC to adhere to this. This
was already the case for descriptor-based arrays - but now this is also
being done for all arrays.
---
 flang/include/flang/Lower/DirectivesCommon.h  | 22 ++++++++++++++-----
 .../Optimizer/Builder/DirectivesCommon.h      | 18 ++++++++++-----
 flang/lib/Lower/OpenACC.cpp                   | 18 +++++++++++----
 flang/test/Lower/OpenACC/acc-bounds.f90       | 15 +++++++++++++
 .../acc-enter-data-unwrap-defaultbounds.f90   | 12 +++++-----
 flang/test/Lower/OpenACC/acc-enter-data.f90   |  6 ++---
 .../mlir/Dialect/OpenACC/OpenACCOps.td        | 11 ++++++++--
 7 files changed, 77 insertions(+), 25 deletions(-)

diff --git a/flang/include/flang/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
index 6e24343cebd3a..1020d32a07439 100644
--- a/flang/include/flang/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -670,7 +670,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
              const std::vector<Fortran::evaluate::Subscript> &subscripts,
              std::stringstream &asFortran, fir::ExtendedValue &dataExv,
              bool dataExvIsAssumedSize, fir::factory::AddrAndBoundsInfo &info,
-             bool treatIndexAsSection = false) {
+             bool treatIndexAsSection = false, bool strideIncludeLowerExtent = false) {
   int dimension = 0;
   mlir::Type idxTy = builder.getIndexType();
   mlir::Type boundTy = builder.getType<BoundsType>();
@@ -679,6 +679,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
   mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
   const int dataExvRank = static_cast<int>(dataExv.rank());
+  mlir::Value cummulativeExtent = one;
   for (const auto &subscript : subscripts) {
     const auto *triplet{std::get_if<Fortran::evaluate::Triplet>(&subscript.u)};
     if (triplet || treatIndexAsSection) {
@@ -847,6 +848,15 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
           ubound = builder.create<mlir::arith::SubIOp>(loc, extent, one);
         }
       }
+
+      // When the strideInBytes is true, it means the stride is from descriptor
+      // and this already includes the lower extents.
+      if (strideIncludeLowerExtent && !strideInBytes) {
+        stride = cummulativeExtent;
+        cummulativeExtent = builder.createOrFold<mlir::arith::MulIOp>(
+            loc, cummulativeExtent, extent);
+      }
+
       mlir::Value bound = builder.create<BoundsOp>(
           loc, boundTy, lbound, ubound, extent, stride, strideInBytes, baseLb);
       bounds.push_back(bound);
@@ -882,7 +892,8 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
     const Fortran::semantics::MaybeExpr &maybeDesignator,
     mlir::Location operandLocation, std::stringstream &asFortran,
     llvm::SmallVector<mlir::Value> &bounds, bool treatIndexAsSection = false,
-    bool unwrapFirBox = true, bool genDefaultBounds = true) {
+    bool unwrapFirBox = true, bool genDefaultBounds = true,
+    bool strideIncludeLowerExtent = false) {
   using namespace Fortran;
 
   fir::factory::AddrAndBoundsInfo info;
@@ -943,7 +954,8 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
       asFortran << '(';
       bounds = genBoundsOps<BoundsOp, BoundsType>(
           builder, operandLocation, converter, stmtCtx, arrayRef->subscript(),
-          asFortran, dataExv, dataExvIsAssumedSize, info, treatIndexAsSection);
+          asFortran, dataExv, dataExvIsAssumedSize, info, treatIndexAsSection,
+          strideIncludeLowerExtent);
     }
     asFortran << ')';
   } else if (auto compRef = detail::getRef<evaluate::Component>(designator)) {
@@ -955,7 +967,7 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
         mlir::isa<fir::SequenceType>(fir::unwrapRefType(info.addr.getType())))
       bounds = fir::factory::genBaseBoundsOps<BoundsOp, BoundsType>(
           builder, operandLocation, compExv,
-          /*isAssumedSize=*/false);
+          /*isAssumedSize=*/false, strideIncludeLowerExtent);
     asFortran << designator.AsFortran();
 
     if (semantics::IsOptional(compRef->GetLastSymbol())) {
@@ -1012,7 +1024,7 @@ fir::factory::AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
       if (genDefaultBounds &&
           mlir::isa<fir::SequenceType>(fir::unwrapRefType(info.addr.getType())))
         bounds = fir::factory::genBaseBoundsOps<BoundsOp, BoundsType>(
-            builder, operandLocation, dataExv, dataExvIsAssumedSize);
+            builder, operandLocation, dataExv, dataExvIsAssumedSize, strideIncludeLowerExtent);
       asFortran << symRef->get().name().ToString();
     } else { // Unsupported
       llvm::report_fatal_error("Unsupported type of OpenACC operand");
diff --git a/flang/include/flang/Optimizer/Builder/DirectivesCommon.h b/flang/include/flang/Optimizer/Builder/DirectivesCommon.h
index c0ab557d26970..eec70a1d62a55 100644
--- a/flang/include/flang/Optimizer/Builder/DirectivesCommon.h
+++ b/flang/include/flang/Optimizer/Builder/DirectivesCommon.h
@@ -203,7 +203,8 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
-                 fir::ExtendedValue dataExv, bool isAssumedSize) {
+                 fir::ExtendedValue dataExv, bool isAssumedSize,
+                 bool strideIncludeLowerExtent = false) {
   mlir::Type idxTy = builder.getIndexType();
   mlir::Type boundTy = builder.getType<BoundsType>();
   llvm::SmallVector<mlir::Value> bounds;
@@ -213,23 +214,30 @@ genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
 
   mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
   const unsigned rank = dataExv.rank();
+  mlir::Value cummulativeExtent = one;
   for (unsigned dim = 0; dim < rank; ++dim) {
     mlir::Value baseLb =
         fir::factory::readLowerBound(builder, loc, dataExv, dim, one);
     mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
     mlir::Value ub;
     mlir::Value lb = zero;
-    mlir::Value ext = fir::factory::readExtent(builder, loc, dataExv, dim);
+    mlir::Value extent = fir::factory::readExtent(builder, loc, dataExv, dim);
     if (isAssumedSize && dim + 1 == rank) {
-      ext = zero;
+      extent = zero;
       ub = lb;
     } else {
       // ub = extent - 1
-      ub = builder.create<mlir::arith::SubIOp>(loc, ext, one);
+      ub = builder.create<mlir::arith::SubIOp>(loc, extent, one);
+    }
+    mlir::Value stride = one;
+    if (strideIncludeLowerExtent) {
+      stride = cummulativeExtent;
+      cummulativeExtent = builder.createOrFold<mlir::arith::MulIOp>(
+          loc, cummulativeExtent, extent);
     }
 
     mlir::Value bound =
-        builder.create<BoundsOp>(loc, boundTy, lb, ub, ext, one, false, baseLb);
+        builder.create<BoundsOp>(loc, boundTy, lb, ub, extent, stride, false, baseLb);
     bounds.push_back(bound);
   }
   return bounds;
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index c4f2e27f69c3b..418bf4ee3d15f 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -52,6 +52,12 @@ static llvm::cl::opt<bool> generateDefaultBounds(
     llvm::cl::desc("Whether to generate default bounds for arrays."),
     llvm::cl::init(false));
 
+static llvm::cl::opt<bool> strideIncludeLowerExtent(
+    "openacc-stride-include-lower-extent",
+    llvm::cl::desc(
+        "Whether to include the lower dimensions extents in the stride."),
+    llvm::cl::init(true));
+
 // Special value for * passed in device_type or gang clauses.
 static constexpr std::int64_t starCst = -1;
 
@@ -396,7 +402,8 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
             converter, builder, semanticsContext, stmtCtx, symbol, designator,
             operandLocation, asFortran, bounds,
             /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
-            /*genDefaultBounds=*/generateDefaultBounds);
+            /*genDefaultBounds=*/generateDefaultBounds,
+            /*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
     LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
 
     // If the input value is optional and is not a descriptor, we use the
@@ -437,7 +444,8 @@ static void genDeclareDataOperandOperations(
             converter, builder, semanticsContext, stmtCtx, symbol, designator,
             operandLocation, asFortran, bounds,
             /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
-            /*genDefaultBounds=*/generateDefaultBounds);
+            /*genDefaultBounds=*/generateDefaultBounds,
+            /*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
     LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
     EntryOp op = createDataEntryOp<EntryOp>(
         builder, operandLocation, info.addr, asFortran, bounds, structured,
@@ -914,7 +922,8 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
             converter, builder, semanticsContext, stmtCtx, symbol, designator,
             operandLocation, asFortran, bounds,
             /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
-            /*genDefaultBounds=*/generateDefaultBounds);
+            /*genDefaultBounds=*/generateDefaultBounds,
+            /*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
     LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
 
     RecipeOp recipe;
@@ -1545,7 +1554,8 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
             converter, builder, semanticsContext, stmtCtx, symbol, designator,
             operandLocation, asFortran, bounds,
             /*treatIndexAsSection=*/true, /*unwrapFirBox=*/unwrapFirBox,
-            /*genDefaultBounds=*/generateDefaultBounds);
+            /*genDefaultBounds=*/generateDefaultBounds,
+            /*strideIncludeLowerExtent=*/strideIncludeLowerExtent);
     LLVM_DEBUG(llvm::dbgs() << __func__ << "\n"; info.dump(llvm::dbgs()));
 
     mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
diff --git a/flang/test/Lower/OpenACC/acc-bounds.f90 b/flang/test/Lower/OpenACC/acc-bounds.f90
index 8fea357f116a2..cff53a2bfd122 100644
--- a/flang/test/Lower/OpenACC/acc-bounds.f90
+++ b/flang/test/Lower/OpenACC/acc-bounds.f90
@@ -182,4 +182,19 @@ subroutine acc_optional_data3(a, n)
 ! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
 ! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref<!fir.array<?xf32>>) {
 
+  subroutine acc_explicit_shape_3d(arr)
+    real :: arr(1000,100,10)
+    !$acc data copyin(arr(:1000,:100,:10))
+    !$acc end data
+  end subroutine
+
+! Test that the stride is cummulative of the lower extents
+! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_explicit_shape_3d(
+! CHECK-SAME: %[[ARR:.*]]: !fir.ref<!fir.array<1000x100x10xf32>> {fir.bindc_name = "arr"}) {
+! CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%c0 : index) upperbound(%c999 : index) extent(%c1000 : index) stride(%c1{{.*}} : index) startIdx(%c1 : index)
+! CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%c0 : index) upperbound(%c99 : index) extent(%c100 : index) stride(%c1000{{.*}} : index) startIdx(%c1 : index)
+! CHECK: %[[BOUND3:.*]] = acc.bounds lowerbound(%c0 : index) upperbound(%c9 : index) extent(%c10 : index) stride(%c100000{{.*}} : index) startIdx(%c1 : index)
+! CHECK: %[[COPYIN:.*]] = acc.copyin varPtr({{.*}} : !fir.ref<!fir.array<1000x100x10xf32>>) bounds(%[[BOUND1]], %[[BOUND2]], %[[BOUND3]]) -> !fir.ref<!fir.array<1000x100x10xf32>> {name = "arr(:1000,:100,:10)"}
+! CHECK: acc.data dataOperands(%[[COPYIN]] : !fir.ref<!fir.array<1000x100x10xf32>>) {
+
 end module
diff --git a/flang/test/Lower/OpenACC/acc-enter-data-unwrap-defaultbounds.f90 b/flang/test/Lower/OpenACC/acc-enter-data-unwrap-defaultbounds.f90
index 3de00ebb9eb05..c42350a07c498 100644
--- a/flang/test/Lower/OpenACC/acc-enter-data-unwrap-defaultbounds.f90
+++ b/flang/test/Lower/OpenACC/acc-enter-data-unwrap-defaultbounds.f90
@@ -26,7 +26,7 @@ subroutine acc_enter_data
 !CHECK: %[[BOUND0:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[LB:.*]] = arith.constant 0 : index
 !CHECK: %[[UB:.*]] = arith.subi %[[EXTENT_C10]], %[[ONE]] : index
-!CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[EXTENT_C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
+!CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[EXTENT_C10]] : index) stride(%c1{{.*}} : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND0]], %[[BOUND1]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
 !CHECK: acc.enter_data dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
 
@@ -37,7 +37,7 @@ subroutine acc_enter_data
 !CHECK: %[[BOUND0:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[LB:.*]] = arith.constant 0 : index
 !CHECK: %[[UB:.*]] = arith.subi %[[EXTENT_C10]], %[[ONE]] : index
-!CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[EXTENT_C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
+!CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[EXTENT_C10]] : index) stride(%c1{{.*}} : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND0]], %[[BOUND1]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
 !CHECK: [[IF1:%.*]] = arith.constant true
 !CHECK: acc.enter_data if([[IF1]]) dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
@@ -49,7 +49,7 @@ subroutine acc_enter_data
 !CHECK: %[[BOUND0:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[LB:.*]] = arith.constant 0 : index
 !CHECK: %[[UB:.*]] = arith.subi %[[EXTENT_C10]], %[[ONE]] : index
-!CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[EXTENT_C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
+!CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[EXTENT_C10]] : index) stride(%c1{{.*}} : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND0]], %[[BOUND1]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
 !CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
 !CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
@@ -161,7 +161,7 @@ subroutine acc_enter_data
 !CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB1]] : index) upperbound(%[[UB1]] : index) extent(%c10{{.*}} : index) stride(%[[ONE]] : index) startIdx(%c1{{.*}} : index)
 !CHECK: %[[LB2:.*]] = arith.constant 0 : index
 !CHECK: %[[UB2:.*]] = arith.constant 4 : index
-!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB2]] : index) upperbound(%[[UB2]] : index) extent(%[[EXTENT_C10]] : index) stride(%[[ONE]] : index) startIdx(%c1{{.*}} : index)
+!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB2]] : index) upperbound(%[[UB2]] : index) extent(%[[EXTENT_C10]] : index) stride(%c1{{.*}} : index) startIdx(%c1{{.*}} : index)
 !CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND1]], %[[BOUND2]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a(1:,1:5)", structured = false}
 !CHECK: acc.enter_data   dataOperands(%[[COPYIN_A]] : !fir.ref<!fir.array<10x10xf32>>)
 
@@ -172,7 +172,7 @@ subroutine acc_enter_data
 !CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB1]] : index) extent(%[[C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[LB:.*]] = arith.constant 0 : index
 !CHECK: %[[UB2:.*]] = arith.constant 4 : index
-!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB2]] : index) extent(%[[EXTENT_C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
+!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB2]] : index) extent(%[[EXTENT_C10]] : index) stride(%c1{{.*}} : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND1]], %[[BOUND2]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a(:10,1:5)", structured = false}
 !CHECK: acc.enter_data dataOperands(%[[COPYIN_A]] : !fir.ref<!fir.array<10x10xf32>>)
 
@@ -182,7 +182,7 @@ subroutine acc_enter_data
 !CHECK: %[[UB:.*]] = arith.subi %c10{{.*}}, %[[ONE]] : index
 !CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%c10{{.*}} : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[UB:.*]] = arith.subi %c10{{.*}}, %[[ONE]] : index
-!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%c10{{.*}} : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
+!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%c10{{.*}} : index) stride(%c1{{.*}} : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND1]], %[[BOUND2]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a(:,:)", structured = false}
 end subroutine acc_enter_data
 
diff --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90
index d2cd097388828..3e49259c360eb 100644
--- a/flang/test/Lower/OpenACC/acc-enter-data.f90
+++ b/flang/test/Lower/OpenACC/acc-enter-data.f90
@@ -105,7 +105,7 @@ subroutine acc_enter_data
 !CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB1]] : index) upperbound(%[[UB1]] : index) extent(%c10{{.*}} : index) stride(%[[ONE]] : index) startIdx(%c1{{.*}} : index)
 !CHECK: %[[LB2:.*]] = arith.constant 0 : index
 !CHECK: %[[UB2:.*]] = arith.constant 4 : index
-!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB2]] : index) upperbound(%[[UB2]] : index) extent(%[[EXTENT_C10]] : index) stride(%[[ONE]] : index) startIdx(%c1{{.*}} : index)
+!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB2]] : index) upperbound(%[[UB2]] : index) extent(%[[EXTENT_C10]] : index) stride(%c1{{.*}} : index) startIdx(%c1{{.*}} : index)
 !CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND1]], %[[BOUND2]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a(1:,1:5)", structured = false}
 !CHECK: acc.enter_data   dataOperands(%[[COPYIN_A]] : !fir.ref<!fir.array<10x10xf32>>)
 
@@ -116,7 +116,7 @@ subroutine acc_enter_data
 !CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB1]] : index) extent(%[[C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[LB:.*]] = arith.constant 0 : index
 !CHECK: %[[UB2:.*]] = arith.constant 4 : index
-!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB2]] : index) extent(%[[EXTENT_C10]] : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
+!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB2]] : index) extent(%[[EXTENT_C10]] : index) stride(%c1{{.*}} : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND1]], %[[BOUND2]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a(:10,1:5)", structured = false}
 !CHECK: acc.enter_data dataOperands(%[[COPYIN_A]] : !fir.ref<!fir.array<10x10xf32>>)
 
@@ -126,7 +126,7 @@ subroutine acc_enter_data
 !CHECK: %[[UB:.*]] = arith.subi %c10{{.*}}, %[[ONE]] : index
 !CHECK: %[[BOUND1:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%c10{{.*}} : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
 !CHECK: %[[UB:.*]] = arith.subi %c10{{.*}}, %[[ONE]] : index
-!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%c10{{.*}} : index) stride(%[[ONE]] : index) startIdx(%[[ONE]] : index)
+!CHECK: %[[BOUND2:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%c10{{.*}} : index) stride(%c1{{.*}} : index) startIdx(%c1{{.*}} : index)
 !CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[DECLA]]#0 : !fir.ref<!fir.array<10x10xf32>>) bounds(%[[BOUND1]], %[[BOUND2]]) -> !fir.ref<!fir.array<10x10xf32>> {name = "a(:,:)", structured = false}
 end subroutine acc_enter_data
 
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 70a2ba0919952..275472bc5edd9 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -377,6 +377,13 @@ def OpenACC_DataBoundsOp : OpenACC_Op<"bounds",
     but not checked for consistency). When the source language's arrays are
     not zero-based, the `startIdx` must specify the zero-position index.
 
+    The `stride` represents the distance between consecutive elements. For
+    multi-dimensional arrays, the `stride` for each outer dimension must account
+    for the complete size of all inner dimensions.
+
+    The `strideInBytes` flag indicates that the `stride` is specified in bytes
+    rather than the number of elements.
+
     Examples below show copying a slice of 10-element array except first element.
     Note that the examples use extent in data clause for C++ and upperbound
     for Fortran (as per 2.7.1). To simplify examples, the constants are used
@@ -389,7 +396,7 @@ def OpenACC_DataBoundsOp : OpenACC_Op<"bounds",
     ```
     =>
     ```mlir
-    acc.bounds lb(1) ub(9) extent(9) startIdx(0)
+    acc.bounds lb(1) ub(9) extent(9) startIdx(0) stride(1)
     ```
 
     Fortran:
@@ -399,7 +406,7 @@ def OpenACC_DataBoundsOp : OpenACC_Op<"bounds",
     ```
     =>
     ```mlir
-    acc.bounds lb(1) ub(9) extent(9) startIdx(1)
+    acc.bounds lb(1) ub(9) extent(9) startIdx(1) stride(1)
     ```
   }];
 



More information about the flang-commits mailing list