[flang-commits] [flang] [flang][rfc] Add represention of volatile references (PR #132486)

via flang-commits flang-commits at lists.llvm.org
Fri Mar 21 16:13:27 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Asher Mancinelli (ashermancinelli)

<details>
<summary>Changes</summary>

[RFC on discourse](https://discourse.llvm.org/t/rfc-volatile-representation-in-flang/85404/1)

Flang currently lacks support for volatile variables. For some cases, the compiler produces TODO error messages and others are ignored. Some of our tests are like the example from _C.4 Clause 8 notes: The VOLATILE attribute (8.5.20)_ and require volatile variables.

This change is a minimal draft of support for volatility in Fortran. This PR does not include some important features, like support for volatility on boxes and other non-reference reference-like types. This commit only supports volatility for `!fir.ref<T>` and is the minimum needed to get end-to-end examples working to see if this is the right direction.

If this is the right direction, I'll break this up into a few chunks, add more tests, and share a smaller PR.

---

Patch is 23.70 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132486.diff


15 Files Affected:

- (modified) flang/include/flang/Optimizer/Builder/FIRBuilder.h (+1-1) 
- (modified) flang/include/flang/Optimizer/Dialect/FIRType.h (+6) 
- (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+7-3) 
- (modified) flang/lib/Lower/CallInterface.cpp (-1) 
- (modified) flang/lib/Lower/ConvertExprToHLFIR.cpp (+42-6) 
- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+2-2) 
- (modified) flang/lib/Optimizer/Builder/HLFIRTools.cpp (+2-1) 
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+12-5) 
- (modified) flang/lib/Optimizer/Dialect/FIRType.cpp (+43-14) 
- (modified) flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp (+7) 
- (modified) flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp (+5-2) 
- (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+6-3) 
- (added) flang/test/Fir/volatile.fir (+18) 
- (added) flang/test/Integration/volatile.f90 (+11) 
- (added) flang/test/Lower/volatile.fir (+21) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 003b4358572c1..870709a5d55b6 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   mlir::Block *getAllocaBlock();
 
   /// Safely create a reference type to the type `eleTy`.
-  mlir::Type getRefType(mlir::Type eleTy);
+  mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false);
 
   /// Create a sequence of `eleTy` with `rank` dimensions of unknown size.
   mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 76e0aa352bcd9..8261c67e4559d 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -111,6 +111,12 @@ inline bool isa_ref_type(mlir::Type t) {
                    fir::LLVMPointerType>(t);
 }
 
+inline bool isa_volatile_ref_type(mlir::Type t) {
+  if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(t))
+    return refTy.isVolatile();
+  return false;
+}
+
 /// Is `t` a boxed type?
 inline bool isa_box_type(mlir::Type t) {
   return mlir::isa<fir::BaseBoxType, fir::BoxCharType, fir::BoxProcType>(t);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index fd5bbbe44751f..0584c175b36ff 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -363,18 +363,22 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
     The type of a reference to an entity in memory.
   }];
 
-  let parameters = (ins "mlir::Type":$eleTy);
+  let parameters = (ins
+    "mlir::Type":$eleTy,
+    DefaultValuedParameter<"bool", "false">:$isVol);
 
   let skipDefaultBuilders = 1;
 
   let builders = [
-    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
-      return Base::get(elementType.getContext(), elementType);
+    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol), [{
+      return Base::get(elementType.getContext(), elementType, isVol);
     }]>,
   ];
 
   let extraClassDeclaration = [{
     mlir::Type getElementType() const { return getEleTy(); }
+    bool isVolatile() const { return (bool)getIsVol(); }
+    static llvm::StringRef getVolatileKeyword() { return "volatile"; }
   }];
 
   let genVerifyDecl = 1;
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index 226ba1e52c968..4ee28fbeb9a0c 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -1112,7 +1112,6 @@ class Fortran::lower::CallInterfaceImpl {
     if (obj.attrs.test(Attrs::Value))
       isValueAttr = true; // TODO: do we want an mlir::Attribute as well?
     if (obj.attrs.test(Attrs::Volatile)) {
-      TODO(loc, "VOLATILE in procedure interface");
       addMLIRAttr(fir::getVolatileAttrName());
     }
     // obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index dc00e0b13f583..3ac10596df5ae 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -223,8 +223,37 @@ class HlfirDesignatorBuilder {
             designatorNode, getConverter().getFoldingContext(),
             /*namedConstantSectionsAreAlwaysContiguous=*/false))
       return fir::BoxType::get(resultValueType);
