[flang-commits] [flang] [flang] lower remaining cases of pointer assignments inside forall (PR #130772)

via flang-commits flang-commits at lists.llvm.org
Tue Mar 11 06:13:53 PDT 2025


https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/130772

Implement handling of `NULL()` RHS, polymorphic pointers, as well as lower bounds or bounds remapping in pointer assignment inside FORALL.

These cases eventually do not require updating hlfir.region_assign, lowering can simply prepare the new descriptor for the LHS inside the RHS region.

Looking more closely at the polymorphic cases, there is not need to call the runtime, fir.rebox and fir.embox do handle the dynamic type setting correctly.

After this patch, the last remaining TODO is the allocatable assignment inside FORALL, which like some cases here, is more likely an accidental feature given FORALL was deprecated in F2003 at the same time than allocatable components where added.

>From bdf3229ec163b9407e84cce84c85b9d99bfad62d Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 11 Mar 2025 06:03:47 -0700
Subject: [PATCH] [flang] lower remaining cases of pointer assignments inside
 forall

---
 .../flang/Optimizer/Builder/FIRBuilder.h      |  10 ++
 flang/lib/Lower/Bridge.cpp                    |  99 ++++++++++------
 flang/lib/Lower/ConvertVariable.cpp           |  18 +--
 flang/lib/Optimizer/Builder/FIRBuilder.cpp    |  31 ++++-
 ...l-pointer-assignment-scheduling-bounds.f90 |  93 +++++++++++++++
 ...nter-assignment-scheduling-polymorphic.f90 | 110 ++++++++++++++++++
 .../forall-pointer-assignment-scheduling.f90  |  56 +++++++--
 ...all-proc-pointer-assignment-scheduling.f90 |  33 ++++++
 .../acc-enter-data-unwrap-defaultbounds.f90   |   4 +-
 flang/test/Lower/OpenACC/acc-enter-data.f90   |   4 +-
 10 files changed, 393 insertions(+), 65 deletions(-)
 create mode 100644 flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling-bounds.f90
 create mode 100644 flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling-polymorphic.f90

diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 1675c15363868..003b4358572c1 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -774,9 +774,19 @@ mlir::Value createZeroValue(fir::FirOpBuilder &builder, mlir::Location loc,
 std::optional<std::int64_t> getExtentFromTriplet(mlir::Value lb, mlir::Value ub,
                                                  mlir::Value stride);
 
+/// Compute the extent value given the lower bound \lb and upper bound \ub.
+/// All inputs must have the same SSA integer type.
+mlir::Value computeExtent(fir::FirOpBuilder &builder, mlir::Location loc,
+                          mlir::Value lb, mlir::Value ub);
+mlir::Value computeExtent(fir::FirOpBuilder &builder, mlir::Location loc,
+                          mlir::Value lb, mlir::Value ub, mlir::Value zero,
+                          mlir::Value one);
+
 /// Generate max(\p value, 0) where \p value is a scalar integer.
 mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
                            mlir::Value value);
+mlir::Value genMaxWithZero(fir::FirOpBuilder &builder, mlir::Location loc,
+                           mlir::Value value, mlir::Value zero);
 
 /// The type(C_PTR/C_FUNPTR) is defined as the derived type with only one
 /// component of integer 64, and the component is the C address. Get the C
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 93f54d88a029d..d0b26ddc92133 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4353,30 +4353,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                                         stmtCtx);
   }
 
-  void genForallPointerAssignment(
-      mlir::Location loc, const Fortran::evaluate::Assignment &assign,
-      const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
-    std::optional<Fortran::evaluate::DynamicType> lhsType =
-        assign.lhs.GetType();
-    // Polymorphic pointer assignment is delegated to the runtime, and
-    // PointerAssociateLowerBounds needs the lower bounds as arguments, so they
-    // must be preserved.
-    if (lhsType && lhsType->IsPolymorphic())
-      TODO(loc, "polymorphic pointer assignment in FORALL");
-    // Nullification is special, there is no RHS that can be prepared,
-    // need to encode it in HLFIR.
-    if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
-            assign.rhs))
-      TODO(loc, "NULL pointer assignment in FORALL");
-    // Lower bounds could be "applied" when preparing RHS, but in order
-    // to deal with the polymorphic case and to reuse existing pointer
-    // assignment helpers in HLFIR codegen, it is better to keep them
-    // separate.
-    if (!lbExprs.empty())
-      TODO(loc, "Pointer assignment with new lower bounds inside FORALL");
-    // Otherwise, this is a "dumb" pointer assignment that can be represented
-    // with hlfir.region_assign with descriptor address/value and later
-    // implemented with a store.
+  void genForallPointerAssignment(mlir::Location loc,
+                                  const Fortran::evaluate::Assignment &assign) {
+    // Lower pointer assignment inside forall with hlfir.region_assign with
+    // descriptor address/value and later implemented with a store.
+    // The RHS is fully prepared in lowering, so that all that is left
+    // in hlfir.region_assign code generation is the store.
     auto regionAssignOp = builder->create<hlfir::RegionAssignOp>(loc);
 
     // Lower LHS in its own region.
@@ -4400,22 +4382,74 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     builder->setInsertionPointAfter(regionAssignOp);
   }
 
