[flang-commits] [flang] [flang][hlfir] optimize hlfir.eval_in_mem bufferization (PR #118069)

via flang-commits flang-commits at lists.llvm.org
Mon Dec 2 02:02:27 PST 2024


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/118069

>From 5d43a6bd4fb27015e6917eb471bc725fe31b9ac2 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 28 Nov 2024 08:25:05 -0800
Subject: [PATCH 1/3] [flang][hlfir] add hlfir.eval_in_mem operation

---
 .../flang/Optimizer/Builder/HLFIRTools.h      |  19 ++++
 .../include/flang/Optimizer/HLFIR/HLFIROps.td |  59 ++++++++++
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    |  47 ++++++++
 flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp     |  76 ++++++++++---
 .../HLFIR/Transforms/BufferizeHLFIR.cpp       |  33 +++++-
 flang/test/HLFIR/eval_in_mem-codegen.fir      | 107 ++++++++++++++++++
 flang/test/HLFIR/eval_in_mem.fir              |  99 ++++++++++++++++
 flang/test/HLFIR/invalid.fir                  |  34 ++++++
 8 files changed, 454 insertions(+), 20 deletions(-)
 create mode 100644 flang/test/HLFIR/eval_in_mem-codegen.fir
 create mode 100644 flang/test/HLFIR/eval_in_mem.fir

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index f073f494b3fb21..efbd9e4f50d432 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -33,6 +33,7 @@ class AssociateOp;
 class ElementalOp;
 class ElementalOpInterface;
 class ElementalAddrOp;
+class EvaluateInMemoryOp;
 class YieldElementOp;
 
 /// Is this a Fortran variable for which the defining op carrying the Fortran
@@ -398,6 +399,24 @@ mlir::Value inlineElementalOp(
     mlir::IRMapping &mapper,
     const std::function<bool(hlfir::ElementalOp)> &mustRecursivelyInline);
 
+/// Create a new temporary with the shape and parameters of the provided
+/// hlfir.eval_in_mem operation and clone the body of the hlfir.eval_in_mem
+/// operating on this new temporary.  returns the temporary and whether the
+/// temporary is heap or stack allocated.
+std::pair<hlfir::Entity, bool>
+computeEvaluateOpInNewTemp(mlir::Location, fir::FirOpBuilder &,
+                           hlfir::EvaluateInMemoryOp evalInMem,
+                           mlir::Value shape, mlir::ValueRange typeParams);
+
+// Clone the body of the hlfir.eval_in_mem operating on this the provided
+// storage.  The provided storage must be a contiguous "raw" memory reference
+// (not a fir.box) big enough to hold the value computed by hlfir.eval_in_mem.
+// No runtime check is inserted by this utility to enforce that. It is also
+// usually invalid to provide some storage that is already addressed directly
+// or indirectly inside the hlfir.eval_in_mem body.
+void computeEvaluateOpIn(mlir::Location, fir::FirOpBuilder &,
+                         hlfir::EvaluateInMemoryOp, mlir::Value storage);
+
 std::pair<fir::ExtendedValue, std::optional<hlfir::CleanupFunction>>
 convertToValue(mlir::Location loc, fir::FirOpBuilder &builder,
                hlfir::Entity entity);
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 1ab8793f726523..a9826543f48b69 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -1755,4 +1755,63 @@ def hlfir_CharExtremumOp : hlfir_Op<"char_extremum",
   let hasVerifier = 1;
 }
 
+def hlfir_EvaluateInMemoryOp : hlfir_Op<"eval_in_mem", [AttrSizedOperandSegments,
+    RecursiveMemoryEffects, RecursivelySpeculatable,
+    SingleBlockImplicitTerminator<"fir::FirEndOp">]> {
+  let summary = "Wrap an in-memory implementation that computes expression value";
+  let description = [{
+    Returns a Fortran expression value for which the computation is
+    implemented inside the region operating on the block argument which
+    is a raw memory reference corresponding to the expression type.
+
+    The shape and type parameters of the expressions are operands of the
+    operations.
+
+    The memory cannot escape the region, and it is not described how it is
+    allocated. This facilitates later elision of the temporary storage for the
+    expression evaluation if it can be evaluated in some other storage (like a
+    left-hand side variable).
+
+    Example:
+
+    A function returning an array can be represented as:
+    ```
+      %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+      %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+      ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+        %3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+        fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+      }
+    ```
+  }];
+
+  let arguments = (ins
+    Optional<fir_ShapeType>:$shape,
+    Variadic<AnyIntegerType>:$typeparams
+  );
+
+  let results = (outs hlfir_ExprType);
+  let regions = (region  SizedRegion<1>:$body);
+
+  let assemblyFormat = [{
+    (`shape` $shape^)? (`typeparams` $typeparams^)?
+    attr-dict `:` functional-type(operands, results)
+    $body}];
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape,
+      CArg<"mlir::ValueRange", "{}">:$typeparams)>
+  ];
+
+  let extraClassDeclaration = [{
+      // Return block argument representing the memory where the expression
+      // is evaluated.
+      mlir::Value getMemory() {return getBody().getArgument(0);}
+  }];
+
+  let hasVerifier = 1;
+}
+
+
 #endif // FORTRAN_DIALECT_HLFIR_OPS
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 7425ccf7fc0e30..1bd950f2445ee4 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -535,6 +535,8 @@ static mlir::Value tryRetrievingShapeOrShift(hlfir::Entity entity) {
   if (mlir::isa<hlfir::ExprType>(entity.getType())) {
     if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
       return elemental.getShape();
+    if (auto evalInMem = entity.getDefiningOp<hlfir::EvaluateInMemoryOp>())
+      return evalInMem.getShape();
     return mlir::Value{};
   }
   if (auto varIface = entity.getIfVariableInterface())
@@ -642,6 +644,11 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
       result.append(elemental.getTypeparams().begin(),
                     elemental.getTypeparams().end());
       return;
+    } else if (auto evalInMem =
+                   expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
+      result.append(evalInMem.getTypeparams().begin(),
+                    evalInMem.getTypeparams().end());
+      return;
     } else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
       result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
       return;