+
+    bool isVolatile = false;
+
+    // Check if the base type is volatile
+    if (partInfo.base.has_value()) {
+      mlir::Type baseType = partInfo.base.value().getType();
+      isVolatile = fir::isa_volatile_ref_type(baseType);
+    }
+
+    auto isVolatileSymbol = [](const Fortran::semantics::Symbol &symbol) {
+      return symbol.GetUltimate().attrs().test(
+          Fortran::semantics::Attr::VOLATILE);
+    };
+
+    // Check if this should be a volatile reference
+    if constexpr (std::is_same_v<std::decay_t<T>,
+                                 Fortran::evaluate::SymbolRef>) {
+      if (isVolatileSymbol(designatorNode.get()))
+        isVolatile = true;
+    } else if constexpr (std::is_same_v<std::decay_t<T>,
+                                        Fortran::evaluate::Component>) {
+      if (isVolatileSymbol(designatorNode.GetLastSymbol()))
+        isVolatile = true;
+    }
+
+    // If it's a reference to a ref, account for it
+    if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(resultValueType))
+      resultValueType = refTy.getEleTy();
+
     // Other designators can be handled as raw addresses.
-    return fir::ReferenceType::get(resultValueType);
+    return fir::ReferenceType::get(resultValueType, isVolatile);
   }
 
   template <typename T>
@@ -414,10 +443,16 @@ class HlfirDesignatorBuilder {
         .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
           return fir::SequenceType::get(seqTy.getShape(), newEleTy);
         })
-        .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
-              fir::ClassType>([&](auto t) -> mlir::Type {
-          using FIRT = decltype(t);
-          return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
+        // TODO: handle volatility for other types
+        .Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
+            [&](auto t) -> mlir::Type {
+              using FIRT = decltype(t);
+              return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
+            })
+        .Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
+          return fir::ReferenceType::get(
+              changeElementType(refTy.getEleTy(), newEleTy),
+              refTy.isVolatile());
         })
         .Default([newEleTy](mlir::Type t) -> mlir::Type { return newEleTy; });
   }
@@ -1808,6 +1843,7 @@ class HlfirBuilder {
       auto &expr = std::get<const Fortran::lower::SomeExpr &>(iter);
       auto &baseOp = std::get<hlfir::EntityWithAttributes>(iter);
       std::string name = converter.getRecordTypeFieldName(sym);
+      const bool isVolatile = fir::isa_volatile_ref_type(baseOp.getType());
 
       // Generate DesignateOp for the component.
       // The designator's result type is just a reference to the component type,
@@ -1818,7 +1854,7 @@ class HlfirBuilder {
       assert(compType && "failed to retrieve component type");
       mlir::Value compShape =
           designatorBuilder.genComponentShape(sym, compType);
-      mlir::Type designatorType = builder.getRefType(compType);
+      mlir::Type designatorType = builder.getRefType(compType, isVolatile);
 
       mlir::Type fieldElemType = hlfir::getFortranElementType(compType);
       llvm::SmallVector<mlir::Value, 1> typeParams;
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index b7f8a8d3a9d56..02ded29606885 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -104,9 +104,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
   return modOp.lookupSymbol<fir::GlobalOp>(name);
 }
 
-mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) {
+mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) {
   assert(!mlir::isa<fir::ReferenceType>(eleTy) && "cannot be a reference type");
-  return fir::ReferenceType::get(eleTy);
+  return fir::ReferenceType::get(eleTy, isVolatile);
 }
 
 mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) {
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 85fd742db6beb..aec88ec97b514 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -819,7 +819,8 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
   } else if (fir::isRecordWithTypeParameters(eleTy)) {
     return fir::BoxType::get(eleTy);
   }
-  return fir::ReferenceType::get(eleTy);
+  const bool isVolatile = fir::isa_volatile_ref_type(variable.getType());
+  return fir::ReferenceType::get(eleTy, isVolatile);
 }
 
 mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index b54b497ee4ba1..90f2474dafca3 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3224,6 +3224,8 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
                   mlir::ConversionPatternRewriter &rewriter) const override {
 
     mlir::Type llvmLoadTy = convertObjectType(load.getType());
+    const bool isVolatile =
+        fir::isa_volatile_ref_type(load.getMemref().getType());
     if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) {
       // fir.box is a special case because it is considered an ssa value in
       // fir, but it is lowered as a pointer to a descriptor. So
@@ -3253,7 +3255,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
       mlir::Value boxSize =
           computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
       auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
-          loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
+          loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile);
 
       if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
         memcpy.setTBAATags(*optionalTag);
