[Mlir-commits] [mlir] fff39b6 - [mlir][Attribute] Remove usages of Attribute::getKind

River Riddle llvmlistbot at llvm.org
Fri Aug 7 13:43:52 PDT 2020


Author: River Riddle
Date: 2020-08-07T13:43:25-07:00
New Revision: fff39b62bb4078ce78813f25c04e0da435a8feb3

URL: https://github.com/llvm/llvm-project/commit/fff39b62bb4078ce78813f25c04e0da435a8feb3
DIFF: https://github.com/llvm/llvm-project/commit/fff39b62bb4078ce78813f25c04e0da435a8feb3.diff

LOG: [mlir][Attribute] Remove usages of Attribute::getKind

This is in preparation for removing the use of "kinds" within attributes and types in MLIR.

Differential Revision: https://reviews.llvm.org/D85370

Added: 
    

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/Location.h
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Attributes.cpp
    mlir/lib/IR/Diagnostics.cpp
    mlir/lib/IR/Location.cpp
    mlir/lib/Target/LLVMIR/DebugTranslation.cpp
    mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index a57adb315bc3..75ac2adc302c 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -661,10 +661,7 @@ class ElementsAttr : public Attribute {
                          function_ref<APInt(const APFloat &)> mapping) const;
 
   /// Method for support type inquiry through isa, cast and dyn_cast.
-  static bool classof(Attribute attr) {
-    return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
-           attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
-  }
+  static bool classof(Attribute attr);
 
 protected:
   /// Returns the 1 dimensional flattened row-major index from the given
@@ -729,10 +726,7 @@ class DenseElementsAttr : public ElementsAttr {
   using ElementsAttr::ElementsAttr;
 
   /// Method for support type inquiry through isa, cast and dyn_cast.
-  static bool classof(Attribute attr) {
-    return attr.getKind() == StandardAttributes::DenseIntOrFPElements ||
-           attr.getKind() == StandardAttributes::DenseStringElements;
-  }
+  static bool classof(Attribute attr);
 
   /// Constructs a dense elements attribute from an array of element values.
   /// Each element attribute value is expected to be an element of 'type'.
@@ -1513,12 +1507,10 @@ class ElementsAttrIterator
   template <typename RetT, template <typename> class ProcessFn,
             typename... Args>
   RetT process(Args &... args) const {
-    switch (attrKind) {
-    case StandardAttributes::DenseIntOrFPElements:
+    if (attr.isa<DenseElementsAttr>())
       return ProcessFn<DenseIteratorT>()(args...);
-    case StandardAttributes::SparseElements:
+    if (attr.isa<SparseElementsAttr>())
       return ProcessFn<SparseIteratorT>()(args...);
-    }
     llvm_unreachable("unexpected attribute kind");
   }
 
@@ -1543,22 +1535,21 @@ class ElementsAttrIterator
   };
 
 public:
