[Mlir-commits] [mlir] [Interfaces] Migrate away from PointerUnion::{is, get} (NFC) (PR #120845)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Dec 21 09:06:45 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kazu Hirata (kazutakahirata)
<details>
<summary>Changes</summary>
Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:
// FIXME: Replace the uses of is(), get() and dyn_cast() with
// isa<T>, cast<T> and the llvm::dyn_cast<T>
I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.
---
Full diff: https://github.com/llvm/llvm-project/pull/120845.diff
3 Files Affected:
- (modified) mlir/lib/Interfaces/CallInterfaces.cpp (+1-1)
- (modified) mlir/lib/Interfaces/DataLayoutInterfaces.cpp (+8-8)
- (modified) mlir/lib/Interfaces/InferTypeOpInterface.cpp (+10-10)
``````````diff
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index 9e5bc159dc8908..da0ca0e24630f0 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -22,7 +22,7 @@ call_interface_impl::resolveCallable(CallOpInterface call,
return symbolVal.getDefiningOp();
// If the callable isn't a value, lookup the symbol reference.
- auto symbolRef = callable.get<SymbolRefAttr>();
+ auto symbolRef = cast<SymbolRefAttr>(callable);
if (symbolTable)
return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 1c661e3beea48e..049d7f123cec8f 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -95,7 +95,7 @@ findEntryForIntegerType(IntegerType intType,
std::map<unsigned, DataLayoutEntryInterface> sortedParams;
for (DataLayoutEntryInterface entry : params) {
sortedParams.insert(std::make_pair(
- entry.getKey().get<Type>().getIntOrFloatBitWidth(), entry));
+ cast<Type>(entry.getKey()).getIntOrFloatBitWidth(), entry));
}
auto iter = sortedParams.lower_bound(intType.getWidth());
if (iter == sortedParams.end())
@@ -315,9 +315,9 @@ DataLayoutEntryInterface
mlir::detail::filterEntryForIdentifier(DataLayoutEntryListRef entries,
StringAttr id) {
const auto *it = llvm::find_if(entries, [id](DataLayoutEntryInterface entry) {
- if (!entry.getKey().is<StringAttr>())
- return false;
- return entry.getKey().get<StringAttr>() == id;
+ if (auto attr = dyn_cast<StringAttr>(entry.getKey()))
+ return attr == id;
+ return false;
});
return it == entries.end() ? DataLayoutEntryInterface() : *it;
}
@@ -691,7 +691,7 @@ void DataLayoutSpecInterface::bucketEntriesByType(
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()))
types[type.getTypeID()].push_back(entry);
else
- ids[entry.getKey().get<StringAttr>()] = entry;
+ ids[llvm::cast<StringAttr>(entry.getKey())] = entry;
}
}
@@ -709,7 +709,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
spec.bucketEntriesByType(types, ids);
for (const auto &kvp : types) {
- auto sampleType = kvp.second.front().getKey().get<Type>();
+ auto sampleType = cast<Type>(kvp.second.front().getKey());
if (isa<IndexType>(sampleType)) {
assert(kvp.second.size() == 1 &&
"expected one data layout entry for non-parametric 'index' type");
@@ -763,7 +763,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
}
for (const auto &kvp : ids) {
- StringAttr identifier = kvp.second.getKey().get<StringAttr>();
+ StringAttr identifier = cast<StringAttr>(kvp.second.getKey());
Dialect *dialect = identifier.getReferencedDialect();
// Ignore attributes that belong to an unknown dialect, the dialect may
@@ -816,7 +816,7 @@ mlir::detail::verifyTargetSystemSpec(TargetSystemSpecInterface spec,
// targetDeviceSpec does not support Type as a key.
return failure();
} else {
- deviceDescKeys[entry.getKey().get<StringAttr>()] = entry;
+ deviceDescKeys[cast<StringAttr>(entry.getKey())] = entry;
}
}
}
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 8cc4206dae6edf..3eb401c4499805 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -53,7 +53,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
// * Attribute for static dimensions
// * Value for dynamic dimensions
assert(shapedType.isDynamicDim(dim) ==
- reifiedReturnShapes[resultIdx][dim].is<Value>() &&
+ isa<Value>(reifiedReturnShapes[resultIdx][dim]) &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
}
++resultIdx;
@@ -70,9 +70,9 @@ bool ShapeAdaptor::hasRank() const {
return false;
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).hasRank();
- if (val.is<Attribute>())
+ if (isa<Attribute>(val))
return true;
- return val.get<ShapedTypeComponents *>()->hasRank();
+ return cast<ShapedTypeComponents *>(val)->hasRank();
}
Type ShapeAdaptor::getElementType() const {
@@ -80,9 +80,9 @@ Type ShapeAdaptor::getElementType() const {
return nullptr;
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getElementType();
- if (val.is<Attribute>())
+ if (isa<Attribute>(val))
return nullptr;
- return val.get<ShapedTypeComponents *>()->getElementType();
+ return cast<ShapedTypeComponents *>(val)->getElementType();
}
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
@@ -97,7 +97,7 @@ void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
for (auto it : dattr.getValues<APInt>())
res.push_back(it.getSExtValue());
} else {
- auto vals = val.get<ShapedTypeComponents *>()->getDims();
+ auto vals = cast<ShapedTypeComponents *>(val)->getDims();
res.assign(vals.begin(), vals.end());
}
}
@@ -116,7 +116,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
return cast<DenseIntElementsAttr>(attr)
.getValues<APInt>()[index]
.getSExtValue();
- auto *stc = val.get<ShapedTypeComponents *>();
+ auto *stc = cast<ShapedTypeComponents *>(val);
return stc->getDims()[index];
}
@@ -126,7 +126,7 @@ int64_t ShapeAdaptor::getRank() const {
return cast<ShapedType>(t).getRank();
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
return cast<DenseIntElementsAttr>(attr).size();
- return val.get<ShapedTypeComponents *>()->getDims().size();
+ return cast<ShapedTypeComponents *>(val)->getDims().size();
}
bool ShapeAdaptor::hasStaticShape() const {
@@ -142,7 +142,7 @@ bool ShapeAdaptor::hasStaticShape() const {
return false;
return true;
}
- auto *stc = val.get<ShapedTypeComponents *>();
+ auto *stc = cast<ShapedTypeComponents *>(val);
return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
}
@@ -162,7 +162,7 @@ int64_t ShapeAdaptor::getNumElements() const {
return num;
}
- auto *stc = val.get<ShapedTypeComponents *>();
+ auto *stc = cast<ShapedTypeComponents *>(val);
int64_t num = 1;
for (int64_t dim : stc->getDims()) {
num *= dim;
``````````
</details>
https://github.com/llvm/llvm-project/pull/120845
More information about the Mlir-commits
mailing list