@@ -1313,3 +1320,43 @@ hlfir::genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
   };
   return {hlfir::Entity{convertedRhs}, cleanup};
 }
+
+std::pair<hlfir::Entity, bool> hlfir::computeEvaluateOpInNewTemp(
+    mlir::Location loc, fir::FirOpBuilder &builder,
+    hlfir::EvaluateInMemoryOp evalInMem, mlir::Value shape,
+    mlir::ValueRange typeParams) {
+  llvm::StringRef tmpName{".tmp.expr_result"};
+  llvm::SmallVector<mlir::Value> extents =
+      hlfir::getIndexExtents(loc, builder, shape);
+  mlir::Type baseType =
+      hlfir::getFortranElementOrSequenceType(evalInMem.getType());
+  bool heapAllocated = fir::hasDynamicSize(baseType);
+  // Note: temporaries are stack allocated here when possible (do not require
+  // stack save/restore) because flang has always stack allocated function
+  // results.
+  mlir::Value temp = heapAllocated
+                         ? builder.createHeapTemporary(loc, baseType, tmpName,
+                                                       extents, typeParams)
+                         : builder.createTemporary(loc, baseType, tmpName,
+                                                   extents, typeParams);
+  mlir::Value innerMemory = evalInMem.getMemory();
+  temp = builder.createConvert(loc, innerMemory.getType(), temp);
+  auto declareOp = builder.create<hlfir::DeclareOp>(
+      loc, temp, tmpName, shape, typeParams,
+      /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
+  computeEvaluateOpIn(loc, builder, evalInMem, declareOp.getOriginalBase());
+  return {hlfir::Entity{declareOp.getBase()}, /*heapAllocated=*/heapAllocated};
+}
+
+void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder,
+                                hlfir::EvaluateInMemoryOp evalInMem,
+                                mlir::Value storage) {
+  mlir::Value innerMemory = evalInMem.getMemory();
+  mlir::Value storageCast =
+      builder.createConvert(loc, innerMemory.getType(), storage);
+  mlir::IRMapping mapper;
+  mapper.map(innerMemory, storageCast);
+  for (auto &op : evalInMem.getBody().front().without_terminator())
+    builder.clone(op, mapper);
+  return;
+}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index b593383ff2848d..87519882446485 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -333,6 +333,25 @@ static void printDesignatorComplexPart(mlir::OpAsmPrinter &p,
       p << "real";
   }
 }
+template <typename Op>
+static llvm::LogicalResult verifyTypeparams(Op &op, mlir::Type elementType,
+                                            unsigned numLenParam) {
+  if (mlir::isa<fir::CharacterType>(elementType)) {
+    if (numLenParam != 1)
+      return op.emitOpError("must be provided one length parameter when the "
+                            "result is a character");
+  } else if (fir::isRecordWithTypeParameters(elementType)) {
+    if (numLenParam !=
+        mlir::cast<fir::RecordType>(elementType).getNumLenParams())
+      return op.emitOpError("must be provided the same number of length "
+                            "parameters as in the result derived type");
+  } else if (numLenParam != 0) {
+    return op.emitOpError(
+        "must not be provided length parameters if the result "
+        "type does not have length parameters");
+  }
+  return mlir::success();
+}
 
 llvm::LogicalResult hlfir::DesignateOp::verify() {
   mlir::Type memrefType = getMemref().getType();
@@ -462,20 +481,10 @@ llvm::LogicalResult hlfir::DesignateOp::verify() {
         return emitOpError("shape must be a fir.shape or fir.shapeshift with "
                            "the rank of the result");
     }
-    auto numLenParam = getTypeparams().size();
-    if (mlir::isa<fir::CharacterType>(outputElementType)) {
-      if (numLenParam != 1)
-        return emitOpError("must be provided one length parameter when the "
-                           "result is a character");
-    } else if (fir::isRecordWithTypeParameters(outputElementType)) {
-      if (numLenParam !=
-          mlir::cast<fir::RecordType>(outputElementType).getNumLenParams())
-        return emitOpError("must be provided the same number of length "
-                           "parameters as in the result derived type");
-    } else if (numLenParam != 0) {
-      return emitOpError("must not be provided length parameters if the result "
-                         "type does not have length parameters");
-    }
+    if (auto res =
+            verifyTypeparams(*this, outputElementType, getTypeparams().size());
+        failed(res))
+      return res;
   }
   return mlir::success();
 }