-  ElementsAttrIterator(const ElementsAttrIterator<T> &rhs)
-      : attrKind(rhs.attrKind) {
+  ElementsAttrIterator(const ElementsAttrIterator<T> &rhs) : attr(rhs.attr) {
     process<void, ConstructIter>(it, rhs.it);
   }
   ~ElementsAttrIterator() { process<void, DestructIter>(it); }
 
   /// Methods necessary to support random access iteration.
   ptr
diff _t operator-(const ElementsAttrIterator<T> &rhs) const {
-    assert(attrKind == rhs.attrKind && "incompatible iterators");
+    assert(attr == rhs.attr && "incompatible iterators");
     return process<ptr
diff _t, Minus>(it, rhs.it);
   }
   bool operator==(const ElementsAttrIterator<T> &rhs) const {
-    return rhs.attrKind == attrKind && process<bool, std::equal_to>(it, rhs.it);
+    return rhs.attr == attr && process<bool, std::equal_to>(it, rhs.it);
   }
   bool operator<(const ElementsAttrIterator<T> &rhs) const {
-    assert(attrKind == rhs.attrKind && "incompatible iterators");
+    assert(attr == rhs.attr && "incompatible iterators");
     return process<bool, std::less>(it, rhs.it);
   }
   ElementsAttrIterator<T> &operator+=(ptr
diff _t offset) {
@@ -1575,14 +1566,14 @@ class ElementsAttrIterator
 
 private:
   template <typename IteratorT>
-  ElementsAttrIterator(unsigned attrKind, IteratorT &&it)
-      : attrKind(attrKind), it(std::forward<IteratorT>(it)) {}
+  ElementsAttrIterator(Attribute attr, IteratorT &&it)
+      : attr(attr), it(std::forward<IteratorT>(it)) {}
 
   /// Allow accessing the constructor.
   friend ElementsAttr;
 
-  /// The kind of derived elements attribute.
-  unsigned attrKind;
+  /// The parent elements attribute.
+  Attribute attr;
 
   /// A union containing the specific iterators for each derived kind.
   Iterator it;
@@ -1599,13 +1590,13 @@ template <typename T>
 auto ElementsAttr::getValues() const -> iterator_range<T> {
   if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
     auto values = denseAttr.getValues<T>();
-    return {iterator<T>(getKind(), values.begin()),
-            iterator<T>(getKind(), values.end())};
+    return {iterator<T>(*this, values.begin()),
+            iterator<T>(*this, values.end())};
   }
   if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
     auto values = sparseAttr.getValues<T>();
-    return {iterator<T>(getKind(), values.begin()),
-            iterator<T>(getKind(), values.end())};
+    return {iterator<T>(*this, values.begin()),
+            iterator<T>(*this, values.end())};
   }
   llvm_unreachable("unexpected attribute kind");
 }

diff  --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index 5bbf003a4271..aed0c4239fb4 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -42,10 +42,7 @@ class LocationAttr : public Attribute {
   using Attribute::Attribute;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Attribute attr) {
-    return attr.getKind() >= StandardAttributes::FIRST_LOCATION_ATTR &&
-           attr.getKind() <= StandardAttributes::LAST_LOCATION_ATTR;
-  }
+  static bool classof(Attribute attr);
 };
 
 /// This class defines the main interface for locations in MLIR and acts as a

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index b7d36f4a9487..5e098e815d98 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1472,18 +1472,15 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
   // ODS already generates checks to make sure the result type is valid. We just
   // need to additionally check that the value's attribute type is consistent
   // with the result type.
-  switch (value.getKind()) {
-  case StandardAttributes::Integer:
-  case StandardAttributes::Float: {
+  if (value.isa<IntegerAttr, FloatAttr>()) {
     if (valueType != opType)
       return constOp.emitOpError("result type (")
              << opType << ") does not match value type (" << valueType << ")";
     return success();
-  } break;
-  case StandardAttributes::DenseIntOrFPElements:
-  case StandardAttributes::SparseElements: {
+  }
+  if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
     if (valueType == opType)
-      break;
+      return success();
     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
     auto shapedType = valueType.dyn_cast<ShapedType>();
     if (!arrayType) {
@@ -1497,9 +1494,8 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
       numElements *= t.getNumElements();
       opElemType = t.getElementType();
     }
-    if (!opElemType.isIntOrFloat()) {
+    if (!opElemType.isIntOrFloat())
       return constOp.emitOpError("only support nested array result type");
-    }
 
     auto valueElemType = shapedType.getElementType();
     if (valueElemType != opElemType) {
@@ -1513,26 +1509,24 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
              << numElements << ") does not match value number of elements ("
              << shapedType.getNumElements() << ")";
     }
-  } break;
-  case StandardAttributes::Array: {
+    return success();
+  }
+  if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
     if (!arrayType)
       return constOp.emitOpError(
           "must have spv.array result type for array value");
-    auto elemType = arrayType.getElementType();
-    for (auto element : value.cast<ArrayAttr>().getValue()) {
+    Type elemType = arrayType.getElementType();
+    for (Attribute element : attayAttr.getValue()) {
       if (element.getType() != elemType)
         return constOp.emitOpError("has array element whose type (")
                << element.getType()
                << ") does not match the result element type (" << elemType
                << ')';
     }
-  } break;
-  default:
-    return constOp.emitOpError("cannot have value of type ") << valueType;
+    return success();
   }
