[llvm-branch-commits] [flang] [flang] optimize array function calls using hlfir.eval_in_mem (PR #118070)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Dec 2 07:22:00 PST 2024


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/118070

>From 4debf91864773ba805a9439cde9b74fc44315acc Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 28 Nov 2024 08:28:21 -0800
Subject: [PATCH 1/3] [flang] optimize array function calls using
 hlfir.eval_in_mem

---
 flang/include/flang/Lower/ConvertCall.h       |   6 +-
 .../flang/Optimizer/HLFIR/HLFIRDialect.h      |   4 +
 flang/lib/Lower/ConvertCall.cpp               | 102 +++++++++-----
 flang/lib/Lower/ConvertExpr.cpp               |  13 +-
 flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp |  15 ++
 flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp     |  11 +-
 .../order_assignments/where-scheduling.f90    |   2 +-
 .../test/Lower/HLFIR/calls-array-results.f90  | 131 ++++++++++++++++++
 flang/test/Lower/HLFIR/where-nonelemental.f90 |  38 ++---
 .../Lower/explicit-interface-results-2.f90    |   2 -
 .../test/Lower/explicit-interface-results.f90 |   8 +-
 flang/test/Lower/forall/array-constructor.f90 |   2 +-
 12 files changed, 260 insertions(+), 74 deletions(-)
 create mode 100644 flang/test/Lower/HLFIR/calls-array-results.f90

diff --git a/flang/include/flang/Lower/ConvertCall.h b/flang/include/flang/Lower/ConvertCall.h
index bc082907e61760..2c51a887010c87 100644
--- a/flang/include/flang/Lower/ConvertCall.h
+++ b/flang/include/flang/Lower/ConvertCall.h
@@ -24,6 +24,10 @@
 
 namespace Fortran::lower {
 
+struct LoweredResult {
+  std::variant<fir::ExtendedValue, hlfir::EntityWithAttributes> result;
+};
+
 /// Given a call site for which the arguments were already lowered, generate
 /// the call and return the result. This function deals with explicit result
 /// allocation and lowering if needed. It also deals with passing the host
@@ -32,7 +36,7 @@ namespace Fortran::lower {
 /// It is only used for HLFIR.
 /// The returned boolean indicates if finalization has been emitted in
 /// \p stmtCtx for the result.
-std::pair<fir::ExtendedValue, bool> genCallOpAndResult(
+std::pair<LoweredResult, bool> genCallOpAndResult(
     mlir::Location loc, Fortran::lower::AbstractConverter &converter,
     Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
     Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType,
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 3830237f96f3cc..447d5fbab89998 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -61,6 +61,10 @@ inline mlir::Type getFortranElementOrSequenceType(mlir::Type type) {
   return type;
 }
 
+/// Build the hlfir.expr type for the value held in a variable of type \p
+/// variableType.
+mlir::Type getExprType(mlir::Type variableType);
+
 /// Is this a fir.box or fir.class address type?
 inline bool isBoxAddressType(mlir::Type type) {
   type = fir::dyn_cast_ptrEleTy(type);
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index e84e7afbe82e09..088d8f96caa410 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -284,7 +284,8 @@ static void remapActualToDummyDescriptors(
   }
 }
 
-std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
+std::pair<Fortran::lower::LoweredResult, bool>
+Fortran::lower::genCallOpAndResult(
     mlir::Location loc, Fortran::lower::AbstractConverter &converter,
     Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx,
     Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType,
@@ -326,6 +327,11 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
     }
   }
 
+  const bool isExprCall =
+      converter.getLoweringOptions().getLowerToHighLevelFIR() &&
+      callSiteType.getNumResults() == 1 &&
+      llvm::isa<fir::SequenceType>(callSiteType.getResult(0));
+
   mlir::IndexType idxTy = builder.getIndexType();
   auto lowerSpecExpr = [&](const auto &expr) -> mlir::Value {
     mlir::Value convertExpr = builder.createConvert(
@@ -333,6 +339,8 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
     return fir::factory::genMaxWithZero(builder, loc, convertExpr);
   };
   llvm::SmallVector<mlir::Value> resultLengths;
+  mlir::Value arrayResultShape;
+  hlfir::EvaluateInMemoryOp evaluateInMemory;
   auto allocatedResult = [&]() -> std::optional<fir::ExtendedValue> {
     llvm::SmallVector<mlir::Value> extents;
     llvm::SmallVector<mlir::Value> lengths;
@@ -366,6 +374,18 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
       resultLengths = lengths;
     }
 
+    if (!extents.empty())
+      arrayResultShape = builder.genShape(loc, extents);
+
+    if (isExprCall) {
+      mlir::Type exprType = hlfir::getExprType(type);
+      evaluateInMemory = builder.create<hlfir::EvaluateInMemoryOp>(
+          loc, exprType, arrayResultShape, resultLengths);
+      builder.setInsertionPointToStart(&evaluateInMemory.getBody().front());
+      return toExtendedValue(loc, evaluateInMemory.getMemory(), extents,
+                             lengths);
+    }
+
     if ((!extents.empty() || !lengths.empty()) && !isElemental) {
       // Note: in the elemental context, the alloca ownership inside the
       // elemental region is implicit, and later pass in lowering (stack
@@ -384,8 +404,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
   if (mustPopSymMap)
     symMap.popScope();
 
-  // Place allocated result or prepare the fir.save_result arguments.
-  mlir::Value arrayResultShape;
+  // Place allocated result
   if (allocatedResult) {
     if (std::optional<Fortran::lower::CallInterface<
             Fortran::lower::CallerInterface>::PassedEntity>
@@ -399,16 +418,6 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
       else
         fir::emitFatalError(
             loc, "only expect character scalar result to be passed by ref");
-    } else {
-      assert(caller.mustSaveResult());
-      arrayResultShape = allocatedResult->match(
-          [&](const fir::CharArrayBoxValue &) {
-            return builder.createShape(loc, *allocatedResult);
-          },
-          [&](const fir::ArrayBoxValue &) {
-            return builder.createShape(loc, *allocatedResult);
-          },
-          [&](const auto &) { return mlir::Value{}; });
     }
   }
 
@@ -642,6 +651,19 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
       callResult = call.getResult(0);
   }
 
+  std::optional<Fortran::evaluate::DynamicType> retTy =
+      caller.getCallDescription().proc().GetType();
+  // With HLFIR lowering, isElemental must be set to true
+  // if we are producing an elemental call. In this case,
+  // the elemental results must not be destroyed, instead,
+  // the resulting array result will be finalized/destroyed
+  // as needed by hlfir.destroy.
+  const bool mustFinalizeResult =
+      !isElemental && callSiteType.getNumResults() > 0 &&
+      !fir::isPointerType(callSiteType.getResult(0)) && retTy.has_value() &&
+      (retTy->category() == Fortran::common::TypeCategory::Derived ||
+       retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic());
+
   if (caller.mustSaveResult()) {
     assert(allocatedResult.has_value());
     builder.create<fir::SaveResultOp>(loc, callResult,
@@ -649,6 +671,19 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
                                       arrayResultShape, resultLengths);
   }
 
+  if (evaluateInMemory) {
+    builder.setInsertionPointAfter(evaluateInMemory);
+    mlir::Value expr = evaluateInMemory.getResult();
+    fir::FirOpBuilder *bldr = &converter.getFirOpBuilder();
+    if (!isElemental)
+      stmtCtx.attachCleanup([bldr, loc, expr, mustFinalizeResult]() {
+        bldr->create<hlfir::DestroyOp>(loc, expr,
+                                       /*finalize=*/mustFinalizeResult);
+      });
+    return {LoweredResult{hlfir::EntityWithAttributes{expr}},
+            mustFinalizeResult};
+  }
+
   if (allocatedResult) {
     // The result must be optionally destroyed (if it is of a derived type
     // that may need finalization or deallocation of the components).
@@ -679,17 +714,7 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
     // derived-type.
     // For polymorphic and unlimited polymorphic enities call the runtime
     // in any cases.
-    std::optional<Fortran::evaluate::DynamicType> retTy =
-        caller.getCallDescription().proc().GetType();
-    // With HLFIR lowering, isElemental must be set to true
-    // if we are producing an elemental call. In this case,
-    // the elemental results must not be destroyed, instead,
-    // the resulting array result will be finalized/destroyed
-    // as needed by hlfir.destroy.
-    if (!isElemental && !fir::isPointerType(funcType.getResults()[0]) &&
-        retTy &&
-        (retTy->category() == Fortran::common::TypeCategory::Derived ||
-         retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic())) {
+    if (mustFinalizeResult) {
       if (retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic()) {
         auto *bldr = &converter.getFirOpBuilder();
         stmtCtx.attachCleanup([bldr, loc, allocatedResult]() {
@@ -715,12 +740,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
         }
       }
     }
-    return {*allocatedResult, resultIsFinalized};
+    return {LoweredResult{*allocatedResult}, resultIsFinalized};
   }
 
   // subroutine call
   if (!resultType)
-    return {fir::ExtendedValue{mlir::Value{}}, /*resultIsFinalized=*/false};
+    return {LoweredResult{fir::ExtendedValue{mlir::Value{}}},
+            /*resultIsFinalized=*/false};
 
   // For now, Fortran return values are implemented with a single MLIR
   // function return value.
@@ -734,10 +760,13 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
         mlir::dyn_cast<fir::CharacterType>(funcType.getResults()[0]);
     mlir::Value len = builder.createIntegerConstant(
         loc, builder.getCharacterLengthType(), charTy.getLen());
-    return {fir::CharBoxValue{callResult, len}, /*resultIsFinalized=*/false};
+    return {
+        LoweredResult{fir::ExtendedValue{fir::CharBoxValue{callResult, len}}},
+        /*resultIsFinalized=*/false};
   }
 
-  return {callResult, /*resultIsFinalized=*/false};
+  return {LoweredResult{fir::ExtendedValue{callResult}},
+          /*resultIsFinalized=*/false};
 }
 
 static hlfir::EntityWithAttributes genStmtFunctionRef(
@@ -1661,19 +1690,26 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
   // Prepare lowered arguments according to the interface
   // and map the lowered values to the dummy
   // arguments.
-  auto [result, resultIsFinalized] = Fortran::lower::genCallOpAndResult(
+  auto [loweredResult, resultIsFinalized] = Fortran::lower::genCallOpAndResult(
       loc, callContext.converter, callContext.symMap, callContext.stmtCtx,
       caller, callSiteType, callContext.resultType,
       callContext.isElementalProcWithArrayArgs());
-  // For procedure pointer function result, just return the call.
-  if (callContext.resultType &&
-      mlir::isa<fir::BoxProcType>(*callContext.resultType))
-    return hlfir::EntityWithAttributes(fir::getBase(result));
 
   /// Clean-up associations and copy-in.
   for (auto cleanUp : callCleanUps)
     cleanUp.genCleanUp(loc, builder);
 
+  if (auto *entity =
+          std::get_if<hlfir::EntityWithAttributes>(&loweredResult.result))
+    return *entity;
+
+  auto &result = std::get<fir::ExtendedValue>(loweredResult.result);
+
+  // For procedure pointer function result, just return the call.
+  if (callContext.resultType &&
+      mlir::isa<fir::BoxProcType>(*callContext.resultType))
+    return hlfir::EntityWithAttributes(fir::getBase(result));
+
   if (!fir::getBase(result))
     return std::nullopt; // subroutine call.
 
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 46168b81dd3a03..926e45b807e0a0 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -2852,10 +2852,11 @@ class ScalarExprLowering {
       }
     }
 
-    ExtValue result =
+    auto loweredResult =
         Fortran::lower::genCallOpAndResult(loc, converter, symMap, stmtCtx,
                                            caller, callSiteType, resultType)
             .first;
+    auto &result = std::get<ExtValue>(loweredResult.result);
 
     // Sync pointers and allocatables that may have been modified during the
     // call.
@@ -4881,10 +4882,12 @@ class ArrayExprLowering {
             [&](const auto &) { return fir::getBase(exv); });
         caller.placeInput(argIface, arg);
       }
-      return Fortran::lower::genCallOpAndResult(loc, converter, symMap,
-                                                getElementCtx(), caller,
-                                                callSiteType, retTy)
-          .first;
+      Fortran::lower::LoweredResult res =
+          Fortran::lower::genCallOpAndResult(loc, converter, symMap,
+                                             getElementCtx(), caller,
+                                             callSiteType, retTy)
+              .first;
+      return std::get<ExtValue>(res.result);
     };
   }
 
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index 0b61c0edce622b..c66ba75f912fb7 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -215,3 +215,18 @@ bool hlfir::mayHaveAllocatableComponent(mlir::Type ty) {
   return fir::isPolymorphicType(ty) || fir::isUnlimitedPolymorphicType(ty) ||
          fir::isRecordWithAllocatableMember(hlfir::getFortranElementType(ty));
 }
+
+mlir::Type hlfir::getExprType(mlir::Type variableType) {
+  hlfir::ExprType::Shape typeShape;
+  bool isPolymorphic = fir::isPolymorphicType(variableType);
+  mlir::Type type = getFortranElementOrSequenceType(variableType);
+  assert(!fir::isa_trivial(type) &&
+         "numerical and logical scalar should not be wrapped in hlfir.expr");
+  if (auto seqType = mlir::dyn_cast<fir::SequenceType>(type)) {
+    assert(!seqType.hasUnknownShape() && "assumed-rank cannot be expressions");
+    typeShape.append(seqType.getShape().begin(), seqType.getShape().end());
+    type = seqType.getEleTy();
+  }
+  return hlfir::ExprType::get(variableType.getContext(), typeShape, type,
+                              isPolymorphic);
+}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 87519882446485..3a172d1b8b5400 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -1427,16 +1427,7 @@ llvm::LogicalResult hlfir::EndAssociateOp::verify() {
 void hlfir::AsExprOp::build(mlir::OpBuilder &builder,
                             mlir::OperationState &result, mlir::Value var,
                             mlir::Value mustFree) {
-  hlfir::ExprType::Shape typeShape;
-  bool isPolymorphic = fir::isPolymorphicType(var.getType());
-  mlir::Type type = getFortranElementOrSequenceType(var.getType());
-  if (auto seqType = mlir::dyn_cast<fir::SequenceType>(type)) {
-    typeShape.append(seqType.getShape().begin(), seqType.getShape().end());
-    type = seqType.getEleTy();
-  }
-
-  auto resultType = hlfir::ExprType::get(builder.getContext(), typeShape, type,
-                                         isPolymorphic);
+  mlir::Type resultType = hlfir::getExprType(var.getType());
   return build(builder, result, resultType, var, mustFree);
 }
 
diff --git a/flang/test/HLFIR/order_assignments/where-scheduling.f90 b/flang/test/HLFIR/order_assignments/where-scheduling.f90
index 3010476d4a188f..6feaba0d3389a0 100644
--- a/flang/test/HLFIR/order_assignments/where-scheduling.f90
+++ b/flang/test/HLFIR/order_assignments/where-scheduling.f90
@@ -134,7 +134,7 @@ end function f
 !CHECK-NEXT: run 1 save    : where/mask
 !CHECK-NEXT: run 2 evaluate: where/region_assign1
 !CHECK-LABEL: ------------ scheduling where in _QPonly_once ------------
-!CHECK-NEXT: unknown effect: %{{[0-9]+}} = llvm.intr.stacksave : !llvm.ptr
+!CHECK-NEXT: unknown effect: %11 = fir.call @_QPcall_me_only_once() fastmath<contract> : () -> !fir.array<10x!fir.logical<4>>
 !CHECK-NEXT: saving eval because write effect prevents re-evaluation
 !CHECK-NEXT: run 1 save  (w): where/mask
 !CHECK-NEXT: run 2 evaluate: where/region_assign1
diff --git a/flang/test/Lower/HLFIR/calls-array-results.f90 b/flang/test/Lower/HLFIR/calls-array-results.f90
new file mode 100644
index 00000000000000..d91844cc2e6f87
--- /dev/null
+++ b/flang/test/Lower/HLFIR/calls-array-results.f90
@@ -0,0 +1,131 @@
+! RUN: bbc -emit-hlfir -o - %s -I nowhere | FileCheck %s
+
+subroutine simple_test()
+  implicit none
+  interface
+    function array_func()
+      real :: array_func(10)
+    end function
+  end interface
+  real :: x(10)
+  x = array_func()
+end subroutine
+
+subroutine arg_test(n)
+  implicit none
+  interface
+    function array_func_2(n)
+      integer(8) :: n
+      real :: array_func_2(n)
+    end function
+  end interface
+  integer(8) :: n
+  real :: x(n)
+  x = array_func_2(n)
+end subroutine
+
+module type_defs
+  interface
+    function array_func()
+      real :: array_func(10)
+    end function
+  end interface
+  type t
+    contains
+    procedure, nopass :: array_func => array_func
+  end type
+end module
+
+subroutine dispatch_test(x, a)
+  use type_defs, only : t
+  implicit none
+  real :: x(10)
+  class(t) :: a
+  x = a%array_func()
+end subroutine
+
+! CHECK-LABEL:   func.func @_QPsimple_test() {
+! CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+! CHECK:           %[[VAL_1:.*]] = fir.alloca !fir.array<10xf32> {bindc_name = "x", uniq_name = "_QFsimple_testEx"}
+! CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_2]]) {uniq_name = "_QFsimple_testEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+! CHECK:           %[[VAL_4:.*]] = arith.constant 10 : i64
+! CHECK:           %[[VAL_5:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_6:.*]] = arith.subi %[[VAL_4]], %[[VAL_5]] : i64
+! CHECK:           %[[VAL_7:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i64
+! CHECK:           %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i64) -> index
+! CHECK:           %[[VAL_10:.*]] = arith.constant 0 : index
+! CHECK:           %[[VAL_11:.*]] = arith.cmpi sgt, %[[VAL_9]], %[[VAL_10]] : index
+! CHECK:           %[[VAL_12:.*]] = arith.select %[[VAL_11]], %[[VAL_9]], %[[VAL_10]] : index
+! CHECK:           %[[VAL_13:.*]] = fir.shape %[[VAL_12]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_14:.*]] = hlfir.eval_in_mem shape %[[VAL_13]] : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:           ^bb0(%[[VAL_15:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK:             %[[VAL_16:.*]] = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+! CHECK:             fir.save_result %[[VAL_16]] to %[[VAL_15]](%[[VAL_13]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           hlfir.assign %[[VAL_14]] to %[[VAL_3]]#0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+! CHECK:           hlfir.destroy %[[VAL_14]] : !hlfir.expr<10xf32>
+! CHECK:           return
+! CHECK:         }
+
+! CHECK-LABEL:   func.func @_QParg_test(
+! CHECK-SAME:                           %[[VAL_0:.*]]: !fir.ref<i64> {fir.bindc_name = "n"}) {
+! CHECK:           %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
+! CHECK:           %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]] {uniq_name = "_QFarg_testEn"} : (!fir.ref<i64>, !fir.dscope) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK:           %[[VAL_3:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref<i64>
+! CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
+! CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+! CHECK:           %[[VAL_6:.*]] = arith.cmpi sgt, %[[VAL_4]], %[[VAL_5]] : index
+! CHECK:           %[[VAL_7:.*]] = arith.select %[[VAL_6]], %[[VAL_4]], %[[VAL_5]] : index
+! CHECK:           %[[VAL_8:.*]] = fir.alloca !fir.array<?xf32>, %[[VAL_7]] {bindc_name = "x", uniq_name = "_QFarg_testEx"}
+! CHECK:           %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_9]]) {uniq_name = "_QFarg_testEx"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+! CHECK:           %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_2]]#1 {uniq_name = "_QFarg_testFarray_func_2En"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK:           %[[VAL_12:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref<i64>
+! CHECK:           %[[VAL_13:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_13]] : i64
+! CHECK:           %[[VAL_15:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : i64
+! CHECK:           %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (i64) -> index
+! CHECK:           %[[VAL_18:.*]] = arith.constant 0 : index
+! CHECK:           %[[VAL_19:.*]] = arith.cmpi sgt, %[[VAL_17]], %[[VAL_18]] : index
+! CHECK:           %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_17]], %[[VAL_18]] : index
+! CHECK:           %[[VAL_21:.*]] = fir.shape %[[VAL_20]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_22:.*]] = hlfir.eval_in_mem shape %[[VAL_21]] : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+! CHECK:           ^bb0(%[[VAL_23:.*]]: !fir.ref<!fir.array<?xf32>>):
+! CHECK:             %[[VAL_24:.*]] = fir.call @_QParray_func_2(%[[VAL_2]]#1) fastmath<contract> : (!fir.ref<i64>) -> !fir.array<?xf32>
+! CHECK:             fir.save_result %[[VAL_24]] to %[[VAL_23]](%[[VAL_21]]) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           hlfir.assign %[[VAL_22]] to %[[VAL_10]]#0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+! CHECK:           hlfir.destroy %[[VAL_22]] : !hlfir.expr<?xf32>
+! CHECK:           return
+! CHECK:         }
+
+! CHECK-LABEL:   func.func @_QPdispatch_test(
+! CHECK-SAME:                                %[[VAL_0:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x"},
+! CHECK-SAME:                                %[[VAL_1:.*]]: !fir.class<!fir.type<_QMtype_defsTt>> {fir.bindc_name = "a"}) {
+! CHECK:           %[[VAL_2:.*]] = fir.dummy_scope : !fir.dscope
+! CHECK:           %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_1]] dummy_scope %[[VAL_2]] {uniq_name = "_QFdispatch_testEa"} : (!fir.class<!fir.type<_QMtype_defsTt>>, !fir.dscope) -> (!fir.class<!fir.type<_QMtype_defsTt>>, !fir.class<!fir.type<_QMtype_defsTt>>)
+! CHECK:           %[[VAL_4:.*]] = arith.constant 10 : index
+! CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) dummy_scope %[[VAL_2]] {uniq_name = "_QFdispatch_testEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+! CHECK:           %[[VAL_7:.*]] = arith.constant 10 : i64
+! CHECK:           %[[VAL_8:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_8]] : i64
+! CHECK:           %[[VAL_10:.*]] = arith.constant 1 : i64
+! CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_10]] : i64
+! CHECK:           %[[VAL_12:.*]] = fir.convert %[[VAL_11]] : (i64) -> index
+! CHECK:           %[[VAL_13:.*]] = arith.constant 0 : index
+! CHECK:           %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_12]], %[[VAL_13]] : index
+! CHECK:           %[[VAL_15:.*]] = arith.select %[[VAL_14]], %[[VAL_12]], %[[VAL_13]] : index
+! CHECK:           %[[VAL_16:.*]] = fir.shape %[[VAL_15]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_17:.*]] = hlfir.eval_in_mem shape %[[VAL_16]] : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+! CHECK:           ^bb0(%[[VAL_18:.*]]: !fir.ref<!fir.array<10xf32>>):
+! CHECK:             %[[VAL_19:.*]] = fir.dispatch "array_func"(%[[VAL_3]]#1 : !fir.class<!fir.type<_QMtype_defsTt>>) -> !fir.array<10xf32>
+! CHECK:             fir.save_result %[[VAL_19]] to %[[VAL_18]](%[[VAL_16]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+! CHECK:           }
+! CHECK:           hlfir.assign %[[VAL_17]] to %[[VAL_6]]#0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+! CHECK:           hlfir.destroy %[[VAL_17]] : !hlfir.expr<10xf32>
+! CHECK:           return
+! CHECK:         }
diff --git a/flang/test/Lower/HLFIR/where-nonelemental.f90 b/flang/test/Lower/HLFIR/where-nonelemental.f90
index 643f417c476748..7be58318900124 100644
--- a/flang/test/Lower/HLFIR/where-nonelemental.f90
+++ b/flang/test/Lower/HLFIR/where-nonelemental.f90
@@ -26,11 +26,12 @@ real elemental function elem_func(x)
 ! CHECK-LABEL:   func.func @_QPtest_where(
 ! CHECK:           hlfir.where {
 ! CHECK-NOT: hlfir.exactly_once
-! CHECK:             %[[VAL_17:.*]] = llvm.intr.stacksave : !llvm.ptr
-! CHECK:             %[[VAL_19:.*]] = fir.call @_QPlogical_func1() fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
-! CHECK:             hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup {
-! CHECK:               llvm.intr.stackrestore %[[VAL_17]] : !llvm.ptr
-! CHECK:             }
+! CHECK:               %[[VAL_19:.*]] = hlfir.eval_in_mem {{.*}} {
+! CHECK:                 fir.call @_QPlogical_func1() fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
+! CHECK:               }
+! CHECK:               hlfir.yield %[[VAL_19]] : !hlfir.expr<100x!fir.logical<4>> cleanup {
+! CHECK:                 hlfir.destroy %[[VAL_19]]
+! CHECK:               }
 ! CHECK:           } do {
 ! CHECK:             hlfir.region_assign {
 ! CHECK:               %[[VAL_24:.*]] = hlfir.exactly_once : f32 {
@@ -70,10 +71,11 @@ real elemental function elem_func(x)
 ! CHECK:             }
 ! CHECK:             hlfir.elsewhere mask {
 ! CHECK:               %[[VAL_62:.*]] = hlfir.exactly_once : !hlfir.expr<100x!fir.logical<4>> {
-! CHECK:                 %[[VAL_72:.*]] = llvm.intr.stacksave : !llvm.ptr
-! CHECK:                 fir.call @_QPlogical_func2() fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
-! CHECK:                 hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup {
-! CHECK:                   llvm.intr.stackrestore %[[VAL_72]] : !llvm.ptr
+! CHECK:                 %[[VAL_72:.*]] = hlfir.eval_in_mem {{.*}} {
+! CHECK:                  fir.call @_QPlogical_func2() fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
+! CHECK:                 }
+! CHECK:                 hlfir.yield %[[VAL_72]] : !hlfir.expr<100x!fir.logical<4>> cleanup {
+! CHECK:                   hlfir.destroy %[[VAL_72]]
 ! CHECK:                 }
 ! CHECK:               }
 ! CHECK:               hlfir.yield %[[VAL_62]] : !hlfir.expr<100x!fir.logical<4>>
@@ -123,11 +125,12 @@ integer pure function pure_ifoo()
 ! CHECK:           }  (%[[VAL_10:.*]]: i32) {
 ! CHECK:             %[[VAL_11:.*]] = hlfir.forall_index "i" %[[VAL_10]] : (i32) -> !fir.ref<i32>
 ! CHECK:             hlfir.where {
-! CHECK:               %[[VAL_21:.*]] = llvm.intr.stacksave : !llvm.ptr
 ! CHECK-NOT: hlfir.exactly_once
-! CHECK:               %[[VAL_23:.*]] = fir.call @_QPpure_logical_func1() proc_attrs<pure> fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
-! CHECK:               hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup {
-! CHECK:                 llvm.intr.stackrestore %[[VAL_21]] : !llvm.ptr
+! CHECK:               %[[VAL_23:.*]] = hlfir.eval_in_mem {{.*}} {
+! CHECK:                  fir.call @_QPpure_logical_func1() proc_attrs<pure> fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
+! CHECK:               }
+! CHECK:               hlfir.yield %[[VAL_23]] : !hlfir.expr<100x!fir.logical<4>> cleanup {
+! CHECK:                 hlfir.destroy %[[VAL_23]]
 ! CHECK:               }
 ! CHECK:             } do {
 ! CHECK:               hlfir.region_assign {
@@ -172,10 +175,11 @@ integer pure function pure_ifoo()
 ! CHECK:               }
 ! CHECK:               hlfir.elsewhere mask {
 ! CHECK:                 %[[VAL_129:.*]] = hlfir.exactly_once : !hlfir.expr<100x!fir.logical<4>> {
-! CHECK:                   %[[VAL_139:.*]] = llvm.intr.stacksave : !llvm.ptr
-! CHECK:                   %[[VAL_141:.*]] = fir.call @_QPpure_logical_func2() proc_attrs<pure> fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
-! CHECK:                   hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup {
-! CHECK:                     llvm.intr.stackrestore %[[VAL_139]] : !llvm.ptr
+! CHECK:                   %[[VAL_139:.*]] = hlfir.eval_in_mem {{.*}} {
+! CHECK:                    fir.call @_QPpure_logical_func2() proc_attrs<pure> fastmath<contract> : () -> !fir.array<100x!fir.logical<4>>
+! CHECK:                   }
+! CHECK:                   hlfir.yield %[[VAL_139]] : !hlfir.expr<100x!fir.logical<4>> cleanup {
+! CHECK:                     hlfir.destroy %[[VAL_139]]
 ! CHECK:                   }
 ! CHECK:                 }
 ! CHECK:                 hlfir.yield %[[VAL_129]] : !hlfir.expr<100x!fir.logical<4>>
diff --git a/flang/test/Lower/explicit-interface-results-2.f90 b/flang/test/Lower/explicit-interface-results-2.f90
index 95aee84f4a6446..2336053c32a547 100644
--- a/flang/test/Lower/explicit-interface-results-2.f90
+++ b/flang/test/Lower/explicit-interface-results-2.f90
@@ -252,12 +252,10 @@ subroutine test_call_to_used_interface(dummy_proc)
   call takes_array(dummy_proc())
 ! CHECK:  %[[VAL_1:.*]] = arith.constant 100 : index
 ! CHECK:  %[[VAL_2:.*]] = fir.alloca !fir.array<100xf32> {bindc_name = ".result"}
-! CHECK:  %[[VAL_3:.*]] = llvm.intr.stacksave : !llvm.ptr
 ! CHECK:  %[[VAL_4:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
 ! CHECK:  %[[VAL_5:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> (() -> !fir.array<100xf32>)
 ! CHECK:  %[[VAL_6:.*]] = fir.call %[[VAL_5]]() {{.*}}: () -> !fir.array<100xf32>
 ! CHECK:  fir.save_result %[[VAL_6]] to %[[VAL_2]](%[[VAL_4]]) : !fir.array<100xf32>, !fir.ref<!fir.array<100xf32>>, !fir.shape<1>
 ! CHECK:  %[[VAL_7:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<100xf32>>) -> !fir.ref<!fir.array<?xf32>>
 ! CHECK:  fir.call @_QPtakes_array(%[[VAL_7]]) {{.*}}: (!fir.ref<!fir.array<?xf32>>) -> ()
-! CHECK:  llvm.intr.stackrestore %[[VAL_3]] : !llvm.ptr
 end subroutine
diff --git a/flang/test/Lower/explicit-interface-results.f90 b/flang/test/Lower/explicit-interface-results.f90
index 623e875b5f9c9d..612d57be364481 100644
--- a/flang/test/Lower/explicit-interface-results.f90
+++ b/flang/test/Lower/explicit-interface-results.f90
@@ -195,8 +195,8 @@ subroutine dyn_array(m, n)
   ! CHECK-DAG: %[[ncast2:.*]] = fir.convert %[[nadd]] : (i64) -> index
   ! CHECK-DAG: %[[ncmpi:.*]] = arith.cmpi sgt, %[[ncast2]], %{{.*}} : index
   ! CHECK-DAG: %[[nselect:.*]] = arith.select %[[ncmpi]], %[[ncast2]], %{{.*}} : index
-  ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<?x?xf32>, %[[mselect]], %[[nselect]]
   ! CHECK: %[[shape:.*]] = fir.shape %[[mselect]], %[[nselect]] : (index, index) -> !fir.shape<2>
+  ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<?x?xf32>, %[[mselect]], %[[nselect]]
   ! CHECK: %[[res:.*]] = fir.call @_QMcalleePreturn_dyn_array(%[[m]], %[[n]]) {{.*}}: (!fir.ref<i32>, !fir.ref<i32>) -> !fir.array<?x?xf32>
   ! CHECK: fir.save_result %[[res]] to %[[tmp]](%[[shape]]) : !fir.array<?x?xf32>, !fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>
   print *, return_dyn_array(m, n)
@@ -211,8 +211,8 @@ subroutine dyn_char_cst_array(l)
   ! CHECK: %[[lcast2:.*]] = fir.convert %[[lcast]] : (i64) -> index
   ! CHECK: %[[cmpi:.*]] = arith.cmpi sgt, %[[lcast2]], %{{.*}} : index
   ! CHECK: %[[select:.*]] = arith.select %[[cmpi]], %[[lcast2]], %{{.*}} : index
-  ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<20x30x!fir.char<1,?>>(%[[select]] : index)
   ! CHECK: %[[shape:.*]] = fir.shape %{{.*}}, %{{.*}} : (index, index) -> !fir.shape<2>
+  ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<20x30x!fir.char<1,?>>(%[[select]] : index)
   ! CHECK: %[[res:.*]] = fir.call @_QMcalleePreturn_dyn_char_cst_array(%[[l]]) {{.*}}: (!fir.ref<i32>) -> !fir.array<20x30x!fir.char<1,?>>
   ! CHECK: fir.save_result %[[res]] to %[[tmp]](%[[shape]]) typeparams %[[select]] : !fir.array<20x30x!fir.char<1,?>>, !fir.ref<!fir.array<20x30x!fir.char<1,?>>>, !fir.shape<2>, index
   print *, return_dyn_char_cst_array(l)
@@ -236,8 +236,8 @@ subroutine cst_char_dyn_array(m, n)
   ! CHECK-DAG: %[[ncast2:.*]] = fir.convert %[[nadd]] : (i64) -> index
   ! CHECK-DAG: %[[ncmpi:.*]] = arith.cmpi sgt, %[[ncast2]], %{{.*}} : index
   ! CHECK-DAG: %[[nselect:.*]] = arith.select %[[ncmpi]], %[[ncast2]], %{{.*}} : index
-  ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<?x?x!fir.char<1,10>>, %[[mselect]], %[[nselect]]
   ! CHECK: %[[shape:.*]] = fir.shape %[[mselect]], %[[nselect]] : (index, index) -> !fir.shape<2>
+  ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<?x?x!fir.char<1,10>>, %[[mselect]], %[[nselect]]
   ! CHECK: %[[res:.*]] = fir.call @_QMcalleePreturn_cst_char_dyn_array(%[[m]], %[[n]]) {{.*}}: (!fir.ref<i32>, !fir.ref<i32>) -> !fir.array<?x?x!fir.char<1,10>>
   ! CHECK: fir.save_result %[[res]] to %[[tmp]](%[[shape]]) typeparams {{.*}} : !fir.array<?x?x!fir.char<1,10>>, !fir.ref<!fir.array<?x?x!fir.char<1,10>>>, !fir.shape<2>, index
   print *, return_cst_char_dyn_array(m, n)
@@ -267,8 +267,8 @@ subroutine dyn_char_dyn_array(l, m, n)
   ! CHECK-DAG: %[[lcast2:.*]] = fir.convert %[[lcast]] : (i64) -> index
   ! CHECK-DAG: %[[lcmpi:.*]] = arith.cmpi sgt, %[[lcast2]], %{{.*}} : index
   ! CHECK-DAG: %[[lselect:.*]] = arith.select %[[lcmpi]], %[[lcast2]], %{{.*}} : index
-  ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<?x?x!fir.char<1,?>>(%[[lselect]] : index), %[[mselect]], %[[nselect]]
   ! CHECK: %[[shape:.*]] = fir.shape %[[mselect]], %[[nselect]] : (index, index) -> !fir.shape<2>
+  ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<?x?x!fir.char<1,?>>(%[[lselect]] : index), %[[mselect]], %[[nselect]]
   ! CHECK: %[[res:.*]] = fir.call @_QMcalleePreturn_dyn_char_dyn_array(%[[l]], %[[m]], %[[n]]) {{.*}}: (!fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>) -> !fir.array<?x?x!fir.char<1,?>>
   ! CHECK: fir.save_result %[[res]] to %[[tmp]](%[[shape]]) typeparams {{.*}} : !fir.array<?x?x!fir.char<1,?>>, !fir.ref<!fir.array<?x?x!fir.char<1,?>>>, !fir.shape<2>, index
   integer :: l, m, n
diff --git a/flang/test/Lower/forall/array-constructor.f90 b/flang/test/Lower/forall/array-constructor.f90
index 4c8c756ea689c1..6b6b46fdd4688e 100644
--- a/flang/test/Lower/forall/array-constructor.f90
+++ b/flang/test/Lower/forall/array-constructor.f90
@@ -232,8 +232,8 @@ end subroutine ac2
 ! CHECK:           %[[C0:.*]] = arith.constant 0 : index
 ! CHECK:           %[[CMPI:.*]] = arith.cmpi sgt, %[[VAL_80]], %[[C0]] : index
 ! CHECK:           %[[SELECT:.*]] = arith.select %[[CMPI]], %[[VAL_80]], %[[C0]] : index
-! CHECK:           %[[VAL_81:.*]] = llvm.intr.stacksave : !llvm.ptr
 ! CHECK:           %[[VAL_82:.*]] = fir.shape %[[SELECT]] : (index) -> !fir.shape<1>
+! CHECK:           %[[VAL_81:.*]] = llvm.intr.stacksave : !llvm.ptr
 ! CHECK:           %[[VAL_83:.*]] = fir.convert %[[VAL_74]] : (!fir.box<!fir.array<1xi32>>) -> !fir.box<!fir.array<?xi32>>
 ! CHECK:           %[[VAL_84:.*]] = fir.call @_QFac2Pfunc(%[[VAL_83]]) {{.*}}: (!fir.box<!fir.array<?xi32>>) -> !fir.array<3xi32>
 ! CHECK:           fir.save_result %[[VAL_84]] to %[[VAL_2]](%[[VAL_82]]) : !fir.array<3xi32>, !fir.ref<!fir.array<3xi32>>, !fir.shape<1>

>From 2002996a96326a28d11da797f323966c331b6fe8 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 2 Dec 2024 05:52:09 -0800
Subject: [PATCH 2/3] PR118070 comments: simplify LoweredResult

---
 flang/include/flang/Lower/ConvertCall.h | 7 ++++---
 flang/lib/Lower/ConvertCall.cpp         | 5 ++---
 flang/lib/Lower/ConvertExpr.cpp         | 4 ++--
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/flang/include/flang/Lower/ConvertCall.h b/flang/include/flang/Lower/ConvertCall.h
index 2c51a887010c87..f1cd4f938320bf 100644
--- a/flang/include/flang/Lower/ConvertCall.h
+++ b/flang/include/flang/Lower/ConvertCall.h
@@ -24,9 +24,10 @@
 
 namespace Fortran::lower {
 
-struct LoweredResult {
-  std::variant<fir::ExtendedValue, hlfir::EntityWithAttributes> result;
-};
+/// Data structure packaging the SSA value(s) produced for the result of lowered
+/// function calls.
+using LoweredResult =
+    std::variant<fir::ExtendedValue, hlfir::EntityWithAttributes>;
 
 /// Given a call site for which the arguments were already lowered, generate
 /// the call and return the result. This function deals with explicit result
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 088d8f96caa410..40cd106e630182 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -1699,11 +1699,10 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals,
   for (auto cleanUp : callCleanUps)
     cleanUp.genCleanUp(loc, builder);
 
-  if (auto *entity =
-          std::get_if<hlfir::EntityWithAttributes>(&loweredResult.result))
+  if (auto *entity = std::get_if<hlfir::EntityWithAttributes>(&loweredResult))
     return *entity;
 
-  auto &result = std::get<fir::ExtendedValue>(loweredResult.result);
+  auto &result = std::get<fir::ExtendedValue>(loweredResult);
 
   // For procedure pointer function result, just return the call.
   if (callContext.resultType &&
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 926e45b807e0a0..7698fac89c2231 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -2856,7 +2856,7 @@ class ScalarExprLowering {
         Fortran::lower::genCallOpAndResult(loc, converter, symMap, stmtCtx,
                                            caller, callSiteType, resultType)
             .first;
-    auto &result = std::get<ExtValue>(loweredResult.result);
+    auto &result = std::get<ExtValue>(loweredResult);
 
     // Sync pointers and allocatables that may have been modified during the
     // call.
@@ -4887,7 +4887,7 @@ class ArrayExprLowering {
                                              getElementCtx(), caller,
                                              callSiteType, retTy)
               .first;
-      return std::get<ExtValue>(res.result);
+      return std::get<ExtValue>(res);
     };
   }
 

>From 4ab4992246fddf1eef2131d2a6d341095feba85f Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 2 Dec 2024 07:18:49 -0800
Subject: [PATCH 3/3] remove new over-assertive assert

I added an assert that prevents ExprType to be "trivial types" (i32, f32, ...)
but that actually breaks existing lowering code that generate as_expr + associate
patterns to create temporaries:

https://github.com/llvm/llvm-project/blob/03b5f8f0f0d10c412842ed04b90e2217cf071218/flang/lib/Lower/ConvertCall.cpp#L1271

```
subroutine trivial_as_expr(i)
  integer, optional :: i
  interface
    subroutine foo(j)
      integer, optional, value :: j
    end subroutine
  end interface
  call foo(i)
end subroutine
```

This is unrelated to this patch, so leave it like this (it is not even a
functional problem).
---
 flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index c66ba75f912fb7..d67b5fa6598075 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -220,8 +220,6 @@ mlir::Type hlfir::getExprType(mlir::Type variableType) {
   hlfir::ExprType::Shape typeShape;
   bool isPolymorphic = fir::isPolymorphicType(variableType);
   mlir::Type type = getFortranElementOrSequenceType(variableType);
-  assert(!fir::isa_trivial(type) &&
-         "numerical and logical scalar should not be wrapped in hlfir.expr");
   if (auto seqType = mlir::dyn_cast<fir::SequenceType>(type)) {
     assert(!seqType.hasUnknownShape() && "assumed-rank cannot be expressions");
     typeShape.append(seqType.getShape().begin(), seqType.getShape().end());



More information about the llvm-branch-commits mailing list