[Mlir-commits] [llvm] [mlir] Mark `mlir::Value::isa/dyn_cast/cast/...` member functions deprecated. (PR #89238)
Christian Sigg
llvmlistbot at llvm.org
Thu Apr 18 07:54:09 PDT 2024
https://github.com/chsigg updated https://github.com/llvm/llvm-project/pull/89238
>From 1b31e44800498e864cff1c59e28f4a13b9066a88 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Thu, 18 Apr 2024 16:53:48 +0200
Subject: [PATCH] Mark `isa/dyn_cast/cast/...` member functions deprecated.
See https://mlir.llvm.org/deprecation
---
llvm/include/llvm/ADT/TypeSwitch.h | 3 +-
.../transform/Ch4/lib/MyExtension.cpp | 2 +-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 6 ++--
mlir/include/mlir/IR/Value.h | 6 +++-
.../IR/BufferDeallocationOpInterface.cpp | 6 ++--
.../Linalg/TransformOps/LinalgMatchOps.cpp | 24 ++++++-------
.../TransformOps/MemRefTransformOps.cpp | 4 +--
.../Mesh/Interfaces/ShardingInterface.cpp | 2 +-
.../Dialect/Mesh/Transforms/Spmdization.cpp | 35 ++++++++-----------
.../Dialect/Mesh/Transforms/Transforms.cpp | 7 ++--
.../TransformOps/SparseTensorTransformOps.cpp | 2 +-
.../Transforms/SparseTensorRewriting.cpp | 6 ++--
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 +--
.../lib/Dialect/Transform/IR/TransformOps.cpp | 4 +--
.../Mesh/TestReshardingSpmdization.cpp | 5 ++-
15 files changed, 55 insertions(+), 61 deletions(-)
diff --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h
index 10a2d48e918db9..2495855334eef9 100644
--- a/llvm/include/llvm/ADT/TypeSwitch.h
+++ b/llvm/include/llvm/ADT/TypeSwitch.h
@@ -64,8 +64,7 @@ template <typename DerivedT, typename T> class TypeSwitchBase {
/// Trait to check whether `ValueT` provides a 'dyn_cast' method with type
/// `CastT`.
template <typename ValueT, typename CastT>
- using has_dyn_cast_t =
- decltype(std::declval<ValueT &>().template dyn_cast<CastT>());
+ using has_dyn_cast_t = decltype(dyn_cast<CastT>(std::declval<ValueT &>()));
/// Attempt to dyn_cast the given `value` to `CastT`. This overload is
/// selected if `value` already has a suitable dyn_cast method.
diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp
index 26e348f2a30ec6..83e2dcd750bb39 100644
--- a/mlir/examples/transform/Ch4/lib/MyExtension.cpp
+++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp
@@ -142,7 +142,7 @@ mlir::transform::HasOperandSatisfyingOp::apply(
transform::detail::prepareValueMappings(
yieldedMappings, getBody().front().getTerminator()->getOperands(),
state);
- results.setParams(getPosition().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getPosition()),
{rewriter.getI32IntegerAttr(operand.getOperandNumber())});
for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
results.setMappedValues(result, mapping);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index a9bc3351f4cff0..ec3c2cb011c357 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -60,11 +60,11 @@ class MulOperandsAndResultElementType
if (llvm::isa<FloatType>(resElemType))
return impl::verifySameOperandsAndResultElementType(op);
- if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
+ if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
IntegerType lhsIntType =
- getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
+ cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
IntegerType rhsIntType =
- getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
+ cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
if (lhsIntType != rhsIntType)
return op->emitOpError(
"requires the same element type for all operands");
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index a74d0faa1dfc4b..cdbc6cc374368c 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -98,21 +98,25 @@ class Value {
constexpr Value(detail::ValueImpl *impl = nullptr) : impl(impl) {}
template <typename U>
+ [[deprecated("Use isa<U>() instead")]]
bool isa() const {
return llvm::isa<U>(*this);
}
template <typename U>
+ [[deprecated("Use dyn_cast<U>() instead")]]
U dyn_cast() const {
return llvm::dyn_cast<U>(*this);
}
template <typename U>
+ [[deprecated("Use dyn_cast_or_null<U>() instead")]]
U dyn_cast_or_null() const {
- return llvm::dyn_cast_if_present<U>(*this);
+ return llvm::dyn_cast_or_null<U>(*this);
}
template <typename U>
+ [[deprecated("Use cast<U>() instead")]]
U cast() const {
return llvm::cast<U>(*this);
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index a5ea42b7d701d0..b197786c320548 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -39,7 +39,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
}
-static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
//===----------------------------------------------------------------------===//
// Ownership
@@ -222,8 +222,8 @@ bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
return false;
// Block arguments are less than results.
- bool lhsIsBBArg = lhs.isa<BlockArgument>();
- if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
+ bool lhsIsBBArg = isa<BlockArgument>(lhs);
+ if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
return lhsIsBBArg;
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 3e85559e1ec0c6..dc77014abeb27f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -259,11 +259,11 @@ transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
return builder.getI64IntegerAttr(value);
}));
};
- results.setParams(getBatch().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getBatch()),
makeI64Attrs(contractionDims->batch));
- results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
- results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
- results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
+ results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
+ results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
+ results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
return DiagnosedSilenceableFailure::success();
}
@@ -288,17 +288,17 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
return builder.getI64IntegerAttr(value);
}));
};
- results.setParams(getBatch().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getBatch()),
makeI64Attrs(convolutionDims->batch));
- results.setParams(getOutputImage().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getOutputImage()),
makeI64Attrs(convolutionDims->outputImage));
- results.setParams(getOutputChannel().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getOutputChannel()),
makeI64Attrs(convolutionDims->outputChannel));
- results.setParams(getFilterLoop().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getFilterLoop()),
makeI64Attrs(convolutionDims->filterLoop));
- results.setParams(getInputChannel().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getInputChannel()),
makeI64Attrs(convolutionDims->inputChannel));
- results.setParams(getDepth().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getDepth()),
makeI64Attrs(convolutionDims->depth));
auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
@@ -307,9 +307,9 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
return builder.getI64IntegerAttr(value);
}));
};
- results.setParams(getStrides().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getStrides()),
makeI64AttrsFromI64(convolutionDims->strides));
- results.setParams(getDilations().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getDilations()),
makeI64AttrsFromI64(convolutionDims->dilations));
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index b3481ce1c56bbd..3c9475c2d143a6 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -173,8 +173,8 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
}
// Assemble results.
- results.set(getGlobal().cast<OpResult>(), globalOps);
- results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
+ results.set(cast<OpResult>(getGlobal()), globalOps);
+ results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 9acee5aa8d8604..ddb50130c6b82f 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -97,7 +97,7 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
FailureOr<std::pair<bool, MeshShardingAttr>>
mesh::getMeshShardingAttr(OpResult result) {
- Value val = result.cast<Value>();
+ Value val = cast<Value>(result);
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
if (!shardOp)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index e4868435135ed1..31297ece6c57f0 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -86,14 +86,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
}
builder.setInsertionPointAfterValue(sourceShard);
- TypedValue<ShapedType> resultValue =
+ TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
builder
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
sourceSharding.getMesh().getLeafReference(),
allReduceMeshAxes, sourceShard,
sourceSharding.getPartialType())
- .getResult()
- .cast<TypedValue<ShapedType>>();
+ .getResult());
llvm::SmallVector<MeshAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
@@ -135,13 +134,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
- TypedValue<ShapedType> targetShard =
+ TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
builder
.create<AllSliceOp>(sourceShard, mesh,
ArrayRef<MeshAxis>(splitMeshAxis),
splitTensorAxis)
- .getResult()
- .cast<TypedValue<ShapedType>>();
+ .getResult());
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
return {targetShard, targetSharding};
@@ -278,10 +276,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
- TypedValue<ShapedType> targetShard =
- builder.create<tensor::CastOp>(targetShape, allGatherResult)
- .getResult()
- .cast<TypedValue<ShapedType>>();
+ TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
return {targetShard, targetSharding};
}
@@ -413,10 +409,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
- TypedValue<ShapedType> targetShard =
- builder.create<tensor::CastOp>(targetShape, allToAllResult)
- .getResult()
- .cast<TypedValue<ShapedType>>();
+ TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
return {targetShard, targetSharding};
}
@@ -505,7 +499,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(
implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
- source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
+ cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
@@ -536,7 +530,7 @@ shardedBlockArgumentTypes(Block &block,
llvm::transform(block.getArguments(), std::back_inserter(res),
[&symbolTableCollection](BlockArgument arg) {
auto rankedTensorArg =
- arg.dyn_cast<TypedValue<RankedTensorType>>();
+ dyn_cast<TypedValue<RankedTensorType>>(arg);
if (!rankedTensorArg) {
return arg.getType();
}
@@ -587,7 +581,7 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
res.reserve(op.getNumOperands());
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
TypedValue<RankedTensorType> rankedTensor =
- operand.dyn_cast<TypedValue<RankedTensorType>>();
+ dyn_cast<TypedValue<RankedTensorType>>(operand);
if (!rankedTensor) {
return MeshShardingAttr();
}
@@ -608,7 +602,7 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
llvm::transform(op.getResults(), std::back_inserter(res),
[](OpResult result) {
TypedValue<RankedTensorType> rankedTensor =
- result.dyn_cast<TypedValue<RankedTensorType>>();
+ dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
return MeshShardingAttr();
}
@@ -636,9 +630,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
} else {
// Insert resharding.
assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
- TypedValue<ShapedType> srcSpmdValue =
- spmdizationMap.lookup(srcShardOp.getOperand())
- .cast<TypedValue<ShapedType>>();
+ TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
+ spmdizationMap.lookup(srcShardOp.getOperand()));
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
symbolTableCollection);
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index cb13ee404751ca..932a9fb3acf0a8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -202,10 +202,9 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder) {
Operation::result_range meshShape =
builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
- return arith::createProduct(builder, builder.getLoc(),
- llvm::to_vector_of<Value>(meshShape),
- builder.getIndexType())
- .cast<TypedValue<IndexType>>();
+ return cast<TypedValue<IndexType>>(arith::createProduct(
+ builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
+ builder.getIndexType()));
}
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
diff --git a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp
index 5b7ea9360e2211..ca19259ebffa68 100644
--- a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp
+++ b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp
@@ -25,7 +25,7 @@ DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation(
return emitSilenceableFailure(current->getLoc(),
"operation has no sparse input or output");
}
- results.set(getResult().cast<OpResult>(), state.getPayloadOps(getTarget()));
+ results.set(cast<OpResult>(getResult()), state.getPayloadOps(getTarget()));
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b117c1694e45b8..02375f54d7152f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -476,8 +476,8 @@ struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
if (!sel)
return std::nullopt;
- auto tVal = sel.getTrueValue().dyn_cast<BlockArgument>();
- auto fVal = sel.getFalseValue().dyn_cast<BlockArgument>();
+ auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
+ auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
// TODO: For simplicity, we only handle cases where both true/false value
// are directly loaded the input tensor. We can probably admit more cases
// in theory.
@@ -487,7 +487,7 @@ struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
// Helper lambda to determine whether the value is loaded from a dense input
// or is a loop invariant.
auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
- if (auto bArg = v.dyn_cast<BlockArgument>();
+ if (auto bArg = dyn_cast<BlockArgument>(v);
bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
return true;
// If the value is defined outside the loop, it is a loop invariant.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ce40e81371209..4bac390cee8d5a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -820,7 +820,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
if (!destOp)
return failure();
- auto resultIndex = source.cast<OpResult>().getResultNumber();
+ auto resultIndex = cast<OpResult>(source).getResultNumber();
auto *initOperand = destOp.getDpsInitOperand(resultIndex);
rewriter.modifyOpInPlace(
@@ -4307,7 +4307,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
/// unpack(destinationStyleOp(x)) -> unpack(x)
if (auto dstStyleOp =
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
- auto destValue = unPackOp.getDest().cast<OpResult>();
+ auto destValue = cast<OpResult>(unPackOp.getDest());
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
rewriter.modifyOpInPlace(unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dc19022219e5b2..bbe05eda8f0db3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1608,7 +1608,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
}
params.push_back(TypeAttr::get(type));
}
- results.setParams(getResult().cast<OpResult>(), params);
+ results.setParams(cast<OpResult>(getResult()), params);
return DiagnosedSilenceableFailure::success();
}
@@ -2210,7 +2210,7 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
llvm_unreachable("unknown kind of transform dialect type");
return 0;
});
- results.setParams(getNum().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getNum()),
rewriter.getI64IntegerAttr(numAssociations));
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index 9b3082a819224f..5e3918f79d1844 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -67,12 +67,11 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ShapedType sourceShardShape =
shardShapedType(op.getResult().getType(), mesh, op.getShard());
- TypedValue<ShapedType> sourceShard =
+ TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
builder
.create<UnrealizedConversionCastOp>(sourceShardShape,
op.getOperand())
- ->getResult(0)
- .cast<TypedValue<ShapedType>>();
+ ->getResult(0));
TypedValue<ShapedType> targetShard =
reshard(builder, mesh, op, targetShardOp, sourceShard);
Value newTargetUnsharded =
More information about the Mlir-commits
mailing list