-
-  return success();
+  return constOp.emitOpError("cannot have value of type ") << valueType;
 }
 
 bool spirv::ConstantOp::isBuildableWith(Type type) {
@@ -2619,19 +2613,14 @@ static LogicalResult verify(spirv::SpecConstantOp constOp) {
       return constOp.emitOpError("SpecId cannot be negative");
 
   auto value = constOp.default_value();
-
-  switch (value.getKind()) {
-  case StandardAttributes::Integer:
-  case StandardAttributes::Float: {
+  if (value.isa<IntegerAttr, FloatAttr>()) {
     // Make sure bitwidth is allowed.
     if (!value.getType().isa<spirv::SPIRVType>())
       return constOp.emitOpError("default value bitwidth disallowed");
     return success();
   }
-  default:
-    return constOp.emitOpError(
-        "default value can only be a bool, integer, or float scalar");
-  }
+  return constOp.emitOpError(
+      "default value can only be a bool, integer, or float scalar");
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 372e4c93dc37..4470f5b4b826 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -33,6 +33,7 @@
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Regex.h"
 #include "llvm/Support/SaveAndRestore.h"
@@ -1019,76 +1020,67 @@ void ModulePrinter::printTrailingLocation(Location loc) {
 }
 
 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
-  switch (loc.getKind()) {
-  case StandardAttributes::OpaqueLocation:
-    printLocationInternal(loc.cast<OpaqueLoc>().getFallbackLocation(), pretty);
-    break;
-  case StandardAttributes::UnknownLocation:
-    if (pretty)
-      os << "[unknown]";
-    else
-      os << "unknown";
-    break;
-  case StandardAttributes::FileLineColLocation: {
-    auto fileLoc = loc.cast<FileLineColLoc>();
-    auto mayQuote = pretty ? "" : "\"";
-    os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
-       << fileLoc.getLine() << ':' << fileLoc.getColumn();
-    break;
-  }
-  case StandardAttributes::NameLocation: {
-    auto nameLoc = loc.cast<NameLoc>();
-    os << '\"' << nameLoc.getName() << '\"';
-
-    // Print the child if it isn't unknown.
-    auto childLoc = nameLoc.getChildLoc();
-    if (!childLoc.isa<UnknownLoc>()) {
-      os << '(';
-      printLocationInternal(childLoc, pretty);
-      os << ')';
-    }
-    break;
-  }
-  case StandardAttributes::CallSiteLocation: {
-    auto callLocation = loc.cast<CallSiteLoc>();
-    auto caller = callLocation.getCaller();
-    auto callee = callLocation.getCallee();
-    if (!pretty)
-      os << "callsite(";
-    printLocationInternal(callee, pretty);
-    if (pretty) {
-      if (callee.isa<NameLoc>()) {
-        if (caller.isa<FileLineColLoc>()) {
-          os << " at ";
+  TypeSwitch<LocationAttr>(loc)
+      .Case<OpaqueLoc>([&](OpaqueLoc loc) {
+        printLocationInternal(loc.getFallbackLocation(), pretty);
+      })
+      .Case<UnknownLoc>([&](UnknownLoc loc) {
+        if (pretty)
+          os << "[unknown]";
+        else
+          os << "unknown";
+      })
+      .Case<FileLineColLoc>([&](FileLineColLoc loc) {
+        StringRef mayQuote = pretty ? "" : "\"";
+        os << mayQuote << loc.getFilename() << mayQuote << ':' << loc.getLine()
+           << ':' << loc.getColumn();
+      })
+      .Case<NameLoc>([&](NameLoc loc) {
+        os << '\"' << loc.getName() << '\"';
+
+        // Print the child if it isn't unknown.
+        auto childLoc = loc.getChildLoc();
+        if (!childLoc.isa<UnknownLoc>()) {
+          os << '(';
+          printLocationInternal(childLoc, pretty);
+          os << ')';
+        }
+      })
+      .Case<CallSiteLoc>([&](CallSiteLoc loc) {
+        Location caller = loc.getCaller();
+        Location callee = loc.getCallee();
+        if (!pretty)
+          os << "callsite(";
+        printLocationInternal(callee, pretty);
+        if (pretty) {
+          if (callee.isa<NameLoc>()) {
+            if (caller.isa<FileLineColLoc>()) {
+              os << " at ";
+            } else {
+              os << newLine << " at ";
+            }
+          } else {
+            os << newLine << " at ";
+          }
         } else {
-          os << newLine << " at ";
+          os << " at ";
         }
-      } else {
-        os << newLine << " at ";
-      }
-    } else {
-      os << " at ";
-    }
-    printLocationInternal(caller, pretty);
-    if (!pretty)
-      os << ")";
-    break;
-  }
-  case StandardAttributes::FusedLocation: {
-    auto fusedLoc = loc.cast<FusedLoc>();
-    if (!pretty)
-      os << "fused";
-    if (auto metadata = fusedLoc.getMetadata())
-      os << '<' << metadata << '>';
-    os << '[';
-    interleave(
-        fusedLoc.getLocations(),
-        [&](Location loc) { printLocationInternal(loc, pretty); },
-        [&]() { os << ", "; });
-    os << ']';
-    break;
-  }
-  }
+        printLocationInternal(caller, pretty);
+        if (!pretty)
+          os << ")";
+      })
+      .Case<FusedLoc>([&](FusedLoc loc) {
+        if (!pretty)
+          os << "fused";
+        if (Attribute metadata = loc.getMetadata())
+          os << '<' << metadata << '>';
+        os << '[';
+        interleave(
+            loc.getLocations(),
+            [&](Location loc) { printLocationInternal(loc, pretty); },
+            [&]() { os << ", "; });
+        os << ']';
+      });
 }
 
 /// Print a floating point value in a way that the parser will be able to
@@ -1305,27 +1297,19 @@ void ModulePrinter::printAttribute(Attribute attr,
   }
 
   auto attrType = attr.getType();
-  switch (attr.getKind()) {
-  default:
-    return printDialectAttribute(attr);
-
-  case StandardAttributes::Opaque: {
-    auto opaqueAttr = attr.cast<OpaqueAttr>();
+  if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
                        opaqueAttr.getAttrData());
-    break;
-  }
-  case StandardAttributes::Unit:
+  } else if (attr.isa<UnitAttr>()) {
     os << "unit";
-    break;
-  case StandardAttributes::Dictionary:
+    return;
+  } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
     os << '{';
-    interleaveComma(attr.cast<DictionaryAttr>().getValue(),
+    interleaveComma(dictAttr.getValue(),
                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
     os << '}';
-    break;
-  case StandardAttributes::Integer: {
-    auto intAttr = attr.cast<IntegerAttr>();
+
+  } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
     if (attrType.isSignlessInteger(1)) {
       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
 
@@ -1343,114 +1327,98 @@ void ModulePrinter::printAttribute(Attribute attr,
     // IntegerAttr elides the type if I64.
     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
       return;
-    break;
-  }
-  case StandardAttributes::Float: {
-    auto floatAttr = attr.cast<FloatAttr>();
+
+  } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
     printFloatValue(floatAttr.getValue(), os);
 
     // FloatAttr elides the type if F64.
     if (typeElision == AttrTypeElision::May && attrType.isF64())
       return;
-    break;
-  }
-  case StandardAttributes::String:
+
+  } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
     os << '"';
-    printEscapedString(attr.cast<StringAttr>().getValue(), os);
+    printEscapedString(strAttr.getValue(), os);
     os << '"';
-    break;
-  case StandardAttributes::Array:
+
+  } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
     os << '[';
-    interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
+    interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
       printAttribute(attr, AttrTypeElision::May);
     });
     os << ']';