@@ -1989,6 +1998,45 @@ hlfir::GetLengthOp::canonicalize(GetLengthOp getLength,
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// EvaluateInMemoryOp
+//===----------------------------------------------------------------------===//
+
+void hlfir::EvaluateInMemoryOp::build(mlir::OpBuilder &builder,
+                                      mlir::OperationState &odsState,
+                                      mlir::Type resultType, mlir::Value shape,
+                                      mlir::ValueRange typeparams) {
+  odsState.addTypes(resultType);
+  if (shape)
+    odsState.addOperands(shape);
+  odsState.addOperands(typeparams);
+  odsState.addAttribute(
+      getOperandSegmentSizeAttr(),
+      builder.getDenseI32ArrayAttr(
+          {shape ? 1 : 0, static_cast<int32_t>(typeparams.size())}));
+  mlir::Region *bodyRegion = odsState.addRegion();
+  bodyRegion->push_back(new mlir::Block{});
+  mlir::Type memType = fir::ReferenceType::get(
+      hlfir::getFortranElementOrSequenceType(resultType));
+  bodyRegion->front().addArgument(memType, odsState.location);
+  EvaluateInMemoryOp::ensureTerminator(*bodyRegion, builder, odsState.location);
+}
+
+llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() {
+  unsigned shapeRank = 0;
+  if (mlir::Value shape = getShape())
+    if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(shape.getType()))
+      shapeRank = shapeTy.getRank();
+  auto exprType = mlir::cast<hlfir::ExprType>(getResult().getType());
+  if (shapeRank != exprType.getRank())
+    return emitOpError("`shape` rank must match the result rank");
+  mlir::Type elementType = exprType.getElementType();
+  if (auto res = verifyTypeparams(*this, elementType, getTypeparams().size());
+      failed(res))
+    return res;
+  return mlir::success();
+}
+
 #include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
 #define GET_OP_CLASSES
 #include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 1848dbe2c7a2c2..347f0a5630777f 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -905,6 +905,26 @@ struct CharExtremumOpConversion
   }
 };
 
