[Mlir-commits] [mlir] 4f4e2ab - [mlir] Migrate away from PointerUnion::{is, get} (NFC) (#122591)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 11 13:16:46 PST 2025


Author: Kazu Hirata
Date: 2025-01-11T13:16:43-08:00
New Revision: 4f4e2abb1a5ff1225d32410fd02b732d077aa056

URL: https://github.com/llvm/llvm-project/commit/4f4e2abb1a5ff1225d32410fd02b732d077aa056
DIFF: https://github.com/llvm/llvm-project/commit/4f4e2abb1a5ff1225d32410fd02b732d077aa056.diff

LOG: [mlir] Migrate away from PointerUnion::{is,get} (NFC) (#122591)

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.

Added: 
    

Modified: 
    mlir/examples/transform-opt/mlir-transform-opt.cpp
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/include/mlir/IR/Matchers.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
    mlir/include/mlir/Pass/AnalysisManager.h
    mlir/lib/Analysis/DataFlowFramework.cpp
    mlir/lib/AsmParser/Parser.cpp
    mlir/lib/Bytecode/Writer/IRNumbering.h
    mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
    mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/examples/transform-opt/mlir-transform-opt.cpp b/mlir/examples/transform-opt/mlir-transform-opt.cpp
index 65615fb25bff6b..10e16096211ad7 100644
--- a/mlir/examples/transform-opt/mlir-transform-opt.cpp
+++ b/mlir/examples/transform-opt/mlir-transform-opt.cpp
@@ -131,7 +131,7 @@ class DiagnosticHandlerWrapper {
     if (auto *ptr = handler.dyn_cast<mlir::SourceMgrDiagnosticHandler *>()) {
       delete ptr;
     } else {
-      delete handler.get<mlir::SourceMgrDiagnosticVerifierHandler *>();
+      delete cast<mlir::SourceMgrDiagnosticVerifierHandler *>(handler);
     }
   }
 

diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 507087d5575e9e..387b9ee707179b 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -37,7 +37,7 @@ class AbstractSparseLattice : public AnalysisState {
   AbstractSparseLattice(Value value) : AnalysisState(value) {}
 
   /// Return the value this lattice is located at.
-  Value getAnchor() const { return AnalysisState::getAnchor().get<Value>(); }
+  Value getAnchor() const { return cast<Value>(AnalysisState::getAnchor()); }
 
   /// Join the information contained in 'rhs' into this lattice. Returns
   /// if the value of the lattice changed.

diff  --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index d218206e50f8f1..1dce055db1b4a7 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -92,7 +92,7 @@ struct constant_op_binder {
     (void)result;
     assert(succeeded(result) && "expected ConstantLike op to be foldable");
 
-    if (auto attr = llvm::dyn_cast<AttrT>(foldedOp.front().get<Attribute>())) {
+    if (auto attr = llvm::dyn_cast<AttrT>(cast<Attribute>(foldedOp.front()))) {
       if (bind_value)
         *bind_value = attr;
       return true;

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d6690991..d91c573c03efe5 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -272,8 +272,9 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
   void dump() const { llvm::errs() << *this << "\n"; }
 
   MLIRContext *getContext() const {
-    return is<Attribute>() ? get<Attribute>().getContext()
-                           : get<Value>().getContext();
+    PointerUnion pu = *this;
+    return isa<Attribute>(pu) ? cast<Attribute>(pu).getContext()
+                              : cast<Value>(pu).getContext();
   }
 };
 

diff  --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
index 3532116700af51..0d09b92928fe33 100644
--- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
+++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
@@ -86,7 +86,7 @@ def DataLayoutEntryInterface : AttrInterface<"DataLayoutEntryInterface"> {
   let extraClassDeclaration = [{
     /// Returns `true` if the key of this entry is a type.
     bool isTypeEntry() {
-      return getKey().is<::mlir::Type>();
+      return llvm::isa<::mlir::Type>(getKey());
     }
   }];
 }

diff  --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index f9db26140259ba..199ffee792bb54 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -262,7 +262,7 @@ struct NestedAnalysisMap {
   PassInstrumentor *getPassInstrumentor() const {
     if (auto *parent = getParent())
       return parent->getPassInstrumentor();
-    return parentOrInstrumentor.get<PassInstrumentor *>();
+    return cast<PassInstrumentor *>(parentOrInstrumentor);
   }
 
   /// The cached analyses for nested operations.

diff  --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 7e83668c067652..d2742c6e4b966d 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -84,7 +84,7 @@ void LatticeAnchor::print(raw_ostream &os) const {
     return value.print(os, OpPrintingFlags().skipRegions());
   }
 
-  return get<ProgramPoint *>()->print(os);
+  return llvm::cast<ProgramPoint *>(*this)->print(os);
 }
 
 Location LatticeAnchor::getLoc() const {
@@ -93,7 +93,7 @@ Location LatticeAnchor::getLoc() const {
   if (auto value = llvm::dyn_cast<Value>(*this))
     return value.getLoc();
 
-  ProgramPoint *pp = get<ProgramPoint *>();
+  ProgramPoint *pp = llvm::cast<ProgramPoint *>(*this);
   if (!pp->isBlockStart())
     return pp->getPrevOp()->getLoc();
   return pp->getBlock()->getParent()->getLoc();

diff  --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e3db248164672c..eccb3241012a24 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -2162,7 +2162,7 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
   if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument))
     op->setLoc(directLoc);
   else
-    opOrArgument.get<BlockArgument>().setLoc(directLoc);
+    cast<BlockArgument>(opOrArgument).setLoc(directLoc);
   return success();
 }
 

diff  --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
index eab75f50d2ee4f..9b7ac0d3688e37 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.h
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -50,11 +50,11 @@ struct AttrTypeNumbering {
 };
 struct AttributeNumbering : public AttrTypeNumbering {
   AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {}
-  Attribute getValue() const { return value.get<Attribute>(); }
+  Attribute getValue() const { return cast<Attribute>(value); }
 };
 struct TypeNumbering : public AttrTypeNumbering {
   TypeNumbering(Type value) : AttrTypeNumbering(value) {}
-  Type getValue() const { return value.get<Type>(); }
+  Type getValue() const { return cast<Type>(value); }
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index 76b066feb69308..873ecb5b70e02d 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -142,7 +142,7 @@ struct PDLIndexSymbol {
       const ast::Name *declName = decl->getName();
       return declName ? declName->getLoc() : decl->getLoc();
     }
-    return definition.get<const ods::Operation *>()->getLoc();
+    return cast<const ods::Operation *>(definition)->getLoc();
   }
 
   /// The main definition of the symbol.
@@ -470,7 +470,7 @@ PDLDocument::findHover(const lsp::URIForFile &uri,
   if (const auto *op =
           llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
     return buildHoverForOpName(op, hoverRange);
-  const auto *decl = symbol->definition.get<const ast::Decl *>();
+  const auto *decl = cast<const ast::Decl *>(symbol->definition);
   return findHover(decl, hoverRange);
 }
 

diff  --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
index 5c59d94a061fa0..6c4ec06fffb32b 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
@@ -165,11 +165,11 @@ struct TableGenRecordSymbol : public TableGenIndexSymbol {
   ~TableGenRecordSymbol() override = default;
 
   static bool classof(const TableGenIndexSymbol *symbol) {
-    return symbol->definition.is<const Record *>();
+    return isa<const Record *>(symbol->definition);
   }
 
   /// Return the value of this symbol.
-  const Record *getValue() const { return definition.get<const Record *>(); }
+  const Record *getValue() const { return cast<const Record *>(definition); }
 };
 /// This class represents a single record value symbol.
 struct TableGenRecordValSymbol : public TableGenIndexSymbol {
@@ -178,12 +178,12 @@ struct TableGenRecordValSymbol : public TableGenIndexSymbol {
   ~TableGenRecordValSymbol() override = default;
 
   static bool classof(const TableGenIndexSymbol *symbol) {
-    return symbol->definition.is<const RecordVal *>();
+    return isa<const RecordVal *>(symbol->definition);
   }
 
   /// Return the value of this symbol.
   const RecordVal *getValue() const {
-    return definition.get<const RecordVal *>();
+    return cast<const RecordVal *>(definition);
   }
 
   /// The parent record of this symbol.

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index c43f439525526b..e9adda0cd01db5 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -260,7 +260,7 @@ OperationFolder::processFoldResults(Operation *op,
 
     // Check to see if there is a canonicalized version of this constant.
     auto res = op->getResult(i);
-    Attribute attrRepl = foldResults[i].get<Attribute>();
+    Attribute attrRepl = cast<Attribute>(foldResults[i]);
     if (auto *constOp =
             tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl,
                                    res.getType(), erasedFoldedLocation)) {

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 99f3569b767b1c..969c560c99ab7c 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -523,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
           }
           // Materialize Attributes as SSA values.
           Operation *constOp = op->getDialect()->materializeConstant(
-              rewriter, ofr.get<Attribute>(), resultType, op->getLoc());
+              rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
 
           if (!constOp) {
             // If materialization fails, cleanup any operations generated for


        


More information about the Mlir-commits mailing list