[Mlir-commits] [llvm] [mlir] Mark `mlir::Value::isa/dyn_cast/cast/...` member functions deprecated. (PR #89238)

Christian Sigg llvmlistbot at llvm.org
Thu Apr 18 07:54:09 PDT 2024


https://github.com/chsigg updated https://github.com/llvm/llvm-project/pull/89238

>From 1b31e44800498e864cff1c59e28f4a13b9066a88 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Thu, 18 Apr 2024 16:53:48 +0200
Subject: [PATCH] Mark `isa/dyn_cast/cast/...` member functions deprecated.

See https://mlir.llvm.org/deprecation
---
 llvm/include/llvm/ADT/TypeSwitch.h            |  3 +-
 .../transform/Ch4/lib/MyExtension.cpp         |  2 +-
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |  6 ++--
 mlir/include/mlir/IR/Value.h                  |  6 +++-
 .../IR/BufferDeallocationOpInterface.cpp      |  6 ++--
 .../Linalg/TransformOps/LinalgMatchOps.cpp    | 24 ++++++-------
 .../TransformOps/MemRefTransformOps.cpp       |  4 +--
 .../Mesh/Interfaces/ShardingInterface.cpp     |  2 +-
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 35 ++++++++-----------
 .../Dialect/Mesh/Transforms/Transforms.cpp    |  7 ++--
 .../TransformOps/SparseTensorTransformOps.cpp |  2 +-
 .../Transforms/SparseTensorRewriting.cpp      |  6 ++--
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  4 +--
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  4 +--
 .../Mesh/TestReshardingSpmdization.cpp        |  5 ++-
 15 files changed, 55 insertions(+), 61 deletions(-)

diff --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h
index 10a2d48e918db9..2495855334eef9 100644
--- a/llvm/include/llvm/ADT/TypeSwitch.h
+++ b/llvm/include/llvm/ADT/TypeSwitch.h
@@ -64,8 +64,7 @@ template <typename DerivedT, typename T> class TypeSwitchBase {
   /// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
   /// `CastT`.
   template <typename ValueT, typename CastT>
-  using has_dyn_cast_t =
-      decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
+  using has_dyn_cast_t = decltype(dyn_cast<CastT>(std::declval<ValueT &>()));
 
   /// Attempt to dyn_cast the given `value` to `CastT`. This overload is
   /// selected if `value` already has a suitable dyn_cast method.
diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp
index 26e348f2a30ec6..83e2dcd750bb39 100644
--- a/mlir/examples/transform/Ch4/lib/MyExtension.cpp
+++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp
@@ -142,7 +142,7 @@ mlir::transform::HasOperandSatisfyingOp::apply(
     transform::detail::prepareValueMappings(
         yieldedMappings, getBody().front().getTerminator()->getOperands(),
         state);
-    results.setParams(getPosition().cast<OpResult>(),
+    results.setParams(cast<OpResult>(getPosition()),
                       {rewriter.getI32IntegerAttr(operand.getOperandNumber())});
     for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
       results.setMappedValues(result, mapping);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index a9bc3351f4cff0..ec3c2cb011c357 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -60,11 +60,11 @@ class MulOperandsAndResultElementType
     if (llvm::isa<FloatType>(resElemType))
       return impl::verifySameOperandsAndResultElementType(op);
 
-    if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
+    if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
       IntegerType lhsIntType =
-          getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
+          cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
       IntegerType rhsIntType =
-          getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
+          cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
       if (lhsIntType != rhsIntType)
         return op->emitOpError(
             "requires the same element type for all operands");
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index a74d0faa1dfc4b..cdbc6cc374368c 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -98,21 +98,25 @@ class Value {
   constexpr Value(detail::ValueImpl *impl = nullptr) : impl(impl) {}
 
   template <typename U>
+  [[deprecated("Use isa<U>() instead")]]
   bool isa() const {
     return llvm::isa<U>(*this);
   }
 
   template <typename U>
+  [[deprecated("Use dyn_cast<U>() instead")]]
   U dyn_cast() const {
     return llvm::dyn_cast<U>(*this);
   }
 
   template <typename U>
+  [[deprecated("Use dyn_cast_or_null<U>() instead")]]
   U dyn_cast_or_null() const {
-    return llvm::dyn_cast_if_present<U>(*this);
+    return llvm::dyn_cast_or_null<U>(*this);
   }
 
   template <typename U>
+  [[deprecated("Use cast<U>() instead")]]
   U cast() const {
     return llvm::cast<U>(*this);
   }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index a5ea42b7d701d0..b197786c320548 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -39,7 +39,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
   return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
 }
 
-static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
 
 //===----------------------------------------------------------------------===//
 // Ownership
@@ -222,8 +222,8 @@ bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
     return false;
 
   // Block arguments are less than results.
-  bool lhsIsBBArg = lhs.isa<BlockArgument>();
-  if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
+  bool lhsIsBBArg = isa<BlockArgument>(lhs);
+  if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
     return lhsIsBBArg;
   }
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 3e85559e1ec0c6..dc77014abeb27f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -259,11 +259,11 @@ transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
           return builder.getI64IntegerAttr(value);
         }));
   };
