[Mlir-commits] [mlir] [Interfaces] Migrate away from PointerUnion::{is, get} (NFC) (PR #120845)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Dec 21 09:06:45 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kazu Hirata (kazutakahirata)

<details>
<summary>Changes</summary>

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.


---
Full diff: https://github.com/llvm/llvm-project/pull/120845.diff


3 Files Affected:

- (modified) mlir/lib/Interfaces/CallInterfaces.cpp (+1-1) 
- (modified) mlir/lib/Interfaces/DataLayoutInterfaces.cpp (+8-8) 
- (modified) mlir/lib/Interfaces/InferTypeOpInterface.cpp (+10-10) 


``````````diff
diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp
index 9e5bc159dc8908..da0ca0e24630f0 100644
--- a/mlir/lib/Interfaces/CallInterfaces.cpp
+++ b/mlir/lib/Interfaces/CallInterfaces.cpp
@@ -22,7 +22,7 @@ call_interface_impl::resolveCallable(CallOpInterface call,
     return symbolVal.getDefiningOp();
 
   // If the callable isn't a value, lookup the symbol reference.
-  auto symbolRef = callable.get<SymbolRefAttr>();
+  auto symbolRef = cast<SymbolRefAttr>(callable);
   if (symbolTable)
     return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
   return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index 1c661e3beea48e..049d7f123cec8f 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -95,7 +95,7 @@ findEntryForIntegerType(IntegerType intType,
   std::map<unsigned, DataLayoutEntryInterface> sortedParams;
   for (DataLayoutEntryInterface entry : params) {
     sortedParams.insert(std::make_pair(
-        entry.getKey().get<Type>().getIntOrFloatBitWidth(), entry));
+        cast<Type>(entry.getKey()).getIntOrFloatBitWidth(), entry));
   }
   auto iter = sortedParams.lower_bound(intType.getWidth());
   if (iter == sortedParams.end())
@@ -315,9 +315,9 @@ DataLayoutEntryInterface
 mlir::detail::filterEntryForIdentifier(DataLayoutEntryListRef entries,
                                        StringAttr id) {
   const auto *it = llvm::find_if(entries, [id](DataLayoutEntryInterface entry) {
-    if (!entry.getKey().is<StringAttr>())
-      return false;
-    return entry.getKey().get<StringAttr>() == id;
+    if (auto attr = dyn_cast<StringAttr>(entry.getKey()))
+      return attr == id;
+    return false;
   });
   return it == entries.end() ? DataLayoutEntryInterface() : *it;
 }
@@ -691,7 +691,7 @@ void DataLayoutSpecInterface::bucketEntriesByType(
     if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()))
       types[type.getTypeID()].push_back(entry);
     else
-      ids[entry.getKey().get<StringAttr>()] = entry;
+      ids[llvm::cast<StringAttr>(entry.getKey())] = entry;
   }
 }
 
@@ -709,7 +709,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
   spec.bucketEntriesByType(types, ids);
 
   for (const auto &kvp : types) {
-    auto sampleType = kvp.second.front().getKey().get<Type>();
+    auto sampleType = cast<Type>(kvp.second.front().getKey());
     if (isa<IndexType>(sampleType)) {
       assert(kvp.second.size() == 1 &&
              "expected one data layout entry for non-parametric 'index' type");
@@ -763,7 +763,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
   }
 
   for (const auto &kvp : ids) {
-    StringAttr identifier = kvp.second.getKey().get<StringAttr>();
+    StringAttr identifier = cast<StringAttr>(kvp.second.getKey());
     Dialect *dialect = identifier.getReferencedDialect();
 
     // Ignore attributes that belong to an unknown dialect, the dialect may
@@ -816,7 +816,7 @@ mlir::detail::verifyTargetSystemSpec(TargetSystemSpecInterface spec,
         // targetDeviceSpec does not support Type as a key.
         return failure();
       } else {
-        deviceDescKeys[entry.getKey().get<StringAttr>()] = entry;
+        deviceDescKeys[cast<StringAttr>(entry.getKey())] = entry;
       }
     }
   }
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 8cc4206dae6edf..3eb401c4499805 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -53,7 +53,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
       // * Attribute for static dimensions
       // * Value for dynamic dimensions
       assert(shapedType.isDynamicDim(dim) ==
-                 reifiedReturnShapes[resultIdx][dim].is<Value>() &&
+                 isa<Value>(reifiedReturnShapes[resultIdx][dim]) &&
              "incorrect implementation of ReifyRankedShapedTypeOpInterface");
     }
     ++resultIdx;
@@ -70,9 +70,9 @@ bool ShapeAdaptor::hasRank() const {
     return false;
   if (auto t = llvm::dyn_cast_if_present<Type>(val))
     return cast<ShapedType>(t).hasRank();
-  if (val.is<Attribute>())
+  if (isa<Attribute>(val))
     return true;
-  return val.get<ShapedTypeComponents *>()->hasRank();
+  return cast<ShapedTypeComponents *>(val)->hasRank();
 }
 
 Type ShapeAdaptor::getElementType() const {
@@ -80,9 +80,9 @@ Type ShapeAdaptor::getElementType() const {
     return nullptr;
   if (auto t = llvm::dyn_cast_if_present<Type>(val))
     return cast<ShapedType>(t).getElementType();
-  if (val.is<Attribute>())
+  if (isa<Attribute>(val))
     return nullptr;
-  return val.get<ShapedTypeComponents *>()->getElementType();
+  return cast<ShapedTypeComponents *>(val)->getElementType();
 }
 
 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
@@ -97,7 +97,7 @@ void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
     for (auto it : dattr.getValues<APInt>())
       res.push_back(it.getSExtValue());
   } else {
-    auto vals = val.get<ShapedTypeComponents *>()->getDims();
+    auto vals = cast<ShapedTypeComponents *>(val)->getDims();
     res.assign(vals.begin(), vals.end());
   }
 }
@@ -116,7 +116,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
     return cast<DenseIntElementsAttr>(attr)
         .getValues<APInt>()[index]
         .getSExtValue();
-  auto *stc = val.get<ShapedTypeComponents *>();
+  auto *stc = cast<ShapedTypeComponents *>(val);
   return stc->getDims()[index];
 }
 
@@ -126,7 +126,7 @@ int64_t ShapeAdaptor::getRank() const {
     return cast<ShapedType>(t).getRank();
   if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
     return cast<DenseIntElementsAttr>(attr).size();
-  return val.get<ShapedTypeComponents *>()->getDims().size();
+  return cast<ShapedTypeComponents *>(val)->getDims().size();
 }
 
 bool ShapeAdaptor::hasStaticShape() const {
@@ -142,7 +142,7 @@ bool ShapeAdaptor::hasStaticShape() const {
         return false;
     return true;
   }
-  auto *stc = val.get<ShapedTypeComponents *>();
+  auto *stc = cast<ShapedTypeComponents *>(val);
   return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
 }
 
@@ -162,7 +162,7 @@ int64_t ShapeAdaptor::getNumElements() const {
     return num;
   }
 
-  auto *stc = val.get<ShapedTypeComponents *>();
+  auto *stc = cast<ShapedTypeComponents *>(val);
   int64_t num = 1;
   for (int64_t dim : stc->getDims()) {
     num *= dim;

``````````

</details>


https://github.com/llvm/llvm-project/pull/120845


More information about the Mlir-commits mailing list