-    break;
-  case StandardAttributes::AffineMap:
+
+  } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
     os << "affine_map<";
-    attr.cast<AffineMapAttr>().getValue().print(os);
+    affineMapAttr.getValue().print(os);
     os << '>';
 
     // AffineMap always elides the type.
     return;
-  case StandardAttributes::IntegerSet:
+
+  } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
     os << "affine_set<";
-    attr.cast<IntegerSetAttr>().getValue().print(os);
+    integerSetAttr.getValue().print(os);
     os << '>';
 
     // IntegerSet always elides the type.
     return;
-  case StandardAttributes::Type:
-    printType(attr.cast<TypeAttr>().getValue());
-    break;
-  case StandardAttributes::SymbolRef: {
-    auto refAttr = attr.dyn_cast<SymbolRefAttr>();
+
+  } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
+    printType(typeAttr.getValue());
+
+  } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
     printSymbolReference(refAttr.getRootReference(), os);
     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
       os << "::";
       printSymbolReference(nestedRef.getValue(), os);
     }
-    break;
-  }
-  case StandardAttributes::OpaqueElements: {
-    auto eltsAttr = attr.cast<OpaqueElementsAttr>();
-    if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
+
+  } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
+    if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
       printElidedElementsAttr(os);
-      break;
+    } else {
+      os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", ";
+      os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">";
     }
