[Mlir-commits] [mlir] [Dialect] Migrate away from PointerUnion::{is, get} (NFC) (#120679) (PR #120818)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 20 19:23:20 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-dlti

Author: Kazu Hirata (kazutakahirata)

<details>
<summary>Changes</summary>

Note that PointerUnion::{is,get} have been soft deprecated in
PointerUnion.h:

  // FIXME: Replace the uses of is(), get() and dyn_cast() with
  //        isa<T>, cast<T> and the llvm::dyn_cast<T>

I'm not touching PointerUnion::dyn_cast for now because it's a bit
complicated; we could blindly migrate it to dyn_cast_if_present, but
we should probably use dyn_cast when the operand is known to be
non-null.


---
Full diff: https://github.com/llvm/llvm-project/pull/120818.diff


10 Files Affected:

- (modified) mlir/lib/Dialect/Arith/Utils/Utils.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-1) 
- (modified) mlir/lib/Dialect/DLTI/DLTI.cpp (+3-3) 
- (modified) mlir/lib/Dialect/GPU/TransformOps/Utils.cpp (+13-13) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp (+3-3) 
- (modified) mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp (+4-4) 
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+6-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 67dcce454f028b..0fa7d321844113 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -66,7 +66,7 @@ mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
     int64_t inputIndex = it.index();
     // Call get<Value>() under the assumption that we're not casting
     // dynamism.
-    Value indexGroupSize = inputShape[inputIndex].get<Value>();
+    Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
     Value indexGroupStaticSizesProduct =
         b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
     Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index f8a7a22787404b..349841f06959c3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -174,7 +174,7 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
             resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
         for (const auto &dim : enumerate(tensorType.getShape()))
           if (ShapedType::isDynamic(dim.value()))
-            dynamicSizes.push_back(shape[dim.index()].get<Value>());
+            dynamicSizes.push_back(cast<Value>(shape[dim.index()]));
       }
     }
 
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 508e50d42e4cf2..2510e774f2b2aa 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -312,7 +312,7 @@ combineOneSpec(DataLayoutSpecInterface spec,
       continue;
     }
 
-    Type typeSample = kvp.second.front().getKey().get<Type>();
+    Type typeSample = cast<Type>(kvp.second.front().getKey());
     assert(&typeSample.getDialect() !=
                typeSample.getContext()->getLoadedDialect<BuiltinDialect>() &&
            "unexpected data layout entry for built-in type");