+struct EvaluateInMemoryOpConversion
+    : public mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp> {
+  using mlir::OpConversionPattern<
+      hlfir::EvaluateInMemoryOp>::OpConversionPattern;
+  explicit EvaluateInMemoryOpConversion(mlir::MLIRContext *ctx)
+      : mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp>{ctx} {}
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::EvaluateInMemoryOp evalInMemOp, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    mlir::Location loc = evalInMemOp->getLoc();
+    fir::FirOpBuilder builder(rewriter, evalInMemOp.getOperation());
+    auto [temp, isHeapAlloc] = hlfir::computeEvaluateOpInNewTemp(
+        loc, builder, evalInMemOp, adaptor.getShape(), adaptor.getTypeparams());
+    mlir::Value bufferizedExpr =
+        packageBufferizedExpr(loc, builder, temp, isHeapAlloc);
+    rewriter.replaceOp(evalInMemOp, bufferizedExpr);
+    return mlir::success();
+  }
+};
+
 class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
 public:
   void runOnOperation() override {
@@ -918,12 +938,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
     auto module = this->getOperation();
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
-                    AssociateOpConversion, CharExtremumOpConversion,
-                    ConcatOpConversion, DestroyOpConversion,
-                    ElementalOpConversion, EndAssociateOpConversion,
-                    NoReassocOpConversion, SetLengthOpConversion,
-                    ShapeOfOpConversion, GetLengthOpConversion>(context);
+    patterns
+        .insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
+                AssociateOpConversion, CharExtremumOpConversion,
+                ConcatOpConversion, DestroyOpConversion, ElementalOpConversion,
+                EndAssociateOpConversion, EvaluateInMemoryOpConversion,
+                NoReassocOpConversion, SetLengthOpConversion,
+                ShapeOfOpConversion, GetLengthOpConversion>(context);
     mlir::ConversionTarget target(*context);
     // Note that YieldElementOp is not marked as an illegal operation.
     // It must be erased by its parent converter and there is no explicit
diff --git a/flang/test/HLFIR/eval_in_mem-codegen.fir b/flang/test/HLFIR/eval_in_mem-codegen.fir
new file mode 100644
index 00000000000000..26a989832ca927
--- /dev/null
+++ b/flang/test/HLFIR/eval_in_mem-codegen.fir
@@ -0,0 +1,107 @@
+// Test hlfir.eval_in_mem default code generation.
+
+// RUN: fir-opt %s --bufferize-hlfir -o - | FileCheck %s
+
+func.func @_QPtest() {
+  %c10 = arith.constant 10 : index
+  %0 = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+    %3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+    fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+  }
+  hlfir.assign %2 to %0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+  hlfir.destroy %2 : !hlfir.expr<10xf32>
+  return
+}
+fir.global internal @_QFtestEx : !fir.array<10xf32>
+func.func private @_QParray_func() -> !fir.array<10xf32>
+
+
+func.func @_QPtest_char() {
+  %c10 = arith.constant 10 : index
+  %c5 = arith.constant 5 : index
+  %0 = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = hlfir.eval_in_mem shape %1 typeparams %c5 : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10x!fir.char<1,5>>>):
+    %3 = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+    fir.save_result %3 to %arg0(%1) typeparams %c5 : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+  }
+  hlfir.assign %2 to %0 : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+  hlfir.destroy %2 : !hlfir.expr<10x!fir.char<1,5>>
+  return
+}
+
+fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+func.func @test_dynamic(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: index) {
+   %0 = fir.shape %arg1 : (index) -> !fir.shape<1>
+   %1 = hlfir.eval_in_mem shape %0 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+   ^bb0(%arg2: !fir.ref<!fir.array<?xf32>>):
+     %2 = fir.call @_QPdyn_array_func(%arg1) : (index) -> !fir.array<?xf32>
+     fir.save_result %2 to %arg2(%0) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+   }
+   hlfir.assign %1 to %arg0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+   hlfir.destroy %1 : !hlfir.expr<?xf32>
+   return
+}
+func.func private @_QPdyn_array_func(index) -> !fir.array<?xf32>
+
+// CHECK-LABEL:   func.func @_QPtest() {
+// CHECK:           %[[VAL_0:.*]] = fir.alloca !fir.array<10xf32> {bindc_name = ".tmp.expr_result"}
+// CHECK:           %[[VAL_1:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_2:.*]] = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+// CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+// CHECK:           %[[VAL_5:.*]] = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK:           fir.save_result %[[VAL_5]] to %[[VAL_4]]#1(%[[VAL_3]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+// CHECK:           %[[VAL_6:.*]] = arith.constant false
+// CHECK:           %[[VAL_7:.*]] = fir.undefined tuple<!fir.ref<!fir.array<10xf32>>, i1>
+// CHECK:           %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_6]], [1 : index] : (tuple<!fir.ref<!fir.array<10xf32>>, i1>, i1) -> tuple<!fir.ref<!fir.array<10xf32>>, i1>
+// CHECK:           %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_4]]#0, [0 : index] : (tuple<!fir.ref<!fir.array<10xf32>>, i1>, !fir.ref<!fir.array<10xf32>>) -> tuple<!fir.ref<!fir.array<10xf32>>, i1>
+// CHECK:           hlfir.assign %[[VAL_4]]#0 to %[[VAL_2]] : !fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         fir.global internal @_QFtestEx : !fir.array<10xf32>
+// CHECK:         func.func private @_QParray_func() -> !fir.array<10xf32>
+
+// CHECK-LABEL:   func.func @_QPtest_char() {
+// CHECK:           %[[VAL_0:.*]] = fir.alloca !fir.array<10x!fir.char<1,5>> {bindc_name = ".tmp.expr_result"}
+// CHECK:           %[[VAL_1:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_3:.*]] = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_4]]) typeparams %[[VAL_2]] {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index) -> (!fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>)
+// CHECK:           %[[VAL_6:.*]] = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+// CHECK:           fir.save_result %[[VAL_6]] to %[[VAL_5]]#1(%[[VAL_4]]) typeparams %[[VAL_2]] : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+// CHECK:           %[[VAL_7:.*]] = arith.constant false
+// CHECK:           %[[VAL_8:.*]] = fir.undefined tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>
+// CHECK:           %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_7]], [1 : index] : (tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>, i1) -> tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>
+// CHECK:           %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_5]]#0, [0 : index] : (tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>, !fir.ref<!fir.array<10x!fir.char<1,5>>>) -> tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>
+// CHECK:           hlfir.assign %[[VAL_5]]#0 to %[[VAL_3]] : !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+// CHECK:         func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+// CHECK-LABEL:   func.func @test_dynamic(
+// CHECK-SAME:                            %[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>>,
+// CHECK-SAME:                            %[[VAL_1:.*]]: index) {
+// CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_3:.*]] = fir.allocmem !fir.array<?xf32>, %[[VAL_1]] {bindc_name = ".tmp.expr_result", uniq_name = ""}
+// CHECK:           %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.heap<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK:           %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]](%[[VAL_2]]) {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+// CHECK:           %[[VAL_6:.*]] = fir.call @_QPdyn_array_func(%[[VAL_1]]) : (index) -> !fir.array<?xf32>
+// CHECK:           fir.save_result %[[VAL_6]] to %[[VAL_5]]#1(%[[VAL_2]]) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+// CHECK:           %[[VAL_7:.*]] = arith.constant true
+// CHECK:           %[[VAL_8:.*]] = fir.undefined tuple<!fir.box<!fir.array<?xf32>>, i1>
+// CHECK:           %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_7]], [1 : index] : (tuple<!fir.box<!fir.array<?xf32>>, i1>, i1) -> tuple<!fir.box<!fir.array<?xf32>>, i1>
+// CHECK:           %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_5]]#0, [0 : index] : (tuple<!fir.box<!fir.array<?xf32>>, i1>, !fir.box<!fir.array<?xf32>>) -> tuple<!fir.box<!fir.array<?xf32>>, i1>
+// CHECK:           hlfir.assign %[[VAL_5]]#0 to %[[VAL_0]] : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
+// CHECK:           %[[VAL_11:.*]] = fir.box_addr %[[VAL_5]]#0 : (!fir.box<!fir.array<?xf32>>) -> !fir.heap<!fir.array<?xf32>>
+// CHECK:           fir.freemem %[[VAL_11]] : !fir.heap<!fir.array<?xf32>>
+// CHECK:           return
+// CHECK:         }
diff --git a/flang/test/HLFIR/eval_in_mem.fir b/flang/test/HLFIR/eval_in_mem.fir
new file mode 100644
index 00000000000000..34e48ed5be5452
--- /dev/null
+++ b/flang/test/HLFIR/eval_in_mem.fir
@@ -0,0 +1,99 @@
+// Test hlfir.eval_in_mem operation parse, verify (no errors), and unparse.
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+func.func @_QPtest() {
+  %c10 = arith.constant 10 : index
+  %0 = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+    %3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+    fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+  }
+  hlfir.assign %2 to %0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+  hlfir.destroy %2 : !hlfir.expr<10xf32>
+  return
+}
+fir.global internal @_QFtestEx : !fir.array<10xf32>
+func.func private @_QParray_func() -> !fir.array<10xf32>
+
+
+func.func @_QPtest_char() {
+  %c10 = arith.constant 10 : index
+  %c5 = arith.constant 5 : index
+  %0 = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2 = hlfir.eval_in_mem shape %1 typeparams %c5 : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10x!fir.char<1,5>>>):
+    %3 = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+    fir.save_result %3 to %arg0(%1) typeparams %c5 : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+  }
+  hlfir.assign %2 to %0 : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+  hlfir.destroy %2 : !hlfir.expr<10x!fir.char<1,5>>
+  return
+}
+
+fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+func.func @test_dynamic(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: index) {
+   %0 = fir.shape %arg1 : (index) -> !fir.shape<1>
+   %1 = hlfir.eval_in_mem shape %0 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+   ^bb0(%arg2: !fir.ref<!fir.array<?xf32>>):
+     %2 = fir.call @_QPdyn_array_func(%arg1) : (index) -> !fir.array<?xf32>
+     fir.save_result %2 to %arg2(%0) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+   }
+   hlfir.assign %1 to %arg0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+   hlfir.destroy %1 : !hlfir.expr<?xf32>
+   return
+}
+func.func private @_QPdyn_array_func(index) -> !fir.array<?xf32>
+
+// CHECK-LABEL:   func.func @_QPtest() {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_1:.*]] = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+// CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_3:.*]] = hlfir.eval_in_mem shape %[[VAL_2]] : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: !fir.ref<!fir.array<10xf32>>):
+// CHECK:             %[[VAL_5:.*]] = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK:             fir.save_result %[[VAL_5]] to %[[VAL_4]](%[[VAL_2]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+// CHECK:           }
+// CHECK:           hlfir.assign %[[VAL_3]] to %[[VAL_1]] : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+// CHECK:           hlfir.destroy %[[VAL_3]] : !hlfir.expr<10xf32>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         fir.global internal @_QFtestEx : !fir.array<10xf32>
+// CHECK:         func.func private @_QParray_func() -> !fir.array<10xf32>
+
+// CHECK-LABEL:   func.func @_QPtest_char() {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_2:.*]] = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK:           %[[VAL_3:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_4:.*]] = hlfir.eval_in_mem shape %[[VAL_3]] typeparams %[[VAL_1]] : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> {
+// CHECK:           ^bb0(%[[VAL_5:.*]]: !fir.ref<!fir.array<10x!fir.char<1,5>>>):
+// CHECK:             %[[VAL_6:.*]] = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+// CHECK:             fir.save_result %[[VAL_6]] to %[[VAL_5]](%[[VAL_3]]) typeparams %[[VAL_1]] : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+// CHECK:           }
+// CHECK:           hlfir.assign %[[VAL_4]] to %[[VAL_2]] : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK:           hlfir.destroy %[[VAL_4]] : !hlfir.expr<10x!fir.char<1,5>>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+// CHECK:         func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+// CHECK-LABEL:   func.func @test_dynamic(
+// CHECK-SAME:                            %[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>>,
+// CHECK-SAME:                            %[[VAL_1:.*]]: index) {
+// CHECK:           %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:           %[[VAL_3:.*]] = hlfir.eval_in_mem shape %[[VAL_2]] : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// CHECK:           ^bb0(%[[VAL_4:.*]]: !fir.ref<!fir.array<?xf32>>):
+// CHECK:             %[[VAL_5:.*]] = fir.call @_QPdyn_array_func(%[[VAL_1]]) : (index) -> !fir.array<?xf32>
+// CHECK:             fir.save_result %[[VAL_5]] to %[[VAL_4]](%[[VAL_2]]) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+// CHECK:           }
+// CHECK:           hlfir.assign %[[VAL_3]] to %[[VAL_0]] : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+// CHECK:           hlfir.destroy %[[VAL_3]] : !hlfir.expr<?xf32>
+// CHECK:           return
+// CHECK:         }
+// CHECK:         func.func private @_QPdyn_array_func(index) -> !fir.array<?xf32>
diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index c390dddcf3f387..5c5db7aac06970 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -1314,3 +1314,37 @@ func.func @end_associate_with_alloc_comp(%var: !hlfir.expr<?x!fir.type<_QMtypesT
   hlfir.end_associate %4#1, %4#2 : !fir.ref<!fir.array<?x!fir.type<_QMtypesTt{x:!fir.box<!fir.heap<f32>>}>>>, i1
   return
 }