@@ -3261,8 +3263,10 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
         attachTBAATag(memcpy, boxTy, boxTy, nullptr);
       rewriter.replaceOp(load, newBoxStorage);
     } else {
+      // TODO: are we losing any attributes from the load op?
+      auto memref = adaptor.getOperands()[0];
       auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
-          load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs());
+          load.getLoc(), llvmLoadTy, memref, /*alignment=*/0, isVolatile);
       if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
         loadOp.setTBAATags(*optionalTag);
       else
@@ -3540,6 +3544,8 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
     mlir::Value llvmValue = adaptor.getValue();
     mlir::Value llvmMemref = adaptor.getMemref();
     mlir::LLVM::AliasAnalysisOpInterface newOp;
+    const bool isVolatile =
+        fir::isa_volatile_ref_type(store.getMemref().getType());
     if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(storeTy)) {
       mlir::Type llvmBoxTy = lowerTy().convertBoxTypeAsStruct(boxTy);
       // Always use memcpy because LLVM is not as effective at optimizing
@@ -3547,10 +3553,11 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
       TypePair boxTypePair{boxTy, llvmBoxTy};
       mlir::Value boxSize =
           computeBoxSize(loc, boxTypePair, llvmValue, rewriter);
-      newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
-          loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
+      newOp = rewriter.create<mlir::LLVM::MemcpyOp>(loc, llvmMemref, llvmValue,
+                                                    boxSize, isVolatile);
     } else {
-      newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
+      newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref,
+                                                   /*alignment=*/0, isVolatile);
     }
     if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
       newOp.setTBAATags(*optionalTag);
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index f3f969ba401e5..90942522d9073 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -649,12 +649,17 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
       .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
         return fir::SequenceType::get(seqTy.getShape(), newElementType);
       })
-      .Case<fir::PointerType, fir::HeapType, fir::ReferenceType,
-            fir::ClassType>([&](auto t) -> mlir::Type {
-        using FIRT = decltype(t);
-        return FIRT::get(
-            changeElementType(t.getEleTy(), newElementType, turnBoxIntoClass));
+      .Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
+        auto newEleTy = changeElementType(refTy.getEleTy(), newElementType,
+                                          turnBoxIntoClass);
+        return fir::ReferenceType::get(newEleTy, refTy.isVolatile());
       })
+      .Case<fir::PointerType, fir::HeapType, fir::ClassType>(
+          [&](auto t) -> mlir::Type {
+            using FIRT = decltype(t);
+            return FIRT::get(changeElementType(t.getEleTy(), newElementType,
+                                               turnBoxIntoClass));
+          })
       .Case<fir::BoxType>([&](fir::BoxType t) -> mlir::Type {
         mlir::Type newInnerType =
             changeElementType(t.getEleTy(), newElementType, false);
@@ -1057,18 +1062,38 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
 // ReferenceType
 //===----------------------------------------------------------------------===//
 
-// `ref` `<` type `>`
+// `ref` `<` type (`, volatile` $volatile^)? `>`
 mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
-  return parseTypeSingleton<fir::ReferenceType>(parser);
+  if (parser.parseLess())
+    return {};
+
+  mlir::Type eleTy;
+  if (parser.parseType(eleTy))
+    return {};
+
+  bool isVolatile = false;
+  if (!parser.parseOptionalComma()) {
+    if (parser.parseKeyword(getVolatileKeyword())) {
+      return {};
+    }
+    isVolatile = true;
+  }
+
+  if (parser.parseGreater())
+    return {};
+  return get(eleTy, isVolatile);
 }
 
 void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
-  printer << "<" << getEleTy() << '>';
+  printer << "<" << getEleTy();
+  if (isVolatile())
+    printer << ", " << getVolatileKeyword();
+  printer << '>';
 }
 
 llvm::LogicalResult fir::ReferenceType::verify(
-    llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
-    mlir::Type eleTy) {
+    llvm::function_ref<mlir::InFlightDiagnostic()> emitError, mlir::Type eleTy,
+    bool isVolatile) {
   if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
                 ReferenceType, TypeDescType>(eleTy))
     return emitError() << "cannot build a reference to type: " << eleTy << '\n';
@@ -1319,11 +1344,15 @@ changeTypeShape(mlir::Type type,
           return fir::SequenceType::get(*newShape, seqTy.getEleTy());
         return seqTy.getEleTy();
       })
