[Mlir-commits] [mlir] b0746c6 - [mlir-tblgen] Migrate away from PointerUnion::{is, get} (NFC) (#119994)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 15 09:18:17 PST 2024


Author: Kazu Hirata
Date: 2024-12-15T09:18:13-08:00
New Revision: b0746c68629c26567cd123d8f9b28e796ef26f47

URL: https://github.com/llvm/llvm-project/commit/b0746c68629c26567cd123d8f9b28e796ef26f47
DIFF: https://github.com/llvm/llvm-project/commit/b0746c68629c26567cd123d8f9b28e796ef26f47.diff

LOG: [mlir-tblgen] Migrate away from PointerUnion::{is,get} (NFC) (#119994)

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.

Added: 
    

Modified: 
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
    mlir/tools/mlir-tblgen/RewriterGen.cpp
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 59ae7777da5675..30a2d4da573dc5 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3162,7 +3162,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
                              isOptional);
       continue;
     }
-    const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
+    const NamedAttribute &namedAttr = *cast<NamedAttribute *>(arg);
     const Attribute &attr = namedAttr.attr;
 
     // Inferred attributes don't need to be added to the param list.
@@ -3499,14 +3499,14 @@ void OpEmitter::genSideEffectInterfaceMethods() {
   /// Attributes and Operands.
   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
     Argument arg = op.getArg(i);
-    if (arg.is<NamedTypeConstraint *>()) {
+    if (isa<NamedTypeConstraint *>(arg)) {
       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
       ++operandIt;
       continue;
     }
-    if (arg.is<NamedProperty *>())
+    if (isa<NamedProperty *>(arg))
       continue;
-    const NamedAttribute *attr = arg.get<NamedAttribute *>();
+    const NamedAttribute *attr = cast<NamedAttribute *>(arg);
     if (attr->attr.getBaseAttr().isSymbolRefAttr())
       resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
   }
@@ -3547,7 +3547,7 @@ void OpEmitter::genSideEffectInterfaceMethods() {
                     .str();
       } else if (location.kind == EffectKind::Symbol) {
         // A symbol reference requires adding the proper attribute.
-        const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
+        const auto *attr = cast<NamedAttribute *>(op.getArg(location.index));
         std::string argName = op.getGetterName(attr->name);
         if (attr->attr.isOptional()) {
           body << "  if (auto symbolRef = " << argName << "Attr())\n  "
@@ -3648,7 +3648,7 @@ void OpEmitter::genTypeInterfaceMethods() {
           // If this is an attribute, index into the attribute dictionary.
         } else {
           auto *attr =
-              op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
+              cast<NamedAttribute *>(op.getArg(arg.operandOrAttributeIndex()));
           body << "  ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
                << " = ";
           if (op.getDialect().usePropertiesForAttributes()) {

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 5019b69d91127e..300d9977f03994 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -628,7 +628,7 @@ static void populateBuilderArgs(const Operator &op,
       name = formatv("_gen_arg_{0}", i);
     name = sanitizeName(name);
     builderArgs.push_back(name);
-    if (!op.getArg(i).is<NamedAttribute *>())
+    if (!isa<NamedAttribute *>(op.getArg(i)))
       operandNames.push_back(name);
   }
 }

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index e2764f8493cd8a..a041c4d3277798 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -655,7 +655,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
     }
 
     // Next handle DAG leaf: operand or attribute
-    if (opArg.is<NamedTypeConstraint *>()) {
+    if (isa<NamedTypeConstraint *>(opArg)) {
       auto operandName =
           formatv("{0}.getODSOperands({1})", castedName, nextOperand);
       emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
@@ -663,7 +663,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                        /*argName=*/tree.getArgName(i), opArgIdx,
                        /*variadicSubIndex=*/std::nullopt);
       ++nextOperand;
-    } else if (opArg.is<NamedAttribute *>()) {
+    } else if (isa<NamedAttribute *>(opArg)) {
       emitAttributeMatch(tree, opName, opArgIdx, depth);
     } else {
       PrintFatalError(loc, "unhandled case when matching op");
@@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
                                       int argIndex,
                                       std::optional<int> variadicSubIndex) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *operand = op.getArg(operandIndex).get<NamedTypeConstraint *>();
+  auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
 
   // If a constraint is specified, we need to generate C++ statements to
   // check the constraint.
@@ -770,7 +770,7 @@ void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
       // need to queue the operation only if the matching success. Thus we emit
       // the code at the end.
       tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
-    } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
+    } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
       emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
                        operandIndex,
                        /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
@@ -851,7 +851,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
       os << formatv("tblgen_ops.push_back({0});\n", argName);
 
       os.unindent() << "}\n";
-    } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
+    } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
       auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
       emitOperandMatch(tree, opName, operandName.str(), operandIndex,
                        /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
@@ -867,7 +867,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
                                         int argIndex, int depth) {
   Operator &op = tree.getDialectOp(opMap);
-  auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
+  auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex));
   const auto &attr = namedAttr->attr;
 
   os << "{\n";