+
+// -----
+
+func.func @bad_eval_in_mem_1() {
+  %c10 = arith.constant 10 : index
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+// expected-error at +1 {{'hlfir.eval_in_mem' op result #0 must be The type of an array, character, or derived type Fortran expression, but got '!fir.array<10xf32>'}}
+  %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !fir.array<10xf32> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+  }
+  return
+}
+
+// -----
+
+func.func @bad_eval_in_mem_2() {
+  %c10 = arith.constant 10 : index
+  %1 = fir.shape %c10, %c10 : (index, index) -> !fir.shape<2>
+  // expected-error at +1 {{'hlfir.eval_in_mem' op `shape` rank must match the result rank}}
+  %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<2>) -> !hlfir.expr<10xf32> {
+  ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+  }
+  return
+}
+
+// -----
+
+func.func @bad_eval_in_mem_3() {
+  // expected-error at +1 {{'hlfir.eval_in_mem' op must be provided one length parameter when the result is a character}}
+  %1 = hlfir.eval_in_mem  : () -> !hlfir.expr<!fir.char<1,?>> {
+  ^bb0(%arg0: !fir.ref<!fir.char<1,?>>):
+  }
+  return
+}

>From 6be998ec6f74937820e350f51796490b41a76d46 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 28 Nov 2024 08:26:39 -0800
Subject: [PATCH 2/3] [flang][hlfir] optimize hlfir.eval_in_mem bufferization

