[Mlir-commits] [mlir] Mark `isa/dyn_cast/cast/...` member functions deprecated. (PR #89998)

Christian Sigg llvmlistbot at llvm.org
Wed Apr 24 14:47:56 PDT 2024


https://github.com/chsigg created https://github.com/llvm/llvm-project/pull/89998

See https://mlir.llvm.org/deprecation and
https://discourse.llvm.org/t/preferred-casting-style-going-forward.

>From d90b64838859993de356406697e010f129b04b53 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Wed, 24 Apr 2024 23:46:33 +0200
Subject: [PATCH] Mark `isa/dyn_cast/cast/...` member functions deprecated.

See https://mlir.llvm.org/deprecation and
https://discourse.llvm.org/t/preferred-casting-style-going-forward
---
 mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td |  4 ++--
 .../mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td   |  6 +++---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td   | 10 +++++-----
 mlir/include/mlir/IR/Attributes.h                |  5 +++++
 .../include/mlir/IR/BuiltinLocationAttributes.td | 13 ++++++++-----
 mlir/include/mlir/IR/Location.h                  |  3 +++
 mlir/include/mlir/IR/Types.h                     |  5 +++++
 mlir/include/mlir/IR/Value.h                     |  8 ++++----
 .../ComplexToStandard/ComplexToStandard.cpp      |  2 +-
 .../Polynomial/IR/PolynomialAttributes.cpp       |  2 +-
 .../Vector/Transforms/VectorLinearize.cpp        | 16 +++++++---------
 11 files changed, 44 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index da12e7c83b22b8..64c538367267dc 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -138,10 +138,10 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
 
   let extraClassDeclaration = [{
     ShapedType getInputOperandType() {
-      return getInput().getType().cast<ShapedType>();
+      return cast<ShapedType>(getInput().getType());
     }
     ShapedType getOutputOperandType() {
-      return getOutput().getType().cast<ShapedType>();
+      return cast<ShapedType>(getOutput().getType());
     }
     int64_t getInputOperandRank() {
       return getInputOperandType().getRank();
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index ab9b78e755d9d5..c23937cac7538c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -234,8 +234,8 @@ def OffloadModuleInterface : OpInterface<"OffloadModuleInterface"> {
       /*methodName=*/"getIsTargetDevice",
       (ins), [{}], [{
         if (Attribute isTargetDevice = $_op->getAttr("omp.is_target_device"))
-          if (isTargetDevice.isa<mlir::BoolAttr>())
-           return isTargetDevice.dyn_cast<BoolAttr>().getValue();
+          if (isa<mlir::BoolAttr>(isTargetDevice))
+           return dyn_cast<BoolAttr>(isTargetDevice).getValue();
         return false;
       }]>,
     InterfaceMethod<
@@ -259,7 +259,7 @@ def OffloadModuleInterface : OpInterface<"OffloadModuleInterface"> {
       /*methodName=*/"getIsGPU",
       (ins), [{}], [{
         if (Attribute isTargetCGAttr = $_op->getAttr("omp.is_gpu"))
-          if (auto isTargetCGVal = isTargetCGAttr.dyn_cast<BoolAttr>())
+          if (auto isTargetCGVal = dyn_cast<BoolAttr>(isTargetCGAttr))
            return isTargetCGVal.getValue();
         return false;
       }]>,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c6f7f83441b96c..07b53a5a077da2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -164,10 +164,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     /// source operand. They overide static shape from source memref type.
     ArrayRef<int64_t> getStaticSizes() {
       auto attr = getConstShapeAttr();
-      if (getSourceType().isa<IntegerType>() || attr)
+      if (llvm::isa<IntegerType>(getSourceType()) || attr)
         return attr;
 
-      auto memrefType = getSourceType().dyn_cast<MemRefType>();
+      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
       assert(memrefType && "Incorrect use of getStaticSizes");
       return memrefType.getShape();
     }
@@ -179,10 +179,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     /// source operand. They overide static strides from source memref type.
     ArrayRef<int64_t> getStaticStrides() {
       auto attr = getConstStridesAttr();
-      if (getSourceType().isa<IntegerType>() || attr)
+      if (llvm::isa<IntegerType>(getSourceType()) || attr)
         return attr;
 
-      auto memrefType = getSourceType().dyn_cast<MemRefType>();
+      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
       assert(memrefType && "Incorrect use of getStaticStrides");
       auto [strides, offset] = getStridesAndOffset(memrefType);
       // reuse the storage of ConstStridesAttr since strides from
@@ -196,7 +196,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     /// `static_shape` and `static_strides` attributes.
     std::array<unsigned, 3> getArrayAttrMaxRanks() {
       unsigned rank;
-      if (auto ty = getSourceType().dyn_cast<MemRefType>()) {
+      if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
         rank = ty.getRank();
       } else {
         rank = (unsigned)getMixedOffsets().size();
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index cc0cee6a31183c..8a077865b51b5f 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -50,14 +50,19 @@ class Attribute {
   /// Casting utility functions. These are deprecated and will be removed,
   /// please prefer using the `llvm` namespace variants instead.
   template <typename... Tys>
+  [[deprecated("Use mlir::isa<U>() instead")]]
   bool isa() const;
   template <typename... Tys>
+  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
   bool isa_and_nonnull() const;
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
   U dyn_cast() const;
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
   U dyn_cast_or_null() const;
   template <typename U>
+  [[deprecated("Use mlir::cast<U>() instead")]]
   U cast() const;
 
   /// Return a unique identifier for the concrete attribute type. This is used
diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
index dfcc180071f72a..5a72404dea15bb 100644
--- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td
@@ -228,7 +228,8 @@ def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc"> {
     template <typename T> static T getUnderlyingLocation(Location location) {
       assert(isa<T>(location));
       return reinterpret_cast<T>(
-          location.cast<mlir::OpaqueLoc>().getUnderlyingLocation());
+          mlir::cast<mlir::OpaqueLoc>(static_cast<LocationAttr>(location))
+              .getUnderlyingLocation());
     }
 
     /// Returns a pointer to some data structure that opaque location stores.
@@ -237,15 +238,17 @@ def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc"> {
     template <typename T>
     static T getUnderlyingLocationOrNull(Location location) {
       return isa<T>(location)
-                 ? reinterpret_cast<T>(
-                       location.cast<mlir::OpaqueLoc>().getUnderlyingLocation())
-                 : T(nullptr);
+                ? reinterpret_cast<T>(mlir::cast<mlir::OpaqueLoc>(
+                                          static_cast<LocationAttr>(location))
+                                          .getUnderlyingLocation())
+                : T(nullptr);
     }
 
     /// Checks whether provided location is opaque location and contains a
     /// pointer to an object of particular type.
     template <typename T> static bool isa(Location location) {
-      auto opaque_loc = location.dyn_cast<OpaqueLoc>();
+      auto opaque_loc =
+          mlir::dyn_cast<OpaqueLoc>(static_cast<LocationAttr>(location));
       return opaque_loc && opaque_loc.getUnderlyingTypeID() == TypeID::get<T>();
     }
   }];
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index aa8314f38cdfac..423b4d19b5b944 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -78,14 +78,17 @@ class Location {
 
   /// Type casting utilities on the underlying location.
   template <typename U>
+  [[deprecated("Use mlir::isa<U>() instead")]]
   bool isa() const {
     return llvm::isa<U>(*this);
   }
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
   U dyn_cast() const {
     return llvm::dyn_cast<U>(*this);
   }
   template <typename U>
+  [[deprecated("Use mlir::cast<U>() instead")]]
   U cast() const {
     return llvm::cast<U>(*this);
   }
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index a89e13b625bf40..65824531fdc908 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -97,14 +97,19 @@ class Type {
   bool operator!() const { return impl == nullptr; }
 
   template <typename... Tys>
+  [[deprecated("Use mlir::isa<U>() instead")]]
   bool isa() const;
   template <typename... Tys>
+  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
   bool isa_and_nonnull() const;
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
   U dyn_cast() const;
   template <typename U>
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
   U dyn_cast_or_null() const;
   template <typename U>
+  [[deprecated("Use mlir::cast<U>() instead")]]
   U cast() const;
 
   /// Return a unique identifier for the concrete type. This is used to support
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index cdbc6cc374368c..a7344c64e6730d 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -98,25 +98,25 @@ class Value {
   constexpr Value(detail::ValueImpl *impl = nullptr) : impl(impl) {}
 
   template <typename U>
-  [[deprecated("Use isa<U>() instead")]]
+  [[deprecated("Use mlir::isa<U>() instead")]]
   bool isa() const {
     return llvm::isa<U>(*this);
   }
 
   template <typename U>
-  [[deprecated("Use dyn_cast<U>() instead")]]
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
   U dyn_cast() const {
     return llvm::dyn_cast<U>(*this);
   }
 
   template <typename U>
-  [[deprecated("Use dyn_cast_or_null<U>() instead")]]
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
   U dyn_cast_or_null() const {
     return llvm::dyn_cast_or_null<U>(*this);
   }
 
   template <typename U>
-  [[deprecated("Use cast<U>() instead")]]
+  [[deprecated("Use mlir::cast<U>() instead")]]
   U cast() const {
     return llvm::cast<U>(*this);
   }
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 4a15976d40c763..c2a83f90bcbe9d 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -857,7 +857,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
     auto type = cast<ComplexType>(op.getType());
-    auto elementType = type.getElementType().cast<FloatType>();
+    auto elementType = cast<FloatType>(type.getElementType());
     arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
 
     auto cst = [&](APFloat v) {
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index ee09c73bb3c4ae..f1ec2be72a33ab 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -172,7 +172,7 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
     if (failed(parser.parseEqual()))
       return {};
 
-    IntegerType iType = ty.dyn_cast<IntegerType>();
+    IntegerType iType = mlir::dyn_cast<IntegerType>(ty);
     if (!iType) {
       parser.emitError(parser.getCurrentLocation(),
                        "coefficientType must specify an integer type");
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 69999f0918c103..802a64b0805ee4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -140,7 +140,7 @@ struct LinearizeVectorExtractStridedSlice final
                   ConversionPatternRewriter &rewriter) const override {
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
     assert(!(extractOp.getVector().getType().isScalable() ||
-             dstType.cast<VectorType>().isScalable()) &&
+             cast<VectorType>(dstType).isScalable()) &&
            "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
@@ -172,7 +172,7 @@ struct LinearizeVectorExtractStridedSlice final
     // Get total number of extracted slices.
     int64_t nExtractedSlices = 1;
     for (Attribute size : sizes) {
-      nExtractedSlices *= size.cast<IntegerAttr>().getInt();
+      nExtractedSlices *= cast<IntegerAttr>(size).getInt();
     }
     // Compute the strides of the source vector considering first k dimensions.
     llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
@@ -189,7 +189,7 @@ struct LinearizeVectorExtractStridedSlice final
     // Compute extractedStrides.
     for (int i = kD - 2; i >= 0; --i) {
       extractedStrides[i] =
-          extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
+          extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
     }
     // Iterate over all extracted slices from 0 to nExtractedSlices - 1
     // and compute the multi-dimensional index and the corresponding linearized
@@ -207,7 +207,7 @@ struct LinearizeVectorExtractStridedSlice final
       int64_t linearizedIndex = 0;
       for (int64_t j = 0; j < kD; ++j) {
         linearizedIndex +=
-            (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
+            (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
             sourceStrides[j];
       }
       // Fill the indices array form linearizedIndex to linearizedIndex +
@@ -254,7 +254,7 @@ struct LinearizeVectorShuffle final
     Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
     assert(!(shuffleOp.getV1VectorType().isScalable() ||
              shuffleOp.getV2VectorType().isScalable() ||
-             dstType.cast<VectorType>().isScalable()) &&
+             cast<VectorType>(dstType).isScalable()) &&
            "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
@@ -324,7 +324,7 @@ struct LinearizeVectorExtract final
                   ConversionPatternRewriter &rewriter) const override {
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     assert(!(extractOp.getVector().getType().isScalable() ||
-             dstTy.cast<VectorType>().isScalable()) &&
+             cast<VectorType>(dstTy).isScalable()) &&
            "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
@@ -405,9 +405,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
       [=](vector::ShuffleOp shuffleOp) -> bool {
         return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
                    ? (typeConverter.isLegal(shuffleOp) &&
-                      shuffleOp.getResult()
-                              .getType()
-                              .cast<mlir::VectorType>()
+                      cast<mlir::VectorType>(shuffleOp.getResult().getType())
                               .getRank() == 1)
                    : true;
       });



More information about the Mlir-commits mailing list