@@ -1775,7 +1775,7 @@ void PatternEmitter::supplyValuesForOpArgs(
       auto patArgName = node.getArgName(argIndex);
       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
         // TODO: Refactor out into map to avoid recomputing these.
-        if (!opArg.is<NamedAttribute *>())
+        if (!isa<NamedAttribute *>(opArg))
           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
         if (!patArgName.empty())
           os << "/*" << patArgName << "=*/";
@@ -1805,7 +1805,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
   bool hasOperandSegmentSizes = false;
   std::vector<std::string> sizes;
   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
-    if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
+    if (isa<NamedAttribute *>(resultOp.getArg(argIndex))) {
       // The argument in the op definition.
       auto opArgName = resultOp.getArgName(argIndex);
       hasOperandSegmentSizes =
@@ -1826,7 +1826,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
     }
 
     const auto *operand =
-        resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
+        cast<NamedTypeConstraint *>(resultOp.getArg(argIndex));
     std::string varName;
     if (operand->isVariadic()) {
       ++numVariadic;

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 75286231f5902f..75b8829be4da59 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -33,7 +33,9 @@
 #include <optional>
 
 using llvm::ArrayRef;
+using llvm::cast;
 using llvm::formatv;
+using llvm::isa;
 using llvm::raw_ostream;
 using llvm::raw_string_ostream;
 using llvm::Record;
@@ -607,11 +609,11 @@ static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
   bool areOperandsAheadOfAttrs = true;
   // Find the first attribute.
   const Argument *it = llvm::find_if(op.getArgs(), [](const Argument &arg) {
-    return arg.is<NamedAttribute *>();
+    return isa<NamedAttribute *>(arg);
   });
   // Check whether all following arguments are attributes.
   for (const Argument *ie = op.arg_end(); it != ie; ++it) {
-    if (!it->is<NamedAttribute *>()) {
+    if (!isa<NamedAttribute *>(*it)) {
       areOperandsAheadOfAttrs = false;
       break;
     }
@@ -642,7 +644,7 @@ static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
   for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
     auto argument = op.getArg(i);
     os << tabs << "{\n";
-    if (argument.is<NamedTypeConstraint *>()) {
+    if (isa<NamedTypeConstraint *>(argument)) {
       os << tabs
          << formatv("  for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
                     operandNum);
@@ -657,7 +659,7 @@ static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
       os << "    }\n";
       operandNum++;
     } else {
-      NamedAttribute *attr = argument.get<NamedAttribute *>();
+      NamedAttribute *attr = cast<NamedAttribute *>(argument);
       auto newtabs = tabs.str() + "  ";
       emitAttributeSerialization(
           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
@@ -962,7 +964,7 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
       os << tabs << "}\n";
     } else {
       os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
-      auto *attr = argument.get<NamedAttribute *>();
+      auto *attr = cast<NamedAttribute *>(argument);
       auto newtabs = tabs.str() + "  ";
       emitAttributeDeserialization(
           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),


        


More information about the Mlir-commits mailing list