[flang-commits] [flang] [flang][hlfir] add hlfir.eval_in_mem operation (PR #118067)

via flang-commits flang-commits at lists.llvm.org
Fri Nov 29 00:40:03 PST 2024


https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/118067

See HLFIROps.td change for the description of the operation.

The goal is to ease temporary storage elision for expression evaluation (typically evaluating the RHS directly inside the LHS) for expressions that do not have abtsractions in HLFIR and for which it is not clear adding one would bring much. The case that is implemented in the following lowering patch is the array call case, where adding a new hlfir.call would add complexity (needs to deal with dispatch, inlining ....).

Of course the optimizer could also try to remove temps created, but this is in general a harder problem (need to identify lifetime, get stack/save or free if any, make sure there are no captures). Encapsulating the temporary in a region where it cannot escape by design greatly ease the analysis and optimization.

This concept could be used for other expressions that are currently lowered in memory (at least some cases of array constructors and structure constructors).

>From 5d43a6bd4fb27015e6917eb471bc725fe31b9ac2 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 28 Nov 2024 08:25:05 -0800
Subject: [PATCH] [flang][hlfir] add hlfir.eval_in_mem operation

---
 .../flang/Optimizer/Builder/HLFIRTools.h      |  19 ++++
 .../include/flang/Optimizer/HLFIR/HLFIROps.td |  59 ++++++++++
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    |  47 ++++++++
 flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp     |  76 ++++++++++---
 .../HLFIR/Transforms/BufferizeHLFIR.cpp       |  33 +++++-
 flang/test/HLFIR/eval_in_mem-codegen.fir      | 107 ++++++++++++++++++
 flang/test/HLFIR/eval_in_mem.fir              |  99 ++++++++++++++++
 flang/test/HLFIR/invalid.fir                  |  34 ++++++
 8 files changed, 454 insertions(+), 20 deletions(-)
 create mode 100644 flang/test/HLFIR/eval_in_mem-codegen.fir
 create mode 100644 flang/test/HLFIR/eval_in_mem.fir

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index f073f494b3fb21..efbd9e4f50d432 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -33,6 +33,7 @@ class AssociateOp;
 class ElementalOp;
 class ElementalOpInterface;
 class ElementalAddrOp;
+class EvaluateInMemoryOp;
 class YieldElementOp;
 
 /// Is this a Fortran variable for which the defining op carrying the Fortran
@@ -398,6 +399,24 @@ mlir::Value inlineElementalOp(
     mlir::IRMapping &mapper,
     const std::function<bool(hlfir::ElementalOp)> &mustRecursivelyInline);
 
+/// Create a new temporary with the shape and parameters of the provided
+/// hlfir.eval_in_mem operation and clone the body of the hlfir.eval_in_mem
+/// operating on this new temporary.  returns the temporary and whether the
+/// temporary is heap or stack allocated.
+std::pair<hlfir::Entity, bool>
+computeEvaluateOpInNewTemp(mlir::Location, fir::FirOpBuilder &,
+                           hlfir::EvaluateInMemoryOp evalInMem,
+                           mlir::Value shape, mlir::ValueRange typeParams);
+
+// Clone the body of the hlfir.eval_in_mem operating on this the provided
+// storage.  The provided storage must be a contiguous "raw" memory reference
+// (not a fir.box) big enough to hold the value computed by hlfir.eval_in_mem.
+// No runtime check is inserted by this utility to enforce that. It is also
+// usually invalid to provide some storage that is already addressed directly
+// or indirectly inside the hlfir.eval_in_mem body.
+void computeEvaluateOpIn(mlir::Location, fir::FirOpBuilder &,
+                         hlfir::EvaluateInMemoryOp, mlir::Value storage);
+
 std::pair<fir::ExtendedValue, std::optional<hlfir::CleanupFunction>>
 convertToValue(mlir::Location loc, fir::FirOpBuilder &builder,
                hlfir::Entity entity);
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 1ab8793f726523..a9826543f48b69 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -1755,4 +1755,63 @@ def hlfir_CharExtremumOp : hlfir_Op<"char_extremum",
   let hasVerifier = 1;
 }
 