-    os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
-    os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
-    break;
-  }
-  case StandardAttributes::DenseIntOrFPElements: {
-    auto eltsAttr = attr.cast<DenseIntOrFPElementsAttr>();
-    if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
+
+  } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
+    if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
       printElidedElementsAttr(os);
-      break;
+    } else {
+      os << "dense<";
+      printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
+      os << '>';
     }
-    os << "dense<";
-    printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true);
-    os << '>';
-    break;
-  }
-  case StandardAttributes::DenseStringElements: {
-    auto eltsAttr = attr.cast<DenseStringElementsAttr>();
-    if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
+
+  } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
+    if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
       printElidedElementsAttr(os);
-      break;
+    } else {
+      os << "dense<";
+      printDenseStringElementsAttr(strEltAttr);
+      os << '>';
     }
-    os << "dense<";
-    printDenseStringElementsAttr(eltsAttr);
-    os << '>';
-    break;
-  }
-  case StandardAttributes::SparseElements: {
-    auto elementsAttr = attr.cast<SparseElementsAttr>();
-    if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) ||
-        printerFlags.shouldElideElementsAttr(elementsAttr.getValues())) {
+
+  } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
+    if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
+        printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
       printElidedElementsAttr(os);
-      break;
-    }
-    os << "sparse<";
-    DenseIntElementsAttr indices = elementsAttr.getIndices();
-    if (indices.getNumElements() != 0) {
-      printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
-      os << ", ";
-      printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
+    } else {
+      os << "sparse<";
+      DenseIntElementsAttr indices = sparseEltAttr.getIndices();
+      if (indices.getNumElements() != 0) {
+        printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
+        os << ", ";
+        printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
+      }
+      os << '>';
     }
-    os << '>';
-    break;
-  }
 
