[Mlir-commits] [mlir] Switch member calls to `isa/dyn_cast/cast/...` to free function calls. (PR #89356)
Christian Sigg
llvmlistbot at llvm.org
Fri Apr 19 02:13:13 PDT 2024
https://github.com/chsigg created https://github.com/llvm/llvm-project/pull/89356
This change cleans up call sites. Next step is to mark the member functions deprecated.
See https://mlir.llvm.org/deprecation and https://discourse.llvm.org/t/preferred-casting-style-going-forward.
>From ed7c0058e2dded63667fd6afa85c9f97462453d9 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Fri, 19 Apr 2024 11:12:06 +0200
Subject: [PATCH] Switch member calls to `isa/dyn_cast/cast/...` to free
function calls.
This change cleans up call sites. Next step is to mark the member functions deprecated.
See https://mlir.llvm.org/deprecation and https://discourse.llvm.org/t/preferred-casting-style-going-forward.
---
.../transform/Ch4/lib/MyExtension.cpp | 2 +-
.../Mesh/Interfaces/ShardingInterfaceImpl.h | 6 +-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 6 +-
mlir/include/mlir/IR/Location.h | 2 +-
mlir/lib/CAPI/Dialect/LLVM.cpp | 6 +-
mlir/lib/CAPI/IR/BuiltinTypes.cpp | 4 +-
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 ++--
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 12 ++--
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 7 +-
.../GPUCommon/GPUToLLVMConversion.cpp | 2 +-
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 18 ++---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 20 +++---
.../Conversion/VectorToGPU/VectorToGPU.cpp | 4 +-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 3 +-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 4 +-
.../BufferizableOpInterfaceImpl.cpp | 2 +-
.../Transforms/EmulateUnsupportedFloats.cpp | 2 +-
.../LowerContractionToSMMLAPattern.cpp | 14 ++--
.../IR/BufferDeallocationOpInterface.cpp | 6 +-
.../IR/BufferizableOpInterface.cpp | 6 +-
.../OwnershipBasedBufferDeallocation.cpp | 2 +-
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 2 +-
.../BufferDeallocationOpInterfaceImpl.cpp | 2 +-
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 20 +++---
.../GPU/TransformOps/GPUTransformOps.cpp | 6 +-
mlir/lib/Dialect/IRDL/IRDLLoading.cpp | 4 +-
.../LLVMIR/IR/BasicPtxBuilderInterface.cpp | 2 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 10 +--
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 10 ++-
.../LLVMIR/Transforms/TypeConsistency.cpp | 2 +-
.../Linalg/TransformOps/LinalgMatchOps.cpp | 28 ++++----
.../TransformOps/LinalgTransformOps.cpp | 6 +-
.../Transforms/ConvertToDestinationStyle.cpp | 9 ++-
.../Transforms/EliminateEmptyTensors.cpp | 3 +-
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +-
.../Dialect/Linalg/Transforms/Transforms.cpp | 4 +-
.../Linalg/Transforms/Vectorization.cpp | 4 +-
.../TransformOps/MemRefTransformOps.cpp | 4 +-
.../MemRef/Transforms/EmulateNarrowType.cpp | 4 +-
.../MemRef/Transforms/ExpandRealloc.cpp | 6 +-
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 19 +++---
.../Mesh/Interfaces/ShardingInterface.cpp | 10 +--
.../Dialect/Mesh/Transforms/Spmdization.cpp | 66 ++++++++-----------
.../Dialect/Mesh/Transforms/Transforms.cpp | 9 ++-
.../NVGPU/TransformOps/NVGPUTransformOps.cpp | 8 +--
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 2 +-
.../Dialect/SparseTensor/IR/Detail/Var.cpp | 4 +-
.../SparseTensor/IR/SparseTensorDialect.cpp | 2 +-
.../IR/SparseTensorInterfaces.cpp | 2 +-
.../TransformOps/SparseTensorTransformOps.cpp | 2 +-
.../Transforms/SparseAssembler.cpp | 4 +-
.../Transforms/SparseReinterpretMap.cpp | 2 +-
.../Transforms/SparseTensorRewriting.cpp | 6 +-
.../Transforms/Utils/CodegenUtils.cpp | 2 +-
.../Transforms/Utils/IterationGraphSorter.cpp | 3 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 7 +-
.../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 4 +-
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 4 +-
.../Dialect/Tosa/Transforms/TosaFolders.cpp | 2 +-
.../Tosa/Transforms/TosaValidation.cpp | 2 +-
.../DebugExtension/DebugExtensionOps.cpp | 7 +-
.../lib/Dialect/Transform/IR/TransformOps.cpp | 6 +-
.../Dialect/Transform/IR/TransformTypes.cpp | 4 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +-
.../Vector/Transforms/LowerVectorTransfer.cpp | 2 +-
.../Transforms/VectorDropLeadUnitDim.cpp | 2 +-
.../Transforms/VectorEmulateNarrowType.cpp | 2 +-
.../VectorTransferSplitRewritePatterns.cpp | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 2 +-
mlir/lib/IR/AffineMap.cpp | 4 +-
mlir/lib/IR/Operation.cpp | 2 +-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 25 ++++---
.../Dialect/VCIX/VCIXToLLVMIRTranslation.cpp | 4 +-
.../MathToVCIX/TestMathToVCIXConversion.cpp | 4 +-
.../Mesh/TestReshardingSpmdization.cpp | 5 +-
.../Dialect/Test/TestToLLVMIRTranslation.cpp | 2 +-
mlir/test/lib/IR/TestAffineWalk.cpp | 2 +-
.../lib/IR/TestBuiltinAttributeInterfaces.cpp | 2 +-
mlir/test/lib/Rewrite/TestPDLByteCode.cpp | 2 +-
80 files changed, 241 insertions(+), 270 deletions(-)
diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp
index 26e348f2a30ec6..83e2dcd750bb39 100644
--- a/mlir/examples/transform/Ch4/lib/MyExtension.cpp
+++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp
@@ -142,7 +142,7 @@ mlir::transform::HasOperandSatisfyingOp::apply(
transform::detail::prepareValueMappings(
yieldedMappings, getBody().front().getTerminator()->getOperands(),
state);
- results.setParams(getPosition().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getPosition()),
{rewriter.getI32IntegerAttr(operand.getOperandNumber())});
for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
results.setMappedValues(result, mapping);
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index ab4df2ab028d43..5e4b4f3a66af9d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -87,7 +87,7 @@ struct IndependentParallelIteratorDomainShardingInterface
void
populateIteratorTypes(Type t,
SmallVector<utils::IteratorType> &iterTypes) const {
- RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
+ RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
if (!rankedTensorType) {
return;
}
@@ -106,7 +106,7 @@ struct ElementwiseShardingInterface
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
Value val = op->getOperand(0);
- auto type = val.getType().dyn_cast<RankedTensorType>();
+ auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
SmallVector<utils::IteratorType> types(type.getRank(),
@@ -117,7 +117,7 @@ struct ElementwiseShardingInterface
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
MLIRContext *ctx = op->getContext();
Value val = op->getOperand(0);
- auto type = val.getType().dyn_cast<RankedTensorType>();
+ auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
int64_t rank = type.getRank();
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index a9bc3351f4cff0..ec3c2cb011c357 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -60,11 +60,11 @@ class MulOperandsAndResultElementType
if (llvm::isa<FloatType>(resElemType))
return impl::verifySameOperandsAndResultElementType(op);
- if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
+ if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
IntegerType lhsIntType =
- getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
+ cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
IntegerType rhsIntType =
- getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
+ cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
if (lhsIntType != rhsIntType)
return op->emitOpError(
"requires the same element type for all operands");
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index d4268e804f4f7a..aa8314f38cdfac 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -154,7 +154,7 @@ class FusedLocWith : public FusedLoc {
/// Support llvm style casting.
static bool classof(Attribute attr) {
auto fusedLoc = llvm::dyn_cast<FusedLoc>(attr);
- return fusedLoc && fusedLoc.getMetadata().isa_and_nonnull<MetadataT>();
+ return fusedLoc && mlir::isa_and_nonnull<MetadataT>(fusedLoc.getMetadata());
}
};
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index 4669c40f843d94..21c66f38a8af03 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -135,7 +135,7 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations,
unwrap(ctx),
llvm::map_to_vector(
unwrapList(nOperations, operations, attrStorage),
- [](Attribute a) { return a.cast<DIExpressionElemAttr>(); })));
+ [](Attribute a) { return cast<DIExpressionElemAttr>(a); })));
}
MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) {
@@ -165,7 +165,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet(
cast<DIScopeAttr>(unwrap(scope)), cast<DITypeAttr>(unwrap(baseType)),
DIFlags(flags), sizeInBits, alignInBits,
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
- [](Attribute a) { return a.cast<DINodeAttr>(); })));
+ [](Attribute a) { return cast<DINodeAttr>(a); })));
}
MlirAttribute
@@ -259,7 +259,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx,
return wrap(DISubroutineTypeAttr::get(
unwrap(ctx), callingConvention,
llvm::map_to_vector(unwrapList(nTypes, types, attrStorage),
- [](Attribute a) { return a.cast<DITypeAttr>(); })));
+ [](Attribute a) { return cast<DITypeAttr>(a); })));
}
MlirAttribute mlirLLVMDISubprogramAttrGet(
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index e1a5d82587cf9e..c94c070144a7e9 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -311,11 +311,11 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
}
bool mlirVectorTypeIsScalable(MlirType type) {
- return unwrap(type).cast<VectorType>().isScalable();
+ return cast<VectorType>(unwrap(type)).isScalable();
}
bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
- return unwrap(type).cast<VectorType>().getScalableDims()[dim];
+ return cast<VectorType>(unwrap(type)).getScalableDims()[dim];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 7e073bae75c0c9..033e66c6118f30 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -371,7 +371,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
bool isUnsigned, Value llvmInput,
SmallVector<Value, 4> &operands) {
Type inputType = llvmInput.getType();
- auto vectorType = inputType.dyn_cast<VectorType>();
+ auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
@@ -414,7 +414,7 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Value output, int32_t subwordOffset,
bool clamp, SmallVector<Value, 4> &operands) {
Type inputType = output.getType();
- auto vectorType = inputType.dyn_cast<VectorType>();
+ auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
output = rewriter.create<LLVM::BitcastOp>(
@@ -569,9 +569,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
-
- auto sourceVectorType = wmma.getSourceA().getType().dyn_cast<VectorType>();
- auto destVectorType = wmma.getDestC().getType().dyn_cast<VectorType>();
+ auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
+ auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
auto elemSourceType = sourceVectorType.getElementType();
auto elemDestType = destVectorType.getElementType();
@@ -727,7 +726,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
Value source = adaptor.getSource();
- auto sourceVecType = op.getSource().getType().dyn_cast<VectorType>();
+ auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 0113a3df0b8e3d..3d3ff001c541b5 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -65,7 +65,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
Type inType = op.getIn().getType();
- if (auto inVecType = inType.dyn_cast<VectorType>()) {
+ if (auto inVecType = dyn_cast<VectorType>(inType)) {
if (inVecType.isScalable())
return failure();
if (inVecType.getShape().size() > 1)
@@ -81,13 +81,13 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- if (!in.getType().isa<VectorType>()) {
+ if (!isa<VectorType>(in.getType())) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
return rewriter.replaceOp(op, result);
}
- VectorType inType = in.getType().cast<VectorType>();
+ VectorType inType = cast<VectorType>(in.getType());
int64_t numElements = inType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
@@ -179,7 +179,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (op.getRoundingmodeAttr())
return failure();
Type outType = op.getOut().getType();
- if (auto outVecType = outType.dyn_cast<VectorType>()) {
+ if (auto outVecType = dyn_cast<VectorType>(outType)) {
if (outVecType.isScalable())
return failure();
if (outVecType.getShape().size() > 1)
@@ -202,7 +202,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
VectorType truncResType = VectorType::get(4, outElemType);
- if (!in.getType().isa<VectorType>()) {
+ if (!isa<VectorType>(in.getType())) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
@@ -210,7 +210,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
return rewriter.replaceOp(op, result);
}
- VectorType outType = op.getOut().getType().cast<VectorType>();
+ VectorType outType = cast<VectorType>(op.getOut().getType());
int64_t numElements = outType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 993c09b03c0fde..36e10372e4bc5b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -214,7 +214,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
auto remapping = signatureConversion.getInputMapping(idx);
NamedAttrList argAttr =
- argAttrs ? argAttrs[idx].cast<DictionaryAttr>() : NamedAttrList();
+ argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
auto copyAttribute = [&](StringRef attrName) {
Attribute attr = argAttr.erase(attrName);
if (!attr)
@@ -234,9 +234,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return;
}
for (size_t i = 0, e = remapping->size; i < e; ++i) {
- if (llvmFuncOp.getArgument(remapping->inputNo + i)
- .getType()
- .isa<LLVM::LLVMPointerType>()) {
+ if (isa<LLVM::LLVMPointerType>(
+ llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
}
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 78d4e806246872..3a4fc7d8063f40 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -668,7 +668,7 @@ static int32_t getCuSparseLtDataTypeFrom(Type type) {
static int32_t getCuSparseDataTypeFrom(Type type) {
if (llvm::isa<ComplexType>(type)) {
// get the element type
- auto elementType = type.cast<ComplexType>().getElementType();
+ auto elementType = cast<ComplexType>(type).getElementType();
if (elementType.isBF16())
return 15; // CUDA_C_16BF
if (elementType.isF16())
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9b5d19ebd783a9..11d29754aa760e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1579,7 +1579,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
if (offset)
ti = makeAdd(ti, makeConst(offset));
- auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
+ auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
// Number of 32-bit registers owns per thread
constexpr unsigned numAdjacentRegisters = 2;
@@ -1606,9 +1606,9 @@ struct NVGPUWarpgroupMmaStoreOpLowering
int offset = 0;
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value matriDValue = adaptor.getMatrixD();
- auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
+ auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
- auto structType = matrixD.cast<LLVM::LLVMStructType>();
+ auto structType = cast<LLVM::LLVMStructType>(matrixD);
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
@@ -1626,13 +1626,9 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
- LLVM::LLVMStructType packStructType =
- getTypeConverter()
- ->convertType(op.getMatrixC().getType())
- .cast<LLVM::LLVMStructType>();
- Type elemType = packStructType.getBody()
- .front()
- .cast<LLVM::LLVMStructType>()
+ LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
+ getTypeConverter()->convertType(op.getMatrixC().getType()));
+ Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
.getBody()
.front();
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
@@ -1640,7 +1636,7 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
SmallVector<Value> innerStructs;
// Unpack the structs and set all values to zero
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
- auto structType = s.cast<LLVM::LLVMStructType>();
+ auto structType = cast<LLVM::LLVMStructType>(s);
Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
structValue = b.create<LLVM::InsertValueOp>(
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ef8d59c9b26082..b6b85cab5a3821 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -618,7 +618,7 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
Location loc, Operation *operation) {
auto rank =
- operation->getResultTypes().front().cast<RankedTensorType>().getRank();
+ cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
return expandRank(rewriter, loc, operand, rank);
});
@@ -680,7 +680,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
// dimension, that is the target size. An occurrence of an additional static
// dimension greater than 1 with a different value is undefined behavior.
for (auto operand : operands) {
- auto size = operand.getType().cast<RankedTensorType>().getDimSize(dim);
+ auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
if (!ShapedType::isDynamic(size) && size > 1)
return {rewriter.getIndexAttr(size), operand};
}
@@ -688,7 +688,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
// Filter operands with dynamic dimension
auto operandsWithDynamicDim =
llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
- return operand.getType().cast<RankedTensorType>().isDynamicDim(dim);
+ return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
}));
// If no operand has a dynamic dimension, it means all sizes were 1
@@ -718,7 +718,7 @@ static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
computeTargetShape(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, ValueRange operands) {
assert(!operands.empty());
- auto rank = operands.front().getType().cast<RankedTensorType>().getRank();
+ auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
SmallVector<OpFoldResult> targetShape;
SmallVector<Value> masterOperands;
for (auto dim : llvm::seq<int64_t>(0, rank)) {
@@ -735,7 +735,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
int64_t dim, OpFoldResult targetSize,
Value masterOperand) {
// Nothing to do if this is a static dimension
- auto rankedTensorType = operand.getType().cast<RankedTensorType>();
+ auto rankedTensorType = cast<RankedTensorType>(operand.getType());
if (!rankedTensorType.isDynamicDim(dim))
return operand;
@@ -817,7 +817,7 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value operand,
ArrayRef<OpFoldResult> targetShape,
ArrayRef<Value> masterOperands) {
- int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
+ int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
assert((int64_t)targetShape.size() == rank);
assert((int64_t)masterOperands.size() == rank);
for (auto index : llvm::seq<int64_t>(0, rank))
@@ -848,8 +848,7 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
Operation *operation, ValueRange operands,
ArrayRef<OpFoldResult> targetShape) {
// Generate output tensor
- auto resultType =
- operation->getResultTypes().front().cast<RankedTensorType>();
+ auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, targetShape, resultType.getElementType());
@@ -2274,8 +2273,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
llvm::SmallVector<int64_t, 3> staticSizes;
dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
- auto elementType =
- input.getType().cast<RankedTensorType>().getElementType();
+ auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
return RankedTensorType::get(staticSizes, elementType);
}
@@ -2327,7 +2325,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
auto loc = rfft2d.getLoc();
auto input = rfft2d.getInput();
auto elementType =
- input.getType().cast<ShapedType>().getElementType().cast<FloatType>();
+ cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
// Compute the output type and set of dynamic sizes
llvm::SmallVector<Value> dynamicSizes;
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 399c0450824ee5..3f92372d7cea98 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1204,10 +1204,10 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
return rewriter.notifyMatchFailure(op, "no mapping");
matrixOperands.push_back(it->second);
}
- auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
+ auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
if (opType == gpu::MMAElementwiseOp::EXTF) {
// The floating point extension case has a different result type.
- auto vectorType = op->getResultTypes()[0].cast<VectorType>();
+ auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
resultType = gpu::MMAMatrixType::get(resultType.getShape(),
vectorType.getElementType(),
resultType.getOperand());
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 85d10f326e260e..1b9975237c699b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -631,8 +631,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
Type vectorType) {
const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
- auto denseValue =
- DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
+ auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 2575ad4984814b..e3beceaa3bbb5b 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -227,8 +227,8 @@ LogicalResult WMMAOp::verify() {
Type sourceAType = getSourceA().getType();
Type destType = getDestC().getType();
- VectorType sourceVectorAType = sourceAType.dyn_cast<VectorType>();
- VectorType destVectorType = destType.dyn_cast<VectorType>();
+ VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
+ VectorType destVectorType = dyn_cast<VectorType>(destType);
Type sourceAElemType = sourceVectorAType.getElementType();
Type destElemType = destVectorType.getElementType();
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index d7492c9e25db31..5e69a98db8f1ee 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
- auto type = constantOp.getType().dyn_cast<RankedTensorType>();
+ auto type = dyn_cast<RankedTensorType>(constantOp.getType());
// Only ranked tensors are supported.
if (!type)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index b9ab95b92496e3..4a50da3513f99b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -106,7 +106,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
targetType](Type type) -> std::optional<Type> {
if (llvm::is_contained(sourceTypes, type))
return targetType;
- if (auto shaped = type.dyn_cast<ShapedType>())
+ if (auto shaped = dyn_cast<ShapedType>(type))
if (llvm::is_contained(sourceTypes, shaped.getElementType()))
return shaped.clone(targetType);
// All other types legal
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 3ae894692089b3..7e390aa551972a 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -99,7 +99,7 @@ class LowerContractionToSMMLAPattern
Value extsiLhs;
Value extsiRhs;
if (auto lhsExtInType =
- origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+ dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetLhsExtTy =
matchContainerType(rewriter.getI8Type(), lhsExtInType);
@@ -108,7 +108,7 @@ class LowerContractionToSMMLAPattern
}
}
if (auto rhsExtInType =
- origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
+ dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetRhsExtTy =
matchContainerType(rewriter.getI8Type(), rhsExtInType);
@@ -161,9 +161,9 @@ class LowerContractionToSMMLAPattern
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
auto inputElementType =
- tiledLhs.getType().cast<ShapedType>().getElementType();
+ cast<ShapedType>(tiledLhs.getType()).getElementType();
auto accElementType =
- tiledAcc.getType().cast<ShapedType>().getElementType();
+ cast<ShapedType>(tiledAcc.getType()).getElementType();
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
auto outputExpandedType = VectorType::get({2, 2}, accElementType);
@@ -175,9 +175,9 @@ class LowerContractionToSMMLAPattern
auto emptyOperand = rewriter.create<arith::ConstantOp>(
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
SmallVector<int64_t> offsets(
- emptyOperand.getType().cast<ShapedType>().getRank(), 0);
+ cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
SmallVector<int64_t> strides(
- tiledOperand.getType().cast<ShapedType>().getRank(), 1);
+ cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledOperand, emptyOperand, offsets, strides);
};
@@ -214,7 +214,7 @@ class LowerContractionToSMMLAPattern
// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
- tiledRes.getType().cast<ShapedType>().getRank(), 1);
+ cast<ShapedType>(tiledRes.getType()).getRank(), 1);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledRes, result, accOffsets, strides);
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index a5ea42b7d701d0..b197786c320548 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -39,7 +39,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
}
-static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
//===----------------------------------------------------------------------===//
// Ownership
@@ -222,8 +222,8 @@ bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
return false;
// Block arguments are less than results.
- bool lhsIsBBArg = lhs.isa<BlockArgument>();
- if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
+ bool lhsIsBBArg = isa<BlockArgument>(lhs);
+ if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
return lhsIsBBArg;
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index c2b2b99fc0083b..d51d63f243ea0c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -684,7 +684,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
// Op is not bufferizable.
auto memSpace =
- options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
+ options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
@@ -939,7 +939,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If we do not know the memory space and there is no default memory space,
// report a failure.
auto memSpace =
- options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
+ options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
@@ -987,7 +987,7 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
for (Region ®ion : opOperand.getOwner()->getRegions())
if (!region.getBlocks().empty())
for (BlockArgument bbArg : region.getBlocks().front().getArguments())
- if (bbArg.getType().isa<TensorType>())
+ if (isa<TensorType>(bbArg.getType()))
r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
return r;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index c9fd110d48d9a9..a8ec111f8c304b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -46,7 +46,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
}
-static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
/// Return "true" if the given op is guaranteed to have neither "Allocate" nor
/// "Free" side effects.
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 1c81433bc3e945..fb97045687d653 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -378,7 +378,7 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
if (!rhs)
return {};
- ArrayAttr arrayAttr = rhs.dyn_cast<ArrayAttr>();
+ ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
if (!arrayAttr || arrayAttr.size() != 2)
return {};
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index 0dc357c2298fad..89546da428fa20 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -17,7 +17,7 @@
using namespace mlir;
using namespace mlir::bufferization;
-static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
+static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
namespace {
/// While CondBranchOp also implement the BranchOpInterface, we add a
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b037ef3c0b4152..66a71df29a9bb2 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -160,13 +160,13 @@ LogicalResult AddOp::verify() {
Type lhsType = getLhs().getType();
Type rhsType = getRhs().getType();
- if (lhsType.isa<emitc::PointerType>() && rhsType.isa<emitc::PointerType>())
+ if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
return emitOpError("requires that at most one operand is a pointer");
- if ((lhsType.isa<emitc::PointerType>() &&
- !rhsType.isa<IntegerType, emitc::OpaqueType>()) ||
- (rhsType.isa<emitc::PointerType>() &&
- !lhsType.isa<IntegerType, emitc::OpaqueType>()))
+ if ((isa<emitc::PointerType>(lhsType) &&
+ !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
+ (isa<emitc::PointerType>(rhsType) &&
+ !isa<IntegerType, emitc::OpaqueType>(lhsType)))
return emitOpError("requires that one operand is an integer or of opaque "
"type if the other is a pointer");
@@ -778,16 +778,16 @@ LogicalResult SubOp::verify() {
Type rhsType = getRhs().getType();
Type resultType = getResult().getType();
- if (rhsType.isa<emitc::PointerType>() && !lhsType.isa<emitc::PointerType>())
+ if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
return emitOpError("rhs can only be a pointer if lhs is a pointer");
- if (lhsType.isa<emitc::PointerType>() &&
- !rhsType.isa<IntegerType, emitc::OpaqueType, emitc::PointerType>())
+ if (isa<emitc::PointerType>(lhsType) &&
+ !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
return emitOpError("requires that rhs is an integer, pointer or of opaque "
"type if lhs is a pointer");
- if (lhsType.isa<emitc::PointerType>() && rhsType.isa<emitc::PointerType>() &&
- !resultType.isa<IntegerType, emitc::OpaqueType>())
+ if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
+ !isa<IntegerType, emitc::OpaqueType>(resultType))
return emitOpError("requires that the result is an integer or of opaque "
"type if lhs and rhs are pointers");
return success();
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index b584f63f16e0aa..3661c5dea45259 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -196,7 +196,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
if (!extract)
return std::nullopt;
- auto vecType = extract.getResult().getType().cast<VectorType>();
+ auto vecType = cast<VectorType>(extract.getResult().getType());
if (sliceType && sliceType != vecType)
return std::nullopt;
sliceType = vecType;
@@ -204,7 +204,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
return llvm::to_vector(sliceType.getShape());
}
if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
- if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
// TODO: The condition for unrolling elementwise should be restricted
// only to operations that need unrolling (connected to the contract).
if (vecType.getRank() < 2)
@@ -219,7 +219,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
if (!extract)
return std::nullopt;
- auto vecType = extract.getResult().getType().cast<VectorType>();
+ auto vecType = cast<VectorType>(extract.getResult().getType());
if (sliceType && sliceType != vecType)
return std::nullopt;
sliceType = vecType;
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 9ab7ae2a90820e..cfc8d092c8178a 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -354,7 +354,7 @@ static WalkResult loadOperation(
// Gather the variadicities of each result
for (Attribute attr : resultsOp->getVariadicity())
- resultVariadicity.push_back(attr.cast<VariadicityAttr>().getValue());
+ resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue());
}
// Gather which constraint slots correspond to attributes constraints
@@ -367,7 +367,7 @@ static WalkResult loadOperation(
for (const auto &[name, value] : llvm::zip(names, values)) {
for (auto [i, constr] : enumerate(constrToValue)) {
if (constr == value) {
- attributesContraints[name.cast<StringAttr>()] = i;
+ attributesContraints[cast<StringAttr>(name)] = i;
break;
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index f3b674fdb50501..f7f1e944d637d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -42,7 +42,7 @@ static char getRegisterType(Type type) {
return 'f';
if (type.isF64())
return 'd';
- if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
+ if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
// Shared address spaces is addressed with 32-bit pointers.
if (ptr.getAddressSpace() == kSharedMemorySpace) {
return 'r';
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f90240a67dcc5f..1db506e286b3c0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -559,7 +559,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
// we don't do anything here. The verifier will catch it and emit a proper
// error. All other canonicalization is done in the fold method.
bool requiresConst = !rawConstantIndices.empty() &&
- currType.isa_and_nonnull<LLVMStructType>();
+ isa_and_nonnull<LLVMStructType>(currType);
if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
APInt intC;
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
@@ -2564,14 +2564,14 @@ LogicalResult LLVM::ConstantOp::verify() {
}
// See the comment for getLLVMConstant for more details about why 8-bit
// floats can be represented by integers.
- if (getType().isa<IntegerType>() && !getType().isInteger(floatWidth)) {
+ if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
return emitOpError() << "expected integer type of width " << floatWidth;
}
}
if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
- if (!getType().isa<VectorType>() && !getType().isa<LLVM::LLVMArrayType>() &&
- !getType().isa<LLVM::LLVMFixedVectorType>() &&
- !getType().isa<LLVM::LLVMScalableVectorType>())
+ if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
+ !isa<LLVM::LLVMFixedVectorType>(getType()) &&
+ !isa<LLVM::LLVMScalableVectorType>(getType()))
return emitOpError() << "expected vector or array type";
}
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 93901477b58204..f2ab3eae2c343e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -319,7 +319,7 @@ LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
Attribute index) {
auto subelementIndexMap =
- slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
+ cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
if (!subelementIndexMap)
return {};
assert(!subelementIndexMap->empty());
@@ -913,8 +913,7 @@ bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
if (getIsVolatile())
return false;
- if (!slot.elemType.cast<DestructurableTypeInterface>()
- .getSubelementIndexMap())
+ if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
return false;
if (!areAllIndicesI32(slot))
@@ -928,7 +927,7 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
RewriterBase &rewriter,
const DataLayout &dataLayout) {
std::optional<DenseMap<Attribute, Type>> types =
- slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
+ cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
IntegerAttr memsetLenAttr;
bool successfulMatch =
@@ -1047,8 +1046,7 @@ static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
if (op.getIsVolatile())
return false;
- if (!slot.elemType.cast<DestructurableTypeInterface>()
- .getSubelementIndexMap())
+ if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
return false;
if (!areAllIndicesI32(slot))
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index b264e9ff9283d4..0a372ad0c52fcd 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -475,7 +475,7 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
}
}
- auto destructurableType = typeHint.dyn_cast<DestructurableTypeInterface>();
+ auto destructurableType = dyn_cast<DestructurableTypeInterface>(typeHint);
if (!destructurableType)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 3e85559e1ec0c6..768df0953fc5c8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -202,9 +202,9 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
body,
[&](Operation *elem, Operation *red) {
return elem->getName().getStringRef() ==
- (*contractionOps)[0].cast<StringAttr>().getValue() &&
+ cast<StringAttr>((*contractionOps)[0]).getValue() &&
red->getName().getStringRef() ==
- (*contractionOps)[1].cast<StringAttr>().getValue();
+ cast<StringAttr>((*contractionOps)[1]).getValue();
},
os);
if (result)
@@ -259,11 +259,11 @@ transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
return builder.getI64IntegerAttr(value);
}));
};
- results.setParams(getBatch().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getBatch()),
makeI64Attrs(contractionDims->batch));
- results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
- results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
- results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
+ results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
+ results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
+ results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
return DiagnosedSilenceableFailure::success();
}
@@ -288,17 +288,17 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
return builder.getI64IntegerAttr(value);
}));
};
- results.setParams(getBatch().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getBatch()),
makeI64Attrs(convolutionDims->batch));
- results.setParams(getOutputImage().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getOutputImage()),
makeI64Attrs(convolutionDims->outputImage));
- results.setParams(getOutputChannel().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getOutputChannel()),
makeI64Attrs(convolutionDims->outputChannel));
- results.setParams(getFilterLoop().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getFilterLoop()),
makeI64Attrs(convolutionDims->filterLoop));
- results.setParams(getInputChannel().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getInputChannel()),
makeI64Attrs(convolutionDims->inputChannel));
- results.setParams(getDepth().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getDepth()),
makeI64Attrs(convolutionDims->depth));
auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
@@ -307,9 +307,9 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
return builder.getI64IntegerAttr(value);
}));
};
- results.setParams(getStrides().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getStrides()),
makeI64AttrsFromI64(convolutionDims->strides));
- results.setParams(getDilations().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getDilations()),
makeI64AttrsFromI64(convolutionDims->dilations));
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7e7cf1d0244613..3c3d968fbb865e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1219,7 +1219,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
// All the operands must must be equal to the specified type
auto typeattr =
dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
- Type t = typeattr.getValue().cast<::mlir::Type>();
+ Type t = cast<::mlir::Type>(typeattr.getValue());
if (!llvm::all_of(op->getOperandTypes(),
[&](Type operandType) { return operandType == t; }))
return;
@@ -1234,7 +1234,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
for (auto [attr, operandType] :
llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
auto typeattr = cast<mlir::TypeAttr>(attr);
- Type type = typeattr.getValue().cast<::mlir::Type>();
+ Type type = cast<::mlir::Type>(typeattr.getValue());
if (type != operandType)
return;
@@ -2665,7 +2665,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
if (scalableSizes[ofrIdx]) {
auto val = b.create<arith::ConstantIndexOp>(
- getLoc(), attr.cast<IntegerAttr>().getInt());
+ getLoc(), cast<IntegerAttr>(attr).getInt());
Value vscale =
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
sizes.push_back(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index b95677b7457e63..c8ea73427cebd9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -60,7 +60,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
const linalg::BufferizeToAllocationOptions &options) {
auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType());
assert(tensorType && "expected ranked tensor");
- assert(memrefDest.getType().isa<MemRefType>() && "expected ranked memref");
+ assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref");
switch (options.memcpyOp) {
case linalg::BufferizeToAllocationOptions::MemcpyOp::
@@ -496,10 +496,10 @@ Value linalg::bufferizeToAllocation(
if (op == nestedOp)
return;
if (llvm::any_of(nestedOp->getOperands(),
- [](Value v) { return v.getType().isa<TensorType>(); }))
+ [](Value v) { return isa<TensorType>(v.getType()); }))
llvm_unreachable("ops with nested tensor ops are not supported yet");
if (llvm::any_of(nestedOp->getResults(),
- [](Value v) { return v.getType().isa<TensorType>(); }))
+ [](Value v) { return isa<TensorType>(v.getType()); }))
llvm_unreachable("ops with nested tensor ops are not supported yet");
});
}
@@ -508,8 +508,7 @@ Value linalg::bufferizeToAllocation(
// Gather tensor results.
SmallVector<OpResult> tensorResults;
for (OpResult result : op->getResults()) {
- if (!result.getType().isa<TensorType>())
- continue;
+ if (!isa<TensorType>(result.getType())) continue;
// Unranked tensors are not supported
if (!isa<RankedTensorType>(result.getType()))
return nullptr;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 81669a1807796c..18e3cbe517f485 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -49,8 +49,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
for (OpOperand *in : op.getDpsInputOperands()) {
// Skip non-tensor operands.
- if (!in->get().getType().isa<RankedTensorType>())
- continue;
+ if (!isa<RankedTensorType>(in->get().getType())) continue;
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 462f692615faad..df4089d61bfd72 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -405,7 +405,7 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
// Swap tensor inits with the corresponding block argument of the
// scf.forall op. Memref inits remain as is.
- if (outOperand.get().getType().isa<TensorType>()) {
+ if (isa<TensorType>(outOperand.get().getType())) {
auto *it = llvm::find(dest, outOperand.get());
assert(it != dest.end() && "could not find destination tensor");
unsigned destNum = std::distance(dest.begin(), it);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a17bc8e4cd318f..2297bf5e355125 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -557,7 +557,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
Value dest = tensor::PackOp::createDestinationTensor(
rewriter, loc, operand, innerPackSizes, innerPos,
/*outerDimsPerm=*/{});
- ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType operandType = cast<ShapedType>(operand.getType());
bool areConstantTiles =
llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
return getConstantIntValue(tile).has_value();
@@ -565,7 +565,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
if (areConstantTiles && operandType.hasStaticShape() &&
!tensor::PackOp::requirePaddingValue(
operandType.getShape(), innerPos,
- dest.getType().cast<ShapedType>().getShape(), {},
+ cast<ShapedType>(dest.getType()).getShape(), {},
innerPackSizes)) {
packOps.push_back(rewriter.create<tensor::PackOp>(
loc, operand, dest, innerPos, innerPackSizes));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index df61381432921b..fbff2088637f44 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3410,8 +3410,8 @@ struct Conv1DGenerator
// * shape_cast(broadcast(filter))
// * broadcast(shuffle(filter))
// Opt for the option without shape_cast to simplify the codegen.
- auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
- auto resSize = res.getType().cast<VectorType>().getShape()[1];
+ auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
+ auto resSize = cast<VectorType>(res.getType()).getShape()[1];
SmallVector<int64_t, 16> indicies;
for (int i = 0; i < resSize / rhsSize; ++i) {
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index b3481ce1c56bbd..3c9475c2d143a6 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -173,8 +173,8 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
}
// Assemble results.
- results.set(getGlobal().cast<OpResult>(), globalOps);
- results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
+ results.set(cast<OpResult>(getGlobal()), globalOps);
+ results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 8236a4c475f17c..4449733f0daf06 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -254,7 +254,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
auto convertedElementType = convertedType.getElementType();
auto oldElementType = op.getMemRefType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
@@ -351,7 +351,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
int srcBits = op.getMemRefType().getElementTypeBitWidth();
int dstBits = convertedType.getElementTypeBitWidth();
auto dstIntegerType = rewriter.getIntegerType(dstBits);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
index 62a8f7e43c8675..dcc5eac916d034 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
@@ -68,7 +68,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// Get the size of the original buffer.
int64_t inputSize =
- op.getSource().getType().cast<BaseMemRefType>().getDimSize(0);
+ cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
if (ShapedType::isDynamic(inputSize)) {
Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
@@ -79,7 +79,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// Get the requested size that the new buffer should have.
int64_t outputSize =
- op.getResult().getType().cast<BaseMemRefType>().getDimSize(0);
+ cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
? OpFoldResult{op.getDynamicResultSize()}
: rewriter.getIndexAttr(outputSize);
@@ -127,7 +127,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// is already bigger than the requested size, the cast represents a
// subview operation.
Value casted = builder.create<memref::ReinterpretCastOp>(
- loc, op.getResult().getType().cast<MemRefType>(), op.getSource(),
+ loc, cast<MemRefType>(op.getResult().getType()), op.getSource(),
rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
builder.create<scf::YieldOp>(loc, casted);
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 03f11ad1f94965..d4329b401df191 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -169,7 +169,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
}
Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
- RankedTensorType rankedTensorType = type.dyn_cast<RankedTensorType>();
+ RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
if (rankedTensorType) {
return shardShapedType(rankedTensorType, mesh, sharding);
}
@@ -281,7 +281,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
}
bool MeshShardingAttr::operator==(Attribute rhs) const {
- MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
+ MeshShardingAttr rhsAsMeshShardingAttr =
+ mlir::dyn_cast<MeshShardingAttr>(rhs);
return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
}
@@ -484,15 +485,15 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
static LogicalResult verifyGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
- auto resultRank = result.getType().template cast<ShapedType>().getRank();
+ auto resultRank = cast<ShapedType>(result.getType()).getRank();
if (gatherAxis < 0 || gatherAxis >= resultRank) {
return emitError(result.getLoc())
<< "Gather axis " << gatherAxis << " is out of bounds [0, "
<< resultRank << ").";
}
- ShapedType operandType = operand.getType().cast<ShapedType>();
- ShapedType resultType = result.getType().cast<ShapedType>();
+ ShapedType operandType = cast<ShapedType>(operand.getType());
+ ShapedType resultType = cast<ShapedType>(result.getType());
auto deviceGroupSize =
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
@@ -511,8 +512,8 @@ static LogicalResult verifyGatherOperandAndResultShape(
static LogicalResult verifyAllToAllOperandAndResultShape(
Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
- ShapedType operandType = operand.getType().cast<ShapedType>();
- ShapedType resultType = result.getType().cast<ShapedType>();
+ ShapedType operandType = cast<ShapedType>(operand.getType());
+ ShapedType resultType = cast<ShapedType>(result.getType());
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
if (failed(verifyDimensionCompatibility(
@@ -556,8 +557,8 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
static LogicalResult verifyScatterOrSliceOperandAndResultShape(
Value operand, Value result, int64_t tensorAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
- ShapedType operandType = operand.getType().cast<ShapedType>();
- ShapedType resultType = result.getType().cast<ShapedType>();
+ ShapedType operandType = cast<ShapedType>(operand.getType());
+ ShapedType resultType = cast<ShapedType>(result.getType());
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
if (axis != tensorAxis) {
if (failed(verifyDimensionCompatibility(
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 9acee5aa8d8604..dbb9e667d4709c 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -97,7 +97,7 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
FailureOr<std::pair<bool, MeshShardingAttr>>
mesh::getMeshShardingAttr(OpResult result) {
- Value val = result.cast<Value>();
+ Value val = cast<Value>(result);
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
if (!shardOp)
@@ -178,7 +178,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
return failure();
for (OpResult result : op->getResults()) {
- auto resultType = result.getType().dyn_cast<RankedTensorType>();
+ auto resultType = dyn_cast<RankedTensorType>(result.getType());
if (!resultType)
return failure();
AffineMap map = maps[numOperands + result.getResultNumber()];
@@ -404,7 +404,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
if (succeeded(maybeSharding) && !maybeSharding->first)
return success();
- auto resultType = result.getType().cast<RankedTensorType>();
+ auto resultType = cast<RankedTensorType>(result.getType());
SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
SmallVector<MeshAxis> partialAxes;
@@ -457,7 +457,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
return success();
Value operand = opOperand.get();
- auto operandType = operand.getType().cast<RankedTensorType>();
+ auto operandType = cast<RankedTensorType>(operand.getType());
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
@@ -526,7 +526,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
static bool
isValueCompatibleWithFullReplicationSharding(Value value,
MeshShardingAttr sharding) {
- if (value.getType().isa<RankedTensorType>()) {
+ if (isa<RankedTensorType>(value.getType())) {
return sharding && isFullReplication(sharding);
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index e4868435135ed1..6b1326d76bc4a4 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -86,14 +86,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
}
builder.setInsertionPointAfterValue(sourceShard);
- TypedValue<ShapedType> resultValue =
+ TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
builder
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
sourceSharding.getMesh().getLeafReference(),
allReduceMeshAxes, sourceShard,
sourceSharding.getPartialType())
- .getResult()
- .cast<TypedValue<ShapedType>>();
+ .getResult());
llvm::SmallVector<MeshAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
@@ -135,13 +134,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
- TypedValue<ShapedType> targetShard =
+ TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
builder
.create<AllSliceOp>(sourceShard, mesh,
ArrayRef<MeshAxis>(splitMeshAxis),
splitTensorAxis)
- .getResult()
- .cast<TypedValue<ShapedType>>();
+ .getResult());
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
return {targetShard, targetSharding};
@@ -278,10 +276,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
- TypedValue<ShapedType> targetShard =
- builder.create<tensor::CastOp>(targetShape, allGatherResult)
- .getResult()
- .cast<TypedValue<ShapedType>>();
+ TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
return {targetShard, targetSharding};
}
@@ -413,10 +409,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
- TypedValue<ShapedType> targetShard =
- builder.create<tensor::CastOp>(targetShape, allToAllResult)
- .getResult()
- .cast<TypedValue<ShapedType>>();
+ TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
return {targetShard, targetSharding};
}
@@ -505,7 +499,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(
implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
- source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
+ cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
@@ -533,23 +527,22 @@ SmallVector<Type>
shardedBlockArgumentTypes(Block &block,
SymbolTableCollection &symbolTableCollection) {
SmallVector<Type> res;
- llvm::transform(block.getArguments(), std::back_inserter(res),
- [&symbolTableCollection](BlockArgument arg) {
- auto rankedTensorArg =
- arg.dyn_cast<TypedValue<RankedTensorType>>();
- if (!rankedTensorArg) {
- return arg.getType();
- }
-
- assert(rankedTensorArg.hasOneUse());
- Operation *useOp = *rankedTensorArg.getUsers().begin();
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
- assert(shardOp);
- MeshOp mesh = getMesh(shardOp, symbolTableCollection);
- return shardShapedType(rankedTensorArg.getType(), mesh,
- shardOp.getShardAttr())
- .cast<Type>();
- });
+ llvm::transform(
+ block.getArguments(), std::back_inserter(res),
+ [&symbolTableCollection](BlockArgument arg) {
+ auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
+ if (!rankedTensorArg) {
+ return arg.getType();
+ }
+
+ assert(rankedTensorArg.hasOneUse());
+ Operation *useOp = *rankedTensorArg.getUsers().begin();
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
+ assert(shardOp);
+ MeshOp mesh = getMesh(shardOp, symbolTableCollection);
+ return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
+ shardOp.getShardAttr()));
+ });
return res;
}
@@ -587,7 +580,7 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
res.reserve(op.getNumOperands());
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
TypedValue<RankedTensorType> rankedTensor =
- operand.dyn_cast<TypedValue<RankedTensorType>>();
+ dyn_cast<TypedValue<RankedTensorType>>(operand);
if (!rankedTensor) {
return MeshShardingAttr();
}
@@ -608,7 +601,7 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
llvm::transform(op.getResults(), std::back_inserter(res),
[](OpResult result) {
TypedValue<RankedTensorType> rankedTensor =
- result.dyn_cast<TypedValue<RankedTensorType>>();
+ dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
return MeshShardingAttr();
}
@@ -636,9 +629,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
} else {
// Insert resharding.
assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
- TypedValue<ShapedType> srcSpmdValue =
- spmdizationMap.lookup(srcShardOp.getOperand())
- .cast<TypedValue<ShapedType>>();
+ TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
+ spmdizationMap.lookup(srcShardOp.getOperand()));
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
symbolTableCollection);
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index cb13ee404751ca..60c4e07a118cb8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -133,7 +133,7 @@ struct AllSliceOpLowering
// insert tensor.extract_slice
RankedTensorType operandType =
- op.getOperand().getType().cast<RankedTensorType>();
+ cast<RankedTensorType>(op.getOperand().getType());
SmallVector<OpFoldResult> sizes;
for (int64_t i = 0; i < operandType.getRank(); ++i) {
if (i == sliceAxis) {
@@ -202,10 +202,9 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder) {
Operation::result_range meshShape =
builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
- return arith::createProduct(builder, builder.getLoc(),
- llvm::to_vector_of<Value>(meshShape),
- builder.getIndexType())
- .cast<TypedValue<IndexType>>();
+ return cast<TypedValue<IndexType>>(arith::createProduct(
+ builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
+ builder.getIndexType()));
}
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 1635297a5447d4..4e256aea0be37a 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -651,7 +651,7 @@ struct MmaSyncBuilder {
template <typename ApplyFn, typename ReduceFn>
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
ReduceFn reduceFn) {
- VectorType vectorType = vector.getType().cast<VectorType>();
+ VectorType vectorType = cast<VectorType>(vector.getType());
auto vectorShape = vectorType.getShape();
auto strides = computeStrides(vectorShape);
for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
@@ -779,11 +779,11 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
- assert(lhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
+ assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
"expected lhs to be a 2D memref");
- assert(rhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
+ assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
"expected rhs to be a 2D memref");
- assert(resMemRef.getType().cast<MemRefType>().getRank() == 2 &&
+ assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
"expected res to be a 2D memref");
int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1e480d6471cbce..f380926c4bce3f 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1318,7 +1318,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
Type varType = std::get<0>(privateVarInfo).getType();
SymbolRefAttr privatizerSym =
- std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
+ cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
PrivateClauseOp privatizerOp =
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
privatizerSym);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 481275f052a3ce..187b3b71e3458b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -145,9 +145,9 @@ void VarInfo::setNum(Var::Num n) {
/// mismatches.
LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
- const auto loc1 = parser.getEncodedSourceLoc(sm1).dyn_cast<FileLineColLoc>();
+ const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
- const auto loc2 = parser.getEncodedSourceLoc(sm2).dyn_cast<FileLineColLoc>();
+ const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
if (loc1.getFilename() != loc2.getFilename())
return SMLoc();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 516b0943bdcfac..b1d44559fa5aba 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2078,7 +2078,7 @@ struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
- if (attr.isa<SparseTensorEncodingAttr>()) {
+ if (isa<SparseTensorEncodingAttr>(attr)) {
os << "sparse";
return AliasResult::OverridableAlias;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
index 4f9988d48d7710..9c84f4c25866fd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp
@@ -29,7 +29,7 @@ LogicalResult sparse_tensor::detail::stageWithSortImpl(
Location loc = op.getLoc();
Type finalTp = op->getOpResult(0).getType();
- SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
+ SparseTensorType dstStt(cast<RankedTensorType>(finalTp));
Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
// Clones the original operation but changing the output to an unordered COO.
diff --git a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp
index 5b7ea9360e2211..ca19259ebffa68 100644
--- a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp
+++ b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp
@@ -25,7 +25,7 @@ DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation(
return emitSilenceableFailure(current->getLoc(),
"operation has no sparse input or output");
}
- results.set(getResult().cast<OpResult>(), state.getPayloadOps(getTarget()));
+ results.set(cast<OpResult>(getResult()), state.getPayloadOps(getTarget()));
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index eafbe95b7aebe0..a53bce16dad860 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -42,7 +42,7 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
- auto rtp = t.cast<ShapedType>();
+ auto rtp = cast<ShapedType>(t);
if (!directOut) {
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
if (extraTypes)
@@ -97,7 +97,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
toVals.push_back(mem);
} else {
- ShapedType rtp = t.cast<ShapedType>();
+ ShapedType rtp = cast<ShapedType>(t);
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 9c0fc60877d8a3..36ecf692b02c51 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -502,7 +502,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
for (const AffineExpr l : order.getResults()) {
unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
auto itTp =
- linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
+ cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
if (linalg::isReductionIterator(itTp.getValue()))
break; // terminate at first reduction
nest++;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b117c1694e45b8..02375f54d7152f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -476,8 +476,8 @@ struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
if (!sel)
return std::nullopt;
- auto tVal = sel.getTrueValue().dyn_cast<BlockArgument>();
- auto fVal = sel.getFalseValue().dyn_cast<BlockArgument>();
+ auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
+ auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
// TODO: For simplicity, we only handle cases where both true/false value
// are directly loaded the input tensor. We can probably admit more cases
// in theory.
@@ -487,7 +487,7 @@ struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
// Helper lambda to determine whether the value is loaded from a dense input
// or is a loop invariant.
auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
- if (auto bArg = v.dyn_cast<BlockArgument>();
+ if (auto bArg = dyn_cast<BlockArgument>(v);
bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
return true;
// If the value is defined outside the loop, it is a loop invariant.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index 89af75dea2a0f2..de553a5f9bf08c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -165,7 +165,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
Value elem, Type dstTp) {
- if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+ if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) {
// Scalars can only be converted to 0-ranked tensors.
assert(rtp.getRank() == 0);
elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index 66f96ba08c0ed2..8981de58306da3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -157,8 +157,7 @@ IterationGraphSorter::IterationGraphSorter(
// The number of results of the map should match the rank of the tensor.
assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
auto [m, v] = mvPair;
- return m.getNumResults() ==
- v.getType().template cast<ShapedType>().getRank();
+ return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
}));
itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ce40e81371209..9de4bd63800919 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -820,7 +820,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
if (!destOp)
return failure();
- auto resultIndex = source.cast<OpResult>().getResultNumber();
+ auto resultIndex = cast<OpResult>(source).getResultNumber();
auto *initOperand = destOp.getDpsInitOperand(resultIndex);
rewriter.modifyOpInPlace(
@@ -3475,8 +3475,7 @@ SplatOp::reifyResultShapes(OpBuilder &builder,
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getInput();
- if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
- return {};
+ if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand)) return {};
// Do not fold if the splat is not statically shaped
if (!getType().hasStaticShape())
@@ -4307,7 +4306,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
/// unpack(destinationStyleOp(x)) -> unpack(x)
if (auto dstStyleOp =
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
- auto destValue = unPackOp.getDest().cast<OpResult>();
+ auto destValue = cast<OpResult>(unPackOp.getDest());
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
rewriter.modifyOpInPlace(unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index 06a441dbeaf150..137156fe1a73e2 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -32,7 +32,7 @@ namespace {
struct MatMulOpSharding
: public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
- auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+ auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
if (!tensorType)
return {};
@@ -48,7 +48,7 @@ struct MatMulOpSharding
}
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
- auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+ auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
if (!tensorType)
return {};
MLIRContext *ctx = op->getContext();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c8bf4c526b239f..c139d5f60024ca 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -285,7 +285,7 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
return failure();
}
- if (inputElementType.isa<FloatType>()) {
+ if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
auto minClamp = op.getMinFp();
auto maxClamp = op.getMaxFp();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e06ac9a27ae4cc..10e6016a1ed431 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -168,7 +168,7 @@ ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
- if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
typeAttr = TypeAttr::get(typedAttr.getType());
}
return success();
@@ -186,7 +186,7 @@ ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr) {
bool needsSpace = false;
- auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
+ auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
if (!typedAttr || typedAttr.getType() != type.getValue()) {
p << ": ";
p.printAttribute(type);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 6575b39fd45a1f..6eef2c5018d6d7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -371,7 +371,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
auto reductionAxis = op.getAxis();
const auto denseElementsAttr = constOp.getValue();
const auto shapedOldElementsValues =
- denseElementsAttr.getType().cast<ShapedType>();
+ cast<ShapedType>(denseElementsAttr.getType());
if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 74ef6381f3d701..c99f62d5ae1124 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -357,7 +357,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
bool levelCheckTransposeConv2d(Operation *op) {
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
if (ShapedType filterType =
- transpose.getFilter().getType().dyn_cast<ShapedType>()) {
+ dyn_cast<ShapedType>(transpose.getFilter().getType())) {
auto shape = filterType.getShape();
assert(shape.size() == 4);
// level check kernel sizes for kH and KW
diff --git a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
index 4b3e28e4313c64..94d4a96a07ad64 100644
--- a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
@@ -21,16 +21,15 @@ DiagnosedSilenceableFailure
transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- if (getAt().getType().isa<TransformHandleTypeInterface>()) {
+ if (isa<TransformHandleTypeInterface>(getAt().getType())) {
auto payload = state.getPayloadOps(getAt());
for (Operation *op : payload)
op->emitRemark() << getMessage();
return DiagnosedSilenceableFailure::success();
}
- assert(
- getAt().getType().isa<transform::TransformValueHandleTypeInterface>() &&
- "unhandled kind of transform type");
+ assert(isa<transform::TransformValueHandleTypeInterface>(getAt().getType()) &&
+ "unhandled kind of transform type");
auto describeValue = [](Diagnostic &os, Value value) {
os << "value handle points to ";
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 53f958caa0bdb7..7a5a6974700586 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1615,7 +1615,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
}
params.push_back(TypeAttr::get(type));
}
- results.setParams(getResult().cast<OpResult>(), params);
+ results.setParams(cast<OpResult>(getResult()), params);
return DiagnosedSilenceableFailure::success();
}
@@ -2217,14 +2217,14 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
llvm_unreachable("unknown kind of transform dialect type");
return 0;
});
- results.setParams(getNum().cast<OpResult>(),
+ results.setParams(cast<OpResult>(getNum()),
rewriter.getI64IntegerAttr(numAssociations));
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::NumAssociationsOp::verify() {
// Verify that the result type accepts an i64 attribute as payload.
- auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
+ auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
return resultType
.checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
.checkAndReport();
diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
index 8d9f105d1c5dba..9a24c2baebabb2 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
@@ -44,7 +44,7 @@ DiagnosedSilenceableFailure
transform::AffineMapParamType::checkPayload(Location loc,
ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
- if (!attr.isa<AffineMapAttr>()) {
+ if (!mlir::isa<AffineMapAttr>(attr)) {
return emitSilenceableError(loc)
<< "expected affine map attribute, got " << attr;
}
@@ -144,7 +144,7 @@ DiagnosedSilenceableFailure
transform::TypeParamType::checkPayload(Location loc,
ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
- if (!attr.isa<TypeAttr>()) {
+ if (!mlir::isa<TypeAttr>(attr)) {
return emitSilenceableError(loc)
<< "expected type attribute, got " << attr;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e6425879cc67f..19bbc163775a0f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6169,8 +6169,7 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getInput();
- if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
- return {};
+ if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand)) return {};
// SplatElementsAttr::get treats single value for second arg as being a splat.
return SplatElementsAttr::get(getType(), {constOperand});
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 0693aa596cb28f..b30b43d70bf0f4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -57,7 +57,7 @@ static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
SmallVector<int64_t> permutation;
for (int64_t i = addedRank,
- e = broadcasted.getType().cast<VectorType>().getRank();
+ e = cast<VectorType>(broadcasted.getType()).getRank();
i < e; ++i)
permutation.push_back(i);
for (int64_t i = 0; i < addedRank; ++i)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 8d733c5a8849b6..7ed3dea42b7715 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -403,7 +403,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
// Such transposes do not materially effect the underlying vector and can
// be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
bool transposeNonOuterUnitDims = false;
- auto operandShape = operands[it.index()].getType().cast<ShapedType>();
+ auto operandShape = cast<ShapedType>(operands[it.index()].getType());
for (auto [index, dim] :
llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
if (dim != static_cast<int64_t>(index) &&
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dc6f126aae4c87..d24721f3defa65 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -63,7 +63,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
// new mask index) only happens on the last dimension of the vectors.
Operation *newMask = nullptr;
SmallVector<int64_t> shape(
- maskOp->getResultTypes()[0].cast<VectorType>().getShape());
+ cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
shape.back() = numElements;
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
if (createMaskOp) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index b844c2bfa837ce..ee622e886f6185 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -171,7 +171,7 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
/// is first inserted, followed by a `memref.cast`.
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
MemRefType compatibleMemRefType) {
- MemRefType sourceType = memref.getType().cast<MemRefType>();
+ MemRefType sourceType = cast<MemRefType>(memref.getType());
Value res = memref;
if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
sourceType = MemRefType::get(
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 530c50ef74f7a0..23c5749c2309de 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -127,7 +127,7 @@ LogicalResult CreateNdDescOp::verify() {
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
- auto memrefTy = getSourceType().dyn_cast<MemRefType>();
+ auto memrefTy = dyn_cast<MemRefType>(getSourceType());
if (memrefTy) {
invalidRank |= (memrefTy.getRank() != rank);
invalidElemTy |= memrefTy.getElementType() != getElementType();
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 6cdc2682753fc7..411ac656e4afbd 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -711,7 +711,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
for (int64_t i = 0; i < map.getNumDims(); ++i) {
if (auto attr = operands[i].dyn_cast<Attribute>()) {
dimReplacements.push_back(
- b.getAffineConstantExpr(attr.cast<IntegerAttr>().getInt()));
+ b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
} else {
dimReplacements.push_back(b.getAffineDimExpr(numDims++));
remainingValues.push_back(operands[i].get<Value>());
@@ -721,7 +721,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
for (int64_t i = 0; i < map.getNumSymbols(); ++i) {
if (auto attr = operands[i + map.getNumDims()].dyn_cast<Attribute>()) {
symReplacements.push_back(
- b.getAffineConstantExpr(attr.cast<IntegerAttr>().getInt()));
+ b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
} else {
symReplacements.push_back(b.getAffineSymbolExpr(numSymbols++));
remainingValues.push_back(operands[i + map.getNumDims()].get<Value>());
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index db903d540761b7..0feb078db297d3 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1154,7 +1154,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) {
// delegate function that returns rank of shaped type with known rank
auto getRank = [](const Type type) {
- return type.cast<ShapedType>().getRank();
+ return cast<ShapedType>(type).getRank();
};
auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e89ff9209b034a..9cf65223f9d6aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2489,8 +2489,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
auto addDevInfos = [&, fail](auto devOperands, auto devOpType) -> void {
for (const auto &devOp : devOperands) {
// TODO: Only LLVMPointerTypes are handled.
- if (!devOp.getType().template isa<LLVM::LLVMPointerType>())
- return fail();
+ if (!isa<LLVM::LLVMPointerType>(devOp.getType())) return fail();
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devOp);
@@ -3083,10 +3082,9 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
std::vector<llvm::GlobalVariable *> generatedRefs;
std::vector<llvm::Triple> targetTriple;
- auto targetTripleAttr =
- op->getParentOfType<mlir::ModuleOp>()
- ->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())
- .dyn_cast_or_null<mlir::StringAttr>();
+ auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
+ op->getParentOfType<mlir::ModuleOp>()->getAttr(
+ LLVM::LLVMDialect::getTargetTripleAttrName()));
if (targetTripleAttr)
targetTriple.emplace_back(targetTripleAttr.data());
@@ -3328,7 +3326,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
attribute.getName())
.Case("omp.is_target_device",
[&](Attribute attr) {
- if (auto deviceAttr = attr.dyn_cast<BoolAttr>()) {
+ if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.setIsTargetDevice(deviceAttr.getValue());
@@ -3338,7 +3336,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
})
.Case("omp.is_gpu",
[&](Attribute attr) {
- if (auto gpuAttr = attr.dyn_cast<BoolAttr>()) {
+ if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.setIsGPU(gpuAttr.getValue());
@@ -3348,7 +3346,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
})
.Case("omp.host_ir_filepath",
[&](Attribute attr) {
- if (auto filepathAttr = attr.dyn_cast<StringAttr>()) {
+ if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
llvm::OpenMPIRBuilder *ompBuilder =
moduleTranslation.getOpenMPBuilder();
ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
@@ -3358,13 +3356,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
})
.Case("omp.flags",
[&](Attribute attr) {
- if (auto rtlAttr = attr.dyn_cast<omp::FlagsAttr>())
+ if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
return convertFlagsAttr(op, rtlAttr, moduleTranslation);
return failure();
})
.Case("omp.version",
[&](Attribute attr) {
- if (auto versionAttr = attr.dyn_cast<omp::VersionAttr>()) {
+ if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
llvm::OpenMPIRBuilder *ompBuilder =
moduleTranslation.getOpenMPBuilder();
ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
@@ -3376,15 +3374,14 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
.Case("omp.declare_target",
[&](Attribute attr) {
if (auto declareTargetAttr =
- attr.dyn_cast<omp::DeclareTargetAttr>())
+ dyn_cast<omp::DeclareTargetAttr>(attr))
return convertDeclareTargetAttr(op, declareTargetAttr,
moduleTranslation);
return failure();
})
.Case("omp.requires",
[&](Attribute attr) {
- if (auto requiresAttr =
- attr.dyn_cast<omp::ClauseRequiresAttr>()) {
+ if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
using Requires = omp::ClauseRequires;
Requires flags = requiresAttr.getValue();
llvm::OpenMPIRBuilderConfig &config =
diff --git a/mlir/lib/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.cpp
index 8212725b5a58be..b78b002d322920 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.cpp
@@ -29,8 +29,8 @@ using mlir::LLVM::detail::createIntrinsicCall;
/// option around.
static llvm::Type *getXlenType(Attribute opcodeAttr,
LLVM::ModuleTranslation &moduleTranslation) {
- auto intAttr = opcodeAttr.cast<IntegerAttr>();
- unsigned xlenWidth = intAttr.getType().cast<IntegerType>().getWidth();
+ auto intAttr = cast<IntegerAttr>(opcodeAttr);
+ unsigned xlenWidth = cast<IntegerType>(intAttr.getType()).getWidth();
return llvm::Type::getIntNTy(moduleTranslation.getLLVMContext(), xlenWidth);
}
diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
index c8bee817213d8d..e17fe12b9088bd 100644
--- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
+++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
@@ -25,7 +25,7 @@ namespace {
/// according to LLVM's encoding:
/// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html
static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
- VectorType vt = type.cast<VectorType>();
+ VectorType vt = cast<VectorType>(type);
// To simplify test pass, avoid multi-dimensional vectors.
if (!vt || vt.getRank() != 1)
return {0, nullptr};
@@ -39,7 +39,7 @@ static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
sew = 32;
else if (eltTy.isF64())
sew = 64;
- else if (auto intTy = eltTy.dyn_cast<IntegerType>())
+ else if (auto intTy = dyn_cast<IntegerType>(eltTy))
sew = intTy.getWidth();
else
return {0, nullptr};
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index 9b3082a819224f..5e3918f79d1844 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -67,12 +67,11 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ShapedType sourceShardShape =
shardShapedType(op.getResult().getType(), mesh, op.getShard());
- TypedValue<ShapedType> sourceShard =
+ TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
builder
.create<UnrealizedConversionCastOp>(sourceShardShape,
op.getOperand())
- ->getResult(0)
- .cast<TypedValue<ShapedType>>();
+ ->getResult(0));
TypedValue<ShapedType> targetShard =
reshard(builder, mesh, op, targetShardOp, sourceShard);
Value newTargetUnsharded =
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index 2dd99c67c1439b..fa093cafcb0dc3 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -61,7 +61,7 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
}
bool createSymbol = false;
- if (auto boolAttr = attr.dyn_cast<BoolAttr>())
+ if (auto boolAttr = dyn_cast<BoolAttr>(attr))
createSymbol = boolAttr.getValue();
if (createSymbol) {
diff --git a/mlir/test/lib/IR/TestAffineWalk.cpp b/mlir/test/lib/IR/TestAffineWalk.cpp
index 8361b48ce42857..e8b836888b459f 100644
--- a/mlir/test/lib/IR/TestAffineWalk.cpp
+++ b/mlir/test/lib/IR/TestAffineWalk.cpp
@@ -44,7 +44,7 @@ void TestAffineWalk::runOnOperation() {
// Test whether the walk is being correctly interrupted.
m.walk([](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) {
- auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>();
+ auto mapAttr = dyn_cast<AffineMapAttr>(attr.getValue());
if (!mapAttr)
return;
checkMod(mapAttr.getAffineMap(), op->getLoc());
diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index 71ed30bfbe34cd..93c4bcfe1424e9 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -51,7 +51,7 @@ struct TestElementsAttrInterface
InFlightDiagnostic diag = op->emitError()
<< "Test iterating `" << type << "`: ";
- if (!attr.getElementType().isa<mlir::IntegerType>()) {
+ if (!isa<mlir::IntegerType>(attr.getElementType())) {
diag << "expected element type to be an integer type";
return;
}
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 56af3c15b905f1..77aa30f847dcd0 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -61,7 +61,7 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args) {
auto *op = args[0].cast<Operation *>();
- int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
+ int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt();
if (op->getName().getStringRef() == "test.success_op") {
SmallVector<Type> types;
More information about the Mlir-commits
mailing list