+def hlfir_EvaluateInMemoryOp : hlfir_Op<"eval_in_mem", [AttrSizedOperandSegments,
+    RecursiveMemoryEffects, RecursivelySpeculatable,
+    SingleBlockImplicitTerminator<"fir::FirEndOp">]> {
+  let summary = "Wrap an in-memory implementation that computes expression value";
+  let description = [{
+    Returns a Fortran expression value for which the computation is
+    implemented inside the region operating on the block argument which
+    is a raw memory reference corresponding to the expression type.
+
+    The shape and type parameters of the expressions are operands of the
+    operations.
+
+    The memory cannot escape the region, and it is not described how it is
+    allocated. This facilitates later elision of the temporary storage for the
+    expression evaluation if it can be evaluated in some other storage (like a
+    left-hand side variable).
+
+    Example:
+
+    A function returning an array can be represented as:
+    ```
+      %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+      %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+      ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+        %3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+        fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+      }
+    ```
+  }];
+
+  let arguments = (ins
+    Optional<fir_ShapeType>:$shape,
+    Variadic<AnyIntegerType>:$typeparams
+  );
+
+  let results = (outs hlfir_ExprType);
+  let regions = (region  SizedRegion<1>:$body);
+
+  let assemblyFormat = [{
+    (`shape` $shape^)? (`typeparams` $typeparams^)?
+    attr-dict `:` functional-type(operands, results)
+    $body}];
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape,
+      CArg<"mlir::ValueRange", "{}">:$typeparams)>
+  ];
+
+  let extraClassDeclaration = [{
+      // Return block argument representing the memory where the expression
+      // is evaluated.
+      mlir::Value getMemory() {return getBody().getArgument(0);}
+  }];
+
+  let hasVerifier = 1;
+}
+
+
 #endif // FORTRAN_DIALECT_HLFIR_OPS
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 7425ccf7fc0e30..1bd950f2445ee4 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -535,6 +535,8 @@ static mlir::Value tryRetrievingShapeOrShift(hlfir::Entity entity) {
   if (mlir::isa<hlfir::ExprType>(entity.getType())) {
     if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
       return elemental.getShape();
+    if (auto evalInMem = entity.getDefiningOp<hlfir::EvaluateInMemoryOp>())
+      return evalInMem.getShape();
     return mlir::Value{};
   }
   if (auto varIface = entity.getIfVariableInterface())
@@ -642,6 +644,11 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
       result.append(elemental.getTypeparams().begin(),
                     elemental.getTypeparams().end());
       return;
+    } else if (auto evalInMem =
+                   expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
+      result.append(evalInMem.getTypeparams().begin(),
+                    evalInMem.getTypeparams().end());
+      return;
     } else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
       result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
       return;
@@ -1313,3 +1320,43 @@ hlfir::genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
   };
   return {hlfir::Entity{convertedRhs}, cleanup};
 }