-  results.setParams(getBatch().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getBatch()),
                     makeI64Attrs(contractionDims->batch));
-  results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
-  results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
-  results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
+  results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
+  results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
+  results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -288,17 +288,17 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
           return builder.getI64IntegerAttr(value);
         }));
   };
-  results.setParams(getBatch().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getBatch()),
                     makeI64Attrs(convolutionDims->batch));
-  results.setParams(getOutputImage().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getOutputImage()),
                     makeI64Attrs(convolutionDims->outputImage));
-  results.setParams(getOutputChannel().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getOutputChannel()),
                     makeI64Attrs(convolutionDims->outputChannel));
-  results.setParams(getFilterLoop().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getFilterLoop()),
                     makeI64Attrs(convolutionDims->filterLoop));
-  results.setParams(getInputChannel().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getInputChannel()),
                     makeI64Attrs(convolutionDims->inputChannel));
-  results.setParams(getDepth().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getDepth()),
                     makeI64Attrs(convolutionDims->depth));
 
   auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
@@ -307,9 +307,9 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
           return builder.getI64IntegerAttr(value);
         }));
   };
-  results.setParams(getStrides().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getStrides()),
                     makeI64AttrsFromI64(convolutionDims->strides));
-  results.setParams(getDilations().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getDilations()),
                     makeI64AttrsFromI64(convolutionDims->dilations));
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index b3481ce1c56bbd..3c9475c2d143a6 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -173,8 +173,8 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Assemble results.
-  results.set(getGlobal().cast<OpResult>(), globalOps);
-  results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
+  results.set(cast<OpResult>(getGlobal()), globalOps);
+  results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
 
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 9acee5aa8d8604..ddb50130c6b82f 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -97,7 +97,7 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
 
 FailureOr<std::pair<bool, MeshShardingAttr>>
 mesh::getMeshShardingAttr(OpResult result) {
-  Value val = result.cast<Value>();
+  Value val = cast<Value>(result);
   bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
     auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
     if (!shardOp)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index e4868435135ed1..31297ece6c57f0 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -86,14 +86,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
   }
 
   builder.setInsertionPointAfterValue(sourceShard);
-  TypedValue<ShapedType> resultValue =
+  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
       builder
           .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
                                sourceSharding.getMesh().getLeafReference(),
                                allReduceMeshAxes, sourceShard,
                                sourceSharding.getPartialType())
-          .getResult()
-          .cast<TypedValue<ShapedType>>();
+          .getResult());
 
   llvm::SmallVector<MeshAxis> remainingPartialAxes;
   llvm::copy_if(sourceShardingPartialAxesSet,
@@ -135,13 +134,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           MeshShardingAttr sourceSharding,
                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  TypedValue<ShapedType> targetShard =
+  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
       builder
           .create<AllSliceOp>(sourceShard, mesh,
                               ArrayRef<MeshAxis>(splitMeshAxis),
                               splitTensorAxis)
-          .getResult()
-          .cast<TypedValue<ShapedType>>();
+          .getResult());
   MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
       builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
   return {targetShard, targetSharding};
@@ -278,10 +276,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
       APInt(64, splitTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
-  TypedValue<ShapedType> targetShard =
-      builder.create<tensor::CastOp>(targetShape, allGatherResult)
-          .getResult()
-          .cast<TypedValue<ShapedType>>();
+  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+      builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
   return {targetShard, targetSharding};
 }
 
@@ -413,10 +409,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
       APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
-  TypedValue<ShapedType> targetShard =
-      builder.create<tensor::CastOp>(targetShape, allToAllResult)
-          .getResult()
-          .cast<TypedValue<ShapedType>>();
+  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+      builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
   return {targetShard, targetSharding};
 }
 
@@ -505,7 +499,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
   return reshard(
       implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
-      source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
+      cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
 }
 
 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
