[flang-commits] [flang] [flang][rfc] Add represention of volatile references (PR #132486)
via flang-commits
flang-commits at lists.llvm.org
Fri Mar 21 16:13:27 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Asher Mancinelli (ashermancinelli)
<details>
<summary>Changes</summary>
[RFC on discourse](https://discourse.llvm.org/t/rfc-volatile-representation-in-flang/85404/1)
Flang currently lacks support for volatile variables. For some cases, the compiler produces TODO error messages and others are ignored. Some of our tests are like the example from _C.4 Clause 8 notes: The VOLATILE attribute (8.5.20)_ and require volatile variables.
This change is a minimal draft of support for volatility in Fortran. This PR does not include some important features, like support for volatility on boxes and other non-reference reference-like types. This commit only supports volatility for `!fir.ref<T>` and is the minimum needed to get end-to-end examples working to see if this is the right direction.
If this is the right direction, I'll break this up into a few chunks, add more tests, and share a smaller PR.
---
Patch is 23.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132486.diff
15 Files Affected:
- (modified) flang/include/flang/Optimizer/Builder/FIRBuilder.h (+1-1)
- (modified) flang/include/flang/Optimizer/Dialect/FIRType.h (+6)
- (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+7-3)
- (modified) flang/lib/Lower/CallInterface.cpp (-1)
- (modified) flang/lib/Lower/ConvertExprToHLFIR.cpp (+42-6)
- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+2-2)
- (modified) flang/lib/Optimizer/Builder/HLFIRTools.cpp (+2-1)
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+12-5)
- (modified) flang/lib/Optimizer/Dialect/FIRType.cpp (+43-14)
- (modified) flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp (+7)
- (modified) flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp (+5-2)
- (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+6-3)
- (added) flang/test/Fir/volatile.fir (+18)
- (added) flang/test/Integration/volatile.f90 (+11)
- (added) flang/test/Lower/volatile.fir (+21)
``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 003b4358572c1..870709a5d55b6 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
mlir::Block *getAllocaBlock();
/// Safely create a reference type to the type `eleTy`.
- mlir::Type getRefType(mlir::Type eleTy);
+ mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false);
/// Create a sequence of `eleTy` with `rank` dimensions of unknown size.
mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 76e0aa352bcd9..8261c67e4559d 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -111,6 +111,12 @@ inline bool isa_ref_type(mlir::Type t) {
fir::LLVMPointerType>(t);
}
+inline bool isa_volatile_ref_type(mlir::Type t) {
+ if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(t))
+ return refTy.isVolatile();
+ return false;
+}
+
/// Is `t` a boxed type?
inline bool isa_box_type(mlir::Type t) {
return mlir::isa<fir::BaseBoxType, fir::BoxCharType, fir::BoxProcType>(t);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index fd5bbbe44751f..0584c175b36ff 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -363,18 +363,22 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
The type of a reference to an entity in memory.
}];
- let parameters = (ins "mlir::Type":$eleTy);
+ let parameters = (ins
+ "mlir::Type":$eleTy,
+ DefaultValuedParameter<"bool", "false">:$isVol);
let skipDefaultBuilders = 1;
let builders = [
- TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
- return Base::get(elementType.getContext(), elementType);
+ TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol), [{
+ return Base::get(elementType.getContext(), elementType, isVol);
}]>,
];
let extraClassDeclaration = [{
mlir::Type getElementType() const { return getEleTy(); }
+ bool isVolatile() const { return (bool)getIsVol(); }
+ static llvm::StringRef getVolatileKeyword() { return "volatile"; }
}];
let genVerifyDecl = 1;
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index 226ba1e52c968..4ee28fbeb9a0c 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -1112,7 +1112,6 @@ class Fortran::lower::CallInterfaceImpl {
if (obj.attrs.test(Attrs::Value))
isValueAttr = true; // TODO: do we want an mlir::Attribute as well?
if (obj.attrs.test(Attrs::Volatile)) {
- TODO(loc, "VOLATILE in procedure interface");
addMLIRAttr(fir::getVolatileAttrName());
}
// obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index dc00e0b13f583..3ac10596df5ae 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -223,8 +223,37 @@ class HlfirDesignatorBuilder {
designatorNode, getConverter().getFoldingContext(),
/*namedConstantSectionsAreAlwaysContiguous=*/false))
return fir::BoxType::get(resultValueType);
+
+ bool isVolatile = false;
+
+ // Check if the base type is volatile
+ if (partInfo.base.has_value()) {
+ mlir::Type baseType = partInfo.base.value().getType();
+ isVolatile = fir::isa_volatile_ref_type(baseType);
+ }
+
+ auto isVolatileSymbol = [](const Fortran::semantics::Symbol &symbol) {
+ return symbol.GetUltimate().attrs().test(
+ Fortran::semantics::Attr::VOLATILE);
+ };
+
+ // Check if this should be a volatile reference
+ if constexpr (std::is_same_v<std::decay_t<T>,
+ Fortran::evaluate::SymbolRef>) {
+ if (isVolatileSymbol(designatorNode.get()))
+ isVolatile = true;
+ } else if constexpr (std::is_same_v<std::decay_t<T>,
+ Fortran::evaluate::Component>) {
+ if (isVolatileSymbol(designatorNode.GetLastSymbol()))
+ isVolatile = true;
+ }
+
+ // If it's a reference to a ref, account for it
+ if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(resultValueType))
+ resultValueType = refTy.getEleTy();
+
// Other designators can be handled as raw addresses.
- return fir::ReferenceType::get(resultValueType);
+ return fir::ReferenceType::get(resultValueType, isVolatile);
}
template <typename T>
@@ -414,10 +443,16 @@ class HlfirDesignatorBuilder {
.Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
return fir::SequenceType::get(seqTy.getShape(), newEleTy);
})
- .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
- fir::ClassType>([&](auto t) -> mlir::Type {
- using FIRT = decltype(t);
- return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
+ // TODO: handle volatility for other types
+ .Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
+ [&](auto t) -> mlir::Type {
+ using FIRT = decltype(t);
+ return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
+ })
+ .Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
+ return fir::ReferenceType::get(
+ changeElementType(refTy.getEleTy(), newEleTy),
+ refTy.isVolatile());
})
.Default([newEleTy](mlir::Type t) -> mlir::Type { return newEleTy; });
}
@@ -1808,6 +1843,7 @@ class HlfirBuilder {
auto &expr = std::get<const Fortran::lower::SomeExpr &>(iter);
auto &baseOp = std::get<hlfir::EntityWithAttributes>(iter);
std::string name = converter.getRecordTypeFieldName(sym);
+ const bool isVolatile = fir::isa_volatile_ref_type(baseOp.getType());
// Generate DesignateOp for the component.
// The designator's result type is just a reference to the component type,
@@ -1818,7 +1854,7 @@ class HlfirBuilder {
assert(compType && "failed to retrieve component type");
mlir::Value compShape =
designatorBuilder.genComponentShape(sym, compType);
- mlir::Type designatorType = builder.getRefType(compType);
+ mlir::Type designatorType = builder.getRefType(compType, isVolatile);
mlir::Type fieldElemType = hlfir::getFortranElementType(compType);
llvm::SmallVector<mlir::Value, 1> typeParams;
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index b7f8a8d3a9d56..02ded29606885 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -104,9 +104,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
return modOp.lookupSymbol<fir::GlobalOp>(name);
}
-mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) {
+mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) {
assert(!mlir::isa<fir::ReferenceType>(eleTy) && "cannot be a reference type");
- return fir::ReferenceType::get(eleTy);
+ return fir::ReferenceType::get(eleTy, isVolatile);
}
mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) {
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 85fd742db6beb..aec88ec97b514 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -819,7 +819,8 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
} else if (fir::isRecordWithTypeParameters(eleTy)) {
return fir::BoxType::get(eleTy);
}
- return fir::ReferenceType::get(eleTy);
+ const bool isVolatile = fir::isa_volatile_ref_type(variable.getType());
+ return fir::ReferenceType::get(eleTy, isVolatile);
}
mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index b54b497ee4ba1..90f2474dafca3 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3224,6 +3224,8 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Type llvmLoadTy = convertObjectType(load.getType());
+ const bool isVolatile =
+ fir::isa_volatile_ref_type(load.getMemref().getType());
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) {
// fir.box is a special case because it is considered an ssa value in
// fir, but it is lowered as a pointer to a descriptor. So
@@ -3253,7 +3255,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
mlir::Value boxSize =
computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
- loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
+ loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile);
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
memcpy.setTBAATags(*optionalTag);
@@ -3261,8 +3263,10 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
attachTBAATag(memcpy, boxTy, boxTy, nullptr);
rewriter.replaceOp(load, newBoxStorage);
} else {
+ // TODO: are we losing any attributes from the load op?
+ auto memref = adaptor.getOperands()[0];
auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
- load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs());
+ load.getLoc(), llvmLoadTy, memref, /*alignment=*/0, isVolatile);
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
loadOp.setTBAATags(*optionalTag);
else
@@ -3540,6 +3544,8 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
mlir::Value llvmValue = adaptor.getValue();
mlir::Value llvmMemref = adaptor.getMemref();
mlir::LLVM::AliasAnalysisOpInterface newOp;
+ const bool isVolatile =
+ fir::isa_volatile_ref_type(store.getMemref().getType());
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(storeTy)) {
mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy);
// Always use memcpy because LLVM is not as effective at optimizing
@@ -3547,10 +3553,11 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
TypePair boxTypePair{boxTy, llvmBoxTy};
mlir::Value boxSize =
computeBoxSize(loc, boxTypePair, llvmValue, rewriter);
- newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
- loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
+ newOp = rewriter.create<mlir::LLVM::MemcpyOp>(loc, llvmMemref, llvmValue,
+ boxSize, isVolatile);
} else {
- newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
+ newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref,
+ /*alignment=*/0, isVolatile);
}
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
newOp.setTBAATags(*optionalTag);
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index f3f969ba401e5..90942522d9073 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -649,12 +649,17 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
.Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
return fir::SequenceType::get(seqTy.getShape(), newElementType);
})
- .Case<fir::PointerType, fir::HeapType, fir::ReferenceType,
- fir::ClassType>([&](auto t) -> mlir::Type {
- using FIRT = decltype(t);
- return FIRT::get(
- changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass));
+ .Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
+ auto newEleTy = changeElementType(refTy.getEleTy(), newElementType,
+ turnBoxIntoClass);
+ return fir::ReferenceType::get(newEleTy, refTy.isVolatile());
})
+ .Case<fir::PointerType, fir::HeapType, fir::ClassType>(
+ [&](auto t) -> mlir::Type {
+ using FIRT = decltype(t);
+ return FIRT::get(changeElementType(t.getEleTy(), newElementType,
+ turnBoxIntoClass));
+ })
.Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type {
mlir::Type newInnerType =
changeElementType(t.getEleTy(), newElementType, false);
@@ -1057,18 +1062,38 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
// ReferenceType
//===----------------------------------------------------------------------===//
-// `ref` `<` type `>`
+// `ref` `<` type (`, volatile` $volatile^)? `>`
mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
- return parseTypeSingleton<fir::ReferenceType>(parser);
+ if (parser.parseLess())
+ return {};
+
+ mlir::Type eleTy;
+ if (parser.parseType(eleTy))
+ return {};
+
+ bool isVolatile = false;
+ if (!parser.parseOptionalComma()) {
+ if (parser.parseKeyword(getVolatileKeyword())) {
+ return {};
+ }
+ isVolatile = true;
+ }
+
+ if (parser.parseGreater())
+ return {};
+ return get(eleTy, isVolatile);
}
void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
- printer << "<" << getEleTy() << '>';
+ printer << "<" << getEleTy();
+ if (isVolatile())
+ printer << ", " << getVolatileKeyword();
+ printer << '>';
}
llvm::LogicalResult fir::ReferenceType::verify(
- llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
- mlir::Type eleTy) {
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError, mlir::Type eleTy,
+ bool isVolatile) {
if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
ReferenceType, TypeDescType>(eleTy))
return emitError() << "cannot build a reference to type: " << eleTy << '\n';
@@ -1319,11 +1344,15 @@ changeTypeShape(mlir::Type type,
return fir::SequenceType::get(*newShape, seqTy.getEleTy());
return seqTy.getEleTy();
})
- .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
- fir::ClassType>([&](auto t) -> mlir::Type {
- using FIRT = decltype(t);
- return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
+ .Case<fir::ReferenceType>([&](fir::ReferenceType rt) -> mlir::Type {
+ return fir::ReferenceType::get(changeTypeShape(rt.getEleTy(), newShape),
+ rt.isVolatile());
})
+ .Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
+ [&](auto t) -> mlir::Type {
+ using FIRT = decltype(t);
+ return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
+ })
.Default([&](mlir::Type t) -> mlir::Type {
assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
llvm::isa<mlir::NoneType>(t) ||
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 8851a3a7187b9..4a3308ff4e747 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -214,6 +214,13 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
auto nameAttr = builder.getStringAttr(uniq_name);
mlir::Type inputType = memref.getType();
bool hasExplicitLbs = hasExplicitLowerBounds(shape);
+ if (fortran_attrs && mlir::isa<fir::ReferenceType>(inputType) &&
+ bitEnumContainsAny(fortran_attrs.getFlags(),
+ fir::FortranVariableFlagsEnum::fortran_volatile)) {
+ auto refType = mlir::cast<fir::ReferenceType>(inputType);
+ inputType = fir::ReferenceType::get(refType.getEleTy(), true);
+ memref = builder.create<fir::ConvertOp>(memref.getLoc(), inputType, memref);
+ }
mlir::Type hlfirVariableType =
getHLFIRVariableType(inputType, hasExplicitLbs);
build(builder, result, {hlfirVariableType, inputType}, memref, shape,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
index 496a5560ac615..aa151f90ed0d1 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
@@ -418,7 +418,9 @@ class DesignateOpConversion
firstElementIndices.push_back(indices[i]);
i = i + (isTriplet ? 3 : 1);
}
- mlir::Type arrayCoorType = fir::ReferenceType::get(baseEleTy);
+ mlir::Type originalDesignateType = designate.getResult().getType();
+ const bool isVolatile = fir::isa_volatile_ref_type(originalDesignateType);
+ mlir::Type arrayCoorType = fir::ReferenceType::get(baseEleTy, isVolatile);
base = builder.create<fir::ArrayCoorOp>(
loc, arrayCoorType, base, shape,
/*slice=*/mlir::Value{}, firstElementIndices, firBaseTypeParameters);
@@ -441,6 +443,7 @@ class DesignateOpConversion
TODO(loc, "hlfir::designate load of pointer or allocatable");
mlir::Type designateResultType = designate.getResult().getType();
+ const bool isVolatile = fir::isa_volatile_ref_type(designateResultType);
llvm::SmallVector<mlir::Value> firBaseTypeParameters;
auto [base, shape] = hlfir::genVariableFirBaseShapeAndParams(
loc, builder, baseEntity, firBaseTypeParameters);
@@ -464,7 +467,7 @@ class DesignateOpConversion
mlir::Type componentType =
mlir::cast<fir::RecordType>(baseEleTy).getType(
designate.getComponent().value());
- mlir::Type coorTy = fir::ReferenceType::get(componentType);
+ mlir::Type coorTy = fir::ReferenceType::get(componentType, isVolatile);
base = builder.create<fir::CoordinateOp>(loc, coorTy, base, fieldIndex);
if (mlir::isa<fir::BaseBoxType>(componentType)) {
auto variableInterface = mlir::cast<fir::FortranVariableOpInterface>(
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 96a3622f4afee..020915179a670 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1126,7 +1126,8 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
builder.create<fir::StoreOp>(loc, flagSet, flagRef);
mlir::Type resultElemTy =
hlfir::getFortranElementType(resultArr.getType());
- mlir::Type returnRefTy = builder.getRefType(resultElemTy);
+ mlir::Type returnRefTy = builder.getRefType(
+ resultElemTy, fir::isa_volatile_ref_type(flagRef.getType()));
mlir::IndexType idxTy = builder.getIndexType();
for (unsigned int i = 0; i < rank; ++i) {
@@ -1153,7 +1154,8 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
const mlir::Type &resultElemType, mlir::Value resultArr,
mlir::Value index) {
- mlir::Type resultRefTy = builder.getRefType(resultElemType);
+ mlir::Type resultRefTy = builder.getRefType(
+ resultElemType, fir::isa_volatile_ref_type(resultArr.getType()));
mlir::Value oneIdx =
builder.createIntegerConstant(loc, builder.getIndexType(), 1);
index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
@@ -1162,8 +1164,9 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
};
// Initialize the resu...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/132486
More information about the flang-commits
mailing list