[flang-commits] [flang] Ajm/flang volatile attr (PR #132469)

Asher Mancinelli via flang-commits flang-commits at lists.llvm.org
Fri Mar 21 13:28:52 PDT 2025


https://github.com/ashermancinelli created https://github.com/llvm/llvm-project/pull/132469

None

>From 82077c46c1f26ab3234711ece19d3e8102cd7fa5 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 20 Mar 2025 08:01:56 -0700
Subject: [PATCH 1/3] init

---
 flang/include/flang/Optimizer/Dialect/FIRTypes.td | 13 ++++++++++++-
 flang/lib/Optimizer/Dialect/FIRType.cpp           |  2 +-
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index fd5bbbe44751f..2546bc4f9af6e 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -14,6 +14,7 @@
 #define FIR_DIALECT_FIR_TYPES
 
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributes.td"
 include "flang/Optimizer/Dialect/FIRDialect.td"
 
 //===----------------------------------------------------------------------===//
@@ -363,7 +364,9 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
     The type of a reference to an entity in memory.
   }];
 
-  let parameters = (ins "mlir::Type":$eleTy);
+  let parameters = (ins
+    "mlir::Type":$eleTy,
+    "mlir::UnitAttr":$isVol);
 
   let skipDefaultBuilders = 1;
 
@@ -371,10 +374,18 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
     TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
       return Base::get(elementType.getContext(), elementType);
     }]>,
+    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, "bool":$isVol)>,
+    // [{
+    //  if (isVol)
+    //    return Base::get(elementType.getContext(), elementType, mlir::UnitAttr::get(elementType.getContext()));
+    //  else
+    //    return Base::get(elementType.getContext(), elementType);
+    //}]>,
   ];
 
   let extraClassDeclaration = [{
     mlir::Type getElementType() const { return getEleTy(); }
+    bool isVolatile() const { return (bool)getIsVol(); }
   }];
 
   let genVerifyDecl = 1;
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index dc0bee9b060c9..3cd5d8725f749 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -1068,7 +1068,7 @@ void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
 
 llvm::LogicalResult fir::ReferenceType::verify(
     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
-    mlir::Type eleTy) {
+    mlir::Type eleTy, mlir::UnitAttr isVolatile) {
   if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
                 ReferenceType, TypeDescType>(eleTy))
     return emitError() << "cannot build a reference to type: " << eleTy << '\n';

>From e27bb5bcd18d8a7b3005f824c49ced6ecbb1f0e5 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 20 Mar 2025 14:31:56 -0700
Subject: [PATCH 2/3] checkpoint

---
 .../flang/Optimizer/Builder/FIRBuilder.h      |  2 +-
 .../include/flang/Optimizer/Dialect/FIRType.h |  6 +++
 .../flang/Optimizer/Dialect/FIRTypes.td       | 17 +++----
 flang/lib/Lower/CallInterface.cpp             |  2 +-
 flang/lib/Lower/ConvertExprToHLFIR.cpp        | 48 ++++++++++++++++---
 flang/lib/Optimizer/Builder/FIRBuilder.cpp    |  4 +-
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    |  5 +-
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       |  6 ++-
 flang/lib/Optimizer/Dialect/FIRType.cpp       | 31 ++++++++++--
 flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp     |  7 +++
 .../Transforms/OptimizedBufferization.cpp     |  7 +--
 11 files changed, 105 insertions(+), 30 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index 1675c15363868..d7ddb37480ebb 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -150,7 +150,7 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   mlir::Block *getAllocaBlock();
 
   /// Safely create a reference type to the type `eleTy`.
-  mlir::Type getRefType(mlir::Type eleTy);
+  mlir::Type getRefType(mlir::Type eleTy, bool isVolatile = false);
 
   /// Create a sequence of `eleTy` with `rank` dimensions of unknown size.
   mlir::Type getVarLenSeqTy(mlir::Type eleTy, unsigned rank = 1);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 76e0aa352bcd9..8261c67e4559d 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -111,6 +111,12 @@ inline bool isa_ref_type(mlir::Type t) {
                    fir::LLVMPointerType>(t);
 }
 