@@ -536,7 +530,7 @@ shardedBlockArgumentTypes(Block &block,
   llvm::transform(block.getArguments(), std::back_inserter(res),
                   [&symbolTableCollection](BlockArgument arg) {
                     auto rankedTensorArg =
-                        arg.dyn_cast<TypedValue<RankedTensorType>>();
+                        dyn_cast<TypedValue<RankedTensorType>>(arg);
                     if (!rankedTensorArg) {
                       return arg.getType();
                     }
@@ -587,7 +581,7 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
   res.reserve(op.getNumOperands());
   llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
     TypedValue<RankedTensorType> rankedTensor =
-        operand.dyn_cast<TypedValue<RankedTensorType>>();
+        dyn_cast<TypedValue<RankedTensorType>>(operand);
     if (!rankedTensor) {
       return MeshShardingAttr();
     }
@@ -608,7 +602,7 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
   llvm::transform(op.getResults(), std::back_inserter(res),
                   [](OpResult result) {
                     TypedValue<RankedTensorType> rankedTensor =
-                        result.dyn_cast<TypedValue<RankedTensorType>>();
+                        dyn_cast<TypedValue<RankedTensorType>>(result);
                     if (!rankedTensor) {
                       return MeshShardingAttr();
                     }
@@ -636,9 +630,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
   } else {
     // Insert resharding.
     assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
-    TypedValue<ShapedType> srcSpmdValue =
-        spmdizationMap.lookup(srcShardOp.getOperand())
-            .cast<TypedValue<ShapedType>>();
+    TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
+        spmdizationMap.lookup(srcShardOp.getOperand()));
     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                               symbolTableCollection);
   }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index cb13ee404751ca..932a9fb3acf0a8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -202,10 +202,9 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
                                  ImplicitLocOpBuilder &builder) {
   Operation::result_range meshShape =
       builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
-  return arith::createProduct(builder, builder.getLoc(),
-                              llvm::to_vector_of<Value>(meshShape),
-                              builder.getIndexType())
-      .cast<TypedValue<IndexType>>();
+  return cast<TypedValue<IndexType>>(arith::createProduct(
+      builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
+      builder.getIndexType()));
 }
 
 TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
diff --git a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp
index 5b7ea9360e2211..ca19259ebffa68 100644
--- a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp
+++ b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp
@@ -25,7 +25,7 @@ DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation(
     return emitSilenceableFailure(current->getLoc(),
                                   "operation has no sparse input or output");
   }
-  results.set(getResult().cast<OpResult>(), state.getPayloadOps(getTarget()));
+  results.set(cast<OpResult>(getResult()), state.getPayloadOps(getTarget()));
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b117c1694e45b8..02375f54d7152f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -476,8 +476,8 @@ struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
     if (!sel)
       return std::nullopt;
 
-    auto tVal = sel.getTrueValue().dyn_cast<BlockArgument>();
-    auto fVal = sel.getFalseValue().dyn_cast<BlockArgument>();
+    auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
+    auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
     // TODO: For simplicity, we only handle cases where both true/false value
     // are directly loaded the input tensor. We can probably admit more cases
     // in theory.
@@ -487,7 +487,7 @@ struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
     // Helper lambda to determine whether the value is loaded from a dense input
     // or is a loop invariant.
     auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
-      if (auto bArg = v.dyn_cast<BlockArgument>();
+      if (auto bArg = dyn_cast<BlockArgument>(v);
           bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
         return true;
       // If the value is defined outside the loop, it is a loop invariant.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ce40e81371209..4bac390cee8d5a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -820,7 +820,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
     if (!destOp)
       return failure();
 
-    auto resultIndex = source.cast<OpResult>().getResultNumber();
+    auto resultIndex = cast<OpResult>(source).getResultNumber();
     auto *initOperand = destOp.getDpsInitOperand(resultIndex);
 
     rewriter.modifyOpInPlace(
@@ -4307,7 +4307,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
   /// unpack(destinationStyleOp(x)) -> unpack(x)
   if (auto dstStyleOp =
           unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
-    auto destValue = unPackOp.getDest().cast<OpResult>();
+    auto destValue = cast<OpResult>(unPackOp.getDest());
     Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
     rewriter.modifyOpInPlace(unPackOp,
                              [&]() { unPackOp.setDpsInitOperand(0, newDest); });
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dc19022219e5b2..bbe05eda8f0db3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1608,7 +1608,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
     }
     params.push_back(TypeAttr::get(type));
   }
-  results.setParams(getResult().cast<OpResult>(), params);
+  results.setParams(cast<OpResult>(getResult()), params);
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -2210,7 +2210,7 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
             llvm_unreachable("unknown kind of transform dialect type");
             return 0;
           });
-  results.setParams(getNum().cast<OpResult>(),
+  results.setParams(cast<OpResult>(getNum()),
                     rewriter.getI64IntegerAttr(numAssociations));
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index 9b3082a819224f..5e3918f79d1844 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -67,12 +67,11 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
       ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
       ShapedType sourceShardShape =
           shardShapedType(op.getResult().getType(), mesh, op.getShard());
-      TypedValue<ShapedType> sourceShard =
+      TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
           builder
               .create<UnrealizedConversionCastOp>(sourceShardShape,
                                                   op.getOperand())
-              ->getResult(0)
-              .cast<TypedValue<ShapedType>>();
+              ->getResult(0));
       TypedValue<ShapedType> targetShard =
           reshard(builder, mesh, op, targetShardOp, sourceShard);
       Value newTargetUnsharded =



More information about the Mlir-commits mailing list