[flang-commits] [flang] Ajm/flang volatile attr (PR #132469)

via flang-commits flang-commits at lists.llvm.org
Fri Mar 21 13:29:27 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-codegen

Author: Asher Mancinelli (ashermancinelli)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/132469.diff


11 Files Affected:

- (modified) flang/include/flang/Optimizer/Builder/FIRBuilder.h (+1-1) 
- (modified) flang/include/flang/Optimizer/Dialect/FIRType.h (+6) 
- (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+8-3) 
- (modified) flang/lib/Lower/CallInterface.cpp (-1) 
- (modified) flang/lib/Lower/ConvertExprToHLFIR.cpp (+38-6) 
- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+2-2) 
- (modified) flang/lib/Optimizer/Builder/HLFIRTools.cpp (+2-1) 
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+4-2) 
- (modified) flang/lib/Optimizer/Dialect/FIRType.cpp (+25-4) 
- (modified) flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp (+7) 
- (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+4-3) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 1675c15363868..d7ddb37480ebb 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   mlir::Block *getAllocaBlock();
 
   /// Safely create a reference type to the type `eleTy`.
-  mlir::Type getRefType(mlir::Type eleTy);
+  mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false);
 
   /// Create a sequence of `eleTy` with `rank` dimensions of unknown size.
   mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 76e0aa352bcd9..8261c67e4559d 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -111,6 +111,12 @@ inline bool isa_ref_type(mlir::Type t) {
                    fir::LLVMPointerType>(t);
 }
 
+inline bool isa_volatile_ref_type(mlir::Type t) {
+  if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(t))
+    return refTy.isVolatile();
+  return false;
+}
+
 /// Is `t` a boxed type?
 inline bool isa_box_type(mlir::Type t) {
   return mlir::isa<fir::BaseBoxType, fir::BoxCharType, fir::BoxProcType>(t);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index fd5bbbe44751f..c11758cfe9244 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -14,6 +14,7 @@
 #define FIR_DIALECT_FIR_TYPES
 
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributes.td"
 include "flang/Optimizer/Dialect/FIRDialect.td"
 
 //===----------------------------------------------------------------------===//
@@ -363,18 +364,22 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
     The type of a reference to an entity in memory.
   }];
 
-  let parameters = (ins "mlir::Type":$eleTy);
+  let parameters = (ins
+    "mlir::Type":$eleTy,
+    DefaultValuedParameter<"bool", "false">:$isVol);
 
   let skipDefaultBuilders = 1;
 
   let builders = [
-    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
-      return Base::get(elementType.getContext(), elementType);
+    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol), [{
+      return Base::get(elementType.getContext(), elementType, isVol);
     }]>,
   ];
 
   let extraClassDeclaration = [{
     mlir::Type getElementType() const { return getEleTy(); }
+    bool isVolatile() const { return (bool)getIsVol(); }
+    static llvm::StringRef getVolatileKeyword() { return "volatile"; }
   }];
 
   let genVerifyDecl = 1;
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index 226ba1e52c968..4ee28fbeb9a0c 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -1112,7 +1112,6 @@ class Fortran::lower::CallInterfaceImpl {
     if (obj.attrs.test(Attrs::Value))
       isValueAttr = true; // TODO: do we want an mlir::Attribute as well?
     if (obj.attrs.test(Attrs::Volatile)) {
-      TODO(loc, "VOLATILE in procedure interface");
       addMLIRAttr(fir::getVolatileAttrName());
     }
     // obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index dc00e0b13f583..79906c81ecc68 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -223,8 +223,36 @@ class HlfirDesignatorBuilder {
             designatorNode, getConverter().getFoldingContext(),
             /*namedConstantSectionsAreAlwaysContiguous=*/false))
       return fir::BoxType::get(resultValueType);