-  // Location attributes.
-  case StandardAttributes::CallSiteLocation:
-  case StandardAttributes::FileLineColLocation:
-  case StandardAttributes::FusedLocation:
-  case StandardAttributes::NameLocation:
-  case StandardAttributes::OpaqueLocation:
-  case StandardAttributes::UnknownLocation:
-    printLocation(attr.cast<LocationAttr>());
-    break;
+  } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
+    printLocation(locAttr);
+
+  } else {
+    return printDialectAttribute(attr);
   }
 
   // Don't print the type if we must elide it, or if it is a None type.

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 8202465af0c7..e353b0b90a2a 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -460,16 +460,11 @@ int64_t ElementsAttr::getNumElements() const {
 /// Return the value at the given index. If index does not refer to a valid
 /// element, then a null attribute is returned.
 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
-  switch (getKind()) {
-  case StandardAttributes::DenseIntOrFPElements:
-    return cast<DenseElementsAttr>().getValue(index);
-  case StandardAttributes::OpaqueElements:
-    return cast<OpaqueElementsAttr>().getValue(index);
-  case StandardAttributes::SparseElements:
-    return cast<SparseElementsAttr>().getValue(index);
-  default:
-    llvm_unreachable("unknown ElementsAttr kind");
-  }
+  if (auto denseAttr = dyn_cast<DenseElementsAttr>())
+    return denseAttr.getValue(index);
+  if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
+    return opaqueAttr.getValue(index);
+  return cast<SparseElementsAttr>().getValue(index);
 }
 
 /// Return if the given 'index' refers to a valid element in this attribute.
@@ -491,23 +486,23 @@ bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
 ElementsAttr
 ElementsAttr::mapValues(Type newElementType,
                         function_ref<APInt(const APInt &)> mapping) const {
-  switch (getKind()) {
-  case StandardAttributes::DenseIntOrFPElements:
-    return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
-  default:
-    llvm_unreachable("unsupported ElementsAttr subtype");
-  }
+  if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
+    return intOrFpAttr.mapValues(newElementType, mapping);
+  llvm_unreachable("unsupported ElementsAttr subtype");
 }
 
 ElementsAttr
 ElementsAttr::mapValues(Type newElementType,
                         function_ref<APInt(const APFloat &)> mapping) const {
-  switch (getKind()) {
-  case StandardAttributes::DenseIntOrFPElements:
-    return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
-  default:
-    llvm_unreachable("unsupported ElementsAttr subtype");
-  }
+  if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
+    return intOrFpAttr.mapValues(newElementType, mapping);
+  llvm_unreachable("unsupported ElementsAttr subtype");
+}
+
+/// Method for support type inquiry through isa, cast and dyn_cast.
+bool ElementsAttr::classof(Attribute attr) {
+  return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
+                  OpaqueElementsAttr, SparseElementsAttr>();
 }
 
 /// Returns the 1 dimensional flattened row-major index from the given
@@ -718,6 +713,11 @@ DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
 // DenseElementsAttr
 //===----------------------------------------------------------------------===//
 
+/// Method for support type inquiry through isa, cast and dyn_cast.
+bool DenseElementsAttr::classof(Attribute attr) {
+  return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
+}
+
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
   assert(hasSameElementsOrSplat(type, values));

diff  --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 34908c61639f..15dbac1f6cb8 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -366,43 +366,38 @@ struct SourceMgrDiagnosticHandlerImpl {
 
 /// Return a processable FileLineColLoc from the given location.
 static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
-  switch (loc->getKind()) {
-  case StandardAttributes::NameLocation:
+  if (auto nameLoc = loc.dyn_cast<NameLoc>())
     return getFileLineColLoc(loc.cast<NameLoc>().getChildLoc());
-  case StandardAttributes::FileLineColLocation:
-    return loc.cast<FileLineColLoc>();
-  case StandardAttributes::CallSiteLocation:
-    // Process the callee of a callsite location.
+  if (auto fileLoc = loc.dyn_cast<FileLineColLoc>())
+    return fileLoc;
+  if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
     return getFileLineColLoc(loc.cast<CallSiteLoc>().getCallee());
-  case StandardAttributes::FusedLocation:
+  if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
     for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
       if (auto callLoc = getFileLineColLoc(subLoc)) {
         return callLoc;
       }
     }
     return llvm::None;
-  default:
-    return llvm::None;
   }
+  return llvm::None;
 }
 
 /// Return a processable CallSiteLoc from the given location.
 static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