---
 .../lib/Optimizer/Analysis/AliasAnalysis.cpp  |  14 ++-
 .../Transforms/OptimizedBufferization.cpp     | 108 ++++++++++++++++++
 .../HLFIR/opt-bufferization-eval_in_mem.fir   |  67 +++++++++++
 3 files changed, 188 insertions(+), 1 deletion(-)
 create mode 100644 flang/test/HLFIR/opt-bufferization-eval_in_mem.fir

diff --git a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
index 2b24791d6c7c52..c561285b9feef5 100644
--- a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
+++ b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
@@ -91,6 +91,13 @@ bool AliasAnalysis::Source::isDummyArgument() const {
   return false;
 }
 
+static bool isEvaluateInMemoryBlockArg(mlir::Value v) {
+  if (auto evalInMem = llvm::dyn_cast_or_null<hlfir::EvaluateInMemoryOp>(
+          v.getParentRegion()->getParentOp()))
+    return evalInMem.getMemory() == v;
+  return false;
+}
+
 bool AliasAnalysis::Source::isData() const { return origin.isData; }
 bool AliasAnalysis::Source::isBoxData() const {
   return mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(valueType)) &&
@@ -698,7 +705,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
           breakFromLoop = true;
         });
   }
-  if (!defOp && type == SourceKind::Unknown)
+  if (!defOp && type == SourceKind::Unknown) {
     // Check if the memory source is coming through a dummy argument.
     if (isDummyArgument(v)) {
       type = SourceKind::Argument;
@@ -708,7 +715,12 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
 
       if (isPointerReference(ty))
         attributes.set(Attribute::Pointer);
+    } else if (isEvaluateInMemoryBlockArg(v)) {
+      // hlfir.eval_in_mem block operands is allocated by the operation.
+      type = SourceKind::Allocate;
+      ty = v.getType();
     }
+  }
 
   if (type == SourceKind::Global) {
     return {{global, instantiationPoint, followingData},
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index a0160b233e3cd1..e8c15a256b9da0 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1108,6 +1108,113 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
   }
 };
 
