[flang-commits] [flang] [llvm] [mlir] [mlir] Mark `isa/dyn_cast/cast/...` member functions deprecated. (PR #89998)

Christian Sigg via flang-commits flang-commits at lists.llvm.org
Thu Apr 25 03:07:55 PDT 2024


https://github.com/chsigg updated https://github.com/llvm/llvm-project/pull/89998

>From d90b64838859993de356406697e010f129b04b53 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Wed, 24 Apr 2024 23:46:33 +0200
Subject: [PATCH 1/3] Mark `isa/dyn_cast/cast/...` member functions deprecated.

See https://mlir.llvm.org/deprecation and
https://discourse.llvm.org/t/preferred-casting-style-going-forward
---
 mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td |  4 ++--
 .../mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td   |  6 +++---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td   | 10 +++++-----
 mlir/include/mlir/IR/Attributes.h                |  5 +++++
 .../include/mlir/IR/BuiltinLocationAttributes.td | 13 ++++++++-----
 mlir/include/mlir/IR/Location.h                  |  3 +++
 mlir/include/mlir/IR/Types.h                     |  5 +++++
 mlir/include/mlir/IR/Value.h                     |  8 ++++----
 .../ComplexToStandard/ComplexToStandard.cpp      |  2 +-
 .../Polynomial/IR/PolynomialAttributes.cpp       |  2 +-
 .../Vector/Transforms/VectorLinearize.cpp        | 16 +++++++---------
 11 files changed, 44 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index da12e7c83b22b8..64c538367267dc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -138,10 +138,10 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
 
   let extraClassDeclaration = [{
     ShapedType getInputOperandType() {
-      return getInput().getType().cast<ShapedType>();
+      return cast<ShapedType>(getInput().getType());
     }
     ShapedType getOutputOperandType() {
-      return getOutput().getType().cast<ShapedType>();
+      return cast<ShapedType>(getOutput().getType());
     }
     int64_t getInputOperandRank() {
       return getInputOperandType().getRank();
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index ab9b78e755d9d5..c23937cac7538c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -234,8 +234,8 @@ def OffloadModuleInterface : OpInterface<"OffloadModuleInterface"> {
       /*methodName=*/"getIsTargetDevice",
       (ins), [{}], [{
         if (Attribute isTargetDevice = $_op->getAttr("omp.is_target_device"))
-          if (isTargetDevice.isa<mlir::BoolAttr>())
-           return isTargetDevice.dyn_cast<BoolAttr>().getValue();
+          if (isa<mlir::BoolAttr>(isTargetDevice))
+           return dyn_cast<BoolAttr>(isTargetDevice).getValue();
         return false;
       }]>,
     InterfaceMethod<
@@ -259,7 +259,7 @@ def OffloadModuleInterface : OpInterface<"OffloadModuleInterface"> {
       /*methodName=*/"getIsGPU",
       (ins), [{}], [{
         if (Attribute isTargetCGAttr = $_op->getAttr("omp.is_gpu"))
-          if (auto isTargetCGVal = isTargetCGAttr.dyn_cast<BoolAttr>())
+          if (auto isTargetCGVal = dyn_cast<BoolAttr>(isTargetCGAttr))
            return isTargetCGVal.getValue();
         return false;
       }]>,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c6f7f83441b96c..07b53a5a077da2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -164,10 +164,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     /// source operand. They overide static shape from source memref type.
     ArrayRef<int64_t> getStaticSizes() {
       auto attr = getConstShapeAttr();
-      if (getSourceType().isa<IntegerType>() || attr)
+      if (llvm::isa<IntegerType>(getSourceType()) || attr)
         return attr;
 
-      auto memrefType = getSourceType().dyn_cast<MemRefType>();
+      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
       assert(memrefType && "Incorrect use of getStaticSizes");
       return memrefType.getShape();
     }
@@ -179,10 +179,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     /// source operand. They overide static strides from source memref type.
     ArrayRef<int64_t> getStaticStrides() {
       auto attr = getConstStridesAttr();
-      if (getSourceType().isa<IntegerType>() || attr)
+      if (llvm::isa<IntegerType>(getSourceType()) || attr)
         return attr;
 
-      auto memrefType = getSourceType().dyn_cast<MemRefType>();
+      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
       assert(memrefType && "Incorrect use of getStaticStrides");
       auto [strides, offset] = getStridesAndOffset(memrefType);
       // reuse the storage of ConstStridesAttr since strides from
@@ -196,7 +196,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     /// `static_shape` and `static_strides` attributes.
     std::array<unsigned, 3> getArrayAttrMaxRanks() {
       unsigned rank;
-      if (auto ty = getSourceType().dyn_cast<MemRefType>()) {
+      if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
         rank = ty.getRank();
       } else {
         rank = (unsigned)getMixedOffsets().size();
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index cc0cee6a31183c..8a077865b51b5f 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -50,14 +50,19 @@ class Attribute {
   /// Casting utility functions. These are deprecated and will be removed,
   /// please prefer using the `llvm` namespace variants instead.
   template <typename... Tys>
+  [[deprecated("Use mlir::isa<U>() instead")]]
   bool isa() const;
   template <typename... Tys>
+  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
   bool isa_and_nonnull() const;
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
   U dyn_cast() const;
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
   U dyn_cast_or_null() const;
   template <typename U>
+  [[deprecated("Use mlir::cast<U>() instead")]]
   U cast() const;
 
   /// Return a unique identifier for the concrete attribute type. This is used
diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index dfcc180071f72a..5a72404dea15bb 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -228,7 +228,8 @@ def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc"> {
     template <typename T> static T getUnderlyingLocation(Location location) {
       assert(isa<T>(location));
       return reinterpret_cast<T>(
-          location.cast<mlir::OpaqueLoc>().getUnderlyingLocation());
+          mlir::cast<mlir::OpaqueLoc>(static_cast<LocationAttr>(location))
+              .getUnderlyingLocation());
     }
 
     /// Returns a pointer to some data structure that opaque location stores.
@@ -237,15 +238,17 @@ def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc"> {
     template <typename T>
     static T getUnderlyingLocationOrNull(Location location) {
       return isa<T>(location)
-                 ? reinterpret_cast<T>(
-                       location.cast<mlir::OpaqueLoc>().getUnderlyingLocation())
-                 : T(nullptr);
+                ? reinterpret_cast<T>(mlir::cast<mlir::OpaqueLoc>(
+                                          static_cast<LocationAttr>(location))
+                                          .getUnderlyingLocation())
+                : T(nullptr);
     }
 
     /// Checks whether provided location is opaque location and contains a
     /// pointer to an object of particular type.
     template <typename T> static bool isa(Location location) {
-      auto opaque_loc = location.dyn_cast<OpaqueLoc>();
+      auto opaque_loc =
+          mlir::dyn_cast<OpaqueLoc>(static_cast<LocationAttr>(location));
       return opaque_loc && opaque_loc.getUnderlyingTypeID() == TypeID::get<T>();
     }
   }];
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index aa8314f38cdfac..423b4d19b5b944 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -78,14 +78,17 @@ class Location {
 
   /// Type casting utilities on the underlying location.
   template <typename U>
+  [[deprecated("Use mlir::isa<U>() instead")]]
   bool isa() const {
     return llvm::isa<U>(*this);
   }
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
   U dyn_cast() const {
     return llvm::dyn_cast<U>(*this);
   }
   template <typename U>
+  [[deprecated("Use mlir::cast<U>() instead")]]
   U cast() const {
     return llvm::cast<U>(*this);
   }
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index a89e13b625bf40..65824531fdc908 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -97,14 +97,19 @@ class Type {
   bool operator!() const { return impl == nullptr; }
 
   template <typename... Tys>
+  [[deprecated("Use mlir::isa<U>() instead")]]
   bool isa() const;
   template <typename... Tys>
+  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
   bool isa_and_nonnull() const;
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
   U dyn_cast() const;
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
   U dyn_cast_or_null() const;
   template <typename U>
+  [[deprecated("Use mlir::cast<U>() instead")]]
   U cast() const;
 
   /// Return a unique identifier for the concrete type. This is used to support
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index cdbc6cc374368c..a7344c64e6730d 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -98,25 +98,25 @@ class Value {
   constexpr Value(detail::ValueImpl *impl = nullptr) : impl(impl) {}
 
   template <typename U>
-  [[deprecated("Use isa<U>() instead")]]
+  [[deprecated("Use mlir::isa<U>() instead")]]
   bool isa() const {
     return llvm::isa<U>(*this);
   }
 
   template <typename U>
-  [[deprecated("Use dyn_cast<U>() instead")]]
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
   U dyn_cast() const {
     return llvm::dyn_cast<U>(*this);
   }
 
   template <typename U>
-  [[deprecated("Use dyn_cast_or_null<U>() instead")]]
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
   U dyn_cast_or_null() const {
     return llvm::dyn_cast_or_null<U>(*this);
   }
 
   template <typename U>
-  [[deprecated("Use cast<U>() instead")]]
+  [[deprecated("Use mlir::cast<U>() instead")]]
   U cast() const {
     return llvm::cast<U>(*this);
   }
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 4a15976d40c763..c2a83f90bcbe9d 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -857,7 +857,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     auto type = cast<ComplexType>(op.getType());
-    auto elementType = type.getElementType().cast<FloatType>();
+    auto elementType = cast<FloatType>(type.getElementType());
     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
 
     auto cst = [&](APFloat v) {
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index ee09c73bb3c4ae..f1ec2be72a33ab 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -172,7 +172,7 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
     if (failed(parser.parseEqual()))
       return {};
 
-    IntegerType iType = ty.dyn_cast<IntegerType>();
+    IntegerType iType = mlir::dyn_cast<IntegerType>(ty);
     if (!iType) {
       parser.emitError(parser.getCurrentLocation(),
                        "coefficientType must specify an integer type");
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 69999f0918c103..802a64b0805ee4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -140,7 +140,7 @@ struct LinearizeVectorExtractStridedSlice final
                   ConversionPatternRewriter &rewriter) const override {
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
     assert(!(extractOp.getVector().getType().isScalable() ||
-             dstType.cast<VectorType>().isScalable()) &&
+             cast<VectorType>(dstType).isScalable()) &&
            "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
@@ -172,7 +172,7 @@ struct LinearizeVectorExtractStridedSlice final
     // Get total number of extracted slices.
     int64_t nExtractedSlices = 1;
     for (Attribute size : sizes) {
-      nExtractedSlices *= size.cast<IntegerAttr>().getInt();
+      nExtractedSlices *= cast<IntegerAttr>(size).getInt();
     }
     // Compute the strides of the source vector considering first k dimensions.
     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
@@ -189,7 +189,7 @@ struct LinearizeVectorExtractStridedSlice final
     // Compute extractedStrides.
     for (int i = kD - 2; i >= 0; --i) {
       extractedStrides[i] =
-          extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
+          extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
     }
     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
     // and compute the multi-dimensional index and the corresponding linearized
@@ -207,7 +207,7 @@ struct LinearizeVectorExtractStridedSlice final
       int64_t linearizedIndex = 0;
       for (int64_t j = 0; j < kD; ++j) {
         linearizedIndex +=
-            (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
+            (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
             sourceStrides[j];
       }
       // Fill the indices array form linearizedIndex to linearizedIndex +
@@ -254,7 +254,7 @@ struct LinearizeVectorShuffle final
     Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
     assert(!(shuffleOp.getV1VectorType().isScalable() ||
              shuffleOp.getV2VectorType().isScalable() ||
-             dstType.cast<VectorType>().isScalable()) &&
+             cast<VectorType>(dstType).isScalable()) &&
            "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
@@ -324,7 +324,7 @@ struct LinearizeVectorExtract final
                   ConversionPatternRewriter &rewriter) const override {
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     assert(!(extractOp.getVector().getType().isScalable() ||
-             dstTy.cast<VectorType>().isScalable()) &&
+             cast<VectorType>(dstTy).isScalable()) &&
            "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
@@ -405,9 +405,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
       [=](vector::ShuffleOp shuffleOp) -> bool {
         return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
                    ? (typeConverter.isLegal(shuffleOp) &&
-                      shuffleOp.getResult()
-                              .getType()
-                              .cast<mlir::VectorType>()
+                      cast<mlir::VectorType>(shuffleOp.getResult().getType())
                               .getRank() == 1)
                    : true;
       });

>From 644910a5b4ef1cd48f1c7936f8bc2e2398894a13 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Thu, 25 Apr 2024 10:01:10 +0200
Subject: [PATCH 2/3] Update flang code, disable warning in TypeSwitch.

---
 .../include/flang/Optimizer/Dialect/FIRType.h | 70 ++++++++++---------
 .../Dialect/FortranVariableInterface.td       | 10 +--
 flang/lib/Optimizer/Dialect/FIRAttr.cpp       | 14 ++--
 .../Dialect/FortranVariableInterface.cpp      | 11 +--
 llvm/include/llvm/ADT/TypeSwitch.h            |  3 +
 5 files changed, 57 insertions(+), 51 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 7fcd9c1babf24f..219fc3459af3c8 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -97,35 +97,36 @@ bool isa_fir_or_std_type(mlir::Type t);
 
 /// Is `t` a FIR dialect type that implies a memory (de)reference?
 inline bool isa_ref_type(mlir::Type t) {
-  return t.isa<fir::ReferenceType, fir::PointerType, fir::HeapType,
-               fir::LLVMPointerType>();
+  return mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType,
+                   fir::LLVMPointerType>(t);
 }
 
 /// Is `t` a boxed type?
 inline bool isa_box_type(mlir::Type t) {
-  return t.isa<fir::BaseBoxType, fir::BoxCharType, fir::BoxProcType>();
+  return mlir::isa<fir::BaseBoxType, fir::BoxCharType, fir::BoxProcType>(t);
 }
 
 /// Is `t` a type that is always trivially pass-by-reference? Specifically, this
 /// is testing if `t` is a ReferenceType or any box type. Compare this to
 /// conformsWithPassByRef(), which includes pointers and allocatables.
 inline bool isa_passbyref_type(mlir::Type t) {
-  return t.isa<fir::ReferenceType, mlir::FunctionType>() || isa_box_type(t);
+  return mlir::isa<fir::ReferenceType, mlir::FunctionType>(t) ||
+         isa_box_type(t);
 }
 
 /// Is `t` a type that can conform to be pass-by-reference? Depending on the
 /// context, these types may simply demote to pass-by-reference or a reference
 /// to them may have to be passed instead. Functions are always referent.
 inline bool conformsWithPassByRef(mlir::Type t) {
-  return isa_ref_type(t) || isa_box_type(t) || t.isa<mlir::FunctionType>();
+  return isa_ref_type(t) || isa_box_type(t) || mlir::isa<mlir::FunctionType>(t);
 }
 
 /// Is `t` a derived (record) type?
-inline bool isa_derived(mlir::Type t) { return t.isa<fir::RecordType>(); }
+inline bool isa_derived(mlir::Type t) { return mlir::isa<fir::RecordType>(t); }
 
 /// Is `t` type(c_ptr) or type(c_funptr)?
 inline bool isa_builtin_cptr_type(mlir::Type t) {
-  if (auto recTy = t.dyn_cast_or_null<fir::RecordType>())
+  if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(t))
     return recTy.getName().ends_with("T__builtin_c_ptr") ||
            recTy.getName().ends_with("T__builtin_c_funptr");
   return false;
@@ -133,7 +134,7 @@ inline bool isa_builtin_cptr_type(mlir::Type t) {
 
 /// Is `t` a FIR dialect aggregate type?
 inline bool isa_aggregate(mlir::Type t) {
-  return t.isa<SequenceType, mlir::TupleType>() || fir::isa_derived(t);
+  return mlir::isa<SequenceType, mlir::TupleType>(t) || fir::isa_derived(t);
 }
 
 /// Extract the `Type` pointed to from a FIR memory reference type. If `t` is
@@ -146,17 +147,17 @@ mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t);
 
 /// Is `t` a FIR Real or MLIR Float type?
 inline bool isa_real(mlir::Type t) {
-  return t.isa<fir::RealType, mlir::FloatType>();
+  return mlir::isa<fir::RealType, mlir::FloatType>(t);
 }
 
 /// Is `t` an integral type?
 inline bool isa_integer(mlir::Type t) {
-  return t.isa<mlir::IndexType, mlir::IntegerType, fir::IntegerType>();
+  return mlir::isa<mlir::IndexType, mlir::IntegerType, fir::IntegerType>(t);
 }
 
 /// Is `t` a vector type?
 inline bool isa_vector(mlir::Type t) {
-  return t.isa<mlir::VectorType, fir::VectorType>();
+  return mlir::isa<mlir::VectorType, fir::VectorType>(t);
 }
 
 mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);
@@ -169,22 +170,22 @@ void verifyIntegralType(mlir::Type type);
 
 /// Is `t` a FIR or MLIR Complex type?
 inline bool isa_complex(mlir::Type t) {
-  return t.isa<fir::ComplexType, mlir::ComplexType>();
+  return mlir::isa<fir::ComplexType, mlir::ComplexType>(t);
 }
 
 /// Is `t` a CHARACTER type? Does not check the length.
-inline bool isa_char(mlir::Type t) { return t.isa<fir::CharacterType>(); }
+inline bool isa_char(mlir::Type t) { return mlir::isa<fir::CharacterType>(t); }
 
 /// Is `t` a trivial intrinsic type? CHARACTER is <em>excluded</em> because it
 /// is a dependent type.
 inline bool isa_trivial(mlir::Type t) {
   return isa_integer(t) || isa_real(t) || isa_complex(t) || isa_vector(t) ||
-         t.isa<fir::LogicalType>();
+         mlir::isa<fir::LogicalType>(t);
 }
 
 /// Is `t` a CHARACTER type with a LEN other than 1?
 inline bool isa_char_string(mlir::Type t) {
-  if (auto ct = t.dyn_cast_or_null<fir::CharacterType>())
+  if (auto ct = mlir::dyn_cast_or_null<fir::CharacterType>(t))
     return ct.getLen() != fir::CharacterType::singleton();
   return false;
 }
@@ -198,7 +199,7 @@ bool isa_unknown_size_box(mlir::Type t);
 
 /// Returns true iff `t` is a fir.char type and has an unknown length.
 inline bool characterWithDynamicLen(mlir::Type t) {
-  if (auto charTy = t.dyn_cast<fir::CharacterType>())
+  if (auto charTy = mlir::dyn_cast<fir::CharacterType>(t))
     return charTy.hasDynamicLen();
   return false;
 }
@@ -213,11 +214,11 @@ inline bool sequenceWithNonConstantShape(fir::SequenceType seqTy) {
 bool hasDynamicSize(mlir::Type t);
 
 inline unsigned getRankOfShapeType(mlir::Type t) {
-  if (auto shTy = t.dyn_cast<fir::ShapeType>())
+  if (auto shTy = mlir::dyn_cast<fir::ShapeType>(t))
     return shTy.getRank();
-  if (auto shTy = t.dyn_cast<fir::ShapeShiftType>())
+  if (auto shTy = mlir::dyn_cast<fir::ShapeShiftType>(t))
     return shTy.getRank();
-  if (auto shTy = t.dyn_cast<fir::ShiftType>())
+  if (auto shTy = mlir::dyn_cast<fir::ShiftType>(t))
     return shTy.getRank();
   return 0;
 }
@@ -225,14 +226,14 @@ inline unsigned getRankOfShapeType(mlir::Type t) {
 /// Get the memory reference type of the data pointer from the box type,
 inline mlir::Type boxMemRefType(fir::BaseBoxType t) {
   auto eleTy = t.getEleTy();
-  if (!eleTy.isa<fir::PointerType, fir::HeapType>())
+  if (!mlir::isa<fir::PointerType, fir::HeapType>(eleTy))
     eleTy = fir::ReferenceType::get(t);
   return eleTy;
 }
 
 /// If `t` is a SequenceType return its element type, otherwise return `t`.
 inline mlir::Type unwrapSequenceType(mlir::Type t) {
-  if (auto seqTy = t.dyn_cast<fir::SequenceType>())
+  if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(t))
     return seqTy.getEleTy();
   return t;
 }
@@ -278,7 +279,7 @@ inline fir::SequenceType unwrapUntilSeqType(mlir::Type t) {
       t = ty;
       continue;
     }
-    if (auto seqTy = t.dyn_cast<fir::SequenceType>())
+    if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(t))
       return seqTy;
     return {};
   }
@@ -377,7 +378,7 @@ bool isRecordWithDescriptorMember(mlir::Type ty);
 
 /// Return true iff `ty` is a RecordType with type parameters.
 inline bool isRecordWithTypeParameters(mlir::Type ty) {
-  if (auto recTy = ty.dyn_cast_or_null<fir::RecordType>())
+  if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(ty))
     return recTy.isDependentType();
   return false;
 }
@@ -401,14 +402,14 @@ mlir::Type fromRealTypeID(mlir::MLIRContext *context, llvm::Type::TypeID typeID,
 int getTypeCode(mlir::Type ty, const KindMapping &kindMap);
 
 inline bool BaseBoxType::classof(mlir::Type type) {
-  return type.isa<fir::BoxType, fir::ClassType>();
+  return mlir::isa<fir::BoxType, fir::ClassType>(type);
 }
 
 /// Return true iff `ty` is none or fir.array<none>.
 inline bool isNoneOrSeqNone(mlir::Type type) {
-  if (auto seqTy = type.dyn_cast<fir::SequenceType>())
-    return seqTy.getEleTy().isa<mlir::NoneType>();
-  return type.isa<mlir::NoneType>();
+  if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(type))
+    return mlir::<mlir::NoneType>(seqTy.getEleTy());
+  return mlir::isa<mlir::NoneType>(type);
 }
 
 /// Return a fir.box<T> or fir.class<T> if the type is polymorphic. If the type
@@ -428,16 +429,16 @@ inline mlir::Type wrapInClassOrBoxType(mlir::Type eleTy,
 /// !fir.array<2xf32> -> !fir.array<2xnone>
 /// !fir.heap<!fir.array<2xf32>> -> !fir.heap<!fir.array<2xnone>>
 inline mlir::Type updateTypeForUnlimitedPolymorphic(mlir::Type ty) {
-  if (auto seqTy = ty.dyn_cast<fir::SequenceType>())
+  if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
     return fir::SequenceType::get(
         seqTy.getShape(), updateTypeForUnlimitedPolymorphic(seqTy.getEleTy()));
-  if (auto heapTy = ty.dyn_cast<fir::HeapType>())
+  if (auto heapTy = mlir::dyn_cast<fir::HeapType>(ty))
     return fir::HeapType::get(
         updateTypeForUnlimitedPolymorphic(heapTy.getEleTy()));
-  if (auto pointerTy = ty.dyn_cast<fir::PointerType>())
+  if (auto pointerTy = mlir::dyn_cast<fir::PointerType>(ty))
     return fir::PointerType::get(
         updateTypeForUnlimitedPolymorphic(pointerTy.getEleTy()));
-  if (!ty.isa<mlir::NoneType, fir::RecordType>())
+  if (!mlir::isa<mlir::NoneType, fir::RecordType>(ty))
     return mlir::NoneType::get(ty.getContext());
   return ty;
 }
@@ -451,18 +452,19 @@ mlir::Type changeElementType(mlir::Type type, mlir::Type newElementType,
 
 /// Is `t` an address to fir.box or class type?
 inline bool isBoxAddress(mlir::Type t) {
-  return fir::isa_ref_type(t) && fir::unwrapRefType(t).isa<fir::BaseBoxType>();
+  return fir::isa_ref_type(t) &&
+         mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(t));
 }
 
 /// Is `t` a fir.box or class address or value type?
 inline bool isBoxAddressOrValue(mlir::Type t) {
-  return fir::unwrapRefType(t).isa<fir::BaseBoxType>();
+  return mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(t));
 }
 
 /// Is this a fir.boxproc address type?
 inline bool isBoxProcAddressType(mlir::Type t) {
   t = fir::dyn_cast_ptrEleTy(t);
-  return t && t.isa<fir::BoxProcType>();
+  return t && mlir::isa<fir::BoxProcType>(t);
 }
 
 /// Return a string representation of `ty`.
diff --git a/flang/include/flang/Optimizer/Dialect/FortranVariableInterface.td b/flang/include/flang/Optimizer/Dialect/FortranVariableInterface.td
index 6405afbf1bfbc2..3f78a93a2515ef 100644
--- a/flang/include/flang/Optimizer/Dialect/FortranVariableInterface.td
+++ b/flang/include/flang/Optimizer/Dialect/FortranVariableInterface.td
@@ -75,7 +75,7 @@ def fir_FortranVariableOpInterface : OpInterface<"FortranVariableOpInterface"> {
     /// variable.
     mlir::Type getElementOrSequenceType() {
       mlir::Type type = fir::unwrapPassByRefType(fir::unwrapRefType(getBase().getType()));
-      if (auto boxCharType = type.dyn_cast<fir::BoxCharType>())
+      if (auto boxCharType = mlir::dyn_cast<fir::BoxCharType>(type))
         return boxCharType.getEleTy();
       return type;
     }
@@ -87,13 +87,13 @@ def fir_FortranVariableOpInterface : OpInterface<"FortranVariableOpInterface"> {
 
     /// Is the variable an array?
     bool isArray() {
-      return getElementOrSequenceType().isa<fir::SequenceType>();
+      return mlir::isa<fir::SequenceType>(getElementOrSequenceType());
     }
 
     /// Return the rank of the entity if it is known at compile time.
     std::optional<unsigned> getRank() {
       if (auto sequenceType =
-            getElementOrSequenceType().dyn_cast<fir::SequenceType>()) {
+            mlir::dyn_cast<fir::SequenceType>(getElementOrSequenceType())) {
         if (sequenceType.hasUnknownShape())
           return {};
         return sequenceType.getDimension();
@@ -133,7 +133,7 @@ def fir_FortranVariableOpInterface : OpInterface<"FortranVariableOpInterface"> {
 
     /// Is this a Fortran character variable?
     bool isCharacter() {
-      return getElementType().isa<fir::CharacterType>();
+      return mlir::isa<fir::CharacterType>(getElementType());
     }
 
     /// Is this a Fortran character variable with an explicit length?
@@ -149,7 +149,7 @@ def fir_FortranVariableOpInterface : OpInterface<"FortranVariableOpInterface"> {
 
     /// Is this variable represented as a fir.box or fir.class value?
     bool isBoxValue() {
-      return getBase().getType().isa<fir::BaseBoxType>();
+      return mlir::isa<fir::BaseBoxType>(getBase().getType());
     }
 
     /// Is this variable represented as a fir.box or fir.class address?
diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
index e43710f5627ee0..9ea3a0568f6916 100644
--- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -264,23 +264,23 @@ void fir::FortranVariableFlagsAttr::print(mlir::AsmPrinter &printer) const {
 void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
                             mlir::DialectAsmPrinter &p) {
   auto &os = p.getStream();
-  if (auto exact = attr.dyn_cast<fir::ExactTypeAttr>()) {
+  if (auto exact = mlir::dyn_cast<fir::ExactTypeAttr>(attr)) {
     os << fir::ExactTypeAttr::getAttrName() << '<';
     p.printType(exact.getType());
     os << '>';
-  } else if (auto sub = attr.dyn_cast<fir::SubclassAttr>()) {
+  } else if (auto sub = mlir::dyn_cast<fir::SubclassAttr>(attr)) {
     os << fir::SubclassAttr::getAttrName() << '<';
     p.printType(sub.getType());
     os << '>';
-  } else if (attr.dyn_cast_or_null<fir::PointIntervalAttr>()) {
+  } else if (mlir::dyn_cast_or_null<fir::PointIntervalAttr>(attr)) {
     os << fir::PointIntervalAttr::getAttrName();
-  } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) {
+  } else if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) {
     os << fir::ClosedIntervalAttr::getAttrName();
-  } else if (attr.dyn_cast_or_null<fir::LowerBoundAttr>()) {
+  } else if (mlir::dyn_cast_or_null<fir::LowerBoundAttr>(attr)) {
     os << fir::LowerBoundAttr::getAttrName();
-  } else if (attr.dyn_cast_or_null<fir::UpperBoundAttr>()) {
+  } else if (mlir::dyn_cast_or_null<fir::UpperBoundAttr>(attr)) {
     os << fir::UpperBoundAttr::getAttrName();
-  } else if (auto a = attr.dyn_cast_or_null<fir::RealAttr>()) {
+  } else if (auto a = mlir::dyn_cast_or_null<fir::RealAttr>(attr)) {
     os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x";
     llvm::SmallString<40> ss;
     a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
diff --git a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp
index 94f1689dfb0585..70b1a2f3d8446f 100644
--- a/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp
+++ b/flang/lib/Optimizer/Dialect/FortranVariableInterface.cpp
@@ -18,7 +18,7 @@ mlir::LogicalResult
 fir::FortranVariableOpInterface::verifyDeclareLikeOpImpl(mlir::Value memref) {
   const unsigned numExplicitTypeParams = getExplicitTypeParams().size();
   mlir::Type memType = memref.getType();
-  const bool sourceIsBoxValue = memType.isa<fir::BaseBoxType>();
+  const bool sourceIsBoxValue = mlir::isa<fir::BaseBoxType>(memType);
   const bool sourceIsBoxAddress = fir::isBoxAddress(memType);
   const bool sourceIsBox = sourceIsBoxValue || sourceIsBoxAddress;
   if (isCharacter()) {
@@ -29,7 +29,8 @@ fir::FortranVariableOpInterface::verifyDeclareLikeOpImpl(mlir::Value memref) {
       return emitOpError("must be provided exactly one type parameter when its "
                          "base is a character that is not a box");
 
-  } else if (auto recordType = getElementType().dyn_cast<fir::RecordType>()) {
+  } else if (auto recordType =
+                 mlir::dyn_cast<fir::RecordType>(getElementType())) {
     if (numExplicitTypeParams < recordType.getNumLenParams() && !sourceIsBox)
       return emitOpError("must be provided all the derived type length "
                          "parameters when the base is not a box");
@@ -45,16 +46,16 @@ fir::FortranVariableOpInterface::verifyDeclareLikeOpImpl(mlir::Value memref) {
       if (sourceIsBoxAddress)
         return emitOpError("for box address must not have a shape operand");
       unsigned shapeRank = 0;
-      if (auto shapeType = shape.getType().dyn_cast<fir::ShapeType>()) {
+      if (auto shapeType = mlir::dyn_cast<fir::ShapeType>(shape.getType())) {
         shapeRank = shapeType.getRank();
       } else if (auto shapeShiftType =
-                     shape.getType().dyn_cast<fir::ShapeShiftType>()) {
+                     mlir::dyn_cast<fir::ShapeShiftType>(shape.getType())) {
         shapeRank = shapeShiftType.getRank();
       } else {
         if (!sourceIsBoxValue)
           emitOpError("of array entity with a raw address base must have a "
                       "shape operand that is a shape or shapeshift");
-        shapeRank = shape.getType().cast<fir::ShiftType>().getRank();
+        shapeRank = mlir::cast<fir::ShiftType>(shape.getType()).getRank();
       }
 
       std::optional<unsigned> rank = getRank();
diff --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h
index 10a2d48e918db9..14ad56ad575ffc 100644
--- a/llvm/include/llvm/ADT/TypeSwitch.h
+++ b/llvm/include/llvm/ADT/TypeSwitch.h
@@ -74,7 +74,10 @@ template <typename DerivedT, typename T> class TypeSwitchBase {
       ValueT &&value,
       std::enable_if_t<is_detected<has_dyn_cast_t, ValueT, CastT>::value> * =
           nullptr) {
+    // Silence warnings about MLIR's deprecated dyn_cast member functions.
+    LLVM_SUPPRESS_DEPRECATED_DECLARATIONS_PUSH
     return value.template dyn_cast<CastT>();
+    LLVM_SUPPRESS_DEPRECATED_DECLARATIONS_POP
   }
 
   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is

>From 0967fa16dce5d0c49880830d9fe7fa19d875488c Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Thu, 25 Apr 2024 12:07:33 +0200
Subject: [PATCH 3/3] Fix typo.

---
 flang/include/flang/Optimizer/Dialect/FIRType.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 219fc3459af3c8..1f200b2e81a4ee 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -408,7 +408,7 @@ inline bool BaseBoxType::classof(mlir::Type type) {
 /// Return true iff `ty` is none or fir.array<none>.
 inline bool isNoneOrSeqNone(mlir::Type type) {
   if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(type))
-    return mlir::<mlir::NoneType>(seqTy.getEleTy());
+    return mlir::isa<mlir::NoneType>(seqTy.getEleTy());
   return mlir::isa<mlir::NoneType>(type);
 }
 



More information about the flang-commits mailing list