+inline bool isa_volatile_ref_type(mlir::Type t) {
+  if (auto refTy = mlir::dyn_cast_or_null<fir::ReferenceType>(t))
+    return refTy.isVolatile();
+  return false;
+}
+
 /// Is `t` a boxed type?
 inline bool isa_box_type(mlir::Type t) {
   return mlir::isa<fir::BaseBoxType, fir::BoxCharType, fir::BoxProcType>(t);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 2546bc4f9af6e..a08943f5067fc 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -366,26 +366,23 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
 
   let parameters = (ins
     "mlir::Type":$eleTy,
-    "mlir::UnitAttr":$isVol);
+    DefaultValuedParameter<"bool", "false">:$isVol,
+    DefaultValuedParameter<"bool", "false">:$isAsync);
 
   let skipDefaultBuilders = 1;
 
   let builders = [
-    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
-      return Base::get(elementType.getContext(), elementType);
+    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol, CArg<"bool", "false">:$isAsync), [{
+      return Base::get(elementType.getContext(), elementType, isVol, isAsync);
     }]>,
-    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, "bool":$isVol)>,
-    // [{
-    //  if (isVol)
-    //    return Base::get(elementType.getContext(), elementType, mlir::UnitAttr::get(elementType.getContext()));
-    //  else
-    //    return Base::get(elementType.getContext(), elementType);
-    //}]>,
   ];
 
   let extraClassDeclaration = [{
     mlir::Type getElementType() const { return getEleTy(); }
     bool isVolatile() const { return (bool)getIsVol(); }
+    bool isAsync() const { return (bool)getIsAsync(); }
+    static llvm::StringRef getVolatileKeyword() { return "volatile"; }
+    static llvm::StringRef getAsyncKeyword() { return "async"; }
   }];
 
   let genVerifyDecl = 1;
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index 226ba1e52c968..c741f1c1d2c76 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -1112,7 +1112,7 @@ class Fortran::lower::CallInterfaceImpl {
     if (obj.attrs.test(Attrs::Value))
       isValueAttr = true; // TODO: do we want an mlir::Attribute as well?
     if (obj.attrs.test(Attrs::Volatile)) {
-      TODO(loc, "VOLATILE in procedure interface");
+      // TODO(loc, "VOLATILE in procedure interface");
       addMLIRAttr(fir::getVolatileAttrName());
     }
     // obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index dc00e0b13f583..6908092399a8d 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -223,8 +223,37 @@ class HlfirDesignatorBuilder {
             designatorNode, getConverter().getFoldingContext(),
             /*namedConstantSectionsAreAlwaysContiguous=*/false))
       return fir::BoxType::get(resultValueType);
+
+    // TODO: handle async references
+    bool isVolatile = false, isAsync = false;
+
+    // Check if the base type is volatile
+    if (partInfo.base.has_value()) {
+      mlir::Type baseType = partInfo.base.value().getType();
+      isVolatile = fir::isa_volatile_ref_type(baseType);
+    }
+
+    auto isVolatileSymbol = [&](const Fortran::semantics::Symbol &symbol) {
+      return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE);
+    };
+
+    // Check if this should be a volatile reference
+    if constexpr (std::is_same_v<std::decay_t<T>,
+                                 Fortran::evaluate::SymbolRef>) {
+      if (isVolatileSymbol(designatorNode.get()))
+        isVolatile = true;
+    } else if constexpr (std::is_same_v<std::decay_t<T>,
+                                        Fortran::evaluate::Component>) {
+      if (isVolatileSymbol(designatorNode.GetLastSymbol()))
+        isVolatile = true;
+    }
+    
+    // If it's a reference to a ref, account for it
+    if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(resultValueType))
+      resultValueType = refTy.getEleTy();
+
     // Other designators can be handled as raw addresses.
-    return fir::ReferenceType::get(resultValueType);
+    return fir::ReferenceType::get(resultValueType, isVolatile, isAsync);
   }
 
   template <typename T>
@@ -269,6 +298,7 @@ class HlfirDesignatorBuilder {
         partInfo.componentName, partInfo.componentShape, partInfo.subscripts,
         partInfo.substring, partInfo.complexPart, partInfo.resultShape,
         partInfo.typeParams, attributes);