-  switch (loc->getKind()) {
-  case StandardAttributes::NameLocation:
+  if (auto nameLoc = loc.dyn_cast<NameLoc>())
     return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
-  case StandardAttributes::CallSiteLocation:
-    return loc.cast<CallSiteLoc>();
-  case StandardAttributes::FusedLocation:
+  if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
+    return callLoc;
+  if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
     for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
       if (auto callLoc = getCallSiteLoc(subLoc)) {
         return callLoc;
       }
     }
     return llvm::None;
-  default:
-    return llvm::None;
   }
+  return llvm::None;
 }
 
 /// Given a diagnostic kind, returns the LLVM DiagKind.

diff  --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index f22fd5cb7852..48b05ba0eb40 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -13,6 +13,16 @@
 using namespace mlir;
 using namespace mlir::detail;
 
+//===----------------------------------------------------------------------===//
+// LocationAttr
+//===----------------------------------------------------------------------===//
+
+/// Methods for support type inquiry through isa, cast, and dyn_cast.
+bool LocationAttr::classof(Attribute attr) {
+  return attr.isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
+                  UnknownLoc>();
+}
+
 //===----------------------------------------------------------------------===//
 // CallSiteLoc
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index f40f44f2fbc6..af364ba9048f 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -115,26 +115,19 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
     return existingIt->second;
 
   const llvm::DILocation *llvmLoc = nullptr;
-  switch (loc->getKind()) {
-  case StandardAttributes::CallSiteLocation: {
-    auto callLoc = loc.dyn_cast<CallSiteLoc>();
-
+  if (auto callLoc = loc.dyn_cast<CallSiteLoc>()) {
     // For callsites, the caller is fed as the inlinedAt for the callee.
     const auto *callerLoc = translateLoc(callLoc.getCaller(), scope, inlinedAt);
     llvmLoc = translateLoc(callLoc.getCallee(), scope, callerLoc);
-    break;
-  }
-  case StandardAttributes::FileLineColLocation: {
-    auto fileLoc = loc.dyn_cast<FileLineColLoc>();
+
+  } else if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
     auto *file = translateFile(fileLoc.getFilename());
     auto *fileScope = builder.createLexicalBlockFile(scope, file);
     llvmLoc = llvm::DILocation::get(llvmCtx, fileLoc.getLine(),
                                     fileLoc.getColumn(), fileScope,
                                     const_cast<llvm::DILocation *>(inlinedAt));
-    break;
-  }
-  case StandardAttributes::FusedLocation: {
-    auto fusedLoc = loc.dyn_cast<FusedLoc>();
+
+  } else if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
     ArrayRef<Location> locations = fusedLoc.getLocations();
 
     // For fused locations, merge each of the nodes.
@@ -143,18 +136,17 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
       llvmLoc = llvm::DILocation::getMergedLocation(
           llvmLoc, translateLoc(locIt, scope, inlinedAt));
     }
-    break;
-  }
-  case StandardAttributes::NameLocation:
+
+  } else if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
     llvmLoc = translateLoc(loc.cast<NameLoc>().getChildLoc(), scope, inlinedAt);
-    break;
-  case StandardAttributes::OpaqueLocation:
+
+  } else if (auto opaqueLoc = loc.dyn_cast<OpaqueLoc>()) {
     llvmLoc = translateLoc(loc.cast<OpaqueLoc>().getFallbackLocation(), scope,
                            inlinedAt);
-    break;
-  default:
+  } else {
     llvm_unreachable("unknown location kind");
   }
+
   locationToLoc.try_emplace(std::make_pair(loc, scope), llvmLoc);
   return llvmLoc;
 }

diff  --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
index 68e43d45ab42..97c94a54ffc4 100644
--- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
@@ -158,7 +158,7 @@ TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
   auto expectedTensorType = realValue.getType().cast<TensorType>();
   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
   EXPECT_EQ(tensorType.getElementType(), convertedType);
-  EXPECT_EQ(returnedValue.getKind(), StandardAttributes::SparseElements);
+  EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
 
   // Check Elements attribute element value is expected.
   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});


        


More information about the Mlir-commits mailing list