-      .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
-            fir::ClassType>([&](auto t) -> mlir::Type {
-        using FIRT = decltype(t);
-        return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
+      .Case<fir::ReferenceType>([&](fir::ReferenceType rt) -> mlir::Type {
+        return fir::ReferenceType::get(changeTypeShape(rt.getEleTy(), newShape),
+                                       rt.isVolatile());
       })
+      .Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
+          [&](auto t) -> mlir::Type {
+            using FIRT = decltype(t);
+            return FIRT::get(changeTypeShape(t.getEleTy(), newShape));
+          })
       .Default([&](mlir::Type t) -> mlir::Type {
         assert((fir::isa_trivial(t) || llvm::isa<fir::RecordType>(t) ||
                 llvm::isa<mlir::NoneType>(t) ||
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 8851a3a7187b9..4a3308ff4e747 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -214,6 +214,13 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
   auto nameAttr = builder.getStringAttr(uniq_name);
   mlir::Type inputType = memref.getType();
   bool hasExplicitLbs = hasExplicitLowerBounds(shape);
+  if (fortran_attrs && mlir::isa<fir::ReferenceType>(inputType) &&
+      bitEnumContainsAny(fortran_attrs.getFlags(),
+                         fir::FortranVariableFlagsEnum::fortran_volatile)) {
+    auto refType = mlir::cast<fir::ReferenceType>(inputType);
+    inputType = fir::ReferenceType::get(refType.getEleTy(), true);
+    memref = builder.create<fir::ConvertOp>(memref.getLoc(), inputType, memref);
+  }
   mlir::Type hlfirVariableType =
       getHLFIRVariableType(inputType, hasExplicitLbs);
   build(builder, result, {hlfirVariableType, inputType}, memref, shape,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
index 496a5560ac615..aa151f90ed0d1 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp
@@ -418,7 +418,9 @@ class DesignateOpConversion
       firstElementIndices.push_back(indices[i]);
       i = i + (isTriplet ? 3 : 1);
     }
-    mlir::Type arrayCoorType = fir::ReferenceType::get(baseEleTy);
+    mlir::Type originalDesignateType = designate.getResult().getType();
+    const bool isVolatile = fir::isa_volatile_ref_type(originalDesignateType);
+    mlir::Type arrayCoorType = fir::ReferenceType::get(baseEleTy, isVolatile);
     base = builder.create<fir::ArrayCoorOp>(
         loc, arrayCoorType, base, shape,
         /*slice=*/mlir::Value{}, firstElementIndices, firBaseTypeParameters);
@@ -441,6 +443,7 @@ class DesignateOpConversion
       TODO(loc, "hlfir::designate load of pointer or allocatable");
 
     mlir::Type designateResultType = designate.getResult().getType();
+    const bool isVolatile = fir::isa_volatile_ref_type(designateResultType);
     llvm::SmallVector<mlir::Value> firBaseTypeParameters;
     auto [base, shape] = hlfir::genVariableFirBaseShapeAndParams(
         loc, builder, baseEntity, firBaseTypeParameters);
@@ -464,7 +467,7 @@ class DesignateOpConversion
         mlir::Type componentType =
             mlir::cast<fir::RecordType>(baseEleTy).getType(
                 designate.getComponent().value());
-        mlir::Type coorTy = fir::ReferenceType::get(componentType);
+        mlir::Type coorTy = fir::ReferenceType::get(componentType, isVolatile);
         base = builder.create<fir::CoordinateOp>(loc, coorTy, base, fieldIndex);
         if (mlir::isa<fir::BaseBoxType>(componentType)) {
           auto variableInterface = mlir::cast<fir::FortranVariableOpInterface>(
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 96a3622f4afee..020915179a670 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1126,7 +1126,8 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
       builder.create<fir::StoreOp>(loc, flagSet, flagRef);
       mlir::Type resultElemTy =
           hlfir::getFortranElementType(resultArr.getType());
-      mlir::Type returnRefTy = builder.getRefType(resultElemTy);
+      mlir::Type returnRefTy = builder.getRefType(
+          resultElemTy, fir::isa_volatile_ref_type(flagRef.getType()));
       mlir::IndexType idxTy = builder.getIndexType();
 
       for (unsigned int i = 0; i < rank; ++i) {
@@ -1153,7 +1154,8 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
     auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
                         const mlir::Type &resultElemType, mlir::Value resultArr,
                         mlir::Value index) {
-      mlir::Type resultRefTy = builder.getRefType(resultElemType);
+      mlir::Type resultRefTy = builder.getRefType(
+          resultElemType, fir::isa_volatile_ref_type(resultArr.getType()));
       mlir::Value oneIdx =
           builder.createIntegerConstant(loc, builder.getIndexType(), 1);
       index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
@@ -1162,8 +1164,9 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
     };
 
     // Initialize the resu...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/132486


More information about the flang-commits mailing list