+
+    bool isVolatile = false;
+
+    // Check if the base type is volatile
+    if (partInfo.base.has_value()) {
+      mlir::Type baseType = partInfo.base.value().getType();
+      isVolatile = fir::isa_volatile_ref_type(baseType);
+    }
+
+    auto isVolatileSymbol = [&](const Fortran::semantics::Symbol &symbol) {
+      return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE);
+    };
+
+    // Check if this should be a volatile reference
+    if constexpr (std::is_same_v<std::decay_t<T>,
+                                 Fortran::evaluate::SymbolRef>) {
+      if (isVolatileSymbol(designatorNode.get()))
+        isVolatile = true;
+    } else if constexpr (std::is_same_v<std::decay_t<T>,
+                                        Fortran::evaluate::Component>) {
+      if (isVolatileSymbol(designatorNode.GetLastSymbol()))
+        isVolatile = true;
+    }
+
+    // If it's a reference to a ref, account for it
+    if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(resultValueType))
+      resultValueType = refTy.getEleTy();
+
     // Other designators can be handled as raw addresses.
-    return fir::ReferenceType::get(resultValueType);
+    return fir::ReferenceType::get(resultValueType, isVolatile);
   }
 
   template <typename T>
@@ -414,10 +442,13 @@ class HlfirDesignatorBuilder {
         .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
           return fir::SequenceType::get(seqTy.getShape(), newEleTy);
         })
-        .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
-              fir::ClassType>([&](auto t) -> mlir::Type {
-          using FIRT = decltype(t);
-          return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
+        .Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
+            [&](auto t) -> mlir::Type {
+              using FIRT = decltype(t);
+              return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
+            })
+        .Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
+          return fir::ReferenceType::get(changeElementType(refTy.getEleTy(), newEleTy), refTy.isVolatile());
         })
         .Default([newEleTy](mlir::Type t) -> mlir::Type { return newEleTy; });
   }
@@ -1808,6 +1839,7 @@ class HlfirBuilder {
       auto &expr = std::get<const Fortran::lower::SomeExpr &>(iter);
       auto &baseOp = std::get<hlfir::EntityWithAttributes>(iter);
       std::string name = converter.getRecordTypeFieldName(sym);
+      const bool isVolatile = fir::isa_volatile_ref_type(baseOp.getType());
 
       // Generate DesignateOp for the component.
       // The designator's result type is just a reference to the component type,
@@ -1818,7 +1850,7 @@ class HlfirBuilder {
       assert(compType && "failed to retrieve component type");
       mlir::Value compShape =
           designatorBuilder.genComponentShape(sym, compType);
-      mlir::Type designatorType = builder.getRefType(compType);
+      mlir::Type designatorType = builder.getRefType(compType, isVolatile);
 
       mlir::Type fieldElemType = hlfir::getFortranElementType(compType);
       llvm::SmallVector<mlir::Value, 1> typeParams;
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index b3d440cedee07..cfae25f8fe4b9 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -104,9 +104,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
   return modOp.lookupSymbol<fir::GlobalOp>(name);
 }
 
-mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) {
+mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) {
   assert(!mlir::isa<fir::ReferenceType>(eleTy) && "cannot be a reference type");
-  return fir::ReferenceType::get(eleTy);
+  return fir::ReferenceType::get(eleTy, isVolatile);
 }
 
 mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) {
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 1a31ca33e9465..cf8bb7eaddf70 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -809,7 +809,8 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
   } else if (fir::isRecordWithTypeParameters(eleTy)) {
     return fir::BoxType::get(eleTy);
   }
-  return fir::ReferenceType::get(eleTy);
+  const bool isVolatile = fir::isa_volatile_ref_type(variable.getType());
+  return fir::ReferenceType::get(eleTy, isVolatile);
 }
 
 mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 2cb4cea58c2b0..2ef9fc79403c7 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3218,6 +3218,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
                   mlir::ConversionPatternRewriter &rewriter) const override {
 
     mlir::Type llvmLoadTy = convertObjectType(load.getType());
+    const bool isVolatile = fir::isa_volatile_ref_type(load.getMemref().getType());
     if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) {
       // fir.box is a special case because it is considered an ssa value in
       // fir, but it is lowered as a pointer to a descriptor. So
@@ -3247,7 +3248,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
       mlir::Value boxSize =
           computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
       auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
-          loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
+          loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile);
 
       if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
         memcpy.setTBAATags(*optionalTag);