+    llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n" << designate << "\n" << designatorType << "\n";
     if (auto elementalAddrOp = getVectorSubscriptElementAddrOp())
       builder.setInsertionPoint(*elementalAddrOp);
     return mlir::cast<fir::FortranVariableOpInterface>(
@@ -414,10 +444,13 @@ class HlfirDesignatorBuilder {
         .Case<fir::SequenceType>([&](fir::SequenceType seqTy) -> mlir::Type {
           return fir::SequenceType::get(seqTy.getShape(), newEleTy);
         })
-        .Case<fir::PointerType, fir::HeapType, fir::ReferenceType, fir::BoxType,
-              fir::ClassType>([&](auto t) -> mlir::Type {
-          using FIRT = decltype(t);
-          return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
+        .Case<fir::PointerType, fir::HeapType, fir::BoxType, fir::ClassType>(
+            [&](auto t) -> mlir::Type {
+              using FIRT = decltype(t);
+              return FIRT::get(changeElementType(t.getEleTy(), newEleTy));
+            })
+        .Case<fir::ReferenceType>([&](fir::ReferenceType refTy) -> mlir::Type {
+          return fir::ReferenceType::get(changeElementType(refTy.getEleTy(), newEleTy), refTy.isVolatile());
         })
         .Default([newEleTy](mlir::Type t) -> mlir::Type { return newEleTy; });
   }
@@ -1796,6 +1829,7 @@ class HlfirBuilder {
             /*complexPart=*/std::nullopt,
             /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{},
             fir::FortranVariableFlagsAttr{});
+        llvm::dbgs() << __LINE__ << " " << newParent << "\n";
         currentParent = hlfir::EntityWithAttributes{newParent};
       }
       valuesAndParents.emplace_back(
@@ -1808,6 +1842,7 @@ class HlfirBuilder {
       auto &expr = std::get<const Fortran::lower::SomeExpr &>(iter);
       auto &baseOp = std::get<hlfir::EntityWithAttributes>(iter);
       std::string name = converter.getRecordTypeFieldName(sym);
+      const bool isVolatile = fir::isa_volatile_ref_type(baseOp.getType());
 
       // Generate DesignateOp for the component.
       // The designator's result type is just a reference to the component type,
@@ -1818,7 +1853,6 @@ class HlfirBuilder {
       assert(compType && "failed to retrieve component type");
       mlir::Value compShape =
           designatorBuilder.genComponentShape(sym, compType);
-      mlir::Type designatorType = builder.getRefType(compType);
 
       mlir::Type fieldElemType = hlfir::getFortranElementType(compType);
       llvm::SmallVector<mlir::Value, 1> typeParams;
@@ -1839,6 +1873,7 @@ class HlfirBuilder {
       // Convert component symbol attributes to variable attributes.
       fir::FortranVariableFlagsAttr attrs =
           Fortran::lower::translateSymbolAttributes(builder.getContext(), sym);
+      mlir::Type designatorType = builder.getRefType(compType, isVolatile);
 
       // Get the component designator.
       auto lhs = builder.create<hlfir::DesignateOp>(
@@ -1847,6 +1882,7 @@ class HlfirBuilder {
           /*substring=*/mlir::ValueRange{},
           /*complexPart=*/std::nullopt,
           /*shape=*/compShape, typeParams, attrs);
+      llvm::dbgs() << __LINE__ << " " << lhs << "\n";
 
       if (attrs && bitEnumContainsAny(attrs.getFlags(),
                                       fir::FortranVariableFlagsEnum::pointer)) {
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index b3d440cedee07..cfae25f8fe4b9 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -104,9 +104,9 @@ fir::FirOpBuilder::getNamedGlobal(mlir::ModuleOp modOp,
   return modOp.lookupSymbol<fir::GlobalOp>(name);
 }
 
-mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy) {
+mlir::Type fir::FirOpBuilder::getRefType(mlir::Type eleTy, bool isVolatile) {
   assert(!mlir::isa<fir::ReferenceType>(eleTy) && "cannot be a reference type");
-  return fir::ReferenceType::get(eleTy);
+  return fir::ReferenceType::get(eleTy, isVolatile);
 }
 
 mlir::Type fir::FirOpBuilder::getVarLenSeqTy(mlir::Type eleTy, unsigned rank) {
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 1a31ca33e9465..7b9b33e8e4172 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -809,7 +809,10 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
   } else if (fir::isRecordWithTypeParameters(eleTy)) {
     return fir::BoxType::get(eleTy);
   }
-  return fir::ReferenceType::get(eleTy);
+  const bool isVolatile = fir::isa_volatile_ref_type(variable.getType());
+  auto newty = fir::ReferenceType::get(eleTy, isVolatile);
+  llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n" << variable << " " << variable.getType() << " newty:" << newty << " isvol:" << isVolatile << "\n";
+  return newty;
 }
 
 mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 2cb4cea58c2b0..2ef9fc79403c7 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3218,6 +3218,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
                   mlir::ConversionPatternRewriter &rewriter) const override {
 
     mlir::Type llvmLoadTy = convertObjectType(load.getType());
+    const bool isVolatile = fir::isa_volatile_ref_type(load.getMemref().getType());
     if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) {
       // fir.box is a special case because it is considered an ssa value in
       // fir, but it is lowered as a pointer to a descriptor. So
@@ -3247,7 +3248,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
       mlir::Value boxSize =
           computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
       auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
-          loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
+          loc, newBoxStorage, inputBoxStorage, boxSize, isVolatile);
 
       if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
         memcpy.setTBAATags(*optionalTag);
@@ -3255,8 +3256,9 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
         attachTBAATag(memcpy, boxTy, boxTy, nullptr);
       rewriter.replaceOp(load, newBoxStorage);
     } else {
+      auto memref = adaptor.getOperands()[0];
       auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
-          load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs());
+          load.getLoc(), llvmLoadTy, memref, /*alignment=*/0, isVolatile);
       if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
         loadOp.setTBAATags(*optionalTag);
       else
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 3cd5d8725f749..6e3cd71fb2ccf 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -1057,18 +1057,41 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
 // ReferenceType
 //===----------------------------------------------------------------------===//
 
-// `ref` `<` type `>`
+// `ref` `<` type (, volatile)? (, async)? `>`
 mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
-  return parseTypeSingleton<fir::ReferenceType>(parser);
+  if (parser.parseLess())
+    return {};
+  mlir::Type eleTy;
+  if (parser.parseType(eleTy))
+    return {};
+  bool isVolatile = false;
+  bool isAsync = false;
+  while (parser.parseOptionalComma()) {
+    if (parser.parseOptionalKeyword(getVolatileKeyword())) {
+      isVolatile = true;
+    } else if (parser.parseOptionalKeyword(getAsyncKeyword())) {
+      isAsync = true;
+    } else {
+      return {};
+    }
+  }
+  if (parser.parseGreater())
+    return {};
+  return ReferenceType::get(eleTy, isVolatile, isAsync);
 }
 
 void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
-  printer << "<" << getEleTy() << '>';
+  printer << "<" << getEleTy();
+  if (isVolatile())
+    printer << ", volatile";
+  if (isAsync())
+    printer << ", async";
+  printer << '>';
 }
 
 llvm::LogicalResult fir::ReferenceType::verify(
     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
-    mlir::Type eleTy, mlir::UnitAttr isVolatile) {
+    mlir::Type eleTy, bool isVolatile, bool isAsync) {
   if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
                 ReferenceType, TypeDescType>(eleTy))
     return emitError() << "cannot build a reference to type: " << eleTy << '\n';
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 8851a3a7187b9..4a3308ff4e747 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -214,6 +214,13 @@ void hlfir::DeclareOp::build(mlir::OpBuilder &builder,
   auto nameAttr = builder.getStringAttr(uniq_name);
   mlir::Type inputType = memref.getType();
   bool hasExplicitLbs = hasExplicitLowerBounds(shape);
+  if (fortran_attrs && mlir::isa<fir::ReferenceType>(inputType) &&
+      bitEnumContainsAny(fortran_attrs.getFlags(),
+                         fir::FortranVariableFlagsEnum::fortran_volatile)) {
+    auto refType = mlir::cast<fir::ReferenceType>(inputType);
+    inputType = fir::ReferenceType::get(refType.getEleTy(), true);
+    memref = builder.create<fir::ConvertOp>(memref.getLoc(), inputType, memref);
+  }
   mlir::Type hlfirVariableType =
       getHLFIRVariableType(inputType, hasExplicitLbs);
   build(builder, result, {hlfirVariableType, inputType}, memref, shape,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 96a3622f4afee..e22b3d224ca1f 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -1126,7 +1126,7 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
       builder.create<fir::StoreOp>(loc, flagSet, flagRef);
       mlir::Type resultElemTy =
           hlfir::getFortranElementType(resultArr.getType());
-      mlir::Type returnRefTy = builder.getRefType(resultElemTy);
+      mlir::Type returnRefTy = builder.getRefType(resultElemTy, fir::isa_volatile_ref_type(flagRef.getType()));
       mlir::IndexType idxTy = builder.getIndexType();
 
       for (unsigned int i = 0; i < rank; ++i) {
@@ -1153,7 +1153,7 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
     auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
                         const mlir::Type &resultElemType, mlir::Value resultArr,
                         mlir::Value index) {
-      mlir::Type resultRefTy = builder.getRefType(resultElemType);
+      mlir::Type resultRefTy = builder.getRefType(resultElemType, fir::isa_volatile_ref_type(resultArr.getType()));
       mlir::Value oneIdx =
           builder.createIntegerConstant(loc, builder.getIndexType(), 1);
       index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx);
@@ -1162,8 +1162,9 @@ class ReductionMaskConversion : public mlir::OpRewritePattern<Op> {
     };
 
     // Initialize the result
+    const bool isVolatile = fir::isa_volatile_ref_type(resultArr.getType());
     mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
-    mlir::Type resultRefTy = builder.getRefType(resultElemTy);
+    mlir::Type resultRefTy = builder.getRefType(resultElemTy, isVolatile);
     mlir::Value returnValue =
         builder.createIntegerConstant(loc, resultElemTy, 0);
     for (unsigned int i = 0; i < rank; ++i) {

>From 4639ebb32ee246f857a0525d83e0465421131da0 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Fri, 21 Mar 2025 13:07:36 -0700
Subject: [PATCH 3/3] checkpoint

---
 .../include/flang/Optimizer/Dialect/FIRTypes.td  |  9 +++------
 flang/lib/Lower/CallInterface.cpp                |  1 -
 flang/lib/Lower/ConvertExprToHLFIR.cpp           | 12 ++++--------
 flang/lib/Optimizer/Builder/HLFIRTools.cpp       |  4 +---
 flang/lib/Optimizer/Dialect/FIRType.cpp          | 16 +++++++---------
 5 files changed, 15 insertions(+), 27 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index a08943f5067fc..c11758cfe9244 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -366,23 +366,20 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
 
   let parameters = (ins
     "mlir::Type":$eleTy,
-    DefaultValuedParameter<"bool", "false">:$isVol,
-    DefaultValuedParameter<"bool", "false">:$isAsync);
+    DefaultValuedParameter<"bool", "false">:$isVol);
 
   let skipDefaultBuilders = 1;
 
   let builders = [
-    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol, CArg<"bool", "false">:$isAsync), [{
-      return Base::get(elementType.getContext(), elementType, isVol, isAsync);
+    TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType, CArg<"bool", "false">:$isVol), [{
+      return Base::get(elementType.getContext(), elementType, isVol);
     }]>,
   ];
 
   let extraClassDeclaration = [{
     mlir::Type getElementType() const { return getEleTy(); }
     bool isVolatile() const { return (bool)getIsVol(); }
-    bool isAsync() const { return (bool)getIsAsync(); }
     static llvm::StringRef getVolatileKeyword() { return "volatile"; }
-    static llvm::StringRef getAsyncKeyword() { return "async"; }
   }];
 
   let genVerifyDecl = 1;
diff --git a/flang/lib/Lower/CallInterface.cpp b/flang/lib/Lower/CallInterface.cpp
index c741f1c1d2c76..4ee28fbeb9a0c 100644
--- a/flang/lib/Lower/CallInterface.cpp
+++ b/flang/lib/Lower/CallInterface.cpp
@@ -1112,7 +1112,6 @@ class Fortran::lower::CallInterfaceImpl {
     if (obj.attrs.test(Attrs::Value))
       isValueAttr = true; // TODO: do we want an mlir::Attribute as well?
     if (obj.attrs.test(Attrs::Volatile)) {
-      // TODO(loc, "VOLATILE in procedure interface");
       addMLIRAttr(fir::getVolatileAttrName());
     }
     // obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 6908092399a8d..79906c81ecc68 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -224,8 +224,7 @@ class HlfirDesignatorBuilder {
             /*namedConstantSectionsAreAlwaysContiguous=*/false))
       return fir::BoxType::get(resultValueType);
 
-    // TODO: handle async references
-    bool isVolatile = false, isAsync = false;
+    bool isVolatile = false;
 
     // Check if the base type is volatile
     if (partInfo.base.has_value()) {
@@ -247,13 +246,13 @@ class HlfirDesignatorBuilder {
       if (isVolatileSymbol(designatorNode.GetLastSymbol()))
         isVolatile = true;
     }
-    
+
     // If it's a reference to a ref, account for it
     if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(resultValueType))
       resultValueType = refTy.getEleTy();
 
     // Other designators can be handled as raw addresses.
-    return fir::ReferenceType::get(resultValueType, isVolatile, isAsync);
+    return fir::ReferenceType::get(resultValueType, isVolatile);
   }
 
   template <typename T>
@@ -298,7 +297,6 @@ class HlfirDesignatorBuilder {
         partInfo.componentName, partInfo.componentShape, partInfo.subscripts,
         partInfo.substring, partInfo.complexPart, partInfo.resultShape,
         partInfo.typeParams, attributes);
-    llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n" << designate << "\n" << designatorType << "\n";
     if (auto elementalAddrOp = getVectorSubscriptElementAddrOp())
       builder.setInsertionPoint(*elementalAddrOp);
     return mlir::cast<fir::FortranVariableOpInterface>(
@@ -1829,7 +1827,6 @@ class HlfirBuilder {
             /*complexPart=*/std::nullopt,
             /*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{},
             fir::FortranVariableFlagsAttr{});
-        llvm::dbgs() << __LINE__ << " " << newParent << "\n";
         currentParent = hlfir::EntityWithAttributes{newParent};
       }
       valuesAndParents.emplace_back(
@@ -1853,6 +1850,7 @@ class HlfirBuilder {
       assert(compType && "failed to retrieve component type");
       mlir::Value compShape =
           designatorBuilder.genComponentShape(sym, compType);
+      mlir::Type designatorType = builder.getRefType(compType, isVolatile);
 
       mlir::Type fieldElemType = hlfir::getFortranElementType(compType);
       llvm::SmallVector<mlir::Value, 1> typeParams;
@@ -1873,7 +1871,6 @@ class HlfirBuilder {
       // Convert component symbol attributes to variable attributes.
       fir::FortranVariableFlagsAttr attrs =
           Fortran::lower::translateSymbolAttributes(builder.getContext(), sym);
-      mlir::Type designatorType = builder.getRefType(compType, isVolatile);
 
       // Get the component designator.
       auto lhs = builder.create<hlfir::DesignateOp>(
@@ -1882,7 +1879,6 @@ class HlfirBuilder {
           /*substring=*/mlir::ValueRange{},
           /*complexPart=*/std::nullopt,
           /*shape=*/compShape, typeParams, attrs);
-      llvm::dbgs() << __LINE__ << " " << lhs << "\n";
 
       if (attrs && bitEnumContainsAny(attrs.getFlags(),
                                       fir::FortranVariableFlagsEnum::pointer)) {
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 7b9b33e8e4172..cf8bb7eaddf70 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -810,9 +810,7 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
     return fir::BoxType::get(eleTy);
   }
   const bool isVolatile = fir::isa_volatile_ref_type(variable.getType());
-  auto newty = fir::ReferenceType::get(eleTy, isVolatile);
-  llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n" << variable << " " << variable.getType() << " newty:" << newty << " isvol:" << isVolatile << "\n";
-  return newty;
+  return fir::ReferenceType::get(eleTy, isVolatile);
 }
 
 mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 6e3cd71fb2ccf..e2dc1ed3f3ecb 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -1057,41 +1057,39 @@ unsigned fir::RecordType::getFieldIndex(llvm::StringRef ident) {
 // ReferenceType
 //===----------------------------------------------------------------------===//
 
-// `ref` `<` type (, volatile)? (, async)? `>`
+// `ref` `<` type (`, volatile` $volatile^)? (`, async` $async^)? `>`
 mlir::Type fir::ReferenceType::parse(mlir::AsmParser &parser) {
   if (parser.parseLess())
     return {};
+
   mlir::Type eleTy;
   if (parser.parseType(eleTy))
     return {};
+
   bool isVolatile = false;
-  bool isAsync = false;
-  while (parser.parseOptionalComma()) {
+  if (parser.parseOptionalComma()) {
     if (parser.parseOptionalKeyword(getVolatileKeyword())) {
       isVolatile = true;
-    } else if (parser.parseOptionalKeyword(getAsyncKeyword())) {
-      isAsync = true;
     } else {
       return {};
     }
   }
+
   if (parser.parseGreater())
     return {};
-  return ReferenceType::get(eleTy, isVolatile, isAsync);
+  return ReferenceType::get(eleTy, isVolatile);
 }
 
 void fir::ReferenceType::print(mlir::AsmPrinter &printer) const {
   printer << "<" << getEleTy();
   if (isVolatile())
     printer << ", volatile";
-  if (isAsync())
-    printer << ", async";
   printer << '>';
 }
 
 llvm::LogicalResult fir::ReferenceType::verify(
     llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
-    mlir::Type eleTy, bool isVolatile, bool isAsync) {
+    mlir::Type eleTy, bool isVolatile) {
   if (mlir::isa<ShapeType, ShapeShiftType, SliceType, FieldType, LenType,
                 ReferenceType, TypeDescType>(eleTy))
     return emitError() << "cannot build a reference to type: " << eleTy << '\n';



More information about the flang-commits mailing list