+  mlir::Value lowerToIndexValue(mlir::Location loc,
+                                const Fortran::evaluate::ExtentExpr &expr,
+                                Fortran::lower::StatementContext &stmtCtx) {
+    mlir::Value val = fir::getBase(genExprValue(toEvExpr(expr), stmtCtx));
+    return builder->createConvert(loc, builder->getIndexType(), val);
+  }
+
   mlir::Value
   genForallPointerAssignmentRhs(mlir::Location loc, mlir::Value lhs,
                                 const Fortran::evaluate::Assignment &assign,
                                 Fortran::lower::StatementContext &rhsContext) {
-    if (Fortran::evaluate::IsProcedureDesignator(assign.rhs))
+    if (Fortran::evaluate::IsProcedureDesignator(assign.lhs)) {
+      if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
+              assign.rhs))
+        return fir::factory::createNullBoxProc(
+            *builder, loc, fir::unwrapRefType(lhs.getType()));
       return fir::getBase(Fortran::lower::convertExprToAddress(
           loc, *this, assign.rhs, localSymbols, rhsContext));
+    }
     // Data target.
+    auto lhsBoxType =
+        llvm::cast<fir::BaseBoxType>(fir::unwrapRefType(lhs.getType()));
+    // For NULL, create disassociated descriptor whose dynamic type is
+    // the static type of the LHS.
+    if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
+            assign.rhs))
+      return fir::factory::createUnallocatedBox(*builder, loc, lhsBoxType,
+                                                std::nullopt);
     hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR(
         loc, *this, assign.rhs, localSymbols, rhsContext);
     // Create pointer descriptor value from the RHS.
     if (rhs.isMutableBox())
       rhs = hlfir::Entity{builder->create<fir::LoadOp>(loc, rhs)};
-    auto lhsBoxType =
-        llvm::cast<fir::BaseBoxType>(fir::unwrapRefType(lhs.getType()));
-    return hlfir::genVariableBox(loc, *builder, rhs, lhsBoxType);
+    mlir::Value rhsBox = hlfir::genVariableBox(
+        loc, *builder, rhs, lhsBoxType.getBoxTypeWithNewShape(rhs.getRank()));
+    mlir::Type indexTy = builder->getIndexType();
+    // Bounds
+    if (const auto *lbExprs =
+            std::get_if<Fortran::evaluate::Assignment::BoundsSpec>(&assign.u);
+        lbExprs && !lbExprs->empty()) {
+      // Override target lower bounds with the LHS bounds spec.
+      llvm::SmallVector<mlir::Value> lbounds;
+      for (const Fortran::evaluate::ExtentExpr &lbExpr : *lbExprs)
+        lbounds.push_back(lowerToIndexValue(loc, lbExpr, rhsContext));
+      mlir::Value shift = builder->genShift(loc, lbounds);
+      rhsBox = builder->create<fir::ReboxOp>(loc, lhsBoxType, rhsBox, shift,
+                                             /*slice=*/mlir::Value{});
+    } else if (const auto *boundExprs =
+                   std::get_if<Fortran::evaluate::Assignment::BoundsRemapping>(
+                       &assign.u);
+               boundExprs && !boundExprs->empty()) {
+      // Reshape the target according to the LHS bounds remapping.
+      llvm::SmallVector<mlir::Value> lbounds;
+      llvm::SmallVector<mlir::Value> extents;
+      mlir::Type indexTy = builder->getIndexType();
+      mlir::Value zero = builder->createIntegerConstant(loc, indexTy, 0);
+      mlir::Value one = builder->createIntegerConstant(loc, indexTy, 1);
+      for (const auto &[lbExpr, ubExpr] : *boundExprs) {
+        lbounds.push_back(lowerToIndexValue(loc, lbExpr, rhsContext));
+        mlir::Value ub = lowerToIndexValue(loc, ubExpr, rhsContext);
+        extents.push_back(fir::factory::computeExtent(
+            *builder, loc, lbounds.back(), ub, zero, one));
+      }
+      mlir::Value shape = builder->genShape(loc, lbounds, extents);
+      rhsBox = builder->create<fir::ReboxOp>(loc, lhsBoxType, rhsBox, shape,
+                                             /*slice=*/mlir::Value{});
+    }
+    return rhsBox;
   }
 
   // Create the 2 x newRank array with the bounds to be passed to the runtime as