@@ -325,7 +325,7 @@ combineOneSpec(DataLayoutSpecInterface spec,
   }
 
   for (const auto &kvp : newEntriesForID) {
-    StringAttr id = kvp.second.getKey().get<StringAttr>();
+    StringAttr id = cast<StringAttr>(kvp.second.getKey());
     Dialect *dialect = id.getReferencedDialect();
     if (!entriesForID.count(id)) {
       entriesForID[id] = kvp.second;
@@ -574,7 +574,7 @@ class TargetDataLayoutInterface : public DataLayoutDialectInterface {
 
   LogicalResult verifyEntry(DataLayoutEntryInterface entry,
                             Location loc) const final {
-    StringRef entryName = entry.getKey().get<StringAttr>().strref();
+    StringRef entryName = cast<StringAttr>(entry.getKey()).strref();
     if (entryName == DLTIDialect::kDataLayoutEndiannessKey) {
       auto value = dyn_cast<StringAttr>(entry.getValue());
       if (value &&
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 17bda27b558110..f4d36129bae776 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -113,16 +113,17 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
     // clang-format on
 
     // Return n-D ids for indexing and 1-D size + id for predicate generation.
-    return IdBuilderResult{
-        /*mappingIdOps=*/ids,
-        /*availableMappingSizes=*/
-        SmallVector<int64_t>{computeProduct(originalBasis)},
-        // `forallMappingSizes` iterate in the scaled basis, they need to be
-        // scaled back into the original basis to provide tight
-        // activeMappingSizes quantities for predication.
-        /*activeMappingSizes=*/
-        SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
-        /*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}};
+      return IdBuilderResult{
+          /*mappingIdOps=*/ids,
+          /*availableMappingSizes=*/
+          SmallVector<int64_t>{computeProduct(originalBasis)},
+          // `forallMappingSizes` iterate in the scaled basis, they need to be
+          // scaled back into the original basis to provide tight
+          // activeMappingSizes quantities for predication.
+          /*activeMappingSizes=*/
+          SmallVector<int64_t>{computeProduct(forallMappingSizes) *
+                               multiplicity},
+          /*activeIdOps=*/SmallVector<Value>{cast<Value>(linearId)}};
   };
 
   return res;
@@ -144,9 +145,8 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
     // In the 3-D mapping case, scale the first dimension by the multiplicity.
     SmallVector<Value> scaledIds = ids;
     AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
-    scaledIds[0] = affine::makeComposedFoldedAffineApply(
-                       rewriter, loc, d0.floorDiv(multiplicity), {scaledIds[0]})
-                       .get<Value>();
+    scaledIds[0] = cast<Value>(affine::makeComposedFoldedAffineApply(
+        rewriter, loc, d0.floorDiv(multiplicity), {scaledIds[0]}));
     // In the 3-D mapping case, unscale the first dimension by the multiplicity.
     SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes);
     forallMappingSizeInOriginalBasis[0] *= multiplicity;
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 60c4e07a118cb8..447668cc0ea50f 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -217,7 +217,7 @@ TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
   OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
       llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
       llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
-  return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>());
+  return cast<TypedValue<IndexType>>(cast<Value>(processInGroupLinearIndex));
 }
 
 } // namespace mlir::mesh
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
index 2866d4eb10feb1..49c8ed977d50cc 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrTypes.cpp
@@ -33,7 +33,7 @@ static SpecAttr getPointerSpec(DataLayoutEntryListRef params, PtrType type) {
   for (DataLayoutEntryInterface entry : params) {
     if (!entry.isTypeEntry())
       continue;
-    if (cast<PtrType>(entry.getKey().get<Type>()).getMemorySpace() ==
+    if (cast<PtrType>(cast<Type>(entry.getKey())).getMemorySpace() ==
         type.getMemorySpace()) {
       if (auto spec = dyn_cast<SpecAttr>(entry.getValue()))
         return spec;
@@ -55,7 +55,7 @@ bool PtrType::areCompatible(DataLayoutEntryListRef oldLayout,
       continue;
     uint32_t size = kDefaultPointerSizeBits;
     uint32_t abi = kDefaultPointerAlignment;
-    auto newType = llvm::cast<PtrType>(newEntry.getKey().get<Type>());
+    auto newType = llvm::cast<PtrType>(llvm::cast<Type>(newEntry.getKey()));
     const auto *it =
         llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
           if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
@@ -134,7 +134,7 @@ LogicalResult PtrType::verifyEntries(DataLayoutEntryListRef entries,
   for (DataLayoutEntryInterface entry : entries) {
     if (!entry.isTypeEntry())
       continue;
-    auto key = entry.getKey().get<Type>();
+    auto key = llvm::cast<Type>(entry.getKey());
     if (!llvm::isa<SpecAttr>(entry.getValue())) {
       return emitError(loc) << "expected layout attribute for " << key
                             << " to be a #ptr.spec attribute";
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 3e317319b68fc5..191bb330df7565 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -205,7 +205,7 @@ CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
 }
 
 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
-  (*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>());
+  (*this)->setAttr(getCalleeAttrName(), cast<SymbolRefAttr>(callee));
 }
 
 Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index cb5874ff45068e..ea5533dfc6bacd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -87,7 +87,7 @@ static Value unFoldOpIntResult(OpBuilder &builder, Location loc,
                                OpFoldResult ofr) {
   if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value())
     return constantIndex(builder, loc, *i);
-  return ofr.get<Value>();
+  return cast<Value>(ofr);
 }
 
 static Value tryFoldTensors(Value t) {
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index fdd968238667e2..1e0ef5add358e3 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1475,19 +1475,19 @@ transform::detail::checkApplyToOne(Operation *transformOp,
     if (ptr.isNull())
       continue;
     if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
-        !ptr.is<Operation *>()) {
+        !isa<Operation *>(ptr)) {
       return emitDiag() << "application of " << transformOpName
                         << " expected to produce an Operation * for result #"
                         << res.getResultNumber();
     }
     if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
-        !ptr.is<Attribute>()) {
+        !isa<Attribute>(ptr)) {
       return emitDiag() << "application of " << transformOpName
                         << " expected to produce an Attribute for result #"
                         << res.getResultNumber();
     }
     if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
-        !ptr.is<Value>()) {
+        !isa<Value>(ptr)) {
       return emitDiag() << "application of " << transformOpName
                         << " expected to produce a Value for result #"
                         << res.getResultNumber();
@@ -1499,7 +1499,7 @@ transform::detail::checkApplyToOne(Operation *transformOp,
 template <typename T>
 static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
   return llvm::to_vector(llvm::map_range(
-      range, [](transform::MappedValue value) { return value.get<T>(); }));
+      range, [](transform::MappedValue value) { return cast<T>(value); }));
 }
 
 void transform::detail::setApplyToOneResults(
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 0b399fba3f2635..5c8f6ded39ba4e 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -50,7 +50,7 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
                                SmallVectorImpl<int64_t> &staticVec) {
   auto v = llvm::dyn_cast_if_present<Value>(ofr);
   if (!v) {
-    APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
+    APInt apInt = cast<IntegerAttr>(cast<Attribute>(ofr)).getValue();
     staticVec.push_back(apInt.getSExtValue());
     return;
   }
@@ -212,11 +212,11 @@ decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
   SmallVector<int64_t> staticValues;
   SmallVector<Value> dynamicValues;
   for (const auto &it : mixedValues) {
-    if (it.is<Attribute>()) {
-      staticValues.push_back(cast<IntegerAttr>(it.get<Attribute>()).getInt());
+    if (auto attr = dyn_cast<Attribute>(it)) {
+      staticValues.push_back(cast<IntegerAttr>(attr).getInt());
     } else {
       staticValues.push_back(ShapedType::kDynamic);
-      dynamicValues.push_back(it.get<Value>());
+      dynamicValues.push_back(cast<Value>(it));
     }
   }
   return {staticValues, dynamicValues};
@@ -294,10 +294,10 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
                                    bool onlyNonNegative, bool onlyNonZero) {
   bool valuesChanged = false;
   for (OpFoldResult &ofr : ofrs) {
-    if (ofr.is<Attribute>())
+    if (isa<Attribute>(ofr))
       continue;
     Attribute attr;
-    if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
+    if (matchPattern(cast<Value>(ofr), m_Constant(&attr))) {
       // Note: All ofrs have index type.
       if (onlyNonNegative && *getConstantIntValue(attr) < 0)
         continue;

``````````

</details>


https://github.com/llvm/llvm-project/pull/120818


More information about the Mlir-commits mailing list