[flang-commits] [flang] b87e655 - [flang][hlfir] Lower forall to HLFIR

Jean Perier via flang-commits flang-commits at lists.llvm.org
Tue May 9 00:20:42 PDT 2023


Author: Jean Perier
Date: 2023-05-09T09:20:23+02:00
New Revision: b87e65531c58df55cfae4c06c7a68f84539aa779

URL: https://github.com/llvm/llvm-project/commit/b87e65531c58df55cfae4c06c7a68f84539aa779
DIFF: https://github.com/llvm/llvm-project/commit/b87e65531c58df55cfae4c06c7a68f84539aa779.diff

LOG: [flang][hlfir] Lower forall to HLFIR

Lower Forall to the previously added hlfir.forall, hlfir.forall_mask.
hlfir.forall_index, and hlfir.region_assign operations.

The HLFIR assignment code lowering is moved into genDataAssignment for
more readability and so that user defined assignment (still a TODO),
will be able to share most of the logic.

Differential Revision: https://reviews.llvm.org/D149878

Added: 
    flang/test/Lower/HLFIR/forall.f90

Modified: 
    flang/lib/Lower/Bridge.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index fdc6649972471..fe86fe8cb2dd9 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1937,7 +1937,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   }
 
   void genFIR(const Fortran::parser::EndForallStmt &) {
-    cleanupExplicitSpace();
+    if (!lowerToHighLevelFIR())
+      cleanupExplicitSpace();
   }
 
   template <typename A>