+
+std::pair<hlfir::Entity, bool> hlfir::computeEvaluateOpInNewTemp(
+    mlir::Location loc, fir::FirOpBuilder &builder,
+    hlfir::EvaluateInMemoryOp evalInMem, mlir::Value shape,
+    mlir::ValueRange typeParams) {
+  llvm::StringRef tmpName{".tmp.expr_result"};
+  llvm::SmallVector<mlir::Value> extents =
+      hlfir::getIndexExtents(loc, builder, shape);
+  mlir::Type baseType =
+      hlfir::getFortranElementOrSequenceType(evalInMem.getType());
+  bool heapAllocated = fir::hasDynamicSize(baseType);
+  // Note: temporaries are stack allocated here when possible (do not require
+  // stack save/restore) because flang has always stack allocated function
+  // results.
+  mlir::Value temp = heapAllocated
+                         ? builder.createHeapTemporary(loc, baseType, tmpName,
+                                                       extents, typeParams)
+                         : builder.createTemporary(loc, baseType, tmpName,
+                                                   extents, typeParams);
+  mlir::Value innerMemory = evalInMem.getMemory();
+  temp = builder.createConvert(loc, innerMemory.getType(), temp);
+  auto declareOp = builder.create<hlfir::DeclareOp>(
+      loc, temp, tmpName, shape, typeParams,
+      /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
+  computeEvaluateOpIn(loc, builder, evalInMem, declareOp.getOriginalBase());
+  return {hlfir::Entity{declareOp.getBase()}, /*heapAllocated=*/heapAllocated};
+}
+
+void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder,
+                                hlfir::EvaluateInMemoryOp evalInMem,
+                                mlir::Value storage) {
+  mlir::Value innerMemory = evalInMem.getMemory();
+  mlir::Value storageCast =
+      builder.createConvert(loc, innerMemory.getType(), storage);
+  mlir::IRMapping mapper;
+  mapper.map(innerMemory, storageCast);
+  for (auto &op : evalInMem.getBody().front().without_terminator())
+    builder.clone(op, mapper);
+  return;
+}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index b593383ff2848d..87519882446485 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -333,6 +333,25 @@ static void printDesignatorComplexPart(mlir::OpAsmPrinter &p,
       p << "real";
   }
 }
+template <typename Op>
+static llvm::LogicalResult verifyTypeparams(Op &op, mlir::Type elementType,
+                                            unsigned numLenParam) {
+  if (mlir::isa<fir::CharacterType>(elementType)) {
+    if (numLenParam != 1)
+      return op.emitOpError("must be provided one length parameter when the "
+                            "result is a character");
+  } else if (fir::isRecordWithTypeParameters(elementType)) {
+    if (numLenParam !=
+        mlir::cast<fir::RecordType>(elementType).getNumLenParams())
+      return op.emitOpError("must be provided the same number of length "
+                            "parameters as in the result derived type");
+  } else if (numLenParam != 0) {
+    return op.emitOpError(
+        "must not be provided length parameters if the result "
+        "type does not have length parameters");
+  }
+  return mlir::success();
+}
 
 llvm::LogicalResult hlfir::DesignateOp::verify() {
   mlir::Type memrefType = getMemref().getType();
@@ -462,20 +481,10 @@ llvm::LogicalResult hlfir::DesignateOp::verify() {
         return emitOpError("shape must be a fir.shape or fir.shapeshift with "
                            "the rank of the result");
     }
-    auto numLenParam = getTypeparams().size();
-    if (mlir::isa<fir::CharacterType>(outputElementType)) {
-      if (numLenParam != 1)
-        return emitOpError("must be provided one length parameter when the "
-                           "result is a character");
-    } else if (fir::isRecordWithTypeParameters(outputElementType)) {
-      if (numLenParam !=
-          mlir::cast<fir::RecordType>(outputElementType).getNumLenParams())
-        return emitOpError("must be provided the same number of length "
-                           "parameters as in the result derived type");
-    } else if (numLenParam != 0) {
-      return emitOpError("must not be provided length parameters if the result "
-                         "type does not have length parameters");
-    }
+    if (auto res =
+            verifyTypeparams(*this, outputElementType, getTypeparams().size());
+        failed(res))
+      return res;
   }
   return mlir::success();
 }
