[Mlir-commits] [mlir] [mlir] Migrate away from PointerUnion::dyn_cast (NFC) (PR #123693)

Kazu Hirata llvmlistbot at llvm.org
Mon Jan 20 22:37:50 PST 2025


https://github.com/kazutakahirata created https://github.com/llvm/llvm-project/pull/123693

Note that PointerUnion::dyn_cast has been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>


>From cc7008c4b3a6bd2f8a20d2ea5cb7402263d10368 Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Mon, 20 Jan 2025 10:23:29 -0800
Subject: [PATCH] [mlir] Migrate away from PointerUnion::dyn_cast (NFC)

Note that PointerUnion::dyn_cast has been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>
---
 mlir/examples/transform-opt/mlir-transform-opt.cpp          | 6 ++++--
 mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp    | 2 +-
 .../Transform/Transforms/TransformInterpreterUtils.cpp      | 2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp                    | 4 ++--
 .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp   | 5 +++--
 mlir/lib/IR/AffineMap.cpp                                   | 5 +++--
 mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp    | 2 +-
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp                 | 5 +++--
 mlir/unittests/IR/SymbolTableTest.cpp                       | 6 +++---
 9 files changed, 21 insertions(+), 16 deletions(-)

diff --git a/mlir/examples/transform-opt/mlir-transform-opt.cpp b/mlir/examples/transform-opt/mlir-transform-opt.cpp
index 10e16096211ad7..73cb0319bfd087 100644
--- a/mlir/examples/transform-opt/mlir-transform-opt.cpp
+++ b/mlir/examples/transform-opt/mlir-transform-opt.cpp
@@ -120,7 +120,8 @@ class DiagnosticHandlerWrapper {
   /// Verifies the captured "expected-*" diagnostics if required.
   llvm::LogicalResult verify() const {
     if (auto *ptr =
-            handler.dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>()) {
+            dyn_cast_if_present<mlir::SourceMgrDiagnosticVerifierHandler *>(
+                handler)) {
       return ptr->verify();
     }
     return mlir::success();
@@ -128,7 +129,8 @@ class DiagnosticHandlerWrapper {
 
   /// Destructs the object of the same type as allocated.
   ~DiagnosticHandlerWrapper() {
-    if (auto *ptr = handler.dyn_cast<mlir::SourceMgrDiagnosticHandler *>()) {
+    if (auto *ptr =
+            dyn_cast_if_present<mlir::SourceMgrDiagnosticHandler *>(handler)) {
       delete ptr;
     } else {
       delete cast<mlir::SourceMgrDiagnosticVerifierHandler *>(handler);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d688d8e2ab6588..7bd6201d4608cf 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -119,7 +119,7 @@ static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
 /// an LLVM constant op.
 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
                             OpFoldResult foldResult) {
-  if (auto attr = foldResult.dyn_cast<Attribute>()) {
+  if (auto attr = dyn_cast_if_present<Attribute>(foldResult)) {
     auto intAttr = cast<IntegerAttr>(attr);
     return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
   }
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
index 232c9c96dd09fc..4868ab8e49178f 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp
@@ -210,7 +210,7 @@ LogicalResult transform::applyTransformNamedSequence(
            << "expected one payload to be bound to the first argument, got "
            << bindings.at(0).size();
   }
-  auto *payloadRoot = bindings.at(0).front().dyn_cast<Operation *>();
+  auto *payloadRoot = dyn_cast_if_present<Operation *>(bindings.at(0).front());
   if (!payloadRoot) {
     return transformRoot->emitError() << "expected the object bound to the "
                                          "first argument to be an operation";
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 696d1e0f9b1e68..c04ddd1922127c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -340,7 +340,7 @@ SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc,
   SmallVector<Value> values;
   llvm::transform(foldResults, std::back_inserter(values),
                   [&](OpFoldResult foldResult) {
-                    if (auto attr = foldResult.dyn_cast<Attribute>())
+                    if (auto attr = dyn_cast_if_present<Attribute>(foldResult))
                       return builder
                           .create<arith::ConstantIndexOp>(
                               loc, cast<IntegerAttr>(attr).getInt())
@@ -2880,7 +2880,7 @@ LogicalResult InsertOp::verify() {
     return emitOpError(
         "expected position attribute rank to match the dest vector rank");
   for (auto [idx, pos] : llvm::enumerate(position)) {
-    if (auto attr = pos.dyn_cast<Attribute>()) {
+    if (auto attr = dyn_cast_if_present<Attribute>(pos)) {
       int64_t constIdx = cast<IntegerAttr>(attr).getInt();
       if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) {
         return emitOpError("expected position attribute #")
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..2481b3e44e7a2e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -242,9 +242,10 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
                                          int64_t numElementsToExtract) {
   for (int i = 0; i < numElementsToExtract; ++i) {
     Value extractLoc =
-        (i == 0) ? offset.dyn_cast<Value>()
+        (i == 0) ? dyn_cast_if_present<Value>(offset)
                  : rewriter.create<arith::AddIOp>(
-                       loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
+                       loc, rewriter.getIndexType(),
+                       dyn_cast_if_present<Value>(offset),
                        rewriter.create<arith::ConstantIndexOp>(loc, i));
     auto extractOp =
         rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 8e8a433f331df5..f9cbaa9d26740b 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -748,7 +748,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
   SmallVector<AffineExpr> dimReplacements, symReplacements;
   int64_t numDims = 0;
   for (int64_t i = 0; i < map.getNumDims(); ++i) {
-    if (auto attr = operands[i].dyn_cast<Attribute>()) {
+    if (auto attr = dyn_cast_if_present<Attribute>(operands[i])) {
       dimReplacements.push_back(
           b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
     } else {
@@ -758,7 +758,8 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
   }
   int64_t numSymbols = 0;
   for (int64_t i = 0; i < map.getNumSymbols(); ++i) {
-    if (auto attr = operands[i + map.getNumDims()].dyn_cast<Attribute>()) {
+    if (auto attr =
+            dyn_cast_if_present<Attribute>(operands[i + map.getNumDims()])) {
       symReplacements.push_back(
           b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
     } else {
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 969c560c99ab7c..e7620d93697afc 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -515,7 +515,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
         bool materializationSucceeded = true;
         for (auto [ofr, resultType] :
              llvm::zip_equal(foldResults, op->getResultTypes())) {
-          if (auto value = ofr.dyn_cast<Value>()) {
+          if (auto value = dyn_cast_if_present<Value>(ofr)) {
             assert(value.getType() == resultType &&
                    "folder produced value of incorrect type");
             replacements.push_back(value);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a970cbc5cacebe..39735cd5646a14 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1781,7 +1781,7 @@ void OpEmitter::genPropertiesSupportForBytecode(
       writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt);
     }
     if (const auto *namedProperty =
-            attrOrProp.dyn_cast<const NamedProperty *>()) {
+            dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
       StringRef name = namedProperty->name;
       readPropertiesMethod << formatv(
           R"(
@@ -1807,7 +1807,8 @@ void OpEmitter::genPropertiesSupportForBytecode(
           name, tgfmt(namedProperty->prop.getWriteToMlirBytecodeCall(), &fctx));
       continue;
     }
-    const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+    const auto *namedAttr =
+        dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
     StringRef name = namedAttr->attrName;
     if (namedAttr->isRequired) {
       readPropertiesMethod << formatv(R"(
diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
index 5dcec749f0f425..192d2e273c2d01 100644
--- a/mlir/unittests/IR/SymbolTableTest.cpp
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -49,9 +49,9 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
     // Check that it got renamed.
     bool calleeFound = false;
     fooOp->walk([&](CallOpInterface callOp) {
-      StringAttr callee = callOp.getCallableForCallee()
-                              .dyn_cast<SymbolRefAttr>()
-                              .getLeafReference();
+      StringAttr callee =
+          dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee())
+              .getLeafReference();
       EXPECT_EQ(callee, "baz");
       calleeFound = true;
     });



More information about the Mlir-commits mailing list