[flang-commits] [flang] [flang] Do not hoist all scalar sub-expressions from WHERE constructs (PR #91395)

via flang-commits flang-commits at lists.llvm.org
Mon May 13 02:03:46 PDT 2024


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/91395

>From 60a5b9b2e233ea1a12d8c558b4ca6dedbcbb44a5 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 7 May 2024 03:17:46 -0700
Subject: [PATCH 1/3] [flang] Do not hoist all scalar expressions from WHERE
 constructs

The HLFIR pass lowering WHERE (hlfir.where op) was too aggressive
in its hoisting of scalar sub-expressions from LHS/RHS/MASKS outside
of the loops generated for the WHERE construct.
This violated F'2023 10.2.3.2 point 10 that stipulated that elemental
operations must be evaluated only for elements corresponding to true
values, because scalar operations are still elemental, and hoisting
them is invalid if they could have side effects (e.g, division by zero)
and if the MASK is always false (i.e., the loop body is never evaluated).

The difficulty is that 10.2.3.2 point 9 mandates that nonelemental function
must be evaluated before the loops. So it is not possible to simply stop
hoisting non hlfir.elemental operations.
Marking calls with an elemental/nonelemental attribute would not allow
the pass to be correct if inlining is run before and drops this information,
beside, extracting the argument tree that may have been CSE-ed with the rest of
the expression evaluation would be a bit combursome.

Instead, lower nonelemental calls into a new hlfir.exactly_once operation
that will allow retaining the information that the operations contained inside
its region must be hoisted. This allows inlining to operate before if desired
in order to improve alias analysis.

The LowerHLFIROrderedAssignments pass is updated to only hoist the operations
contained inside hlfir.exactly_once bodies.
---
 flang/include/flang/Lower/StatementContext.h  |  14 ++
 .../include/flang/Optimizer/HLFIR/HLFIROps.td |  24 ++-
 flang/lib/Lower/Bridge.cpp                    |  45 ++--
 flang/lib/Lower/ConvertCall.cpp               |  37 ++++
 .../LowerHLFIROrderedAssignments.cpp          |  93 ++++++--
 .../HLFIR/order_assignments/impure-where.fir  |   9 +-
 .../order_assignments/inlined-stack-temp.fir  |   2 +-
 .../user-defined-assignment-finalization.fir  |  31 +--
 .../where-codegen-no-conflict.fir             |   4 +-
 .../order_assignments/where-hoisting.f90      |  50 +++++
 flang/test/Lower/HLFIR/where-nonelemental.f90 | 198 ++++++++++++++++++
 11 files changed, 448 insertions(+), 59 deletions(-)
 create mode 100644 flang/test/HLFIR/order_assignments/where-hoisting.f90
 create mode 100644 flang/test/Lower/HLFIR/where-nonelemental.f90

diff --git a/flang/include/flang/Lower/StatementContext.h b/flang/include/flang/Lower/StatementContext.h
index cec9641d43a08..7776edc93ed73 100644
--- a/flang/include/flang/Lower/StatementContext.h
+++ b/flang/include/flang/Lower/StatementContext.h
@@ -18,6 +18,15 @@
 #include <functional>
 #include <optional>
 
+namespace mlir {
+class Location;
+class Region;
+} // namespace mlir
+
+namespace fir {
+class FirOpBuilder;
+}
+
 namespace Fortran::lower {
 
 /// When lowering a statement, temporaries for intermediate results may be
@@ -105,6 +114,11 @@ class StatementContext {
   llvm::SmallVector<std::optional<CleanupFunction>> cufs;
 };
 
+/// If \p context contains any cleanups, ensure \p region has a block, and
+/// generate the cleanup inside that block.
+void genCleanUpInRegionIfAny(mlir::Location loc, fir::FirOpBuilder &builder,
+                             mlir::Region &region, StatementContext &context);
+
 } // namespace Fortran::lower
 
 #endif // FORTRAN_LOWER_STATEMENTCONTEXT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index ee3c26800ae3a..fdb21656a27fc 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -1329,7 +1329,8 @@ def hlfir_RegionAssignOp : hlfir_Op<"region_assign", [hlfir_OrderedAssignmentTre
 }
 
 def hlfir_YieldOp : hlfir_Op<"yield", [Terminator, ParentOneOf<["RegionAssignOp",
-    "ElementalAddrOp", "ForallOp", "ForallMaskOp", "WhereOp", "ElseWhereOp"]>,
+    "ElementalAddrOp", "ForallOp", "ForallMaskOp", "WhereOp", "ElseWhereOp",
+    "ExactlyOnceOp"]>,
     SingleBlockImplicitTerminator<"fir::FirEndOp">, RecursivelySpeculatable,
         RecursiveMemoryEffects]> {
 
@@ -1594,6 +1595,27 @@ def hlfir_ForallMaskOp : hlfir_AssignmentMaskOp<"forall_mask"> {
   let hasVerifier = 1;
 }
 
+def hlfir_ExactlyOnceOp : hlfir_Op<"exactly_once", [RecursiveMemoryEffects]> {
+  let summary = "Execute exactly once its region in a WhereOp";
+  let description = [{
+    Inside a Where assignment, Fortran requires a non elemental call and its
+    arguments to be executed exactly once, regardless of the mask values.
+    This operation allows holding these evaluations that cannot be hoisted
+    until potential parent Forall loops have been created.
+    It also allows inlining the calls without losing the information that
+    these calls must be hoisted.
+  }];
+
+  let regions = (region SizedRegion<1>:$body);
+
+  let results = (outs AnyFortranEntity:$result);
+
+  let assemblyFormat = [{
+    attr-dict `:` type($result)
+    $body
+    }];
+}
+
 def hlfir_WhereOp : hlfir_AssignmentMaskOp<"where"> {
   let summary = "Represent a Fortran where construct or statement";
   let description = [{
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ae8679afc603f..beda4d7328525 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3677,22 +3677,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     return hlfir::Entity{valueAndPair.first};
   }
 
-  static void
-  genCleanUpInRegionIfAny(mlir::Location loc, fir::FirOpBuilder &builder,
-                          mlir::Region &region,
-                          Fortran::lower::StatementContext &context) {
-    if (!context.hasCode())
-      return;
-    mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
-    if (region.empty())
-      builder.createBlock(&region);
-    else
-      builder.setInsertionPointToEnd(&region.front());
-    context.finalizeAndPop();
-    hlfir::YieldOp::ensureTerminator(region, builder, loc);
-    builder.restoreInsertionPoint(insertPt);
-  }
-
   bool firstDummyIsPointerOrAllocatable(
       const Fortran::evaluate::ProcedureRef &userDefinedAssignment) {
     using DummyAttr = Fortran::evaluate::characteristics::DummyDataObject::Attr;
@@ -3918,7 +3902,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     Fortran::lower::StatementContext rhsContext;
     hlfir::Entity rhs = evaluateRhs(rhsContext);
     auto rhsYieldOp = builder.create<hlfir::YieldOp>(loc, rhs);
-    genCleanUpInRegionIfAny(loc, builder, rhsYieldOp.getCleanup(), rhsContext);
+    Fortran::lower::genCleanUpInRegionIfAny(
+        loc, builder, rhsYieldOp.getCleanup(), rhsContext);
     // Lower LHS in its own region.
     builder.createBlock(&regionAssignOp.getLhsRegion());
     Fortran::lower::StatementContext lhsContext;
@@ -3926,15 +3911,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     if (!lhsHasVectorSubscripts) {
       hlfir::Entity lhs = evaluateLhs(lhsContext);
       auto lhsYieldOp = builder.create<hlfir::YieldOp>(loc, lhs);
-      genCleanUpInRegionIfAny(loc, builder, lhsYieldOp.getCleanup(),
-                              lhsContext);
+      Fortran::lower::genCleanUpInRegionIfAny(
+          loc, builder, lhsYieldOp.getCleanup(), lhsContext);
       lhsYield = lhs;
     } else {
       hlfir::ElementalAddrOp elementalAddr =
           Fortran::lower::convertVectorSubscriptedExprToElementalAddr(
               loc, *this, assign.lhs, localSymbols, lhsContext);
-      genCleanUpInRegionIfAny(loc, builder, elementalAddr.getCleanup(),
-                              lhsContext);
+      Fortran::lower::genCleanUpInRegionIfAny(
+          loc, builder, elementalAddr.getCleanup(), lhsContext);
       lhsYield = elementalAddr.getYieldOp().getEntity();
     }
     assert(lhsYield && "must have been set");
@@ -4289,7 +4274,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         loc, *this, *maskExpr, localSymbols, maskContext);
     mask = hlfir::loadTrivialScalar(loc, *builder, mask);
     auto yieldOp = builder->create<hlfir::YieldOp>(loc, mask);
-    genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(), maskContext);
+    Fortran::lower::genCleanUpInRegionIfAny(loc, *builder, yieldOp.getCleanup(),
+                                            maskContext);
   }
   void genFIR(const Fortran::parser::WhereConstructStmt &stmt) {
     const Fortran::semantics::SomeExpr *maskExpr = Fortran::semantics::GetExpr(
@@ -5545,3 +5531,18 @@ Fortran::lower::LoweringBridge::LoweringBridge(
   fir::support::setMLIRDataLayout(*module.get(),
                                   targetMachine.createDataLayout());
 }
+
+void Fortran::lower::genCleanUpInRegionIfAny(
+    mlir::Location loc, fir::FirOpBuilder &builder, mlir::Region &region,
+    Fortran::lower::StatementContext &context) {
+  if (!context.hasCode())
+    return;
+  mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
+  if (region.empty())
+    builder.createBlock(&region);
+  else
+    builder.setInsertionPointToEnd(&region.front());
+  context.finalizeAndPop();
+  hlfir::YieldOp::ensureTerminator(region, builder, loc);
+  builder.restoreInsertionPoint(insertPt);
+}
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 3659dad367b42..f989bc7e017f3 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -2682,10 +2682,47 @@ bool Fortran::lower::isIntrinsicModuleProcRef(
   return module && module->attrs().test(Fortran::semantics::Attr::INTRINSIC);
 }
 
+static bool isInWhereMaskedExpression(fir::FirOpBuilder& builder) {
+  // The MASK of the outer WHERE is not masked itself.
+  mlir::Operation* op = builder.getRegion().getParentOp();
+  return op && op->getParentOfType<hlfir::WhereOp>();
+}
+
 std::optional<hlfir::EntityWithAttributes> Fortran::lower::convertCallToHLFIR(
     mlir::Location loc, Fortran::lower::AbstractConverter &converter,
     const evaluate::ProcedureRef &procRef, std::optional<mlir::Type> resultType,
     Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
+  auto &builder = converter.getFirOpBuilder();
+  if (resultType && !procRef.IsElemental() && isInWhereMaskedExpression(builder) &&
+      !builder.getRegion().getParentOfType<hlfir::ExactlyOnceOp>()) {
+    // Non elemental calls inside a where-assignment-stmt must be executed
+    // exactly once without mask control. Lower them in a special region so that
+    // this can be enforced whenscheduling forall/where expression evaluations.
+    Fortran::lower::StatementContext localStmtCtx;
+    mlir::Type bogusType = builder.getIndexType();
+    auto exactlyOnce = builder.create<hlfir::ExactlyOnceOp>(loc, bogusType);
+    mlir::Block *block = builder.createBlock(&exactlyOnce.getBody());
+    builder.setInsertionPointToStart(block);
+    CallContext callContext(procRef, resultType, loc, converter, symMap,
+                            localStmtCtx);
+    std::optional<hlfir::EntityWithAttributes> res =
+        genProcedureRef(callContext);
+    assert(res.has_value() && "must be a function");
+    auto yield = builder.create<hlfir::YieldOp>(loc, *res);
+    Fortran::lower::genCleanUpInRegionIfAny(loc, builder, yield.getCleanup(),
+                                            localStmtCtx);
+    builder.setInsertionPointAfter(exactlyOnce);
+    exactlyOnce->getResult(0).setType(res->getType());
+    if (hlfir::isFortranValue(exactlyOnce.getResult()))
+      return hlfir::EntityWithAttributes{exactlyOnce.getResult()};
+    // Create hlfir.declare for the result to satisfy
+    // hlfir::EntityWithAttributes requirements.
+    auto [exv, cleanup] = hlfir::translateToExtendedValue(
+        loc, builder, hlfir::Entity{exactlyOnce});
+    assert(!cleanup && "resut is a variable");
+    return hlfir::genDeclare(loc, builder, exv, ".func.pointer.result",
+                             fir::FortranVariableFlagsAttr{});
+  }
   CallContext callContext(procRef, resultType, loc, converter, symMap, stmtCtx);
   return genProcedureRef(callContext);
 }
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
index 63b52c0cd0bc4..e4a9999d48c10 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
@@ -56,7 +56,8 @@ namespace {
 /// expression and allows splitting the generation of the none elemental part
 /// from the elemental part.
 struct MaskedArrayExpr {
-  MaskedArrayExpr(mlir::Location loc, mlir::Region &region);
+  MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
+                  bool isOuterMaskExpr);
 
   /// Generate the none elemental part. Must be called outside of the
   /// loops created for the WHERE construct.
@@ -81,14 +82,17 @@ struct MaskedArrayExpr {
 
   mlir::Location loc;
   mlir::Region ®ion;
-  /// Was generateNoneElementalPart called?
-  bool noneElementalPartWasGenerated = false;
   /// Set of operations that form the elemental parts of the
   /// expression evaluation. These are the hlfir.elemental and
   /// hlfir.elemental_addr that form the elemental tree producing
   /// the expression value. hlfir.elemental that produce values
   /// used inside transformational operations are not part of this set.
   llvm::SmallSet<mlir::Operation *, 4> elementalParts{};
+  /// Was generateNoneElementalPart called?
+  bool noneElementalPartWasGenerated = false;
+  /// Is this expression the mask expression of the outer where statement?
+  /// It is special because its evaluation is not masked by anything yet.
+  bool isOuterMaskExpr = false;
 };
 } // namespace
 
@@ -202,7 +206,7 @@ class OrderedAssignmentRewriter {
   /// This method returns the scalar element (that may have been previously
   /// saved) for the current indices inside the where loop.
   mlir::Value generateMaskedEntity(mlir::Location loc, mlir::Region &region) {
-    MaskedArrayExpr maskedExpr(loc, region);
+    MaskedArrayExpr maskedExpr(loc, region, /*isOuterMaskExpr=*/!whereLoopNest);
     return generateMaskedEntity(maskedExpr);
   }
   mlir::Value generateMaskedEntity(MaskedArrayExpr &maskedExpr);
@@ -524,7 +528,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
       return;
     }
     // The mask was not evaluated yet or can be safely re-evaluated.
-    MaskedArrayExpr mask(loc, whereOp.getMaskRegion());
+    MaskedArrayExpr mask(loc, whereOp.getMaskRegion(),
+                         /*isOuterMaskExpr=*/true);
     mask.generateNoneElementalPart(builder, mapper);
     mlir::Value shape = mask.generateShape(builder, mapper);
     whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
@@ -628,6 +633,13 @@ OrderedAssignmentRewriter::getIfSaved(mlir::Region &region) {
   return std::nullopt;
 }
 
+static hlfir::YieldOp getYield(mlir::Region &region) {
+  auto yield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
+      region.back().getOperations().back());
+  assert(yield && "region computing entities must end with a YieldOp");
+  return yield;
+}
+
 OrderedAssignmentRewriter::ValueAndCleanUp
 OrderedAssignmentRewriter::generateYieldedEntity(
     mlir::Region &region, std::optional<mlir::Type> castToType) {
@@ -644,9 +656,7 @@ OrderedAssignmentRewriter::generateYieldedEntity(
   }
 
   assert(region.hasOneBlock() && "region must contain one block");
-  auto oldYield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
-      region.back().getOperations().back());
-  assert(oldYield && "region computing entities must end with a YieldOp");
+  auto oldYield = getYield(region);
   mlir::Block::OpListType &ops = region.back().getOperations();
 
   // Inside Forall, scalars that do not depend on forall indices can be hoisted
@@ -887,8 +897,9 @@ gatherElementalTree(hlfir::ElementalOpInterface elemental,
   }
 }
 
-MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region &region)
-    : loc{loc}, region{region} {
+MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
+                                 bool isOuterMaskExpr)
+    : loc{loc}, region{region}, isOuterMaskExpr{isOuterMaskExpr} {
   mlir::Operation &terminator = region.back().back();
   if (auto elementalAddr =
           mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) {
@@ -907,13 +918,36 @@ void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder,
                                                 mlir::IRMapping &mapper) {
   assert(!noneElementalPartWasGenerated &&
          "none elemental parts already generated");
-  // Clone all operations, except the elemental and the final yield.
-  mlir::Block::OpListType &ops = region.back().getOperations();
-  assert(!ops.empty() && "yield block cannot be empty");
-  auto end = ops.end();
-  for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt)
-    if (!elementalParts.contains(&*opIt))
-      (void)builder.clone(*opIt, mapper);
+  if (isOuterMaskExpr) {
+    // The outer mask expression is actually not masked, it is dealt as
+    // such so that its elemental part, if any, can be inlined in the WHERE
+    // loops. But all of the operations outside of hlfir.elemental/
+    // hlfir.elemental_addr must be emitted now because their value may be
+    // required to deduce the mask shape and the WHERE loop bounds.
+    for (mlir::Operation &op : region.back().without_terminator())
+      if (!elementalParts.contains(&op))
+        (void)builder.clone(op, mapper);
+  } else {
+    // For actual masked expressions, Fortran requires elemental expressions,
+    // even the scalar ones that are no encoded with hlfir.elemental, to be
+    // evaluated only when the mask is true. Blindly hoisting all scalar SSA
+    // tree could be wrong if the scalar computation has side effects and
+    // would never have been evaluated (e.g. division by zero) if the mask
+    // is fully false. See F'2023 10.2.3.2 point 10.
+    // Clone only the bodies of all hlfir.exactly_once operations, which contain
+    // the evaluation of sub-expression tree whose root was a non elemental
+    // function call at the Fortran level (the call itself may have been inlined
+    // since). These must be evaluated only once as per F'2023 10.2.3.2 point 9.
+    for (mlir::Operation &op : region.back().without_terminator())
+      if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
+        for (mlir::Operation &subOp :
+             exactlyOnce.getBody().back().without_terminator())
+          (void)builder.clone(subOp, mapper);
+        mlir::Value oldYield = getYield(exactlyOnce.getBody()).getEntity();
+        auto newYield = mapper.lookupOrDefault(oldYield);
+        mapper.map(exactlyOnce.getResult(), newYield);
+      }
+  }
   noneElementalPartWasGenerated = true;
 }
 