@@ -1989,6 +1998,45 @@ hlfir::GetLengthOp::canonicalize(GetLengthOp getLength,
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// EvaluateInMemoryOp
+//===----------------------------------------------------------------------===//
+
+void hlfir::EvaluateInMemoryOp::build(mlir::OpBuilder &builder,
+                                      mlir::OperationState &odsState,
+                                      mlir::Type resultType, mlir::Value shape,
+                                      mlir::ValueRange typeparams) {
+  odsState.addTypes(resultType);
+  if (shape)
+    odsState.addOperands(shape);
+  odsState.addOperands(typeparams);
+  odsState.addAttribute(
+      getOperandSegmentSizeAttr(),
+      builder.getDenseI32ArrayAttr(
+          {shape ? 1 : 0, static_cast<int32_t>(typeparams.size())}));
+  mlir::Region *bodyRegion = odsState.addRegion();
+  bodyRegion->push_back(new mlir::Block{});
+  mlir::Type memType = fir::ReferenceType::get(
+      hlfir::getFortranElementOrSequenceType(resultType));
+  bodyRegion->front().addArgument(memType, odsState.location);
+  EvaluateInMemoryOp::ensureTerminator(*bodyRegion, builder, odsState.location);
+}
+
+llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() {
+  unsigned shapeRank = 0;
+  if (mlir::Value shape = getShape())
+    if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(shape.getType()))
+      shapeRank = shapeTy.getRank();
+  auto exprType = mlir::cast<hlfir::ExprType>(getResult().getType());
+  if (shapeRank != exprType.getRank())
+    return emitOpError("`shape` rank must match the result rank");
+  mlir::Type elementType = exprType.getElementType();
+  if (auto res = verifyTypeparams(*this, elementType, getTypeparams().size());
+      failed(res))
+    return res;
+  return mlir::success();
+}
+
 #include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
 #define GET_OP_CLASSES
 #include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 1848dbe2c7a2c2..347f0a5630777f 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -905,6 +905,26 @@ struct CharExtremumOpConversion
   }
 };
 