@@ -1956,11 +1957,24 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Generate FIR for a FORALL statement.
   void genFIR(const Fortran::parser::ForallStmt &stmt) {
+    const auto &concurrentHeader =
+        std::get<
+            Fortran::common::Indirection<Fortran::parser::ConcurrentHeader>>(
+            stmt.t)
+            .value();
+    if (lowerToHighLevelFIR()) {
+      mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+      localSymbols.pushScope();
+      genForallNest(concurrentHeader);
+      genFIR(std::get<Fortran::parser::UnlabeledStatement<
+                 Fortran::parser::ForallAssignmentStmt>>(stmt.t)
+                 .statement);
+      localSymbols.popScope();
+      builder->restoreInsertionPoint(insertPt);
+      return;
+    }
     prepareExplicitSpace(stmt);
-    genFIR(std::get<
-               Fortran::common::Indirection<Fortran::parser::ConcurrentHeader>>(
-               stmt.t)
-               .value());
+    genFIR(concurrentHeader);
     genFIR(std::get<Fortran::parser::UnlabeledStatement<
                Fortran::parser::ForallAssignmentStmt>>(stmt.t)
                .statement);
@@ -1969,7 +1983,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   /// Generate FIR for a FORALL construct.
   void genFIR(const Fortran::parser::ForallConstruct &forall) {
-    prepareExplicitSpace(forall);
+    mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+    if (lowerToHighLevelFIR())
+      localSymbols.pushScope();
+    else
+      prepareExplicitSpace(forall);
     genNestedStatement(
         std::get<
             Fortran::parser::Statement<Fortran::parser::ForallConstructStmt>>(
@@ -1987,14 +2005,101 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     genNestedStatement(
         std::get<Fortran::parser::Statement<Fortran::parser::EndForallStmt>>(
             forall.t));
+    if (lowerToHighLevelFIR()) {
+      localSymbols.popScope();
+      builder->restoreInsertionPoint(insertPt);
+    }
   }
 
   /// Lower the concurrent header specification.
   void genFIR(const Fortran::parser::ForallConstructStmt &stmt) {
-    genFIR(std::get<
-               Fortran::common::Indirection<Fortran::parser::ConcurrentHeader>>(
-               stmt.t)
-               .value());
+    const auto &concurrentHeader =
+        std::get<
+            Fortran::common::Indirection<Fortran::parser::ConcurrentHeader>>(
+            stmt.t)
+            .value();
+    if (lowerToHighLevelFIR())
+      genForallNest(concurrentHeader);
+    else
+      genFIR(concurrentHeader);
+  }
+
+  /// Generate hlfir.forall and hlfir.forall_mask nest given a Forall
+  /// concurrent header
+  void genForallNest(const Fortran::parser::ConcurrentHeader &header) {
+    mlir::Location loc = getCurrentLocation();
+    const bool isOutterForall = !isInsideHlfirForallOrWhere();
+    hlfir::ForallOp outerForall;
+    auto evaluateControl = [&](const auto &parserExpr, mlir::Region &region,
+                               bool isMask = false) {
+      if (region.empty())
+        builder->createBlock(&region);
+      Fortran::lower::StatementContext localStmtCtx;
+      const Fortran::semantics::SomeExpr *anlalyzedExpr =
+          Fortran::semantics::GetExpr(parserExpr);
+      assert(anlalyzedExpr && "expression semantics failed");
+      // Generate the controls of outer forall outside of the hlfir.forall
+      // region. They do not depend on any previous forall indices (C1123) and
+      // no assignment has been made yet that could modify their value. This
+      // will simplify hlfir.forall analysis because the SSA integer value
+      // yielded will obviously not depend on any variable modified by the
+      // forall when produced outside of it.
+      // This is not done for the mask because it may (and in usual code, does)
+      // depend on the forall indices that have just been defined as
+      // hlfir.forall block arguments.
+      mlir::OpBuilder::InsertPoint innerInsertionPoint;
+      if (outerForall && !isMask) {
+        innerInsertionPoint = builder->saveInsertionPoint();
+        builder->setInsertionPoint(outerForall);
+      }
+      mlir::Value exprVal =
+          fir::getBase(genExprValue(*anlalyzedExpr, localStmtCtx, &loc));
+      localStmtCtx.finalizeAndPop();
+      if (isMask)
+        exprVal = builder->createConvert(loc, builder->getI1Type(), exprVal);
+      if (innerInsertionPoint.isSet())
+        builder->restoreInsertionPoint(innerInsertionPoint);
+      builder->create<hlfir::YieldOp>(loc, exprVal);
+    };
+    for (const Fortran::parser::ConcurrentControl &control :
+         std::get<std::list<Fortran::parser::ConcurrentControl>>(header.t)) {
+      auto forallOp = builder->create<hlfir::ForallOp>(loc);
+      if (isOutterForall && !outerForall)
+        outerForall = forallOp;
+      evaluateControl(std::get<1>(control.t), forallOp.getLbRegion());
+      evaluateControl(std::get<2>(control.t), forallOp.getUbRegion());
+      if (const auto &optionalStep =
+              std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
+                  control.t))
+        evaluateControl(*optionalStep, forallOp.getStepRegion());
+      // Create block argument and map it to a symbol via an hlfir.forall_index
+      // op (symbols must be mapped to in memory values).
+      const Fortran::semantics::Symbol *controlVar =
+          std::get<Fortran::parser::Name>(control.t).symbol;
+      assert(controlVar && "symbol analysis failed");
+      mlir::Type controlVarType = genType(*controlVar);
+      mlir::Block *forallBody = builder->createBlock(&forallOp.getBody(), {},
+                                                     {controlVarType}, {loc});
+      auto forallIndex = builder->create<hlfir::ForallIndexOp>(
+          loc, fir::ReferenceType::get(controlVarType),
+          forallBody->getArguments()[0],
+          builder->getStringAttr(controlVar->name().ToString()));
+      localSymbols.addVariableDefinition(*controlVar, forallIndex,
+                                         /*force=*/true);
+      auto end = builder->create<fir::FirEndOp>(loc);
+      builder->setInsertionPoint(end);
+    }
+
+    if (const auto &maskExpr =
+            std::get<std::optional<Fortran::parser::ScalarLogicalExpr>>(
+                header.t)) {
+      // Create hlfir.forall_mask and set insertion point in its body.
+      auto forallMaskOp = builder->create<hlfir::ForallMaskOp>(loc);
+      evaluateControl(*maskExpr, forallMaskOp.getMaskRegion(), /*isMask=*/true);
+      builder->createBlock(&forallMaskOp.getBody());
+      auto end = builder->create<fir::FirEndOp>(loc);
+      builder->setInsertionPoint(end);
+    }
   }
 
   void genFIR(const Fortran::parser::CompilerDirective &) {
@@ -2991,13 +3096,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// DestroyOp in case the returned value has hlfir::ExprType.
   mlir::Value
   genImplicitLogicalConvert(const Fortran::evaluate::Assignment &assign,
-                            hlfir::Entity lhs, hlfir::Entity rhs) {
+                            hlfir::Entity rhs,
+                            Fortran::lower::StatementContext &stmtCtx) {
     mlir::Type fromTy = rhs.getFortranElementType();
-    mlir::Type toTy = lhs.getFortranElementType();
-    if (fromTy == toTy)
+    if (!fromTy.isa<mlir::IntegerType, fir::LogicalType>())
       return nullptr;
 
-    if (!fromTy.isa<mlir::IntegerType, fir::LogicalType>())
+    mlir::Type toTy = hlfir::getFortranElementType(genType(assign.lhs));
+    if (fromTy == toTy)
       return nullptr;
     if (!toTy.isa<mlir::IntegerType, fir::LogicalType>())
       return nullptr;
@@ -3015,76 +3121,147 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto val = hlfir::loadTrivialScalar(loc, builder, elementPtr);
       return hlfir::EntityWithAttributes{builder.createConvert(loc, toTy, val)};
     };
-    return hlfir::genElementalOp(loc, builder, toTy, shape, /*typeParams=*/{},
-                                 genKernel);
+    mlir::Value convertedRhs = hlfir::genElementalOp(
+        loc, builder, toTy, shape, /*typeParams=*/{}, genKernel);
+    fir::FirOpBuilder *bldr = &builder;
+    stmtCtx.attachCleanup([loc, bldr, convertedRhs]() {
+      bldr->create<hlfir::DestroyOp>(loc, convertedRhs);
+    });
+    return convertedRhs;
+  }
+
+  static void
+  genCleanUpInRegionIfAny(mlir::Location loc, fir::FirOpBuilder &builder,
+                          mlir::Region &region,
+                          Fortran::lower::StatementContext &context) {
+    if (!context.hasCode())
+      return;
+    mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
+    if (region.empty())
+      builder.createBlock(&region);
+    else
+      builder.setInsertionPointToEnd(&region.front());
+    context.finalizeAndPop();
+    hlfir::YieldOp::ensureTerminator(region, builder, loc);
+    builder.restoreInsertionPoint(insertPt);
+  }
+
+  void genDataAssignment(
+      const Fortran::evaluate::Assignment &assign,
+      const Fortran::evaluate::ProcedureRef *userDefinedAssignment) {
+    mlir::Location loc = getCurrentLocation();
+    fir::FirOpBuilder &builder = getFirOpBuilder();
+    // Gather some information about the assignment that will impact how it is
+    // lowered.
+    const bool isWholeAllocatableAssignment =
+        !userDefinedAssignment &&
+        Fortran::lower::isWholeAllocatable(assign.lhs);
+    std::optional<Fortran::evaluate::DynamicType> lhsType =
+        assign.lhs.GetType();
+    const bool keepLhsLengthInAllocatableAssignment =
+        isWholeAllocatableAssignment && lhsType.has_value() &&
+        lhsType->category() == Fortran::common::TypeCategory::Character &&
+        !lhsType->HasDeferredTypeParameter();
+    const bool lhsHasVectorSubscripts =
+        Fortran::evaluate::HasVectorSubscript(assign.lhs);
+
+    // Helper to generate the code evaluating the right-hand side.
+    auto evaluateRhs = [&](Fortran::lower::StatementContext &stmtCtx) {
+      hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR(
+          loc, *this, assign.rhs, localSymbols, stmtCtx);
+      // Load trivial scalar RHS to allow the loads to be hoisted outside of
+      // loops early if possible. This also dereferences pointer and
+      // allocatable RHS: the target is being assigned from.
+      rhs = hlfir::loadTrivialScalar(loc, builder, rhs);
+      // In intrinsic assignments, Logical<->Integer assignments are allowed as
+      // an extension, but there is no explicit Convert expression for the RHS.
+      // Recognize the type mismatch here and insert explicit scalar convert or
+      // ElementalOp for array assignment.
+      if (!userDefinedAssignment)
+        if (mlir::Value conversion =
+                genImplicitLogicalConvert(assign, rhs, stmtCtx))
+          rhs = hlfir::Entity{conversion};
+      return rhs;
+    };
+
+    // Helper to generate the code evaluating the left-hand side.
+    auto evaluateLhs = [&](Fortran::lower::StatementContext &stmtCtx) {
+      hlfir::Entity lhs = Fortran::lower::convertExprToHLFIR(
+          loc, *this, assign.lhs, localSymbols, stmtCtx);
+      // Dereference pointer LHS: the target is being assigned to.
+      // Same for allocatables outside of whole allocatable assignments.
+      if (!isWholeAllocatableAssignment)
+        lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
+      return lhs;
+    };
+
+    if (!isInsideHlfirForallOrWhere() && !lhsHasVectorSubscripts &&
+        !userDefinedAssignment) {
+      Fortran::lower::StatementContext localStmtCtx;
+      hlfir::Entity rhs = evaluateRhs(localStmtCtx);
+      hlfir::Entity lhs = evaluateLhs(localStmtCtx);
+      builder.create<hlfir::AssignOp>(loc, rhs, lhs,
+                                      isWholeAllocatableAssignment,
+                                      keepLhsLengthInAllocatableAssignment);
+      return;
+    }
+    // Assignments inside Forall, Where, or assignments to a vector subscripted
+    // left-hand side requires using an hlfir.region_assign in HLFIR. The
+    // right-hand side and left-hand side must be evaluated inside the
+    // hlfir.region_assign regions.
+    auto regionAssignOp = builder.create<hlfir::RegionAssignOp>(loc);
+
+    // Lower RHS in its own region.
+    builder.createBlock(&regionAssignOp.getRhsRegion());
+    Fortran::lower::StatementContext rhsContext;
+    hlfir::Entity rhs = evaluateRhs(rhsContext);
+    auto rhsYieldOp = builder.create<hlfir::YieldOp>(loc, rhs);
+    genCleanUpInRegionIfAny(loc, builder, rhsYieldOp.getCleanup(), rhsContext);
+    // Lower LHS in its own region.
+    builder.createBlock(&regionAssignOp.getLhsRegion());
+    Fortran::lower::StatementContext lhsContext;
+    if (!lhsHasVectorSubscripts) {
+      hlfir::Entity lhs = evaluateLhs(lhsContext);
+      auto lhsYieldOp = builder.create<hlfir::YieldOp>(loc, lhs);
+      genCleanUpInRegionIfAny(loc, builder, lhsYieldOp.getCleanup(),
+                              lhsContext);
+    } else {
+      TODO(loc, "assignment to vector subscripted entity");
+    }
+
+    // Add "realloc" flag to hlfir.region_assign.
+    if (isWholeAllocatableAssignment)
+      TODO(loc, "assignment to a whole allocatable inside FORALL");
+    // Generate the hlfir.region_assign userDefinedAssignment region.
+    if (userDefinedAssignment)
+      TODO(loc, "HLFIR user defined assignment");
+
+    builder.setInsertionPointAfter(regionAssignOp);
   }
 
   /// Shared for both assignments and pointer assignments.
   void genAssignment(const Fortran::evaluate::Assignment &assign) {
     mlir::Location loc = toLocation();
     if (lowerToHighLevelFIR()) {
-      if (explicitIterationSpace() || !implicitIterSpace.empty())
-        TODO(loc, "HLFIR assignment inside FORALL or WHERE");
-      auto &builder = getFirOpBuilder();
+      if (!implicitIterSpace.empty())
+        TODO(loc, "HLFIR assignment inside WHERE");
       std::visit(
           Fortran::common::visitors{
-              // [1] Plain old assignment.
               [&](const Fortran::evaluate::Assignment::Intrinsic &) {
-                if (Fortran::evaluate::HasVectorSubscript(assign.lhs))
-                  TODO(loc, "assignment to vector subscripted entity");
-                Fortran::lower::StatementContext stmtCtx;
-                hlfir::Entity rhs = Fortran::lower::convertExprToHLFIR(
-                    loc, *this, assign.rhs, localSymbols, stmtCtx);
-                // Load trivial scalar LHS to allow the loads to be hoisted
-                // outside of loops early if possible. This also dereferences
-                // pointer and allocatable RHS: the target is being assigned
-                // from.
-                rhs = hlfir::loadTrivialScalar(loc, builder, rhs);
-                hlfir::Entity lhs = Fortran::lower::convertExprToHLFIR(
-                    loc, *this, assign.lhs, localSymbols, stmtCtx);
-                bool isWholeAllocatableAssignment = false;
-                bool keepLhsLengthInAllocatableAssignment = false;
-                if (Fortran::lower::isWholeAllocatable(assign.lhs)) {
-                  isWholeAllocatableAssignment = true;
-                  if (std::optional<Fortran::evaluate::DynamicType> lhsType =
-                          assign.lhs.GetType())
-                    keepLhsLengthInAllocatableAssignment =
-                        lhsType->category() ==
-                            Fortran::common::TypeCategory::Character &&
-                        !lhsType->HasDeferredTypeParameter();
-                } else {
-                  // Dereference pointer LHS: the target is being assigned to.
-                  lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
-                }
-
-                // Logical<->Integer assignments are allowed as an extension,
-                // but there is no explicit Convert expression for the RHS.
-                // Recognize the type mismatch here and insert explicit
-                // scalar convert or ElementalOp for array assignment.
-                mlir::Value logicalConvert =
-                    genImplicitLogicalConvert(assign, lhs, rhs);
-                if (logicalConvert)
-                  rhs = hlfir::EntityWithAttributes{logicalConvert};
-
-                builder.create<hlfir::AssignOp>(
-                    loc, rhs, lhs, isWholeAllocatableAssignment,
-                    keepLhsLengthInAllocatableAssignment);
-
-                // Mark the end of life range of the ElementalOp's result.
-                if (logicalConvert &&
-                    logicalConvert.getType().isa<hlfir::ExprType>())
-                  builder.create<hlfir::DestroyOp>(loc, rhs);
+                genDataAssignment(assign, /*userDefinedAssignment=*/nullptr);
               },
-              // [2] User defined assignment. If the context is a scalar
-              // expression then call the procedure.
               [&](const Fortran::evaluate::ProcedureRef &procRef) {
-                TODO(loc, "HLFIR user defined assignment");
+                genDataAssignment(assign, /*userDefinedAssignment=*/&procRef);
               },
               [&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
+                if (isInsideHlfirForallOrWhere())
+                  TODO(loc, "pointer assignment inside FORALL");
                 genPointerAssignment(loc, assign, lbExprs);
               },
               [&](const Fortran::evaluate::Assignment::BoundsRemapping
                       &boundExprs) {
+                if (isInsideHlfirForallOrWhere())
+                  TODO(loc, "pointer assignment inside FORALL");
                 genPointerAssignment(loc, assign, boundExprs);
               },
           },
@@ -3275,6 +3452,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       Fortran::lower::createArrayMergeStores(*this, explicitIterSpace);
   }
 
+  bool isInsideHlfirForallOrWhere() const {
+    mlir::Block *block = builder->getInsertionBlock();
+    mlir::Operation *op = block ? block->getParentOp() : nullptr;
+    while (op) {
+      if (mlir::isa<hlfir::ForallOp, hlfir::WhereOp>(op))
+        return true;
+      op = op->getParentOp();
+    }
+    return false;
+  }
+
   void genFIR(const Fortran::parser::WhereConstruct &c) {
     implicitIterSpace.growStack();
     genNestedStatement(

diff  --git a/flang/test/Lower/HLFIR/forall.f90 b/flang/test/Lower/HLFIR/forall.f90
new file mode 100644
index 0000000000000..5091dde7bc79c
--- /dev/null
+++ b/flang/test/Lower/HLFIR/forall.f90
@@ -0,0 +1,185 @@
+! Test lowering of Forall to HLFIR.
+! RUN: bbc --hlfir -o - %s | FileCheck %s
+
+module forall_defs
+  integer :: x(10, 10), y(10)
+  interface
+    pure integer(8) function ifoo2(i, j)
+      integer(8), value :: i, j
+    end function
+    pure integer(8) function jfoo()
+    end function
+    pure integer(8) function jbar()
+    end function
+    pure logical function predicate(i)
+      integer(8), intent(in) :: i
+    end function
+  end interface
+end module
+
+subroutine test_simple_forall()
+  use forall_defs
+  forall (integer(8)::i=1:10) x(i, i) = y(i)
+end subroutine
+! CHECK-LABEL:   func.func @_QPtest_simple_forall() {
+! CHECK:  %[[VAL_0:.*]] = arith.constant 10 : i32
+! CHECK:  %[[VAL_1:.*]] = arith.constant 1 : i32
+! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  %[[VAL_8:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK:  hlfir.forall lb {
+! CHECK:    hlfir.yield %[[VAL_1]] : i32
+! CHECK:  } ub {
+! CHECK:    hlfir.yield %[[VAL_0]] : i32
+! CHECK:  }  (%[[VAL_9:.*]]: i64) {
+! CHECK:    hlfir.region_assign {
+! CHECK:      %[[VAL_10:.*]] = hlfir.designate %[[VAL_8]]#0 (%[[VAL_9]])  : (!fir.ref<!fir.array<10xi32>>, i64) -> !fir.ref<i32>
+! CHECK:      %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
+! CHECK:      hlfir.yield %[[VAL_11]] : i32
+! CHECK:    } to {
+! CHECK:      %[[VAL_12:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_9]], %[[VAL_9]])  : (!fir.ref<!fir.array<10x10xi32>>, i64, i64) -> !fir.ref<i32>
+! CHECK:      hlfir.yield %[[VAL_12]] : !fir.ref<i32>
+! CHECK:    }
+! CHECK:  }
+
+subroutine test_forall_step(step)
+  use forall_defs
+  integer :: step
+  forall (integer(8)::i=1:10:step) x(i, i) = y(i)
+end subroutine
+! CHECK-LABEL:   func.func @_QPtest_forall_step(
+! CHECK:  %[[VAL_1:.*]] = arith.constant 10 : i32
+! CHECK:  %[[VAL_2:.*]] = arith.constant 1 : i32
+! CHECK:  %[[VAL_4:.*]]:2 = hlfir.declare {{.*}}Estep
+! CHECK:  %[[VAL_7:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  %[[VAL_10:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK:  %[[VAL_11:.*]] = fir.load %[[VAL_4]]#1 : !fir.ref<i32>
+! CHECK:  hlfir.forall lb {
+! CHECK:    hlfir.yield %[[VAL_2]] : i32
+! CHECK:  } ub {
+! CHECK:    hlfir.yield %[[VAL_1]] : i32
+! CHECK:  } step {
+! CHECK:    hlfir.yield %[[VAL_11]] : i32
+! CHECK:  }  (%[[VAL_12:.*]]: i64) {
+! CHECK:    hlfir.region_assign {
+! CHECK:      %[[VAL_13:.*]] = hlfir.designate %[[VAL_10]]#0 (%[[VAL_12]])  : (!fir.ref<!fir.array<10xi32>>, i64) -> !fir.ref<i32>
+! CHECK:      %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<i32>
+! CHECK:      hlfir.yield %[[VAL_14]] : i32
+! CHECK:    } to {
+! CHECK:      %[[VAL_15:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_12]], %[[VAL_12]])  : (!fir.ref<!fir.array<10x10xi32>>, i64, i64) -> !fir.ref<i32>
+! CHECK:      hlfir.yield %[[VAL_15]] : !fir.ref<i32>
+! CHECK:    }
+! CHECK:  }
+
+subroutine test_forall_mask()
+  use forall_defs
+  forall (integer(8)::i=1:10, predicate(i)) x(i, i) = y(i)
+end subroutine
+! CHECK-LABEL:   func.func @_QPtest_forall_mask() {
+! CHECK:  %[[VAL_0:.*]] = arith.constant 10 : i32
+! CHECK:  %[[VAL_1:.*]] = arith.constant 1 : i32
+! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  %[[VAL_8:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK:  hlfir.forall lb {
+! CHECK:    hlfir.yield %[[VAL_1]] : i32
+! CHECK:  } ub {
+! CHECK:    hlfir.yield %[[VAL_0]] : i32
+! CHECK:  }  (%[[VAL_9:.*]]: i64) {
+! CHECK:    %[[VAL_10:.*]] = hlfir.forall_index "i" %[[VAL_9]] : (i64) -> !fir.ref<i64>
+! CHECK:    hlfir.forall_mask {
+! CHECK:      %[[VAL_11:.*]] = fir.call @_QPpredicate(%[[VAL_10]]) fastmath<contract> : (!fir.ref<i64>) -> !fir.logical<4>
+! CHECK:      %[[VAL_12:.*]] = fir.convert %[[VAL_11]] : (!fir.logical<4>) -> i1
+! CHECK:      hlfir.yield %[[VAL_12]] : i1
+! CHECK:    } do {
+! CHECK:      hlfir.region_assign {
+! CHECK:        %[[VAL_13:.*]] = hlfir.designate %[[VAL_8]]#0 (%[[VAL_9]])  : (!fir.ref<!fir.array<10xi32>>, i64) -> !fir.ref<i32>
+! CHECK:        %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<i32>
+! CHECK:        hlfir.yield %[[VAL_14]] : i32
+! CHECK:      } to {
+! CHECK:        %[[VAL_15:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_9]], %[[VAL_9]])  : (!fir.ref<!fir.array<10x10xi32>>, i64, i64) -> !fir.ref<i32>
+! CHECK:        hlfir.yield %[[VAL_15]] : !fir.ref<i32>
+! CHECK:      }
+! CHECK:    }
+! CHECK:  }
+
+subroutine test_forall_several_indices()
+  use forall_defs
+  ! Test outer forall controls are lowered outside.
+  forall (integer(8)::i=ibar():ifoo(), j=jfoo():jbar()) x(i, j) = y(ifoo2(i, j))
+end subroutine
+! CHECK-LABEL:   func.func @_QPtest_forall_several_indices() {
+! CHECK:  %[[VAL_3:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  %[[VAL_6:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK:  %[[VAL_7:.*]] = fir.call @_QPibar() fastmath<contract> : () -> i32
+! CHECK:  %[[VAL_8:.*]] = fir.call @_QPifoo() fastmath<contract> : () -> i32
+! CHECK:  %[[VAL_9:.*]] = fir.call @_QPjfoo() fastmath<contract> : () -> i64
+! CHECK:  %[[VAL_10:.*]] = fir.call @_QPjbar() fastmath<contract> : () -> i64
+! CHECK:  hlfir.forall lb {
+! CHECK:    hlfir.yield %[[VAL_7]] : i32
+! CHECK:  } ub {
+! CHECK:    hlfir.yield %[[VAL_8]] : i32
+! CHECK:  }  (%[[VAL_11:.*]]: i64) {
+! CHECK:    hlfir.forall lb {
+! CHECK:      hlfir.yield %[[VAL_9]] : i64
+! CHECK:    } ub {
+! CHECK:      hlfir.yield %[[VAL_10]] : i64
+! CHECK:    }  (%[[VAL_12:.*]]: i64) {
+! CHECK:      hlfir.region_assign {
+! CHECK:        %[[VAL_13:.*]] = fir.call @_QPifoo2(%[[VAL_11]], %[[VAL_12]]) fastmath<contract> : (i64, i64) -> i64
+! CHECK:        %[[VAL_14:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_13]])  : (!fir.ref<!fir.array<10xi32>>, i64) -> !fir.ref<i32>
+! CHECK:        %[[VAL_15:.*]] = fir.load %[[VAL_14]] : !fir.ref<i32>
+! CHECK:        hlfir.yield %[[VAL_15]] : i32
+! CHECK:      } to {
+! CHECK:        %[[VAL_16:.*]] = hlfir.designate %[[VAL_3]]#0 (%[[VAL_11]], %[[VAL_12]])  : (!fir.ref<!fir.array<10x10xi32>>, i64, i64) -> !fir.ref<i32>
+! CHECK:        hlfir.yield %[[VAL_16]] : !fir.ref<i32>
+! CHECK:      }
+! CHECK:    }
+! CHECK:  }
+
+subroutine test_nested_foralls()
+  use forall_defs
+  forall (integer(8)::i=1:10)
+    x(i, i) = y(i)
+    ! ifoo and ibar could depend on x since it is a module
+    ! variable use associated. The calls in the control value
+    ! computation cannot be hoisted from the outer forall
+    ! even when they do not depend on outer forall indicies.
+    forall (integer(8)::j=jfoo():jbar())
+      x(i, j) = x(j, i)
+    end forall
+  end forall
+end subroutine
+! CHECK-LABEL:   func.func @_QPtest_nested_foralls() {
+! CHECK:  %[[VAL_0:.*]] = arith.constant 10 : i32
+! CHECK:  %[[VAL_1:.*]] = arith.constant 1 : i32
+! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare {{.*}}Ex
+! CHECK:  %[[VAL_8:.*]]:2 = hlfir.declare {{.*}}Ey
+! CHECK:  hlfir.forall lb {
+! CHECK:    hlfir.yield %[[VAL_1]] : i32
+! CHECK:  } ub {
+! CHECK:    hlfir.yield %[[VAL_0]] : i32
+! CHECK:  }  (%[[VAL_9:.*]]: i64) {
+! CHECK:    hlfir.region_assign {
+! CHECK:      %[[VAL_10:.*]] = hlfir.designate %[[VAL_8]]#0 (%[[VAL_9]])  : (!fir.ref<!fir.array<10xi32>>, i64) -> !fir.ref<i32>
+! CHECK:      %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
+! CHECK:      hlfir.yield %[[VAL_11]] : i32
+! CHECK:    } to {
+! CHECK:      %[[VAL_12:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_9]], %[[VAL_9]])  : (!fir.ref<!fir.array<10x10xi32>>, i64, i64) -> !fir.ref<i32>
+! CHECK:      hlfir.yield %[[VAL_12]] : !fir.ref<i32>
+! CHECK:    }
+! CHECK:    hlfir.forall lb {
+! CHECK:      %[[VAL_13:.*]] = fir.call @_QPjfoo() fastmath<contract> : () -> i64
+! CHECK:      hlfir.yield %[[VAL_13]] : i64
+! CHECK:    } ub {
+! CHECK:      %[[VAL_14:.*]] = fir.call @_QPjbar() fastmath<contract> : () -> i64
+! CHECK:      hlfir.yield %[[VAL_14]] : i64
+! CHECK:    }  (%[[VAL_15:.*]]: i64) {
+! CHECK:      hlfir.region_assign {
+! CHECK:        %[[VAL_16:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_15]], %[[VAL_9]])  : (!fir.ref<!fir.array<10x10xi32>>, i64, i64) -> !fir.ref<i32>
+! CHECK:        %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+! CHECK:        hlfir.yield %[[VAL_17]] : i32
+! CHECK:      } to {
+! CHECK:        %[[VAL_18:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_9]], %[[VAL_15]])  : (!fir.ref<!fir.array<10x10xi32>>, i64, i64) -> !fir.ref<i32>
+! CHECK:        hlfir.yield %[[VAL_18]] : !fir.ref<i32>
+! CHECK:      }
+! CHECK:    }
+! CHECK:  }


        


More information about the flang-commits mailing list