@@ -3255,8 +3256,9 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
         attachTBAATag(memcpy, boxTy, boxTy, nullptr);
       rewriter.replaceOp(load, newBoxStorage);
     } else {
+      auto memref = adaptor.getOperands()[0];
       auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
-          load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs());
+          load.getLoc(), llvmLoadTy, memref, /*alignment=*/0, isVolatile);
       if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
         loadOp.setTBAATags(*optionalTag);
       else
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index dc0bee9b060c9..e2dc1ed3f3ecb 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -1057,18 +1057,39 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
 // ReferenceType
 //===----------------------------------------------------------------------===//
 
-// `ref` `<` type `>`
+// `ref` `<` type (`, volatile` $volatile^)? (`, async` $async^)? `>`
 mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
-  return parseTypeSingleton<fir::ReferenceType>(parser);
+  if (parser.parseLess())
+    return {};
+
+  mlir::Type eleTy;
+  if (parser.parseType(eleTy))
+    return {};
+
+  bool isVolatile = false;
+  if (parser.parseOptionalComma()) {
+    if (parser.parseOptionalKeyword(getVolatileKeyword())) {
+      isVolatile = true;
+    } else {
+      return {};
+    }
+  }
+
+  if (parser.parseGreater())
+    return {};
+  return ReferenceType::get(eleTy, isVolatile);
 }
 
 void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
-  printer << "<" << getEleTy() << '>';
+  printer << "<" << getEleTy();
+  if (isVolatile())
+    printer << ", volatile";
+  printer << '>';
 }
 
 llvm::LogicalResult fir::ReferenceType::verify(
     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
-    mlir::Type eleTy) {
+    mlir::Type eleTy, bool isVolatile) {
   if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
                 ReferenceType, TypeDescType>(eleTy))
     return emitError() << "cannot build a reference to type: " << eleTy << '\n';
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 8851a3a7187b9..4a3308ff4e747 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -214,6 +214,13 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
   auto nameAttr = builder.getStringAttr(uniq_name);
   mlir::Type inputType = memref.getType();
   bool hasExplicitLbs = hasExplicitLowerBounds(shape);
+  if (fortran_attrs && mlir::isa<fir::ReferenceType>(inputType) &&
+      bitEnumContainsAny(fortran_attrs.getFlags(),
+                         fir::FortranVariableFlagsEnum::fortran_volatile)) {
+    auto refType = mlir::cast<fir::ReferenceType>(inputType);
+    inputType = fir::ReferenceType::get(refType.getEleTy(), true);
+    memref = builder.create<fir::ConvertOp>(memref.getLoc(), inputType, memref);
+  }
   mlir::Type hlfirVariableType =
       getHLFIRVariableType(inputType, hasExplicitLbs);
   build(builder, result, {hlfirVariableType, inputType}, memref, shape,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 96a3622f4afee..e22b3d224ca1f 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1126,7 +1126,7 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
       builder.create<fir::StoreOp>(loc, flagSet, flagRef);
       mlir::Type resultElemTy =
           hlfir::getFortranElementType(resultArr.getType());
-      mlir::Type returnRefTy = builder.getRefType(resultElemTy);
+      mlir::Type returnRefTy = builder.getRefType(resultElemTy, fir::isa_volatile_ref_type(flagRef.getType()));
       mlir::IndexType idxTy = builder.getIndexType();
 
       for (unsigned int i = 0; i < rank; ++i) {
@@ -1153,7 +1153,7 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
     auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
                         const mlir::Type &resultElemType, mlir::Value resultArr,
                         mlir::Value index) {
-      mlir::Type resultRefTy = builder.getRefType(resultElemType);
+      mlir::Type resultRefTy = builder.getRefType(resultElemType, fir::isa_volatile_ref_type(resultArr.getType()));
       mlir::Value oneIdx =
           builder.createIntegerConstant(loc, builder.getIndexType(), 1);
       index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
@@ -1162,8 +1162,9 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
     };
 
     // Initialize the result
+    const bool isVolatile = fir::isa_volatile_ref_type(resultArr.getType());
     mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
-    mlir::Type resultRefTy = builder.getRefType(resultElemTy);
+    mlir::Type resultRefTy = builder.getRefType(resultElemTy, isVolatile);
     mlir::Value returnValue =
         builder.createIntegerConstant(loc, resultElemTy, 0);
     for (unsigned int i = 0; i < rank; ++i) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/132469


More information about the flang-commits mailing list