[Mlir-commits] [mlir] fff39b6 - [mlir][Attribute] Remove usages of Attribute::getKind
River Riddle
llvmlistbot at llvm.org
Fri Aug 7 13:43:52 PDT 2020
Author: River Riddle
Date: 2020-08-07T13:43:25-07:00
New Revision: fff39b62bb4078ce78813f25c04e0da435a8feb3
URL: https://github.com/llvm/llvm-project/commit/fff39b62bb4078ce78813f25c04e0da435a8feb3
DIFF: https://github.com/llvm/llvm-project/commit/fff39b62bb4078ce78813f25c04e0da435a8feb3.diff
LOG: [mlir][Attribute] Remove usages of Attribute::getKind
This is in preparation for removing the use of "kinds" within attributes and types in MLIR.
Differential Revision: https://reviews.llvm.org/D85370
Added:
Modified:
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Location.h
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Diagnostics.cpp
mlir/lib/IR/Location.cpp
mlir/lib/Target/LLVMIR/DebugTranslation.cpp
mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index a57adb315bc3..75ac2adc302c 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -661,10 +661,7 @@ class ElementsAttr : public Attribute {
function_ref<APInt(const APFloat &)> mapping) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
- static bool classof(Attribute attr) {
- return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
- attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
- }
+ static bool classof(Attribute attr);
protected:
/// Returns the 1 dimensional flattened row-major index from the given
@@ -729,10 +726,7 @@ class DenseElementsAttr : public ElementsAttr {
using ElementsAttr::ElementsAttr;
/// Method for support type inquiry through isa, cast and dyn_cast.
- static bool classof(Attribute attr) {
- return attr.getKind() == StandardAttributes::DenseIntOrFPElements ||
- attr.getKind() == StandardAttributes::DenseStringElements;
- }
+ static bool classof(Attribute attr);
/// Constructs a dense elements attribute from an array of element values.
/// Each element attribute value is expected to be an element of 'type'.
@@ -1513,12 +1507,10 @@ class ElementsAttrIterator
template <typename RetT, template <typename> class ProcessFn,
typename... Args>
RetT process(Args &... args) const {
- switch (attrKind) {
- case StandardAttributes::DenseIntOrFPElements:
+ if (attr.isa<DenseElementsAttr>())
return ProcessFn<DenseIteratorT>()(args...);
- case StandardAttributes::SparseElements:
+ if (attr.isa<SparseElementsAttr>())
return ProcessFn<SparseIteratorT>()(args...);
- }
llvm_unreachable("unexpected attribute kind");
}
@@ -1543,22 +1535,21 @@ class ElementsAttrIterator
};
public:
- ElementsAttrIterator(const ElementsAttrIterator<T> &rhs)
- : attrKind(rhs.attrKind) {
+ ElementsAttrIterator(const ElementsAttrIterator<T> &rhs) : attr(rhs.attr) {
process<void, ConstructIter>(it, rhs.it);
}
~ElementsAttrIterator() { process<void, DestructIter>(it); }
/// Methods necessary to support random access iteration.
ptr
diff _t operator-(const ElementsAttrIterator<T> &rhs) const {
- assert(attrKind == rhs.attrKind && "incompatible iterators");
+ assert(attr == rhs.attr && "incompatible iterators");
return process<ptr
diff _t, Minus>(it, rhs.it);
}
bool operator==(const ElementsAttrIterator<T> &rhs) const {
- return rhs.attrKind == attrKind && process<bool, std::equal_to>(it, rhs.it);
+ return rhs.attr == attr && process<bool, std::equal_to>(it, rhs.it);
}
bool operator<(const ElementsAttrIterator<T> &rhs) const {
- assert(attrKind == rhs.attrKind && "incompatible iterators");
+ assert(attr == rhs.attr && "incompatible iterators");
return process<bool, std::less>(it, rhs.it);
}
ElementsAttrIterator<T> &operator+=(ptr
diff _t offset) {
@@ -1575,14 +1566,14 @@ class ElementsAttrIterator
private:
template <typename IteratorT>
- ElementsAttrIterator(unsigned attrKind, IteratorT &&it)
- : attrKind(attrKind), it(std::forward<IteratorT>(it)) {}
+ ElementsAttrIterator(Attribute attr, IteratorT &&it)
+ : attr(attr), it(std::forward<IteratorT>(it)) {}
/// Allow accessing the constructor.
friend ElementsAttr;
- /// The kind of derived elements attribute.
- unsigned attrKind;
+ /// The parent elements attribute.
+ Attribute attr;
/// A union containing the specific iterators for each derived kind.
Iterator it;
@@ -1599,13 +1590,13 @@ template <typename T>
auto ElementsAttr::getValues() const -> iterator_range<T> {
if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
auto values = denseAttr.getValues<T>();
- return {iterator<T>(getKind(), values.begin()),
- iterator<T>(getKind(), values.end())};
+ return {iterator<T>(*this, values.begin()),
+ iterator<T>(*this, values.end())};
}
if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
auto values = sparseAttr.getValues<T>();
- return {iterator<T>(getKind(), values.begin()),
- iterator<T>(getKind(), values.end())};
+ return {iterator<T>(*this, values.begin()),
+ iterator<T>(*this, values.end())};
}
llvm_unreachable("unexpected attribute kind");
}
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index 5bbf003a4271..aed0c4239fb4 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -42,10 +42,7 @@ class LocationAttr : public Attribute {
using Attribute::Attribute;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Attribute attr) {
- return attr.getKind() >= StandardAttributes::FIRST_LOCATION_ATTR &&
- attr.getKind() <= StandardAttributes::LAST_LOCATION_ATTR;
- }
+ static bool classof(Attribute attr);
};
/// This class defines the main interface for locations in MLIR and acts as a
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index b7d36f4a9487..5e098e815d98 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1472,18 +1472,15 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
// ODS already generates checks to make sure the result type is valid. We just
// need to additionally check that the value's attribute type is consistent
// with the result type.
- switch (value.getKind()) {
- case StandardAttributes::Integer:
- case StandardAttributes::Float: {
+ if (value.isa<IntegerAttr, FloatAttr>()) {
if (valueType != opType)
return constOp.emitOpError("result type (")
<< opType << ") does not match value type (" << valueType << ")";
return success();
- } break;
- case StandardAttributes::DenseIntOrFPElements:
- case StandardAttributes::SparseElements: {
+ }
+ if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
if (valueType == opType)
- break;
+ return success();
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
auto shapedType = valueType.dyn_cast<ShapedType>();
if (!arrayType) {
@@ -1497,9 +1494,8 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
numElements *= t.getNumElements();
opElemType = t.getElementType();
}
- if (!opElemType.isIntOrFloat()) {
+ if (!opElemType.isIntOrFloat())
return constOp.emitOpError("only support nested array result type");
- }
auto valueElemType = shapedType.getElementType();
if (valueElemType != opElemType) {
@@ -1513,26 +1509,24 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
<< numElements << ") does not match value number of elements ("
<< shapedType.getNumElements() << ")";
}
- } break;
- case StandardAttributes::Array: {
+ return success();
+ }
+ if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
if (!arrayType)
return constOp.emitOpError(
"must have spv.array result type for array value");
- auto elemType = arrayType.getElementType();
- for (auto element : value.cast<ArrayAttr>().getValue()) {
+ Type elemType = arrayType.getElementType();
+ for (Attribute element : attayAttr.getValue()) {
if (element.getType() != elemType)
return constOp.emitOpError("has array element whose type (")
<< element.getType()
<< ") does not match the result element type (" << elemType
<< ')';
}
- } break;
- default:
- return constOp.emitOpError("cannot have value of type ") << valueType;
+ return success();
}
-
- return success();
+ return constOp.emitOpError("cannot have value of type ") << valueType;
}
bool spirv::ConstantOp::isBuildableWith(Type type) {
@@ -2619,19 +2613,14 @@ static LogicalResult verify(spirv::SpecConstantOp constOp) {
return constOp.emitOpError("SpecId cannot be negative");
auto value = constOp.default_value();
-
- switch (value.getKind()) {
- case StandardAttributes::Integer:
- case StandardAttributes::Float: {
+ if (value.isa<IntegerAttr, FloatAttr>()) {
// Make sure bitwidth is allowed.
if (!value.getType().isa<spirv::SPIRVType>())
return constOp.emitOpError("default value bitwidth disallowed");
return success();
}
- default:
- return constOp.emitOpError(
- "default value can only be a bool, integer, or float scalar");
- }
+ return constOp.emitOpError(
+ "default value can only be a bool, integer, or float scalar");
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 372e4c93dc37..4470f5b4b826 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -33,6 +33,7 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SaveAndRestore.h"
@@ -1019,76 +1020,67 @@ void ModulePrinter::printTrailingLocation(Location loc) {
}
void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
- switch (loc.getKind()) {
- case StandardAttributes::OpaqueLocation:
- printLocationInternal(loc.cast<OpaqueLoc>().getFallbackLocation(), pretty);
- break;
- case StandardAttributes::UnknownLocation:
- if (pretty)
- os << "[unknown]";
- else
- os << "unknown";
- break;
- case StandardAttributes::FileLineColLocation: {
- auto fileLoc = loc.cast<FileLineColLoc>();
- auto mayQuote = pretty ? "" : "\"";
- os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
- << fileLoc.getLine() << ':' << fileLoc.getColumn();
- break;
- }
- case StandardAttributes::NameLocation: {
- auto nameLoc = loc.cast<NameLoc>();
- os << '\"' << nameLoc.getName() << '\"';
-
- // Print the child if it isn't unknown.
- auto childLoc = nameLoc.getChildLoc();
- if (!childLoc.isa<UnknownLoc>()) {
- os << '(';
- printLocationInternal(childLoc, pretty);
- os << ')';
- }
- break;
- }
- case StandardAttributes::CallSiteLocation: {
- auto callLocation = loc.cast<CallSiteLoc>();
- auto caller = callLocation.getCaller();
- auto callee = callLocation.getCallee();
- if (!pretty)
- os << "callsite(";
- printLocationInternal(callee, pretty);
- if (pretty) {
- if (callee.isa<NameLoc>()) {
- if (caller.isa<FileLineColLoc>()) {
- os << " at ";
+ TypeSwitch<LocationAttr>(loc)
+ .Case<OpaqueLoc>([&](OpaqueLoc loc) {
+ printLocationInternal(loc.getFallbackLocation(), pretty);
+ })
+ .Case<UnknownLoc>([&](UnknownLoc loc) {
+ if (pretty)
+ os << "[unknown]";
+ else
+ os << "unknown";
+ })
+ .Case<FileLineColLoc>([&](FileLineColLoc loc) {
+ StringRef mayQuote = pretty ? "" : "\"";
+ os << mayQuote << loc.getFilename() << mayQuote << ':' << loc.getLine()
+ << ':' << loc.getColumn();
+ })
+ .Case<NameLoc>([&](NameLoc loc) {
+ os << '\"' << loc.getName() << '\"';
+
+ // Print the child if it isn't unknown.
+ auto childLoc = loc.getChildLoc();
+ if (!childLoc.isa<UnknownLoc>()) {
+ os << '(';
+ printLocationInternal(childLoc, pretty);
+ os << ')';
+ }
+ })
+ .Case<CallSiteLoc>([&](CallSiteLoc loc) {
+ Location caller = loc.getCaller();
+ Location callee = loc.getCallee();
+ if (!pretty)
+ os << "callsite(";
+ printLocationInternal(callee, pretty);
+ if (pretty) {
+ if (callee.isa<NameLoc>()) {
+ if (caller.isa<FileLineColLoc>()) {
+ os << " at ";
+ } else {
+ os << newLine << " at ";
+ }
+ } else {
+ os << newLine << " at ";
+ }
} else {
- os << newLine << " at ";
+ os << " at ";
}
- } else {
- os << newLine << " at ";
- }
- } else {
- os << " at ";
- }
- printLocationInternal(caller, pretty);
- if (!pretty)
- os << ")";
- break;
- }
- case StandardAttributes::FusedLocation: {
- auto fusedLoc = loc.cast<FusedLoc>();
- if (!pretty)
- os << "fused";
- if (auto metadata = fusedLoc.getMetadata())
- os << '<' << metadata << '>';
- os << '[';
- interleave(
- fusedLoc.getLocations(),
- [&](Location loc) { printLocationInternal(loc, pretty); },
- [&]() { os << ", "; });
- os << ']';
- break;
- }
- }
+ printLocationInternal(caller, pretty);
+ if (!pretty)
+ os << ")";
+ })
+ .Case<FusedLoc>([&](FusedLoc loc) {
+ if (!pretty)
+ os << "fused";
+ if (Attribute metadata = loc.getMetadata())
+ os << '<' << metadata << '>';
+ os << '[';
+ interleave(
+ loc.getLocations(),
+ [&](Location loc) { printLocationInternal(loc, pretty); },
+ [&]() { os << ", "; });
+ os << ']';
+ });
}
/// Print a floating point value in a way that the parser will be able to
@@ -1305,27 +1297,19 @@ void ModulePrinter::printAttribute(Attribute attr,
}
auto attrType = attr.getType();
- switch (attr.getKind()) {
- default:
- return printDialectAttribute(attr);
-
- case StandardAttributes::Opaque: {
- auto opaqueAttr = attr.cast<OpaqueAttr>();
+ if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
opaqueAttr.getAttrData());
- break;
- }
- case StandardAttributes::Unit:
+ } else if (attr.isa<UnitAttr>()) {
os << "unit";
- break;
- case StandardAttributes::Dictionary:
+ return;
+ } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
os << '{';
- interleaveComma(attr.cast<DictionaryAttr>().getValue(),
+ interleaveComma(dictAttr.getValue(),
[&](NamedAttribute attr) { printNamedAttribute(attr); });
os << '}';
- break;
- case StandardAttributes::Integer: {
- auto intAttr = attr.cast<IntegerAttr>();
+
+ } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
if (attrType.isSignlessInteger(1)) {
os << (intAttr.getValue().getBoolValue() ? "true" : "false");
@@ -1343,114 +1327,98 @@ void ModulePrinter::printAttribute(Attribute attr,
// IntegerAttr elides the type if I64.
if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
return;
- break;
- }
- case StandardAttributes::Float: {
- auto floatAttr = attr.cast<FloatAttr>();
+
+ } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
printFloatValue(floatAttr.getValue(), os);
// FloatAttr elides the type if F64.
if (typeElision == AttrTypeElision::May && attrType.isF64())
return;
- break;
- }
- case StandardAttributes::String:
+
+ } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
os << '"';
- printEscapedString(attr.cast<StringAttr>().getValue(), os);
+ printEscapedString(strAttr.getValue(), os);
os << '"';
- break;
- case StandardAttributes::Array:
+
+ } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
os << '[';
- interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
+ interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
printAttribute(attr, AttrTypeElision::May);
});
os << ']';
- break;
- case StandardAttributes::AffineMap:
+
+ } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
os << "affine_map<";
- attr.cast<AffineMapAttr>().getValue().print(os);
+ affineMapAttr.getValue().print(os);
os << '>';
// AffineMap always elides the type.
return;
- case StandardAttributes::IntegerSet:
+
+ } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
os << "affine_set<";
- attr.cast<IntegerSetAttr>().getValue().print(os);
+ integerSetAttr.getValue().print(os);
os << '>';
// IntegerSet always elides the type.
return;
- case StandardAttributes::Type:
- printType(attr.cast<TypeAttr>().getValue());
- break;
- case StandardAttributes::SymbolRef: {
- auto refAttr = attr.dyn_cast<SymbolRefAttr>();
+
+ } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
+ printType(typeAttr.getValue());
+
+ } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
printSymbolReference(refAttr.getRootReference(), os);
for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
os << "::";
printSymbolReference(nestedRef.getValue(), os);
}
- break;
- }
- case StandardAttributes::OpaqueElements: {
- auto eltsAttr = attr.cast<OpaqueElementsAttr>();
- if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
+
+ } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
+ if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
printElidedElementsAttr(os);
- break;
+ } else {
+ os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", ";
+ os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">";
}
- os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
- os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
- break;
- }
- case StandardAttributes::DenseIntOrFPElements: {
- auto eltsAttr = attr.cast<DenseIntOrFPElementsAttr>();
- if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
+
+ } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
+ if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
printElidedElementsAttr(os);
- break;
+ } else {
+ os << "dense<";
+ printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
+ os << '>';
}
- os << "dense<";
- printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true);
- os << '>';
- break;
- }
- case StandardAttributes::DenseStringElements: {
- auto eltsAttr = attr.cast<DenseStringElementsAttr>();
- if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
+
+ } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
+ if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
printElidedElementsAttr(os);
- break;
+ } else {
+ os << "dense<";
+ printDenseStringElementsAttr(strEltAttr);
+ os << '>';
}
- os << "dense<";
- printDenseStringElementsAttr(eltsAttr);
- os << '>';
- break;
- }
- case StandardAttributes::SparseElements: {
- auto elementsAttr = attr.cast<SparseElementsAttr>();
- if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) ||
- printerFlags.shouldElideElementsAttr(elementsAttr.getValues())) {
+
+ } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
+ if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
+ printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
printElidedElementsAttr(os);
- break;
- }
- os << "sparse<";
- DenseIntElementsAttr indices = elementsAttr.getIndices();
- if (indices.getNumElements() != 0) {
- printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
- os << ", ";
- printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
+ } else {
+ os << "sparse<";
+ DenseIntElementsAttr indices = sparseEltAttr.getIndices();
+ if (indices.getNumElements() != 0) {
+ printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
+ os << ", ";
+ printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
+ }
+ os << '>';
}
- os << '>';
- break;
- }
- // Location attributes.
- case StandardAttributes::CallSiteLocation:
- case StandardAttributes::FileLineColLocation:
- case StandardAttributes::FusedLocation:
- case StandardAttributes::NameLocation:
- case StandardAttributes::OpaqueLocation:
- case StandardAttributes::UnknownLocation:
- printLocation(attr.cast<LocationAttr>());
- break;
+ } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
+ printLocation(locAttr);
+
+ } else {
+ return printDialectAttribute(attr);
}
// Don't print the type if we must elide it, or if it is a None type.
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 8202465af0c7..e353b0b90a2a 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -460,16 +460,11 @@ int64_t ElementsAttr::getNumElements() const {
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
- switch (getKind()) {
- case StandardAttributes::DenseIntOrFPElements:
- return cast<DenseElementsAttr>().getValue(index);
- case StandardAttributes::OpaqueElements:
- return cast<OpaqueElementsAttr>().getValue(index);
- case StandardAttributes::SparseElements:
- return cast<SparseElementsAttr>().getValue(index);
- default:
- llvm_unreachable("unknown ElementsAttr kind");
- }
+ if (auto denseAttr = dyn_cast<DenseElementsAttr>())
+ return denseAttr.getValue(index);
+ if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
+ return opaqueAttr.getValue(index);
+ return cast<SparseElementsAttr>().getValue(index);
}
/// Return if the given 'index' refers to a valid element in this attribute.
@@ -491,23 +486,23 @@ bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
ElementsAttr
ElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APInt &)> mapping) const {
- switch (getKind()) {
- case StandardAttributes::DenseIntOrFPElements:
- return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
- default:
- llvm_unreachable("unsupported ElementsAttr subtype");
- }
+ if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
+ return intOrFpAttr.mapValues(newElementType, mapping);
+ llvm_unreachable("unsupported ElementsAttr subtype");
}
ElementsAttr
ElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APFloat &)> mapping) const {
- switch (getKind()) {
- case StandardAttributes::DenseIntOrFPElements:
- return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
- default:
- llvm_unreachable("unsupported ElementsAttr subtype");
- }
+ if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
+ return intOrFpAttr.mapValues(newElementType, mapping);
+ llvm_unreachable("unsupported ElementsAttr subtype");
+}
+
+/// Method for support type inquiry through isa, cast and dyn_cast.
+bool ElementsAttr::classof(Attribute attr) {
+ return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
+ OpaqueElementsAttr, SparseElementsAttr>();
}
/// Returns the 1 dimensional flattened row-major index from the given
@@ -718,6 +713,11 @@ DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
// DenseElementsAttr
//===----------------------------------------------------------------------===//
+/// Method for support type inquiry through isa, cast and dyn_cast.
+bool DenseElementsAttr::classof(Attribute attr) {
+ return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
+}
+
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(hasSameElementsOrSplat(type, values));
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 34908c61639f..15dbac1f6cb8 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -366,43 +366,38 @@ struct SourceMgrDiagnosticHandlerImpl {
/// Return a processable FileLineColLoc from the given location.
static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
- switch (loc->getKind()) {
- case StandardAttributes::NameLocation:
+ if (auto nameLoc = loc.dyn_cast<NameLoc>())
return getFileLineColLoc(loc.cast<NameLoc>().getChildLoc());
- case StandardAttributes::FileLineColLocation:
- return loc.cast<FileLineColLoc>();
- case StandardAttributes::CallSiteLocation:
- // Process the callee of a callsite location.
+ if (auto fileLoc = loc.dyn_cast<FileLineColLoc>())
+ return fileLoc;
+ if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
return getFileLineColLoc(loc.cast<CallSiteLoc>().getCallee());
- case StandardAttributes::FusedLocation:
+ if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
if (auto callLoc = getFileLineColLoc(subLoc)) {
return callLoc;
}
}
return llvm::None;
- default:
- return llvm::None;
}
+ return llvm::None;
}
/// Return a processable CallSiteLoc from the given location.
static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
- switch (loc->getKind()) {
- case StandardAttributes::NameLocation:
+ if (auto nameLoc = loc.dyn_cast<NameLoc>())
return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
- case StandardAttributes::CallSiteLocation:
- return loc.cast<CallSiteLoc>();
- case StandardAttributes::FusedLocation:
+ if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
+ return callLoc;
+ if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
if (auto callLoc = getCallSiteLoc(subLoc)) {
return callLoc;
}
}
return llvm::None;
- default:
- return llvm::None;
}
+ return llvm::None;
}
/// Given a diagnostic kind, returns the LLVM DiagKind.
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index f22fd5cb7852..48b05ba0eb40 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -13,6 +13,16 @@
using namespace mlir;
using namespace mlir::detail;
+//===----------------------------------------------------------------------===//
+// LocationAttr
+//===----------------------------------------------------------------------===//
+
+/// Methods for support type inquiry through isa, cast, and dyn_cast.
+bool LocationAttr::classof(Attribute attr) {
+ return attr.isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
+ UnknownLoc>();
+}
+
//===----------------------------------------------------------------------===//
// CallSiteLoc
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index f40f44f2fbc6..af364ba9048f 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -115,26 +115,19 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
return existingIt->second;
const llvm::DILocation *llvmLoc = nullptr;
- switch (loc->getKind()) {
- case StandardAttributes::CallSiteLocation: {
- auto callLoc = loc.dyn_cast<CallSiteLoc>();
-
+ if (auto callLoc = loc.dyn_cast<CallSiteLoc>()) {
// For callsites, the caller is fed as the inlinedAt for the callee.
const auto *callerLoc = translateLoc(callLoc.getCaller(), scope, inlinedAt);
llvmLoc = translateLoc(callLoc.getCallee(), scope, callerLoc);
- break;
- }
- case StandardAttributes::FileLineColLocation: {
- auto fileLoc = loc.dyn_cast<FileLineColLoc>();
+
+ } else if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
auto *file = translateFile(fileLoc.getFilename());
auto *fileScope = builder.createLexicalBlockFile(scope, file);
llvmLoc = llvm::DILocation::get(llvmCtx, fileLoc.getLine(),
fileLoc.getColumn(), fileScope,
const_cast<llvm::DILocation *>(inlinedAt));
- break;
- }
- case StandardAttributes::FusedLocation: {
- auto fusedLoc = loc.dyn_cast<FusedLoc>();
+
+ } else if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
ArrayRef<Location> locations = fusedLoc.getLocations();
// For fused locations, merge each of the nodes.
@@ -143,18 +136,17 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
llvmLoc = llvm::DILocation::getMergedLocation(
llvmLoc, translateLoc(locIt, scope, inlinedAt));
}
- break;
- }
- case StandardAttributes::NameLocation:
+
+ } else if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
llvmLoc = translateLoc(loc.cast<NameLoc>().getChildLoc(), scope, inlinedAt);
- break;
- case StandardAttributes::OpaqueLocation:
+
+ } else if (auto opaqueLoc = loc.dyn_cast<OpaqueLoc>()) {
llvmLoc = translateLoc(loc.cast<OpaqueLoc>().getFallbackLocation(), scope,
inlinedAt);
- break;
- default:
+ } else {
llvm_unreachable("unknown location kind");
}
+
locationToLoc.try_emplace(std::make_pair(loc, scope), llvmLoc);
return llvmLoc;
}
diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
index 68e43d45ab42..97c94a54ffc4 100644
--- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
@@ -158,7 +158,7 @@ TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
auto expectedTensorType = realValue.getType().cast<TensorType>();
EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
EXPECT_EQ(tensorType.getElementType(), convertedType);
- EXPECT_EQ(returnedValue.getKind(), StandardAttributes::SparseElements);
+ EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
// Check Elements attribute element value is expected.
auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
More information about the Mlir-commits
mailing list