@@ -4856,17 +4890,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
               },
               [&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
                 if (isInsideHlfirForallOrWhere())
-                  genForallPointerAssignment(loc, assign, lbExprs);
+                  genForallPointerAssignment(loc, assign);
                 else
                   genPointerAssignment(loc, assign, lbExprs);
               },
               [&](const Fortran::evaluate::Assignment::BoundsRemapping
                       &boundExprs) {
                 if (isInsideHlfirForallOrWhere())
-                  TODO(
-                      loc,
-                      "pointer assignment with bounds remapping inside FORALL");
-                genPointerAssignment(loc, assign, boundExprs);
+                  genForallPointerAssignment(loc, assign);
+                else
+                  genPointerAssignment(loc, assign, boundExprs);
               },
           },
           assign.u);
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index 295158f153121..ae6db34e6e06e 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -1519,17 +1519,6 @@ static bool lowerToBoxValue(const Fortran::semantics::Symbol &sym,
   return false;
 }
 
-/// Compute extent from lower and upper bound.
-static mlir::Value computeExtent(fir::FirOpBuilder &builder, mlir::Location loc,
-                                 mlir::Value lb, mlir::Value ub) {
-  mlir::IndexType idxTy = builder.getIndexType();
-  // Let the folder deal with the common `ub - <const> + 1` case.
-  auto diff = builder.create<mlir::arith::SubIOp>(loc, idxTy, ub, lb);
-  mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
-  auto rawExtent = builder.create<mlir::arith::AddIOp>(loc, idxTy, diff, one);
-  return fir::factory::genMaxWithZero(builder, loc, rawExtent);
-}
-
 /// Lower explicit lower bounds into \p result. Does nothing if this is not an
 /// array, or if the lower bounds are deferred, or all implicit or one.
 static void lowerExplicitLowerBounds(
@@ -1593,8 +1582,8 @@ lowerExplicitExtents(Fortran::lower::AbstractConverter &converter,
       if (lowerBounds.empty())
         result.emplace_back(fir::factory::genMaxWithZero(builder, loc, ub));
       else
-        result.emplace_back(
-            computeExtent(builder, loc, lowerBounds[spec.index()], ub));
+        result.emplace_back(fir::factory::computeExtent(
+            builder, loc, lowerBounds[spec.index()], ub));
     } else if (spec.value()->ubound().isStar()) {
       result.emplace_back(getAssumedSizeExtent(loc, builder));
     }
@@ -2214,7 +2203,8 @@ void Fortran::lower::mapSymbolAttributes(
         if (auto high = spec->ubound().GetExplicit()) {
           auto expr = Fortran::lower::SomeExpr{*high};
           ub = builder.createConvert(loc, idxTy, genValue(expr));
-          extents.emplace_back(computeExtent(builder, loc, lb, ub));
+          extents.emplace_back(
+              fir::factory::computeExtent(builder, loc, lb, ub));
         } else {
           // An assumed size array. The extent is not computed.
           assert(spec->ubound().isStar() && "expected assumed size");
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index b3d440cedee07..b7f8a8d3a9d56 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -1609,9 +1609,8 @@ fir::factory::getExtentFromTriplet(mlir::Value lb, mlir::Value ub,
 }
 
 mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
-                                         mlir::Location loc,
-                                         mlir::Value value) {
-  mlir::Value zero = builder.createIntegerConstant(loc, value.getType(), 0);
+                                         mlir::Location loc, mlir::Value value,
+                                         mlir::Value zero) {
   if (mlir::Operation *definingOp = value.getDefiningOp())
     if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp))
       if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(cst.getValue()))
@@ -1622,6 +1621,32 @@ mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
                                                zero);
 }
 
+mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
+                                         mlir::Location loc,
+                                         mlir::Value value) {
+  mlir::Value zero = builder.createIntegerConstant(loc, value.getType(), 0);
+  return genMaxWithZero(builder, loc, value, zero);
+}
+
+mlir::Value fir::factory::computeExtent(fir::FirOpBuilder &builder,
+                                        mlir::Location loc, mlir::Value lb,
+                                        mlir::Value ub, mlir::Value zero,
+                                        mlir::Value one) {
+  mlir::Type type = lb.getType();
+  // Let the folder deal with the common `ub - <const> + 1` case.
+  auto diff = builder.create<mlir::arith::SubIOp>(loc, type, ub, lb);
+  auto rawExtent = builder.create<mlir::arith::AddIOp>(loc, type, diff, one);
+  return fir::factory::genMaxWithZero(builder, loc, rawExtent, zero);
+}
+mlir::Value fir::factory::computeExtent(fir::FirOpBuilder &builder,
+                                        mlir::Location loc, mlir::Value lb,
+                                        mlir::Value ub) {
+  mlir::Type type = lb.getType();
+  mlir::Value one = builder.createIntegerConstant(loc, type, 1);
+  mlir::Value zero = builder.createIntegerConstant(loc, type, 0);
+  return computeExtent(builder, loc, lb, ub, zero, one);
+}
+
 static std::pair<mlir::Value, mlir::Type>
 genCPtrOrCFunptrFieldIndex(fir::FirOpBuilder &builder, mlir::Location loc,
                            mlir::Type cptrTy) {
diff --git a/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling-bounds.f90 b/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling-bounds.f90
new file mode 100644
index 0000000000000..00c94d25e7b11
--- /dev/null
+++ b/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling-bounds.f90
@@ -0,0 +1,93 @@
+! Test analysis of pointer assignment inside FORALL with lower bounds or bounds
+! remapping.
+! The analysis must detect if the evaluation of the LHS or RHS may be impacted
+! by the pointer assignments, or if the forall can be lowered into a single
+! loop without any temporary copy.
+
+! RUN: bbc -hlfir -o /dev/null -pass-pipeline="builtin.module(lower-hlfir-ordered-assignments)" \
+! RUN: --debug-only=flang-ordered-assignment -flang-dbg-order-assignment-schedule-only %s 2>&1 | FileCheck %s
+! REQUIRES: asserts
+module forall_pointers_bounds
+  type ptr_wrapper
+    integer, pointer :: p(:, :)
+  end type
+contains
+
+! Simple case that can be lowered into a single loop.
+subroutine test_lb_no_conflict(a, iarray)
+ type(ptr_wrapper) :: a(:)
+ integer, target :: iarray(:, :)
+ forall(i=lbound(a,1):ubound(a,1)) a(i)%p(2*(i-1)+1:,2*i:) => iarray
+end subroutine
+
+subroutine test_remapping_no_conflict(a, iarray)
+ type(ptr_wrapper) :: a(:)
+ integer, target :: iarray(6)
+ ! Reshaping 6 to 2x3 with custom lower bounds.
+ forall(i=lbound(a,1):ubound(a,1)) a(i)%p(2*(i-1)+1:2*(i-1)+2,2*i:2*i+2) => iarray
+end subroutine
+! CHECK: ------------ scheduling forall in _QMforall_pointers_boundsPtest_remapping_no_conflict ------------
+! CHECK-NEXT: run 1 evaluate: forall/region_assign1
+
+! Bounds expression conflict. Note that even though they are syntactically on
+! the LHS,they are saved with the RHS because they are applied when preparing the
+! new descriptor value pointing to the RHS.
+subroutine test_lb_conflict(a, iarray)
+ type(ptr_wrapper) :: a(:)
+ integer, target :: iarray(:, :)
+ integer :: n
+ n = ubound(a,1)
+ forall(i=lbound(a,1):ubound(a,1)) a(i)%p(a(n+1-i)%p(1,1):,a(n+1-i)%p(2,1):) => iarray
+end subroutine
+! CHECK: ------------ scheduling forall in _QMforall_pointers_boundsPtest_lb_conflict ------------
+! CHECK-NEXT: conflict: R/W
+! CHECK-NEXT: run 1 save    : forall/region_assign1/rhs
+! CHECK-NEXT: run 2 evaluate: forall/region_assign1
+
+end module
+
+! End to end test provided for debugging purpose (not run by lit).
+program end_to_end
+  use forall_pointers_bounds
+  integer, parameter :: n = 5
+  integer, target, save :: data(2, 2, n) = reshape([(i, i=1,size(data))], shape=shape(data))
+  integer, target, save :: data2(6) = reshape([(i, i=1,size(data2))], shape=shape(data2))
+  type(ptr_wrapper) :: pointers(n)
+  ! Print pointer/target mapping baseline.
+  call reset_pointers(pointers)
+  if (.not.check_equal(pointers, [17,18,19,20,13,14,15,16,9,10,11,12,5,6,7,8,1,2,3,4])) stop 1
+
+  call reset_pointers(pointers)
+  call test_lb_no_conflict(pointers, data(:, :, 1))
+  if (.not.check_equal(pointers, [([1,2,3,4],i=1,n)])) stop 2
+  if (.not.all([(lbound(pointers(i)%p), i=1,n)].eq.[(i, i=1,2*n)])) stop 3
+
+  call reset_pointers(pointers)
+  call test_remapping_no_conflict(pointers, data2)
+  if (.not.check_equal(pointers, [([1,2,3,4,5,6],i=1,n)])) stop 4
+  if (.not.all([(lbound(pointers(i)%p), i=1,n)].eq.[(i, i=1,2*n)])) stop 5
+  if (.not.all([(ubound(pointers(i)%p), i=1,n)].eq.[([2*(i-1)+2, 2*i+2], i=1,n)])) stop 6
+
+  call reset_pointers(pointers)
+  call test_lb_conflict(pointers, data(:, :, 1))
+  if (.not.check_equal(pointers, [([1,2,3,4],i=1,n)])) stop 7
+  if (.not.all([(lbound(pointers(i)%p), i=1,n)].eq.[([data(1,1,i), data(2,1,i)], i=1,n)])) stop 8
+
+  print *, "PASS"
+contains
+subroutine reset_pointers(a)
+  type(ptr_wrapper) :: a(:)
+  do i=1,n
+    a(i)%p => data(:, :, n+1-i)
+  end do
+end subroutine
+logical function check_equal(a, expected)
+  type(ptr_wrapper) :: a(:)
+  integer :: expected(:)
+  check_equal = all([(a(i)%p, i=1,n)].eq.expected)
+  if (.not.check_equal) then
+    print *, "expected:", expected
+    print *, "got:", [(a(i)%p, i=1,n)]
+  end if
+end function
+end
diff --git a/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling-polymorphic.f90 b/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling-polymorphic.f90
new file mode 100644
index 0000000000000..9ccba7acc1b08
--- /dev/null
+++ b/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling-polymorphic.f90
@@ -0,0 +1,110 @@
+! Test analysis of polymorphic pointer assignment inside FORALL.
+! The analysis must detect if the evaluation of the LHS or RHS may be impacted
+! by the pointer assignments, or if the forall can be lowered into a single
+! loop without any temporary copy.
+
+! RUN: bbc -hlfir -o /dev/null -pass-pipeline="builtin.module(lower-hlfir-ordered-assignments)" \
+! RUN: --debug-only=flang-ordered-assignment -flang-dbg-order-assignment-schedule-only %s 2>&1 | FileCheck %s
+! REQUIRES: asserts
+module forall_poly_pointers
+  type base
+    integer :: i
+  end type
+  type, extends(base) :: extension
+    integer :: j
+  end type
+  type ptr_wrapper
+    class(base), pointer :: p
+  end type
+contains
+
+! Simple case that can be lowered into a single loop.
+subroutine test_no_conflict(n, a, somet)
+ integer :: n
+ type(ptr_wrapper) :: a(:)
+ class(base), target :: somet
+ forall(i=1:n) a(i)%p => somet
+end subroutine
+! CHECK: ------------ scheduling forall in _QMforall_poly_pointersPtest_no_conflict ------------
+! CHECK-NEXT: run 1 evaluate: forall/region_assign1
+
+subroutine test_no_conflict2(n, a, somet)
+ integer :: n
+ type(ptr_wrapper) :: a(:)
+ type(base), target :: somet
+ forall(i=1:n) a(i)%p => somet
+end subroutine
+! CHECK: ------------ scheduling forall in _QMforall_poly_pointersPtest_no_conflict2 ------------
+! CHECK-NEXT: run 1 evaluate: forall/region_assign1
+
+subroutine test_rhs_conflict(n, a)
+ integer :: n
+ type(ptr_wrapper) :: a(:)
+ forall(i=1:n) a(i)%p => a(n+1-i)%p
+end subroutine
+! CHECK: ------------ scheduling forall in _QMforall_poly_pointersPtest_rhs_conflict ------------
+! CHECK-NEXT: conflict: R/W
+! CHECK-NEXT: run 1 save    : forall/region_assign1/rhs
+! CHECK-NEXT: run 2 evaluate: forall/region_assign1
+end module
+
+! End to end test provided for debugging purpose (not run by lit).
+program end_to_end
+  use forall_poly_pointers
+  integer, parameter :: n = 10
+  type(extension), target, save :: data(n) = [(extension(i, 100+i), i=1,n)]
+  type(ptr_wrapper) :: pointers(n)
+  ! Print pointer/target mapping baseline.
+  call reset_pointers(pointers)
+  if (.not.check_equal(pointers, [10,9,8,7,6,5,4,3,2,1])) stop 1
+  if (.not.check_type(pointers, [(modulo(i,3).eq.0, i=1,n)])) stop 2
+
+  ! Test dynamic type is correctly set.
+  call test_no_conflict(n, pointers, data(1))
+  if (.not.check_equal(pointers, [(1,i=1,10)])) stop 3
+  if (.not.check_type(pointers, [(.true.,i=1,10)])) stop 4
+  call test_no_conflict(n, pointers, data(1)%base)
+  if (.not.check_equal(pointers, [(1,i=1,10)])) stop 5
+  if (.not.check_type(pointers, [(.false.,i=1,10)])) stop 6
+
+  call test_no_conflict2(n, pointers, data(1)%base)
+  if (.not.check_equal(pointers, [(1,i=1,10)])) stop 7
+  if (.not.check_type(pointers, [(.false.,i=1,10)])) stop 8
+
+  ! Test RHS conflict.
+  call reset_pointers(pointers)
+  call test_rhs_conflict(n, pointers)
+  if (.not.check_equal(pointers, [(i, i=1,10)])) stop 9
+  if (.not.check_type(pointers, [(modulo(i,3).eq.2, i=1,n)])) stop 10
+
+  print *, "PASS"
+contains
+subroutine reset_pointers(a)
+  type(ptr_wrapper) :: a(:)
+  do i=1,n
+    if (modulo(i,3).eq.0) then
+      a(i)%p => data(n+1-i)
+    else
+      a(i)%p => data(n+1-i)%base
+    end if
+  end do
+end subroutine
+logical function check_equal(a, expected)
+  type(ptr_wrapper) :: a(:)
+  integer :: expected(:)
+  check_equal = all([(a(i)%p%i, i=1,10)].eq.expected)
+  if (.not.check_equal) then
+    print *, "expected:", expected
+    print *, "got:", [(a(i)%p%i, i=1,10)]
+  end if
+end function
+logical function check_type(a, expected)
+  type(ptr_wrapper) :: a(:)
+  logical :: expected(:)
+  check_type = all([(same_type_as(a(i)%p, extension(1,1)), i=1,10)].eqv.expected)
+  if (.not.check_type) then
+    print *, "expected:", expected
+    print *, "got:", [(same_type_as(a(i)%p, extension(1,1)), i=1,10)]
+  end if
+end function
+end
diff --git a/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling.f90 b/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling.f90
index 52a0105ce2b6a..cb5bff1020b3a 100644
--- a/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling.f90
+++ b/flang/test/HLFIR/order_assignments/forall-pointer-assignment-scheduling.f90
@@ -25,6 +25,14 @@ subroutine test_no_conflict(n, a, somet)
 ! CHECK: ------------ scheduling forall in _QMforall_pointersPtest_no_conflict ------------
 ! CHECK-NEXT: run 1 evaluate: forall/region_assign1
 
+subroutine test_null_no_conflict(n, a)
+ integer :: n
+ type(ptr_wrapper), allocatable :: a(:)
+ forall(i=1:n) a(i)%p => null()
+end subroutine
+! CHECK: ------------ scheduling forall in _QMforall_pointersPtest_null_no_conflict ------------
+! CHECK-NEXT: run 1 evaluate: forall/region_assign1
+
 ! Case where the pointer target evaluations are impacted by the pointer
 ! assignments and should be evaluated for each iteration before doing
 ! any pointer assignment.
@@ -53,6 +61,16 @@ subroutine test_need_to_save_lhs(n, a, somet)
 ! CHECK-NEXT: run 1 save    : forall/region_assign1/lhs
 ! CHECK-NEXT: run 2 evaluate: forall/region_assign1
 
+subroutine test_null_need_to_save_lhs(n, a)
+ integer :: n
+ type(ptr_wrapper) :: a(:)
+ forall(i=1:n) a(a(n+1-i)%p%i)%p => null()
+end subroutine
+! CHECK: ------------ scheduling forall in _QMforall_pointersPtest_null_need_to_save_lhs ------------
+! CHECK-NEXT: conflict: R/W
+! CHECK-NEXT: run 1 save    : forall/region_assign1/lhs
+! CHECK-NEXT: run 2 evaluate: forall/region_assign1
+
 ! Case where both the computation of the target and descriptor addresses are
 ! impacted by the assignment and need to be all evaluated before doing any
 ! assignment.
@@ -76,27 +94,29 @@ program end_to_end
   type(t), target, save :: data(n) = [(t(i), i=1,n)]
   type(ptr_wrapper) :: pointers(n)
   ! Print pointer/target mapping baseline.
-  ! Expect: 10 9 8 7 6 5 4 3 2 1
   call reset_pointers(pointers)
-  call print_pointers(pointers)
+  if (.not.check_equal(pointers, [10,9,8,7,6,5,4,3,2,1])) stop 1
 
   ! Test case where RHS target addresses must be saved in FORALL.
-  ! Expect: 1 2 3 4 5 6 7 8 9 10
   call test_need_to_save_rhs(n, pointers)
-  call print_pointers(pointers)
+  if (.not.check_equal(pointers, [1,2,3,4,5,6,7,8,9,10])) stop 2
 
   ! Test case where LHS pointer addresses must be saved in FORALL.
-  ! Expect: 1 1 1 1 1 1 1 1 1 1
   call reset_pointers(pointers)
   call test_need_to_save_lhs(n, pointers, data(1))
-  call print_pointers(pointers)
+  if (.not.check_equal(pointers, [(1,i=1,10)])) stop 3
 
   ! Test case where bot RHS target addresses and LHS pointer addresses must be
   ! saved in FORALL.
-  ! Expect: 2 4 6 8 10 1 3 5 7 9
   call reset_pointers(pointers)
   call test_need_to_save_lhs_and_rhs(n, pointers)
-  call print_pointers(pointers)
+  if (.not.check_equal(pointers, [2,4,6,8,10,1,3,5,7,9])) stop 4
+
+  call reset_pointers(pointers)
+  call test_null_need_to_save_lhs(n, pointers)
+  if (.not.check_associated(pointers, [(.false., i=1,n)])) stop 5
+
+  print *, "PASS"
 contains
 subroutine reset_pointers(a)
   type(ptr_wrapper) :: a(:)
@@ -104,8 +124,22 @@ subroutine reset_pointers(a)
     a(i)%p => data(n+1-i)
   end do
 end subroutine
-subroutine print_pointers(a)
+logical function check_equal(a, expected)
   type(ptr_wrapper) :: a(:)
-  print *, [(a(i)%p%i, i=lbound(a,1), ubound(a,1))]
-end subroutine
+  integer :: expected(:)
+  check_equal = all([(a(i)%p%i, i=1,10)].eq.expected)
+  if (.not.check_equal) then
+    print *, "expected:", expected
+    print *, "got:", [(a(i)%p%i, i=1,10)]
+  end if
+end function
+logical function check_associated(a, expected)
+  type(ptr_wrapper) :: a(:)
+  logical :: expected(:)
+  check_associated = all([(associated(a(i)%p), i=1,10)].eqv.expected)
+  if (.not.check_associated) then
+    print *, "expected:", expected
+    print *, "got:", [(associated(a(i)%p), i=1,10)]
+  end if
+end function
 end
diff --git a/flang/test/HLFIR/order_assignments/forall-proc-pointer-assignment-scheduling.f90 b/flang/test/HLFIR/order_assignments/forall-proc-pointer-assignment-scheduling.f90
index ba9c203453d95..0cce790470cb4 100644
--- a/flang/test/HLFIR/order_assignments/forall-proc-pointer-assignment-scheduling.f90
+++ b/flang/test/HLFIR/order_assignments/forall-proc-pointer-assignment-scheduling.f90
@@ -80,6 +80,23 @@ subroutine test_need_to_save_lhs_and_rhs(x)
 ! CHECK-NEXT: run 1 save    : forall/region_assign1/lhs
 ! CHECK-NEXT: run 2 evaluate: forall/region_assign1
 
+  subroutine test_null_no_conflict(x)
+    type(t) :: x(10)
+    forall(i=1:10) x(i)%p => null()
+  end subroutine
+! CHECK: ------------ scheduling forall in _QMproc_ptr_forallPtest_null_no_conflict ------------
+! CHECK-NEXT: run 1 evaluate: forall/region_assign1
+
+  subroutine test_null_need_to_save_lhs(x)
+    type(t) :: x(10)
+    forall(i=1:10) x(x(11-i)%p())%p => null()
+  end subroutine
+! CHECK: ------------ scheduling forall in _QMproc_ptr_forallPtest_null_need_to_save_lhs ------------
+! CHECK-NEXT: unknown effect: %{{.*}} = fir.call
+! CHECK-NEXT: unknown effect: %{{.*}} = fir.call
+! CHECK-NEXT: conflict: R/W
+! CHECK-NEXT: run 1 save    : forall/region_assign1/lhs
+! CHECK-NEXT: run 2 evaluate: forall/region_assign1
 
 ! End-to-end test utilities for debugging purposes.
 
@@ -102,6 +119,17 @@ logical function check_equal(a, expected)
       print *, "got:", [(a(i)%p(), i=1,10)]
     end if
   end function
+
+  logical function check_association(a, expected)
+    type(t) :: a(:)
+    logical :: expected(:)
+    check_association = all([(associated(a(i)%p), i=1,10)].eqv.expected)
+    if (.not.check_association) then
+      print *, "expected:", expected
+      print *, "got:", [(associated(a(i)%p), i=1,10)]
+    end if
+  end function
+
 end module
 
 ! End-to-end test for debugging purposes (not verified by lit).
@@ -119,5 +147,10 @@ logical function check_equal(a, expected)
   call reset(a)
   call test_need_to_save_lhs_and_rhs(a)
   if (.not.check_equal(a, [2, 4, 6, 8, 10, 1, 3, 5, 7, 9])) stop 3
+
+  call reset(a)
+  call test_null_need_to_save_lhs(a)
+  if (.not.check_association(a, [(.false., i=1,10)])) stop 4
+
   print *, "PASS"
 end
diff --git a/flang/test/Lower/OpenACC/acc-enter-data-unwrap-defaultbounds.f90 b/flang/test/Lower/OpenACC/acc-enter-data-unwrap-defaultbounds.f90
index 6bdd1031eeb4e..b6d76134f14af 100644
--- a/flang/test/Lower/OpenACC/acc-enter-data-unwrap-defaultbounds.f90
+++ b/flang/test/Lower/OpenACC/acc-enter-data-unwrap-defaultbounds.f90
@@ -203,10 +203,10 @@ subroutine acc_enter_data_dummy(a, b, n, m)
 !CHECK: %[[LOAD_M:.*]] = fir.load %[[DECLM]]#0 : !fir.ref<i32>
 !CHECK: %[[M_I64:.*]] = fir.convert %[[LOAD_M]] : (i32) -> i64
 !CHECK: %[[M_IDX:.*]] = fir.convert %[[M_I64]] : (i64) -> index
-!CHECK: %[[M_N:.*]] = arith.subi %[[M_IDX]], %[[N_IDX]] : index
 !CHECK: %[[C1:.*]] = arith.constant 1 : index
-!CHECK: %[[M_N_1:.*]] = arith.addi %[[M_N]], %[[C1]] : index
 !CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[M_N:.*]] = arith.subi %[[M_IDX]], %[[N_IDX]] : index