@@ -942,6 +976,15 @@ MaskedArrayExpr::generateElementalParts(fir::FirOpBuilder &builder,
                                         mlir::IRMapping &mapper) {
   assert(noneElementalPartWasGenerated &&
          "non elemental part must have been generated");
+  if (!isOuterMaskExpr) {
+    // Clone all operations that are not hlfir.exactly_once and that are not
+    // hlfir.elemental/hlfir.elemental_addr.
+    for (mlir::Operation &op : region.back().without_terminator())
+      if (!mlir::isa<hlfir::ExactlyOnceOp>(op) && !elementalParts.contains(&op))
+        (void)builder.clone(op, mapper);
+    // For the outer mask, this was already done outside of the loop.
+  }
+  // Clone and "index" bodies of hlfir.elemental/hlfir.elemental_addr.
   mlir::Operation &terminator = region.back().back();
   hlfir::ElementalOpInterface elemental =
       mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator);
@@ -968,6 +1011,22 @@ MaskedArrayExpr::generateElementalParts(fir::FirOpBuilder &builder,
 
 void MaskedArrayExpr::generateNoneElementalCleanupIfAny(
     fir::FirOpBuilder &builder, mlir::IRMapping &mapper) {
+  if (!isOuterMaskExpr) {
+    // Clone clean-ups of hlfir.exactly_once operations (in reverse order
+    // to properly deal with stack restores).
+    for (mlir::Operation &op :
+         llvm::reverse(region.back().without_terminator()))
+      if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
+        mlir::Region &cleanupRegion =
+            getYield(exactlyOnce.getBody()).getCleanup();
+        if (!cleanupRegion.empty())
+          for (mlir::Operation &cleanupOp :
+               cleanupRegion.front().without_terminator())
+            (void)builder.clone(cleanupOp, mapper);
+      }
+  }
+  // Clone the clean-ups from the region itself, except for the destroy
+  // of the hlfir.elemental that have been inlined.
   mlir::Operation &terminator = region.back().back();
   mlir::Region *cleanupRegion = nullptr;
   if (auto elementalAddr = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) {
diff --git a/flang/test/HLFIR/order_assignments/impure-where.fir b/flang/test/HLFIR/order_assignments/impure-where.fir
index 9399ea83d1822..011a486b2baf7 100644
--- a/flang/test/HLFIR/order_assignments/impure-where.fir
+++ b/flang/test/HLFIR/order_assignments/impure-where.fir
@@ -13,10 +13,13 @@ func.func @test_elsewhere_impure_mask(%x: !fir.ref<!fir.array<10xi32>>, %y: !fir
     hlfir.yield %mask : !fir.ref<!fir.array<10x!fir.logical<4>>>
   } do {
     hlfir.elsewhere mask {
-      %mask2 = fir.call @impure() : () -> !fir.heap<!fir.array<10x!fir.logical<4>>>
-      hlfir.yield %mask2 : !fir.heap<!fir.array<10x!fir.logical<4>>> cleanup {
-        fir.freemem %mask2 : !fir.heap<!fir.array<10x!fir.logical<4>>>
+      %mask2 = hlfir.exactly_once : !fir.heap<!fir.array<10x!fir.logical<4>>> {
+        %imp = fir.call @impure() : () -> !fir.heap<!fir.array<10x!fir.logical<4>>>
+        hlfir.yield %imp : !fir.heap<!fir.array<10x!fir.logical<4>>> cleanup {
+          fir.freemem %imp : !fir.heap<!fir.array<10x!fir.logical<4>>>
+        }
       }
+      hlfir.yield %mask2 : !fir.heap<!fir.array<10x!fir.logical<4>>>
     } do {
       hlfir.region_assign {
         hlfir.yield %y : !fir.ref<!fir.array<10xi32>>
diff --git a/flang/test/HLFIR/order_assignments/inlined-stack-temp.fir b/flang/test/HLFIR/order_assignments/inlined-stack-temp.fir
index 66ff55558ea67..0724d019537c0 100644
--- a/flang/test/HLFIR/order_assignments/inlined-stack-temp.fir
+++ b/flang/test/HLFIR/order_assignments/inlined-stack-temp.fir
@@ -282,7 +282,6 @@ func.func @test_where_rhs_save(%x: !fir.ref<!fir.array<10xi32>>, %mask: !fir.ref
 // CHECK:           %[[VAL_7:.*]] = arith.constant 10 : index
 // CHECK:           %[[VAL_8:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_9:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_10:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_5]]:%[[VAL_4]]:%[[VAL_3]])  shape %[[VAL_6]] : (!fir.ref<!fir.array<10xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
 // CHECK:           %[[VAL_11:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_12:.*]] = arith.subi %[[VAL_7]], %[[VAL_9]] : index
 // CHECK:           %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : index
@@ -300,6 +299,7 @@ func.func @test_where_rhs_save(%x: !fir.ref<!fir.array<10xi32>>, %mask: !fir.ref
 // CHECK:             %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref<!fir.logical<4>>
 // CHECK:             %[[VAL_25:.*]] = fir.convert %[[VAL_24]] : (!fir.logical<4>) -> i1
 // CHECK:             fir.if %[[VAL_25]] {
+// CHECK:           %[[VAL_10:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_5]]:%[[VAL_4]]:%[[VAL_3]])  shape %[[VAL_6]] : (!fir.ref<!fir.array<10xi32>>, index, index, index, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
 // CHECK:               %[[VAL_26:.*]] = hlfir.designate %[[VAL_10]] (%[[VAL_22]])  : (!fir.ref<!fir.array<10xi32>>, index) -> !fir.ref<i32>
 // CHECK:               %[[VAL_27:.*]] = fir.load %[[VAL_26]] : !fir.ref<i32>
 // CHECK:               %[[VAL_28:.*]] = fir.load %[[VAL_2]] : !fir.ref<index>
diff --git a/flang/test/HLFIR/order_assignments/user-defined-assignment-finalization.fir b/flang/test/HLFIR/order_assignments/user-defined-assignment-finalization.fir
index bbb589169a80a..ae5329a2d2433 100644
--- a/flang/test/HLFIR/order_assignments/user-defined-assignment-finalization.fir
+++ b/flang/test/HLFIR/order_assignments/user-defined-assignment-finalization.fir
@@ -59,7 +59,7 @@ func.func @_QPtest1() {
       %7 = fir.call @_FortranADestroy(%6) fastmath<contract> : (!fir.box<none>) -> none
     }
   } to {
-    hlfir.yield %2#0 : !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>> 
+    hlfir.yield %2#0 : !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
   } user_defined_assign  (%arg0: !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) to (%arg1: !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) {
     %3 = fir.embox %arg1 : (!fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) -> !fir.box<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
     %4 = fir.convert %3 : (!fir.box<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) -> !fir.class<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
@@ -119,7 +119,7 @@ func.func @_QPtest2() {
       fir.call @llvm.stackrestore.p0(%4) fastmath<contract> : (!fir.ref<i8>) -> ()
     }
   } to {
-    hlfir.yield %3#0 : !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>> 
+    hlfir.yield %3#0 : !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>
   } user_defined_assign  (%arg0: !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) to (%arg1: !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) {
     %4 = fir.embox %arg1 : (!fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) -> !fir.box<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
     %5 = fir.convert %4 : (!fir.box<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) -> !fir.class<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
@@ -193,18 +193,22 @@ func.func @_QPtest3(%arg0: !fir.ref<!fir.array<2xi32>> {fir.bindc_name = "y"}) {
     }
   } do {
     hlfir.region_assign {
-      %5 = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
-      %6 = fir.call @_QPnew_obja() fastmath<contract> : () -> !fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
-      fir.save_result %6 to %0(%2) : !fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>, !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.shape<1>
-      %7:2 = hlfir.declare %0(%2) {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.shape<1>) -> (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>)
-      hlfir.yield %7#0 : !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>> cleanup {
-        %8 = fir.embox %0(%2) : (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.shape<1>) -> !fir.box<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>
-        %9 = fir.convert %8 : (!fir.box<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>) -> !fir.box<none>
-        %10 = fir.call @_FortranADestroy(%9) fastmath<contract> : (!fir.box<none>) -> none
-        fir.call @llvm.stackrestore.p0(%5) fastmath<contract> : (!fir.ref<i8>) -> ()
+      %5 = hlfir.exactly_once : !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>> {
+        %7 = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
+        %8 = fir.call @_QPnew_obja() fastmath<contract> : () -> !fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
+        fir.save_result %8 to %0(%2) : !fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>, !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.shape<1>
+        %9:2 = hlfir.declare %0(%2) {uniq_name = ".tmp.func_result"} : (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.shape<1>) -> (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>)
+        hlfir.yield %9#0 : !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>> cleanup {
+          %10 = fir.embox %0(%2) : (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.shape<1>) -> !fir.box<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>
+          %11 = fir.convert %10 : (!fir.box<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>) -> !fir.box<none>
+          %12 = fir.call @_FortranADestroy(%11) fastmath<contract> : (!fir.box<none>) -> none
+          fir.call @llvm.stackrestore.p0(%7) fastmath<contract> : (!fir.ref<i8>) -> ()
+        }
       }
+      %6:2 = hlfir.declare %5(%2) {uniq_name = ".func.pointer.result"} : (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.shape<1>) -> (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>)
+      hlfir.yield %6#0 : !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>
     } to {
-      hlfir.yield %3#0 : !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>> 
+      hlfir.yield %3#0 : !fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>
     } user_defined_assign  (%arg1: !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) to (%arg2: !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) {
       %5 = fir.embox %arg2 : (!fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) -> !fir.box<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
       %6 = fir.convert %5 : (!fir.box<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) -> !fir.class<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
@@ -246,7 +250,8 @@ func.func @_QPtest3(%arg0: !fir.ref<!fir.array<2xi32>> {fir.bindc_name = "y"}) {
 // CHECK:             %[[VAL_30:.*]] = fir.load %[[VAL_29]] : !fir.ref<!fir.logical<4>>
 // CHECK:             %[[VAL_31:.*]] = fir.convert %[[VAL_30]] : (!fir.logical<4>) -> i1
 // CHECK:             fir.if %[[VAL_31]] {
-// CHECK:               %[[VAL_32:.*]] = hlfir.designate %[[VAL_20]]#0 (%[[VAL_28]])  : (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, index) -> !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
+// CHECK:               %[[VAL_20B:.*]]:2 = hlfir.declare %[[VAL_20]]#0(%[[VAL_7]]) {uniq_name = ".func.pointer.result"}
+// CHECK:               %[[VAL_32:.*]] = hlfir.designate %[[VAL_20B]]#0 (%[[VAL_28]])  : (!fir.ref<!fir.array<2x!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>>, index) -> !fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
 // CHECK:               %[[VAL_33:.*]] = fir.embox %[[VAL_32]] : (!fir.ref<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) -> !fir.box<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>
 // CHECK:               %[[VAL_34:.*]] = fir.convert %[[VAL_33]] : (!fir.box<!fir.type<_QMtypesTud_assign{x:!fir.box<!fir.ptr<i32>>}>>) -> !fir.box<none>
 // CHECK:               %[[VAL_35:.*]] = fir.call @_FortranAPushValue(%[[VAL_27]], %[[VAL_34]]) : (!fir.llvm_ptr<i8>, !fir.box<none>) -> none
diff --git a/flang/test/HLFIR/order_assignments/where-codegen-no-conflict.fir b/flang/test/HLFIR/order_assignments/where-codegen-no-conflict.fir
index ac93e6828096a..a1a357b45a64e 100644
--- a/flang/test/HLFIR/order_assignments/where-codegen-no-conflict.fir
+++ b/flang/test/HLFIR/order_assignments/where-codegen-no-conflict.fir
@@ -290,8 +290,6 @@ func.func @inside_forall(%arg0: !fir.ref<!fir.array<10x20xf32>>, %arg1: !fir.ref
 // CHECK:           fir.do_loop %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_14]] {
 // CHECK:             %[[VAL_16:.*]] = fir.convert %[[VAL_15]] : (index) -> i32
 // CHECK:             %[[VAL_17:.*]] = arith.constant 1 : index
-// CHECK:             %[[VAL_18:.*]] = fir.convert %[[VAL_16]] : (i32) -> i64
-// CHECK:             %[[VAL_19:.*]] = hlfir.designate %[[VAL_9]]#0 (%[[VAL_18]], %[[VAL_2]]:%[[VAL_7]]:%[[VAL_2]])  shape %[[VAL_10]] : (!fir.ref<!fir.array<10x20xf32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<20xf32>>
 // CHECK:             fir.do_loop %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_7]] step %[[VAL_17]] {
 // CHECK:               %[[VAL_21:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<20xf32>>, index) -> !fir.ref<f32>
 // CHECK:               %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<f32>
@@ -300,6 +298,8 @@ func.func @inside_forall(%arg0: !fir.ref<!fir.array<10x20xf32>>, %arg1: !fir.ref
 // CHECK:               %[[VAL_25:.*]] = fir.convert %[[VAL_24]] : (!fir.logical<4>) -> i1
 // CHECK:               fir.if %[[VAL_25]] {
 // CHECK:                 %[[VAL_26:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_20]])  : (!fir.ref<!fir.array<20xf32>>, index) -> !fir.ref<f32>
+// CHECK:                 %[[VAL_18:.*]] = fir.convert %[[VAL_16]] : (i32) -> i64
+// CHECK:                 %[[VAL_19:.*]] = hlfir.designate %[[VAL_9]]#0 (%[[VAL_18]], %[[VAL_2]]:%[[VAL_7]]:%[[VAL_2]])  shape %[[VAL_10]] : (!fir.ref<!fir.array<10x20xf32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<20xf32>>
 // CHECK:                 %[[VAL_27:.*]] = hlfir.designate %[[VAL_19]] (%[[VAL_20]])  : (!fir.box<!fir.array<20xf32>>, index) -> !fir.ref<f32>
 // CHECK:                 hlfir.assign %[[VAL_26]] to %[[VAL_27]] : !fir.ref<f32>, !fir.ref<f32>
 // CHECK:               }
diff --git a/flang/test/HLFIR/order_assignments/where-hoisting.f90 b/flang/test/HLFIR/order_assignments/where-hoisting.f90
new file mode 100644
index 0000000000000..6ed2ecb3624b0
--- /dev/null
+++ b/flang/test/HLFIR/order_assignments/where-hoisting.f90
@@ -0,0 +1,50 @@
+! Test that scalar expressions are not hoisted from WHERE loops
+! when they do not appear
+! RUN: bbc -hlfir -o - -pass-pipeline="builtin.module(lower-hlfir-ordered-assignments)" %s | FileCheck %s
+
+subroutine do_not_hoist_div(n, mask, a)
+  integer :: a(10), n
+  logical :: mask(10)
+  where(mask) a=1/n
+end subroutine
+! CHECK-LABEL:   func.func @_QPdo_not_hoist_div(
+! CHECK-NOT:       arith.divsi
+! CHECK:           fir.do_loop {{.*}} {
+! CHECK:             fir.if {{.*}} {
+! CHECK:               arith.divsi
+! CHECK:             }
+! CHECK:           }
+
+subroutine do_not_hoist_optional(n, mask, a)
+  integer :: a(10)
+  integer, optional :: n
+  logical :: mask(10)
+  where(mask) a=n
+end subroutine
+! CHECK-LABEL:   func.func @_QPdo_not_hoist_optional(
+! CHECK:           %[[VAL_9:.*]]:2 = hlfir.declare {{.*}}"_QFdo_not_hoist_optionalEn"
+! CHECK-NOT:       fir.load %[[VAL_9]]
+! CHECK:           fir.do_loop {{.*}} {
+! CHECK:             fir.if {{.*}} {
+! CHECK:               %[[VAL_15:.*]] = fir.load %[[VAL_9]]#0 : !fir.ref<i32>
+! CHECK:             }
+! CHECK:           }
+
+subroutine hoist_function(n, mask, a)
+  integer :: a(10, 10)
+  integer, optional :: n
+  logical :: mask(10, 10)
+  forall (i=1:10)
+  where(mask(i, :)) a(i,:)=ihoist_me(i)
+  end forall
+end subroutine
+! CHECK-LABEL:   func.func @_QPhoist_function(
+! CHECK:           fir.do_loop {{.*}} {
+! CHECK:             fir.call @_QPihoist_me
+! CHECK:             fir.do_loop {{.*}} {
+! CHECK:               fir.if %{{.*}} {
+! CHECK-NOT:             fir.call @_QPihoist_me
+! CHECK:               }
+! CHECK:             }
+! CHECK:           }
+! CHECK-NOT:       fir.call @_QPihoist_me
diff --git a/flang/test/Lower/HLFIR/where-nonelemental.f90 b/flang/test/Lower/HLFIR/where-nonelemental.f90
new file mode 100644
index 0000000000000..f0a6857f0f4b9
--- /dev/null
+++ b/flang/test/Lower/HLFIR/where-nonelemental.f90
@@ -0,0 +1,198 @@
+! Test lowering of non elemental calls and there inputs inside WHERE
+! constructs. These must be lowered inside hlfir.exactly_once so that
+! they are properly hoisted once the loops are materialized and
+! expression evaluations are scheduled.
+! RUN: bbc -emit-hlfir -o - %s | FileCheck %s
+
+subroutine test_where(a, b, c)
+ real, dimension(:) :: a, b, c
+ interface
+  function logical_func1()
+    logical :: logical_func1(100)
+  end function
+  function logical_func2()
+    logical :: logical_func2(100)
+  end function
+  real elemental function elem_func(x)
+    real, intent(in) :: x
+  end function
+ end interface
+ where (logical_func1())
+  a = b + real_func(a+b+real_func2()) + elem_func(a)
+ elsewhere(logical_func2())
+  a(1:ifoo()) = c
+ end where
+end subroutine
+! CHECK-LABEL:   func.func @_QPtest_where(
+! CHECK:           hlfir.where {
+! CHECK-NOT: hlfir.exactly_once
+! CHECK:             %[[VAL_17:.*]] = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
+! CHECK:             %[[VAL_19:.*]] = fir.call @_QPlogical_func1() fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
+! CHECK:             hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup {
+! CHECK:               fir.call @llvm.stackrestore.p0(%[[VAL_17]]) fastmath<contract> : (!fir.ref<i8>) -> ()
+! CHECK:             }
+! CHECK:           } do {
+! CHECK:             hlfir.region_assign {
+! CHECK:               %[[VAL_24:.*]] = hlfir.exactly_once : f32 {
+! CHECK:                 %[[VAL_28:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                 }
+! CHECK-NOT: hlfir.exactly_once
+! CHECK:                 %[[VAL_35:.*]] = fir.call @_QPreal_func2() fastmath<contract> : () -> f32
+! CHECK:                 %[[VAL_36:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                 ^bb0(%[[VAL_37:.*]]: index):
+! CHECK:                   %[[VAL_38:.*]] = hlfir.apply %[[VAL_28]], %[[VAL_37]] : (!hlfir.expr<?xf32>, index) -> f32
+! CHECK:                   %[[VAL_39:.*]] = arith.addf %[[VAL_38]], %[[VAL_35]] fastmath<contract> : f32
+! CHECK:                   hlfir.yield_element %[[VAL_39]] : f32
+! CHECK:                 }
+! CHECK:                 %[[VAL_41:.*]] = fir.call @_QPreal_func
+! CHECK:                 hlfir.yield %[[VAL_41]] : f32 cleanup {
+! CHECK:                   hlfir.destroy %[[VAL_36]] : !hlfir.expr<?xf32>
+! CHECK:                   hlfir.destroy %[[VAL_28]] : !hlfir.expr<?xf32>
+! CHECK:                 }
+! CHECK:               }
+! CHECK:               %[[VAL_45:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                 arith.addf
+! CHECK-NOT: hlfir.exactly_once
+! CHECK:               }
+! CHECK:               %[[VAL_53:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                 fir.call @_QPelem_func
+! CHECK:               }
+! CHECK:               %[[VAL_57:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                 arith.addf
+! CHECK:               }
+! CHECK:               hlfir.yield %[[VAL_57]] : !hlfir.expr<?xf32> cleanup {
+! CHECK:                 hlfir.destroy %[[VAL_57]] : !hlfir.expr<?xf32>
+! CHECK:                 hlfir.destroy %[[VAL_53]] : !hlfir.expr<?xf32>
+! CHECK:                 hlfir.destroy %[[VAL_45]] : !hlfir.expr<?xf32>
+! CHECK:               }
+! CHECK:             } to {
+! CHECK:               hlfir.yield %{{.*}} : !fir.box<!fir.array<?xf32>>
+! CHECK:             }
+! CHECK:             hlfir.elsewhere mask {
+! CHECK:               %[[VAL_62:.*]] = hlfir.exactly_once : !hlfir.expr<100x!fir.logical<4>> {
+! CHECK:                 %[[VAL_72:.*]] = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
+! CHECK:                 fir.call @_QPlogical_func2() fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
+! CHECK:                 hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup {
+! CHECK:                   fir.call @llvm.stackrestore.p0(%[[VAL_72]]) fastmath<contract> : (!fir.ref<i8>) -> ()
+! CHECK:                 }
+! CHECK:               }
+! CHECK:               hlfir.yield %[[VAL_62]] : !hlfir.expr<100x!fir.logical<4>>
+! CHECK:             } do {
+! CHECK:               hlfir.region_assign {
+! CHECK:                 hlfir.yield %{{.*}} : !fir.box<!fir.array<?xf32>>
+! CHECK:               } to {
+! CHECK:                 %[[VAL_80:.*]] = hlfir.exactly_once : i32 {
+! CHECK:                   %[[VAL_81:.*]] = fir.call @_QPifoo() fastmath<contract> : () -> i32
+! CHECK:                   hlfir.yield %[[VAL_81]] : i32
+! CHECK:                 }
+! CHECK:                 hlfir.yield %{{.*}} : !fir.box<!fir.array<?xf32>>
+! CHECK:               }
+! CHECK:             }
+! CHECK:           }
+! CHECK:           return
+! CHECK:         }
+
+subroutine test_where_in_forall(a, b, c)
+ real, dimension(:, :) :: a, b, c
+ interface
+  pure function pure_logical_func1()
+    logical :: pure_logical_func1(100)
+  end function
+  pure function pure_logical_func2()
+    logical :: pure_logical_func2(100)
+  end function
+  real pure elemental function pure_elem_func(x)
+    real, intent(in) :: x
+  end function
+  integer pure function pure_ifoo()
+  end function
+ end interface
+ forall(i=1:10)
+   where (pure_logical_func1())
+    a(2*i, :) = b(i, :) + pure_real_func(a(i,:)+b(i,:)+pure_real_func2()) + pure_elem_func(a(i,:))
+   elsewhere(pure_logical_func2())
+    a(2*i, 1:pure_ifoo()) = c(i, :)
+   end where
+ end forall
+end subroutine
+! CHECK-LABEL:   func.func @_QPtest_where_in_forall(
+! CHECK:           hlfir.forall lb {
+! CHECK:             hlfir.yield %{{.*}} : i32
+! CHECK:           } ub {
+! CHECK:             hlfir.yield %{{.*}} : i32
+! CHECK:           }  (%[[VAL_10:.*]]: i32) {
+! CHECK:             %[[VAL_11:.*]] = hlfir.forall_index "i" %[[VAL_10]] : (i32) -> !fir.ref<i32>
+! CHECK:             hlfir.where {
+! CHECK:               %[[VAL_21:.*]] = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
+! CHECK-NOT: hlfir.exactly_once
+! CHECK:               %[[VAL_23:.*]] = fir.call @_QPpure_logical_func1() fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
+! CHECK:               hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup {
+! CHECK:                 fir.call @llvm.stackrestore.p0(%[[VAL_21]]) fastmath<contract> : (!fir.ref<i8>) -> ()
+! CHECK:               }
+! CHECK:             } do {
+! CHECK:               hlfir.region_assign {
+! CHECK:                 %[[VAL_41:.*]] = hlfir.designate
+! CHECK:                 %[[VAL_42:.*]] = hlfir.exactly_once : f32 {
+! CHECK:                                    hlfir.designate
+! CHECK:                                    hlfir.designate
+! CHECK:                   %[[VAL_71:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                     arith.addf
+! CHECK:                   }
+! CHECK-NOT: hlfir.exactly_once
+! CHECK:                   %[[VAL_78:.*]] = fir.call @_QPpure_real_func2() fastmath<contract> : () -> f32
+! CHECK:                   %[[VAL_79:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                     arith.addf
+! CHECK:                   }
+! CHECK:                   %[[VAL_84:.*]] = fir.call @_QPpure_real_func(
+! CHECK:                   hlfir.yield %[[VAL_84]] : f32 cleanup {
+! CHECK:                     hlfir.destroy %[[VAL_79]] : !hlfir.expr<?xf32>
+! CHECK:                     hlfir.destroy %[[VAL_71]] : !hlfir.expr<?xf32>
+! CHECK:                   }
+! CHECK:                 }
+! CHECK:                 %[[VAL_85:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                     arith.addf
+! CHECK:                 }
+! CHECK-NOT: hlfir.exactly_once
+! CHECK:                 %[[VAL_104:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                 ^bb0(%[[VAL_105:.*]]: index):
+! CHECK-NOT: hlfir.exactly_once
+! CHECK:                   fir.call @_QPpure_elem_func
+! CHECK:                 }
+! CHECK:                 %[[VAL_108:.*]] = hlfir.elemental %{{.*}} unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:                   arith.addf
+! CHECK:                 }
+! CHECK:                 hlfir.yield %[[VAL_108]] : !hlfir.expr<?xf32> cleanup {
+! CHECK:                   hlfir.destroy %[[VAL_108]] : !hlfir.expr<?xf32>
+! CHECK:                   hlfir.destroy %[[VAL_104]] : !hlfir.expr<?xf32>
+! CHECK:                   hlfir.destroy %[[VAL_85]] : !hlfir.expr<?xf32>
+! CHECK:                 }
+! CHECK:               } to {
+! CHECK:                 hlfir.designate
+! CHECK:                 hlfir.yield %{{.*}} : !fir.box<!fir.array<?xf32>>
+! CHECK:               }
+! CHECK:               hlfir.elsewhere mask {
+! CHECK:                 %[[VAL_129:.*]] = hlfir.exactly_once : !hlfir.expr<100x!fir.logical<4>> {
+! CHECK:                   %[[VAL_139:.*]] = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
+! CHECK:                   %[[VAL_141:.*]] = fir.call @_QPpure_logical_func2() fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
+! CHECK:                   hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup {
+! CHECK:                     fir.call @llvm.stackrestore.p0(%[[VAL_139]]) fastmath<contract> : (!fir.ref<i8>) -> ()
+! CHECK:                   }
+! CHECK:                 }
+! CHECK:                 hlfir.yield %[[VAL_129]] : !hlfir.expr<100x!fir.logical<4>>
+! CHECK:               } do {
+! CHECK:                 hlfir.region_assign {
+! CHECK:                   hlfir.designate
+! CHECK:                   hlfir.yield %{{.*}} : !fir.box<!fir.array<?xf32>>
+! CHECK:                 } to {
+! CHECK:                   %[[VAL_165:.*]] = hlfir.exactly_once : i32 {
+! CHECK:                     %[[VAL_166:.*]] = fir.call @_QPpure_ifoo() fastmath<contract> : () -> i32
+! CHECK:                     hlfir.yield %[[VAL_166]] : i32
+! CHECK:                   }
+! CHECK:                   hlfir.designate
+! CHECK:                   hlfir.yield %{{.*}} : !fir.box<!fir.array<?xf32>>
+! CHECK:                 }
+! CHECK:               }
+! CHECK:             }
+! CHECK:           }
+! CHECK:           return
+! CHECK:         }

>From fb8dd24256254159ea5cb579804b6e9c2a9ab8ee Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 7 May 2024 13:43:45 -0700
Subject: [PATCH 2/3] apply clang-format fix

---
 flang/lib/Lower/ConvertCall.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index f989bc7e017f3..c6bfe35921699 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -2682,9 +2682,9 @@ bool Fortran::lower::isIntrinsicModuleProcRef(
   return module && module->attrs().test(Fortran::semantics::Attr::INTRINSIC);
 }
 
-static bool isInWhereMaskedExpression(fir::FirOpBuilder& builder) {
+static bool isInWhereMaskedExpression(fir::FirOpBuilder &builder) {
   // The MASK of the outer WHERE is not masked itself.
-  mlir::Operation* op = builder.getRegion().getParentOp();
+  mlir::Operation *op = builder.getRegion().getParentOp();
   return op && op->getParentOfType<hlfir::WhereOp>();
 }
 
@@ -2693,7 +2693,8 @@ std::optional<hlfir::EntityWithAttributes> Fortran::lower::convertCallToHLFIR(
     const evaluate::ProcedureRef &procRef, std::optional<mlir::Type> resultType,
     Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx) {
   auto &builder = converter.getFirOpBuilder();
-  if (resultType && !procRef.IsElemental() && isInWhereMaskedExpression(builder) &&
+  if (resultType && !procRef.IsElemental() &&
+      isInWhereMaskedExpression(builder) &&
       !builder.getRegion().getParentOfType<hlfir::ExactlyOnceOp>()) {
     // Non elemental calls inside a where-assignment-stmt must be executed
     // exactly once without mask control. Lower them in a special region so that

>From e681f43d20e057a0b8704412c5ba3ffb3fa2b5b5 Mon Sep 17 00:00:00 2001
From: jeanPerier <jean.perier.polytechnique at gmail.com>
Date: Mon, 13 May 2024 11:03:38 +0200
Subject: [PATCH 3/3] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Valentin Clement (バレンタイン クレメン) <clementval at gmail.com>
---
 flang/include/flang/Optimizer/HLFIR/HLFIROps.td                 | 2 +-
 .../Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index fdb21656a27fc..35476118fe9ae 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -1613,7 +1613,7 @@ def hlfir_ExactlyOnceOp : hlfir_Op<"exactly_once", [RecursiveMemoryEffects]> {
   let assemblyFormat = [{
     attr-dict `:` type($result)
     $body
-    }];
+  }];
 }
 
 def hlfir_WhereOp : hlfir_AssignmentMaskOp<"where"> {
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
index e4a9999d48c10..d46baa92c3eaa 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
@@ -929,7 +929,7 @@ void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder,
         (void)builder.clone(op, mapper);
   } else {
     // For actual masked expressions, Fortran requires elemental expressions,
-    // even the scalar ones that are no encoded with hlfir.elemental, to be
+    // even the scalar ones that are not encoded with hlfir.elemental, to be
     // evaluated only when the mask is true. Blindly hoisting all scalar SSA
     // tree could be wrong if the scalar computation has side effects and
     // would never have been evaluated (e.g. division by zero) if the mask



More information about the flang-commits mailing list