[flang-commits] [flang] [flang][hlfir] optimize hlfir.eval_in_mem bufferization (PR #118069)
via flang-commits
flang-commits at lists.llvm.org
Mon Dec 2 02:02:48 PST 2024
https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/118069
>From 5d43a6bd4fb27015e6917eb471bc725fe31b9ac2 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 28 Nov 2024 08:25:05 -0800
Subject: [PATCH 1/3] [flang][hlfir] add hlfir.eval_in_mem operation
---
.../flang/Optimizer/Builder/HLFIRTools.h | 19 ++++
.../include/flang/Optimizer/HLFIR/HLFIROps.td | 59 ++++++++++
flang/lib/Optimizer/Builder/HLFIRTools.cpp | 47 ++++++++
flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp | 76 ++++++++++---
.../HLFIR/Transforms/BufferizeHLFIR.cpp | 33 +++++-
flang/test/HLFIR/eval_in_mem-codegen.fir | 107 ++++++++++++++++++
flang/test/HLFIR/eval_in_mem.fir | 99 ++++++++++++++++
flang/test/HLFIR/invalid.fir | 34 ++++++
8 files changed, 454 insertions(+), 20 deletions(-)
create mode 100644 flang/test/HLFIR/eval_in_mem-codegen.fir
create mode 100644 flang/test/HLFIR/eval_in_mem.fir
diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index f073f494b3fb21..efbd9e4f50d432 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -33,6 +33,7 @@ class AssociateOp;
class ElementalOp;
class ElementalOpInterface;
class ElementalAddrOp;
+class EvaluateInMemoryOp;
class YieldElementOp;
/// Is this a Fortran variable for which the defining op carrying the Fortran
@@ -398,6 +399,24 @@ mlir::Value inlineElementalOp(
mlir::IRMapping &mapper,
const std::function<bool(hlfir::ElementalOp)> &mustRecursivelyInline);
+/// Create a new temporary with the shape and parameters of the provided
+/// hlfir.eval_in_mem operation and clone the body of the hlfir.eval_in_mem
+/// operating on this new temporary. returns the temporary and whether the
+/// temporary is heap or stack allocated.
+std::pair<hlfir::Entity, bool>
+computeEvaluateOpInNewTemp(mlir::Location, fir::FirOpBuilder &,
+ hlfir::EvaluateInMemoryOp evalInMem,
+ mlir::Value shape, mlir::ValueRange typeParams);
+
+// Clone the body of the hlfir.eval_in_mem operating on this the provided
+// storage. The provided storage must be a contiguous "raw" memory reference
+// (not a fir.box) big enough to hold the value computed by hlfir.eval_in_mem.
+// No runtime check is inserted by this utility to enforce that. It is also
+// usually invalid to provide some storage that is already addressed directly
+// or indirectly inside the hlfir.eval_in_mem body.
+void computeEvaluateOpIn(mlir::Location, fir::FirOpBuilder &,
+ hlfir::EvaluateInMemoryOp, mlir::Value storage);
+
std::pair<fir::ExtendedValue, std::optional<hlfir::CleanupFunction>>
convertToValue(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity);
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 1ab8793f726523..a9826543f48b69 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -1755,4 +1755,63 @@ def hlfir_CharExtremumOp : hlfir_Op<"char_extremum",
let hasVerifier = 1;
}
+def hlfir_EvaluateInMemoryOp : hlfir_Op<"eval_in_mem", [AttrSizedOperandSegments,
+ RecursiveMemoryEffects, RecursivelySpeculatable,
+ SingleBlockImplicitTerminator<"fir::FirEndOp">]> {
+ let summary = "Wrap an in-memory implementation that computes expression value";
+ let description = [{
+ Returns a Fortran expression value for which the computation is
+ implemented inside the region operating on the block argument which
+ is a raw memory reference corresponding to the expression type.
+
+ The shape and type parameters of the expressions are operands of the
+ operations.
+
+ The memory cannot escape the region, and it is not described how it is
+ allocated. This facilitates later elision of the temporary storage for the
+ expression evaluation if it can be evaluated in some other storage (like a
+ left-hand side variable).
+
+ Example:
+
+ A function returning an array can be represented as:
+ ```
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+ ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+ %3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+ fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+ }
+ ```
+ }];
+
+ let arguments = (ins
+ Optional<fir_ShapeType>:$shape,
+ Variadic<AnyIntegerType>:$typeparams
+ );
+
+ let results = (outs hlfir_ExprType);
+ let regions = (region SizedRegion<1>:$body);
+
+ let assemblyFormat = [{
+ (`shape` $shape^)? (`typeparams` $typeparams^)?
+ attr-dict `:` functional-type(operands, results)
+ $body}];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape,
+ CArg<"mlir::ValueRange", "{}">:$typeparams)>
+ ];
+
+ let extraClassDeclaration = [{
+ // Return block argument representing the memory where the expression
+ // is evaluated.
+ mlir::Value getMemory() {return getBody().getArgument(0);}
+ }];
+
+ let hasVerifier = 1;
+}
+
+
#endif // FORTRAN_DIALECT_HLFIR_OPS
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 7425ccf7fc0e30..1bd950f2445ee4 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -535,6 +535,8 @@ static mlir::Value tryRetrievingShapeOrShift(hlfir::Entity entity) {
if (mlir::isa<hlfir::ExprType>(entity.getType())) {
if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
return elemental.getShape();
+ if (auto evalInMem = entity.getDefiningOp<hlfir::EvaluateInMemoryOp>())
+ return evalInMem.getShape();
return mlir::Value{};
}
if (auto varIface = entity.getIfVariableInterface())
@@ -642,6 +644,11 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
result.append(elemental.getTypeparams().begin(),
elemental.getTypeparams().end());
return;
+ } else if (auto evalInMem =
+ expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
+ result.append(evalInMem.getTypeparams().begin(),
+ evalInMem.getTypeparams().end());
+ return;
} else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
return;
@@ -1313,3 +1320,43 @@ hlfir::genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
};
return {hlfir::Entity{convertedRhs}, cleanup};
}
+
+std::pair<hlfir::Entity, bool> hlfir::computeEvaluateOpInNewTemp(
+ mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::EvaluateInMemoryOp evalInMem, mlir::Value shape,
+ mlir::ValueRange typeParams) {
+ llvm::StringRef tmpName{".tmp.expr_result"};
+ llvm::SmallVector<mlir::Value> extents =
+ hlfir::getIndexExtents(loc, builder, shape);
+ mlir::Type baseType =
+ hlfir::getFortranElementOrSequenceType(evalInMem.getType());
+ bool heapAllocated = fir::hasDynamicSize(baseType);
+ // Note: temporaries are stack allocated here when possible (do not require
+ // stack save/restore) because flang has always stack allocated function
+ // results.
+ mlir::Value temp = heapAllocated
+ ? builder.createHeapTemporary(loc, baseType, tmpName,
+ extents, typeParams)
+ : builder.createTemporary(loc, baseType, tmpName,
+ extents, typeParams);
+ mlir::Value innerMemory = evalInMem.getMemory();
+ temp = builder.createConvert(loc, innerMemory.getType(), temp);
+ auto declareOp = builder.create<hlfir::DeclareOp>(
+ loc, temp, tmpName, shape, typeParams,
+ /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{});
+ computeEvaluateOpIn(loc, builder, evalInMem, declareOp.getOriginalBase());
+ return {hlfir::Entity{declareOp.getBase()}, /*heapAllocated=*/heapAllocated};
+}
+
+void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::EvaluateInMemoryOp evalInMem,
+ mlir::Value storage) {
+ mlir::Value innerMemory = evalInMem.getMemory();
+ mlir::Value storageCast =
+ builder.createConvert(loc, innerMemory.getType(), storage);
+ mlir::IRMapping mapper;
+ mapper.map(innerMemory, storageCast);
+ for (auto &op : evalInMem.getBody().front().without_terminator())
+ builder.clone(op, mapper);
+ return;
+}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index b593383ff2848d..87519882446485 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -333,6 +333,25 @@ static void printDesignatorComplexPart(mlir::OpAsmPrinter &p,
p << "real";
}
}
+template <typename Op>
+static llvm::LogicalResult verifyTypeparams(Op &op, mlir::Type elementType,
+ unsigned numLenParam) {
+ if (mlir::isa<fir::CharacterType>(elementType)) {
+ if (numLenParam != 1)
+ return op.emitOpError("must be provided one length parameter when the "
+ "result is a character");
+ } else if (fir::isRecordWithTypeParameters(elementType)) {
+ if (numLenParam !=
+ mlir::cast<fir::RecordType>(elementType).getNumLenParams())
+ return op.emitOpError("must be provided the same number of length "
+ "parameters as in the result derived type");
+ } else if (numLenParam != 0) {
+ return op.emitOpError(
+ "must not be provided length parameters if the result "
+ "type does not have length parameters");
+ }
+ return mlir::success();
+}
llvm::LogicalResult hlfir::DesignateOp::verify() {
mlir::Type memrefType = getMemref().getType();
@@ -462,20 +481,10 @@ llvm::LogicalResult hlfir::DesignateOp::verify() {
return emitOpError("shape must be a fir.shape or fir.shapeshift with "
"the rank of the result");
}
- auto numLenParam = getTypeparams().size();
- if (mlir::isa<fir::CharacterType>(outputElementType)) {
- if (numLenParam != 1)
- return emitOpError("must be provided one length parameter when the "
- "result is a character");
- } else if (fir::isRecordWithTypeParameters(outputElementType)) {
- if (numLenParam !=
- mlir::cast<fir::RecordType>(outputElementType).getNumLenParams())
- return emitOpError("must be provided the same number of length "
- "parameters as in the result derived type");
- } else if (numLenParam != 0) {
- return emitOpError("must not be provided length parameters if the result "
- "type does not have length parameters");
- }
+ if (auto res =
+ verifyTypeparams(*this, outputElementType, getTypeparams().size());
+ failed(res))
+ return res;
}
return mlir::success();
}
@@ -1989,6 +1998,45 @@ hlfir::GetLengthOp::canonicalize(GetLengthOp getLength,
return mlir::success();
}
+//===----------------------------------------------------------------------===//
+// EvaluateInMemoryOp
+//===----------------------------------------------------------------------===//
+
+void hlfir::EvaluateInMemoryOp::build(mlir::OpBuilder &builder,
+ mlir::OperationState &odsState,
+ mlir::Type resultType, mlir::Value shape,
+ mlir::ValueRange typeparams) {
+ odsState.addTypes(resultType);
+ if (shape)
+ odsState.addOperands(shape);
+ odsState.addOperands(typeparams);
+ odsState.addAttribute(
+ getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {shape ? 1 : 0, static_cast<int32_t>(typeparams.size())}));
+ mlir::Region *bodyRegion = odsState.addRegion();
+ bodyRegion->push_back(new mlir::Block{});
+ mlir::Type memType = fir::ReferenceType::get(
+ hlfir::getFortranElementOrSequenceType(resultType));
+ bodyRegion->front().addArgument(memType, odsState.location);
+ EvaluateInMemoryOp::ensureTerminator(*bodyRegion, builder, odsState.location);
+}
+
+llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() {
+ unsigned shapeRank = 0;
+ if (mlir::Value shape = getShape())
+ if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(shape.getType()))
+ shapeRank = shapeTy.getRank();
+ auto exprType = mlir::cast<hlfir::ExprType>(getResult().getType());
+ if (shapeRank != exprType.getRank())
+ return emitOpError("`shape` rank must match the result rank");
+ mlir::Type elementType = exprType.getElementType();
+ if (auto res = verifyTypeparams(*this, elementType, getTypeparams().size());
+ failed(res))
+ return res;
+ return mlir::success();
+}
+
#include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
#define GET_OP_CLASSES
#include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 1848dbe2c7a2c2..347f0a5630777f 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -905,6 +905,26 @@ struct CharExtremumOpConversion
}
};
+struct EvaluateInMemoryOpConversion
+ : public mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp> {
+ using mlir::OpConversionPattern<
+ hlfir::EvaluateInMemoryOp>::OpConversionPattern;
+ explicit EvaluateInMemoryOpConversion(mlir::MLIRContext *ctx)
+ : mlir::OpConversionPattern<hlfir::EvaluateInMemoryOp>{ctx} {}
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::EvaluateInMemoryOp evalInMemOp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ mlir::Location loc = evalInMemOp->getLoc();
+ fir::FirOpBuilder builder(rewriter, evalInMemOp.getOperation());
+ auto [temp, isHeapAlloc] = hlfir::computeEvaluateOpInNewTemp(
+ loc, builder, evalInMemOp, adaptor.getShape(), adaptor.getTypeparams());
+ mlir::Value bufferizedExpr =
+ packageBufferizedExpr(loc, builder, temp, isHeapAlloc);
+ rewriter.replaceOp(evalInMemOp, bufferizedExpr);
+ return mlir::success();
+ }
+};
+
class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
public:
void runOnOperation() override {
@@ -918,12 +938,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
auto module = this->getOperation();
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
- patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
- AssociateOpConversion, CharExtremumOpConversion,
- ConcatOpConversion, DestroyOpConversion,
- ElementalOpConversion, EndAssociateOpConversion,
- NoReassocOpConversion, SetLengthOpConversion,
- ShapeOfOpConversion, GetLengthOpConversion>(context);
+ patterns
+ .insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
+ AssociateOpConversion, CharExtremumOpConversion,
+ ConcatOpConversion, DestroyOpConversion, ElementalOpConversion,
+ EndAssociateOpConversion, EvaluateInMemoryOpConversion,
+ NoReassocOpConversion, SetLengthOpConversion,
+ ShapeOfOpConversion, GetLengthOpConversion>(context);
mlir::ConversionTarget target(*context);
// Note that YieldElementOp is not marked as an illegal operation.
// It must be erased by its parent converter and there is no explicit
diff --git a/flang/test/HLFIR/eval_in_mem-codegen.fir b/flang/test/HLFIR/eval_in_mem-codegen.fir
new file mode 100644
index 00000000000000..26a989832ca927
--- /dev/null
+++ b/flang/test/HLFIR/eval_in_mem-codegen.fir
@@ -0,0 +1,107 @@
+// Test hlfir.eval_in_mem default code generation.
+
+// RUN: fir-opt %s --bufferize-hlfir -o - | FileCheck %s
+
+func.func @_QPtest() {
+ %c10 = arith.constant 10 : index
+ %0 = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+ ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+ %3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+ fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+ }
+ hlfir.assign %2 to %0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+ hlfir.destroy %2 : !hlfir.expr<10xf32>
+ return
+}
+fir.global internal @_QFtestEx : !fir.array<10xf32>
+func.func private @_QParray_func() -> !fir.array<10xf32>
+
+
+func.func @_QPtest_char() {
+ %c10 = arith.constant 10 : index
+ %c5 = arith.constant 5 : index
+ %0 = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2 = hlfir.eval_in_mem shape %1 typeparams %c5 : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> {
+ ^bb0(%arg0: !fir.ref<!fir.array<10x!fir.char<1,5>>>):
+ %3 = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+ fir.save_result %3 to %arg0(%1) typeparams %c5 : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+ }
+ hlfir.assign %2 to %0 : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+ hlfir.destroy %2 : !hlfir.expr<10x!fir.char<1,5>>
+ return
+}
+
+fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+func.func @test_dynamic(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: index) {
+ %0 = fir.shape %arg1 : (index) -> !fir.shape<1>
+ %1 = hlfir.eval_in_mem shape %0 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+ ^bb0(%arg2: !fir.ref<!fir.array<?xf32>>):
+ %2 = fir.call @_QPdyn_array_func(%arg1) : (index) -> !fir.array<?xf32>
+ fir.save_result %2 to %arg2(%0) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+ }
+ hlfir.assign %1 to %arg0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+ hlfir.destroy %1 : !hlfir.expr<?xf32>
+ return
+}
+func.func private @_QPdyn_array_func(index) -> !fir.array<?xf32>
+
+// CHECK-LABEL: func.func @_QPtest() {
+// CHECK: %[[VAL_0:.*]] = fir.alloca !fir.array<10xf32> {bindc_name = ".tmp.expr_result"}
+// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_2:.*]] = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+// CHECK: %[[VAL_5:.*]] = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK: fir.save_result %[[VAL_5]] to %[[VAL_4]]#1(%[[VAL_3]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+// CHECK: %[[VAL_6:.*]] = arith.constant false
+// CHECK: %[[VAL_7:.*]] = fir.undefined tuple<!fir.ref<!fir.array<10xf32>>, i1>
+// CHECK: %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_6]], [1 : index] : (tuple<!fir.ref<!fir.array<10xf32>>, i1>, i1) -> tuple<!fir.ref<!fir.array<10xf32>>, i1>
+// CHECK: %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_4]]#0, [0 : index] : (tuple<!fir.ref<!fir.array<10xf32>>, i1>, !fir.ref<!fir.array<10xf32>>) -> tuple<!fir.ref<!fir.array<10xf32>>, i1>
+// CHECK: hlfir.assign %[[VAL_4]]#0 to %[[VAL_2]] : !fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>
+// CHECK: return
+// CHECK: }
+// CHECK: fir.global internal @_QFtestEx : !fir.array<10xf32>
+// CHECK: func.func private @_QParray_func() -> !fir.array<10xf32>
+
+// CHECK-LABEL: func.func @_QPtest_char() {
+// CHECK: %[[VAL_0:.*]] = fir.alloca !fir.array<10x!fir.char<1,5>> {bindc_name = ".tmp.expr_result"}
+// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_3:.*]] = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_4]]) typeparams %[[VAL_2]] {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index) -> (!fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>)
+// CHECK: %[[VAL_6:.*]] = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+// CHECK: fir.save_result %[[VAL_6]] to %[[VAL_5]]#1(%[[VAL_4]]) typeparams %[[VAL_2]] : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+// CHECK: %[[VAL_7:.*]] = arith.constant false
+// CHECK: %[[VAL_8:.*]] = fir.undefined tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>
+// CHECK: %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_7]], [1 : index] : (tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>, i1) -> tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>
+// CHECK: %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_5]]#0, [0 : index] : (tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>, !fir.ref<!fir.array<10x!fir.char<1,5>>>) -> tuple<!fir.ref<!fir.array<10x!fir.char<1,5>>>, i1>
+// CHECK: hlfir.assign %[[VAL_5]]#0 to %[[VAL_3]] : !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK: return
+// CHECK: }
+// CHECK: fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+// CHECK: func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+// CHECK-LABEL: func.func @test_dynamic(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>>,
+// CHECK-SAME: %[[VAL_1:.*]]: index) {
+// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_3:.*]] = fir.allocmem !fir.array<?xf32>, %[[VAL_1]] {bindc_name = ".tmp.expr_result", uniq_name = ""}
+// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.heap<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]](%[[VAL_2]]) {uniq_name = ".tmp.expr_result"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+// CHECK: %[[VAL_6:.*]] = fir.call @_QPdyn_array_func(%[[VAL_1]]) : (index) -> !fir.array<?xf32>
+// CHECK: fir.save_result %[[VAL_6]] to %[[VAL_5]]#1(%[[VAL_2]]) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+// CHECK: %[[VAL_7:.*]] = arith.constant true
+// CHECK: %[[VAL_8:.*]] = fir.undefined tuple<!fir.box<!fir.array<?xf32>>, i1>
+// CHECK: %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_7]], [1 : index] : (tuple<!fir.box<!fir.array<?xf32>>, i1>, i1) -> tuple<!fir.box<!fir.array<?xf32>>, i1>
+// CHECK: %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_5]]#0, [0 : index] : (tuple<!fir.box<!fir.array<?xf32>>, i1>, !fir.box<!fir.array<?xf32>>) -> tuple<!fir.box<!fir.array<?xf32>>, i1>
+// CHECK: hlfir.assign %[[VAL_5]]#0 to %[[VAL_0]] : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
+// CHECK: %[[VAL_11:.*]] = fir.box_addr %[[VAL_5]]#0 : (!fir.box<!fir.array<?xf32>>) -> !fir.heap<!fir.array<?xf32>>
+// CHECK: fir.freemem %[[VAL_11]] : !fir.heap<!fir.array<?xf32>>
+// CHECK: return
+// CHECK: }
diff --git a/flang/test/HLFIR/eval_in_mem.fir b/flang/test/HLFIR/eval_in_mem.fir
new file mode 100644
index 00000000000000..34e48ed5be5452
--- /dev/null
+++ b/flang/test/HLFIR/eval_in_mem.fir
@@ -0,0 +1,99 @@
+// Test hlfir.eval_in_mem operation parse, verify (no errors), and unparse.
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+func.func @_QPtest() {
+ %c10 = arith.constant 10 : index
+ %0 = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+ ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+ %3 = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+ fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+ }
+ hlfir.assign %2 to %0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+ hlfir.destroy %2 : !hlfir.expr<10xf32>
+ return
+}
+fir.global internal @_QFtestEx : !fir.array<10xf32>
+func.func private @_QParray_func() -> !fir.array<10xf32>
+
+
+func.func @_QPtest_char() {
+ %c10 = arith.constant 10 : index
+ %c5 = arith.constant 5 : index
+ %0 = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2 = hlfir.eval_in_mem shape %1 typeparams %c5 : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> {
+ ^bb0(%arg0: !fir.ref<!fir.array<10x!fir.char<1,5>>>):
+ %3 = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+ fir.save_result %3 to %arg0(%1) typeparams %c5 : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+ }
+ hlfir.assign %2 to %0 : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+ hlfir.destroy %2 : !hlfir.expr<10x!fir.char<1,5>>
+ return
+}
+
+fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+func.func @test_dynamic(%arg0: !fir.box<!fir.array<?xf32>>, %arg1: index) {
+ %0 = fir.shape %arg1 : (index) -> !fir.shape<1>
+ %1 = hlfir.eval_in_mem shape %0 : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+ ^bb0(%arg2: !fir.ref<!fir.array<?xf32>>):
+ %2 = fir.call @_QPdyn_array_func(%arg1) : (index) -> !fir.array<?xf32>
+ fir.save_result %2 to %arg2(%0) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+ }
+ hlfir.assign %1 to %arg0 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+ hlfir.destroy %1 : !hlfir.expr<?xf32>
+ return
+}
+func.func private @_QPdyn_array_func(index) -> !fir.array<?xf32>
+
+// CHECK-LABEL: func.func @_QPtest() {
+// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFtestEx) : !fir.ref<!fir.array<10xf32>>
+// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_3:.*]] = hlfir.eval_in_mem shape %[[VAL_2]] : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+// CHECK: ^bb0(%[[VAL_4:.*]]: !fir.ref<!fir.array<10xf32>>):
+// CHECK: %[[VAL_5:.*]] = fir.call @_QParray_func() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK: fir.save_result %[[VAL_5]] to %[[VAL_4]](%[[VAL_2]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+// CHECK: }
+// CHECK: hlfir.assign %[[VAL_3]] to %[[VAL_1]] : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+// CHECK: hlfir.destroy %[[VAL_3]] : !hlfir.expr<10xf32>
+// CHECK: return
+// CHECK: }
+// CHECK: fir.global internal @_QFtestEx : !fir.array<10xf32>
+// CHECK: func.func private @_QParray_func() -> !fir.array<10xf32>
+
+// CHECK-LABEL: func.func @_QPtest_char() {
+// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_2:.*]] = fir.address_of(@_QFtest_charEx) : !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_4:.*]] = hlfir.eval_in_mem shape %[[VAL_3]] typeparams %[[VAL_1]] : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> {
+// CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref<!fir.array<10x!fir.char<1,5>>>):
+// CHECK: %[[VAL_6:.*]] = fir.call @_QPchar_array_func() fastmath<contract> : () -> !fir.array<10x!fir.char<1,5>>
+// CHECK: fir.save_result %[[VAL_6]] to %[[VAL_5]](%[[VAL_3]]) typeparams %[[VAL_1]] : !fir.array<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>, !fir.shape<1>, index
+// CHECK: }
+// CHECK: hlfir.assign %[[VAL_4]] to %[[VAL_2]] : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref<!fir.array<10x!fir.char<1,5>>>
+// CHECK: hlfir.destroy %[[VAL_4]] : !hlfir.expr<10x!fir.char<1,5>>
+// CHECK: return
+// CHECK: }
+// CHECK: fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>>
+// CHECK: func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>>
+
+// CHECK-LABEL: func.func @test_dynamic(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>>,
+// CHECK-SAME: %[[VAL_1:.*]]: index) {
+// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_3:.*]] = hlfir.eval_in_mem shape %[[VAL_2]] : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
+// CHECK: ^bb0(%[[VAL_4:.*]]: !fir.ref<!fir.array<?xf32>>):
+// CHECK: %[[VAL_5:.*]] = fir.call @_QPdyn_array_func(%[[VAL_1]]) : (index) -> !fir.array<?xf32>
+// CHECK: fir.save_result %[[VAL_5]] to %[[VAL_4]](%[[VAL_2]]) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+// CHECK: }
+// CHECK: hlfir.assign %[[VAL_3]] to %[[VAL_0]] : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
+// CHECK: hlfir.destroy %[[VAL_3]] : !hlfir.expr<?xf32>
+// CHECK: return
+// CHECK: }
+// CHECK: func.func private @_QPdyn_array_func(index) -> !fir.array<?xf32>
diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index c390dddcf3f387..5c5db7aac06970 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -1314,3 +1314,37 @@ func.func @end_associate_with_alloc_comp(%var: !hlfir.expr<?x!fir.type<_QMtypesT
hlfir.end_associate %4#1, %4#2 : !fir.ref<!fir.array<?x!fir.type<_QMtypesTt{x:!fir.box<!fir.heap<f32>>}>>>, i1
return
}
+
+// -----
+
+func.func @bad_eval_in_mem_1() {
+ %c10 = arith.constant 10 : index
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+// expected-error at +1 {{'hlfir.eval_in_mem' op result #0 must be The type of an array, character, or derived type Fortran expression, but got '!fir.array<10xf32>'}}
+ %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !fir.array<10xf32> {
+ ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+ }
+ return
+}
+
+// -----
+
+func.func @bad_eval_in_mem_2() {
+ %c10 = arith.constant 10 : index
+ %1 = fir.shape %c10, %c10 : (index, index) -> !fir.shape<2>
+ // expected-error at +1 {{'hlfir.eval_in_mem' op `shape` rank must match the result rank}}
+ %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<2>) -> !hlfir.expr<10xf32> {
+ ^bb0(%arg0: !fir.ref<!fir.array<10xf32>>):
+ }
+ return
+}
+
+// -----
+
+func.func @bad_eval_in_mem_3() {
+ // expected-error at +1 {{'hlfir.eval_in_mem' op must be provided one length parameter when the result is a character}}
+ %1 = hlfir.eval_in_mem : () -> !hlfir.expr<!fir.char<1,?>> {
+ ^bb0(%arg0: !fir.ref<!fir.char<1,?>>):
+ }
+ return
+}
>From 6be998ec6f74937820e350f51796490b41a76d46 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Thu, 28 Nov 2024 08:26:39 -0800
Subject: [PATCH 2/3] [flang][hlfir] optimize hlfir.eval_in_mem bufferization
---
.../lib/Optimizer/Analysis/AliasAnalysis.cpp | 14 ++-
.../Transforms/OptimizedBufferization.cpp | 108 ++++++++++++++++++
.../HLFIR/opt-bufferization-eval_in_mem.fir | 67 +++++++++++
3 files changed, 188 insertions(+), 1 deletion(-)
create mode 100644 flang/test/HLFIR/opt-bufferization-eval_in_mem.fir
diff --git a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
index 2b24791d6c7c52..c561285b9feef5 100644
--- a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
+++ b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
@@ -91,6 +91,13 @@ bool AliasAnalysis::Source::isDummyArgument() const {
return false;
}
+static bool isEvaluateInMemoryBlockArg(mlir::Value v) {
+ if (auto evalInMem = llvm::dyn_cast_or_null<hlfir::EvaluateInMemoryOp>(
+ v.getParentRegion()->getParentOp()))
+ return evalInMem.getMemory() == v;
+ return false;
+}
+
bool AliasAnalysis::Source::isData() const { return origin.isData; }
bool AliasAnalysis::Source::isBoxData() const {
return mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(valueType)) &&
@@ -698,7 +705,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
breakFromLoop = true;
});
}
- if (!defOp && type == SourceKind::Unknown)
+ if (!defOp && type == SourceKind::Unknown) {
// Check if the memory source is coming through a dummy argument.
if (isDummyArgument(v)) {
type = SourceKind::Argument;
@@ -708,7 +715,12 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v,
if (isPointerReference(ty))
attributes.set(Attribute::Pointer);
+ } else if (isEvaluateInMemoryBlockArg(v)) {
+ // hlfir.eval_in_mem block operands is allocated by the operation.
+ type = SourceKind::Allocate;
+ ty = v.getType();
}
+ }
if (type == SourceKind::Global) {
return {{global, instantiationPoint, followingData},
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index a0160b233e3cd1..e8c15a256b9da0 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1108,6 +1108,113 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
}
};
+class EvaluateIntoMemoryAssignBufferization
+ : public mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp> {
+
+public:
+ using mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp>::OpRewritePattern;
+
+ llvm::LogicalResult
+ matchAndRewrite(hlfir::EvaluateInMemoryOp,
+ mlir::PatternRewriter &rewriter) const override;
+};
+
+static bool mayReadOrWrite(mlir::Region ®ion, mlir::Value var) {
+ fir::AliasAnalysis aliasAnalysis;
+ for (mlir::Operation &op : region.getOps()) {
+ if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
+ for (mlir::Region &subRegion : op.getRegions())
+ if (mayReadOrWrite(subRegion, var))
+ return true;
+ // In MLIR, RecursiveMemoryEffects can be combined with
+ // MemoryEffectOpInterface to describe extra effects on top of the
+ // effects of the nested operations. However, the presence of
+ // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
+ // implies the operation has no other memory effects than the one of its
+ // nested operations.
+ if (!mlir::isa<mlir::MemoryEffectOpInterface>(op))
+ continue;
+ }
+ if (!aliasAnalysis.getModRef(&op, var).isNoModRef())
+ return true;
+ }
+ return false;
+}
+
+static llvm::LogicalResult
+tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
+ mlir::PatternRewriter &rewriter) {
+ mlir::Location loc = evalInMem.getLoc();
+ hlfir::DestroyOp destroy;
+ hlfir::AssignOp assign;
+ for (auto user : llvm::enumerate(evalInMem->getUsers())) {
+ if (user.index() > 2)
+ return mlir::failure();
+ mlir::TypeSwitch<mlir::Operation *, void>(user.value())
+ .Case([&](hlfir::AssignOp op) { assign = op; })
+ .Case([&](hlfir::DestroyOp op) { destroy = op; });
+ }
+ if (!assign || !destroy || destroy.mustFinalizeExpr() ||
+ assign.isAllocatableAssignment())
+ return mlir::failure();
+
+ hlfir::Entity lhs{assign.getLhs()};
+ // EvaluateInMemoryOp memory is contiguous, so in general, it can only be
+ // replace by the LHS if the LHS is contiguous.
+ if (!lhs.isSimplyContiguous())
+ return mlir::failure();
+ // Character assignment may involves truncation/padding, so the LHS
+ // cannot be used to evaluate RHS in place without proving the LHS and
+ // RHS lengths are the same.
+ if (lhs.isCharacter())
+ return mlir::failure();
+
+ // The region must not read or write the LHS.
+ if (mayReadOrWrite(evalInMem.getBody(), lhs))
+ return mlir::failure();
+ // Any variables affected between the hlfir.evalInMem and assignment must not
+ // be read or written inside the region since it will be moved at the
+ // assignment insertion point.
+ auto effects = getEffectsBetween(evalInMem->getNextNode(), assign);
+ if (!effects) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "operation with unknown effects between eval_in_mem and assign\n");
+ return mlir::failure();
+ }
+ for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
+ mlir::Value affected = effect.getValue();
+ if (!affected || mayReadOrWrite(evalInMem.getBody(), affected))
+ return mlir::failure();
+ }
+
+ rewriter.setInsertionPoint(assign);
+ fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
+ mlir::Value rawLhs = hlfir::genVariableRawAddress(loc, builder, lhs);
+ hlfir::computeEvaluateOpIn(loc, builder, evalInMem, rawLhs);
+ rewriter.eraseOp(assign);
+ rewriter.eraseOp(destroy);
+ rewriter.eraseOp(evalInMem);
+ return mlir::success();
+}
+
+llvm::LogicalResult EvaluateIntoMemoryAssignBufferization::matchAndRewrite(
+ hlfir::EvaluateInMemoryOp evalInMem,
+ mlir::PatternRewriter &rewriter) const {
+ if (mlir::succeeded(tryUsingAssignLhsDirectly(evalInMem, rewriter)))
+ return mlir::success();
+ // Rewrite to temp + as_expr here so that the assign + as_expr pattern can
+ // kick-in for simple types and at least implement the assignment inline
+ // instead of call Assign runtime.
+ fir::FirOpBuilder builder(rewriter, evalInMem.getOperation());
+ mlir::Location loc = evalInMem.getLoc();
+ auto [temp, isHeapAllocated] = hlfir::computeEvaluateOpInNewTemp(
+ loc, builder, evalInMem, evalInMem.getShape(), evalInMem.getTypeparams());
+ rewriter.replaceOpWithNewOp<hlfir::AsExprOp>(
+ evalInMem, temp, /*mustFree=*/builder.createBool(loc, isHeapAllocated));
+ return mlir::success();
+}
+
class OptimizedBufferizationPass
: public hlfir::impl::OptimizedBufferizationBase<
OptimizedBufferizationPass> {
@@ -1130,6 +1237,7 @@ class OptimizedBufferizationPass
patterns.insert<ElementalAssignBufferization>(context);
patterns.insert<BroadcastAssignBufferization>(context);
patterns.insert<VariableAssignBufferization>(context);
+ patterns.insert<EvaluateIntoMemoryAssignBufferization>(context);
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
patterns.insert<ReductionConversion<hlfir::AllOp>>(context);
diff --git a/flang/test/HLFIR/opt-bufferization-eval_in_mem.fir b/flang/test/HLFIR/opt-bufferization-eval_in_mem.fir
new file mode 100644
index 00000000000000..984c0bcbaddcc3
--- /dev/null
+++ b/flang/test/HLFIR/opt-bufferization-eval_in_mem.fir
@@ -0,0 +1,67 @@
+// RUN: fir-opt --opt-bufferization %s | FileCheck %s
+
+// Fortran F2023 15.5.2.14 point 4. ensures that _QPfoo cannot access _QFtestEx
+// and the temporary storage for the result can be avoided.
+func.func @_QPtest(%arg0: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x"}) {
+ %c10 = arith.constant 10 : index
+ %0 = fir.dummy_scope : !fir.dscope
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2:2 = hlfir.declare %arg0(%1) dummy_scope %0 {uniq_name = "_QFtestEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+ %3 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+ ^bb0(%arg1: !fir.ref<!fir.array<10xf32>>):
+ %4 = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
+ fir.save_result %4 to %arg1(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+ }
+ hlfir.assign %3 to %2#0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+ hlfir.destroy %3 : !hlfir.expr<10xf32>
+ return
+}
+func.func private @_QPfoo() -> !fir.array<10xf32>
+
+// CHECK-LABEL: func.func @_QPtest(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x"}) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_2:.*]] = fir.dummy_scope : !fir.dscope
+// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) dummy_scope %[[VAL_2]] {uniq_name = "_QFtestEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+// CHECK: %[[VAL_5:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK: fir.save_result %[[VAL_5]] to %[[VAL_4]]#1(%[[VAL_3]]) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+// CHECK: return
+// CHECK: }
+
+
+// Temporary storage cannot be avoided in this case since
+// _QFnegative_test_is_targetEx has the TARGET attribute.
+func.func @_QPnegative_test_is_target(%arg0: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x", fir.target}) {
+ %c10 = arith.constant 10 : index
+ %0 = fir.dummy_scope : !fir.dscope
+ %1 = fir.shape %c10 : (index) -> !fir.shape<1>
+ %2:2 = hlfir.declare %arg0(%1) dummy_scope %0 {fortran_attrs = #fir.var_attrs<target>, uniq_name = "_QFnegative_test_is_targetEx"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
+ %3 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> {
+ ^bb0(%arg1: !fir.ref<!fir.array<10xf32>>):
+ %4 = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
+ fir.save_result %4 to %arg1(%1) : !fir.array<10xf32>, !fir.ref<!fir.array<10xf32>>, !fir.shape<1>
+ }
+ hlfir.assign %3 to %2#0 : !hlfir.expr<10xf32>, !fir.ref<!fir.array<10xf32>>
+ hlfir.destroy %3 : !hlfir.expr<10xf32>
+ return
+}
+// CHECK-LABEL: func.func @_QPnegative_test_is_target(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<10xf32>> {fir.bindc_name = "x", fir.target}) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant false
+// CHECK: %[[VAL_3:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_4:.*]] = fir.alloca !fir.array<10xf32>
+// CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]]{{.*}}
+// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_4]]{{.*}}
+// CHECK: %[[VAL_9:.*]] = fir.call @_QPfoo() fastmath<contract> : () -> !fir.array<10xf32>
+// CHECK: fir.save_result %[[VAL_9]] to %[[VAL_8]]#1{{.*}}
+// CHECK: %[[VAL_10:.*]] = hlfir.as_expr %[[VAL_8]]#0 move %[[VAL_2]] : (!fir.ref<!fir.array<10xf32>>, i1) -> !hlfir.expr<10xf32>
+// CHECK: fir.do_loop %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_1]] unordered {
+// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_10]], %[[VAL_11]] : (!hlfir.expr<10xf32>, index) -> f32
+// CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_11]]) : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
+// CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_13]] : f32, !fir.ref<f32>
+// CHECK: }
+// CHECK: hlfir.destroy %[[VAL_10]] : !hlfir.expr<10xf32>
+// CHECK: return
+// CHECK: }
>From d68b5b2652831cda053c00465b33039bc645bc02 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 2 Dec 2024 01:59:33 -0800
Subject: [PATCH 3/3] PR118069 comment: mayReadOrWrite to getModRef
---
.../flang/Optimizer/Analysis/AliasAnalysis.h | 6 +++
.../lib/Optimizer/Analysis/AliasAnalysis.cpp | 27 ++++++++++++++
.../Transforms/OptimizedBufferization.cpp | 37 ++++++-------------
3 files changed, 45 insertions(+), 25 deletions(-)
diff --git a/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h b/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h
index e410831c0fc3eb..8d17e4e476d10d 100644
--- a/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h
+++ b/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h
@@ -198,6 +198,12 @@ struct AliasAnalysis {
/// Return the modify-reference behavior of `op` on `location`.
mlir::ModRefResult getModRef(mlir::Operation *op, mlir::Value location);
+ /// Return the modify-reference behavior of operations inside `region` on
+ /// `location`. Contrary to getModRef(operation, location), this will visit
+ /// nested regions recursively according to the HasRecursiveMemoryEffects
+ /// trait.
+ mlir::ModRefResult getModRef(mlir::Region ®ion, mlir::Value location);
+
/// Return the memory source of a value.
/// If getLastInstantiationPoint is true, the search for the source
/// will stop at [hl]fir.declare if it represents a dummy
diff --git a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
index c561285b9feef5..0b0f83d024ce33 100644
--- a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
+++ b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp
@@ -464,6 +464,33 @@ ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) {
return result;
}
+ModRefResult AliasAnalysis::getModRef(mlir::Region ®ion,
+ mlir::Value location) {
+ ModRefResult result = ModRefResult::getNoModRef();
+ for (mlir::Operation &op : region.getOps()) {
+ if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
+ for (mlir::Region &subRegion : op.getRegions()) {
+ result = result.merge(getModRef(subRegion, location));
+ // Fast return is already mod and ref.
+ if (result.isModAndRef())
+ return result;
+ }
+ // In MLIR, RecursiveMemoryEffects can be combined with
+ // MemoryEffectOpInterface to describe extra effects on top of the
+ // effects of the nested operations. However, the presence of
+ // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
+ // implies the operation has no other memory effects than the one of its
+ // nested operations.
+ if (!mlir::isa<mlir::MemoryEffectOpInterface>(op))
+ continue;
+ }
+ result = result.merge(getModRef(&op, location));
+ if (result.isModAndRef())
+ return result;
+ }
+ return result;
+}
+
AliasAnalysis::Source::Attributes
getAttrsFromVariable(fir::FortranVariableOpInterface var) {
AliasAnalysis::Source::Attributes attrs;
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index e8c15a256b9da0..9327e7ad5875cf 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1119,28 +1119,6 @@ class EvaluateIntoMemoryAssignBufferization
mlir::PatternRewriter &rewriter) const override;
};
-static bool mayReadOrWrite(mlir::Region ®ion, mlir::Value var) {
- fir::AliasAnalysis aliasAnalysis;
- for (mlir::Operation &op : region.getOps()) {
- if (op.hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) {
- for (mlir::Region &subRegion : op.getRegions())
- if (mayReadOrWrite(subRegion, var))
- return true;
- // In MLIR, RecursiveMemoryEffects can be combined with
- // MemoryEffectOpInterface to describe extra effects on top of the
- // effects of the nested operations. However, the presence of
- // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface
- // implies the operation has no other memory effects than the one of its
- // nested operations.
- if (!mlir::isa<mlir::MemoryEffectOpInterface>(op))
- continue;
- }
- if (!aliasAnalysis.getModRef(&op, var).isNoModRef())
- return true;
- }
- return false;
-}
-
static llvm::LogicalResult
tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
mlir::PatternRewriter &rewriter) {
@@ -1168,9 +1146,17 @@ tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
// RHS lengths are the same.
if (lhs.isCharacter())
return mlir::failure();
-
+ fir::AliasAnalysis aliasAnalysis;
// The region must not read or write the LHS.
- if (mayReadOrWrite(evalInMem.getBody(), lhs))
+ // Note that getModRef is used instead of mlir::MemoryEffects because
+ // EvaluateInMemoryOp is typically expected to hold fir.calls and that
+ // Fortran calls cannot be modeled in a useful way with mlir::MemoryEffects:
+ // it is hard/impossible to list all the read/written SSA values in a call,
+ // but it is often possible to tell that an SSA value cannot be accessed,
+ // hence getModRef is needed here and below. Also note that getModRef uses
+ // mlir::MemoryEffects for operations that do not have special handling in
+ // getModRef.
+ if (aliasAnalysis.getModRef(evalInMem.getBody(), lhs).isModOrRef())
return mlir::failure();
// Any variables affected between the hlfir.evalInMem and assignment must not
// be read or written inside the region since it will be moved at the
@@ -1184,7 +1170,8 @@ tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem,
}
for (const mlir::MemoryEffects::EffectInstance &effect : *effects) {
mlir::Value affected = effect.getValue();
- if (!affected || mayReadOrWrite(evalInMem.getBody(), affected))
+ if (!affected ||
+ aliasAnalysis.getModRef(evalInMem.getBody(), affected).isModOrRef())
return mlir::failure();
}
More information about the flang-commits
mailing list