+!CHECK: %[[M_N_1:.*]] = arith.addi %[[M_N]], %[[C1]] : index
 !CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[M_N_1]], %[[C0]] : index
 !CHECK: %[[EXT_B:.*]] = arith.select %[[CMP]], %[[M_N_1]], %[[C0]] : index
 !CHECK: %[[DECLB:.*]]:2 = hlfir.declare %[[B]]
diff --git a/flang/test/Lower/OpenACC/acc-enter-data.f90 b/flang/test/Lower/OpenACC/acc-enter-data.f90
index 8892dec7d1197..2b7cda468f70f 100644
--- a/flang/test/Lower/OpenACC/acc-enter-data.f90
+++ b/flang/test/Lower/OpenACC/acc-enter-data.f90
@@ -147,10 +147,10 @@ subroutine acc_enter_data_dummy(a, b, n, m)
 !CHECK: %[[LOAD_M:.*]] = fir.load %[[DECLM]]#0 : !fir.ref<i32>
 !CHECK: %[[M_I64:.*]] = fir.convert %[[LOAD_M]] : (i32) -> i64
 !CHECK: %[[M_IDX:.*]] = fir.convert %[[M_I64]] : (i64) -> index
-!CHECK: %[[M_N:.*]] = arith.subi %[[M_IDX]], %[[N_IDX]] : index
 !CHECK: %[[C1:.*]] = arith.constant 1 : index
-!CHECK: %[[M_N_1:.*]] = arith.addi %[[M_N]], %[[C1]] : index
 !CHECK: %[[C0:.*]] = arith.constant 0 : index
+!CHECK: %[[M_N:.*]] = arith.subi %[[M_IDX]], %[[N_IDX]] : index
+!CHECK: %[[M_N_1:.*]] = arith.addi %[[M_N]], %[[C1]] : index
 !CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[M_N_1]], %[[C0]] : index
 !CHECK: %[[EXT_B:.*]] = arith.select %[[CMP]], %[[M_N_1]], %[[C0]] : index
 !CHECK: %[[DECLB:.*]]:2 = hlfir.declare %[[B]]



More information about the flang-commits mailing list