+class EvaluateIntoMemoryAssignBufferization
+    : public mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp> {
+
+public:
+  using mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp>::OpRewritePattern;
+
+  llvm::LogicalResult
+  matchAndRewrite(hlfir::EvaluateInMemoryOp,
+                  mlir::PatternRewriter &rewriter) const override;
+};
+
+static bool mayReadOrWrite(mlir::Region &region, mlir::Value var) {
+  fir::AliasAnalysis aliasAnalysis;
+  for (mlir::Operation &op : region.getOps()) {
+    if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
+      for (mlir::Region &subRegion : op.getRegions())
+        if (mayReadOrWrite(subRegion, var))
+          return true;
+      // In MLIR, RecursiveMemoryEffects can be combined with
+      // MemoryEffectOpInterface to describe extra effects on top of the
+      // effects of the nested operations.  However, the presence of
+      // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
+      // implies the operation has no other memory effects than the one of its
+      // nested operations.
+      if (!mlir::isa<mlir::MemoryEffectOpInterface>(op))
+        continue;
+    }
+    if (!aliasAnalysis.getModRef(&op, var).isNoModRef())
+      return true;
+  }
+  return false;
+}
+
+static llvm::LogicalResult
+tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
+                          mlir::PatternRewriter &rewriter) {
+  mlir::Location loc = evalInMem.getLoc();
+  hlfir::DestroyOp destroy;
+  hlfir::AssignOp assign;
+  for (auto user : llvm::enumerate(evalInMem->getUsers())) {
+    if (user.index() > 2)
+      return mlir::failure();
+    mlir::TypeSwitch<mlir::Operation *, void>(user.value())
+        .Case([&](hlfir::AssignOp op) { assign = op; })
+        .Case([&](hlfir::DestroyOp op) { destroy = op; });
+  }
+  if (!assign || !destroy || destroy.mustFinalizeExpr() ||
+      assign.isAllocatableAssignment())
+    return mlir::failure();
+
+  hlfir::Entity lhs{assign.getLhs()};
+  // EvaluateInMemoryOp memory is contiguous, so in general, it can only be
+  // replace by the LHS if the LHS is contiguous.
+  if (!lhs.isSimplyContiguous())
+    return mlir::failure();
+  // Character assignment may involves truncation/padding, so the LHS
+  // cannot be used to evaluate RHS in place without proving the LHS and
+  // RHS lengths are the same.
+  if (lhs.isCharacter())
+    return mlir::failure();
+
+  // The region must not read or write the LHS.
+  if (mayReadOrWrite(evalInMem.getBody(), lhs))
+    return mlir::failure();
+  // Any variables affected between the hlfir.evalInMem and assignment must not
+  // be read or written inside the region since it will be moved at the
+  // assignment insertion point.
+  auto effects = getEffectsBetween(evalInMem->getNextNode(), assign);
+  if (!effects) {
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << "operation with unknown effects between eval_in_mem and assign\n");
+    return mlir::failure();
+  }
+  for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
+    mlir::Value affected = effect.getValue();
+    if (!affected || mayReadOrWrite(evalInMem.getBody(), affected))
+      return mlir::failure();
+  }
+
+  rewriter.setInsertionPoint(assign);
+  fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
+  mlir::Value rawLhs = hlfir::genVariableRawAddress(loc, builder, lhs);
+  hlfir::computeEvaluateOpIn(loc, builder, evalInMem, rawLhs);
+  rewriter.eraseOp(assign);
+  rewriter.eraseOp(destroy);
+  rewriter.eraseOp(evalInMem);
+  return mlir::success();
+}
+
+llvm::LogicalResult EvaluateIntoMemoryAssignBufferization::matchAndRewrite(
+    hlfir::EvaluateInMemoryOp evalInMem,
+    mlir::PatternRewriter &rewriter) const {
+  if (mlir::succeeded(tryUsingAssignLhsDirectly(evalInMem, rewriter)))
+    return mlir::success();
+  // Rewrite to temp + as_expr here so that the assign + as_expr pattern can
+  // kick-in for simple types and at least implement the assignment inline
+  // instead of call Assign runtime.
+  fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
+  mlir::Location loc = evalInMem.getLoc();
+  auto [temp, isHeapAllocated] = hlfir::computeEvaluateOpInNewTemp(
+      loc, builder, evalInMem, evalInMem.getShape(), evalInMem.getTypeparams());
+  rewriter.replaceOpWithNewOp<hlfir::AsExprOp>(
+      evalInMem, temp, /*mustFree=*/builder.createBool(loc, isHeapAllocated));
+  return mlir::success();
+}
+
 class OptimizedBufferizationPass
     : public hlfir::impl::OptimizedBufferizationBase<
           OptimizedBufferizationPass> {
@@ -1130,6 +1237,7 @@ class OptimizedBufferizationPass
     patterns.insert<ElementalAssignBufferization>(context);
     patterns.insert<BroadcastAssignBufferization>(context);
     patterns.insert<VariableAssignBufferization>(context);
+    patterns.insert<EvaluateIntoMemoryAssignBufferization>(context);
     patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
     patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
     patterns.insert<ReductionConversion<hlfir::AllOp>>(context);
diff --git a/flang/test/HLFIR/opt-bufferization-eval_in_mem.fir b/flang/test/HLFIR/opt-bufferization-eval_in_mem.fir
new file mode 100644
index 00000000000000..984c0bcbaddcc3
--- /dev/null
+++ b/flang/test/HLFIR/opt-bufferization-eval_in_mem.fir
@@ -0,0 +1,67 @@
+// RUN: fir-opt --opt-bufferization %s | FileCheck %s
+
+// Fortran F2023 15.5.2.14 point 4. ensures that _QPfoo cannot access _QFtestEx
+// and the temporary storage for the result can be avoided.
+func.func @_QPtest(%arg0: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x"}) {
+  %c10 = arith.constant 10 : index
+  %0 = fir.dummy_scope : !fir.dscope
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2:2 = hlfir.declare %arg0(%1) dummy_scope %0 {uniq_name = "_QFtestEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+  %3 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+  ^bb0(%arg1: !fir.ref<!fir.array<10xf32>>):
+    %4 = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
+    fir.save_result %4 to %arg1(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+  }
+  hlfir.assign %3 to %2#0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+  hlfir.destroy %3 : !hlfir.expr<10xf32>
+  return
+}
+func.func private @_QPfoo() -> !fir.array<10xf32>
+
+// CHECK-LABEL: func.func @_QPtest(
+// CHECK-SAME:                     %[[VAL_0:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x"}) {
+// CHECK:         %[[VAL_1:.*]] = arith.constant 10 : index
+// CHECK:         %[[VAL_2:.*]] = fir.dummy_scope : !fir.dscope
+// CHECK:         %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK:         %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) dummy_scope %[[VAL_2]] {uniq_name = "_QFtestEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+// CHECK:         %[[VAL_5:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK:         fir.save_result %[[VAL_5]] to %[[VAL_4]]#1(%[[VAL_3]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+// CHECK:         return
+// CHECK:       }
+
+
+// Temporary storage cannot be avoided in this case since
+// _QFnegative_test_is_targetEx has the TARGET attribute.
+func.func @_QPnegative_test_is_target(%arg0: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x", fir.target}) {
+  %c10 = arith.constant 10 : index
+  %0 = fir.dummy_scope : !fir.dscope
+  %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %2:2 = hlfir.declare %arg0(%1) dummy_scope %0 {fortran_attrs = #fir.var_attrs<target>, uniq_name = "_QFnegative_test_is_targetEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+  %3 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+  ^bb0(%arg1: !fir.ref<!fir.array<10xf32>>):
+    %4 = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
+    fir.save_result %4 to %arg1(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+  }
+  hlfir.assign %3 to %2#0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+  hlfir.destroy %3 : !hlfir.expr<10xf32>
+  return
+}
+// CHECK-LABEL: func.func @_QPnegative_test_is_target(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x", fir.target}) {
+// CHECK:         %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK:         %[[VAL_2:.*]] = arith.constant false
+// CHECK:         %[[VAL_3:.*]] = arith.constant 10 : index
+// CHECK:         %[[VAL_4:.*]] = fir.alloca !fir.array<10xf32>
+// CHECK:         %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]]{{.*}}
+// CHECK:         %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_4]]{{.*}}
+// CHECK:         %[[VAL_9:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK:         fir.save_result %[[VAL_9]] to %[[VAL_8]]#1{{.*}}
+// CHECK:         %[[VAL_10:.*]] = hlfir.as_expr %[[VAL_8]]#0 move %[[VAL_2]] : (!fir.ref<!fir.array<10xf32>>, i1) -> !hlfir.expr<10xf32>
+// CHECK:         fir.do_loop %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_1]] unordered {
+// CHECK:           %[[VAL_12:.*]] = hlfir.apply %[[VAL_10]], %[[VAL_11]] : (!hlfir.expr<10xf32>, index) -> f32
+// CHECK:           %[[VAL_13:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_11]])  : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK:           hlfir.assign %[[VAL_12]] to %[[VAL_13]] : f32, !fir.ref<f32>
+// CHECK:         }
+// CHECK:         hlfir.destroy %[[VAL_10]] : !hlfir.expr<10xf32>
+// CHECK:         return
+// CHECK:       }

>From d68b5b2652831cda053c00465b33039bc645bc02 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 2 Dec 2024 01:59:33 -0800
Subject: [PATCH 3/3] PR118069 comment: mayReadOrWrite to getModRef

---
 .../flang/Optimizer/Analysis/AliasAnalysis.h  |  6 +++
 .../lib/Optimizer/Analysis/AliasAnalysis.cpp  | 27 ++++++++++++++
 .../Transforms/OptimizedBufferization.cpp     | 37 ++++++-------------
 3 files changed, 45 insertions(+), 25 deletions(-)

diff --git a/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h b/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h
index e410831c0fc3eb..8d17e4e476d10d 100644
--- a/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h
+++ b/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h
@@ -198,6 +198,12 @@ struct AliasAnalysis {
   /// Return the modify-reference behavior of `op` on `location`.
   mlir::ModRefResult getModRef(mlir::Operation *op, mlir::Value location);
 
+  /// Return the modify-reference behavior of operations inside `region` on
+  /// `location`. Contrary to getModRef(operation, location), this will visit
+  /// nested regions recursively according to the HasRecursiveMemoryEffects
+  /// trait.
+  mlir::ModRefResult getModRef(mlir::Region &region, mlir::Value location);
+
   /// Return the memory source of a value.
   /// If getLastInstantiationPoint is true, the search for the source
   /// will stop at [hl]fir.declare if it represents a dummy
diff --git a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
index c561285b9feef5..0b0f83d024ce33 100644
--- a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
+++ b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
@@ -464,6 +464,33 @@ ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) {
   return result;
 }
 
+ModRefResult AliasAnalysis::getModRef(mlir::Region &region,
+                                      mlir::Value location) {
+  ModRefResult result = ModRefResult::getNoModRef();
+  for (mlir::Operation &op : region.getOps()) {
+    if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
+      for (mlir::Region &subRegion : op.getRegions()) {
+        result = result.merge(getModRef(subRegion, location));
+        // Fast return is already mod and ref.
+        if (result.isModAndRef())
+          return result;
+      }
+      // In MLIR, RecursiveMemoryEffects can be combined with
+      // MemoryEffectOpInterface to describe extra effects on top of the
+      // effects of the nested operations.  However, the presence of
+      // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
+      // implies the operation has no other memory effects than the one of its
+      // nested operations.
+      if (!mlir::isa<mlir::MemoryEffectOpInterface>(op))
+        continue;
+    }
+    result = result.merge(getModRef(&op, location));
+    if (result.isModAndRef())
+      return result;
+  }
+  return result;
+}
+
 AliasAnalysis::Source::Attributes
 getAttrsFromVariable(fir::FortranVariableOpInterface var) {
   AliasAnalysis::Source::Attributes attrs;
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index e8c15a256b9da0..9327e7ad5875cf 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1119,28 +1119,6 @@ class EvaluateIntoMemoryAssignBufferization
                   mlir::PatternRewriter &rewriter) const override;
 };
 
-static bool mayReadOrWrite(mlir::Region &region, mlir::Value var) {
-  fir::AliasAnalysis aliasAnalysis;
-  for (mlir::Operation &op : region.getOps()) {
-    if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
-      for (mlir::Region &subRegion : op.getRegions())
-        if (mayReadOrWrite(subRegion, var))
-          return true;
-      // In MLIR, RecursiveMemoryEffects can be combined with
-      // MemoryEffectOpInterface to describe extra effects on top of the
-      // effects of the nested operations.  However, the presence of
-      // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
-      // implies the operation has no other memory effects than the one of its
-      // nested operations.
-      if (!mlir::isa<mlir::MemoryEffectOpInterface>(op))
-        continue;
-    }
-    if (!aliasAnalysis.getModRef(&op, var).isNoModRef())
-      return true;
-  }
-  return false;
-}
-
 static llvm::LogicalResult
 tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
                           mlir::PatternRewriter &rewriter) {
@@ -1168,9 +1146,17 @@ tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
   // RHS lengths are the same.
   if (lhs.isCharacter())
     return mlir::failure();
-
+  fir::AliasAnalysis aliasAnalysis;
   // The region must not read or write the LHS.
-  if (mayReadOrWrite(evalInMem.getBody(), lhs))
+  // Note that getModRef is used instead of mlir::MemoryEffects because
+  // EvaluateInMemoryOp is typically expected to hold fir.calls and that
+  // Fortran calls cannot be modeled in a useful way with mlir::MemoryEffects:
+  // it is hard/impossible to list all the read/written SSA values in a call,
+  // but it is often possible to tell that an SSA value cannot be accessed,
+  // hence getModRef is needed here and below. Also note that getModRef uses
+  // mlir::MemoryEffects for operations that do not have special handling in
+  // getModRef.
+  if (aliasAnalysis.getModRef(evalInMem.getBody(), lhs).isModOrRef())
     return mlir::failure();
   // Any variables affected between the hlfir.evalInMem and assignment must not
   // be read or written inside the region since it will be moved at the
@@ -1184,7 +1170,8 @@ tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
   }
   for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
     mlir::Value affected = effect.getValue();
-    if (!affected || mayReadOrWrite(evalInMem.getBody(), affected))
+    if (!affected ||
+        aliasAnalysis.getModRef(evalInMem.getBody(), affected).isModOrRef())
       return mlir::failure();
   }
 



More information about the flang-commits mailing list