+struct EvaluateInMemoryOpConversion
+    : public mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp> {
+  using mlir::OpConversionPattern<
+      hlfir::EvaluateInMemoryOp>::OpConversionPattern;
+  explicit EvaluateInMemoryOpConversion(mlir::MLIRContext *ctx)
+      : mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp>{ctx} {}
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::EvaluateInMemoryOp evalInMemOp, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    mlir::Location loc = evalInMemOp->getLoc();
+    fir::FirOpBuilder builder(rewriter, evalInMemOp.getOperation());
+    auto [temp, isHeapAlloc] = hlfir::computeEvaluateOpInNewTemp(
+        loc, builder, evalInMemOp, adaptor.getShape(), adaptor.getTypeparams());
+    mlir::Value bufferizedExpr =
+        packageBufferizedExpr(loc, builder, temp, isHeapAlloc);
+    rewriter.replaceOp(evalInMemOp, bufferizedExpr);
+    return mlir::success();
+  }
+};
+
 class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
 public:
   void runOnOperation() override {
@@ -918,12 +938,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
     auto module = this->getOperation();
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
-                    AssociateOpConversion, CharExtremumOpConversion,
-                    ConcatOpConversion, DestroyOpConversion,
-                    ElementalOpConversion, EndAssociateOpConversion,
-                    NoReassocOpConversion, SetLengthOpConversion,
-                    ShapeOfOpConversion, GetLengthOpConversion>(context);
+    patterns
+        .insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
+                AssociateOpConversion, CharExtremumOpConversion,
+                ConcatOpConversion, DestroyOpConversion, ElementalOpConversion,
+                EndAssociateOpConversion, EvaluateInMemoryOpConversion,
+                NoReassocOpConversion, SetLengthOpConversion,
+                ShapeOfOpConversion, GetLengthOpConversion>(context);
     mlir::ConversionTarget target(*context);
     // Note that YieldElementOp is not marked as an illegal operation.
     // It must be erased by its parent converter and there is no explicit
diff --git a/flang/test/HLFIR/eval_in_mem-codegen.fir b/flang/test/HLFIR/eval_in_mem-codegen.fir
new file mode 100644
index 00000000000000..26a989832ca927
--- /dev/null
+++ b/flang/test/HLFIR/eval_in_mem-codegen.fir
@@ -0,0 +1,107 @@
+// Test hlfir.eval_in_mem default code generation.
+
+// RUN: fir-opt %s --bufferize-hlfir -o - | FileCheck %s
+
+func.func @_QPtest() {
+  %c10 = arith.constant 10 : index
+  %0 = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+    %3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+    fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+  }
+  hlfir.assign %2 to %0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+  hlfir.destroy %2 : !hlfir.expr<10xf32>
+  return
+}
+fir.global internal @_QFtestEx : !fir.array<10xf32>
+func.func private @_QParray_func() -> !fir.array<10xf32>
+
+
+func.func @_QPtest_char() {
+  %c10 = arith.constant 10 : index
+  %c5 = arith.constant 5 : index
+  %0 = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = hlfir.eval_in_mem shape %1 typeparams %c5 : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10x!fir.char<1,5>>>):
+    %3 = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+    fir.save_result %3 to %arg0(%1) typeparams %c5 : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+  }
+  hlfir.assign %2 to %0 : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+  hlfir.destroy %2 : !hlfir.expr<10x!fir.char<1,5>>
+  return
+}
+
+fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+func.func @test_dynamic(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: index) {
+   %0 = fir.shape %arg1 : (index) -> !fir.shape<1>
+   %1 = hlfir.eval_in_mem shape %0 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+   ^bb0(%arg2: !fir.ref<!fir.array<?xf32>>):
+     %2 = fir.call @_QPdyn_array_func(%arg1) : (index) -> !fir.array<?xf32>
+     fir.save_result %2 to %arg2(%0) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+   }
+   hlfir.assign %1 to %arg0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+   hlfir.destroy %1 : !hlfir.expr<?xf32>
+   return
+}
+func.func private @_QPdyn_array_func(index) -> !fir.array<?xf32>
+
+// CHECK-LABEL:   func.func @_QPtest() {
+// CHECK:           %[[VAL_0:.*]] = fir.alloca !fir.array<10xf32> {bindc_name = ".tmp.expr_result"}
+// CHECK:           %[[VAL_1:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_2:.*]] = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+// CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+// CHECK:           %[[VAL_5:.*]] = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK:           fir.save_result %[[VAL_5]] to %[[VAL_4]]#1(%[[VAL_3]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+// CHECK:           %[[VAL_6:.*]] = arith.constant false
+// CHECK:           %[[VAL_7:.*]] = fir.undefined tuple<!fir.ref<!fir.array<10xf32>>, i1>
+// CHECK:           %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_6]], [1 : index] : (tuple<!fir.ref<!fir.array<10xf32>>, i1>, i1) -> tuple<!fir.ref<!fir.array<10xf32>>, i1>
+// CHECK:           %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_4]]#0, [0 : index] : (tuple<!fir.ref<!fir.array<10xf32>>, i1>, !fir.ref<!fir.array<10xf32>>) -> tuple<!fir.ref<!fir.array<10xf32>>, i1>
+// CHECK:           hlfir.assign %[[VAL_4]]#0 to %[[VAL_2]] : !fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         fir.global internal @_QFtestEx : !fir.array<10xf32>
+// CHECK:         func.func private @_QParray_func() -> !fir.array<10xf32>
+
+// CHECK-LABEL:   func.func @_QPtest_char() {
+// CHECK:           %[[VAL_0:.*]] = fir.alloca !fir.array<10x!fir.char<1,5>> {bindc_name = ".tmp.expr_result"}
+// CHECK:           %[[VAL_1:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_3:.*]] = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_4]]) typeparams %[[VAL_2]] {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index) -> (!fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>)
+// CHECK:           %[[VAL_6:.*]] = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+// CHECK:           fir.save_result %[[VAL_6]] to %[[VAL_5]]#1(%[[VAL_4]]) typeparams %[[VAL_2]] : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+// CHECK:           %[[VAL_7:.*]] = arith.constant false
+// CHECK:           %[[VAL_8:.*]] = fir.undefined tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>
+// CHECK:           %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_7]], [1 : index] : (tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>, i1) -> tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>
+// CHECK:           %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_5]]#0, [0 : index] : (tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>, !fir.ref<!fir.array<10x!fir.char<1,5>>>) -> tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>
+// CHECK:           hlfir.assign %[[VAL_5]]#0 to %[[VAL_3]] : !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+// CHECK:         func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+// CHECK-LABEL:   func.func @test_dynamic(
+// CHECK-SAME:                            %[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>>,
+// CHECK-SAME:                            %[[VAL_1:.*]]: index) {
+// CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_3:.*]] = fir.allocmem !fir.array<?xf32>, %[[VAL_1]] {bindc_name = ".tmp.expr_result", uniq_name = ""}
+// CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.heap<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]](%[[VAL_2]]) {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+// CHECK:           %[[VAL_6:.*]] = fir.call @_QPdyn_array_func(%[[VAL_1]]) : (index) -> !fir.array<?xf32>
+// CHECK:           fir.save_result %[[VAL_6]] to %[[VAL_5]]#1(%[[VAL_2]]) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+// CHECK:           %[[VAL_7:.*]] = arith.constant true
+// CHECK:           %[[VAL_8:.*]] = fir.undefined tuple<!fir.box<!fir.array<?xf32>>, i1>
+// CHECK:           %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_7]], [1 : index] : (tuple<!fir.box<!fir.array<?xf32>>, i1>, i1) -> tuple<!fir.box<!fir.array<?xf32>>, i1>
+// CHECK:           %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_5]]#0, [0 : index] : (tuple<!fir.box<!fir.array<?xf32>>, i1>, !fir.box<!fir.array<?xf32>>) -> tuple<!fir.box<!fir.array<?xf32>>, i1>
+// CHECK:           hlfir.assign %[[VAL_5]]#0 to %[[VAL_0]] : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
+// CHECK:           %[[VAL_11:.*]] = fir.box_addr %[[VAL_5]]#0 : (!fir.box<!fir.array<?xf32>>) -> !fir.heap<!fir.array<?xf32>>
+// CHECK:           fir.freemem %[[VAL_11]] : !fir.heap<!fir.array<?xf32>>
+// CHECK:           return
+// CHECK:         }
diff --git a/flang/test/HLFIR/eval_in_mem.fir b/flang/test/HLFIR/eval_in_mem.fir
new file mode 100644
index 00000000000000..34e48ed5be5452
--- /dev/null
+++ b/flang/test/HLFIR/eval_in_mem.fir
@@ -0,0 +1,99 @@
+// Test hlfir.eval_in_mem operation parse, verify (no errors), and unparse.
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+func.func @_QPtest() {
+  %c10 = arith.constant 10 : index
+  %0 = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+    %3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+    fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+  }
+  hlfir.assign %2 to %0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+  hlfir.destroy %2 : !hlfir.expr<10xf32>
+  return
+}
+fir.global internal @_QFtestEx : !fir.array<10xf32>
+func.func private @_QParray_func() -> !fir.array<10xf32>
+
+
+func.func @_QPtest_char() {
+  %c10 = arith.constant 10 : index
+  %c5 = arith.constant 5 : index
+  %0 = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = hlfir.eval_in_mem shape %1 typeparams %c5 : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10x!fir.char<1,5>>>):
+    %3 = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+    fir.save_result %3 to %arg0(%1) typeparams %c5 : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+  }
+  hlfir.assign %2 to %0 : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+  hlfir.destroy %2 : !hlfir.expr<10x!fir.char<1,5>>
+  return
+}
+
+fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+func.func @test_dynamic(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: index) {
+   %0 = fir.shape %arg1 : (index) -> !fir.shape<1>
+   %1 = hlfir.eval_in_mem shape %0 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+   ^bb0(%arg2: !fir.ref<!fir.array<?xf32>>):
+     %2 = fir.call @_QPdyn_array_func(%arg1) : (index) -> !fir.array<?xf32>
+     fir.save_result %2 to %arg2(%0) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+   }
+   hlfir.assign %1 to %arg0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+   hlfir.destroy %1 : !hlfir.expr<?xf32>
+   return
+}
+func.func private @_QPdyn_array_func(index) -> !fir.array<?xf32>
+
+// CHECK-LABEL:   func.func @_QPtest() {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_1:.*]] = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+// CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_3:.*]] = hlfir.eval_in_mem shape %[[VAL_2]] : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: !fir.ref<!fir.array<10xf32>>):
+// CHECK:             %[[VAL_5:.*]] = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK:             fir.save_result %[[VAL_5]] to %[[VAL_4]](%[[VAL_2]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+// CHECK:           }
+// CHECK:           hlfir.assign %[[VAL_3]] to %[[VAL_1]] : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+// CHECK:           hlfir.destroy %[[VAL_3]] : !hlfir.expr<10xf32>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         fir.global internal @_QFtestEx : !fir.array<10xf32>
+// CHECK:         func.func private @_QParray_func() -> !fir.array<10xf32>
+
+// CHECK-LABEL:   func.func @_QPtest_char() {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_2:.*]] = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_4:.*]] = hlfir.eval_in_mem shape %[[VAL_3]] typeparams %[[VAL_1]] : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> {
+// CHECK:           ^bb0(%[[VAL_5:.*]]: !fir.ref<!fir.array<10x!fir.char<1,5>>>):
+// CHECK:             %[[VAL_6:.*]] = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+// CHECK:             fir.save_result %[[VAL_6]] to %[[VAL_5]](%[[VAL_3]]) typeparams %[[VAL_1]] : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+// CHECK:           }
+// CHECK:           hlfir.assign %[[VAL_4]] to %[[VAL_2]] : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK:           hlfir.destroy %[[VAL_4]] : !hlfir.expr<10x!fir.char<1,5>>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+// CHECK:         func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+// CHECK-LABEL:   func.func @test_dynamic(
+// CHECK-SAME:                            %[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>>,
+// CHECK-SAME:                            %[[VAL_1:.*]]: index) {
+// CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_3:.*]] = hlfir.eval_in_mem shape %[[VAL_2]] : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: !fir.ref<!fir.array<?xf32>>):
+// CHECK:             %[[VAL_5:.*]] = fir.call @_QPdyn_array_func(%[[VAL_1]]) : (index) -> !fir.array<?xf32>
+// CHECK:             fir.save_result %[[VAL_5]] to %[[VAL_4]](%[[VAL_2]]) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+// CHECK:           }
+// CHECK:           hlfir.assign %[[VAL_3]] to %[[VAL_0]] : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+// CHECK:           hlfir.destroy %[[VAL_3]] : !hlfir.expr<?xf32>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         func.func private @_QPdyn_array_func(index) -> !fir.array<?xf32>
diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index c390dddcf3f387..5c5db7aac06970 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -1314,3 +1314,37 @@ func.func @end_associate_with_alloc_comp(%var: !hlfir.expr<?x!fir.type<_QMtypesT
   hlfir.end_associate %4#1, %4#2 : !fir.ref<!fir.array<?x!fir.type<_QMtypesTt{x:!fir.box<!fir.heap<f32>>}>>>, i1
   return
 }
+
+// -----
+
+func.func @bad_eval_in_mem_1() {
+  %c10 = arith.constant 10 : index
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+// expected-error at +1 {{'hlfir.eval_in_mem' op result #0 must be The type of an array, character, or derived type Fortran expression, but got '!fir.array<10xf32>'}}
+  %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !fir.array<10xf32> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+  }
+  return
+}
+
+// -----
+
+func.func @bad_eval_in_mem_2() {
+  %c10 = arith.constant 10 : index
+  %1 = fir.shape %c10, %c10 : (index, index) -> !fir.shape<2>
+  // expected-error at +1 {{'hlfir.eval_in_mem' op `shape` rank must match the result rank}}
+  %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<2>) -> !hlfir.expr<10xf32> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+  }
+  return
+}
+
+// -----
+
+func.func @bad_eval_in_mem_3() {
+  // expected-error at +1 {{'hlfir.eval_in_mem' op must be provided one length parameter when the result is a character}}
+  %1 = hlfir.eval_in_mem  : () -> !hlfir.expr<!fir.char<1,?>> {
+  ^bb0(%arg0: !fir.ref<!fir.char<1,?>>):
+  }
+  return
+}



More information about the flang-commits mailing list