[Mlir-commits] [mlir] 5550c82 - [mlir] Move casting calls from methods to function calls
Tres Popp
llvmlistbot at llvm.org
Fri May 12 02:52:43 PDT 2023
Author: Tres Popp
Date: 2023-05-12T11:21:25+02:00
New Revision: 5550c821897ab77e664977121a0e90ad5be1ff59
URL: https://github.com/llvm/llvm-project/commit/5550c821897ab77e664977121a0e90ad5be1ff59
DIFF: https://github.com/llvm/llvm-project/commit/5550c821897ab77e664977121a0e90ad5be1ff59.diff
LOG: [mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.
Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.
Caveats include:
- This clang-tidy script probably has more problems.
- This only touches C++ code, so nothing that is being generated.
Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443
Implementation:
This first patch was created with the following steps. The intention is
to only do automated changes at first, so I waste less time if it's
reverted, and so the first mass change is more clear as an example to
other teams that will need to follow similar steps.
Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
additional check:
https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
them to a pure state.
4. Some changes have been deleted for the following reasons:
- Some files had a variable also named cast
- Some files had not included a header file that defines the cast
functions
- Some files are definitions of the classes that have the casting
methods, so the code still refers to the method instead of the
function without adding a prefix or removing the method declaration
at the same time.
```
ninja -C $BUILD_DIR clang-tidy
run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
-header-filter=mlir/ mlir/* -fix
rm -rf $BUILD_DIR/tools/mlir/**/*.inc
git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\
mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\
mlir/lib/**/IR/\
mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\
mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\
mlir/test/lib/Dialect/Test/TestTypes.cpp\
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\
mlir/test/lib/Dialect/Test/TestAttributes.cpp\
mlir/unittests/TableGen/EnumsGenTest.cpp\
mlir/test/python/lib/PythonTestCAPI.cpp\
mlir/include/mlir/IR/
```
Differential Revision: https://reviews.llvm.org/D150123
Added:
Modified:
mlir/include/mlir/Bytecode/BytecodeImplementation.h
mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
mlir/include/mlir/Dialect/Arith/IR/Arith.h
mlir/include/mlir/Dialect/Async/IR/Async.h
mlir/include/mlir/Dialect/CommonFolders.h
mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/Quant/UniformSupport.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Analysis/Liveness.cpp
mlir/lib/Analysis/SliceAnalysis.cpp
mlir/lib/AsmParser/AsmParserState.cpp
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/AsmParser/DialectSymbolParser.cpp
mlir/lib/AsmParser/Parser.cpp
mlir/lib/AsmParser/Parser.h
mlir/lib/AsmParser/TypeParser.cpp
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Bytecode/Writer/IRNumbering.cpp
mlir/lib/CAPI/Dialect/PDL.cpp
mlir/lib/CAPI/Dialect/Quant.cpp
mlir/lib/CAPI/Dialect/SparseTensor.cpp
mlir/lib/CAPI/Dialect/Transform.cpp
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
mlir/lib/Dialect/Affine/Analysis/Utils.cpp
mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
mlir/lib/Dialect/Arith/Utils/Utils.cpp
mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/Split.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
mlir/lib/Dialect/Traits.cpp
mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/lib/Interfaces/DataLayoutInterfaces.cpp
mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
mlir/lib/Interfaces/InferIntRangeInterface.cpp
mlir/lib/Interfaces/InferTypeOpInterface.cpp
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Target/Cpp/TranslateToCpp.cpp
mlir/lib/Target/LLVMIR/DebugTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp
mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.h
mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
mlir/lib/Transforms/CSE.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/lib/Transforms/Mem2Reg.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
mlir/lib/Transforms/Utils/RegionUtils.cpp
mlir/lib/Transforms/ViewOpGraph.cpp
mlir/test/lib/Analysis/TestAliasAnalysis.cpp
mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
mlir/test/lib/IR/TestDiagnostics.cpp
mlir/test/lib/IR/TestFunc.cpp
mlir/test/lib/IR/TestInterfaces.cpp
mlir/test/lib/IR/TestOpaqueLoc.cpp
mlir/test/lib/IR/TestPrintDefUse.cpp
mlir/test/lib/Transforms/TestTopologicalSort.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
mlir/unittests/IR/AttributeTest.cpp
mlir/unittests/IR/InterfaceAttachmentTest.cpp
mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
mlir/unittests/Pass/PassManagerTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 60f5475609ac7..027df35135683 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -83,7 +83,7 @@ class DialectBytecodeReader {
Attribute baseResult;
if (failed(readAttribute(baseResult)))
return failure();
- if ((result = baseResult.dyn_cast<T>()))
+ if ((result = dyn_cast<T>(baseResult)))
return success();
return emitError() << "expected " << llvm::getTypeName<T>()
<< ", but got: " << baseResult;
@@ -100,7 +100,7 @@ class DialectBytecodeReader {
Type baseResult;
if (failed(readType(baseResult)))
return failure();
- if ((result = baseResult.dyn_cast<T>()))
+ if ((result = dyn_cast<T>(baseResult)))
return success();
return emitError() << "expected " << llvm::getTypeName<T>()
<< ", but got: " << baseResult;
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 95b12b69153aa..eea16b4da6a69 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -38,9 +38,8 @@ class AttrConvertFastMathToLLVM {
// Get the name of the arith fastmath attribute.
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
// Remove the source fastmath attribute.
- auto arithFMFAttr =
- convertedAttr.erase(arithFMFAttrName)
- .template dyn_cast_or_null<arith::FastMathFlagsAttr>();
+ auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>(
+ convertedAttr.erase(arithFMFAttrName));
if (arithFMFAttr) {
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
convertedAttr.set(targetAttrName,
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
index 8684d353aef86..a063623c8305b 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h
@@ -31,7 +31,7 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
Value input, Value alignment);
static MemRefType getMemRefResultType(Operation *op) {
- return op->getResult(0).getType().cast<MemRefType>();
+ return cast<MemRefType>(op->getResult(0).getType());
}
/// Computes the alignment for the given memory allocation op.
diff --git a/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h b/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
index d4f9a6ebe7fa4..7def9e25a69ae 100644
--- a/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
+++ b/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
@@ -95,7 +95,7 @@ class FileLineColLocBreakpointManager
std::optional<Breakpoint *> matchFromLocation(Location initialLoc) const {
std::optional<Breakpoint *> match = std::nullopt;
initialLoc->walk([&](Location loc) {
- auto fileLoc = loc.dyn_cast<FileLineColLoc>();
+ auto fileLoc = dyn_cast<FileLineColLoc>(loc);
if (!fileLoc)
return WalkResult::advance();
StringRef file = fileLoc.getFilename();
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 008d398a080d6..1409d52d64a6f 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -106,7 +106,7 @@ class AffineDmaStartOp
/// Returns the source MemRefType for this DMA operation.
Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
MemRefType getSrcMemRefType() {
- return getSrcMemRef().getType().cast<MemRefType>();
+ return cast<MemRefType>(getSrcMemRef().getType());
}
/// Returns the rank (number of indices) of the source MemRefType.
@@ -115,7 +115,7 @@ class AffineDmaStartOp
/// Returns the affine map used to access the source memref.
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
AffineMapAttr getSrcMapAttr() {
- return (*this)->getAttr(getSrcMapAttrStrName()).cast<AffineMapAttr>();
+ return cast<AffineMapAttr>((*this)->getAttr(getSrcMapAttrStrName()));
}
/// Returns the source memref affine map indices for this DMA operation.
@@ -127,7 +127,7 @@ class AffineDmaStartOp
/// Returns the memory space of the source memref.
unsigned getSrcMemorySpace() {
- return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
+ return cast<MemRefType>(getSrcMemRef().getType()).getMemorySpaceAsInt();
}
/// Returns the operand index of the destination memref.
@@ -138,23 +138,23 @@ class AffineDmaStartOp
/// Returns the destination MemRefType for this DMA operation.
Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
MemRefType getDstMemRefType() {
- return getDstMemRef().getType().cast<MemRefType>();
+ return cast<MemRefType>(getDstMemRef().getType());
}
/// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
- return getDstMemRef().getType().cast<MemRefType>().getRank();
+ return cast<MemRefType>(getDstMemRef().getType()).getRank();
}
/// Returns the memory space of the source memref.
unsigned getDstMemorySpace() {
- return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
+ return cast<MemRefType>(getDstMemRef().getType()).getMemorySpaceAsInt();
}
/// Returns the affine map used to access the destination memref.
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
AffineMapAttr getDstMapAttr() {
- return (*this)->getAttr(getDstMapAttrStrName()).cast<AffineMapAttr>();
+ return cast<AffineMapAttr>((*this)->getAttr(getDstMapAttrStrName()));
}
/// Returns the destination memref indices for this DMA operation.
@@ -172,18 +172,18 @@ class AffineDmaStartOp
/// Returns the Tag MemRef for this DMA operation.
Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
MemRefType getTagMemRefType() {
- return getTagMemRef().getType().cast<MemRefType>();
+ return cast<MemRefType>(getTagMemRef().getType());
}
/// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
- return getTagMemRef().getType().cast<MemRefType>().getRank();
+ return cast<MemRefType>(getTagMemRef().getType()).getRank();
}
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
- return (*this)->getAttr(getTagMapAttrStrName()).cast<AffineMapAttr>();
+ return cast<AffineMapAttr>((*this)->getAttr(getTagMapAttrStrName()));
}
/// Returns the tag memref indices for this DMA operation.
@@ -299,13 +299,13 @@ class AffineDmaWaitOp
/// Returns the Tag MemRef associated with the DMA operation being waited on.
Value getTagMemRef() { return getOperand(0); }
MemRefType getTagMemRefType() {
- return getTagMemRef().getType().cast<MemRefType>();
+ return cast<MemRefType>(getTagMemRef().getType());
}
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
- return (*this)->getAttr(getTagMapAttrStrName()).cast<AffineMapAttr>();
+ return cast<AffineMapAttr>((*this)->getAttr(getTagMapAttrStrName()));
}
/// Returns the tag memref index for this DMA operation.
@@ -316,7 +316,7 @@ class AffineDmaWaitOp
/// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() {
- return getTagMemRef().getType().cast<MemRefType>().getRank();
+ return cast<MemRefType>(getTagMemRef().getType()).getRank();
}
/// Impelements the AffineMapAccessInterface. Returns the AffineMapAttr
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index f285262982816..1b516ff1e5aa1 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -63,7 +63,7 @@ class ConstantIntOp : public arith::ConstantOp {
Type type);
inline int64_t value() {
- return arith::ConstantOp::getValue().cast<IntegerAttr>().getInt();
+ return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
}
static bool classof(Operation *op);
@@ -79,7 +79,7 @@ class ConstantFloatOp : public arith::ConstantOp {
const APFloat &value, FloatType type);
inline APFloat value() {
- return arith::ConstantOp::getValue().cast<FloatAttr>().getValue();
+ return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
}
static bool classof(Operation *op);
@@ -94,7 +94,7 @@ class ConstantIndexOp : public arith::ConstantOp {
static void build(OpBuilder &builder, OperationState &result, int64_t value);
inline int64_t value() {
- return arith::ConstantOp::getValue().cast<IntegerAttr>().getInt();
+ return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
}
static bool classof(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index 585a231d24739..9265e2f8065a1 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -49,7 +49,7 @@ namespace async {
/// Returns true if the type is reference counted at runtime.
inline bool isRefCounted(Type type) {
- return type.isa<TokenType, ValueType, GroupType>();
+ return isa<TokenType, ValueType, GroupType>(type);
}
} // namespace async
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 8007027a86514..d3fbc723632a3 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -36,9 +36,9 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
if (!resultType || !operands[0] || !operands[1])
return {};
- if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
- auto lhs = operands[0].cast<AttrElementT>();
- auto rhs = operands[1].cast<AttrElementT>();
+ if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
+ auto lhs = cast<AttrElementT>(operands[0]);
+ auto rhs = cast<AttrElementT>(operands[1]);
if (lhs.getType() != rhs.getType())
return {};
@@ -50,12 +50,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
return AttrElementT::get(resultType, *calRes);
}
- if (operands[0].isa<SplatElementsAttr>() &&
- operands[1].isa<SplatElementsAttr>()) {
+ if (isa<SplatElementsAttr>(operands[0]) &&
+ isa<SplatElementsAttr>(operands[1])) {
// Both operands are splats so we can avoid expanding the values out and
// just fold based on the splat value.
- auto lhs = operands[0].cast<SplatElementsAttr>();
- auto rhs = operands[1].cast<SplatElementsAttr>();
+ auto lhs = cast<SplatElementsAttr>(operands[0]);
+ auto rhs = cast<SplatElementsAttr>(operands[1]);
if (lhs.getType() != rhs.getType())
return {};
@@ -67,11 +67,11 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
}
- if (operands[0].isa<ElementsAttr>() && operands[1].isa<ElementsAttr>()) {
+ if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
- auto lhs = operands[0].cast<ElementsAttr>();
- auto rhs = operands[1].cast<ElementsAttr>();
+ auto lhs = cast<ElementsAttr>(operands[0]);
+ auto rhs = cast<ElementsAttr>(operands[1]);
if (lhs.getType() != rhs.getType())
return {};
@@ -103,7 +103,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
auto getResultType = [](Attribute attr) -> Type {
- if (auto typed = attr.dyn_cast_or_null<TypedAttr>())
+ if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
return typed.getType();
return {};
};
@@ -158,27 +158,27 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
if (!operands[0])
return {};
- if (operands[0].isa<AttrElementT>()) {
- auto op = operands[0].cast<AttrElementT>();
+ if (isa<AttrElementT>(operands[0])) {
+ auto op = cast<AttrElementT>(operands[0]);
auto res = calculate(op.getValue());
if (!res)
return {};
return AttrElementT::get(op.getType(), *res);
}
- if (operands[0].isa<SplatElementsAttr>()) {
+ if (isa<SplatElementsAttr>(operands[0])) {
// Both operands are splats so we can avoid expanding the values out and
// just fold based on the splat value.
- auto op = operands[0].cast<SplatElementsAttr>();
+ auto op = cast<SplatElementsAttr>(operands[0]);
auto elementResult = calculate(op.getSplatValue<ElementValueT>());
if (!elementResult)
return {};
return DenseElementsAttr::get(op.getType(), *elementResult);
- } else if (operands[0].isa<ElementsAttr>()) {
+ } else if (isa<ElementsAttr>(operands[0])) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
- auto op = operands[0].cast<ElementsAttr>();
+ auto op = cast<ElementsAttr>(operands[0]);
auto opIt = op.value_begin<ElementValueT>();
SmallVector<ElementValueT> elementResults;
@@ -216,18 +216,18 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
if (!operands[0])
return {};
- if (operands[0].isa<AttrElementT>()) {
- auto op = operands[0].cast<AttrElementT>();
+ if (isa<AttrElementT>(operands[0])) {
+ auto op = cast<AttrElementT>(operands[0]);
bool castStatus = true;
auto res = calculate(op.getValue(), castStatus);
if (!castStatus)
return {};
return TargetAttrElementT::get(resType, res);
}
- if (operands[0].isa<SplatElementsAttr>()) {
+ if (isa<SplatElementsAttr>(operands[0])) {
// The operand is a splat so we can avoid expanding the values out and
// just fold based on the splat value.
- auto op = operands[0].cast<SplatElementsAttr>();
+ auto op = cast<SplatElementsAttr>(operands[0]);
bool castStatus = true;
auto elementResult =
calculate(op.getSplatValue<ElementValueT>(), castStatus);
@@ -235,10 +235,10 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
return {};
return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
}
- if (operands[0].isa<ElementsAttr>()) {
+ if (isa<ElementsAttr>(operands[0])) {
// Operand is ElementsAttr-derived; perform an element-wise fold by
// expanding the value.
- auto op = operands[0].cast<ElementsAttr>();
+ auto op = cast<ElementsAttr>(operands[0]);
bool castStatus = true;
auto opIt = op.value_begin<ElementValueT>();
SmallVector<TargetElementValueT> elementResults;
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
index 29bab1dec8638..4af4c65148da3 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
@@ -73,7 +73,7 @@ class ValueDecomposer {
[callback = std::forward<FnT>(callback)](
OpBuilder &builder, Location loc, Type type, Value value,
SmallVectorImpl<Value> &newValues) -> std::optional<LogicalResult> {
- if (T derivedType = type.dyn_cast<T>())
+ if (T derivedType = dyn_cast<T>(type))
return callback(builder, loc, derivedType, value, newValues);
return std::nullopt;
};
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 728f95291a699..3725cdd838b9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -222,7 +222,7 @@ SmallVector<IntT> convertArrayToIndices(ArrayRef<Attribute> attrs) {
SmallVector<IntT> indices;
indices.reserve(attrs.size());
for (Attribute attr : attrs)
- indices.push_back(attr.cast<IntegerAttr>().getInt());
+ indices.push_back(cast<IntegerAttr>(attr).getInt());
return indices;
}
diff --git a/mlir/include/mlir/Dialect/Quant/UniformSupport.h b/mlir/include/mlir/Dialect/Quant/UniformSupport.h
index 2a26aa8e557f1..9ea6d1729e960 100644
--- a/mlir/include/mlir/Dialect/Quant/UniformSupport.h
+++ b/mlir/include/mlir/Dialect/Quant/UniformSupport.h
@@ -67,7 +67,7 @@ class UniformQuantizedValueConverter {
static_cast<double>(uniformType.getStorageTypeMin()),
static_cast<double>(uniformType.getStorageTypeMax()),
uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
- assert(uniformType.getExpressedType().isa<FloatType>());
+ assert(isa<FloatType>(uniformType.getExpressedType()));
assert(uniformType.getStorageType().isSignlessInteger());
}
@@ -184,7 +184,7 @@ class UniformQuantizedPerAxisValueConverter {
storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
isSigned(uniformType.isSigned()),
quantizationDim(uniformType.getQuantizedDimension()) {
- assert(uniformType.getExpressedType().isa<FloatType>());
+ assert(isa<FloatType>(uniformType.getExpressedType()));
assert(uniformType.getStorageType().isSignlessInteger());
assert(scales.size() == zeroPoints.size());
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 481c2e629198c..72c1da8714b54 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -95,7 +95,7 @@ template <typename T>
inline RankedTensorType getRankedTensorType(T &&t) {
assert(static_cast<bool>(std::forward<T>(t)) &&
"getRankedTensorType got null argument");
- return std::forward<T>(t).getType().template cast<RankedTensorType>();
+ return cast<RankedTensorType>(std::forward<T>(t).getType());
}
/// Convenience method to abbreviate casting `getType()`.
@@ -103,7 +103,7 @@ template <typename T>
inline MemRefType getMemRefType(T &&t) {
assert(static_cast<bool>(std::forward<T>(t)) &&
"getMemRefType got null argument");
- return std::forward<T>(t).getType().template cast<MemRefType>();
+ return cast<MemRefType>(std::forward<T>(t).getType());
}
/// Convenience method to get a sparse encoding attribute from a type.
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index b36e40bb8a9ff..f425d376fbc7e 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -51,7 +51,7 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
SmallVector<ShapedType> dynTypes;
SmallVector<Value> dynamicDims;
for (const Value ¶m : params) {
- auto paramTy = param.getType().cast<ShapedType>();
+ auto paramTy = cast<ShapedType>(param.getType());
if (!paramTy.hasStaticShape())
dynTypes.push_back(paramTy);
}
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
index cc846f2dd094a..bdacb98710d33 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
@@ -44,7 +44,7 @@ struct ValueKnowledge {
// Get the static knowledge intrinsic to `type`.
static ValueKnowledge getKnowledgeFromType(Type type) {
ValueKnowledge result = getPessimisticValueState();
- if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ if (auto shapedType = dyn_cast<ShapedType>(type)) {
if (shapedType.hasRank()) {
result.hasRank = true;
result.sizes.reserve(shapedType.getRank());
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index d362524528f1f..8f8adffae50e3 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -42,7 +42,7 @@ class SingleOpMatcherOpTrait
"SingleOpMatchOpTrait is only available on operations with "
"MatchOpInterface");
Value operandHandle = cast<OpTy>(op).getOperandHandle();
- if (!operandHandle.getType().isa<TransformHandleTypeInterface>()) {
+ if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
return op->emitError() << "SingleOpMatchOpTrait requires the op handle "
"to be of TransformHandleTypeInterface";
}
@@ -82,7 +82,7 @@ class SingleValueMatcherOpTrait
"MatchOpInterface");
Value operandHandle = cast<OpTy>(op).getOperandHandle();
- if (!operandHandle.getType().isa<TransformValueHandleTypeInterface>()) {
+ if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
return op->emitError() << "SingleValueMatchOpTrait requires an operand "
"of TransformValueHandleTypeInterface";
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index e2e2354876524..39a86d3828786 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1144,9 +1144,9 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
SmallVector<Operation *> emptyPayload;
SmallVector<Attribute> emptyParams;
for (OpResult r : this->getOperation()->getResults()) {
- if (r.getType().isa<TransformParamTypeInterface>())
+ if (isa<TransformParamTypeInterface>(r.getType()))
transformResults.setParams(r, emptyParams);
- else if (r.getType().isa<TransformValueHandleTypeInterface>())
+ else if (isa<TransformValueHandleTypeInterface>(r.getType()))
transformResults.setValues(r, ValueRange());
else
transformResults.set(r, emptyPayload);
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index b7776110d444f..61c929dee0f27 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -92,9 +92,8 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
return reshapeSrcOp.getSrc();
// Reshape of a constant can be replaced with a new constant.
- if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
- return elements.reshape(
- reshapeOp.getResult().getType().template cast<ShapedType>());
+ if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
+ return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
}
return nullptr;
}
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index ed25021e421f4..2edb239d5a61b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -169,12 +169,12 @@ Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
/// Returns true if `attr` has "parallel" iterator type semantics.
inline bool isParallelIterator(Attribute attr) {
- return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::parallel;
+ return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::parallel;
}
/// Returns true if `attr` has "reduction" iterator type semantics.
inline bool isReductionIterator(Attribute attr) {
- return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
+ return cast<IteratorTypeAttr>(attr).getValue() == IteratorType::reduction;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index acfe40484cf63..4ead5ec40ed93 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -39,11 +39,11 @@ reifyResultShapes(OpBuilder &b, Operation *op,
class ShapeAdaptor {
public:
ShapeAdaptor(Type t) {
- if (auto st = t.dyn_cast<ShapedType>())
+ if (auto st = dyn_cast<ShapedType>(t))
val = st;
}
ShapeAdaptor(Attribute t) {
- if (auto da = t.dyn_cast<DenseIntElementsAttr>())
+ if (auto da = dyn_cast<DenseIntElementsAttr>(t))
val = da;
}
ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index ac71b73ae1e34..29e1cf53abbae 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -244,7 +244,7 @@ struct DstValueBoundsOpInterfaceExternalModel
auto dstOp = cast<DestinationStyleOpInterface>(op);
assert(value.getDefiningOp() == dstOp);
- Value tiedOperand = dstOp.getTiedOpOperand(value.cast<OpResult>())->get();
+ Value tiedOperand = dstOp.getTiedOpOperand(cast<OpResult>(value))->get();
cstr.bound(value)[dim] == cstr.getExpr(tiedOperand, dim);
}
};
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index c31ac29a4cd84..020c8ce9ab4ed 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -358,7 +358,7 @@ class TypeConverter {
return [callback = std::forward<FnT>(callback)](
Type type, SmallVectorImpl<Type> &results,
ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
- T derivedType = type.dyn_cast<T>();
+ T derivedType = dyn_cast<T>(type);
if (!derivedType)
return std::nullopt;
return callback(derivedType, results, callStack);
@@ -380,7 +380,7 @@ class TypeConverter {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
- if (T derivedType = resultType.dyn_cast<T>())
+ if (T derivedType = dyn_cast<T>(resultType))
return callback(builder, derivedType, inputs, loc);
return std::nullopt;
};
@@ -395,8 +395,8 @@ class TypeConverter {
wrapTypeAttributeConversion(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
Type type, Attribute attr) -> AttributeConversionResult {
- if (T derivedType = type.dyn_cast<T>()) {
- if (A derivedAttr = attr.dyn_cast_or_null<A>())
+ if (T derivedType = dyn_cast<T>(type)) {
+ if (A derivedAttr = dyn_cast_or_null<A>(attr))
return callback(derivedType, derivedAttr);
}
return AttributeConversionResult::na();
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index 73ddd81aa12b4..f205fabbac7c7 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -59,11 +59,11 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
}
unsigned firstInputIndex, lastInputIndex;
if (region) {
- firstInputIndex = inputs[0].cast<BlockArgument>().getArgNumber();
- lastInputIndex = inputs.back().cast<BlockArgument>().getArgNumber();
+ firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber();
+ lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber();
} else {
- firstInputIndex = inputs[0].cast<OpResult>().getResultNumber();
- lastInputIndex = inputs.back().cast<OpResult>().getResultNumber();
+ firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber();
+ lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber();
}
if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) {
output.push_back(inputValue);
@@ -186,9 +186,9 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
}
--maxDepth;
- if (BlockArgument arg = value.dyn_cast<BlockArgument>())
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value))
return collectUnderlyingAddressValues(arg, maxDepth, visited, output);
- collectUnderlyingAddressValues(value.cast<OpResult>(), maxDepth, visited,
+ collectUnderlyingAddressValues(cast<OpResult>(value), maxDepth, visited,
output);
}
@@ -216,10 +216,10 @@ getAllocEffectFor(Value value,
Operation *&allocScopeOp) {
// Try to get a memory effect interface for the parent operation.
Operation *op;
- if (BlockArgument arg = value.dyn_cast<BlockArgument>())
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value))
op = arg.getOwner()->getParentOp();
else
- op = value.cast<OpResult>().getOwner();
+ op = cast<OpResult>(value).getOwner();
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
if (!interface)
return failure();
@@ -305,7 +305,7 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
if (rhsParentOp->isProperAncestor(lhsAllocScope))
return AliasResult::NoAlias;
if (rhsParentOp == lhsAllocScope) {
- BlockArgument rhsArg = rhs.dyn_cast<BlockArgument>();
+ BlockArgument rhsArg = dyn_cast<BlockArgument>(rhs);
if (rhsArg && rhs.getParentBlock()->isEntryBlock())
return AliasResult::NoAlias;
}
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index aa079cff5cd7c..c866fc610bc8e 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -94,7 +94,7 @@ void IntegerRangeAnalysis::visitOperation(
}));
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
- auto result = v.dyn_cast<OpResult>();
+ auto result = dyn_cast<OpResult>(v);
if (!result)
return;
assert(llvm::is_contained(op->getResults(), result));
@@ -139,7 +139,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
}));
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
- auto arg = v.dyn_cast<BlockArgument>();
+ auto arg = dyn_cast<BlockArgument>(v);
if (!arg)
return;
if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
@@ -179,7 +179,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
if (loopBound.has_value()) {
if (loopBound->is<Attribute>()) {
if (auto bound =
- loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
+ dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
return bound.getValue();
} else if (auto value = loopBound->dyn_cast<Value>()) {
const IntegerValueRangeLattice *lattice =
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index d3bc80603ad18..629c482edab22 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -240,7 +240,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
if (inputs.size() != lattices.size()) {
if (point.dyn_cast<Operation *>()) {
if (!inputs.empty())
- firstIndex = inputs.front().cast<OpResult>().getResultNumber();
+ firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
visitNonControlFlowArgumentsImpl(
branch,
RegionSuccessor(
@@ -248,7 +248,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
lattices, firstIndex);
} else {
if (!inputs.empty())
- firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
+ firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
Region *region = point.get<Block *>()->getParent();
visitNonControlFlowArgumentsImpl(
branch,
diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp
index 7c04bb4ade79e..a8e0daeabf406 100644
--- a/mlir/lib/Analysis/Liveness.cpp
+++ b/mlir/lib/Analysis/Liveness.cpp
@@ -184,7 +184,7 @@ Liveness::OperationListT Liveness::resolveLiveness(Value value) const {
if (Operation *defOp = value.getDefiningOp())
currentBlock = defOp->getBlock();
else
- currentBlock = value.cast<BlockArgument>().getOwner();
+ currentBlock = cast<BlockArgument>(value).getOwner();
toProcess.push_back(currentBlock);
visited.insert(currentBlock);
@@ -280,7 +280,7 @@ void Liveness::print(raw_ostream &os) const {
if (value.getDefiningOp())
os << "val_" << valueIds[value];
else {
- auto blockArg = value.cast<BlockArgument>();
+ auto blockArg = cast<BlockArgument>(value);
os << "arg" << blockArg.getArgNumber() << "@"
<< blockIds[blockArg.getOwner()];
}
@@ -404,7 +404,7 @@ LivenessBlockInfo::currentlyLiveValues(Operation *op) const {
Operation *endOfLiveRange = nullptr;
// If it's a live in or a block argument, then the start is the beginning
// of the block.
- if (isLiveIn(value) || value.isa<BlockArgument>())
+ if (isLiveIn(value) || isa<BlockArgument>(value))
startOfLiveRange = &block->front();
else
startOfLiveRange = block->findAncestorOpInBlock(*startOfLiveRange);
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index bcb23af8a9a22..7af6a65cef99a 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -95,7 +95,7 @@ static void getBackwardSliceImpl(Operation *op,
if (auto *definingOp = operand.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
getBackwardSliceImpl(definingOp, backwardSlice, filter);
- } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
+ } else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
// TODO: determine whether we want to recurse backward into the other
@@ -132,7 +132,7 @@ void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
getBackwardSlice(definingOp, backwardSlice, filter, inclusive);
return;
}
- Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
+ Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
getBackwardSlice(bbAargOwner, backwardSlice, filter, inclusive);
}
diff --git a/mlir/lib/AsmParser/AsmParserState.cpp b/mlir/lib/AsmParser/AsmParserState.cpp
index 29e1f40c5de88..e61aba5a7fe7c 100644
--- a/mlir/lib/AsmParser/AsmParserState.cpp
+++ b/mlir/lib/AsmParser/AsmParserState.cpp
@@ -73,7 +73,7 @@ void AsmParserState::Impl::resolveSymbolUses() {
for (auto &it : *opAndUseMapIt.second) {
symbolOps.clear();
if (failed(symbolTable.lookupSymbolIn(
- opAndUseMapIt.first, it.first.cast<SymbolRefAttr>(), symbolOps)))
+ opAndUseMapIt.first, cast<SymbolRefAttr>(it.first), symbolOps)))
continue;
for (ArrayRef<SMRange> useRange : it.second) {
@@ -301,7 +301,7 @@ void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) {
}
// Otherwise, this is a block argument.
- BlockArgument arg = value.cast<BlockArgument>();
+ BlockArgument arg = cast<BlockArgument>(value);
auto existingIt = impl->blocksToIdx.find(arg.getOwner());
assert(existingIt != impl->blocksToIdx.end() &&
"expected valid block definition for block argument");
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 1a491bacae7e8..95017942165f7 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -348,7 +348,7 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
else if (!(type = parseType()))
return nullptr;
}
- if (!type.isa<FloatType>())
+ if (!isa<FloatType>(type))
return (emitError("floating point value not valid for specified type"),
nullptr);
return FloatAttr::get(type, isNegative ? -*val : *val);
@@ -416,7 +416,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
return nullptr;
}
- if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (auto floatType = dyn_cast<FloatType>(type)) {
std::optional<APFloat> result;
if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
floatType.getFloatSemantics(),
@@ -425,7 +425,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
return FloatAttr::get(floatType, *result);
}
- if (!type.isa<IntegerType, IndexType>())
+ if (!isa<IntegerType, IndexType>(type))
return emitError(loc, "integer literal not valid for specified type"),
nullptr;
@@ -543,7 +543,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
// Check to see if we parse the literal from a hex string.
if (hexStorage &&
- (eltType.isIntOrIndexOrFloat() || eltType.isa<ComplexType>()))
+ (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType)))
return getHexAttr(loc, type);
// Check that the parsed storage size has the same number of elements to the
@@ -563,7 +563,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
// Handle complex types in the specific element type cases below.
bool isComplex = false;
- if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
+ if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) {
eltType = complexTy.getElementType();
isComplex = true;
}
@@ -583,7 +583,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
return DenseElementsAttr::get(type, intValues);
}
// Handle floating point types.
- if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
+ if (FloatType floatTy = dyn_cast<FloatType>(eltType)) {
std::vector<APFloat> floatValues;
if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
return nullptr;
@@ -711,7 +711,7 @@ DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
/// Build a Dense attribute with hex data for the given type.
DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
Type elementType = type.getElementType();
- if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
+ if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) {
p.emitError(loc)
<< "expected floating-point, integer, or complex element type, got "
<< elementType;
@@ -904,7 +904,7 @@ ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
Token token = p.getToken();
std::optional<APFloat> result;
- auto floatType = type.cast<FloatType>();
+ auto floatType = cast<FloatType>(type);
if (p.consumeIf(Token::integer)) {
// Parse an integer literal as a float.
if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
@@ -1025,7 +1025,7 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
return nullptr;
}
- ShapedType shapedType = attrType.dyn_cast<ShapedType>();
+ ShapedType shapedType = dyn_cast<ShapedType>(attrType);
if (!shapedType) {
emitError(typeLoc, "`dense_resource` expected a shaped type");
return nullptr;
@@ -1048,7 +1048,7 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
return nullptr;
}
- auto sType = type.dyn_cast<ShapedType>();
+ auto sType = dyn_cast<ShapedType>(type);
if (!sType) {
emitError("elements literal must be a shaped type");
return nullptr;
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index c98b36862fb53..27981451502d2 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -260,7 +260,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
});
// Ensure that the attribute has the same type as requested.
- auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
+ auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
if (type && typedAttr && typedAttr.getType() != type) {
emitError("attribute type
diff erent than expected: expected ")
<< type << ", but got " << typedAttr.getType();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index ade2465d7fd0a..69116ef39741b 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -1333,7 +1333,7 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
auto type = parseType();
if (!type)
return failure();
- auto fnType = type.dyn_cast<FunctionType>();
+ auto fnType = dyn_cast<FunctionType>(type);
if (!fnType)
return mlir::emitError(typeLoc, "expected function type");
@@ -2352,7 +2352,7 @@ ParseResult OperationParser::codeCompleteSSAUse() {
if (!forwardRefPlaceholders.count(result))
detailOS << result.getOwner()->getName() << ": ";
} else {
- detailOS << "arg #" << frontValue.cast<BlockArgument>().getArgNumber()
+ detailOS << "arg #" << cast<BlockArgument>(frontValue).getArgNumber()
<< ": ";
}
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index c5e32975bb399..749b82c2ed4c6 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -241,7 +241,7 @@ class Parser {
return std::nullopt;
if (Attribute parsedAttr = parseAttribute(type)) {
- attr = parsedAttr.cast<AttributeT>();
+ attr = cast<AttributeT>(parsedAttr);
return success();
}
return failure();
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 737767ce9101b..211049204b268 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -129,7 +129,7 @@ Type Parser::parseComplexType() {
if (!elementType ||
parseToken(Token::greater, "expected '>' in complex type"))
return nullptr;
- if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
+ if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
return emitError(elementTypeLoc, "invalid element type for complex"),
nullptr;
@@ -207,8 +207,8 @@ Type Parser::parseMemRefType() {
if (!attr)
return failure();
- if (attr.isa<MemRefLayoutAttrInterface>()) {
- layout = attr.cast<MemRefLayoutAttrInterface>();
+ if (isa<MemRefLayoutAttrInterface>(attr)) {
+ layout = cast<MemRefLayoutAttrInterface>(attr);
} else if (memorySpace) {
return emitError("multiple memory spaces specified in memref type");
} else {
@@ -383,7 +383,7 @@ Type Parser::parseTensorType() {
Attribute encoding;
if (consumeIf(Token::comma)) {
encoding = parseAttribute();
- if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) {
+ if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
if (failed(v.verifyEncoding(dimensions, elementType,
[&] { return emitError(); })))
return nullptr;
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index e39c56885c205..9344ec9214c18 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -785,7 +785,7 @@ class AttrTypeReader {
Attribute baseResult;
if (failed(parseAttribute(reader, baseResult)))
return failure();
- if ((result = baseResult.dyn_cast<T>()))
+ if ((result = dyn_cast<T>(baseResult)))
return success();
return reader.emitError("expected attribute of type: ",
llvm::getTypeName<T>(), ", but got: ", baseResult);
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
index 7f56e9a94d299..f3a153178e8d2 100644
--- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -180,7 +180,7 @@ void IRNumberingState::number(Attribute attr) {
// have a registered dialect when it got created. We don't want to encode this
// as the builtin OpaqueAttr, we want to encode it as if the dialect was
// actually loaded.
- if (OpaqueAttr opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
+ if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
return;
}
@@ -310,7 +310,7 @@ void IRNumberingState::number(Type type) {
// registered dialect when it got created. We don't want to encode this as the
// builtin OpaqueType, we want to encode it as if the dialect was actually
// loaded.
- if (OpaqueType opaqueType = type.dyn_cast<OpaqueType>()) {
+ if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
return;
}
diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp
index 497b2cb1fb3ff..bd8b13c6516e2 100644
--- a/mlir/lib/CAPI/Dialect/PDL.cpp
+++ b/mlir/lib/CAPI/Dialect/PDL.cpp
@@ -21,7 +21,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PDL, pdl, pdl::PDLDialect)
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLType(MlirType type) {
- return unwrap(type).isa<pdl::PDLType>();
+ return isa<pdl::PDLType>(unwrap(type));
}
//===---------------------------------------------------------------------===//
@@ -29,7 +29,7 @@ bool mlirTypeIsAPDLType(MlirType type) {
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLAttributeType(MlirType type) {
- return unwrap(type).isa<pdl::AttributeType>();
+ return isa<pdl::AttributeType>(unwrap(type));
}
MlirType mlirPDLAttributeTypeGet(MlirContext ctx) {
@@ -41,7 +41,7 @@ MlirType mlirPDLAttributeTypeGet(MlirContext ctx) {
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLOperationType(MlirType type) {
- return unwrap(type).isa<pdl::OperationType>();
+ return isa<pdl::OperationType>(unwrap(type));
}
MlirType mlirPDLOperationTypeGet(MlirContext ctx) {
@@ -53,7 +53,7 @@ MlirType mlirPDLOperationTypeGet(MlirContext ctx) {
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLRangeType(MlirType type) {
- return unwrap(type).isa<pdl::RangeType>();
+ return isa<pdl::RangeType>(unwrap(type));
}
MlirType mlirPDLRangeTypeGet(MlirType elementType) {
@@ -61,7 +61,7 @@ MlirType mlirPDLRangeTypeGet(MlirType elementType) {
}
MlirType mlirPDLRangeTypeGetElementType(MlirType type) {
- return wrap(unwrap(type).cast<pdl::RangeType>().getElementType());
+ return wrap(cast<pdl::RangeType>(unwrap(type)).getElementType());
}
//===---------------------------------------------------------------------===//
@@ -69,7 +69,7 @@ MlirType mlirPDLRangeTypeGetElementType(MlirType type) {
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLTypeType(MlirType type) {
- return unwrap(type).isa<pdl::TypeType>();
+ return isa<pdl::TypeType>(unwrap(type));
}
MlirType mlirPDLTypeTypeGet(MlirContext ctx) {
@@ -81,7 +81,7 @@ MlirType mlirPDLTypeTypeGet(MlirContext ctx) {
//===---------------------------------------------------------------------===//
bool mlirTypeIsAPDLValueType(MlirType type) {
- return unwrap(type).isa<pdl::ValueType>();
+ return isa<pdl::ValueType>(unwrap(type));
}
MlirType mlirPDLValueTypeGet(MlirContext ctx) {
diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp
index 065ab3e366351..0a7181d8bc17c 100644
--- a/mlir/lib/CAPI/Dialect/Quant.cpp
+++ b/mlir/lib/CAPI/Dialect/Quant.cpp
@@ -20,7 +20,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect)
//===---------------------------------------------------------------------===//
bool mlirTypeIsAQuantizedType(MlirType type) {
- return unwrap(type).isa<quant::QuantizedType>();
+ return isa<quant::QuantizedType>(unwrap(type));
}
unsigned mlirQuantizedTypeGetSignedFlag() {
@@ -40,39 +40,37 @@ int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned,
}
MlirType mlirQuantizedTypeGetExpressedType(MlirType type) {
- return wrap(unwrap(type).cast<quant::QuantizedType>().getExpressedType());
+ return wrap(cast<quant::QuantizedType>(unwrap(type)).getExpressedType());
}
unsigned mlirQuantizedTypeGetFlags(MlirType type) {
- return unwrap(type).cast<quant::QuantizedType>().getFlags();
+ return cast<quant::QuantizedType>(unwrap(type)).getFlags();
}
bool mlirQuantizedTypeIsSigned(MlirType type) {
- return unwrap(type).cast<quant::QuantizedType>().isSigned();
+ return cast<quant::QuantizedType>(unwrap(type)).isSigned();
}
MlirType mlirQuantizedTypeGetStorageType(MlirType type) {
- return wrap(unwrap(type).cast<quant::QuantizedType>().getStorageType());
+ return wrap(cast<quant::QuantizedType>(unwrap(type)).getStorageType());
}
int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) {
- return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMin();
+ return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMin();
}
int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) {
- return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMax();
+ return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeMax();
}
unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) {
- return unwrap(type)
- .cast<quant::QuantizedType>()
- .getStorageTypeIntegralWidth();
+ return cast<quant::QuantizedType>(unwrap(type)).getStorageTypeIntegralWidth();
}
bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type,
MlirType candidate) {
- return unwrap(type).cast<quant::QuantizedType>().isCompatibleExpressedType(
- unwrap(candidate));
+ return cast<quant::QuantizedType>(unwrap(type))
+ .isCompatibleExpressedType(unwrap(candidate));
}
MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
@@ -81,19 +79,19 @@ MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) {
MlirType mlirQuantizedTypeCastFromStorageType(MlirType type,
MlirType candidate) {
- return wrap(unwrap(type).cast<quant::QuantizedType>().castFromStorageType(
- unwrap(candidate)));
+ return wrap(cast<quant::QuantizedType>(unwrap(type))
+ .castFromStorageType(unwrap(candidate)));
}
MlirType mlirQuantizedTypeCastToStorageType(MlirType type) {
return wrap(quant::QuantizedType::castToStorageType(
- unwrap(type).cast<quant::QuantizedType>()));
+ cast<quant::QuantizedType>(unwrap(type))));
}
MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type,
MlirType candidate) {
- return wrap(unwrap(type).cast<quant::QuantizedType>().castFromExpressedType(
- unwrap(candidate)));
+ return wrap(cast<quant::QuantizedType>(unwrap(type))
+ .castFromExpressedType(unwrap(candidate)));
}
MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
@@ -102,9 +100,8 @@ MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) {
MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
MlirType candidate) {
- return wrap(
- unwrap(type).cast<quant::QuantizedType>().castExpressedToStorageType(
- unwrap(candidate)));
+ return wrap(cast<quant::QuantizedType>(unwrap(type))
+ .castExpressedToStorageType(unwrap(candidate)));
}
//===---------------------------------------------------------------------===//
@@ -112,7 +109,7 @@ MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type,
//===---------------------------------------------------------------------===//
bool mlirTypeIsAAnyQuantizedType(MlirType type) {
- return unwrap(type).isa<quant::AnyQuantizedType>();
+ return isa<quant::AnyQuantizedType>(unwrap(type));
}
MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
@@ -128,7 +125,7 @@ MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
//===---------------------------------------------------------------------===//
bool mlirTypeIsAUniformQuantizedType(MlirType type) {
- return unwrap(type).isa<quant::UniformQuantizedType>();
+ return isa<quant::UniformQuantizedType>(unwrap(type));
}
MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
@@ -141,15 +138,15 @@ MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
}
double mlirUniformQuantizedTypeGetScale(MlirType type) {
- return unwrap(type).cast<quant::UniformQuantizedType>().getScale();
+ return cast<quant::UniformQuantizedType>(unwrap(type)).getScale();
}
int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) {
- return unwrap(type).cast<quant::UniformQuantizedType>().getZeroPoint();
+ return cast<quant::UniformQuantizedType>(unwrap(type)).getZeroPoint();
}
bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
- return unwrap(type).cast<quant::UniformQuantizedType>().isFixedPoint();
+ return cast<quant::UniformQuantizedType>(unwrap(type)).isFixedPoint();
}
//===---------------------------------------------------------------------===//
@@ -157,7 +154,7 @@ bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) {
//===---------------------------------------------------------------------===//
bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
- return unwrap(type).isa<quant::UniformQuantizedPerAxisType>();
+ return isa<quant::UniformQuantizedPerAxisType>(unwrap(type));
}
MlirType mlirUniformQuantizedPerAxisTypeGet(
@@ -172,33 +169,29 @@ MlirType mlirUniformQuantizedPerAxisTypeGet(
}
intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
- return unwrap(type)
- .cast<quant::UniformQuantizedPerAxisType>()
+ return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
.getScales()
.size();
}
double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) {
- return unwrap(type)
- .cast<quant::UniformQuantizedPerAxisType>()
+ return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
.getScales()[pos];
}
int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type,
intptr_t pos) {
- return unwrap(type)
- .cast<quant::UniformQuantizedPerAxisType>()
+ return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
.getZeroPoints()[pos];
}
int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) {
- return unwrap(type)
- .cast<quant::UniformQuantizedPerAxisType>()
+ return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
.getQuantizedDimension();
}
bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
- return unwrap(type).cast<quant::UniformQuantizedPerAxisType>().isFixedPoint();
+ return cast<quant::UniformQuantizedPerAxisType>(unwrap(type)).isFixedPoint();
}
//===---------------------------------------------------------------------===//
@@ -206,7 +199,7 @@ bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) {
//===---------------------------------------------------------------------===//
bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
- return unwrap(type).isa<quant::CalibratedQuantizedType>();
+ return isa<quant::CalibratedQuantizedType>(unwrap(type));
}
MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
@@ -216,9 +209,9 @@ MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
}
double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
- return unwrap(type).cast<quant::CalibratedQuantizedType>().getMin();
+ return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMin();
}
double mlirCalibratedQuantizedTypeGetMax(MlirType type) {
- return unwrap(type).cast<quant::CalibratedQuantizedType>().getMax();
+ return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMax();
}
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 1aa6d329d41ca..795ce51ff9f07 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -42,7 +42,7 @@ static_assert(
"MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
- return unwrap(attr).isa<SparseTensorEncodingAttr>();
+ return isa<SparseTensorEncodingAttr>(unwrap(attr));
}
MlirAttribute mlirSparseTensorEncodingAttrGet(
@@ -60,29 +60,28 @@ MlirAttribute mlirSparseTensorEncodingAttrGet(
}
MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<SparseTensorEncodingAttr>().getDimOrdering());
+ return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getDimOrdering());
}
MlirAffineMap
mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) {
- return wrap(
- unwrap(attr).cast<SparseTensorEncodingAttr>().getHigherOrdering());
+ return wrap(cast<SparseTensorEncodingAttr>(unwrap(attr)).getHigherOrdering());
}
intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
- return unwrap(attr).cast<SparseTensorEncodingAttr>().getLvlRank();
+ return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
}
MlirSparseTensorDimLevelType
mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) {
return static_cast<MlirSparseTensorDimLevelType>(
- unwrap(attr).cast<SparseTensorEncodingAttr>().getLvlType(lvl));
+ cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
}
int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
- return unwrap(attr).cast<SparseTensorEncodingAttr>().getPosWidth();
+ return cast<SparseTensorEncodingAttr>(unwrap(attr)).getPosWidth();
}
int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) {
- return unwrap(attr).cast<SparseTensorEncodingAttr>().getCrdWidth();
+ return cast<SparseTensorEncodingAttr>(unwrap(attr)).getCrdWidth();
}
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 606b301ccb746..90594b67aacfb 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -22,7 +22,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform,
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyOpType(MlirType type) {
- return unwrap(type).isa<transform::AnyOpType>();
+ return isa<transform::AnyOpType>(unwrap(type));
}
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
@@ -34,7 +34,7 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformOperationType(MlirType type) {
- return unwrap(type).isa<transform::OperationType>();
+ return isa<transform::OperationType>(unwrap(type));
}
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
@@ -44,5 +44,5 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
}
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
- return wrap(unwrap(type).cast<transform::OperationType>().getOperationName());
+ return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index eeac499f49196..1769b1fa82624 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -48,7 +48,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Location loc = gpuOp.getLoc();
Value memref = adaptor.getMemref();
Value unconvertedMemref = gpuOp.getMemref();
- MemRefType memrefType = unconvertedMemref.getType().cast<MemRefType>();
+ MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
if (chipset.majorVersion < 9)
return gpuOp.emitOpError("Raw buffer ops require GCN or higher");
@@ -85,13 +85,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
// so bitcast any floats to integers.
Type llvmBufferValType = llvmWantedDataType;
if (atomicCmpData) {
- if (wantedDataType.isa<VectorType>())
+ if (isa<VectorType>(wantedDataType))
return gpuOp.emitOpError("vector compare-and-swap does not exist");
- if (auto floatType = wantedDataType.dyn_cast<FloatType>())
+ if (auto floatType = dyn_cast<FloatType>(wantedDataType))
llvmBufferValType = this->getTypeConverter()->convertType(
rewriter.getIntegerType(floatType.getWidth()));
}
- if (auto dataVector = wantedDataType.dyn_cast<VectorType>()) {
+ if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t elemBits = dataVector.getElementTypeBitWidth();
uint32_t totalBits = elemBits * dataVector.getNumElements();
if (totalBits > maxVectorOpWidth)
@@ -312,7 +312,7 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
Location loc, Value input) {
Type inputType = input.getType();
- if (auto vectorType = inputType.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (!vectorType.getElementType().isInteger(8))
return input;
int64_t numBytes = vectorType.getNumElements();
@@ -342,10 +342,10 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
b = mfma.getBlocks();
Type sourceElem = mfma.getSourceA().getType();
- if (auto sourceType = sourceElem.dyn_cast<VectorType>())
+ if (auto sourceType = dyn_cast<VectorType>(sourceElem))
sourceElem = sourceType.getElementType();
Type destElem = mfma.getDestC().getType();
- if (auto destType = destElem.dyn_cast<VectorType>())
+ if (auto destType = dyn_cast<VectorType>(destElem))
destElem = destType.getElementType();
if (sourceElem.isF32() && destElem.isF32()) {
@@ -406,7 +406,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
}
- if (sourceElem.isa<IntegerType>() && destElem.isInteger(32)) {
+ if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
if (m == 32 && n == 32 && k == 4 && b == 2)
return ROCDL::mfma_i32_32x32x4i8::getOperationName();
if (m == 16 && n == 16 && k == 4 && b == 4)
@@ -435,7 +435,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
// Known to be correct because there are no scalar f8 instructions and
// because a length mismatch will have been caught by the verifier.
Type sourceBElem =
- mfma.getSourceB().getType().cast<VectorType>().getElementType();
+ cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
@@ -453,7 +453,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() &&
chipset.minorVersion >= 0x40) {
Type sourceBElem =
- mfma.getSourceB().getType().cast<VectorType>().getElementType();
+ cast<VectorType>(mfma.getSourceB().getType()).getElementType();
if (m == 16 && n == 16 && k == 32 && b == 1) {
if (sourceBElem.isFloat8E5M2FNUZ())
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index debb7e804b652..783745a51822e 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -226,7 +226,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
Type resultType = std::get<1>(pair);
std::optional<arith::AtomicRMWKind> reductionOp =
arith::symbolizeAtomicRMWKind(
- static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
+ static_cast<uint64_t>(cast<IntegerAttr>(reduction).getInt()));
assert(reductionOp && "Reduction operation cannot be of None Type");
arith::AtomicRMWKind reductionOpValue = *reductionOp;
identityVals.push_back(
@@ -246,7 +246,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
// For each of the reduction operations get the respective mlir::Value.
std::optional<arith::AtomicRMWKind> reductionOp =
arith::symbolizeAtomicRMWKind(
- reductions[i].cast<IntegerAttr>().getInt());
+ cast<IntegerAttr>(reductions[i]).getInt());
assert(reductionOp && "Reduction Operation cannot be of None Type");
arith::AtomicRMWKind reductionOpValue = *reductionOp;
rewriter.setInsertionPoint(&parOp.getBody()->back());
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 4651c29997f8e..3b4b6452a8ec3 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -210,7 +210,7 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
// Handle the scalar and 1D vector cases.
Type operandType = adaptor.getIn().getType();
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(operandType)) {
Type targetType = this->typeConverter->convertType(resultType);
if (targetBits < sourceBits)
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
@@ -220,7 +220,7 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
return success();
}
- if (!resultType.isa<VectorType>())
+ if (!isa<VectorType>(resultType))
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
@@ -255,7 +255,7 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
Location loc = op.getLoc();
// Handle the scalar and 1D vector cases.
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(operandType)) {
Type newOverflowType = typeConverter->convertType(overflowResultType);
Type structType =
LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
@@ -269,7 +269,7 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
return success();
}
- if (!sumResultType.isa<VectorType>())
+ if (!isa<VectorType>(sumResultType))
return rewriter.notifyMatchFailure(loc, "expected vector result types");
return rewriter.notifyMatchFailure(loc,
@@ -295,16 +295,16 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
// matching extended multiplication intrinsic, perform regular multiplication
// on operands zero-extended to i(2*N) bits, and truncate the results back to
// iN types.
- if (!resultType.isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(resultType)) {
// Shift amount necessary to extract the high bits from widened result.
TypedAttr shiftValAttr;
- if (auto intTy = resultType.dyn_cast<IntegerType>()) {
+ if (auto intTy = dyn_cast<IntegerType>(resultType)) {
unsigned resultBitwidth = intTy.getWidth();
auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
} else {
- auto vecTy = resultType.cast<VectorType>();
+ auto vecTy = cast<VectorType>(resultType);
unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
auto attrTy = VectorType::get(
vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
@@ -330,7 +330,7 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
return success();
}
- if (!resultType.isa<VectorType>())
+ if (!isa<VectorType>(resultType))
return rewriter.notifyMatchFailure(op, "expected vector result type");
return rewriter.notifyMatchFailure(op,
@@ -355,7 +355,7 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
Type resultType = op.getResult().getType();
// Handle the scalar and 1D vector cases.
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(operandType)) {
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
op, typeConverter->convertType(resultType),
convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
@@ -363,7 +363,7 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
return success();
}
- if (!resultType.isa<VectorType>())
+ if (!isa<VectorType>(resultType))
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
@@ -389,7 +389,7 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
Type resultType = op.getResult().getType();
// Handle the scalar and 1D vector cases.
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(operandType)) {
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
op, typeConverter->convertType(resultType),
convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
@@ -397,7 +397,7 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
return success();
}
- if (!resultType.isa<VectorType>())
+ if (!isa<VectorType>(resultType))
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index b6ed24490bb0e..5d2c1f3543214 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -261,9 +261,9 @@ class MinMaxFOpPattern final : public OpConversionPattern<Op> {
/// Converts the given `srcAttr` into a boolean attribute if it holds an
/// integral value. Returns null attribute if conversion fails.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
- if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
+ if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
return boolAttr;
- if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
+ if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
return builder.getBoolAttr(intAttr.getValue().getBoolValue());
return {};
}
@@ -324,7 +324,7 @@ static bool isBoolScalarOrVector(Type type) {
if (type.isInteger(1))
return true;
- if (auto vecType = type.dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(type))
return vecType.getElementType().isInteger(1);
return false;
@@ -337,7 +337,7 @@ static bool hasSameBitwidth(Type a, Type b) {
unsigned bw = 0;
if (type.isIntOrFloat())
bw = type.getIntOrFloatBitWidth();
- else if (auto vecType = type.dyn_cast<VectorType>())
+ else if (auto vecType = dyn_cast<VectorType>(type))
bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
return bw;
};
@@ -369,18 +369,18 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto srcType = constOp.getType().dyn_cast<ShapedType>();
+ auto srcType = dyn_cast<ShapedType>(constOp.getType());
if (!srcType || srcType.getNumElements() == 1)
return failure();
// arith.constant should only have vector or tenor types.
- assert((srcType.isa<VectorType, RankedTensorType>()));
+ assert((isa<VectorType, RankedTensorType>(srcType)));
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
return failure();
- auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
+ auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!dstElementsAttr)
return failure();
@@ -388,7 +388,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
// If the composite type has more than one dimensions, perform linearization.
if (srcType.getRank() > 1) {
- if (srcType.isa<RankedTensorType>()) {
+ if (isa<RankedTensorType>(srcType)) {
dstAttrType = RankedTensorType::get(srcType.getNumElements(),
srcType.getElementType());
dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
@@ -402,19 +402,19 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
Type dstElemType;
// Tensor types are converted to SPIR-V array types; vector types are
// converted to SPIR-V vector/array types.
- if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
dstElemType = arrayType.getElementType();
else
- dstElemType = dstType.cast<VectorType>().getElementType();
+ dstElemType = cast<VectorType>(dstType).getElementType();
// If the source and destination element types are
diff erent, perform
// attribute conversion.
if (srcElemType != dstElemType) {
SmallVector<Attribute, 8> elements;
- if (srcElemType.isa<FloatType>()) {
+ if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
FloatAttr dstAttr =
- convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
+ convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -424,7 +424,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
} else {
for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
IntegerAttr dstAttr = convertIntegerAttr(
- srcAttr, dstElemType.cast<IntegerType>(), rewriter);
+ srcAttr, cast<IntegerType>(dstElemType), rewriter);
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -435,7 +435,7 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
// attributes; element attributes only works with builtin types. So we need
// to prepare another converted builtin types for the destination elements
// attribute.
- if (dstAttrType.isa<RankedTensorType>())
+ if (isa<RankedTensorType>(dstAttrType))
dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
@@ -456,7 +456,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Type srcType = constOp.getType();
- if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
+ if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
if (shapedType.getNumElements() != 1)
return failure();
srcType = shapedType.getElementType();
@@ -465,7 +465,7 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
return failure();
Attribute cstAttr = constOp.getValue();
- if (auto elementsAttr = cstAttr.dyn_cast<DenseElementsAttr>())
+ if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
cstAttr = elementsAttr.getSplatValue<Attribute>();
Type dstType = getTypeConverter()->convertType(srcType);
@@ -473,14 +473,14 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
return failure();
// Floating-point types.
- if (srcType.isa<FloatType>()) {
- auto srcAttr = cstAttr.cast<FloatAttr>();
+ if (isa<FloatType>(srcType)) {
+ auto srcAttr = cast<FloatAttr>(cstAttr);
auto dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
if (srcType != dstType) {
- dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
}
@@ -502,9 +502,9 @@ LogicalResult ConstantScalarOpPattern::matchAndRewrite(
// IndexType or IntegerType. Index values are converted to 32-bit integer
// values when converting to SPIR-V.
- auto srcAttr = cstAttr.cast<IntegerAttr>();
+ auto srcAttr = cast<IntegerAttr>(cstAttr);
IntegerAttr dstAttr =
- convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
+ convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
if (!dstAttr)
return failure();
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
@@ -678,12 +678,12 @@ ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
return getTypeConversionFailure(rewriter, op);
Value allOnes;
- if (auto intTy = dstType.dyn_cast<IntegerType>()) {
+ if (auto intTy = dyn_cast<IntegerType>(dstType)) {
unsigned componentBitwidth = intTy.getWidth();
allOnes = rewriter.create<spirv::ConstantOp>(
loc, intTy,
rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
- } else if (auto vectorTy = dstType.dyn_cast<VectorType>()) {
+ } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
allOnes = rewriter.create<spirv::ConstantOp>(
loc, vectorTy,
@@ -810,7 +810,7 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
// There are no direct corresponding instructions in SPIR-V for such cases.
// Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
- if (auto vectorType = dstType.dyn_cast<VectorType>())
+ if (auto vectorType = dyn_cast<VectorType>(dstType))
type = VectorType::get(vectorType.getShape(), type);
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 691cd23c6ed12..bdbf276d79b22 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,8 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = op.getB().getType().cast<VectorType>().getElementType();
- int length = op.getB().getType().cast<VectorType>().getShape()[0] *
+ Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
Value b2d = op.getB();
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 38041bdec474a..d1998cfb4b642 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -366,12 +366,12 @@ class AsyncRuntimeTypeConverter : public TypeConverter {
static std::optional<Type> convertAsyncTypes(Type type,
bool useOpaquePointers) {
- if (type.isa<TokenType, GroupType, ValueType>())
+ if (isa<TokenType, GroupType, ValueType>(type))
return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers);
- if (type.isa<CoroIdType, CoroStateType>())
+ if (isa<CoroIdType, CoroStateType>(type))
return AsyncAPI::tokenType(type.getContext());
- if (type.isa<CoroHandleType>())
+ if (isa<CoroHandleType>(type))
return AsyncAPI::opaquePointerType(type.getContext(), useOpaquePointers);
return std::nullopt;
@@ -656,14 +656,14 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
Type resultType = op->getResultTypes()[0];
// Tokens creation maps to a simple function call.
- if (resultType.isa<TokenType>()) {
+ if (isa<TokenType>(resultType)) {
rewriter.replaceOpWithNewOp<func::CallOp>(
op, kCreateToken, converter->convertType(resultType));
return success();
}
// To create a value we need to compute the storage requirement.
- if (auto value = resultType.dyn_cast<ValueType>()) {
+ if (auto value = dyn_cast<ValueType>(resultType)) {
// Returns the size requirements for the async value storage.
auto sizeOf = [&](ValueType valueType) -> Value {
auto loc = op->getLoc();
@@ -994,7 +994,7 @@ class RuntimeAddToGroupOpLowering
matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Currently we can only add tokens to the group.
- if (!op.getOperand().getType().isa<TokenType>())
+ if (!isa<TokenType>(op.getOperand().getType()))
return rewriter.notifyMatchFailure(op, "only token type is supported");
// Replace with a runtime API function call.
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 5f649813c78b5..f498d2c359e56 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -41,11 +41,11 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
ConversionPatternRewriter &rewriter) const override {
// Check for unranked memref types which are currently not supported.
Type type = op.getType();
- if (type.isa<UnrankedMemRefType>()) {
+ if (isa<UnrankedMemRefType>(type)) {
return rewriter.notifyMatchFailure(
op, "UnrankedMemRefType is not supported.");
}
- MemRefType memrefType = type.cast<MemRefType>();
+ MemRefType memrefType = cast<MemRefType>(type);
MemRefLayoutAttrInterface layout;
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
diff --git a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
index d35165fee2e0b..3b8338673a5e3 100644
--- a/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
+++ b/mlir/lib/Conversion/ComplexToLibm/ComplexToLibm.cpp
@@ -26,9 +26,9 @@ namespace {
// result type.
struct ComplexTypeResolver {
std::optional<bool> operator()(Type type) const {
- auto complexType = type.cast<ComplexType>();
+ auto complexType = cast<ComplexType>(type);
auto elementType = complexType.getElementType();
- if (!elementType.isa<Float32Type, Float64Type>())
+ if (!isa<Float32Type, Float64Type>(elementType))
return {};
return elementType.getIntOrFloatBitWidth() == 64;
@@ -39,8 +39,8 @@ struct ComplexTypeResolver {
// type.
struct FloatTypeResolver {
std::optional<bool> operator()(Type type) const {
- auto elementType = type.cast<FloatType>();
- if (!elementType.isa<Float32Type, Float64Type>())
+ auto elementType = cast<FloatType>(type);
+ if (!isa<Float32Type, Float64Type>(elementType))
return {};
return elementType.getIntOrFloatBitWidth() == 64;
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 536497607ebeb..9c05cadc2f07e 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -57,7 +57,7 @@ struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto type = op.getType().cast<ComplexType>();
+ auto type = cast<ComplexType>(op.getType());
Type elementType = type.getElementType();
Value lhs = adaptor.getLhs();
@@ -102,10 +102,7 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto type = adaptor.getLhs()
- .getType()
- .template cast<ComplexType>()
- .getElementType();
+ auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
@@ -132,8 +129,8 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
LogicalResult
matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = adaptor.getLhs().getType().template cast<ComplexType>();
- auto elementType = type.getElementType().template cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getLhs().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
@@ -160,8 +157,8 @@ struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto type = adaptor.getComplex().getType().template cast<ComplexType>();
- auto elementType = type.getElementType().template cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
Value real =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -222,8 +219,8 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto type = adaptor.getLhs().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getLhs().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
Value lhsReal =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
@@ -441,8 +438,8 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto type = adaptor.getComplex().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
Value real =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -466,8 +463,8 @@ struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
LogicalResult
matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = adaptor.getComplex().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
@@ -490,8 +487,8 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
LogicalResult
matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = adaptor.getComplex().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
@@ -511,8 +508,8 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
LogicalResult
matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = adaptor.getComplex().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
@@ -550,8 +547,8 @@ struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto type = adaptor.getLhs().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getLhs().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal);
@@ -727,8 +724,8 @@ struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto type = adaptor.getComplex().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
Value real =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
@@ -773,7 +770,7 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- auto type = op.getType().cast<ComplexType>();
+ auto type = cast<ComplexType>(op.getType());
Type elementType = type.getElementType();
Value arg = adaptor.getComplex();
@@ -837,8 +834,8 @@ struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
LogicalResult
matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto type = adaptor.getComplex().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
@@ -881,8 +878,8 @@ struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto type = adaptor.getComplex().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
// The hyperbolic tangent for complex number can be calculated as follows.
// tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
@@ -913,8 +910,8 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto type = adaptor.getComplex().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
Value real =
rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
Value imag =
@@ -933,7 +930,7 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
ComplexType type, Value a, Value b, Value c,
Value d) {
- auto elementType = type.getElementType().cast<FloatType>();
+ auto elementType = cast<FloatType>(type.getElementType());
// Compute (a*a+b*b)^(0.5c).
Value aaPbb = builder.create<arith::AddFOp>(
@@ -995,8 +992,8 @@ struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
- auto type = adaptor.getLhs().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getLhs().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
@@ -1015,8 +1012,8 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
- auto type = adaptor.getComplex().getType().cast<ComplexType>();
- auto elementType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(adaptor.getComplex().getType());
+ auto elementType = cast<FloatType>(type.getElementType());
Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 4fd5b9c43a547..5867d9fa02b06 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -144,13 +144,13 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
size_t argOffset = resultStructType ? 1 : 0;
for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
Value arg = wrapperFuncOp.getArgument(index + argOffset);
- if (auto memrefType = argType.dyn_cast<MemRefType>()) {
+ if (auto memrefType = dyn_cast<MemRefType>(argType)) {
Value loaded = rewriter.create<LLVM::LoadOp>(
loc, typeConverter.convertType(memrefType), arg);
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
continue;
}
- if (argType.isa<UnrankedMemRefType>()) {
+ if (isa<UnrankedMemRefType>(argType)) {
Value loaded = rewriter.create<LLVM::LoadOp>(
loc, typeConverter.convertType(argType), arg);
UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
@@ -218,8 +218,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
if (resultStructType) {
// Allocate the struct on the stack and pass the pointer.
- Type resultType =
- wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
+ Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
Value one = builder.create<LLVM::ConstantOp>(
loc, typeConverter.convertType(builder.getIndexType()),
builder.getIntegerAttr(builder.getIndexType(), 1));
@@ -233,8 +232,8 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
for (Type input : type.getInputs()) {
Value arg;
int numToDrop = 1;
- auto memRefType = input.dyn_cast<MemRefType>();
- auto unrankedMemRefType = input.dyn_cast<UnrankedMemRefType>();
+ auto memRefType = dyn_cast<MemRefType>(input);
+ auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
if (memRefType || unrankedMemRefType) {
numToDrop = memRefType
? MemRefDescriptor::getNumUnpackedValues(memRefType)
@@ -301,9 +300,9 @@ static void modifyFuncOpToUseBarePtrCallingConv(
// Unranked memrefs are not supported in the bare pointer calling
// convention. We should have bailed out before in the presence of
// unranked memrefs.
- assert(!argTy.isa<UnrankedMemRefType>() &&
+ assert(!isa<UnrankedMemRefType>(argTy) &&
"Unranked memref is not supported");
- auto memrefTy = argTy.dyn_cast<MemRefType>();
+ auto memrefTy = dyn_cast<MemRefType>(argTy);
if (!memrefTy)
continue;
@@ -360,18 +359,18 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
}
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
SmallVector<Attribute, 4> newArgAttrs(
- llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
+ cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
// Some LLVM IR attribute have a type attached to them. During FuncOp ->
// LLVMFuncOp conversion these types may have changed. Account for that
// change by converting attributes' types as well.
SmallVector<NamedAttribute, 4> convertedAttrs;
- auto attrsDict = argAttrDicts[i].cast<DictionaryAttr>();
+ auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
convertedAttrs.reserve(attrsDict.size());
for (const NamedAttribute &attr : attrsDict) {
const auto convert = [&](const NamedAttribute &attr) {
return TypeAttr::get(getTypeConverter()->convertType(
- attr.getValue().cast<TypeAttr>().getValue()));
+ cast<TypeAttr>(attr.getValue()).getValue()));
};
if (attr.getName().getValue() ==
LLVM::LLVMDialect::getByValAttrName()) {
@@ -418,7 +417,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
LLVM::Linkage linkage = LLVM::Linkage::External;
if (funcOp->hasAttr(linkageAttrName)) {
auto attr =
- funcOp->getAttr(linkageAttrName).dyn_cast<mlir::LLVM::LinkageAttr>();
+ dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
if (!attr) {
funcOp->emitError() << "Contains " << linkageAttrName
<< " attribute not of type LLVM::LinkageAttr";
@@ -545,7 +544,7 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
if (useBarePtrCallConv) {
for (auto it : callOp->getOperands()) {
Type operandType = it.getType();
- if (operandType.isa<UnrankedMemRefType>()) {
+ if (isa<UnrankedMemRefType>(operandType)) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
@@ -669,11 +668,11 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
Type oldTy = std::get<0>(it).getType();
Value newOperand = std::get<1>(it);
- if (oldTy.isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr(
- oldTy.cast<BaseMemRefType>())) {
+ if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
+ cast<BaseMemRefType>(oldTy))) {
MemRefDescriptor memrefDesc(newOperand);
newOperand = memrefDesc.allocatedPtr(rewriter, loc);
- } else if (oldTy.isa<UnrankedMemRefType>()) {
+ } else if (isa<UnrankedMemRefType>(oldTy)) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index a13acc691ae35..664d077c58875 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -26,22 +26,20 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
BlockArgument attribution = en.value();
- auto type = attribution.getType().dyn_cast<MemRefType>();
+ auto type = dyn_cast<MemRefType>(attribution.getType());
assert(type && type.hasStaticShape() && "unexpected type in attribution");
uint64_t numElements = type.getNumElements();
auto elementType =
- typeConverter->convertType(type.getElementType()).template cast<Type>();
+ cast<Type>(typeConverter->convertType(type.getElementType()));
auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
std::string name = std::string(
llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
uint64_t alignment = 0;
if (auto alignAttr =
- gpuFuncOp
- .getWorkgroupAttributionAttr(
- en.index(), LLVM::LLVMDialect::getAlignAttrName())
- .dyn_cast_or_null<IntegerAttr>())
+ dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr(
+ en.index(), LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
auto globalOp = rewriter.create<LLVM::GlobalOp>(
gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
@@ -100,7 +98,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
global.getAddrSpace()),
global.getSymNameAttr());
auto elementType =
- global.getType().cast<LLVM::LLVMArrayType>().getElementType();
+ cast<LLVM::LLVMArrayType>(global.getType()).getElementType();
Value memory = rewriter.create<LLVM::GEPOp>(
loc,
getTypeConverter()->getPointerType(elementType,
@@ -112,7 +110,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// otherwise necessary given that memref sizes are fixed, but we can try
// and canonicalize that away later.
Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
- auto type = attribution.getType().cast<MemRefType>();
+ auto type = cast<MemRefType>(attribution.getType());
auto descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, memory);
signatureConversion.remapInput(numProperArguments + en.index(), descr);
@@ -123,7 +121,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
Value attribution = en.value();
- auto type = attribution.getType().cast<MemRefType>();
+ auto type = cast<MemRefType>(attribution.getType());
assert(type && type.hasStaticShape() && "unexpected type in attribution");
// Explicitly drop memory space when lowering private memory
@@ -136,10 +134,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
uint64_t alignment = 0;
if (auto alignAttr =
- gpuFuncOp
- .getPrivateAttributionAttr(
- en.index(), LLVM::LLVMDialect::getAlignAttrName())
- .dyn_cast_or_null<IntegerAttr>())
+ dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
+ en.index(), LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
Value allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
@@ -164,7 +160,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
- auto memrefTy = en.value().dyn_cast<MemRefType>();
+ auto memrefTy = dyn_cast<MemRefType>(en.value());
if (!memrefTy)
continue;
assert(memrefTy.hasStaticShape() &&
@@ -302,7 +298,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
for (size_t i = group; i < bound; ++i) {
Value arg = adaptor.getArgs()[i];
- if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
+ if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
if (!floatType.isF64())
arg = rewriter.create<LLVM::FPExtOp>(
loc, typeConverter->convertType(rewriter.getF64Type()), arg);
@@ -428,7 +424,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
Type type = arg.getType();
Value promotedArg = arg;
assert(type.isIntOrFloat());
- if (type.isa<FloatType>()) {
+ if (isa<FloatType>(type)) {
type = rewriter.getF64Type();
promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
}
@@ -462,14 +458,14 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
LLVMTypeConverter &converter) {
TypeRange operandTypes(operands);
if (llvm::none_of(operandTypes,
- [](Type type) { return type.isa<VectorType>(); })) {
+ [](Type type) { return isa<VectorType>(type); })) {
return rewriter.notifyMatchFailure(op, "expected vector operand");
}
if (op->getNumRegions() != 0 || op->getNumSuccessors() != 0)
return rewriter.notifyMatchFailure(op, "expected no region/successor");
if (op->getNumResults() != 1)
return rewriter.notifyMatchFailure(op, "expected single result");
- VectorType vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
+ VectorType vectorType = dyn_cast<VectorType>(op->getResult(0).getType());
if (!vectorType)
return rewriter.notifyMatchFailure(op, "expected vector result");
@@ -482,7 +478,7 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i);
auto extractElement = [&](Value operand) -> Value {
- if (!operand.getType().isa<VectorType>())
+ if (!isa<VectorType>(operand.getType()))
return operand;
return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index);
};
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 3687bd6718bf1..43dff49e1caea 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -454,7 +454,7 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
Location loc = op->getLoc();
auto memRefType = hostRegisterOp.getValue().getType();
- auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
+ auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
auto arguments = getTypeConverter()->promoteOperands(
@@ -476,7 +476,7 @@ LogicalResult ConvertHostUnregisterOpToGpuRuntimeCallPattern::matchAndRewrite(
Location loc = op->getLoc();
auto memRefType = hostUnregisterOp.getValue().getType();
- auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
+ auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
auto arguments = getTypeConverter()->promoteOperands(
@@ -555,7 +555,7 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
}
static bool isGpuAsyncTokenType(Value value) {
- return value.getType().isa<gpu::AsyncTokenType>();
+ return isa<gpu::AsyncTokenType>(value.getType());
}
// Converts !gpu.async.token operands of `async.yield` to runtime calls. The
@@ -591,7 +591,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
// Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
static bool isDefinedByCallTo(Value value, StringRef functionName) {
- assert(value.getType().isa<LLVM::LLVMPointerType>());
+ assert(isa<LLVM::LLVMPointerType>(value.getType()));
if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
return defOp.getCallee()->equals(functionName);
return false;
@@ -862,7 +862,7 @@ static Value bitAndAddrspaceCast(Location loc,
LLVM::LLVMPointerType destinationType,
Value sourcePtr,
LLVMTypeConverter &typeConverter) {
- auto sourceTy = sourcePtr.getType().cast<LLVM::LLVMPointerType>();
+ auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType());
if (destinationType.getAddressSpace() != sourceTy.getAddressSpace())
sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>(
loc,
@@ -879,7 +879,7 @@ static Value bitAndAddrspaceCast(Location loc,
LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto memRefType = memcpyOp.getSrc().getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
@@ -919,7 +919,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::MemsetOp memsetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto memRefType = memsetOp.getDst().getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memsetOp.getDst().getType());
if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
!isConvertibleAndHasIdentityMaps(memRefType) ||
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 6858569862535..ebce2d77310ae 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -54,8 +54,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
Type resultType = castedOperands.front().getType();
Type funcType = getFunctionType(resultType, castedOperands);
- StringRef funcName = getFunctionName(
- funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
+ StringRef funcName =
+ getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
if (funcName.empty())
return failure();
@@ -78,7 +78,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
private:
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
Type type = operand.getType();
- if (!type.isa<Float16Type>())
+ if (!isa<Float16Type>(type))
return operand;
return rewriter.create<LLVM::FPExtOp>(
@@ -91,9 +91,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
}
StringRef getFunctionName(Type type) const {
- if (type.isa<Float32Type>())
+ if (isa<Float32Type>(type))
return f32Func;
- if (type.isa<Float64Type>())
+ if (isa<Float64Type>(type))
return f64Func;
return "";
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index bf5be54f593e9..775dd1e609037 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -91,7 +91,7 @@ struct WmmaLoadOpToNVVMLowering
? NVVM::MMALayout::col
: NVVM::MMALayout::row;
gpu::MMAMatrixType retType =
- subgroupMmaLoadMatrixOp.getRes().getType().cast<gpu::MMAMatrixType>();
+ cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
ArrayRef<int64_t> retTypeShape = retType.getShape();
int64_t m = 0;
int64_t n = 0;
@@ -122,8 +122,7 @@ struct WmmaLoadOpToNVVMLowering
// Create nvvm.mma_load op according to the operand types.
Value dataPtr = getStridedElementPtr(
- loc,
- subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast<MemRefType>(),
+ loc, cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType()),
adaptor.getSrcMemref(), adaptor.getIndices(), rewriter);
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
@@ -158,7 +157,7 @@ struct WmmaStoreOpToNVVMLowering
// Get the shape of the MMAMatrix type being stored. The shape will
// choose which intrinsic this op will be lowered to.
gpu::MMAMatrixType srcType =
- subgroupMmaStoreMatrixOp.getSrc().getType().cast<gpu::MMAMatrixType>();
+ cast<gpu::MMAMatrixType>(subgroupMmaStoreMatrixOp.getSrc().getType());
ArrayRef<int64_t> srcTypeShape = srcType.getShape();
NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
? NVVM::MMALayout::col
@@ -170,7 +169,7 @@ struct WmmaStoreOpToNVVMLowering
if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
- auto matrixType = adaptor.getSrc().getType().cast<LLVM::LLVMStructType>();
+ auto matrixType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
Value toUse =
rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getSrc(), i);
@@ -179,7 +178,7 @@ struct WmmaStoreOpToNVVMLowering
Value dataPtr = getStridedElementPtr(
loc,
- subgroupMmaStoreMatrixOp.getDstMemref().getType().cast<MemRefType>(),
+ cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType()),
adaptor.getDstMemref(), adaptor.getIndices(), rewriter);
Value leadingDim = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
@@ -214,7 +213,7 @@ struct WmmaMmaOpToNVVMLowering
SmallVector<Value> unpackedOps;
auto unpackOp = [&](Value operand) {
- auto structType = operand.getType().cast<LLVM::LLVMStructType>();
+ auto structType = cast<LLVM::LLVMStructType>(operand.getType());
for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
unpackedOps.push_back(toUse);
@@ -224,10 +223,10 @@ struct WmmaMmaOpToNVVMLowering
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
gpu::MMAMatrixType aType =
- subgroupMmaComputeOp.getOpA().getType().cast<gpu::MMAMatrixType>();
+ cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpA().getType());
ArrayRef<int64_t> aTypeShape = aType.getShape();
gpu::MMAMatrixType cType =
- subgroupMmaComputeOp.getOpC().getType().cast<gpu::MMAMatrixType>();
+ cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpC().getType());
ArrayRef<int64_t> cTypeShape = cType.getShape();
int64_t m = cTypeShape[0];
int64_t n = cTypeShape[1];
@@ -245,7 +244,7 @@ struct WmmaMmaOpToNVVMLowering
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
NVVM::MMATypes bElementType = getElementType(
- subgroupMmaComputeOp.getOpB().getType().cast<gpu::MMAMatrixType>());
+ cast<gpu::MMAMatrixType>(subgroupMmaComputeOp.getOpB().getType()));
if (bElementType != sourceType)
return rewriter.notifyMatchFailure(
op, "WMMA compute op input matrix element types must match.");
@@ -277,9 +276,9 @@ struct WmmaConstantOpToNVVMLowering
Location loc = subgroupMmaConstantOp.getLoc();
Value cst = adaptor.getOperands()[0];
LLVM::LLVMStructType type = convertMMAToLLVMType(
- subgroupMmaConstantOp.getType().cast<gpu::MMAMatrixType>());
+ cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
// If the element type is a vector create a vector from the operand.
- if (auto vecType = type.getBody()[0].dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
Value idx = rewriter.create<LLVM::ConstantOp>(
@@ -301,9 +300,9 @@ struct WmmaConstantOpToNVVMLowering
static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
Value rhs, bool isMin) {
- auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
+ auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
- if (auto vecType = lhs.getType().dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
i1Type = VectorType::get(vecType.getShape(), i1Type);
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
@@ -355,7 +354,7 @@ struct WmmaElementwiseOpToNVVMLowering
Location loc = subgroupMmaElementwiseOp.getLoc();
size_t numOperands = adaptor.getOperands().size();
LLVM::LLVMStructType destType = convertMMAToLLVMType(
- subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
+ cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 98f90a3a8aef5..1ac4e8ed00051 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -54,7 +54,7 @@ using namespace mlir;
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
bool canBeBare = true;
for (Type type : func.getArgumentTypes())
- if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
+ if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy);
return canBeBare;
}
@@ -166,9 +166,8 @@ struct LowerGpuOpsToROCDLOpsPass
// Manually rewrite known block size attributes so the LLVMIR translation
// infrastructure can pick them up.
m.walk([ctx](LLVM::LLVMFuncOp op) {
- if (auto blockSizes =
- op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName())
- .dyn_cast_or_null<DenseI32ArrayAttr>()) {
+ if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
+ op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
blockSizes);
// Also set up the rocdl.flat_work_group_size attribute to prevent
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index becb28e61fd5d..feea1e34f1b43 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -495,9 +495,9 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
Type type = arg.getType();
using MembptrT = FuncT OpHandler::*;
MembptrT handlerPtr;
- if (type.isa<FloatType>()) {
+ if (isa<FloatType>(type)) {
handlerPtr = &OpHandler::floatFunc;
- } else if (type.isa<IntegerType>()) {
+ } else if (isa<IntegerType>(type)) {
handlerPtr = &OpHandler::intFunc;
} else {
return std::nullopt;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index f7e135620133c..d64fa6ac4ece2 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -81,9 +81,9 @@ struct WmmaLoadOpToSPIRVLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = subgroupMmaLoadMatrixOp->getLoc();
gpu::MMAMatrixType retType =
- subgroupMmaLoadMatrixOp.getRes().getType().cast<gpu::MMAMatrixType>();
+ cast<gpu::MMAMatrixType>(subgroupMmaLoadMatrixOp.getRes().getType());
auto memrefType =
- subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast<MemRefType>();
+ cast<MemRefType>(subgroupMmaLoadMatrixOp.getSrcMemref().getType());
Value bufferPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
@@ -114,7 +114,7 @@ struct WmmaStoreOpToSPIRVLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = subgroupMmaStoreMatrixOp->getLoc();
auto memrefType =
- subgroupMmaStoreMatrixOp.getDstMemref().getType().cast<MemRefType>();
+ cast<MemRefType>(subgroupMmaStoreMatrixOp.getDstMemref().getType());
Value bufferPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType,
adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
@@ -161,7 +161,7 @@ struct WmmaConstantOpToSPIRVLowering
ConversionPatternRewriter &rewriter) const override {
Value cst = adaptor.getOperands()[0];
auto coopType = convertMMAToSPIRVType(
- subgroupMmaConstantMatrixOp.getType().cast<gpu::MMAMatrixType>());
+ cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
subgroupMmaConstantMatrixOp, coopType, cst);
return success();
@@ -180,11 +180,11 @@ struct WmmaElementwiseOpToSPIRVDefaultLowering
ConversionPatternRewriter &rewriter) const override {
// All operands should be of cooperative matrix types.
for (Value operand : adaptor.getOperands()) {
- if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
+ if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
return failure();
}
auto coopType = convertMMAToSPIRVType(
- elementwiseOp.getType().cast<gpu::MMAMatrixType>());
+ cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
adaptor.getOperands()));
}
@@ -204,7 +204,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
return failure();
// All operands should be of cooperative matrix types.
for (Value operand : adaptor.getOperands()) {
- if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
+ if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
return failure();
}
@@ -236,7 +236,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering
scalar = cc.getConstituents().front();
auto coopType = convertMMAToSPIRVType(
- elementwiseOp.getType().cast<gpu::MMAMatrixType>());
+ cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
elementwiseOp, coopType, ValueRange{matrix, scalar});
return success();
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index e4ac64252acc4..2d2251672230b 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -61,7 +61,7 @@ class ConvertGpuLaunchFuncToVulkanLaunchFunc
/// Checks where the given type is supported by Vulkan runtime.
bool isSupportedType(Type type) {
- if (auto memRefType = type.dyn_cast_or_null<MemRefType>()) {
+ if (auto memRefType = dyn_cast_or_null<MemRefType>(type)) {
auto elementType = memRefType.getElementType();
return memRefType.hasRank() &&
(memRefType.getRank() >= 1 && memRefType.getRank() <= 3) &&
@@ -197,7 +197,7 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
// The below cast always succeeds as it has already been verified in
// 'declareVulkanLaunchFunc' that these are MemRefs with compatible element
// types.
- elementTypes.push_back(type.cast<MemRefType>().getElementType());
+ elementTypes.push_back(cast<MemRefType>(type).getElementType());
}
vulkanLaunchCallOp->setAttr(kSPIRVElementTypesAttrName,
builder.getTypeArrayAttr(elementTypes));
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 78d1f6790c859..036eb0200ca71 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -142,11 +142,11 @@ class VulkanLaunchFuncToVulkanCallsPass
/// Returns a string representation from the given `type`.
StringRef stringifyType(Type type) {
- if (type.isa<Float32Type>())
+ if (isa<Float32Type>(type))
return "Float";
- if (type.isa<Float16Type>())
+ if (isa<Float16Type>(type))
return "Half";
- if (auto intType = type.dyn_cast<IntegerType>()) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
if (intType.getWidth() == 32)
return "Int32";
if (intType.getWidth() == 16)
@@ -282,7 +282,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
// Special case for fp16 type. Since it is not a supported type in C we use
// int16_t and bitcast the descriptor.
- if (!useOpaquePointers && type.isa<Float16Type>()) {
+ if (!useOpaquePointers && isa<Float16Type>(type)) {
auto memRefTy = getMemRefType(rank, IntegerType::get(&getContext(), 16));
ptrToMemRefDescriptor = builder.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor);
@@ -328,9 +328,8 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
rank = 0;
return success();
}
- rank = llvmDescriptorTy.getBody()[3]
- .cast<LLVM::LLVMArrayType>()
- .getNumElements();
+ rank =
+ cast<LLVM::LLVMArrayType>(llvmDescriptorTy.getBody()[3]).getNumElements();
return success();
}
@@ -375,7 +374,7 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
for (auto type : types) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
- if (type.isa<Float16Type>())
+ if (isa<Float16Type>(type))
type = IntegerType::get(&getContext(), 16);
if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMFunctionType::get(
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 2373765dae007..df9dafc2d696c 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -24,8 +24,7 @@ using namespace mlir;
MemRefDescriptor::MemRefDescriptor(Value descriptor)
: StructBuilder(descriptor) {
assert(value != nullptr && "value cannot be null");
- indexType = value.getType()
- .cast<LLVM::LLVMStructType>()
+ indexType = cast<LLVM::LLVMStructType>(value.getType())
.getBody()[kOffsetPosInMemRefDescriptor];
}
@@ -193,10 +192,9 @@ void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
}
LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
- return value.getType()
- .cast<LLVM::LLVMStructType>()
- .getBody()[kAlignedPtrPosInMemRefDescriptor]
- .cast<LLVM::LLVMPointerType>();
+ return cast<LLVM::LLVMPointerType>(
+ cast<LLVM::LLVMStructType>(value.getType())
+ .getBody()[kAlignedPtrPosInMemRefDescriptor]);
}
Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 67a2898c02050..c55a62e1edb61 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -235,7 +235,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
SmallVector<unsigned> unrankedAddressSpaces;
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (auto memRefType = origTypes[i].dyn_cast<UnrankedMemRefType>()) {
+ if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
unrankedMemrefs.emplace_back(operands[i]);
FailureOr<unsigned> addressSpace =
getTypeConverter()->getMemRefAddressSpace(memRefType);
@@ -276,7 +276,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
unsigned unrankedMemrefPos = 0;
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
Type type = origTypes[i];
- if (!type.isa<UnrankedMemRefType>())
+ if (!isa<UnrankedMemRefType>(type))
continue;
Value allocationSize = sizes[unrankedMemrefPos++];
UnrankedMemRefDescriptor desc(operands[i]);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 88d7eaf524bad..cf0d5068c02d7 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -260,7 +260,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
if (!resultType)
return {};
- auto structType = resultType.dyn_cast<LLVM::LLVMStructType>();
+ auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
if (structType) {
// Struct types cannot be safely returned via C interface. Make this a
// pointer argument, instead.
@@ -272,7 +272,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
auto converted = convertType(t);
if (!converted || !LLVM::isCompatibleType(converted))
return {};
- if (t.isa<MemRefType, UnrankedMemRefType>())
+ if (isa<MemRefType, UnrankedMemRefType>(t))
converted = getPointerType(converted);
inputs.push_back(converted);
}
@@ -412,13 +412,13 @@ LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
// Check if a memref type can be converted to a bare pointer.
bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
- if (type.isa<UnrankedMemRefType>())
+ if (isa<UnrankedMemRefType>(type))
// Unranked memref is not supported in the bare pointer calling convention.
return false;
// Check that the memref has static shape, strides and offset. Otherwise, it
// cannot be lowered to a bare pointer.
- auto memrefTy = type.cast<MemRefType>();
+ auto memrefTy = cast<MemRefType>(type);
if (!memrefTy.hasStaticShape())
return false;
@@ -476,7 +476,7 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
Type LLVMTypeConverter::convertCallingConventionType(Type type,
bool useBarePtrCallConv) {
if (useBarePtrCallConv)
- if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
+ if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
return convertMemRefToBarePtr(memrefTy);
return convertType(type);
@@ -491,7 +491,7 @@ void LLVMTypeConverter::promoteBarePtrsToDescriptors(
assert(stdTypes.size() == values.size() &&
"The number of types and values doesn't match");
for (unsigned i = 0, end = values.size(); i < end; ++i)
- if (auto memrefTy = stdTypes[i].dyn_cast<MemRefType>())
+ if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
memrefTy, values[i]);
}
@@ -569,19 +569,19 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
if (useBarePtrCallConv) {
// For the bare-ptr calling convention, we only have to extract the
// aligned pointer of a memref.
- if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+ if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
MemRefDescriptor desc(llvmOperand);
llvmOperand = desc.alignedPtr(builder, loc);
- } else if (operand.getType().isa<UnrankedMemRefType>()) {
+ } else if (isa<UnrankedMemRefType>(operand.getType())) {
llvm_unreachable("Unranked memrefs are not supported");
}
} else {
- if (operand.getType().isa<UnrankedMemRefType>()) {
+ if (isa<UnrankedMemRefType>(operand.getType())) {
UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
promotedOperands);
continue;
}
- if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+ if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
promotedOperands);
continue;
@@ -600,7 +600,7 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result) {
- if (auto memref = type.dyn_cast<MemRefType>()) {
+ if (auto memref = dyn_cast<MemRefType>(type)) {
// In signatures, Memref descriptors are expanded into lists of
// non-aggregate values.
auto converted =
@@ -610,7 +610,7 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
result.append(converted.begin(), converted.end());
return success();
}
- if (type.isa<UnrankedMemRefType>()) {
+ if (isa<UnrankedMemRefType>(type)) {
auto converted = converter.getUnrankedMemRefDescriptorFields();
if (converted.empty())
return failure();
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e95c702d79f38..732f6c578c8b5 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -27,10 +27,10 @@ LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType,
}
info.arraySizes.reserve(vectorType.getRank() - 1);
auto llvmTy = info.llvmNDVectorTy;
- while (llvmTy.isa<LLVM::LLVMArrayType>()) {
+ while (isa<LLVM::LLVMArrayType>(llvmTy)) {
info.arraySizes.push_back(
- llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
- llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
+ cast<LLVM::LLVMArrayType>(llvmTy).getNumElements());
+ llvmTy = cast<LLVM::LLVMArrayType>(llvmTy).getElementType();
}
if (!LLVM::isCompatibleVectorType(llvmTy))
return info;
@@ -81,7 +81,7 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
- auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
+ auto resultNDVectorType = cast<VectorType>(op->getResult(0).getType());
auto resultTypeInfo =
extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
@@ -114,7 +114,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
return failure();
auto llvmNDVectorTy = operands[0].getType();
- if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
+ if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
rewriter);
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index f94da68e87b34..4d1f35c767304 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -42,7 +42,7 @@ static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
// The underlying descriptor type (e.g. LLVM) does not have layout
// information. Canonicalizing the type at the level of std when going into
// a library call avoids needing to introduce DialectCastOp.
- if (auto memrefType = type.dyn_cast<MemRefType>())
+ if (auto memrefType = dyn_cast<MemRefType>(type))
result.push_back(makeStridedLayoutDynamic(memrefType));
else
result.push_back(type);
@@ -96,7 +96,7 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
SmallVector<Value, 4> res;
res.reserve(operands.size());
for (auto op : operands) {
- auto memrefType = op.getType().dyn_cast<MemRefType>();
+ auto memrefType = dyn_cast<MemRefType>(op.getType());
if (!memrefType) {
res.push_back(op);
continue;
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index 10832a15f01a8..3a567643ffdb8 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -106,7 +106,7 @@ LogicalResult
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
Type opType = op.getType();
Location loc = op.getLoc();
- auto vecType = opType.template dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(opType);
if (!vecType)
return rewriter.notifyMatchFailure(op, "not a vector operation");
@@ -117,7 +117,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
Type resultElementType = vecType.getElementType();
Attribute initValueAttr;
- if (resultElementType.isa<FloatType>())
+ if (isa<FloatType>(resultElementType))
initValueAttr = FloatAttr::get(resultElementType, 0.0);
else
initValueAttr = IntegerAttr::get(resultElementType, 0);
@@ -183,7 +183,7 @@ static FunctionType getElementalFuncTypeForOp(Operation *op) {
/// }
/// }
static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
- assert(elementType.isa<IntegerType>() &&
+ assert(isa<IntegerType>(elementType) &&
"non-integer element type for IPowIOp");
ImplicitLocOpBuilder builder =
@@ -361,7 +361,7 @@ static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) {
LogicalResult
IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
PatternRewriter &rewriter) const {
- auto baseType = op.getOperands()[0].getType().dyn_cast<IntegerType>();
+ auto baseType = dyn_cast<IntegerType>(op.getOperands()[0].getType());
if (!baseType)
return rewriter.notifyMatchFailure(op, "non-integer base operand");
@@ -411,8 +411,8 @@ IPowIOpLowering::matchAndRewrite(math::IPowIOp op,
/// }
static func::FuncOp createElementFPowIFunc(ModuleOp *module,
FunctionType funcType) {
- auto baseType = funcType.getInput(0).cast<FloatType>();
- auto powType = funcType.getInput(1).cast<IntegerType>();
+ auto baseType = cast<FloatType>(funcType.getInput(0));
+ auto powType = cast<IntegerType>(funcType.getInput(1));
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody());
@@ -586,7 +586,7 @@ static func::FuncOp createElementFPowIFunc(ModuleOp *module,
LogicalResult
FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
PatternRewriter &rewriter) const {
- if (op.getType().template dyn_cast<VectorType>())
+ if (dyn_cast<VectorType>(op.getType()))
return rewriter.notifyMatchFailure(op, "non-scalar operation");
FunctionType funcType = getElementalFuncTypeForOp(op);
@@ -649,7 +649,7 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
/// return %out: i32
/// }
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
- if (!elementType.isa<IntegerType>()) {
+ if (!isa<IntegerType>(elementType)) {
LLVM_DEBUG({
DBGS() << "non-integer element type for CtlzFunc; type was: ";
elementType.print(llvm::dbgs());
@@ -751,7 +751,7 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
/// operation.
LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op,
PatternRewriter &rewriter) const {
- if (op.getType().template dyn_cast<VectorType>())
+ if (dyn_cast<VectorType>(op.getType()))
return rewriter.notifyMatchFailure(op, "non-scalar operation");
Type type = getElementTypeOrSelf(op.getResult().getType());
@@ -794,7 +794,7 @@ struct ConvertMathToFuncsPass
bool ConvertMathToFuncsPass::isFPowIConvertible(math::FPowIOp op) {
auto expTy =
- getElementTypeOrSelf(op.getRhs().getType()).dyn_cast<IntegerType>();
+ dyn_cast<IntegerType>(getElementTypeOrSelf(op.getRhs().getType()));
return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent);
}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index c331f4f2163bd..6dc5c418d4173 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -79,14 +79,14 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
auto resultType = op.getResult().getType();
auto boolZero = rewriter.getBoolAttr(false);
- if (!operandType.template isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(operandType)) {
LLVM::ConstantOp zero = rewriter.create<LLVM::ConstantOp>(loc, boolZero);
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
zero);
return success();
}
- auto vectorType = resultType.template dyn_cast<VectorType>();
+ auto vectorType = dyn_cast<VectorType>(resultType);
if (!vectorType)
return failure();
@@ -122,17 +122,17 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
- auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+ auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(operandType)) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(operandType)) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
- SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
+ SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
@@ -143,7 +143,7 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
return success();
}
- auto vectorType = resultType.dyn_cast<VectorType>();
+ auto vectorType = dyn_cast<VectorType>(resultType);
if (!vectorType)
return rewriter.notifyMatchFailure(op, "expected vector result type");
@@ -180,17 +180,17 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
- auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+ auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(operandType)) {
LLVM::ConstantOp one =
LLVM::isCompatibleVectorType(operandType)
? rewriter.create<LLVM::ConstantOp>(
loc, operandType,
- SplatElementsAttr::get(resultType.cast<ShapedType>(),
+ SplatElementsAttr::get(cast<ShapedType>(resultType),
floatOne))
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
@@ -202,7 +202,7 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
return success();
}
- auto vectorType = resultType.dyn_cast<VectorType>();
+ auto vectorType = dyn_cast<VectorType>(resultType);
if (!vectorType)
return rewriter.notifyMatchFailure(op, "expected vector result type");
@@ -240,17 +240,17 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
- auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+ auto floatType = cast<FloatType>(getElementTypeOrSelf(resultType));
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
- if (!operandType.isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(operandType)) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(operandType)) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
- SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
+ SplatElementsAttr::get(cast<ShapedType>(resultType), floatOne));
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
@@ -261,7 +261,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
return success();
}
- auto vectorType = resultType.dyn_cast<VectorType>();
+ auto vectorType = dyn_cast<VectorType>(resultType);
if (!vectorType)
return failure();
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index d6834441e92bb..7fd94116ef7a6 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -75,7 +75,7 @@ LogicalResult
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
auto opType = op.getType();
auto loc = op.getLoc();
- auto vecType = opType.template dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(opType);
if (!vecType)
return failure();
@@ -107,7 +107,7 @@ template <typename Op>
LogicalResult
PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
auto opType = op.getType();
- if (!opType.template isa<Float16Type, BFloat16Type>())
+ if (!isa<Float16Type, BFloat16Type>(opType))
return failure();
auto loc = op.getLoc();
@@ -127,7 +127,7 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
PatternRewriter &rewriter) const {
auto module = SymbolTable::getNearestSymbolTable(op);
auto type = op.getType();
- if (!type.template isa<Float32Type, Float64Type>())
+ if (!isa<Float32Type, Float64Type>(type))
return failure();
auto name = type.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 412f99ce042e9..6630aaf7e3a23 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -34,7 +34,7 @@ using namespace mlir;
/// given type is not a 32-bit scalar/vector type.
static Value getScalarOrVectorI32Constant(Type type, int value,
OpBuilder &builder, Location loc) {
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
if (!vectorType.getElementType().isInteger(32))
return nullptr;
SmallVector<int> values(vectorType.getNumElements(), value);
@@ -55,7 +55,7 @@ static bool isSupportedSourceType(Type originalType) {
if (originalType.isIntOrIndexOrFloat())
return true;
- if (auto vecTy = originalType.dyn_cast<VectorType>()) {
+ if (auto vecTy = dyn_cast<VectorType>(originalType)) {
if (!vecTy.getElementType().isIntOrIndexOrFloat())
return false;
if (vecTy.isScalable())
@@ -133,10 +133,10 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
return failure();
FloatType floatType;
- if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
+ if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
floatType = scalarType;
- } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
- floatType = vectorType.getElementType().cast<FloatType>();
+ } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
+ floatType = cast<FloatType>(vectorType.getElementType());
} else {
return failure();
}
@@ -151,7 +151,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
Value valueMask = rewriter.create<spirv::ConstantOp>(
loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
intType = VectorType::get(count, intType);
@@ -203,9 +203,9 @@ struct CountLeadingZerosPattern final
// We can only support 32-bit integer types for now.
unsigned bitwidth = 0;
- if (type.isa<IntegerType>())
+ if (isa<IntegerType>(type))
bitwidth = type.getIntOrFloatBitWidth();
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = dyn_cast<VectorType>(type))
bitwidth = vectorType.getElementTypeBitWidth();
if (bitwidth != 32)
return failure();
@@ -307,10 +307,10 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Get the scalar float type.
FloatType scalarFloatType;
- if (auto scalarType = powfOp.getType().dyn_cast<FloatType>()) {
+ if (auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
scalarFloatType = scalarType;
- } else if (auto vectorType = powfOp.getType().dyn_cast<VectorType>()) {
- scalarFloatType = vectorType.getElementType().cast<FloatType>();
+ } else if (auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
+ scalarFloatType = cast<FloatType>(vectorType.getElementType());
} else {
return failure();
}
@@ -318,7 +318,7 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
// Get int type of the same shape as the float type.
Type scalarIntType = rewriter.getIntegerType(32);
Type intType = scalarIntType;
- if (auto vectorType = adaptor.getRhs().getType().dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType())) {
auto shape = vectorType.getShape();
intType = VectorType::get(shape, scalarIntType);
}
@@ -374,7 +374,7 @@ struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
Value half;
- if (VectorType vty = ty.dyn_cast<VectorType>()) {
+ if (VectorType vty = dyn_cast<VectorType>(ty)) {
half = rewriter.create<spirv::ConstantOp>(
loc, vty,
DenseElementsAttr::get(vty,
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index 37aa6cf53c206..2fa43151e2c1c 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -58,7 +58,7 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
Location loc, Value allocatedPtr,
MemRefType memRefType, Type elementPtrType,
LLVMTypeConverter &typeConverter) {
- auto allocatedPtrTy = allocatedPtr.getType().cast<LLVM::LLVMPointerType>();
+ auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
unsigned memrefAddrSpace = *typeConverter.getMemRefAddressSpace(memRefType);
if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
@@ -114,10 +114,10 @@ unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes(
layout = &analysis->getAbove(op);
}
Type elementType = memRefType.getElementType();
- if (auto memRefElementType = elementType.dyn_cast<MemRefType>())
+ if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
return getTypeConverter()->getMemRefDescriptorSize(memRefElementType,
*layout);
- if (auto memRefElementType = elementType.dyn_cast<UnrankedMemRefType>())
+ if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
return getTypeConverter()->getUnrankedMemRefDescriptorSize(
memRefElementType, *layout);
return layout->getTypeSize(elementType);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index e9fbad30783c3..1a6e5a4e8dbd0 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -184,10 +184,10 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering {
rewriter.setInsertionPointToEnd(currentBlock);
Value src = op.getSource();
- auto srcType = src.getType().dyn_cast<MemRefType>();
+ auto srcType = dyn_cast<MemRefType>(src.getType());
Value srcNumElements = computeNumElements(
srcType, [&]() -> Value { return desc.size(rewriter, loc, 0); });
- auto dstType = op.getType().cast<MemRefType>();
+ auto dstType = cast<MemRefType>(op.getType());
Value dstNumElements = computeNumElements(
dstType, [&]() -> Value { return op.getDynamicResultSize(); });
Value cond = rewriter.create<LLVM::ICmpOp>(
@@ -342,7 +342,7 @@ struct AssumeAlignmentOpLowering
unsigned alignment = op.getAlignment();
auto loc = op.getLoc();
- auto srcMemRefType = op.getMemref().getType().cast<MemRefType>();
+ auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
rewriter);
@@ -417,7 +417,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type operandType = dimOp.getSource().getType();
- if (operandType.isa<UnrankedMemRefType>()) {
+ if (isa<UnrankedMemRefType>(operandType)) {
FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
operandType, dimOp, adaptor.getOperands(), rewriter);
if (failed(extractedSize))
@@ -425,7 +425,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
rewriter.replaceOp(dimOp, {*extractedSize});
return success();
}
- if (operandType.isa<MemRefType>()) {
+ if (isa<MemRefType>(operandType)) {
rewriter.replaceOp(
dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
adaptor.getOperands(), rewriter)});
@@ -441,7 +441,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
ConversionPatternRewriter &rewriter) const {
Location loc = dimOp.getLoc();
- auto unrankedMemRefType = operandType.cast<UnrankedMemRefType>();
+ auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
auto scalarMemRefType =
MemRefType::get({}, unrankedMemRefType.getElementType());
FailureOr<unsigned> maybeAddressSpace =
@@ -492,10 +492,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
return idx;
if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
- return constantOp.getValue()
- .cast<IntegerAttr>()
- .getValue()
- .getSExtValue();
+ return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
return std::nullopt;
}
@@ -506,7 +503,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
Location loc = dimOp.getLoc();
// Take advantage if index is constant.
- MemRefType memRefType = operandType.cast<MemRefType>();
+ MemRefType memRefType = cast<MemRefType>(operandType);
if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
int64_t i = *index;
if (i >= 0 && i < memRefType.getRank()) {
@@ -589,7 +586,7 @@ struct GenericAtomicRMWOpLowering
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
- auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(
@@ -712,7 +709,7 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
Location loc, Value sizeBytes,
Operation *op) const override {
auto getGlobalOp = cast<memref::GetGlobalOp>(op);
- MemRefType type = getGlobalOp.getResult().getType().cast<MemRefType>();
+ MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
// This is called after a type conversion, which would have failed if this
// call fails.
@@ -823,12 +820,12 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type operandType = op.getMemref().getType();
- if (auto unrankedMemRefType = operandType.dyn_cast<UnrankedMemRefType>()) {
+ if (auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(operandType)) {
UnrankedMemRefDescriptor desc(adaptor.getMemref());
rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
return success();
}
- if (auto rankedMemRefType = operandType.dyn_cast<MemRefType>()) {
+ if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
rewriter.replaceOp(
op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())});
return success();
@@ -849,17 +846,17 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
// and require source and result type to have the same rank. Therefore,
// perform a sanity check that the underlying structs are the same. Once op
// semantics are relaxed we can revisit.
- if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
+ if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
return success(typeConverter->convertType(srcType) ==
typeConverter->convertType(dstType));
// At least one of the operands is unranked type
- assert(srcType.isa<UnrankedMemRefType>() ||
- dstType.isa<UnrankedMemRefType>());
+ assert(isa<UnrankedMemRefType>(srcType) ||
+ isa<UnrankedMemRefType>(dstType));
// Unranked to unranked cast is disallowed
- return !(srcType.isa<UnrankedMemRefType>() &&
- dstType.isa<UnrankedMemRefType>())
+ return !(isa<UnrankedMemRefType>(srcType) &&
+ isa<UnrankedMemRefType>(dstType))
? success()
: failure();
}
@@ -872,15 +869,15 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
auto loc = memRefCastOp.getLoc();
// For ranked/ranked case, just keep the original descriptor.
- if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>())
+ if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
- if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) {
+ if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
// Casting ranked to unranked memref type
// Set the rank in the destination from the memref type
// Allocate space on the stack and copy the src memref descriptor
// Set the ptr in the destination to the stack space
- auto srcMemRefType = srcType.cast<MemRefType>();
+ auto srcMemRefType = cast<MemRefType>(srcType);
int64_t rank = srcMemRefType.getRank();
// ptr = AllocaOp sizeof(MemRefDescriptor)
auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
@@ -905,7 +902,7 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
- } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) {
+ } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
// Casting from unranked type to ranked.
// The operation is assumed to be doing a correct cast. If the destination
// type mismatches the unranked the type, it is undefined behavior.
@@ -942,7 +939,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
- auto srcType = op.getSource().getType().dyn_cast<MemRefType>();
+ auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
MemRefDescriptor srcDesc(adaptor.getSource());
@@ -984,8 +981,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
- auto srcType = op.getSource().getType().cast<BaseMemRefType>();
- auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
+ auto srcType = cast<BaseMemRefType>(op.getSource().getType());
+ auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
// First make sure we have an unranked memref descriptor representation.
auto makeUnranked = [&, this](Value ranked, MemRefType type) {
@@ -1012,11 +1009,11 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto stackSaveOp =
rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
- auto srcMemRefType = srcType.dyn_cast<MemRefType>();
+ auto srcMemRefType = dyn_cast<MemRefType>(srcType);
Value unrankedSource =
srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
: adaptor.getSource();
- auto targetMemRefType = targetType.dyn_cast<MemRefType>();
+ auto targetMemRefType = dyn_cast<MemRefType>(targetType);
Value unrankedTarget =
targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
: adaptor.getTarget();
@@ -1055,8 +1052,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
LogicalResult
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto srcType = op.getSource().getType().cast<BaseMemRefType>();
- auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
+ auto srcType = cast<BaseMemRefType>(op.getSource().getType());
+ auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) {
if (!type.hasStaticShape())
@@ -1077,7 +1074,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
};
auto isContiguousMemrefType = [&](BaseMemRefType type) {
- auto memrefType = type.dyn_cast<mlir::MemRefType>();
+ auto memrefType = dyn_cast<mlir::MemRefType>(type);
// We can use memcpy for memrefs if they have an identity layout or are
// contiguous with an arbitrary offset. Ignore empty memrefs, which is a
// special case handled by memrefCopy.
@@ -1105,9 +1102,9 @@ struct MemorySpaceCastOpLowering
Location loc = op.getLoc();
Type resultType = op.getDest().getType();
- if (auto resultTypeR = resultType.dyn_cast<MemRefType>()) {
+ if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
auto resultDescType =
- typeConverter->convertType(resultTypeR).cast<LLVM::LLVMStructType>();
+ cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
Type newPtrType = resultDescType.getBody()[0];
SmallVector<Value> descVals;
@@ -1122,10 +1119,10 @@ struct MemorySpaceCastOpLowering
rewriter.replaceOp(op, result);
return success();
}
- if (auto resultTypeU = resultType.dyn_cast<UnrankedMemRefType>()) {
+ if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
// Since the type converter won't be doing this for us, get the address
// space.
- auto sourceType = op.getSource().getType().cast<UnrankedMemRefType>();
+ auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
FailureOr<unsigned> maybeSourceAddrSpace =
getTypeConverter()->getMemRefAddressSpace(sourceType);
if (failed(maybeSourceAddrSpace))
@@ -1217,7 +1214,7 @@ static void extractPointersAndOffset(Location loc,
Value *allocatedPtr, Value *alignedPtr,
Value *offset = nullptr) {
Type operandType = originalOperand.getType();
- if (operandType.isa<MemRefType>()) {
+ if (isa<MemRefType>(operandType)) {
MemRefDescriptor desc(convertedOperand);
*allocatedPtr = desc.allocatedPtr(rewriter, loc);
*alignedPtr = desc.alignedPtr(rewriter, loc);
@@ -1228,8 +1225,8 @@ static void extractPointersAndOffset(Location loc,
// These will all cause assert()s on unconvertible types.
unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
- operandType.cast<UnrankedMemRefType>());
- Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
+ cast<UnrankedMemRefType>(operandType));
+ Type elementType = cast<UnrankedMemRefType>(operandType).getElementType();
Type llvmElementType = typeConverter.convertType(elementType);
LLVM::LLVMPointerType elementPtrType =
typeConverter.getPointerType(llvmElementType, memorySpace);
@@ -1273,9 +1270,9 @@ struct MemRefReinterpretCastOpLowering
memref::ReinterpretCastOp castOp,
memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
MemRefType targetMemRefType =
- castOp.getResult().getType().cast<MemRefType>();
- auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
- .dyn_cast_or_null<LLVM::LLVMStructType>();
+ cast<MemRefType>(castOp.getResult().getType());
+ auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
+ typeConverter->convertType(targetMemRefType));
if (!llvmTargetDescriptorTy)
return failure();
@@ -1339,13 +1336,12 @@ struct MemRefReshapeOpLowering
Type srcType, memref::ReshapeOp reshapeOp,
memref::ReshapeOp::Adaptor adaptor,
Value *descriptor) const {
- auto shapeMemRefType = reshapeOp.getShape().getType().cast<MemRefType>();
+ auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
if (shapeMemRefType.hasStaticShape()) {
MemRefType targetMemRefType =
- reshapeOp.getResult().getType().cast<MemRefType>();
- auto llvmTargetDescriptorTy =
- typeConverter->convertType(targetMemRefType)
- .dyn_cast_or_null<LLVM::LLVMStructType>();
+ cast<MemRefType>(reshapeOp.getResult().getType());
+ auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
+ typeConverter->convertType(targetMemRefType));
if (!llvmTargetDescriptorTy)
return failure();
@@ -1426,8 +1422,7 @@ struct MemRefReshapeOpLowering
Value resultRank = shapeDesc.size(rewriter, loc, 0);
// Extract address space and element type.
- auto targetType =
- reshapeOp.getResult().getType().cast<UnrankedMemRefType>();
+ auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
unsigned addressSpace =
*getTypeConverter()->getMemRefAddressSpace(targetType);
Type elementType = targetType.getElementType();
@@ -1695,7 +1690,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
// Field 1: Copy the allocated pointer, used for malloc/free.
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
- auto srcMemRefType = viewOp.getSource().getType().cast<MemRefType>();
+ auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
unsigned sourceMemorySpace =
*getTypeConverter()->getMemRefAddressSpace(srcMemRefType);
Value bitcastPtr;
@@ -1848,7 +1843,7 @@ class ExtractStridedMetadataOpLowering
Location loc = extractStridedMetadataOp.getLoc();
Value source = extractStridedMetadataOp.getSource();
- auto sourceMemRefType = source.getType().cast<MemRefType>();
+ auto sourceMemRefType = cast<MemRefType>(source.getType());
int64_t rank = sourceMemRefType.getRank();
SmallVector<Value> results;
results.reserve(2 + rank * 2);
@@ -1858,7 +1853,7 @@ class ExtractStridedMetadataOpLowering
Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(),
- extractStridedMetadataOp.getBaseBuffer().getType().cast<MemRefType>(),
+ cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
baseBuffer, alignedBuffer);
results.push_back((Value)dstMemRef);
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 3f92c6f8e55a8..55c23d7c53e22 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -64,7 +64,7 @@ spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
// Unknown dialect custom attributes are not supported by default.
// Downstream callers should plug in more specialized ones.
- auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>();
+ auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
if (!intAttr)
return std::nullopt;
unsigned memorySpace = intAttr.getInt();
@@ -118,7 +118,7 @@ spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) {
// Unknown dialect custom attributes are not supported by default.
// Downstream callers should plug in more specialized ones.
- auto intAttr = memorySpaceAttr.dyn_cast<IntegerAttr>();
+ auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
if (!intAttr)
return std::nullopt;
unsigned memorySpace = intAttr.getInt();
@@ -177,7 +177,7 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
auto storageAttr =
spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
- if (auto rankedType = memRefType.dyn_cast<MemRefType>()) {
+ if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
rankedType.getLayout(), storageAttr);
}
@@ -203,9 +203,9 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
/// Returns true if the given `type` is considered as legal for SPIR-V
/// conversion.
static bool isLegalType(Type type) {
- if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
+ if (auto memRefType = dyn_cast<BaseMemRefType>(type)) {
Attribute spaceAttr = memRefType.getMemorySpace();
- return spaceAttr && spaceAttr.isa<spirv::StorageClassAttr>();
+ return spaceAttr && isa<spirv::StorageClassAttr>(spaceAttr);
}
return true;
}
@@ -213,7 +213,7 @@ static bool isLegalType(Type type) {
/// Returns true if the given `attr` is considered as legal for SPIR-V
/// conversion.
static bool isLegalAttr(Attribute attr) {
- if (auto typeAttr = attr.dyn_cast<TypeAttr>())
+ if (auto typeAttr = dyn_cast<TypeAttr>(attr))
return isLegalType(typeAttr.getValue());
return true;
}
@@ -266,7 +266,7 @@ LogicalResult MapMemRefStoragePattern::matchAndRewrite(
llvm::SmallVector<NamedAttribute, 4> newAttrs;
newAttrs.reserve(op->getAttrs().size());
for (auto attr : op->getAttrs()) {
- if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
+ if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
} else {
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 9c74feb597c67..efd541b46d8fe 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -93,11 +93,11 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
/// can be lowered to SPIR-V.
static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
- auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+ auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
return false;
} else if (isa<memref::AllocaOp>(allocOp)) {
- auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+ auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!sc || sc.getValue() != spirv::StorageClass::Function)
return false;
} else {
@@ -110,7 +110,7 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
return false;
Type elementType = type.getElementType();
- if (auto vecType = elementType.dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(elementType))
elementType = vecType.getElementType();
return elementType.isIntOrFloat();
}
@@ -119,7 +119,7 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
/// operations of unsupported integer bitwidths, based on the memref
/// type. Returns std::nullopt on failure.
static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
- auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+ auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
switch (sc.getValue()) {
case spirv::StorageClass::StorageBuffer:
return spirv::Scope::Device;
@@ -324,11 +324,11 @@ LogicalResult
AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- if (atomicOp.getType().isa<FloatType>())
+ if (isa<FloatType>(atomicOp.getType()))
return rewriter.notifyMatchFailure(atomicOp,
"unimplemented floating-point case");
- auto memrefType = atomicOp.getMemref().getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
if (!scope)
return rewriter.notifyMatchFailure(atomicOp,
@@ -380,7 +380,7 @@ LogicalResult
DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- MemRefType deallocType = operation.getMemref().getType().cast<MemRefType>();
+ MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
if (!isAllocationSupported(operation, deallocType))
return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
rewriter.eraseOp(operation);
@@ -395,7 +395,7 @@ LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = loadOp.getLoc();
- auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (!memrefType.getElementType().isSignlessInteger())
return failure();
@@ -419,18 +419,18 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
Type pointeeType = pointerType.getPointeeType();
Type dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
- if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
dstType = arrayType.getElementType();
else
dstType = pointeeType;
} else {
// For Vulkan we need to extract element from wrapping struct and array.
Type structElemType =
- pointeeType.cast<spirv::StructType>().getElementType(0);
- if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+ cast<spirv::StructType>(pointeeType).getElementType(0);
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
dstType = arrayType.getElementType();
else
- dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+ dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
}
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
@@ -509,7 +509,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
LogicalResult
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto memrefType = loadOp.getMemref().getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();
auto loadPtr = spirv::getElementPtr(
@@ -526,7 +526,7 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
LogicalResult
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
if (!memrefType.getElementType().isSignlessInteger())
return failure();
@@ -553,18 +553,18 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
Type pointeeType = pointerType.getPointeeType();
Type dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
- if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
dstType = arrayType.getElementType();
else
dstType = pointeeType;
} else {
// For Vulkan we need to extract element from wrapping struct and array.
Type structElemType =
- pointeeType.cast<spirv::StructType>().getElementType(0);
- if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+ cast<spirv::StructType>(pointeeType).getElementType(0);
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
dstType = arrayType.getElementType();
else
- dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+ dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
}
int dstBits = dstType.getIntOrFloatBitWidth();
@@ -651,21 +651,21 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
return rewriter.notifyMatchFailure(
loc, "address space casts require kernel capability");
- auto sourceType = addrCastOp.getSource().getType().dyn_cast<MemRefType>();
+ auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
if (!sourceType)
return rewriter.notifyMatchFailure(
loc, "SPIR-V lowering requires ranked memref types");
- auto resultType = addrCastOp.getResult().getType().cast<MemRefType>();
+ auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
auto sourceStorageClassAttr =
- sourceType.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+ dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
if (!sourceStorageClassAttr)
return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
diag << "source address space " << sourceType.getMemorySpace()
<< " must be a SPIR-V storage class";
});
auto resultStorageClassAttr =
- resultType.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+ dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
if (!resultStorageClassAttr)
return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
diag << "result address space " << resultType.getMemorySpace()
@@ -709,7 +709,7 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
LogicalResult
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto memrefType = storeOp.getMemref().getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
if (memrefType.getElementType().isSignlessInteger())
return failure();
auto storePtr = spirv::getElementPtr(
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4a923fac76c88..3d898e5af19c1 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -28,7 +28,7 @@ using namespace mlir;
/// `gpu.mma.sync` operation.
static Type inferIntrinsicResultType(Type vectorResultType) {
MLIRContext *ctx = vectorResultType.getContext();
- auto a = vectorResultType.cast<LLVM::LLVMArrayType>();
+ auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
auto i32Ty = IntegerType::get(ctx, 32);
auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
@@ -69,8 +69,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type resultType, Value intrinsicResult,
RewriterBase &rewriter) {
MLIRContext *ctx = rewriter.getContext();
- auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
- auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
+ auto structType = dyn_cast<LLVM::LLVMStructType>(intrinsicResultType);
+ auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType);
Type i32Ty = rewriter.getI32Type();
Type f32Ty = rewriter.getF32Type();
Type f64Ty = rewriter.getF64Type();
@@ -153,7 +153,7 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
- auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
+ auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
@@ -172,7 +172,7 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
// For some element types (i32, f32, f64), we need to unpack the inner
// vector/array type as well because the intrinsic expects individual
// scalars to be provided.
- VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
+ VectorType innerArrayTy = dyn_cast<VectorType>(arrayTy.getElementType());
if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
innerArrayTy.getElementType() == f64Ty ||
innerArrayTy.getElementType() == f32Ty)) {
@@ -207,7 +207,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
// of shape (NumRegisters, VectorRegister) where VectorRegister is the
// vector type of the result and always 32 bits long. We bitcast the result
// of the NVVM::LdMatrix to this vector type.
- auto vectorResultType = op->getResultTypes()[0].dyn_cast<VectorType>();
+ auto vectorResultType = dyn_cast<VectorType>(op->getResultTypes()[0]);
if (!vectorResultType) {
return failure();
}
@@ -224,7 +224,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
ldMatrixResultType = rewriter.getI32Type();
}
- auto srcMemrefType = op.getSrcMemref().getType().cast<MemRefType>();
+ auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
Value srcPtr =
getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
adaptor.getIndices(), rewriter);
@@ -307,7 +307,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
// TODO: add an attribute to the op to customize this behavior.
std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
- if (aType.getElementType().isa<IntegerType>())
+ if (isa<IntegerType>(aType.getElementType()))
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
@@ -388,7 +388,7 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
// constant.
auto dstByteConstOp =
dyn_cast<mlir::LLVM::ConstantOp>(dstBytes.getDefiningOp());
- auto dstByteAttr = dstByteConstOp.getValue().dyn_cast<mlir::IntegerAttr>();
+ auto dstByteAttr = dyn_cast<mlir::IntegerAttr>(dstByteConstOp.getValue());
int64_t dstByteVal = dstByteAttr.getValue().getSExtValue();
assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) &&
@@ -537,7 +537,7 @@ struct NVGPUMmaSparseSyncLowering
// TODO: add an attribute to the op to customize this behavior.
std::optional<NVVM::MMAIntOverflow> overflow(std::nullopt);
- if (aType.getElementType().isa<IntegerType>())
+ if (isa<IntegerType>(aType.getElementType()))
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
@@ -585,7 +585,7 @@ struct NVGPUAsyncCopyLowering
matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
- auto dstMemrefType = op.getDst().getType().cast<MemRefType>();
+ auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
adaptor.getDstIndices(), rewriter);
auto i8Ty = IntegerType::get(op.getContext(), 8);
@@ -599,7 +599,7 @@ struct NVGPUAsyncCopyLowering
if (!getTypeConverter()->useOpaquePointers())
dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
- auto srcMemrefType = op.getSrc().getType().cast<MemRefType>();
+ auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
FailureOr<unsigned> srcAddressSpace =
getTypeConverter()->getMemRefAddressSpace(srcMemrefType);
if (failed(srcAddressSpace))
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 4f6763b4558d6..34ec00f5f827a 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -70,7 +70,7 @@ struct RegionLessOpWithVarOperandsConversion
Value originalVariableOperand = curOp.getVariableOperand(idx);
if (!originalVariableOperand)
return failure();
- if (originalVariableOperand.getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(originalVariableOperand.getType())) {
// TODO: Support memref type in variable operands
return rewriter.notifyMatchFailure(curOp,
"memref is not supported yet");
@@ -101,7 +101,7 @@ struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
Value originalVariableOperand = curOp.getVariableOperand(idx);
if (!originalVariableOperand)
return failure();
- if (originalVariableOperand.getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(originalVariableOperand.getType())) {
// TODO: Support memref type in variable operands
return rewriter.notifyMatchFailure(curOp,
"memref is not supported yet");
@@ -143,7 +143,7 @@ struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> {
LogicalResult
matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (curOp.getAccumulator().getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(curOp.getAccumulator().getType())) {
// TODO: Support memref type in variable operands
return rewriter.notifyMatchFailure(curOp, "memref is not supported yet");
}
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index fc0c845b69987..8cd180d450eb0 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -219,7 +219,7 @@ Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) {
// If this value corresponds to an operation, record that we are going to use
// its location as part of a fused location.
- bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
+ bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
if (isOperationValue)
locOps.insert(val);
@@ -280,7 +280,7 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
// The first operation retrieves the representative value of a range.
// This applies only when the parent is a range of values and we were
// requested to use a representative value (e.g., upward traversal).
- if (parentVal.getType().isa<pdl::RangeType>() &&
+ if (isa<pdl::RangeType>(parentVal.getType()) &&
usersPos->useRepresentative())
value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
else
@@ -327,7 +327,7 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
break;
}
case Predicates::TypePos: {
- if (parentVal.getType().isa<pdl::AttributeType>())
+ if (isa<pdl::AttributeType>(parentVal.getType()))
value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
else
value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
@@ -357,11 +357,11 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
case Predicates::TypeLiteralPos: {
auto *typePos = cast<TypeLiteralPosition>(pos);
Attribute rawTypeAttr = typePos->getValue();
- if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
+ if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
else
value = builder.create<pdl_interp::CreateTypesOp>(
- loc, rawTypeAttr.cast<ArrayAttr>());
+ loc, cast<ArrayAttr>(rawTypeAttr));
break;
}
default:
@@ -410,7 +410,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
}
case Predicates::TypeQuestion: {
auto *ans = cast<TypeAnswer>(answer);
- if (val.getType().isa<pdl::RangeType>())
+ if (isa<pdl::RangeType>(val.getType()))
builder.create<pdl_interp::CheckTypesOp>(
loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
else
@@ -554,7 +554,7 @@ void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
OperationNameAnswer>(val, defaultDest, builder,
children);
case Predicates::TypeQuestion:
- if (val.getType().isa<pdl::RangeType>()) {
+ if (isa<pdl::RangeType>(val.getType())) {
return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
val, defaultDest, builder, children);
}
@@ -745,7 +745,7 @@ void PatternLowering::generateRewriter(
// Handle the case where there is a single range representing all of the
// result types.
OperandRange resultTys = operationOp.getTypeValues();
- if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
+ if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
Value &type = rewriteValues[resultTys[0]];
if (!type) {
auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
@@ -762,7 +762,7 @@ void PatternLowering::generateRewriter(
Value &type = rewriteValues[it.value()];
if (type)
continue;
- bool isVariadic = it.value().getType().isa<pdl::RangeType>();
+ bool isVariadic = isa<pdl::RangeType>(it.value().getType());
seenVariableLength |= isVariadic;
// After a variable length result has been seen, we need to use result
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 034291440ad2c..7078e238b86df 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -41,14 +41,14 @@ static bool comparePosDepth(Position *lhs, Position *rhs) {
/// Returns the number of non-range elements within `values`.
static unsigned getNumNonRangeValues(ValueRange values) {
return llvm::count_if(values.getTypes(),
- [](Type type) { return !type.isa<pdl::RangeType>(); });
+ [](Type type) { return !isa<pdl::RangeType>(type); });
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
AttributePosition *pos) {
- assert(val.getType().isa<pdl::AttributeType>() && "expected attribute type");
+ assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
predList.emplace_back(pos, builder.getIsNotNull());
@@ -65,7 +65,7 @@ static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
DenseMap<Value, Position *> &inputs,
Position *pos) {
Type valueType = val.getType();
- bool isVariadic = valueType.isa<pdl::RangeType>();
+ bool isVariadic = isa<pdl::RangeType>(valueType);
// If this is a typed operand, add a type constraint.
TypeSwitch<Operation *>(val.getDefiningOp())
@@ -111,7 +111,7 @@ getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs, OperationPosition *pos,
std::optional<unsigned> ignoreOperand = std::nullopt) {
- assert(val.getType().isa<pdl::OperationType>() && "expected operation");
+ assert(isa<pdl::OperationType>(val.getType()) && "expected operation");
pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
OperationPosition *opPos = cast<OperationPosition>(pos);
@@ -148,7 +148,7 @@ getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
llvm::zip(op.getAttributeValueNames(), op.getAttributeValues())) {
getTreePredicates(
predList, attr, builder, inputs,
- builder.getAttribute(opPos, attrName.cast<StringAttr>().getValue()));
+ builder.getAttribute(opPos, cast<StringAttr>(attrName).getValue()));
}
// Process the operands and results of the operation. For all values up to
@@ -157,7 +157,7 @@ getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
// concrete indices until runtime. If there is only one variadic operand
// group, we treat it as all of the operands/results of the operation.
/// Operands.
- if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
+ if (operands.size() == 1 && isa<pdl::RangeType>(operands[0].getType())) {
// Ignore the operands if we are performing an upward traversal (in that
// case, they have already been visited).
if (opPos->isRoot() || opPos->isOperandDefiningOp())
@@ -166,7 +166,7 @@ getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
} else {
bool foundVariableLength = false;
for (const auto &operandIt : llvm::enumerate(operands)) {
- bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
+ bool isVariadic = isa<pdl::RangeType>(operandIt.value().getType());
foundVariableLength |= isVariadic;
// Ignore the specified operand, usually because this position was
@@ -182,7 +182,7 @@ getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
}
}
/// Results.
- if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
+ if (types.size() == 1 && isa<pdl::RangeType>(types[0].getType())) {
getTreePredicates(predList, types.front(), builder, inputs,
builder.getType(builder.getAllResults(opPos)));
return;
@@ -190,7 +190,7 @@ getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
bool foundVariableLength = false;
for (auto [idx, typeValue] : llvm::enumerate(types)) {
- bool isVariadic = typeValue.getType().isa<pdl::RangeType>();
+ bool isVariadic = isa<pdl::RangeType>(typeValue.getType());
foundVariableLength |= isVariadic;
auto *resultPos = foundVariableLength
@@ -301,7 +301,7 @@ static void getResultPredicates(pdl::ResultsOp op,
// Ensure that the result isn't null if the result has an index.
auto *parentPos = cast<OperationPosition>(inputs.lookup(op.getParent()));
- bool isVariadic = op.getType().isa<pdl::RangeType>();
+ bool isVariadic = isa<pdl::RangeType>(op.getType());
std::optional<unsigned> index = op.getIndex();
resultPos = builder.getResultGroup(parentPos, index, isVariadic);
if (index)
@@ -458,7 +458,7 @@ static void buildCostGraph(ArrayRef<Value> roots, RootOrderingGraph &graph,
// Special case when we pass all the operands in one range.
// For those, the index is empty.
if (operands.size() == 1 &&
- operands[0].getType().isa<pdl::RangeType>()) {
+ isa<pdl::RangeType>(operands[0].getType())) {
toVisit.emplace(operands[0], entry.value, std::nullopt,
entry.depth + 1);
return;
@@ -514,7 +514,7 @@ static bool useOperandGroup(pdl::OperationOp op, unsigned index) {
OperandRange operands = op.getOperandValues();
assert(index < operands.size() && "operand index out of range");
for (unsigned i = 0; i <= index; ++i)
- if (operands[i].getType().isa<pdl::RangeType>())
+ if (isa<pdl::RangeType>(operands[i].getType()))
return true;
return false;
}
@@ -542,7 +542,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
} else if (useOperandGroup(operationOp, *opIndex.index)) {
// We are querying an operand group.
Type type = operationOp.getOperandValues()[*opIndex.index].getType();
- bool variadic = type.isa<pdl::RangeType>();
+ bool variadic = isa<pdl::RangeType>(type);
operandPos = builder.getOperandGroup(opPos, opIndex.index, variadic);
} else {
// We are querying an individual operand.
@@ -578,7 +578,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList,
// Traverse up a group of results.
auto *opPos = dyn_cast<OperationPosition>(pos);
assert(opPos && "operations and results must be interleaved");
- bool isVariadic = value.getType().isa<pdl::RangeType>();
+ bool isVariadic = isa<pdl::RangeType>(value.getType());
if (opIndex.index)
pos = builder.getResultGroup(opPos, opIndex.index, isVariadic);
else
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index c1a57d37a20e0..f9245ad6be9c7 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -441,7 +441,7 @@ static LogicalResult processParallelLoop(
Value iv, lowerBound, upperBound, step;
std::tie(mappingAttribute, iv, lowerBound, upperBound, step) = config;
auto annotation =
- mappingAttribute.dyn_cast<gpu::ParallelLoopDimMappingAttr>();
+ dyn_cast<gpu::ParallelLoopDimMappingAttr>(mappingAttribute);
if (!annotation)
return parallelOp.emitOpError()
<< "expected mapping attribute for lowering to GPU";
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index e91cd0c24171e..50087751053ba 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -51,8 +51,7 @@ static bool matchSimpleReduction(Block &block) {
Value reducedVal = matchReduction({block.getArguments()[1]},
/*redPos=*/0, combinerOps);
- if (!reducedVal || !reducedVal.isa<BlockArgument>() ||
- combinerOps.size() != 1)
+ if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1)
return false;
return isa<OpTy...>(combinerOps[0]) &&
@@ -155,7 +154,7 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
/// Returns an attribute with the minimum (if `min` is set) or the maximum value
/// (otherwise) for the given float type.
static Attribute minMaxValueForFloat(Type type, bool min) {
- auto fltType = type.cast<FloatType>();
+ auto fltType = cast<FloatType>(type);
return FloatAttr::get(
type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
}
@@ -164,7 +163,7 @@ static Attribute minMaxValueForFloat(Type type, bool min) {
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForSignedInt(Type type, bool min) {
- auto intType = type.cast<IntegerType>();
+ auto intType = cast<IntegerType>(type);
unsigned bitwidth = intType.getWidth();
return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
: llvm::APInt::getSignedMaxValue(bitwidth));
@@ -174,7 +173,7 @@ static Attribute minMaxValueForSignedInt(Type type, bool min) {
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
- auto intType = type.cast<IntegerType>();
+ auto intType = cast<IntegerType>(type);
unsigned bitwidth = intType.getWidth();
return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
: llvm::APInt::getAllOnes(bitwidth));
@@ -388,7 +387,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
reductionVariables.reserve(parallelOp.getNumReductions());
for (Value init : parallelOp.getInitVals()) {
assert((LLVM::isCompatibleType(init.getType()) ||
- init.getType().isa<LLVM::PointerElementTypeInterface>()) &&
+ isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
"cannot create a reduction variable if the type is not an LLVM "
"pointer element");
Value storage = rewriter.create<LLVM::AllocaOp>(
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 08da805fe1295..4e17c966cc429 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -220,9 +220,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
for (const auto &operand : llvm::enumerate(kernelOperands)) {
// Check if the kernel's operand is a ranked memref.
- auto memRefType = launchOp.getKernelOperand(operand.index())
- .getType()
- .dyn_cast<MemRefType>();
+ auto memRefType = dyn_cast<MemRefType>(
+ launchOp.getKernelOperand(operand.index()).getType());
if (!memRefType)
return failure();
@@ -241,7 +240,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
// LLVM dialect global variable.
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
auto pointeeType =
- spirvGlobal.getType().cast<spirv::PointerType>().getPointeeType();
+ cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
auto dstGlobalType = typeConverter->convertType(pointeeType);
if (!dstGlobalType)
return failure();
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index b93894757daa5..8b4380870a26d 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -37,7 +37,7 @@ using namespace mlir;
static bool isSignedIntegerOrVector(Type type) {
if (type.isSignedInteger())
return true;
- if (auto vecType = type.dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(type))
return vecType.getElementType().isSignedInteger();
return false;
}
@@ -46,18 +46,18 @@ static bool isSignedIntegerOrVector(Type type) {
static bool isUnsignedIntegerOrVector(Type type) {
if (type.isUnsignedInteger())
return true;
- if (auto vecType = type.dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(type))
return vecType.getElementType().isUnsignedInteger();
return false;
}
/// Returns the bit width of integer, float or vector of float or integer values
static unsigned getBitWidth(Type type) {
- assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
+ assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
"bitwidth is not supported for this type");
if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();
- auto vecType = type.dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(type);
auto elementType = vecType.getElementType();
assert(elementType.isIntOrFloat() &&
"only integers and floats have a bitwidth");
@@ -66,29 +66,29 @@ static unsigned getBitWidth(Type type) {
/// Returns the bit width of LLVMType integer or vector.
static unsigned getLLVMTypeBitWidth(Type type) {
- return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type)
- : type)
- .cast<IntegerType>()
+ return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
+ ? LLVM::getVectorElementType(type)
+ : type))
.getWidth();
}
/// Creates `IntegerAttribute` with all bits set for given type
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
- if (auto vecType = type.dyn_cast<VectorType>()) {
- auto integerType = vecType.getElementType().cast<IntegerType>();
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ auto integerType = cast<IntegerType>(vecType.getElementType());
return builder.getIntegerAttr(integerType, -1);
}
- auto integerType = type.cast<IntegerType>();
+ auto integerType = cast<IntegerType>(type);
return builder.getIntegerAttr(integerType, -1);
}
/// Creates `llvm.mlir.constant` with all bits set for the given type.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
PatternRewriter &rewriter) {
- if (srcType.isa<VectorType>()) {
+ if (isa<VectorType>(srcType)) {
return rewriter.create<LLVM::ConstantOp>(
loc, dstType,
- SplatElementsAttr::get(srcType.cast<ShapedType>(),
+ SplatElementsAttr::get(cast<ShapedType>(srcType),
minusOneIntegerAttribute(srcType, rewriter)));
}
return rewriter.create<LLVM::ConstantOp>(
@@ -98,14 +98,14 @@ static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
static Value createFPConstant(Location loc, Type srcType, Type dstType,
PatternRewriter &rewriter, double value) {
- if (auto vecType = srcType.dyn_cast<VectorType>()) {
- auto floatType = vecType.getElementType().cast<FloatType>();
+ if (auto vecType = dyn_cast<VectorType>(srcType)) {
+ auto floatType = cast<FloatType>(vecType.getElementType());
return rewriter.create<LLVM::ConstantOp>(
loc, dstType,
SplatElementsAttr::get(vecType,
rewriter.getFloatAttr(floatType, value)));
}
- auto floatType = srcType.cast<FloatType>();
+ auto floatType = cast<FloatType>(srcType);
return rewriter.create<LLVM::ConstantOp>(
loc, dstType, rewriter.getFloatAttr(floatType, value));
}
@@ -157,7 +157,7 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
static Value optionallyBroadcast(Location loc, Value value, Type srcType,
LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
- if (auto vectorType = srcType.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(srcType)) {
unsigned numElements = vectorType.getNumElements();
return broadcast(loc, value, numElements, typeConverter, rewriter);
}
@@ -251,7 +251,7 @@ static std::optional<Type> convertArrayType(spirv::ArrayType type,
TypeConverter &converter) {
unsigned stride = type.getArrayStride();
Type elementType = type.getElementType();
- auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
+ auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
return std::nullopt;
@@ -319,10 +319,9 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
indices.insert(indices.begin(), zero);
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(
op, dstType,
- typeConverter.convertType(op.getBasePtr()
- .getType()
- .cast<spirv::PointerType>()
- .getPointeeType()),
+ typeConverter.convertType(
+ cast<spirv::PointerType>(op.getBasePtr().getType())
+ .getPointeeType()),
adaptor.getBasePtr(), indices);
return success();
}
@@ -397,7 +396,7 @@ class ConstantScalarAndVectorPattern
matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = constOp.getType();
- if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
+ if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
return failure();
auto dstType = typeConverter.convertType(srcType);
@@ -413,15 +412,15 @@ class ConstantScalarAndVectorPattern
isUnsignedIntegerOrVector(srcType)) {
auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
- if (srcType.isa<VectorType>()) {
- auto dstElementsAttr = constOp.getValue().cast<DenseIntElementsAttr>();
+ if (isa<VectorType>(srcType)) {
+ auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
constOp, dstType,
dstElementsAttr.mapValues(
signlessType, [&](const APInt &value) { return value; }));
return success();
}
- auto srcAttr = constOp.getValue().cast<IntegerAttr>();
+ auto srcAttr = cast<IntegerAttr>(constOp.getValue());
auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
return success();
@@ -454,17 +453,17 @@ class BitFieldSExtractPattern
// Create a constant that holds the size of the `Base`.
IntegerType integerType;
- if (auto vecType = srcType.dyn_cast<VectorType>())
- integerType = vecType.getElementType().cast<IntegerType>();
+ if (auto vecType = dyn_cast<VectorType>(srcType))
+ integerType = cast<IntegerType>(vecType.getElementType());
else
- integerType = srcType.cast<IntegerType>();
+ integerType = cast<IntegerType>(srcType);
auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
Value size =
- srcType.isa<VectorType>()
+ isa<VectorType>(srcType)
? rewriter.create<LLVM::ConstantOp>(
loc, dstType,
- SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
+ SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
: rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
// Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
@@ -573,9 +572,9 @@ class CompositeExtractPattern
return failure();
Type containerType = op.getComposite().getType();
- if (containerType.isa<VectorType>()) {
+ if (isa<VectorType>(containerType)) {
Location loc = op.getLoc();
- IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
+ IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
op, dstType, adaptor.getComposite(), index);
@@ -605,9 +604,9 @@ class CompositeInsertPattern
return failure();
Type containerType = op.getComposite().getType();
- if (containerType.isa<VectorType>()) {
+ if (isa<VectorType>(containerType)) {
Location loc = op.getLoc();
- IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
+ IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
@@ -732,7 +731,7 @@ class GlobalVariablePattern
if (op.getInitializer())
return failure();
- auto srcType = op.getType().cast<spirv::PointerType>();
+ auto srcType = cast<spirv::PointerType>(op.getType());
auto dstType = typeConverter.convertType(srcType.getPointeeType());
if (!dstType)
return failure();
@@ -946,12 +945,12 @@ class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
Location loc = notOp.getLoc();
IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
- auto mask = srcType.template isa<VectorType>()
- ? rewriter.create<LLVM::ConstantOp>(
- loc, dstType,
- SplatElementsAttr::get(
- srcType.template cast<VectorType>(), minusOne))
- : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
+ auto mask =
+ isa<VectorType>(srcType)
+ ? rewriter.create<LLVM::ConstantOp>(
+ loc, dstType,
+ SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
+ : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
notOp.getOperand(), mask);
return success();
@@ -1262,9 +1261,9 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
ConversionPatternRewriter &rewriter) const override {
auto srcType = varOp.getType();
// Initialization is supported for scalars and vectors only.
- auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
+ auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
auto init = varOp.getInitializer();
- if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
+ if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
return failure();
auto dstType = typeConverter.convertType(srcType);
@@ -1303,7 +1302,7 @@ class BitcastConversionPattern
return failure();
if (typeConverter.useOpaquePointers() &&
- dstType.isa<LLVM::LLVMPointerType>()) {
+ isa<LLVM::LLVMPointerType>(dstType)) {
rewriter.replaceOp(bitcastOp, adaptor.getOperand());
return success();
}
@@ -1416,8 +1415,8 @@ class VectorShufflePattern
auto components = adaptor.getComponents();
auto vector1 = adaptor.getVector1();
auto vector2 = adaptor.getVector2();
- int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
- int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
+ int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
+ int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
if (vector1Size == vector2Size) {
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
op, vector1, vector2,
@@ -1426,16 +1425,16 @@ class VectorShufflePattern
}
auto dstType = typeConverter.convertType(op.getType());
- auto scalarType = dstType.cast<VectorType>().getElementType();
+ auto scalarType = cast<VectorType>(dstType).getElementType();
auto componentsArray = components.getValue();
auto *context = rewriter.getContext();
auto llvmI32Type = IntegerType::get(context, 32);
Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
for (unsigned i = 0; i < componentsArray.size(); i++) {
- if (!componentsArray[i].isa<IntegerAttr>())
+ if (!isa<IntegerAttr>(componentsArray[i]))
return op.emitError("unable to support non-constant component");
- int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
+ int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
if (indexVal == -1)
continue;
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 16cbfca3f3e2a..a3e51aeed0735 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -59,7 +59,7 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// For now, only error-free types are supported by this lowering.
- if (op.getType().template isa<SizeType>())
+ if (isa<SizeType>(op.getType()))
return failure();
rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
@@ -127,7 +127,7 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
// on shapes.
- if (op.getType().isa<ShapeType>())
+ if (isa<ShapeType>(op.getType()))
return failure();
auto loc = op.getLoc();
@@ -189,7 +189,7 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
// For now, this lowering supports only extent tensors, not `shape.shape`
// types.
- if (op.getType().isa<ShapeType>())
+ if (isa<ShapeType>(op.getType()))
return failure();
auto loc = op.getLoc();
@@ -242,7 +242,7 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
// on shapes.
if (!llvm::all_of(op.getShapes(),
- [](Value v) { return !v.getType().isa<ShapeType>(); }))
+ [](Value v) { return !isa<ShapeType>(v.getType()); }))
return failure();
auto loc = op.getLoc();
@@ -363,13 +363,13 @@ LogicalResult GetExtentOpConverter::matchAndRewrite(
GetExtentOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// For now, only error-free types are supported by this lowering.
- if (op.getType().isa<SizeType>())
+ if (isa<SizeType>(op.getType()))
return failure();
// Derive shape extent directly from shape origin if possible. This
// circumvents the necessity to materialize the shape in memory.
if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
- if (shapeOfOp.getArg().getType().isa<ShapedType>()) {
+ if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
adaptor.getDim());
return success();
@@ -397,7 +397,7 @@ LogicalResult
RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering supports only error-free types.
- if (op.getType().isa<SizeType>())
+ if (isa<SizeType>(op.getType()))
return failure();
rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
@@ -420,7 +420,7 @@ LogicalResult
ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands.
- if (op.getShape().getType().isa<ShapeType>())
+ if (isa<ShapeType>(op.getShape().getType()))
return failure();
auto loc = op.getLoc();
@@ -499,7 +499,7 @@ LogicalResult
ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (!llvm::all_of(op.getShapes(),
- [](Value v) { return !v.getType().isa<ShapeType>(); }))
+ [](Value v) { return !isa<ShapeType>(v.getType()); }))
return failure();
Type i1Ty = rewriter.getI1Type();
@@ -570,18 +570,18 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
// For now, only error-free types are supported by this lowering.
- if (op.getType().isa<ShapeType>())
+ if (isa<ShapeType>(op.getType()))
return failure();
// For ranked tensor arguments, lower to `tensor.from_elements`.
auto loc = op.getLoc();
Value tensor = adaptor.getArg();
Type tensorTy = tensor.getType();
- if (tensorTy.isa<RankedTensorType>()) {
+ if (isa<RankedTensorType>(tensorTy)) {
// Build values for individual extents.
SmallVector<Value, 8> extentValues;
- RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
+ RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
int64_t rank = rankedTensorTy.getRank();
for (int64_t i = 0; i < rank; i++) {
if (rankedTensorTy.isDynamicDim(i)) {
@@ -634,7 +634,7 @@ LogicalResult SplitAtOpConversion::matchAndRewrite(
// Error conditions are not implemented, only lower if all operands and
// results are extent tensors.
if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
- [](Value v) { return v.getType().isa<ShapeType>(); }))
+ [](Value v) { return isa<ShapeType>(v.getType()); }))
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
@@ -667,7 +667,7 @@ class ToExtentTensorOpConversion
LogicalResult
matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (!adaptor.getInput().getType().isa<RankedTensorType>())
+ if (!isa<RankedTensorType>(adaptor.getInput().getType()))
return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
index ed13ab3fd8c0d..373952c5201b2 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
@@ -44,7 +44,7 @@ class TensorExtractPattern final
LogicalResult
matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto tensorType = extractOp.getTensor().getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
if (!tensorType.hasStaticShape())
return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 1790e3d0212c4..c025fb9e1367d 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -34,14 +34,14 @@ class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
};
Type matchContainerType(Type element, Type container) {
- if (auto shapedTy = container.dyn_cast<ShapedType>())
+ if (auto shapedTy = dyn_cast<ShapedType>(container))
return shapedTy.clone(element);
return element;
}
TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
- if (auto shapedTy = type.dyn_cast<ShapedType>()) {
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
Type eTy = shapedTy.getElementType();
APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
return DenseIntElementsAttr::get(shapedTy, valueInt);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3f970befa38dc..6aa075158e7fd 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -36,7 +36,7 @@ static arith::ConstantOp
createConstFromIntAttribute(Operation *op, const std::string &attrName,
Type requiredAttrType, OpBuilder &rewriter) {
auto castedN = static_cast<T>(
- op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
+ cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
return rewriter.create<arith::ConstantOp>(
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
@@ -47,13 +47,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
auto elementTy =
- op->getOperand(0).getType().cast<ShapedType>().getElementType();
+ cast<ShapedType>(op->getOperand(0).getType()).getElementType();
// tosa::AbsOp
- if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
- if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementTy));
auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
@@ -63,21 +63,21 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
// tosa::AddOp
- if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
- if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
// tosa::SubOp
- if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
- if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
// tosa::MulOp
- if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
(void)rewriter.notifyMatchFailure(op,
"Cannot have shift value for float");
@@ -87,21 +87,21 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
// tosa::DivOp
- if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::DivOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
// tosa::ReciprocalOp
- if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
auto one =
rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
}
- if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
Value a = args[0];
Value b = args[1];
auto shift =
- op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
+ cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
if (shift > 0) {
auto shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
@@ -134,17 +134,17 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
// tosa::NegateOp
- if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
- if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
+ if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
!cast<tosa::NegateOp>(op).getQuantizationInfo()) {
auto constant =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
}
- if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
+ if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
cast<tosa::NegateOp>(op).getQuantizationInfo()) {
auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
@@ -190,15 +190,15 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
// tosa::BitwiseAndOp
- if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
// tosa::BitwiseOrOp
- if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
// tosa::BitwiseNotOp
- if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
auto allOnesAttr = rewriter.getIntegerAttr(
elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
@@ -206,21 +206,21 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
// tosa::BitwiseXOrOp
- if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
// tosa::LogicalLeftShiftOp
- if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
// tosa::LogicalRightShiftOp
- if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
// tosa::ArithmeticRightShiftOp
- if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
- auto round = op->getAttr("round").cast<BoolAttr>().getValue();
+ auto round = cast<BoolAttr>(op->getAttr("round")).getValue();
if (!round) {
return result;
}
@@ -256,7 +256,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
// tosa::ClzOp
- if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
}
@@ -280,27 +280,27 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
// tosa::PowOp
- if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
// tosa::RsqrtOp
- if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
// tosa::LogOp
- if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
// tosa::ExpOp
- if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
// tosa::TanhOp
- if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
// tosa::GreaterOp
- if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
args[0], args[1]);
@@ -309,7 +309,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
args[0], args[1]);
// tosa::GreaterEqualOp
- if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
args[0], args[1]);
@@ -318,7 +318,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
args[0], args[1]);
// tosa::EqualOp
- if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
args[0], args[1]);
@@ -328,13 +328,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
// tosa::SelectOp
if (isa<tosa::SelectOp>(op)) {
- elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
- if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
+ elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
+ if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
}
// tosa::MaximumOp
- if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
}
@@ -345,7 +345,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
// tosa::MinimumOp
- if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
}
@@ -356,21 +356,21 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
// tosa::CeilOp
- if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<math::CeilOp>(loc, resultTypes, args);
// tosa::FloorOp
- if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
return rewriter.create<math::FloorOp>(loc, resultTypes, args);
// tosa::ClampOp
- if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
bool losesInfo = false;
- APFloat minApf = op->getAttr("min_fp").cast<FloatAttr>().getValue();
- APFloat maxApf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
- minApf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
+ APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
+ APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
+ minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
- maxApf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
+ maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
auto min = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
@@ -379,12 +379,12 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return clampFloatHelper(loc, args[0], min, max, rewriter);
}
- if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
- auto intTy = elementTy.cast<IntegerType>();
+ if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
+ auto intTy = cast<IntegerType>(elementTy);
int32_t min = static_cast<int32_t>(
- op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
+ cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
int32_t max = static_cast<int32_t>(
- op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
+ cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
if (intTy.isUnsignedInteger()) {
min = std::max<int32_t>(min, 0);
@@ -408,7 +408,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
}
// tosa::SigmoidOp
- if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
auto one =
rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
@@ -427,11 +427,11 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (srcTy == dstTy)
return args.front();
- if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
+ if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
std::nullopt);
- if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
+ if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
std::nullopt);
@@ -440,13 +440,13 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
std::nullopt);
- if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
+ if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
std::nullopt);
// Unsigned integers need an unrealized cast so that they can be passed
// to UIToFP.
- if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
+ if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
auto unrealizedCast =
rewriter
.create<UnrealizedConversionCastOp>(
@@ -463,7 +463,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
std::nullopt);
// Casting to boolean, floats need to only be checked as not-equal to zero.
- if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
+ if (isa<FloatType>(srcTy) && dstTy.isInteger(1)) {
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(srcTy, 0.0));
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
@@ -490,18 +490,18 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
// Casting to boolean, integers need to only be checked as not-equal to
// zero.
- if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
+ if (isa<IntegerType>(srcTy) && dstTy.isInteger(1)) {
Value zero = rewriter.create<arith::ConstantIntOp>(
loc, 0, srcTy.getIntOrFloatBitWidth());
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
args.front(), zero);
}
- if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
+ if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
std::nullopt);
- if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
+ if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
}
}
@@ -520,7 +520,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
"All TOSA elementwise ops should only return a single result.");
auto results = operation->getResults();
- auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
+ auto resultTy = dyn_cast<ShapedType>(operation->getResult(0).getType());
if (!resultTy)
return rewriter.notifyMatchFailure(operation,
@@ -538,10 +538,10 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
SmallVector<Value> emptyTensors;
SmallVector<Value> dynDims;
- dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
+ dynDims.resize(cast<ShapedType>(results.front().getType()).getRank());
for (auto arg : operation->getOperands()) {
- auto operandTy = arg.getType().cast<ShapedType>();
+ auto operandTy = cast<ShapedType>(arg.getType());
for (int i = 0; i < operandTy.getRank(); i++) {
if (operandTy.isDynamicDim(i) && !dynDims[i])
dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
@@ -551,7 +551,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
SmallVector<Value> filteredDims = condenseValues(dynDims);
for (auto result : results) {
- auto resultTy = result.getType().template cast<ShapedType>();
+ auto resultTy = cast<ShapedType>(result.getType());
emptyTensors.push_back(rewriter.create<tensor::EmptyOp>(
loc, resultTy.getShape(), resultTy.getElementType(), filteredDims));
opResultTypes.push_back(result.getType());
@@ -566,7 +566,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
// Input indexing maps may be broadcasted.
for (Value operand : operation->getOperands()) {
- ShapedType type = operand.getType().cast<ShapedType>();
+ ShapedType type = cast<ShapedType>(operand.getType());
if (type.getShape() == resultTy.getShape()) {
operands.push_back(operand);
@@ -627,33 +627,33 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
// attribute type varies depending on the element type required.
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
- if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(elementTy, 0.0);
- if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(elementTy, 0);
- if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(elementTy, 1.0);
- if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(elementTy, 1);
- if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
- elementTy.cast<FloatType>().getFloatSemantics(), false));
+ cast<FloatType>(elementTy).getFloatSemantics(), false));
- if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
- if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
- elementTy.cast<FloatType>().getFloatSemantics(), true));
+ cast<FloatType>(elementTy).getFloatSemantics(), true));
- if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
@@ -663,12 +663,12 @@ static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
- if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
- elementTy.cast<FloatType>().getFloatSemantics(), true));
+ cast<FloatType>(elementTy).getFloatSemantics(), true));
- if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
+ if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
@@ -682,37 +682,37 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
Type elementTy,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
- if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::AddFOp>(loc, args);
}
- if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
return rewriter.create<arith::AddIOp>(loc, args);
}
- if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MulFOp>(loc, args);
}
- if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
return rewriter.create<arith::MulIOp>(loc, args);
}
- if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MinFOp>(loc, args[0], args[1]);
}
- if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
}
- if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
+ if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
return rewriter.create<arith::MaxFOp>(loc, args[0], args[1]);
}
- if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
+ if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
auto predicate = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, args[0], args[1]);
return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
@@ -733,8 +733,8 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
PatternRewriter &rewriter) {
auto loc = op->getLoc();
- auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
- auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
+ auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
+ auto resultTy = cast<ShapedType>(op->getResult(0).getType());
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);
@@ -799,7 +799,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
SmallVector<ReassociationExprs, 4> reassociationMap;
uint64_t expandInputRank =
- linalgOp.getResults()[0].getType().cast<ShapedType>().getRank();
+ cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
reassociationMap.resize(expandInputRank);
for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -848,14 +848,14 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
auto loc = op.getLoc();
auto input = op->getOperand(0);
- auto resultTy = op.getType().cast<ShapedType>();
+ auto resultTy = cast<ShapedType>(op.getType());
SmallVector<Value> dynDims;
- dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
+ dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultTy.getRank());
- auto operandTy = input.getType().cast<ShapedType>();
+ auto operandTy = cast<ShapedType>(input.getType());
for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
auto index = permutation.index();
auto value = permutation.value().getZExtValue();
@@ -893,8 +893,8 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
auto input = op.getInput();
- auto inputTy = op.getInput().getType().cast<ShapedType>();
- auto outputTy = op.getOutput().getType().cast<ShapedType>();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ auto outputTy = cast<ShapedType>(op.getOutput().getType());
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
@@ -1036,7 +1036,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// Saturate to the output size.
IntegerType outIntType =
- blockArgs.back().getType().cast<IntegerType>();
+ cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
@@ -1089,8 +1089,8 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
Location loc = op.getLoc();
ImplicitLocOpBuilder builder(loc, rewriter);
auto input = op.getInput();
- auto inputTy = input.getType().cast<RankedTensorType>();
- auto resultTy = op.getType().cast<RankedTensorType>();
+ auto inputTy = cast<RankedTensorType>(input.getType());
+ auto resultTy = cast<RankedTensorType>(op.getType());
const bool isBilinear = op.getMode() == "BILINEAR";
auto inputH = inputTy.getDimSize(1);
@@ -1186,8 +1186,8 @@ class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
Location loc = op.getLoc();
ImplicitLocOpBuilder builder(loc, rewriter);
auto input = op.getInput();
- auto inputTy = input.getType().dyn_cast<RankedTensorType>();
- auto resultTy = op.getType().dyn_cast<RankedTensorType>();
+ auto inputTy = dyn_cast<RankedTensorType>(input.getType());
+ auto resultTy = dyn_cast<RankedTensorType>(op.getType());
if (!inputTy || !resultTy)
return rewriter.notifyMatchFailure(op,
@@ -1282,8 +1282,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Location loc = op.getLoc();
ImplicitLocOpBuilder b(loc, rewriter);
auto input = op.getInput();
- auto inputTy = input.getType().cast<ShapedType>();
- auto resultTy = op.getType().cast<ShapedType>();
+ auto inputTy = cast<ShapedType>(input.getType());
+ auto resultTy = cast<ShapedType>(op.getType());
auto resultETy = resultTy.getElementType();
auto imageH = inputTy.getShape()[1];
@@ -1573,8 +1573,8 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
Value input = op.getInput();
- auto inputTy = input.getType().template cast<ShapedType>();
- auto resultTy = op.getType().template cast<ShapedType>();
+ auto inputTy = cast<ShapedType>(input.getType());
+ auto resultTy = cast<ShapedType>(op.getType());
auto axis = op.getAxis();
SmallVector<Value> dynDims;
@@ -1635,9 +1635,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto input = op.getInput1();
- auto inputTy = input.getType().cast<ShapedType>();
+ auto inputTy = cast<ShapedType>(input.getType());
auto inputShape = inputTy.getShape();
- auto resultTy = op.getType().cast<ShapedType>();
+ auto resultTy = cast<ShapedType>(op.getType());
auto elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
@@ -1710,14 +1710,14 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
PatternRewriter &rewriter) const final {
auto loc = argmaxOp.getLoc();
Value input = argmaxOp.getInput();
- auto inputTy = input.getType().cast<ShapedType>();
- auto resultTy = argmaxOp.getOutput().getType().cast<ShapedType>();
+ auto inputTy = cast<ShapedType>(input.getType());
+ auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
auto inElementTy = inputTy.getElementType();
auto outElementTy = resultTy.getElementType();
int axis = argmaxOp.getAxis();
auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
- if (!outElementTy.isa<IntegerType>())
+ if (!isa<IntegerType>(outElementTy))
return rewriter.notifyMatchFailure(
argmaxOp,
"tosa.arg_max to linalg.* requires integer-like result type");
@@ -1792,10 +1792,10 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
rewriter.create<linalg::IndexOp>(loc, axis));
Value predicate;
- if (inElementTy.isa<FloatType>()) {
+ if (isa<FloatType>(inElementTy)) {
predicate = rewriter.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
- } else if (inElementTy.isa<IntegerType>()) {
+ } else if (isa<IntegerType>(inElementTy)) {
predicate = rewriter.create<arith::CmpIOp>(
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
} else {
@@ -1830,8 +1830,8 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
auto indices = adaptor.getOperands()[1];
auto valuesTy =
- op.getValues().getType().dyn_cast_or_null<RankedTensorType>();
- auto resultTy = op.getType().cast<ShapedType>();
+ dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
+ auto resultTy = cast<ShapedType>(op.getType());
if (!valuesTy)
return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
@@ -1904,9 +1904,9 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
auto loc = op.getLoc();
Value input = op.getInput();
Value table = op.getTable();
- auto inputTy = input.getType().cast<ShapedType>();
- auto tableTy = table.getType().cast<ShapedType>();
- auto resultTy = op.getType().cast<ShapedType>();
+ auto inputTy = cast<ShapedType>(input.getType());
+ auto tableTy = cast<ShapedType>(table.getType());
+ auto resultTy = cast<ShapedType>(op.getType());
auto inputElementTy = inputTy.getElementType();
auto tableElementTy = tableTy.getElementType();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 61413b2a6d6ac..c55a548dae0c5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -36,7 +36,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
return input;
- ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
Type inputETy = inputTy.getElementType();
auto inputShape = inputTy.getShape();
@@ -67,7 +67,7 @@ static mlir::Value
linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
Value conv, Value result,
ArrayRef<AffineMap> indexingMaps) {
- ShapedType resultTy = conv.getType().cast<ShapedType>();
+ ShapedType resultTy = cast<ShapedType>(conv.getType());
return rewriter
.create<linalg::GenericOp>(
loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
@@ -125,7 +125,7 @@ static SmallVector<Value> inferDynamicDimsForConv(
ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
- ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
Type inputETy = inputTy.getElementType();
int64_t inputRank = inputTy.getRank();
@@ -187,11 +187,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
- ShapedType inputTy = input.getType().template cast<ShapedType>();
- ShapedType weightTy = weight.getType().template cast<ShapedType>();
- ShapedType biasTy = bias.getType().template cast<ShapedType>();
- ShapedType resultTy =
- op->getResult(0).getType().template cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ ShapedType weightTy = cast<ShapedType>(weight.getType());
+ ShapedType biasTy = cast<ShapedType>(bias.getType());
+ ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
Type inputETy = inputTy.getElementType();
Type resultETy = resultTy.getElementType();
@@ -353,18 +352,18 @@ class DepthwiseConvConverter
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
- ShapedType inputTy = input.getType().cast<ShapedType>();
- ShapedType weightTy = weight.getType().cast<ShapedType>();
- ShapedType biasTy = bias.getType().cast<ShapedType>();
- ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ ShapedType weightTy = cast<ShapedType>(weight.getType());
+ ShapedType biasTy = cast<ShapedType>(bias.getType());
+ ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
int64_t resultRank = resultTy.getRank();
Type inputETy = inputTy.getElementType();
Type resultETy = resultTy.getElementType();
- auto padAttr = op->getAttr("pad").cast<DenseI64ArrayAttr>();
- auto strideTosaAttr = op->getAttr("stride").cast<DenseI64ArrayAttr>();
- auto dilationTosaAttr = op->getAttr("dilation").cast<DenseI64ArrayAttr>();
+ auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
+ auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
+ auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
@@ -382,7 +381,7 @@ class DepthwiseConvConverter
IntegerAttr kZp;
if (isQuantized) {
auto quantizationInfo =
- op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
+ cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
}
@@ -394,7 +393,7 @@ class DepthwiseConvConverter
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo =
- op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
+ cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
int64_t iZp = quantizationInfo.getInputZp();
int64_t intMin =
@@ -505,14 +504,14 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
- auto outputTy = op.getType().cast<ShapedType>();
+ auto outputTy = cast<ShapedType>(op.getType());
auto outputElementTy = outputTy.getElementType();
- auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>();
- auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>();
+ auto firstOperandTy = cast<ShapedType>(op->getOperand(0).getType());
+ auto secondOperandTy = cast<ShapedType>(op->getOperand(1).getType());
SmallVector<Value> dynDims;
- dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
+ dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
@@ -564,20 +563,20 @@ class FullyConnectedConverter
matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
- auto outputTy = op.getType().cast<ShapedType>();
+ auto outputTy = cast<ShapedType>(op.getType());
auto input = op.getInput();
- auto inputTy = input.getType().cast<ShapedType>();
+ auto inputTy = cast<ShapedType>(input.getType());
auto bias = op.getBias();
auto weight = op.getWeight();
- auto weightTy = weight.getType().cast<ShapedType>();
+ auto weightTy = cast<ShapedType>(weight.getType());
auto weightShape = weightTy.getShape();
auto outputETy = outputTy.getElementType();
SmallVector<Value> dynDims;
- dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
+ dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
@@ -676,9 +675,9 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value input = op.getInput();
- ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
- ShapedType resultTy = op.getType().template cast<ShapedType>();
+ ShapedType resultTy = cast<ShapedType>(op.getType());
Type resultETy = inputTy.getElementType();
auto dynamicDimsOr =
@@ -691,11 +690,10 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
TypedAttr initialAttr;
if (resultETy.isF32())
initialAttr = rewriter.getFloatAttr(
- resultETy,
- APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
- true));
+ resultETy, APFloat::getLargest(
+ cast<FloatType>(resultETy).getFloatSemantics(), true));
- if (resultETy.isa<IntegerType>())
+ if (isa<IntegerType>(resultETy))
initialAttr = rewriter.getIntegerAttr(
resultETy,
APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
@@ -747,14 +745,14 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value input = op.getInput();
- ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
Type inElementTy = inputTy.getElementType();
- ShapedType resultTy = op.getType().template cast<ShapedType>();
- Type resultETy = op.getType().cast<ShapedType>().getElementType();
+ ShapedType resultTy = cast<ShapedType>(op.getType());
+ Type resultETy = cast<ShapedType>(op.getType()).getElementType();
Type accETy =
- inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
+ isa<IntegerType>(inElementTy) ? rewriter.getI32Type() : inElementTy;
ShapedType accTy = resultTy.clone(accETy);
auto dynamicDimsOr =
@@ -872,7 +870,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
// a div however for quantized values input normalization had
// to be applied.
Value poolVal = args[0];
- if (accETy.isa<FloatType>()) {
+ if (isa<FloatType>(accETy)) {
auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
->getResult(0);
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 5e46fab0a1ecd..5e6e971f1f1ce 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -134,8 +134,8 @@ class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
- ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+ ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
+ ShapedType resultTy = cast<ShapedType>(reshape.getType());
bool isDynamic = !operandTy.hasStaticShape();
if (isDynamic && resultTy.getRank() != 1) {
@@ -172,8 +172,8 @@ class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
- ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+ ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
+ ShapedType resultTy = cast<ShapedType>(reshape.getType());
bool isDynamic = !operandTy.hasStaticShape();
if (isDynamic && operandTy.getRank() != 1) {
@@ -211,8 +211,8 @@ class ReshapeConverterCollapseExpand
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
- ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+ ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
+ ShapedType resultTy = cast<ShapedType>(reshape.getType());
bool isDynamic = !operandTy.hasStaticShape();
SmallVector<int64_t> intermediateShape;
@@ -247,7 +247,7 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
Value input = adaptor.getInput();
SmallVector<int64_t> strides, sizes;
ArrayRef<int64_t> starts = sliceOp.getStart();
- strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
+ strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
SmallVector<Value> dynSizes;
for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
@@ -284,7 +284,7 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
auto input = padOp.getInput1();
auto padding = padOp.getPadding();
- ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
Type elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
@@ -297,11 +297,11 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
loc, padOp.getPadConst(), ValueRange({}));
} else {
TypedAttr constantAttr;
- if (elementTy.isa<FloatType>()) {
+ if (isa<FloatType>(elementTy)) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
+ } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
+ } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
int64_t value = padOp.getQuantizationInfo()->getInputZp();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
@@ -355,8 +355,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
LogicalResult
matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
- auto resultType = op.getType().dyn_cast<RankedTensorType>();
+ auto inputType = cast<ShapedType>(op.getOperand(0).getType());
+ auto resultType = dyn_cast<RankedTensorType>(op.getType());
Location loc = op.getLoc();
int axis = op.getAxis();
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 5de3bef3af9d6..a78402eb16428 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -123,7 +123,7 @@ static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
// constant stride.
static std::optional<int64_t>
getMemrefConstantHorizontalStride(ShapedType type) {
- auto memrefType = type.dyn_cast<MemRefType>();
+ auto memrefType = dyn_cast<MemRefType>(type);
if (!memrefType)
return false;
// If the memref is 0 or 1D the horizontal stride is 0.
@@ -193,10 +193,10 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
/// Return true if the constant is a splat to a 2D vector so that it can be
/// converted to a MMA constant matrix op.
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
- auto vecType = constantOp.getType().dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(constantOp.getType());
if (!vecType || vecType.getRank() != 2)
return false;
- return constantOp.getValue().isa<SplatElementsAttr>();
+ return isa<SplatElementsAttr>(constantOp.getValue());
}
/// Return true if this is a broadcast from scalar to a 2D vector.
@@ -268,11 +268,11 @@ extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
// matrixB and matrixC operands. vector.extract_strided_slice op
// is not supported on registers containing matrixA operands.
if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
- return (op->getResult(0).getType().cast<VectorType>() ==
- (*contractOp).getRhs().getType().cast<VectorType>());
+ return (cast<VectorType>(op->getResult(0).getType()) ==
+ cast<VectorType>((*contractOp).getRhs().getType()));
if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
- return (op->getResult(0).getType().cast<VectorType>() ==
- (*contractOp).getAcc().getType().cast<VectorType>());
+ return (cast<VectorType>(op->getResult(0).getType()) ==
+ cast<VectorType>((*contractOp).getAcc().getType()));
return false;
}
@@ -344,11 +344,11 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
bool useNvGpu) {
auto hasVectorDest = [](Operation *op) {
return llvm::any_of(op->getResultTypes(),
- [](Type t) { return t.isa<VectorType>(); });
+ [](Type t) { return isa<VectorType>(t); });
};
auto hasVectorSrc = [](Operation *op) {
return llvm::any_of(op->getOperandTypes(),
- [](Type t) { return t.isa<VectorType>(); });
+ [](Type t) { return isa<VectorType>(t); });
};
SetVector<Operation *> opToConvert;
op->walk([&](vector::ContractionOp contract) {
@@ -448,8 +448,8 @@ struct CombineTransferReadOpTranspose final
(extOp = source.getDefiningOp<arith::ExtUIOp>())) {
source = extOp->getOperand(0);
resultType =
- VectorType::get(resultType.cast<VectorType>().getShape(),
- source.getType().cast<VectorType>().getElementType());
+ VectorType::get(cast<VectorType>(resultType).getShape(),
+ cast<VectorType>(source.getType()).getElementType());
}
auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
@@ -553,7 +553,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
bool isSignedExtend = isa<arith::ExtSIOp>(user);
if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
elType = IntegerType::get(
- op.getContext(), elType.cast<IntegerType>().getWidth(),
+ op.getContext(), cast<IntegerType>(elType).getWidth(),
isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
mappingResult = user->getResult(0);
fragType = inferFragType(user);
@@ -610,7 +610,7 @@ getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) {
SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
regInfo.elementsPerRegister};
Type elType = regInfo.registerLLVMType;
- if (auto vecType = elType.dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(elType))
elType = vecType.getElementType();
return VectorType::get(shape, elType);
}
@@ -637,7 +637,7 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
}
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
- auto dense = op.getValue().dyn_cast<SplatElementsAttr>();
+ auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
if (!dense) {
LLVM_DEBUG(DBGS() << "not a splat\n");
return rewriter.notifyMatchFailure(op, "not a splat");
@@ -782,7 +782,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
// If we are not transposing, then we can use vectorized loads. Otherwise, we
// must load each element individually.
if (!isTransposeLoad) {
- if (!loadedElType.isa<VectorType>()) {
+ if (!isa<VectorType>(loadedElType)) {
loadedElType = VectorType::get({1}, loadedElType);
}
@@ -805,7 +805,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
rewriter.getI64ArrayAttr(i));
}
} else {
- if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
loadedElType = vecType.getElementType();
}
for (int i = 0; i < vectorType.getShape()[0]; i++) {
@@ -838,7 +838,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
/// Return true if this is a shared memory memref type.
static bool isSharedMemory(MemRefType type) {
auto addressSpace =
- type.getMemorySpace().dyn_cast_or_null<gpu::AddressSpaceAttr>();
+ dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
if (addressSpace &&
addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace())
return true;
@@ -860,7 +860,7 @@ convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
bool isLdMatrixCompatible =
- isSharedMemory(op.getSource().getType().cast<MemRefType>()) &&
+ isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
VectorType vecTy = op.getVectorType();
@@ -929,7 +929,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
SmallVectorImpl<int64_t> &results) {
for (auto attr : arrayAttr)
- results.push_back(attr.cast<IntegerAttr>().getInt());
+ results.push_back(cast<IntegerAttr>(attr).getInt());
}
static LogicalResult
@@ -1041,9 +1041,9 @@ convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
itC == valueMapping.end())
return rewriter.notifyMatchFailure(op, "no mapping");
Value opA = itA->second, opB = itB->second, opC = itC->second;
- int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
- int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
- int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
+ int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
+ int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
+ int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
valueMapping[op.getResult()] = matmul;
@@ -1060,11 +1060,11 @@ convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
assert(constantSupportsMMAMatrixType(op));
auto splat =
- op.getValue().cast<SplatElementsAttr>().getSplatValue<TypedAttr>();
+ cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
auto scalarConstant =
rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
const char *fragType = inferFragType(op);
- auto vecType = op.getType().cast<VectorType>();
+ auto vecType = cast<VectorType>(op.getType());
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 05def0f45d7fb..4175f8fcafe72 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -256,8 +256,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
return failure();
// Resolve address.
- auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
- .template cast<VectorType>();
+ auto vtype = cast<VectorType>(
+ this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype,
@@ -277,7 +277,7 @@ class VectorGatherOpConversion
LogicalResult
matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType memRefType = gather.getBaseType().dyn_cast<MemRefType>();
+ MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
assert(memRefType && "The base should be bufferized");
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
@@ -296,7 +296,7 @@ class VectorGatherOpConversion
auto llvmNDVectorTy = adaptor.getIndexVec().getType();
// Handle the simple case of 1-D vector.
- if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) {
+ if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
auto vType = gather.getVectorType();
// Resolve address.
Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
@@ -501,7 +501,7 @@ static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- auto floatType = llvmType.cast<FloatType>();
+ auto floatType = cast<FloatType>(llvmType);
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType,
rewriter.getFloatAttr(
@@ -513,7 +513,7 @@ static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
ConversionPatternRewriter &rewriter,
Location loc, Type llvmType) {
- auto floatType = llvmType.cast<FloatType>();
+ auto floatType = cast<FloatType>(llvmType);
return rewriter.create<LLVM::ConstantOp>(
loc, llvmType,
rewriter.getFloatAttr(
@@ -585,9 +585,9 @@ static Value createIntegerReductionComparisonOpLowering(
/// with vector types.
static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
Value rhs, bool isMin) {
- auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
+ auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
- if (auto vecType = lhs.getType().dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
i1Type = VectorType::get(vecType.getShape(), i1Type);
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
@@ -768,7 +768,7 @@ class VectorReductionOpConversion
return success();
}
- if (!eltType.isa<FloatType>())
+ if (!isa<FloatType>(eltType))
return failure();
// Floating-point reductions: add/mul/min/max
@@ -966,14 +966,14 @@ class VectorShuffleOpConversion
// For all other cases, insert the individual values individually.
int64_t v1Dim = v1Type.getDimSize(0);
Type eltType;
- if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
eltType = arrayType.getElementType();
else
- eltType = llvmType.cast<VectorType>().getElementType();
+ eltType = cast<VectorType>(llvmType).getElementType();
Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
int64_t insPos = 0;
for (const auto &en : llvm::enumerate(maskArrayAttr)) {
- int64_t extPos = en.value().cast<IntegerAttr>().getInt();
+ int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
Value value = adaptor.getV1();
if (extPos >= v1Dim) {
extPos -= v1Dim;
@@ -1046,7 +1046,7 @@ class VectorExtractOpConversion
}
// One-shot extraction of vector from array (only requires extractvalue).
- if (resultType.isa<VectorType>()) {
+ if (isa<VectorType>(resultType)) {
SmallVector<int64_t> indices;
for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
indices.push_back(idx.getInt());
@@ -1062,13 +1062,13 @@ class VectorExtractOpConversion
if (positionAttrs.size() > 1) {
SmallVector<int64_t> nMinusOnePosition;
for (auto idx : positionAttrs.drop_back())
- nMinusOnePosition.push_back(idx.cast<IntegerAttr>().getInt());
+ nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt());
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
nMinusOnePosition);
}
// Remaining extraction of element from 1-D LLVM vector
- auto position = positionAttrs.back().cast<IntegerAttr>();
+ auto position = cast<IntegerAttr>(positionAttrs.back());
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
@@ -1169,7 +1169,7 @@ class VectorInsertOpConversion
}
// One-shot insertion of a vector into an array (only requires insertvalue).
- if (sourceType.isa<VectorType>()) {
+ if (isa<VectorType>(sourceType)) {
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), adaptor.getSource(),
LLVM::convertArrayToIndices(positionArrayAttr));
@@ -1180,7 +1180,7 @@ class VectorInsertOpConversion
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getDest();
auto positionAttrs = positionArrayAttr.getValue();
- auto position = positionAttrs.back().cast<IntegerAttr>();
+ auto position = cast<IntegerAttr>(positionAttrs.back());
auto oneDVectorType = destVectorType;
if (positionAttrs.size() > 1) {
oneDVectorType = reducedVectorTypeBack(destVectorType);
@@ -1333,7 +1333,7 @@ class VectorTypeCastOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = castOp->getLoc();
MemRefType sourceMemRefType =
- castOp.getOperand().getType().cast<MemRefType>();
+ cast<MemRefType>(castOp.getOperand().getType());
MemRefType targetMemRefType = castOp.getType();
// Only static shape casts supported atm.
@@ -1342,13 +1342,13 @@ class VectorTypeCastOpConversion
return failure();
auto llvmSourceDescriptorTy =
- adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
+ dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
if (!llvmSourceDescriptorTy)
return failure();
MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
- auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
- .dyn_cast_or_null<LLVM::LLVMStructType>();
+ auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
+ typeConverter->convertType(targetMemRefType));
if (!llvmTargetDescriptorTy)
return failure();
@@ -1418,7 +1418,7 @@ class VectorCreateMaskOpRewritePattern
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
- if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
+ if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
return failure();
IntegerType idxType =
force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
@@ -1465,7 +1465,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// Make sure element type has runtime support.
PrintConversion conversion = PrintConversion::None;
- VectorType vectorType = printType.dyn_cast<VectorType>();
+ VectorType vectorType = dyn_cast<VectorType>(printType);
Type eltType = vectorType ? vectorType.getElementType() : printType;
auto parent = printOp->getParentOfType<ModuleOp>();
Operation *printer;
@@ -1481,7 +1481,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
} else if (eltType.isIndex()) {
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
- } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
+ } else if (auto intTy = dyn_cast<IntegerType>(eltType)) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
// unsigned print method. Up to 64-bit is supported.
@@ -1536,7 +1536,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
Value value, Type type, Operation *printer, int64_t rank,
PrintConversion conversion) const {
- VectorType vectorType = type.dyn_cast<VectorType>();
+ VectorType vectorType = dyn_cast<VectorType>(type);
Location loc = op->getLoc();
if (!vectorType) {
assert(rank == 0 && "The scalar case expects rank == 0");
@@ -1610,7 +1610,7 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
LogicalResult
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType resultType = splatOp.getType().cast<VectorType>();
+ VectorType resultType = cast<VectorType>(splatOp.getType());
if (resultType.getRank() > 1)
return failure();
@@ -1633,7 +1633,7 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
auto v = rewriter.create<LLVM::InsertElementOp>(
splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
- int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
+ int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 1a47dd1610bec..9456b892f3f2a 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -258,7 +258,7 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
/// Return true if this transfer op operates on a source tensor.
template <typename OpTy>
static bool isTensorOp(OpTy xferOp) {
- if (xferOp.getShapedType().template isa<RankedTensorType>()) {
+ if (isa<RankedTensorType>(xferOp.getShapedType())) {
if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
// TransferWriteOps on tensors have a result.
assert(xferOp->getNumResults() > 0);
@@ -314,7 +314,7 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
///
/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
static MemRefType unpackOneDim(MemRefType type) {
- auto vectorType = type.getElementType().dyn_cast<VectorType>();
+ auto vectorType = dyn_cast<VectorType>(type.getElementType());
auto memrefShape = type.getShape();
SmallVector<int64_t, 8> newMemrefShape;
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
@@ -408,8 +408,8 @@ struct Strategy<TransferReadOp> {
getXferIndices(b, xferOp, iv, xferIndices);
Location loc = xferOp.getLoc();
- auto bufferType = buffer.getType().dyn_cast<ShapedType>();
- auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
+ auto bufferType = dyn_cast<ShapedType>(buffer.getType());
+ auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferReadOp>(
loc, vecType, xferOp.getSource(), xferIndices,
@@ -432,8 +432,8 @@ struct Strategy<TransferReadOp> {
storeIndices.push_back(iv);
Location loc = xferOp.getLoc();
- auto bufferType = buffer.getType().dyn_cast<ShapedType>();
- auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
+ auto bufferType = dyn_cast<ShapedType>(buffer.getType());
+ auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
@@ -698,7 +698,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// Find and cast data buffer. How the buffer can be found depends on OpTy.
ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
- auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
+ auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
auto castedDataType = unpackOneDim(dataBufferType);
auto castedDataBuffer =
locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
@@ -707,8 +707,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
Value castedMaskBuffer;
if (xferOp.getMask()) {
auto maskBuffer = getMaskBuffer(xferOp);
- auto maskBufferType =
- maskBuffer.getType().template dyn_cast<MemRefType>();
+ auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -889,7 +888,7 @@ struct UnrollTransferReadConversion
SmallVector<int64_t, 8> &indices) const {
if (auto insertOp = getInsertOp(xferOp)) {
for (Attribute attr : insertOp.getPosition())
- indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
+ indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
}
}
@@ -908,7 +907,7 @@ struct UnrollTransferReadConversion
auto insertOp = getInsertOp(xferOp);
auto vec = getResultVector(xferOp, rewriter);
- auto vecType = vec.getType().dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(vec.getType());
auto xferVecType = xferOp.getVectorType();
auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
xferVecType.getElementType());
@@ -1016,7 +1015,7 @@ struct UnrollTransferWriteConversion
SmallVector<int64_t, 8> &indices) const {
if (auto extractOp = getExtractOp(xferOp)) {
for (Attribute attr : extractOp.getPosition())
- indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
+ indices.push_back(dyn_cast<IntegerAttr>(attr).getInt());
}
}
@@ -1235,7 +1234,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
if (xferOp.getTransferRank() == 0)
return failure();
auto map = xferOp.getPermutationMap();
- auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
+ auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
if (!memRefType)
return failure();
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 50017b7bcef9b..35171b3e077ee 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -43,7 +43,7 @@ static int getNumBits(Type type) {
// TODO: This does not take into account any memory layout or widening
// constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
// though in practice it will likely be stored as in a 4xi64 vector register.
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = dyn_cast<VectorType>(type))
return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
return type.getIntOrFloatBitWidth();
}
@@ -95,7 +95,7 @@ struct VectorBroadcastConvert final
if (!resultType)
return failure();
- if (resultType.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(resultType)) {
rewriter.replaceOp(castOp, adaptor.getSource());
return success();
}
@@ -116,7 +116,7 @@ struct VectorExtractOpConvert final
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only support extracting a scalar value now.
- VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
+ VectorType resultVectorType = dyn_cast<VectorType>(extractOp.getType());
if (resultVectorType && resultVectorType.getNumElements() > 1)
return failure();
@@ -124,7 +124,7 @@ struct VectorExtractOpConvert final
if (!dstType)
return failure();
- if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
@@ -156,7 +156,7 @@ struct VectorExtractStridedSliceOpConvert final
Value srcVector = adaptor.getOperands().front();
// Extract vector<1xT> case.
- if (dstType.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(dstType)) {
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
srcVector, offset);
return success();
@@ -203,7 +203,7 @@ struct VectorInsertOpConvert final
return success();
}
- if (insertOp.getSourceType().isa<VectorType>() ||
+ if (isa<VectorType>(insertOp.getSourceType()) ||
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure();
int32_t id = getFirstIntValue(insertOp.getPosition());
@@ -224,7 +224,7 @@ struct VectorExtractElementOpConvert final
if (!resultType)
return failure();
- if (adaptor.getVector().getType().isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
@@ -252,7 +252,7 @@ struct VectorInsertElementOpConvert final
if (!vectorType)
return failure();
- if (vectorType.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(vectorType)) {
rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
@@ -285,18 +285,17 @@ struct VectorInsertStridedSliceOpConvert final
return failure();
uint64_t offset = getFirstIntValue(insertOp.getOffsets());
- if (srcVector.getType().isa<spirv::ScalarType>()) {
- assert(!dstVector.getType().isa<spirv::ScalarType>());
+ if (isa<spirv::ScalarType>(srcVector.getType())) {
+ assert(!isa<spirv::ScalarType>(dstVector.getType()));
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, dstVector.getType(), srcVector, dstVector,
rewriter.getI32ArrayAttr(offset));
return success();
}
- uint64_t totalSize =
- dstVector.getType().cast<VectorType>().getNumElements();
+ uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
uint64_t insertSize =
- srcVector.getType().cast<VectorType>().getNumElements();
+ cast<VectorType>(srcVector.getType()).getNumElements();
SmallVector<int32_t, 2> indices(totalSize);
std::iota(indices.begin(), indices.end(), 0);
@@ -324,7 +323,7 @@ struct VectorReductionPattern final
if (!resultType)
return failure();
- auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
+ auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
if (!srcVectorType || srcVectorType.getRank() != 1)
return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
@@ -393,10 +392,10 @@ class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
- if (dstType.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(dstType)) {
rewriter.replaceOp(op, adaptor.getInput());
} else {
- auto dstVecType = dstType.cast<VectorType>();
+ auto dstVecType = cast<VectorType>(dstType);
SmallVector<Value, 4> source(dstVecType.getNumElements(),
adaptor.getInput());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
@@ -422,7 +421,7 @@ struct VectorShuffleOpConvert final
if (oldSourceType.getNumElements() > 1) {
SmallVector<int32_t, 4> components = llvm::to_vector<4>(
llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
- return attr.cast<IntegerAttr>().getValue().getZExtValue();
+ return cast<IntegerAttr>(attr).getValue().getZExtValue();
}));
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index 0c69cdc027912..d07d6518d57c0 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -65,7 +65,7 @@ static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs,
newAttrs.push_back(attr);
continue;
}
- auto segmentAttr = attr.getValue().cast<DenseI32ArrayAttr>();
+ auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
MLIRContext *context = segmentAttr.getContext();
DenseI32ArrayAttr newSegments;
switch (action) {
@@ -128,7 +128,7 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
Value prevLoadForCompare = prevLoad;
Value atomicResForCompare = atomicRes;
- if (auto floatDataTy = dataType.dyn_cast<FloatType>()) {
+ if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
prevLoadForCompare =
rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 9b0ad3c119d04..4b3730a4aa397 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -136,7 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) {
bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) {
// Any memref-typed iteration arguments are treated as serializing.
if (llvm::any_of(forOp.getResultTypes(),
- [](Type type) { return type.isa<BaseMemRefType>(); }))
+ [](Type type) { return isa<BaseMemRefType>(type); }))
return false;
// Collect all load and store ops in loop nest rooted at 'forOp'.
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index 9db1e998bb165..c97e99c0a0c19 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -162,7 +162,7 @@ uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) {
/// conservative.
static bool isAccessIndexInvariant(Value iv, Value index) {
assert(isAffineForInductionVar(iv) && "iv must be a AffineForOp");
- assert(index.getType().isa<IndexType>() && "index must be of IndexType");
+ assert(isa<IndexType>(index.getType()) && "index must be of IndexType");
SmallVector<Operation *, 4> affineApplyOps;
getReachableAffineApplyOps({index}, affineApplyOps);
@@ -262,7 +262,7 @@ static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
template <typename LoadOrStoreOp>
static bool isVectorElement(LoadOrStoreOp memoryOp) {
auto memRefType = memoryOp.getMemRefType();
- return memRefType.getElementType().template isa<VectorType>();
+ return isa<VectorType>(memRefType.getElementType());
}
using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>;
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 438296f25096f..4433d94eb1454 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -190,7 +190,7 @@ void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
if (!hasEdge(srcId, dstId, value)) {
outEdges[srcId].push_back({dstId, value});
inEdges[dstId].push_back({srcId, value});
- if (value.getType().isa<MemRefType>())
+ if (isa<MemRefType>(value.getType()))
memrefEdgeCount[value]++;
}
}
@@ -200,7 +200,7 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
Value value) {
assert(inEdges.count(dstId) > 0);
assert(outEdges.count(srcId) > 0);
- if (value.getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(value.getType())) {
assert(memrefEdgeCount.count(value) > 0);
memrefEdgeCount[value]--;
}
@@ -289,7 +289,7 @@ void MemRefDependenceGraph::gatherDefiningNodes(
// By definition of edge, if the edge value is a non-memref value,
// then the dependence is between a graph node which defines an SSA value
// and another graph node which uses the SSA value.
- if (!edge.value.getType().isa<MemRefType>())
+ if (!isa<MemRefType>(edge.value.getType()))
definingNodes.insert(edge.id);
}
@@ -473,7 +473,7 @@ void MemRefDependenceGraph::forEachMemRefEdge(
ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
for (const auto &edge : edges) {
// Skip if 'edge' is not a memref dependence edge.
- if (!edge.value.getType().isa<MemRefType>())
+ if (!isa<MemRefType>(edge.value.getType()))
continue;
assert(nodes.count(edge.id) > 0);
// Skip if 'edge.id' is not a loop nest.
@@ -808,13 +808,13 @@ std::optional<bool> ComputationSliceState::isMaximal() const {
}
unsigned MemRefRegion::getRank() const {
- return memref.getType().cast<MemRefType>().getRank();
+ return cast<MemRefType>(memref.getType()).getRank();
}
std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
SmallVectorImpl<int64_t> *lbDivisors) const {
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
unsigned rank = memRefType.getRank();
if (shape)
shape->reserve(rank);
@@ -875,7 +875,7 @@ std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
AffineMap &ubMap) const {
assert(pos < cst.getNumDimVars() && "invalid position");
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
unsigned rank = memRefType.getRank();
assert(rank == cst.getNumDimVars() && "inconsistent memref region");
@@ -1049,7 +1049,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
// to guard against potential over-approximation from projection.
// TODO: Support dynamic memref dimensions.
if (addMemRefDimBounds) {
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
for (unsigned r = 0; r < rank; r++) {
cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0);
if (memRefType.isDynamicDim(r))
@@ -1071,7 +1071,7 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
unsigned sizeInBits;
if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
- } else if (auto vectorType = elementType.dyn_cast<VectorType>()) {
+ } else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
if (vectorType.getElementType().isIntOrFloat())
sizeInBits =
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
@@ -1085,7 +1085,7 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
// Returns the size of the region.
std::optional<int64_t> MemRefRegion::getRegionSize() {
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
@@ -1119,7 +1119,7 @@ mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) {
if (!memRefType.hasStaticShape())
return std::nullopt;
auto elementType = memRefType.getElementType();
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
+ if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
return std::nullopt;
auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
@@ -1708,7 +1708,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
}
unsigned MemRefAccess::getRank() const {
- return memref.getType().cast<MemRefType>().getRank();
+ return cast<MemRefType>(memref.getType()).getRank();
}
bool MemRefAccess::isStore() const {
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 89f0a9e92279e..2a9416f39f2fd 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -289,7 +289,7 @@ bool MemRefDependenceGraph::init() {
// memref type. Call Op that returns one or more memref type results
// is already taken care of, by the previous conditions.
if (llvm::any_of(op.getOperandTypes(),
- [&](Type t) { return t.isa<MemRefType>(); })) {
+ [&](Type t) { return isa<MemRefType>(t); })) {
Node node(nextNodeId++, &op);
nodes.insert({node.id, node});
}
@@ -379,7 +379,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
OpBuilder top(forInst->getParentRegion());
// Create new memref type based on slice bounds.
auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
- auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
+ auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
unsigned rank = oldMemRefType.getRank();
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
@@ -516,7 +516,7 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
return WalkResult::advance();
for (Value v : op->getOperands())
// Collect memref values only.
- if (v.getType().isa<MemRefType>())
+ if (isa<MemRefType>(v.getType()))
memRefValues.insert(v);
return WalkResult::advance();
});
diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
index 7d815f742541c..7029251a3720c 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
@@ -88,7 +88,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
};
- auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
+ auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
auto newMemRefType = doubleShape(oldMemRefType);
// The double buffer is allocated right before 'forOp'.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 8987a82b7206c..49618074ec224 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -100,9 +100,9 @@ void SimplifyAffineStructures::runOnOperation() {
SmallVector<Operation *> opsToSimplify;
func.walk([&](Operation *op) {
for (auto attr : op->getAttrs()) {
- if (auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>())
+ if (auto mapAttr = dyn_cast<AffineMapAttr>(attr.getValue()))
simplifyAndUpdateAttribute(op, attr.getName(), mapAttr);
- else if (auto setAttr = attr.getValue().dyn_cast<IntegerSetAttr>())
+ else if (auto setAttr = dyn_cast<IntegerSetAttr>(attr.getValue()))
simplifyAndUpdateAttribute(op, attr.getName(), setAttr);
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 1d347329c0005..b23a2cce35cee 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -838,7 +838,7 @@ void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
Value replacement) {
assert(!valueVectorReplacement.contains(replaced) &&
"Vector replacement already registered");
- assert(replacement.getType().isa<VectorType>() &&
+ assert(isa<VectorType>(replacement.getType()) &&
"Expected vector type in vector replacement");
valueVectorReplacement.map(replaced, replacement);
}
@@ -883,7 +883,7 @@ void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
Value replacement) {
assert(!valueScalarReplacement.contains(replaced) &&
"Scalar value replacement already registered");
- assert(!replacement.getType().isa<VectorType>() &&
+ assert(!isa<VectorType>(replacement.getType()) &&
"Expected scalar type in scalar replacement");
valueScalarReplacement.map(replaced, replacement);
}
@@ -946,7 +946,7 @@ isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops,
/// strategy on the scalar type.
static VectorType getVectorType(Type scalarTy,
const VectorizationStrategy *strategy) {
- assert(!scalarTy.isa<VectorType>() && "Expected scalar type");
+ assert(!isa<VectorType>(scalarTy) && "Expected scalar type");
return VectorType::get(strategy->vectorSizes, scalarTy);
}
@@ -1137,7 +1137,7 @@ static Value vectorizeOperand(Value operand, VectorizationState &state) {
// An vector operand that is not in the replacement map should never reach
// this point. Reaching this point could mean that the code was already
// vectorized and we shouldn't try to vectorize already vectorized code.
- assert(!operand.getType().isa<VectorType>() &&
+ assert(!isa<VectorType>(operand.getType()) &&
"Vector op not found in replacement map");
// Vectorize constant.
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 94203ec942749..01c7c77319f0c 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1852,7 +1852,7 @@ static void getMultiLevelStrides(const MemRefRegion ®ion,
int64_t numEltPerStride = 1;
int64_t stride = 1;
for (int d = bufferShape.size() - 1; d >= 1; d--) {
- int64_t dimSize = region.memref.getType().cast<MemRefType>().getDimSize(d);
+ int64_t dimSize = cast<MemRefType>(region.memref.getType()).getDimSize(d);
stride *= dimSize;
numEltPerStride *= bufferShape[d];
// A stride is needed only if the region has a shorter extent than the
@@ -1891,7 +1891,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
return ubMap.getNumInputs() == ubOperands.size();
}));
- unsigned rank = memref.getType().cast<MemRefType>().getRank();
+ unsigned rank = cast<MemRefType>(memref.getType()).getRank();
assert(lbMaps.size() == rank && "wrong number of lb maps");
assert(ubMaps.size() == rank && "wrong number of ub maps");
@@ -2003,7 +2003,7 @@ static LogicalResult generateCopy(
auto loc = region.loc;
auto memref = region.memref;
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
@@ -2276,7 +2276,7 @@ static bool getFullMemRefAsRegion(Operation *op, unsigned numParamLoopIVs,
assert(false && "expected load or store op");
return false;
}
- auto memRefType = region->memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(region->memref.getType());
if (!memRefType.hasStaticShape())
return false;
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index e454567c9213a..4e02b612b9bfe 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1119,9 +1119,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
ArrayRef<Value> extraIndices, AffineMap indexRemap,
ArrayRef<Value> extraOperands, ArrayRef<Value> symbolOperands,
bool allowNonDereferencingOps) {
- unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
+ unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
(void)newMemRefRank; // unused in opt mode
- unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
+ unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
(void)oldMemRefRank; // unused in opt mode
if (indexRemap) {
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
@@ -1134,8 +1134,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
}
// Assert same elemental type.
- assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
- newMemRef.getType().cast<MemRefType>().getElementType());
+ assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
+ cast<MemRefType>(newMemRef.getType()).getElementType());
SmallVector<unsigned, 2> usePositions;
for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
@@ -1172,7 +1172,7 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// Perform index rewrites for the dereferencing op and then replace the op
NamedAttribute oldMapAttrPair =
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
- AffineMap oldMap = oldMapAttrPair.getValue().cast<AffineMapAttr>().getValue();
+ AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
unsigned oldMapNumInputs = oldMap.getNumInputs();
SmallVector<Value, 4> oldMapOperands(
op->operand_begin() + memRefOperandPos + 1,
@@ -1294,9 +1294,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
ArrayRef<Value> symbolOperands, Operation *domOpFilter,
Operation *postDomOpFilter, bool allowNonDereferencingOps,
bool replaceInDeallocOp) {
- unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
+ unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
(void)newMemRefRank; // unused in opt mode
- unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
+ unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
(void)oldMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
@@ -1309,8 +1309,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
}
// Assert same elemental type.
- assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
- newMemRef.getType().cast<MemRefType>().getElementType());
+ assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
+ cast<MemRefType>(newMemRef.getType()).getElementType());
std::unique_ptr<DominanceInfo> domInfo;
std::unique_ptr<PostDominanceInfo> postDomInfo;
@@ -1734,7 +1734,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
(void)getTileSizePos(layoutMap, tileSizePos);
if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
- MemRefType oldMemRefType = oldMemRef.getType().cast<MemRefType>();
+ MemRefType oldMemRefType = cast<MemRefType>(oldMemRef.getType());
SmallVector<Value, 4> newDynamicSizes;
createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
newDynamicSizes);
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 9602d530cf826..85e07253c488f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -34,7 +34,7 @@ struct ConstantOpInterface
return constantOp->emitError("could not infer memory space");
// Only ranked tensors are supported.
- if (!constantOp.getType().isa<RankedTensorType>())
+ if (!isa<RankedTensorType>(constantOp.getType()))
return failure();
// Only constants inside a module are supported.
@@ -58,7 +58,7 @@ struct ConstantOpInterface
bool isWritable(Operation *op, Value value,
const AnalysisState &state) const {
// Memory locations returned by memref::GetGlobalOp may not be written to.
- assert(value.isa<OpResult>());
+ assert(isa<OpResult>(value));
return false;
}
};
@@ -84,21 +84,21 @@ struct IndexCastOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto castOp = cast<arith::IndexCastOp>(op);
- auto resultTensorType = castOp.getType().cast<TensorType>();
+ auto resultTensorType = cast<TensorType>(castOp.getType());
FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
if (failed(source))
return failure();
- auto sourceType = source->getType().cast<BaseMemRefType>();
+ auto sourceType = cast<BaseMemRefType>(source->getType());
// Result type should have same layout and address space as the source type.
BaseMemRefType resultType;
- if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
+ if (auto rankedMemRefType = dyn_cast<MemRefType>(sourceType)) {
resultType = MemRefType::get(
rankedMemRefType.getShape(), resultTensorType.getElementType(),
rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
} else {
- auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
+ auto unrankedMemrefType = cast<UnrankedMemRefType>(sourceType);
resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
unrankedMemrefType.getMemorySpace());
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 22ec425b4730c..1a50b4ad5598f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -63,10 +63,10 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
const APInt &value) {
TypedAttr attr;
- if (auto intTy = type.dyn_cast<IntegerType>()) {
+ if (auto intTy = dyn_cast<IntegerType>(type)) {
attr = rewriter.getIntegerAttr(type, value);
} else {
- auto vecTy = type.cast<VectorType>();
+ auto vecTy = cast<VectorType>(type);
attr = SplatElementsAttr::get(vecTy, value);
}
@@ -78,10 +78,10 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
int64_t value) {
unsigned elementBitWidth = 0;
- if (auto intTy = type.dyn_cast<IntegerType>())
+ if (auto intTy = dyn_cast<IntegerType>(type))
elementBitWidth = intTy.getWidth();
else
- elementBitWidth = type.cast<VectorType>().getElementTypeBitWidth();
+ elementBitWidth = cast<VectorType>(type).getElementTypeBitWidth();
return createScalarOrSplatConstant(rewriter, loc, type,
APInt(elementBitWidth, value));
@@ -95,7 +95,7 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
Location loc, Value input,
int64_t lastOffset) {
- ArrayRef<int64_t> shape = input.getType().cast<VectorType>().getShape();
+ ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape();
assert(lastOffset < shape.back() && "Offset out of bounds");
// Scalarize the result in case of 1D vectors.
@@ -125,7 +125,7 @@ extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
// `input` is a scalar, this is a noop.
static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
Location loc, Value input) {
- auto vecTy = input.getType().dyn_cast<VectorType>();
+ auto vecTy = dyn_cast<VectorType>(input.getType());
if (!vecTy)
return input;
@@ -142,7 +142,7 @@ static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
/// `input` is a scalar, this is a noop.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
Value input) {
- auto vecTy = input.getType().dyn_cast<VectorType>();
+ auto vecTy = dyn_cast<VectorType>(input.getType());
if (!vecTy)
return input;
@@ -159,11 +159,11 @@ static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
Location loc, Value source, Value dest,
int64_t lastOffset) {
- ArrayRef<int64_t> shape = dest.getType().cast<VectorType>().getShape();
+ ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape();
assert(lastOffset < shape.back() && "Offset out of bounds");
// Handle scalar source.
- if (source.getType().isa<IntegerType>())
+ if (isa<IntegerType>(source.getType()))
return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
SmallVector<int64_t> offsets(shape.size(), 0);
@@ -215,14 +215,14 @@ struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
unsigned newBitWidth = newType.getElementTypeBitWidth();
Attribute oldValue = op.getValueAttr();
- if (auto intAttr = oldValue.dyn_cast<IntegerAttr>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
auto newAttr = DenseElementsAttr::get(newType, {low, high});
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
return success();
}
- if (auto splatAttr = oldValue.dyn_cast<SplatElementsAttr>()) {
+ if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
auto [low, high] =
getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
int64_t numSplatElems = splatAttr.getNumElements();
@@ -238,7 +238,7 @@ struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
return success();
}
- if (auto elemsAttr = oldValue.dyn_cast<DenseElementsAttr>()) {
+ if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
int64_t numElems = elemsAttr.getNumElements();
SmallVector<APInt> values;
values.reserve(numElems * 2);
@@ -527,9 +527,8 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
Location loc = op->getLoc();
Type oldTy = op.getType();
- auto newTy = this->getTypeConverter()
- ->convertType(oldTy)
- .template dyn_cast_or_null<VectorType>();
+ auto newTy = dyn_cast_or_null<VectorType>(
+ this->getTypeConverter()->convertType(oldTy));
if (!newTy)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -549,11 +548,11 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
/// Returns true iff the type is `index` or `vector<...index>`.
static bool isIndexOrIndexVector(Type type) {
- if (type.isa<IndexType>())
+ if (isa<IndexType>(type))
return true;
- if (auto vectorTy = type.dyn_cast<VectorType>())
- if (vectorTy.getElementType().isa<IndexType>())
+ if (auto vectorTy = dyn_cast<VectorType>(type))
+ if (isa<IndexType>(vectorTy.getElementType()))
return true;
return false;
@@ -610,7 +609,7 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
// Emit an index cast over the matching narrow type.
Type narrowTy =
rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
- if (auto vecTy = resultType.dyn_cast<VectorType>())
+ if (auto vecTy = dyn_cast<VectorType>(resultType))
narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
// Sign or zero-extend the result. Let the matching conversion pattern
@@ -1116,7 +1115,7 @@ arith::WideIntEmulationConverter::WideIntEmulationConverter(
// Vector case.
addConversion([this](VectorType ty) -> std::optional<Type> {
- auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
if (!intTy)
return ty;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 787d4989bbabb..8eddd811dbea4 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -86,12 +86,12 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
continue;
}
- assert(value.getType().cast<ShapedType>().isDynamicDim(*dim) &&
+ assert(cast<ShapedType>(value.getType()).isDynamicDim(*dim) &&
"expected dynamic dim");
- if (value.getType().isa<RankedTensorType>()) {
+ if (isa<RankedTensorType>(value.getType())) {
// A tensor dimension is used: generate a tensor.dim.
operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
- } else if (value.getType().isa<MemRefType>()) {
+ } else if (isa<MemRefType>(value.getType())) {
// A memref dimension is used: generate a memref.dim.
operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
} else {
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 45a4bf74b9154..fb363c82a069f 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -58,7 +58,7 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
OpFoldResult ofr) {
if (auto value = ofr.dyn_cast<Value>())
return value;
- auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
+ auto attr = dyn_cast<IntegerAttr>(ofr.dyn_cast<Attribute>());
assert(attr && "expect the op fold result casts to an integer attribute");
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
}
@@ -73,8 +73,8 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
if (targetIsIndex ^ valueIsIndex)
return b.create<arith::IndexCastOp>(loc, targetType, value);
- auto targetIntegerType = targetType.dyn_cast<IntegerType>();
- auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
+ auto targetIntegerType = dyn_cast<IntegerType>(targetType);
+ auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
assert(targetIntegerType && valueIntegerType &&
"unexpected cast between types other than integers and index");
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
@@ -88,9 +88,9 @@ Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
Type toType, bool isUnsignedCast) {
if (operand.getType() == toType)
return operand;
- if (auto toIntType = toType.dyn_cast<IntegerType>()) {
+ if (auto toIntType = dyn_cast<IntegerType>(toType)) {
// If operand is floating point, cast directly to the int type.
- if (operand.getType().isa<FloatType>()) {
+ if (isa<FloatType>(operand.getType())) {
if (isUnsignedCast)
return b.create<arith::FPToUIOp>(loc, toType, operand);
return b.create<arith::FPToSIOp>(loc, toType, operand);
@@ -98,7 +98,7 @@ Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
return b.create<arith::IndexCastOp>(loc, toType, operand);
- if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
+ if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
// Either extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth()) {
if (isUnsignedCast)
@@ -108,15 +108,15 @@ Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
if (toIntType.getWidth() < fromIntType.getWidth())
return b.create<arith::TruncIOp>(loc, toType, operand);
}
- } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
+ } else if (auto toFloatType = dyn_cast<FloatType>(toType)) {
// If operand is integer, cast directly to the float type.
// Note that it is unclear how to cast from BF16<->FP16.
- if (operand.getType().isa<IntegerType>()) {
+ if (isa<IntegerType>(operand.getType())) {
if (isUnsignedCast)
return b.create<arith::UIToFPOp>(loc, toFloatType, operand);
return b.create<arith::SIToFPOp>(loc, toFloatType, operand);
}
- if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
+ if (auto fromFloatType = dyn_cast<FloatType>(operand.getType())) {
if (toFloatType.getWidth() > fromFloatType.getWidth())
return b.create<arith::ExtFOp>(loc, toFloatType, operand);
if (toFloatType.getWidth() < fromFloatType.getWidth())
@@ -141,27 +141,27 @@ Value ArithBuilder::_and(Value lhs, Value rhs) {
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
Value ArithBuilder::add(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::AddFOp>(loc, lhs, rhs);
return b.create<arith::AddIOp>(loc, lhs, rhs);
}
Value ArithBuilder::sub(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::SubFOp>(loc, lhs, rhs);
return b.create<arith::SubIOp>(loc, lhs, rhs);
}
Value ArithBuilder::mul(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::MulFOp>(loc, lhs, rhs);
return b.create<arith::MulIOp>(loc, lhs, rhs);
}
Value ArithBuilder::sgt(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
}
Value ArithBuilder::slt(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 7db078ad3f0a5..04f131ec51cb4 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -528,9 +528,9 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
Operation *op = operand.getOwner();
Type type = operand.get().getType();
- bool isToken = type.isa<TokenType>();
- bool isGroup = type.isa<GroupType>();
- bool isValue = type.isa<ValueType>();
+ bool isToken = isa<TokenType>(type);
+ bool isGroup = isa<GroupType>(type);
+ bool isValue = isa<ValueType>(type);
// Drop reference after async token or group error check (coro await).
if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 25cb61857a10e..db7550d7d99fa 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -161,7 +161,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// We treat TokenType as state update marker to represent side-effects of
// async computations
- bool isStateful = func.getCallableResults().front().isa<TokenType>();
+ bool isStateful = isa<TokenType>(func.getCallableResults().front());
std::optional<Value> retToken;
if (isStateful)
@@ -535,7 +535,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
ConversionPatternRewriter &rewriter) const override {
// We can only await on one the `AwaitableType` (for `await` it can be
// a `token` or a `value`, for `await_all` it must be a `group`).
- if (!op.getOperand().getType().template isa<AwaitableType>())
+ if (!isa<AwaitableType>(op.getOperand().getType()))
return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
// Check if await operation is inside the coroutine function.
@@ -646,7 +646,7 @@ class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
getReplacementValue(AwaitOp op, Value operand,
ConversionPatternRewriter &rewriter) const override {
// Load from the async value storage.
- auto valueType = operand.getType().cast<ValueType>().getValueType();
+ auto valueType = cast<ValueType>(operand.getType()).getValueType();
return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
}
};
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index ed95a62b9b6f8..5e36b55cff840 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -59,7 +59,7 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
// This transform op is currently restricted to ModuleOps and function ops.
// Such ops are modified in-place.
- transformResults.set(getTransformed().cast<OpResult>(), payloadOps);
+ transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index cf51aa58a93a9..b813b2425bdd5 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -280,7 +280,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// defined in a non-dominated block or it is defined in the same block
// but the current value is not dominated by the source value.
if (!dominators.dominates(definingBlock, parentBlock) ||
- (definingBlock == parentBlock && value.isa<BlockArgument>())) {
+ (definingBlock == parentBlock && isa<BlockArgument>(value))) {
toProcess.emplace_back(value, parentBlock);
valuesToFree.insert(value);
} else if (visitedValues.insert(std::make_tuple(value, definingBlock))
@@ -307,8 +307,8 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
// Add new allocs and additional clone operations.
for (Value value : valuesToFree) {
- if (failed(value.isa<BlockArgument>()
- ? introduceBlockArgCopy(value.cast<BlockArgument>())
+ if (failed(isa<BlockArgument>(value)
+ ? introduceBlockArgCopy(cast<BlockArgument>(value))
: introduceValueCopyForRegionResult(value)))
return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
index 83b2ef6a6dac7..278664abcf49d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
@@ -43,7 +43,7 @@ static bool isKnownControlFlowInterface(Operation *op) {
/// exceed the stack space.
static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
unsigned maxRankOfAllocatedMemRef) {
- auto type = alloc.getType().dyn_cast<ShapedType>();
+ auto type = dyn_cast<ShapedType>(alloc.getType());
if (!type || !alloc.getDefiningOp<memref::AllocOp>())
return false;
if (!type.hasStaticShape()) {
@@ -355,7 +355,7 @@ class BufferPlacementPromotion : BufferPlacementTransformationBase {
OpBuilder builder(startOperation);
Operation *allocOp = alloc.getDefiningOp();
Operation *alloca = builder.create<memref::AllocaOp>(
- alloc.getLoc(), alloc.getType().cast<MemRefType>(),
+ alloc.getLoc(), cast<MemRefType>(alloc.getType()),
allocOp->getOperands(), allocOp->getAttrs());
// Replace the original alloc by a newly created alloca.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 7b63335d43fb6..dd359c2dcca5d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -53,7 +53,7 @@ updateFuncOp(func::FuncOp func,
SmallVector<Type, 6> erasedResultTypes;
BitVector erasedResultIndices(functionType.getNumResults());
for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
- if (auto memrefType = resultType.value().dyn_cast<MemRefType>()) {
+ if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
if (!hasStaticIdentityLayout(memrefType) &&
!hasFullyDynamicLayoutMap(memrefType)) {
// Only buffers with static identity layout can be allocated. These can
@@ -103,7 +103,7 @@ static void updateReturnOps(func::FuncOp func,
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
for (Value operand : op.getOperands()) {
- if (operand.getType().isa<MemRefType>())
+ if (isa<MemRefType>(operand.getType()))
copyIntoOutParams.push_back(operand);
else
keepAsReturnOperands.push_back(operand);
@@ -137,7 +137,7 @@ updateCalls(ModuleOp module,
SmallVector<Value, 6> replaceWithNewCallResults;
SmallVector<Value, 6> replaceWithOutParams;
for (OpResult result : op.getResults()) {
- if (result.getType().isa<MemRefType>())
+ if (isa<MemRefType>(result.getType()))
replaceWithOutParams.push_back(result);
else
replaceWithNewCallResults.push_back(result);
@@ -145,13 +145,13 @@ updateCalls(ModuleOp module,
SmallVector<Value, 6> outParams;
OpBuilder builder(op);
for (Value memref : replaceWithOutParams) {
- if (!memref.getType().cast<MemRefType>().hasStaticShape()) {
+ if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
op.emitError()
<< "cannot create out param for dynamically shaped result";
didFail = true;
return;
}
- auto memrefType = memref.getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(memref.getType());
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index b9776e2fb2095..f8231cac778af 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -68,7 +68,7 @@ void BufferPlacementAllocs::build(Operation *op) {
[=](MemoryEffects::EffectInstance &it) {
Value value = it.getValue();
return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
- value.isa<OpResult>() &&
+ isa<OpResult>(value) &&
it.getResource() !=
SideEffects::AutomaticAllocationScopeResource::get();
});
@@ -149,7 +149,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
FailureOr<memref::GlobalOp>
bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
Attribute memorySpace) {
- auto type = constantOp.getType().cast<RankedTensorType>();
+ auto type = cast<RankedTensorType>(constantOp.getType());
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
@@ -185,14 +185,14 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
: IntegerAttr();
BufferizeTypeConverter typeConverter;
- auto memrefType = typeConverter.convertType(type).cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(typeConverter.convertType(type));
if (memorySpace)
memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
auto global = globalBuilder.create<memref::GlobalOp>(
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
/*type=*/memrefType,
- /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
+ /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()),
/*constant=*/true,
/*alignment=*/memrefAlignment);
symbolTable.insert(global);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 4eabfccf2514b..24aaff0e48826 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -44,7 +44,7 @@ using namespace mlir::bufferization;
static Value materializeToTensor(OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
- assert(inputs[0].getType().isa<BaseMemRefType>());
+ assert(isa<BaseMemRefType>(inputs[0].getType()));
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
}
@@ -66,11 +66,11 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
- if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
+ if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
// MemRef to MemRef cast.
assert(inputType != type && "expected
diff erent types");
// Unranked to ranked and ranked to unranked casts must be explicit.
- auto rankedDestType = type.dyn_cast<MemRefType>();
+ auto rankedDestType = dyn_cast<MemRefType>(type);
if (!rankedDestType)
return nullptr;
FailureOr<Value> replacement =
@@ -80,7 +80,7 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
return *replacement;
}
- if (inputs[0].getType().isa<TensorType>()) {
+ if (isa<TensorType>(inputs[0].getType())) {
// Tensor to MemRef cast.
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
}
@@ -222,7 +222,7 @@ struct OneShotBufferizePass
parseLayoutMapOption(unknownTypeConversion);
opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
- auto tensorType = value.getType().cast<TensorType>();
+ auto tensorType = cast<TensorType>(value.getType());
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
@@ -325,7 +325,7 @@ mlir::bufferization::createFinalizingBufferizePass() {
// BufferizableOpInterface-based Bufferization
//===----------------------------------------------------------------------===//
-static bool isaTensor(Type t) { return t.isa<TensorType>(); }
+static bool isaTensor(Type t) { return isa<TensorType>(t); }
/// Return true if the given op has a tensor result or a tensor operand.
static bool hasTensorSemantics(Operation *op) {
@@ -549,7 +549,7 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
- value.getType().cast<TensorType>(), memorySpace);
+ cast<TensorType>(value.getType()), memorySpace);
};
options.opFilter.allowDialect<BufferizationDialect>();
return options;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 5fc12573912f3..58475d225ce8b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -33,12 +33,12 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
Operation *insertionPoint,
const SmallVector<Value> &neededValues) {
for (Value val : neededValues) {
- if (auto bbArg = val.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(val)) {
Block *owner = bbArg.getOwner();
if (!owner->findAncestorOpInBlock(*insertionPoint))
return false;
} else {
- auto opResult = val.cast<OpResult>();
+ auto opResult = cast<OpResult>(val);
if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
return false;
}
@@ -75,7 +75,7 @@ findValidInsertionPoint(Operation *emptyTensorOp,
// * in case of an OpResult: There must be at least one op right after the
// defining op (the anchor op or one of its
// parents).
- if (auto bbArg = val.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(val)) {
insertionPointCandidates.push_back(
&bbArg.getOwner()->getOperations().front());
} else {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index bf14e466190b4..f73efc120d377 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -60,7 +60,7 @@ static BaseMemRefType
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
const BufferizationOptions &options) {
auto tensorType =
- funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
+ dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
assert(tensorType && "expected TensorType");
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
@@ -71,7 +71,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
if (!layoutAttr)
return memrefType;
- auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
+ auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
return MemRefType::get(
rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
@@ -224,7 +224,7 @@ struct CallOpInterface
for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
unsigned returnValIdx = it.index();
Type returnType = it.value();
- if (!returnType.isa<TensorType>()) {
+ if (!isa<TensorType>(returnType)) {
// Non-tensor values are returned.
retValMapping[returnValIdx] = resultTypes.size();
resultTypes.push_back(returnType);
@@ -242,7 +242,7 @@ struct CallOpInterface
Value tensorOperand = opOperand.get();
// Non-tensor operands are just copied.
- if (!tensorOperand.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(tensorOperand.getType())) {
newOperands[idx] = tensorOperand;
continue;
}
@@ -342,7 +342,7 @@ struct FuncOpInterface
SmallVector<Type> argTypes;
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
Type argType = it.value();
- if (auto tensorType = argType.dyn_cast<TensorType>()) {
+ if (auto tensorType = dyn_cast<TensorType>(argType)) {
argTypes.push_back(
getBufferizedFunctionArgType(funcOp, it.index(), options));
continue;
@@ -356,7 +356,7 @@ struct FuncOpInterface
if (funcOp.getBody().empty()) {
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
- if (resultType.isa<TensorType>())
+ if (isa<TensorType>(resultType))
return funcOp->emitError() << "cannot bufferize bodiless function "
<< "that returns a tensor";
retTypes.push_back(resultType);
@@ -373,7 +373,7 @@ struct FuncOpInterface
// 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
Block &frontBlock = funcOp.getBody().front();
for (BlockArgument &bbArg : frontBlock.getArguments()) {
- auto tensorType = bbArg.getType().dyn_cast<TensorType>();
+ auto tensorType = dyn_cast<TensorType>(bbArg.getType());
// Non-tensor types stay the same.
if (!tensorType)
continue;
@@ -404,7 +404,7 @@ struct FuncOpInterface
SmallVector<Value> returnValues;
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
Value returnVal = returnOperand.get();
- auto tensorType = returnVal.getType().dyn_cast<TensorType>();
+ auto tensorType = dyn_cast<TensorType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);
// If not a tensor type just forward it.
@@ -436,7 +436,7 @@ struct FuncOpInterface
bool isWritable(Operation *op, Value value,
const AnalysisState &state) const {
auto funcOp = cast<FuncOp>(op);
- BlockArgument bbArg = value.dyn_cast<BlockArgument>();
+ BlockArgument bbArg = dyn_cast<BlockArgument>(value);
assert(bbArg && "expected BlockArgument");
// "bufferization.writable" overrides other writability decisions. This is
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index db7d4533cd9a3..6da512699cc7b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -66,7 +66,7 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
using namespace mlir;
using namespace mlir::bufferization;
-static bool isaTensor(Type t) { return t.isa<TensorType>(); }
+static bool isaTensor(Type t) { return isa<TensorType>(t); }
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
@@ -85,11 +85,11 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
SmallVector<StringRef> inPlaceVector;
if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) {
inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
- attr.cast<ArrayAttr>().getAsValueRange<StringAttr>()));
+ cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
} else {
inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
for (OpOperand &opOperand : op->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
+ if (isa<TensorType>(opOperand.get().getType()))
inPlaceVector[opOperand.getOperandNumber()] = "false";
}
inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
@@ -107,12 +107,12 @@ OneShotAnalysisState::OneShotAnalysisState(
// Set up alias sets.
op->walk([&](Operation *op) {
for (Value v : op->getResults())
- if (v.getType().isa<TensorType>())
+ if (isa<TensorType>(v.getType()))
createAliasInfoEntry(v);
for (Region &r : op->getRegions())
for (Block &b : r.getBlocks())
for (auto bbArg : b.getArguments())
- if (bbArg.getType().isa<TensorType>())
+ if (isa<TensorType>(bbArg.getType()))
createAliasInfoEntry(bbArg);
});
@@ -121,7 +121,7 @@ OneShotAnalysisState::OneShotAnalysisState(
if (!options.isOpAllowed(bufferizableOp))
return WalkResult::skip();
for (OpOperand &opOperand : bufferizableOp->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
+ if (isa<TensorType>(opOperand.get().getType()))
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
bufferizeInPlace(opOperand);
return WalkResult::advance();
@@ -187,13 +187,13 @@ void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
Value returnVal = returnValOperand.get();
// Skip non-tensor values.
- if (!returnVal.getType().isa<TensorType>())
+ if (!isa<TensorType>(returnVal.getType()))
continue;
// Add all aliases of the returned value. But only the ones that are in
// the same block.
applyOnAliases(returnVal, [&](Value v) {
- if (auto bbArg = v.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(v)) {
if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
yieldedTensors.insert(bbArg);
return;
@@ -217,7 +217,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
// Check all tensor OpResults.
for (OpResult opResult : op->getOpResults()) {
- if (!opResult.getType().isa<TensorType>())
+ if (!isa<TensorType>(opResult.getType()))
continue;
// If there is no preceding definition, the tensor contents are
@@ -259,7 +259,7 @@ bool OneShotAnalysisState::isWritable(Value value) const {
return bufferizableOp.isWritable(value, *this);
// Query BufferizableOpInterface to see if the BlockArgument is writable.
- if (auto bbArg = value.dyn_cast<BlockArgument>())
+ if (auto bbArg = dyn_cast<BlockArgument>(value))
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
return bufferizableOp.isWritable(bbArg, *this);
@@ -431,12 +431,12 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
readingOp->setAttr(readAttr, b.getUnitAttr());
- if (auto opResult = definition.dyn_cast<OpResult>()) {
+ if (auto opResult = dyn_cast<OpResult>(definition)) {
std::string defAttr =
id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
} else {
- auto bbArg = definition.cast<BlockArgument>();
+ auto bbArg = cast<BlockArgument>(definition);
std::string defAttr =
id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
@@ -581,7 +581,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
continue;
}
} else {
- auto bbArg = definition.cast<BlockArgument>();
+ auto bbArg = cast<BlockArgument>(definition);
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
@@ -715,12 +715,12 @@ static void annotateNonWritableTensor(Value value) {
static int64_t counter = 0;
OpBuilder b(value.getContext());
std::string id = "W_" + std::to_string(counter++);
- if (auto opResult = value.dyn_cast<OpResult>()) {
+ if (auto opResult = dyn_cast<OpResult>(value)) {
std::string attr = id + "[NOT-WRITABLE: result " +
std::to_string(opResult.getResultNumber()) + "]";
opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());
} else {
- auto bbArg = value.cast<BlockArgument>();
+ auto bbArg = cast<BlockArgument>(value);
std::string attr = id + "[NOT-WRITABLE: bbArg " +
std::to_string(bbArg.getArgNumber()) + "]";
bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());
@@ -812,7 +812,7 @@ LogicalResult
OneShotAnalysisState::analyzeSingleOp(Operation *op,
const DominanceInfo &domInfo) {
for (OpOperand &opOperand : op->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
+ if (isa<TensorType>(opOperand.get().getType()))
if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
return failure();
return success();
@@ -831,7 +831,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
for (Operation *op : ops) {
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
for (OpResult opResult : op->getOpResults()) {
- if (!opResult.getType().isa<TensorType>())
+ if (!isa<TensorType>(opResult.getType()))
continue;
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
if (aliases.getNumAliases() == 0)
@@ -958,7 +958,7 @@ static LogicalResult checkAliasInfoConsistency(Operation *op,
}
for (OpOperand &opOperand : op->getOpOperands()) {
- if (opOperand.get().getType().isa<TensorType>()) {
+ if (isa<TensorType>(opOperand.get().getType())) {
if (wouldCreateReadAfterWriteInterference(
opOperand, domInfo, state,
/*checkConsistencyOnly=*/true)) {
@@ -984,7 +984,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
// Add __inplace_operands_attr__.
op->walk([&](Operation *op) {
for (OpOperand &opOperand : op->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
+ if (isa<TensorType>(opOperand.get().getType()))
setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
});
}
@@ -1031,12 +1031,12 @@ static LogicalResult assertNoAllocsReturned(Operation *op,
for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
Value returnVal = returnValOperand.get();
// Skip non-tensor values.
- if (!returnVal.getType().isa<TensorType>())
+ if (!isa<TensorType>(returnVal.getType()))
continue;
bool foundEquivValue = false;
state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
- if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(equivVal)) {
Operation *definingOp = bbArg.getOwner()->getParentOp();
if (definingOp->isProperAncestor(returnOp))
foundEquivValue = true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 27b560afdbb34..d0af1c278c146 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -109,9 +109,9 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
SmallVector<int64_t> equivBbArgs;
if (op->hasAttr(kEquivalentArgsAttr)) {
- auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
+ auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
- return a.cast<IntegerAttr>().getValue().getSExtValue();
+ return cast<IntegerAttr>(a).getValue().getSExtValue();
}));
} else {
equivBbArgs.append(op->getNumOperands(), -1);
@@ -132,10 +132,10 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
// return value may alias with any tensor bbArg.
FunctionType type = funcOp.getFunctionType();
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
- if (!inputIt.value().isa<TensorType>())
+ if (!isa<TensorType>(inputIt.value()))
continue;
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
- if (!resultIt.value().isa<TensorType>())
+ if (!isa<TensorType>(resultIt.value()))
continue;
int64_t returnIdx = resultIt.index();
int64_t bbArgIdx = inputIt.index();
@@ -150,9 +150,9 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
assert(returnOp && "expected func with single return op");
for (OpOperand &returnVal : returnOp->getOpOperands())
- if (returnVal.get().getType().isa<RankedTensorType>())
+ if (isa<RankedTensorType>(returnVal.get().getType()))
for (BlockArgument bbArg : funcOp.getArguments())
- if (bbArg.getType().isa<RankedTensorType>()) {
+ if (isa<RankedTensorType>(bbArg.getType())) {
int64_t returnIdx = returnVal.getOperandNumber();
int64_t bbArgIdx = bbArg.getArgNumber();
if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
@@ -193,7 +193,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
++idx) {
// Skip non-tensor arguments.
- if (!funcOp.getFunctionType().getInput(idx).isa<TensorType>())
+ if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
continue;
bool isRead;
bool isWritten;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 4cd19b4efc636..b12ea25396b22 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -58,7 +58,7 @@ resolveUsesInRepetitiveRegions(Operation *op,
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
Value operand = opOperand.get();
// Skip non-tensor operands.
- if (!operand.getType().isa<TensorType>())
+ if (!isa<TensorType>(operand.getType()))
continue;
// Skip operands that do not bufferize to memory writes.
if (!bufferizableOp.bufferizesToMemoryWrite(opOperand, state))
@@ -85,7 +85,7 @@ resolveUsesInRepetitiveRegions(Operation *op,
// Insert a tensor copy and replace all uses inside of repetitive regions.
rewriter.setInsertionPoint(bufferizableOp);
auto tensorCopy = rewriter.create<AllocTensorOp>(
- bufferizableOp->getLoc(), operand.getType().cast<TensorType>(),
+ bufferizableOp->getLoc(), cast<TensorType>(operand.getType()),
/*dynamicSizes=*/ValueRange(),
/*copy=*/operand, /*memory_space=*/IntegerAttr());
for (OpOperand *use : usesInsideRegion)
@@ -137,7 +137,7 @@ mlir::bufferization::insertTensorCopies(Operation *op,
SmallVector<bool> escapeAttrValue;
bool foundTensorResult = false;
for (OpResult opResult : op->getOpResults()) {
- if (!opResult.getType().isa<TensorType>() ||
+ if (!isa<TensorType>(opResult.getType()) ||
!bufferizableOp.bufferizesToAllocation(opResult)) {
escapeAttrValue.push_back(false);
continue;
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 39c9e5e1725a8..8cd2ccfcf882b 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -257,19 +257,19 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
bool hasBlockMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return attr.isa<GPUBlockMappingAttr>();
+ return isa<GPUBlockMappingAttr>(attr);
});
bool hasThreadMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return attr.isa<GPUThreadMappingAttr>();
+ return isa<GPUThreadMappingAttr>(attr);
});
bool hasWarpMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return attr.isa<GPUWarpMappingAttr>();
+ return isa<GPUWarpMappingAttr>(attr);
});
bool hasLinearMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return attr.isa<GPULinearIdMappingAttr>();
+ return isa<GPULinearIdMappingAttr>(attr);
});
int64_t countMappingTypes = 0;
countMappingTypes += hasBlockMapping ? 1 : 0;
@@ -520,7 +520,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
ArrayRef<Attribute>{forallMappingAttrs}.take_front(
forallOp.getInductionVars().size()))) {
Value peIdOp = mappingIdOps[static_cast<int64_t>(
- dim.cast<DeviceMappingAttrInterface>().getMappingId())];
+ cast<DeviceMappingAttrInterface>(dim).getMappingId())];
bvm.map(iv, peIdOp);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 0a584a7920e02..ca9f2ac254c58 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -214,7 +214,7 @@ struct GpuAllReduceRewriter {
/// Returns an accumulator factory that creates an op specified by opName.
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
- bool isFloatingPoint = valueType.isa<FloatType>();
+ bool isFloatingPoint = isa<FloatType>(valueType);
switch (opName) {
case gpu::AllReduceOperation::ADD:
return isFloatingPoint ? getFactory<arith::AddFOp>()
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index 0890bf2677626..1fbe66ff98d7a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -158,9 +158,9 @@ async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
[](Type type) {
// Extract value type from !async.value.
- if (auto valueType = type.dyn_cast<async::ValueType>())
+ if (auto valueType = dyn_cast<async::ValueType>(type))
return valueType.getValueType();
- assert(type.isa<async::TokenType>() && "expected token type");
+ assert(isa<async::TokenType>(type) && "expected token type");
return type;
});
transform(results, std::back_inserter(resultTypes),
@@ -305,9 +305,9 @@ struct GpuAsyncRegionPass::SingleTokenUseCallback {
executeOp.getBodyResults(), [](OpResult result) {
if (result.use_empty() || result.hasOneUse())
return false;
- auto valueType = result.getType().dyn_cast<async::ValueType>();
+ auto valueType = dyn_cast<async::ValueType>(result.getType());
return valueType &&
- valueType.getValueType().isa<gpu::AsyncTokenType>();
+ isa<gpu::AsyncTokenType>(valueType.getValueType());
});
if (multiUseResults.empty())
return;
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 91c1c763f070d..b1e2f914db4cb 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -338,7 +338,7 @@ class GpuKernelOutliningPass
if (!resultAttr)
return failure();
- dataLayoutSpec = resultAttr.dyn_cast<DataLayoutSpecInterface>();
+ dataLayoutSpec = dyn_cast<DataLayoutSpecInterface>(resultAttr);
if (!dataLayoutSpec)
return failure();
}
@@ -410,7 +410,7 @@ class GpuKernelOutliningPass
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
StringRef symbolName =
- symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
+ cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue();
if (symbolTable.lookup(symbolName))
continue;
diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
index ea9c3969c413f..21de15e250881 100644
--- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
@@ -30,7 +30,7 @@ using namespace mlir::gpu;
/// single-iteration loops. Maps the innermost loops to thread dimensions, in
/// reverse order to enable access coalescing in the innermost loop.
static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) {
- auto memRefType = from.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(from.getType());
auto rank = memRefType.getRank();
SmallVector<Value, 4> lbs, ubs, steps;
@@ -121,8 +121,8 @@ static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) {
/// pointed to by "from". In case a smaller block would be sufficient, the
/// caller can create a subview of the memref and promote it instead.
static void insertCopies(Region ®ion, Location loc, Value from, Value to) {
- auto fromType = from.getType().cast<MemRefType>();
- auto toType = to.getType().cast<MemRefType>();
+ auto fromType = cast<MemRefType>(from.getType());
+ auto toType = cast<MemRefType>(to.getType());
(void)fromType;
(void)toType;
assert(fromType.getShape() == toType.getShape());
@@ -143,7 +143,7 @@ static void insertCopies(Region ®ion, Location loc, Value from, Value to) {
/// copies will be inserted in the beginning and in the end of the function.
void mlir::promoteToWorkgroupMemory(GPUFuncOp op, unsigned arg) {
Value value = op.getArgument(arg);
- auto type = value.getType().dyn_cast<MemRefType>();
+ auto type = dyn_cast<MemRefType>(value.getType());
assert(type && type.hasStaticShape() && "can only promote memrefs");
// Get the type of the buffer in the workgroup memory.
diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
index 71d27764f437d..8b09f441a4bd1 100644
--- a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
@@ -67,7 +67,7 @@ LogicalResult DynParametricAttrConstraint::verify(
ConstraintVerifier &context) const {
// Check that the base is the expected one.
- auto dynAttr = attr.dyn_cast<DynamicAttr>();
+ auto dynAttr = dyn_cast<DynamicAttr>(attr);
if (!dynAttr || dynAttr.getAttrDef() != attrDef) {
if (emitError) {
StringRef dialectName = attrDef->getDialect()->getNamespace();
@@ -102,7 +102,7 @@ LogicalResult DynParametricTypeConstraint::verify(
function_ref<InFlightDiagnostic()> emitError, Attribute attr,
ConstraintVerifier &context) const {
// Check that the base is a TypeAttr.
- auto typeAttr = attr.dyn_cast<TypeAttr>();
+ auto typeAttr = dyn_cast<TypeAttr>(attr);
if (!typeAttr) {
if (emitError)
return emitError() << "expected type, got attribute '" << attr;
@@ -110,7 +110,7 @@ LogicalResult DynParametricTypeConstraint::verify(
}
// Check that the type base is the expected one.
- auto dynType = typeAttr.getValue().dyn_cast<DynamicType>();
+ auto dynType = dyn_cast<DynamicType>(typeAttr.getValue());
if (!dynType || dynType.getTypeDef() != typeDef) {
if (emitError) {
StringRef dialectName = typeDef->getDialect()->getNamespace();
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
index 9f38b0caddb7d..ecdadd3062d36 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
@@ -25,11 +25,11 @@ using namespace mlir;
/// Attempt to extract a filename for the given loc.
static FileLineColLoc extractFileLoc(Location loc) {
- if (auto fileLoc = loc.dyn_cast<FileLineColLoc>())
+ if (auto fileLoc = dyn_cast<FileLineColLoc>(loc))
return fileLoc;
- if (auto nameLoc = loc.dyn_cast<NameLoc>())
+ if (auto nameLoc = dyn_cast<NameLoc>(loc))
return extractFileLoc(nameLoc.getChildLoc());
- if (auto opaqueLoc = loc.dyn_cast<OpaqueLoc>())
+ if (auto opaqueLoc = dyn_cast<OpaqueLoc>(loc))
return extractFileLoc(opaqueLoc.getFallbackLocation());
return FileLineColLoc();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 1936a53201f2d..02909bb69977f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -607,7 +607,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
return diag;
Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
- if (getResult().getType().isa<TransformValueHandleTypeInterface>()) {
+ if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
results.setValues(cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
@@ -648,7 +648,7 @@ transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
LogicalResult transform::MatchStructuredResultOp::verify() {
if ((getAny() || getSingle()) ^
- getResult().getType().isa<TransformHandleTypeInterface>()) {
+ isa<TransformHandleTypeInterface>(getResult().getType())) {
return emitOpError() << "expects either the any/single keyword or the type "
"value handle result type";
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 74eb3a2df0f96..ea8d285cf52b7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -87,7 +87,7 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
for (OpFoldResult ofr : ofrs) {
if (ofr.is<Attribute>()) {
- if (!ofr.get<Attribute>().isa<IntegerAttr>())
+ if (!isa<IntegerAttr>(ofr.get<Attribute>()))
return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
result.push_back(ofr);
continue;
@@ -155,7 +155,7 @@ transform::BufferizeToAllocationOp::apply(transform::TransformResults &results,
llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) {
return linalg::bufferizeToAllocation(rewriter, v, memorySpace);
}));
- results.setValues(getTransformed().cast<OpResult>(), transformed);
+ results.setValues(cast<OpResult>(getTransformed()), transformed);
return DiagnosedSilenceableFailure::success();
}
@@ -276,7 +276,7 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
if (!sizesAttr)
return parser.emitError(opLoc)
<< "expected '" << sizesAttrName << "' attribute";
- auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
+ auto sizesArrayAttr = dyn_cast<ArrayAttr>(sizesAttr);
if (!sizesArrayAttr)
return parser.emitError(opLoc)
<< "'" << sizesAttrName << "' attribute must be an array";
@@ -389,7 +389,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
// Tile the producer.
int64_t resultNumber =
- sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
+ cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
FailureOr<TilingResult> tileAndFuseResult =
@@ -411,10 +411,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
// Replace the extract op.
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
- sliceOpToTile->getResult(0)
- .getType()
- .cast<RankedTensorType>()
- .getShape());
+ cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
return tileAndFuseResult->tiledOps;
@@ -482,7 +479,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Replace the use in the tileableProducer before tiling: clone, replace and
// then tile.
- int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
+ int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
// Gather destination tensors.
@@ -516,10 +513,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Replace the extract op.
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
- sliceOpToTile->getResult(0)
- .getType()
- .cast<RankedTensorType>()
- .getShape());
+ cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
@@ -568,7 +562,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
// TODO: Generalize to other type of ops.
assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
"Parallel insert slice is not a valid clone destination");
- unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
+ unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
OpBuilder::InsertionGuard guard(rewriter);
@@ -587,8 +581,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
// If nothing to fuse, propagate success.
if (producerOps.empty()) {
- results.set(getFusedOp().cast<OpResult>(),
- SmallVector<mlir::Operation *>{});
+ results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
return DiagnosedSilenceableFailure::success();
}
ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
@@ -671,7 +664,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
- results.set(getFusedOp().cast<OpResult>(), fusedOps);
+ results.set(cast<OpResult>(getFusedOp()), fusedOps);
return DiagnosedSilenceableFailure::success();
}
@@ -865,7 +858,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
};
payloadOps.front()->walk(matchFun);
- results.set(getResult().cast<OpResult>(), res);
+ results.set(cast<OpResult>(getResult()), res);
return DiagnosedSilenceableFailure::success();
}
@@ -901,7 +894,7 @@ static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
LinalgOp target, transform::ApplyToEachResultList &results,
TransformState &state) {
- if (getLowSize().getType().isa<TransformParamTypeInterface>()) {
+ if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
if (target.hasDynamicShape()) {
auto diag = emitSilenceableError()
<< "cannot compute parametric tile sizes for dynamically "
@@ -923,7 +916,7 @@ DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
spec->lowTileSize * spec->lowTripCount}),
[&builder, this](int64_t value) {
return builder.getIntegerAttr(
- getLowSize().getType().cast<ParamType>().getType(), value);
+ cast<ParamType>(getLowSize().getType()).getType(), value);
}));
return DiagnosedSilenceableFailure::success();
}
@@ -958,7 +951,7 @@ void transform::MultiTileSizesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTarget(), effects);
producesHandle(getResults(), effects);
- if (getLowSize().getType().isa<TransformParamTypeInterface>())
+ if (isa<TransformParamTypeInterface>(getLowSize().getType()))
onlyReadsPayload(effects);
else
modifiesPayload(effects);
@@ -1006,7 +999,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
// If nothing to pack, propagate success.
if (targetOps.empty()) {
- transformResults.set(getPackedOp().cast<OpResult>(), {});
+ transformResults.set(cast<OpResult>(getPackedOp()), {});
return DiagnosedSilenceableFailure::success();
}
// Fail on multi-op handles.
@@ -1036,7 +1029,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
if (failed(maybeResult))
return emitDefiniteFailure("data tiling failed");
- transformResults.set(getPackedOp().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getPackedOp()),
maybeResult->packedLinalgOp.getOperation());
return DiagnosedSilenceableFailure::success();
}
@@ -1242,7 +1235,7 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
}
results.push_back(linalgOp);
}
- transformResults.set(getPackedOp().cast<OpResult>(), results);
+ transformResults.set(cast<OpResult>(getPackedOp()), results);
return DiagnosedSilenceableFailure::success();
}
@@ -1322,9 +1315,9 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
ArrayRef<Operation *> linalgOps = state.getPayloadOps(getTargetLinalgOp());
// Step 1. If nothing to pack, propagate success.
if (packOrUnpackOps.empty()) {
- transformResults.set(getPackedOp().cast<OpResult>(), {});
- transformResults.set(getPackOp().cast<OpResult>(), {});
- transformResults.set(getUnPackOp().cast<OpResult>(), {});
+ transformResults.set(cast<OpResult>(getPackedOp()), {});
+ transformResults.set(cast<OpResult>(getPackOp()), {});
+ transformResults.set(cast<OpResult>(getUnPackOp()), {});
return DiagnosedSilenceableFailure::success();
}
@@ -1366,7 +1359,7 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
if (unPackOp) {
assert(!packOp && "packOp must be null on entry when unPackOp is not null");
OpOperand *packUse = linalgOp.getDpsInitOperand(
- unPackOp.getSource().cast<OpResult>().getResultNumber());
+ cast<OpResult>(unPackOp.getSource()).getResultNumber());
packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
if (!packOp || !packOp.getResult().hasOneUse())
return emitSilenceableError() << "could not find matching pack op";
@@ -1400,14 +1393,14 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
assert(succeeded(res) && "unexpected packTranspose failure");
// Step 4. Return results.
- transformResults.set(getPackOp().cast<OpResult>(), {res->transposedPackOp});
- transformResults.set(getPackedOp().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
+ transformResults.set(cast<OpResult>(getPackedOp()),
{res->transposedLinalgOp});
if (unPackOp) {
- transformResults.set(getUnPackOp().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getUnPackOp()),
{res->transposedUnPackOp});
} else {
- transformResults.set(getUnPackOp().cast<OpResult>(), {});
+ transformResults.set(cast<OpResult>(getUnPackOp()), {});
}
return DiagnosedSilenceableFailure::success();
@@ -1430,14 +1423,14 @@ transform::PadOp::applyToOne(LinalgOp target,
SmallVector<Attribute> paddingValues;
for (auto const &it :
llvm::zip(getPaddingValues(), target->getOperandTypes())) {
- auto attr = std::get<0>(it).dyn_cast<TypedAttr>();
+ auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
if (!attr) {
emitOpError("expects padding values to be typed attributes");
return DiagnosedSilenceableFailure::definiteFailure();
}
Type elementType = getElementTypeOrSelf(std::get<1>(it));
// Try to parse string attributes to obtain an attribute of element type.
- if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
+ if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
auto parsedAttr = dyn_cast_if_present<TypedAttr>(
parseAttribute(stringAttr, getContext(), elementType,
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
@@ -1462,9 +1455,9 @@ transform::PadOp::applyToOne(LinalgOp target,
// Extract the transpose vectors.
SmallVector<SmallVector<int64_t>> transposePaddings;
- for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
+ for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
transposePaddings.push_back(
- extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
+ extractFromI64ArrayAttr(cast<ArrayAttr>(transposeVector)));
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
@@ -1549,13 +1542,13 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
return emitDefiniteFailure() << "could not build packing loop nest";
if (result->clonedLoopIvs.empty()) {
- transformResults.set(getPackingLoop().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getPackingLoop()),
result->hoistedPadOp.getOperation());
return DiagnosedSilenceableFailure::success();
}
auto outerPackedLoop =
scf::getForInductionVarOwner(result->clonedLoopIvs.front());
- transformResults.set(getPackingLoop().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getPackingLoop()),
outerPackedLoop.getOperation());
return DiagnosedSilenceableFailure::success();
}
@@ -1643,7 +1636,7 @@ transform::PromoteOp::applyToOne(LinalgOp target,
if (mapping.size() > 1)
return emitDefaultDefiniteFailure(target);
- auto addressSpace = mapping[0].cast<gpu::GPUMemorySpaceMappingAttr>();
+ auto addressSpace = cast<gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
if (addressSpace.getAddressSpace() ==
gpu::GPUDialect::getWorkgroupAddressSpace()) {
@@ -1711,7 +1704,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
rewriter.replaceOp(target, replacement->getResults());
replacements.push_back(replacement);
}
- transformResults.set(getReplacement().cast<OpResult>(), replacements);
+ transformResults.set(cast<OpResult>(getReplacement()), replacements);
return DiagnosedSilenceableFailure::success();
}
@@ -1828,7 +1821,7 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
splitPoints.reserve(payload.size());
if (getDynamicSplitPoint()) {
auto diag = DiagnosedSilenceableFailure::success();
- if (getDynamicSplitPoint().getType().isa<TransformHandleTypeInterface>()) {
+ if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
splitPoints = llvm::to_vector(llvm::map_range(
state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
if (op->getNumResults() != 1 ||
@@ -1909,8 +1902,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
return diag;
}
- results.set(getFirst().cast<OpResult>(), first);
- results.set(getSecond().cast<OpResult>(), second);
+ results.set(cast<OpResult>(getFirst()), first);
+ results.set(cast<OpResult>(getSecond()), second);
return DiagnosedSilenceableFailure::success();
}
@@ -2212,12 +2205,12 @@ transform::TileOp::apply(TransformResults &transformResults,
dynamicSizeProducers.reserve(getDynamicSizes().size());
paramSizes.reserve(getDynamicSizes().size());
for (Value transformValue : getDynamicSizes()) {
- if (transformValue.getType().isa<ParamType>()) {
+ if (isa<ParamType>(transformValue.getType())) {
dynamicSizeProducers.push_back({});
ArrayRef<Attribute> params = state.getParams(transformValue);
paramSizes.push_back(
llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
- return attr.cast<IntegerAttr>().getValue().getSExtValue();
+ return cast<IntegerAttr>(attr).getValue().getSExtValue();
})));
if (paramSizes.back().size() != targets.size()) {
@@ -2247,7 +2240,7 @@ transform::TileOp::apply(TransformResults &transformResults,
for (Operation *op : dynamicSizeProducers.back()) {
if (op->getNumResults() == 1 &&
- op->getResult(0).getType().isa<IndexType>())
+ isa<IndexType>(op->getResult(0).getType()))
continue;
DiagnosedSilenceableFailure diag =
@@ -2283,7 +2276,7 @@ transform::TileOp::apply(TransformResults &transformResults,
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
- getLoc(), attr.cast<IntegerAttr>().getInt()));
+ getLoc(), cast<IntegerAttr>(attr).getInt()));
continue;
}
ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
@@ -2320,9 +2313,9 @@ transform::TileOp::apply(TransformResults &transformResults,
loops[en2.index()].push_back(en2.value());
}
- transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
+ transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
for (const auto &en : llvm::enumerate(loops))
- transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
+ transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
return DiagnosedSilenceableFailure::success();
}
@@ -2582,8 +2575,8 @@ transform::TileToForallOp::apply(transform::TransformResults &transformResults,
tiledOps.push_back(tilingResult.tiledOp);
}
- transformResults.set(getForallOp().cast<OpResult>(), tileOps);
- transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
+ transformResults.set(cast<OpResult>(getForallOp()), tileOps);
+ transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
return DiagnosedSilenceableFailure::success();
}
@@ -2678,7 +2671,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
for (Operation *op : dynamicSizeProducers.back()) {
if (op->getNumResults() == 1 &&
- op->getResult(0).getType().isa<IndexType>())
+ isa<IndexType>(op->getResult(0).getType()))
continue;
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected sizes to be produced by ops "
@@ -2712,7 +2705,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
- getLoc(), attr.cast<IntegerAttr>().getInt()));
+ getLoc(), cast<IntegerAttr>(attr).getInt()));
} else {
sizes.push_back(
dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
@@ -2737,9 +2730,9 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
loops[en2.index()].push_back(en2.value());
}
- transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
+ transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
for (const auto &en : llvm::enumerate(loops))
- transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
+ transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
return DiagnosedSilenceableFailure::success();
}
@@ -2899,7 +2892,7 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
for (OpFoldResult sz : getMixedVectorSizes()) {
if (sz.is<Attribute>()) {
auto attr = sz.get<Attribute>();
- vectorSizes.push_back(attr.cast<IntegerAttr>().getInt());
+ vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
continue;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 1a7d7a113b22b..6b06c32d22eba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -64,20 +64,20 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
if (genericOp.getNumDpsInits() != 1)
return failure();
- auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
+ auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
// Require the output types to be static given that we are generating
// constants.
if (!outputType || !outputType.hasStaticShape())
return failure();
if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
- return input.getType().isa<ShapedType>();
+ return isa<ShapedType>(input.getType());
}))
return failure();
// Make sure all element types are the same.
auto getOperandElementType = [](Value value) {
- return value.getType().cast<ShapedType>().getElementType();
+ return cast<ShapedType>(value.getType()).getElementType();
};
if (!llvm::all_equal(
llvm::map_range(genericOp->getOperands(), getOperandElementType)))
@@ -138,7 +138,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
// unify the following cases but they have lifetime as the MLIRContext.
SmallVector<APInt> intOutputValues;
SmallVector<APFloat> fpOutputValues;
- if (elementType.template isa<FloatType>())
+ if (isa<FloatType>(elementType))
fpOutputValues.resize(numElements, APFloat(0.f));
else
intOutputValues.resize(numElements);
@@ -174,7 +174,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
auto inputShapes = llvm::to_vector<4>(
llvm::map_range(genericOp.getInputs(), [](Value value) {
- return value.getType().cast<ShapedType>().getShape();
+ return cast<ShapedType>(value.getType()).getShape();
}));
// Given a `linearIndex`, remap it to a linear index to access linalg op
@@ -205,7 +205,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
}
};
- bool isFloat = elementType.isa<FloatType>();
+ bool isFloat = isa<FloatType>(elementType);
if (isFloat) {
SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
for (int i = 0; i < numInputs; ++i)
@@ -282,7 +282,7 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
// The yield op should return the block argument corresponds to the input.
for (Value yieldVal : yieldOp.getValues()) {
- auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+ auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
if (!yieldArg || yieldArg.getOwner() != &body)
return nullptr;
if (yieldArg.getArgNumber() != 0)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 5423cf8d750fc..48c24598f628f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -29,7 +29,7 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
}
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
- bool isInt = x.getType().isa<IntegerType>();
+ bool isInt = isa<IntegerType>(x.getType());
if (isInt)
return builder.create<arith::AddIOp>(loc, x, y);
return builder.create<arith::AddFOp>(loc, x, y);
@@ -42,7 +42,7 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
Value yConvert =
convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
- if (accType.isa<IntegerType>())
+ if (isa<IntegerType>(accType))
return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
}
@@ -74,9 +74,9 @@ static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
- auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
- auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
- auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+ auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+ auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+ auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
@@ -210,9 +210,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter,
linalg::DepthwiseConv2DNhwcHwcOp convOp) {
- auto inputType = convOp.getInputs()[0].getType().cast<RankedTensorType>();
- auto filterType = convOp.getInputs()[1].getType().cast<RankedTensorType>();
- auto outputType = convOp.getOutputs()[0].getType().cast<RankedTensorType>();
+ auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
+ auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
+ auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
@@ -230,7 +230,7 @@ rewriteInIm2Col(RewriterBase &rewriter,
Location loc = convOp.getLoc();
auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
- auto operandTensorType = operand.getType().cast<RankedTensorType>();
+ auto operandTensorType = cast<RankedTensorType>(operand.getType());
auto nloops = indices.size();
ArrayRef<int64_t> inputShape = operandTensorType.getShape();
@@ -272,7 +272,7 @@ rewriteInIm2Col(RewriterBase &rewriter,
Value inputT = transposeOperand(input, {0, 3, 1, 2});
Value filterT = transposeOperand(filter, {2, 0, 1});
ArrayRef<int64_t> filterTShape =
- filterT.getType().cast<RankedTensorType>().getShape();
+ cast<RankedTensorType>(filterT.getType()).getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
int n = outputShape[0];
@@ -360,9 +360,9 @@ rewriteInIm2Col(RewriterBase &rewriter,
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
- auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
- auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
- auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+ auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+ auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+ auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 3ec5094ed90b5..a81a48df00b69 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -66,12 +66,12 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
Attribute constYieldedValue;
// Is the yielded value a bbArg defined outside of the PadOp?
bool outsideBbArg =
- yieldedValue.isa<BlockArgument>() &&
- yieldedValue.cast<BlockArgument>().getOwner()->getParentOp() !=
+ isa<BlockArgument>(yieldedValue) &&
+ cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
padOp.getOperation();
// Is the yielded value an OpResult defined outside of the PadOp?
bool outsideOpResult =
- yieldedValue.isa<OpResult>() &&
+ isa<OpResult>(yieldedValue) &&
yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
bool invariantYieldedValue = outsideBbArg || outsideOpResult;
if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
@@ -120,19 +120,19 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
Value value) {
- auto tensorType = value.getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(value.getType());
if (tensorType.hasStaticShape())
return {};
// Try to reify dynamic sizes.
ReifiedRankedShapedTypeDims reifiedShape;
- if (value.isa<OpResult>() &&
+ if (isa<OpResult>(value) &&
succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) {
SmallVector<Value> dynSizes;
for (int64_t i = 0; i < tensorType.getRank(); ++i) {
if (tensorType.isDynamicDim(i))
dynSizes.push_back(
- reifiedShape[value.cast<OpResult>().getResultNumber()][i]
+ reifiedShape[cast<OpResult>(value).getResultNumber()][i]
.get<Value>());
}
return dynSizes;
@@ -153,12 +153,12 @@ static Value createAllocationForTensor(RewriterBase &rewriter, Location loc,
Value value,
Attribute memorySpace = {}) {
OpBuilder::InsertionGuard g(rewriter);
- auto tensorType = value.getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(value.getType());
// Create buffer allocation.
- auto memrefType = bufferization::getMemRefTypeWithStaticIdentityLayout(
- tensorType, memorySpace)
- .cast<MemRefType>();
+ auto memrefType =
+ cast<MemRefType>(bufferization::getMemRefTypeWithStaticIdentityLayout(
+ tensorType, memorySpace));
SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value);
Value alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
@@ -206,7 +206,7 @@ FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
Location loc = fromElementsOp.getLoc();
RankedTensorType tensorType =
- fromElementsOp.getType().cast<RankedTensorType>();
+ cast<RankedTensorType>(fromElementsOp.getType());
auto shape = tensorType.getShape();
// Create tensor.empty.
@@ -247,7 +247,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
return failure();
Location loc = generateOp.getLoc();
- RankedTensorType tensorType = generateOp.getType().cast<RankedTensorType>();
+ RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
// Create tensor.empty.
auto emptyOp =
@@ -339,7 +339,7 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
OpBuilder::InsertionGuard g(rewriter);
- if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(value)) {
rewriter.setInsertionPointToStart(bbArg.getOwner());
} else {
rewriter.setInsertionPointAfter(value.getDefiningOp());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e5764cb30e031..1ddd8b144c60e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -640,7 +640,7 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
auto loc = genericOp.getLoc();
Value unPackDest = producerUnPackOp.getDest();
auto genericOutType =
- genericOp.getDpsInitOperand(0)->get().getType().cast<RankedTensorType>();
+ cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
if (producerUnPackOp.getDestType() != genericOutType ||
!genericOutType.hasStaticShape()) {
unPackDest = tensor::UnPackOp::createDestinationTensor(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index e381b0aa011cb..42f87a16c92f3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -132,12 +132,12 @@ SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
static Value getZero(OpBuilder &b, Location loc, Type elementType) {
assert(elementType.isIntOrIndexOrFloat() &&
"expected scalar type while computing zero value");
- if (elementType.isa<IntegerType>())
+ if (isa<IntegerType>(elementType))
return b.create<arith::ConstantIntOp>(loc, 0, elementType);
if (elementType.isIndex())
return b.create<arith::ConstantIndexOp>(loc, 0);
// Assume float.
- auto floatType = elementType.cast<FloatType>();
+ auto floatType = cast<FloatType>(elementType);
return b.create<arith::ConstantFloatOp>(
loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
}
@@ -179,7 +179,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
if (resultNumber) {
newInitValues.push_back(
genericOp.getDpsInitOperand(*resultNumber)->get());
- OpResult result = genericOp.getResult(*resultNumber).cast<OpResult>();
+ OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
newResultTypes.push_back(result.getType());
peeledGenericOpIndexingMaps.push_back(
genericOp.getIndexingMapMatchingResult(result));
@@ -231,7 +231,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
}));
for (auto resultNum :
llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
- OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
+ OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
indexingMaps.push_back(
peeledGenericOp.getIndexingMapMatchingResult(result));
}
@@ -348,7 +348,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
/// the peeled operation.
SmallVector<Value> replacements;
for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
- OpResult opr = yieldValue.value().dyn_cast<OpResult>();
+ OpResult opr = dyn_cast<OpResult>(yieldValue.value());
if (!opr || opr.getOwner() != peeledScalarOperation)
replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
else
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 5fd48853875ce..bf91a708ae158 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -32,7 +32,7 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
auto inputType = inputs[0].getType();
- if (inputType.isa<TensorType>())
+ if (isa<TensorType>(inputType))
return nullptr;
// A detensored value is converted back by creating a new tensor from its
@@ -320,9 +320,9 @@ struct LinalgDetensorize
// * Add the argument to blockArgsToDetensor.
// * Walk the use-def chain backwards to add each predecessor's
// terminator-operands corresponding to currentItem to workList.
- if (currentItem.dyn_cast<BlockArgument>()) {
+ if (dyn_cast<BlockArgument>(currentItem)) {
BlockArgument currentItemBlockArgument =
- currentItem.cast<BlockArgument>();
+ cast<BlockArgument>(currentItem);
Block *ownerBlock = currentItemBlockArgument.getOwner();
// Function arguments are not detensored/converted.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 4a2c0a64fc07a..d8eccb9675894 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -308,7 +308,7 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
for (OpOperand *op : candidates) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfterValue(op->get());
- auto elemType = op->get().getType().cast<ShapedType>().getElementType();
+ auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
auto empty = rewriter.create<tensor::EmptyOp>(
loc, tensor::createDimValues(rewriter, loc, op->get()), elemType);
@@ -387,7 +387,7 @@ replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
// Early return for memrefs with affine maps to represent that we will always
// leave them unchanged.
Type actualType = opOperand->get().getType();
- if (auto memref = actualType.dyn_cast<MemRefType>()) {
+ if (auto memref = dyn_cast<MemRefType>(actualType)) {
if (!memref.getLayout().isIdentity())
return std::nullopt;
}
@@ -437,7 +437,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
ArrayRef<ReassociationIndices> reassociation, Location loc,
PatternRewriter &rewriter) const {
// There are no results for memref outputs.
- auto origResultType = origOutput.getType().cast<RankedTensorType>();
+ auto origResultType = cast<RankedTensorType>(origOutput.getType());
if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
unsigned rank = origResultType.getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
@@ -459,7 +459,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
Value collapseValue(Value operand, ArrayRef<int64_t> targetShape,
ArrayRef<ReassociationIndices> reassociation,
Location loc, PatternRewriter &rewriter) const {
- if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+ if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
@@ -478,7 +478,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
reassociation);
}
- if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
+ if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
@@ -502,7 +502,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
PatternRewriter &rewriter) const override {
// Skip the pattern if the op has any tensor with special encoding.
if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
- auto tensorType = type.dyn_cast<RankedTensorType>();
+ auto tensorType = dyn_cast<RankedTensorType>(type);
return tensorType && tensorType.getEncoding() != nullptr;
}))
return failure();
@@ -607,11 +607,10 @@ struct RankReducedExtractSliceOp
if (!reassociation ||
reassociation->size() == static_cast<size_t>(resultType.getRank()))
return failure();
- auto rankReducedType =
+ auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
- strides)
- .cast<RankedTensorType>();
+ strides));
Location loc = sliceOp.getLoc();
Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf728a6ec319b..33ff4a3ecc091 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -87,7 +87,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
// type. Producer must have full tensor semantics to avoid potential
// aliasing between producer and consumer memrefs.
if (!producer.hasTensorSemantics() ||
- !fusedOperand->get().getType().isa<RankedTensorType>())
+ !isa<RankedTensorType>(fusedOperand->get().getType()))
return false;
// Verify that
@@ -232,14 +232,14 @@ static void generateFusedElementwiseOpRegion(
// forward the yield operand.
auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
unsigned producerResultNumber =
- fusedOperand->get().cast<OpResult>().getResultNumber();
+ cast<OpResult>(fusedOperand->get()).getResultNumber();
Value replacement =
mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
// Sanity checks, if replacement is not already in the mapper then it must be
// produced outside.
if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
- if (auto bb = replacement.dyn_cast<BlockArgument>())
+ if (auto bb = dyn_cast<BlockArgument>(replacement))
assert(bb.getOwner() != &producerBlock &&
"yielded block argument must have been mapped");
else
@@ -278,7 +278,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
OpOperand *fusedOperand) {
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
- auto producerResult = fusedOperand->get().cast<OpResult>();
+ auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
@@ -357,7 +357,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
fusedOutputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
Type resultType = opOperand->get().getType();
- if (!resultType.isa<MemRefType>())
+ if (!isa<MemRefType>(resultType))
fusedResultTypes.push_back(resultType);
}
@@ -512,7 +512,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
return genericOp.hasTensorSemantics() &&
llvm::all_of(genericOp.getIndexingMaps().getValue(),
[](Attribute attr) {
- return attr.cast<AffineMapAttr>()
+ return cast<AffineMapAttr>(attr)
.getValue()
.isProjectedPermutation();
}) &&
@@ -776,7 +776,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
continue;
}
if (auto opOperandType =
- opOperand->get().getType().dyn_cast<RankedTensorType>()) {
+ dyn_cast<RankedTensorType>(opOperand->get().getType())) {
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
RankedTensorType expandedOperandType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
@@ -805,7 +805,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
SmallVector<Value> outputs;
for (OpOperand *opOperand : genericOp.getDpsInitOperands()) {
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
- auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
+ auto opOperandType = cast<RankedTensorType>(opOperand->get().getType());
RankedTensorType expandedOutputType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand->get().getType()) {
@@ -921,7 +921,7 @@ struct FoldReshapeWithGenericOpByExpansion
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if all constraints of fusing with reshape by expansion are met.
- auto producerResult = reshapeOp.getSrc().dyn_cast<OpResult>();
+ auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
if (!producerResult) {
return rewriter.notifyMatchFailure(reshapeOp,
"source not produced by an operation");
@@ -959,8 +959,9 @@ struct FoldReshapeWithGenericOpByExpansion
// same type as the returns of the original generic op, the consumer reshape
// op can be replaced by the source of the collapse_shape op that defines
// the replacement.
- Value reshapeReplacement = (*replacementValues)
- [reshapeOp.getSrc().cast<OpResult>().getResultNumber()];
+ Value reshapeReplacement =
+ (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
+ .getResultNumber()];
if (auto collapseOp =
reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
reshapeReplacement = collapseOp.getSrc();
@@ -1447,7 +1448,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
.createLoopRanges(rewriter, genericOp.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = ofr.dyn_cast<Attribute>())
- return attr.cast<IntegerAttr>().getInt() == value;
+ return cast<IntegerAttr>(attr).getInt() == value;
llvm::APInt actual;
return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
actual.getSExtValue() == value;
@@ -1521,8 +1522,8 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
Value collapsedOpResult =
collapsedGenericOp->getResult(originalResult.index());
auto originalResultType =
- originalResult.value().getType().cast<ShapedType>();
- auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
+ cast<ShapedType>(originalResult.value().getType());
+ auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
genericOp.getIndexingMapMatchingResult(originalResult.value());
@@ -1671,7 +1672,7 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
return false;
};
- auto resultValue = opOperand->get().dyn_cast<OpResult>();
+ auto resultValue = dyn_cast<OpResult>(opOperand->get());
if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
continue;
@@ -1756,7 +1757,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
for (OpOperand *opOperand : op.getDpsInitOperands()) {
if (!op.payloadUsesValueFromOperand(opOperand)) {
Value operandVal = opOperand->get();
- auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
+ auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
if (!operandType)
continue;
@@ -1810,7 +1811,7 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
fillFound = true;
Value fillVal = fillOp.value();
auto resultType =
- fillOp.result().getType().cast<RankedTensorType>().getElementType();
+ cast<RankedTensorType>(fillOp.result().getType()).getElementType();
Value convertedVal =
convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
/*isUnsignedCast =*/false);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 549764de593bd..18026cc150337 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -28,7 +28,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
// TODO: The conversion pattern can be made to work for `any_of` here, but
// it's more complex as it requires tracking which operands are scalars.
return llvm::all_of(op->getOperandTypes(),
- [](Type type) { return type.isa<RankedTensorType>(); });
+ [](Type type) { return isa<RankedTensorType>(type); });
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
@@ -67,7 +67,7 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
// Extract static / dynamic shape mix from the first operand.
Value firstOperand = operands.front();
- auto rankedTensorType = t.cast<RankedTensorType>();
+ auto rankedTensorType = cast<RankedTensorType>(t);
auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape());
auto dynamicShape = linalg::createDynamicDimensions(b, loc, firstOperand);
@@ -87,7 +87,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
- auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
+ auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
SmallVector<AffineMap, 3> indexingMaps(
op->getNumResults() + op->getNumOperands(),
rewriter.getMultiDimIdentityMap(rank));
@@ -104,7 +104,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
auto resultTypes = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
- return type.cast<TensorType>().getElementType();
+ return cast<TensorType>(type).getElementType();
}));
auto *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index defa027517584..c89fc5b9da8d3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -89,7 +89,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
Location loc = genericOp.getLoc();
SmallVector<Type> newResultTypes;
for (Value v : newOutputOperands)
- if (v.getType().isa<TensorType>())
+ if (isa<TensorType>(v.getType()))
newResultTypes.push_back(v.getType());
auto newOp = rewriter.create<GenericOp>(
loc, newResultTypes, newInputOperands, newOutputOperands,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
index b6e2ffcbba368..703db8373c31d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
@@ -86,12 +86,12 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// result of the generic op. The low pad values are the offsets, the size of
// the source is the size of the slice.
// TODO: This insert/extract could be potentially made a utility method.
- unsigned resultNumber = source.cast<OpResult>().getResultNumber();
+ unsigned resultNumber = cast<OpResult>(source).getResultNumber();
SmallVector<OpFoldResult> offsets = padOp.getMixedLowPad();
SmallVector<OpFoldResult> sizes;
sizes.reserve(offsets.size());
- for (const auto &shape : llvm::enumerate(
- source.getType().cast<RankedTensorType>().getShape())) {
+ for (const auto &shape :
+ llvm::enumerate(cast<RankedTensorType>(source.getType()).getShape())) {
if (ShapedType::isDynamic(shape.value())) {
sizes.push_back(
rewriter.create<tensor::DimOp>(loc, source, shape.index())
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 6f9b60843d6d5..cf3fd4ba0a0b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -151,7 +151,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(producer->getNumResults());
for (OpOperand *operand : producer.getDpsInitOperands()) {
- auto tensorType = operand->get().getType().dyn_cast<RankedTensorType>();
+ auto tensorType = dyn_cast<RankedTensorType>(operand->get().getType());
if (!tensorType)
continue;
unsigned rank = tensorType.getRank();
@@ -210,20 +210,20 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
// dependence tracking since the dependence tracking is similar to what is done
// w.r.t to buffers.
static void getProducerOfTensor(Value tensor, OpResult &opResult) {
- if (!tensor.getType().isa<RankedTensorType>())
+ if (!isa<RankedTensorType>(tensor.getType()))
return;
while (true) {
LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
- opResult = tensor.cast<OpResult>();
+ opResult = cast<OpResult>(tensor);
return;
}
if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) {
tensor = sliceOp.getSource();
continue;
}
- if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
+ if (auto blockArg = dyn_cast<BlockArgument>(tensor)) {
if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
continue;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d8ecc807ea051..87aade3a3eec5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -227,7 +227,7 @@ SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
return {};
bbArgs.push_back(bbArg);
OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
- bbArg = iterArg->get().dyn_cast<BlockArgument>();
+ bbArg = dyn_cast<BlockArgument>(iterArg->get());
}
// Reverse the block arguments to order them from outer to inner.
@@ -358,13 +358,13 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
// Check if the producer is a LinalgOp possibly passed by iteration argument.
OpOperand *iterArg = nullptr;
- auto producerResult = sliceOp.getSource().dyn_cast<OpResult>();
- if (auto bbArg = sliceOp.getSource().dyn_cast<BlockArgument>()) {
+ auto producerResult = dyn_cast<OpResult>(sliceOp.getSource());
+ if (auto bbArg = dyn_cast<BlockArgument>(sliceOp.getSource())) {
iterArg = getTiedIterArg(bbArg);
// Check the iteration argument may be used to pass in the producer output.
if (!iterArg || hasOtherUses(bbArg, sliceOp))
return failure();
- producerResult = iterArg->get().dyn_cast<OpResult>();
+ producerResult = dyn_cast<OpResult>(iterArg->get());
}
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 251f7d8575f75..21d83d225d705 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -549,7 +549,7 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
int paddedRank = paddedTensorType.getRank();
// Step 0. Populate bvm with opToHoist.getSource if relevant.
- BlockArgument bbArg = opToHoist.getSource().dyn_cast<BlockArgument>();
+ BlockArgument bbArg = dyn_cast<BlockArgument>(opToHoist.getSource());
while (bbArg) {
auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp());
if (!forOp)
@@ -558,7 +558,7 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
break;
OpOperand &operand = forOp.getOpOperandForRegionIterArg(bbArg);
bvm.map(bbArg, operand.get());
- bbArg = operand.get().dyn_cast<BlockArgument>();
+ bbArg = dyn_cast<BlockArgument>(operand.get());
}
// Step 1. iteratively clone loops and push `hoistedPackedTensor`.
@@ -754,9 +754,8 @@ static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
if (!destOp)
break;
LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
- source =
- destOp.getDpsInitOperand(source.cast<OpResult>().getResultNumber())
- ->get();
+ source = destOp.getDpsInitOperand(cast<OpResult>(source).getResultNumber())
+ ->get();
}
LLVM_DEBUG(DBGS() << "--final source: " << source << "\n");
LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 13ec4d92ad26d..01b893a0e0a52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -86,7 +86,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
func.walk([&](vector::TransferReadOp transferRead) {
- if (!transferRead.getShapedType().isa<MemRefType>())
+ if (!isa<MemRefType>(transferRead.getShapedType()))
return WalkResult::advance();
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 23c831f0a018a..d91d8c4bf6100 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -162,7 +162,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
SmallVector<SmallVector<Value>, 8> indexing;
SmallVector<Value> outputBuffers;
for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) {
- if (!outputOperand->get().getType().isa<MemRefType>())
+ if (!isa<MemRefType>(outputOperand->get().getType()))
continue;
indexing.push_back(makeCanonicalAffineApplies(
b, loc, linalgOp.getMatchingIndexingMap(outputOperand),
@@ -242,7 +242,7 @@ static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
return failure();
// The induction variable is a block argument of the entry block of the
// loop operation.
- BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
+ BlockArgument ivVal = dyn_cast<BlockArgument>(iv);
if (!ivVal)
return failure();
loopSet.insert(ivVal.getOwner()->getParentOp());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
index cabd342d86c09..93fa5ff24ac6a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
@@ -44,9 +44,9 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
auto result = operation->getResult(0);
- auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
- auto initTy = init.getType().dyn_cast<RankedTensorType>();
- auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
+ auto kernelTy = dyn_cast<RankedTensorType>(kernel.getType());
+ auto initTy = dyn_cast<RankedTensorType>(init.getType());
+ auto resultTy = dyn_cast<RankedTensorType>(result.getType());
if (!kernelTy || !initTy || !resultTy)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 4fcffea14e035..d39cd0e686e00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -292,9 +292,9 @@ promoteSubViews(ImplicitLocOpBuilder &b,
})
.Case([&](ComplexType t) {
Value tmp;
- if (auto et = t.getElementType().dyn_cast<FloatType>())
+ if (auto et = dyn_cast<FloatType>(t.getElementType()))
tmp = b.create<arith::ConstantOp>(FloatAttr::get(et, 0.0));
- else if (auto et = t.getElementType().cast<IntegerType>())
+ else if (auto et = cast<IntegerType>(t.getElementType()))
tmp = b.create<arith::ConstantOp>(IntegerAttr::get(et, 0));
return b.create<complex::CreateOp>(t, tmp, tmp);
})
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index 344b2893c906e..203ae437a2a5a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -93,7 +93,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
minSplitPoint});
if (auto attr = remainingSize.dyn_cast<Attribute>()) {
- if (attr.cast<IntegerAttr>().getValue().isZero())
+ if (cast<IntegerAttr>(attr).getValue().isZero())
return {op, TilingInterface()};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index b4d95b70de839..982b0243e953a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -113,7 +113,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
}
Type newType = RankedTensorType::get(
newShape,
- operand->get().getType().cast<RankedTensorType>().getElementType());
+ cast<RankedTensorType>(operand->get().getType()).getElementType());
Value newInput = b.create<tensor::ExpandShapeOp>(
loc, newType, operand->get(), reassociation);
newInputs.push_back(newInput);
@@ -309,7 +309,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
fillOps.reserve(op.getNumDpsInits());
for (auto it : llvm::zip(op.getDpsInitOperands(), neutralElements)) {
Value rankedTensor = std::get<0>(it)->get();
- auto t = rankedTensor.getType().cast<RankedTensorType>();
+ auto t = cast<RankedTensorType>(rankedTensor.getType());
RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
reductionDimSize / splitFactor, insertSplitDimension);
SmallVector<Value> dims =
@@ -383,7 +383,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
combinerOps)) {
Value reindexedOutput = std::get<0>(it);
Value originalOutput = std::get<1>(it)->get();
- auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
+ auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
Operation *combinerOp = std::get<2>(it);
AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
index c0355a14d366b..f4556787668d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
@@ -65,7 +65,7 @@ static FailureOr<tensor::ExtractSliceOp>
findHoistableMatchingExtractSlice(RewriterBase &rewriter,
tensor::InsertSliceOp insertSliceOp,
BlockArgument srcTensor) {
- assert(srcTensor.getType().isa<RankedTensorType>() && "not a ranked tensor");
+ assert(isa<RankedTensorType>(srcTensor.getType()) && "not a ranked tensor");
auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
@@ -92,7 +92,7 @@ findHoistableMatchingExtractSlice(RewriterBase &rewriter,
// Skip insert_slice whose vector is defined within the loop: we need to
// hoist that definition first otherwise dominance violations trigger.
- if (!extractSliceOp.getSource().isa<BlockArgument>() &&
+ if (!isa<BlockArgument>(extractSliceOp.getSource()) &&
!forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n");
continue;
@@ -119,7 +119,7 @@ static FailureOr<vector::TransferReadOp>
findHoistableMatchingTransferRead(RewriterBase &rewriter,
vector::TransferWriteOp transferWriteOp,
BlockArgument srcTensor) {
- if (!srcTensor.getType().isa<RankedTensorType>())
+ if (!isa<RankedTensorType>(srcTensor.getType()))
return failure();
auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
@@ -152,7 +152,7 @@ findHoistableMatchingTransferRead(RewriterBase &rewriter,
// transfer_read may be of a vector that is defined within the loop: we
// traverse it by virtue of bypassing disjoint subset operations rooted at
// a bbArg and yielding a matching yield.
- if (!read.getSource().isa<BlockArgument>() &&
+ if (!isa<BlockArgument>(read.getSource()) &&
!forOp.isDefinedOutsideOfLoop(read.getSource())) {
LLVM_DEBUG(DBGS() << "------transfer_read vector appears loop "
"dependent but will be tested for disjointness as "
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 1ff11665b402c..57798fc78ea4b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -49,7 +49,7 @@ static bool isZero(OpFoldResult v) {
if (!v)
return false;
if (auto attr = v.dyn_cast<Attribute>()) {
- IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
+ IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
@@ -105,7 +105,7 @@ void mlir::linalg::transformIndexOps(
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
OpFoldResult value) {
if (auto attr = value.dyn_cast<Attribute>()) {
- assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() &&
+ assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
"expected strictly positive tile size and divisor");
return;
}
@@ -587,8 +587,8 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
SmallVector<Operation *, 8> loops;
loops.reserve(ivs.size());
for (auto iv : ivs) {
- if (iv.isa<BlockArgument>()) {
- loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
+ if (isa<BlockArgument>(iv)) {
+ loops.push_back(cast<BlockArgument>(iv).getOwner()->getParentOp());
assert(loops.back() && "no owner found for induction variable!");
} else {
// TODO: Instead of doing this, try to recover the ops used instead of the
@@ -712,7 +712,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
outOffsets[reductionDim] = forallOp.getInductionVars().front();
// TODO: use SubsetExtractOpInterface once it is available.
tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
- loc, initOperand->get().getType().cast<RankedTensorType>(),
+ loc, cast<RankedTensorType>(initOperand->get().getType()),
destBbArgs[destNum], outOffsets, sizes, strides));
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 1c3745f66cbf1..36f13fa64dccb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -365,8 +365,7 @@ struct LinalgOpPartialReductionInterface
// Then create a new reduction that only reduce the newly added dimension
// from the previous op.
- int64_t intermRank =
- partialReduce[0].getType().cast<ShapedType>().getRank();
+ int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
SmallVector<utils::IteratorType> reductionIteratorTypes;
SmallVector<AffineExpr> exprs;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a9e8ac0bbabbb..230089582f257 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -89,7 +89,7 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
// Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
OpOperand *currOpOperand = opOperand;
while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
- OpResult result = currOpOperand->get().cast<OpResult>();
+ OpResult result = cast<OpResult>(currOpOperand->get());
currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber());
}
@@ -133,7 +133,7 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
// If the size is an attribute add it directly to `paddedShape`.
if (en.value().is<Attribute>()) {
paddedShape[shapeIdx++] =
- en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt();
+ dyn_cast<IntegerAttr>(en.value().get<Attribute>()).getInt();
LLVM_DEBUG(
DBGS() << "------dim is an attr, add it to padded shape, SKIP\n");
continue;
@@ -232,7 +232,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
- int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
+ int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
@@ -476,7 +476,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp) {
// 1. Filter out NYI cases.
auto packedTensorType =
- packOp->getResultTypes().front().cast<RankedTensorType>();
+ cast<RankedTensorType>(packOp->getResultTypes().front());
if (llvm::any_of(packOp.getStaticInnerTiles(),
[](int64_t size) { return ShapedType::isDynamic(size); })) {
return rewriter.notifyMatchFailure(
@@ -639,7 +639,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
- auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
+ auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
if (unPackOp.isLikeUnPad()) {
// This unpack is just a plain unpad.
// Just extract the slice from the higher ranked tensor.
@@ -889,7 +889,7 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
// Sanity check of the expected transposed tensor type.
auto tensorType = permuteShape(
- opOperand.get().getType().cast<RankedTensorType>(), permutation);
+ cast<RankedTensorType>(opOperand.get().getType()), permutation);
(void)tensorType;
assert(tensorType == transposedValue.getType() &&
"expected tensor type mismatch");
@@ -1050,8 +1050,8 @@ LogicalResult
PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const {
- auto inputShapedType = padOp.getSource().getType().cast<ShapedType>();
- auto resultShapedType = padOp.getResult().getType().cast<ShapedType>();
+ auto inputShapedType = cast<ShapedType>(padOp.getSource().getType());
+ auto resultShapedType = cast<ShapedType>(padOp.getResult().getType());
// Bail on non-static shapes.
if (!inputShapedType.hasStaticShape())
@@ -1068,7 +1068,7 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
Operation *definingOp = padValue.getDefiningOp();
if (definingOp && definingOp->getBlock() == &block)
return failure();
- if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
+ if (!definingOp && cast<BlockArgument>(padValue).getOwner() == &block)
return failure();
// Create tensor with the padded shape
@@ -1134,7 +1134,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
return val;
return rewriter
.create<arith::ConstantIndexOp>(
- padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
+ padOp.getLoc(), cast<IntegerAttr>(ofr.get<Attribute>()).getInt())
.getResult();
};
@@ -1514,9 +1514,9 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
Value kernel = convOp.getInputs().back();
Value output = convOp.getOutputs().front();
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
@@ -1638,9 +1638,9 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
Value kernel = convOp.getInputs().back();
Value output = convOp.getOutputs().front();
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
@@ -1706,9 +1706,9 @@ DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
Value kernel = convOp.getInputs().back();
Value output = convOp.getOutputs().front();
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 56b4516452a11..2236d1bff1118 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -563,7 +563,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
loc, value, outputOperand->get(), indices, writeMap);
} else {
// 0-d case is still special: do not invert the reindexing writeMap.
- if (!value.getType().isa<VectorType>())
+ if (!isa<VectorType>(value.getType()))
value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
assert(value.getType() == vectorType && "incorrect type");
write = rewriter.create<vector::TransferWriteOp>(
@@ -864,7 +864,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
targetShape.back() == 1)
return VectorMemoryAccessKind::Gather;
- auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
+ auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
// 2. Assume that it's a gather load when reading _from_ a tensor for which
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
@@ -1024,8 +1024,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
const IRMapping &bvm) {
Value reduceVec = bvm.lookup(reduceValue);
Value outputVec = bvm.lookup(initialValue);
- auto reduceType = reduceVec.getType().dyn_cast<VectorType>();
- auto outputType = outputVec.getType().dyn_cast<VectorType>();
+ auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
+ auto outputType = dyn_cast<VectorType>(outputVec.getType());
// Reduce only if needed as the value may already have been reduce for
// contraction vectorization.
if (!reduceType ||
@@ -1082,7 +1082,7 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
// 4 . Check if the operation is a reduction.
SmallVector<std::pair<Value, Value>> reductionOperands;
for (Value operand : op->getOperands()) {
- auto blockArg = operand.dyn_cast<BlockArgument>();
+ auto blockArg = dyn_cast<BlockArgument>(operand);
if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
continue;
@@ -1107,7 +1107,7 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
// a. first get the first max ranked shape.
SmallVector<int64_t, 4> firstMaxRankedShape;
for (Value operand : op->getOperands()) {
- auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
+ auto vt = dyn_cast<VectorType>(bvm.lookup(operand).getType());
if (vt && firstMaxRankedShape.size() < vt.getShape().size())
firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
}
@@ -1230,7 +1230,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
- if (readValue.getType().cast<VectorType>().getRank() == 0)
+ if (cast<VectorType>(readValue.getType()).getRank() == 0)
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
@@ -1528,8 +1528,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
memref::CopyOp copyOp) {
- auto srcType = copyOp.getSource().getType().cast<MemRefType>();
- auto dstType = copyOp.getTarget().getType().cast<MemRefType>();
+ auto srcType = cast<MemRefType>(copyOp.getSource().getType());
+ auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
@@ -1549,7 +1549,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
Value readValue = rewriter.create<vector::TransferReadOp>(
loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
- if (readValue.getType().cast<VectorType>().getRank() == 0) {
+ if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
@@ -1566,7 +1566,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
/// Helper function that retrieves the value of an IntegerAttr.
static int64_t getIntFromAttr(Attribute attr) {
- return attr.cast<IntegerAttr>().getInt();
+ return cast<IntegerAttr>(attr).getInt();
}
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
@@ -1836,8 +1836,8 @@ struct PadOpVectorizationWithTransferWritePattern
if (hasSameTensorSize(castOp.getSource(), afterTrimming))
return true;
- auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
- auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
+ auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
+ auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
// Only RankedTensorType supported.
if (!t1 || !t2)
return false;
@@ -1946,7 +1946,7 @@ struct PadOpVectorizationWithInsertSlicePattern
if (!padValue)
return failure();
// Dynamic shapes not supported.
- if (!padOp.getResult().getType().cast<ShapedType>().hasStaticShape())
+ if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
return failure();
// Pad result not used as destination.
if (insertOp.getDest() == padOp.getResult())
@@ -2074,7 +2074,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
memref::CopyOp copyOp;
for (auto &u : subView.getUses()) {
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
- assert(newCopyOp.getTarget().getType().isa<MemRefType>());
+ assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
if (newCopyOp.getTarget() != subView)
continue;
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
@@ -2091,7 +2091,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
FillOp maybeFillOp;
for (auto &u : viewOrAlloc.getUses()) {
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
- assert(newFillOp.output().getType().isa<MemRefType>());
+ assert(isa<MemRefType>(newFillOp.output().getType()));
if (newFillOp.output() != viewOrAlloc)
continue;
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
@@ -2162,7 +2162,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
return rewriter.notifyMatchFailure(xferOp, "no copy found");
// `out` is the subview copied into that we replace.
- assert(copyOp.getTarget().getType().isa<MemRefType>());
+ assert(isa<MemRefType>(copyOp.getTarget().getType()));
Value out = copyOp.getTarget();
// Forward vector.transfer into copy.
@@ -2204,7 +2204,7 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
namespace {
bool isCastOfBlockArgument(Operation *op) {
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
- op->getOperand(0).isa<BlockArgument>();
+ isa<BlockArgument>(op->getOperand(0));
}
bool isSupportedPoolKind(vector::CombiningKind kind) {
@@ -2268,9 +2268,9 @@ struct Conv1DGenerator
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
resShaped = linalgOp.getDpsInitOperand(0)->get();
- lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
- rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
- resShapedType = resShaped.getType().dyn_cast<ShapedType>();
+ lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
+ rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
+ resShapedType = dyn_cast<ShapedType>(resShaped.getType());
if (!lhsShapedType || !rhsShapedType || !resShapedType)
return;
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
@@ -2717,8 +2717,8 @@ struct Conv1DGenerator
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
- auto rhsTy = rhs.getType().cast<ShapedType>();
- auto resTy = res.getType().cast<ShapedType>();
+ auto rhsTy = cast<ShapedType>(rhs.getType());
+ auto resTy = cast<ShapedType>(res.getType());
// TODO(suderman): Change this to use a vector.ima intrinsic.
lhs = promote(rewriter, loc, lhs, resTy);
@@ -2730,7 +2730,7 @@ struct Conv1DGenerator
if (!lhs || !rhs)
return nullptr;
- if (resTy.getElementType().isa<FloatType>())
+ if (isa<FloatType>(resTy.getElementType()))
return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
@@ -2863,15 +2863,14 @@ struct Conv1DGenerator
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
// must be block arguments or extension of block arguments.
bool setOperKind(Operation *reduceOp) {
- int numBlockArguments =
- llvm::count_if(reduceOp->getOperands(),
- [](Value v) { return v.isa<BlockArgument>(); });
+ int numBlockArguments = llvm::count_if(
+ reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
// Otherwise, if it can be pooling.
auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
- return !v.isa<BlockArgument>();
+ return !isa<BlockArgument>(v);
});
Operation *feedOp = (*feedValIt).getDefiningOp();
if (isCastOfBlockArgument(feedOp)) {
@@ -2880,7 +2879,7 @@ struct Conv1DGenerator
poolExtOp = feedOp->getName().getIdentifier();
} else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
llvm::all_of(feedOp->getOperands(), [](Value v) {
- if (v.isa<BlockArgument>())
+ if (isa<BlockArgument>(v))
return true;
if (Operation *op = v.getDefiningOp())
return isCastOfBlockArgument(op);
diff --git a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp
index 12b55ef2e6603..f7376c0d9602c 100644
--- a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp
@@ -43,16 +43,16 @@
namespace mlir {
namespace linalg {
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) {
- if (val.getType().isa<UnrankedMemRefType, MemRefType>())
+ if (isa<UnrankedMemRefType, MemRefType>(val.getType()))
return b.createOrFold<memref::DimOp>(loc, val, dim);
- if (val.getType().isa<UnrankedTensorType, RankedTensorType>())
+ if (isa<UnrankedTensorType, RankedTensorType>(val.getType()))
return b.createOrFold<tensor::DimOp>(loc, val, dim);
llvm_unreachable("Expected MemRefType or TensorType");
}
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
int64_t dim) {
- auto shapedType = val.getType().cast<ShapedType>();
+ auto shapedType = cast<ShapedType>(val.getType());
if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
return createOrFoldDimOp(b, loc, val, dim);
return b.getIndexAttr(shapedType.getDimSize(dim));
@@ -60,7 +60,7 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
SmallVector<Value> createDynamicDimensions(OpBuilder &b, Location loc,
Value val) {
- auto shapedType = val.getType().cast<ShapedType>();
+ auto shapedType = cast<ShapedType>(val.getType());
assert(shapedType.hasRank() && "`val` must have a static rank");
SmallVector<Value> res;
res.reserve(shapedType.getRank());
@@ -73,7 +73,7 @@ SmallVector<Value> createDynamicDimensions(OpBuilder &b, Location loc,
SmallVector<OpFoldResult> getMixedDimensions(OpBuilder &b, Location loc,
Value val) {
- auto shapedType = val.getType().cast<ShapedType>();
+ auto shapedType = cast<ShapedType>(val.getType());
assert(shapedType.hasRank() && "`val` must have a static rank");
SmallVector<Value> dynamicDims = createDynamicDimensions(b, loc, val);
return getMixedValues(shapedType.getShape(), dynamicDims, b);
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5e3413accf7c3..ef31668ed25b1 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -281,7 +281,7 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
auto linalgOp = current.getDefiningOp<LinalgOp>();
if (!linalgOp)
break;
- OpResult opResult = current.cast<OpResult>();
+ OpResult opResult = cast<OpResult>(current);
current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
}
auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
@@ -331,7 +331,7 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
Value outputTensor,
ArrayRef<int64_t> transposeVector) {
- auto resultTensorType = outputTensor.getType().cast<RankedTensorType>();
+ auto resultTensorType = cast<RankedTensorType>(outputTensor.getType());
Type elementType = resultTensorType.getElementType();
assert(isPermutationVector(transposeVector) &&
@@ -366,9 +366,9 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
}
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
- auto memrefTypeTo = to.getType().cast<MemRefType>();
+ auto memrefTypeTo = cast<MemRefType>(to.getType());
#ifndef NDEBUG
- auto memrefTypeFrom = from.getType().cast<MemRefType>();
+ auto memrefTypeFrom = cast<MemRefType>(from.getType());
assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
"`from` and `to` memref must have the same rank");
#endif // NDEBUG
@@ -650,7 +650,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
static Value materializeTiledShape(OpBuilder &builder, Location loc,
Value valueToTile,
const SliceParameters &sliceParams) {
- auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
+ auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
.Case([&](MemRefType) {
return builder.create<memref::SubViewOp>(
@@ -685,7 +685,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck) {
- auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
+ auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
assert(shapedType && "only shaped types can be tiled");
ArrayRef<int64_t> shape = shapedType.getShape();
int64_t rank = shapedType.getRank();
@@ -889,7 +889,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
// subdomains explicit.
Type operandType = opOperand.get().getType();
- if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() &&
+ if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
linalgOp.isDpsInit(&opOperand))) {
allSliceParams.push_back(std::nullopt);
LLVM_DEBUG(llvm::dbgs()
@@ -971,7 +971,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
auto size = it.value();
curr.push_back(dim);
auto attr = size.dyn_cast<Attribute>();
- if (attr && attr.cast<IntegerAttr>().getInt() == 1)
+ if (attr && cast<IntegerAttr>(attr).getInt() == 1)
continue;
reassociation.emplace_back(ReassociationIndices{});
std::swap(reassociation.back(), curr);
@@ -989,7 +989,7 @@ std::optional<TypedAttr> getNeutralElement(Operation *op) {
// Builder only used as helper for attribute creation.
OpBuilder b(op->getContext());
Type resultType = op->getResult(0).getType();
- if (auto floatType = resultType.dyn_cast<FloatType>()) {
+ if (auto floatType = dyn_cast<FloatType>(resultType)) {
const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
if (isa<arith::AddFOp>(op))
return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index c5e008e520473..dcace489673f0 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -64,7 +64,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&](Value value) -> Value {
- if (auto vec = op.getType().dyn_cast<VectorType>())
+ if (auto vec = dyn_cast<VectorType>(op.getType()))
return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
return value;
};
@@ -167,7 +167,7 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
- if (auto vec = op.getType().template dyn_cast<VectorType>())
+ if (auto vec = dyn_cast<VectorType>(op.getType()))
return rewriter.create<vector::BroadcastOp>(loc, vec, value);
return value;
};
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 6d286a31290e6..a3efc6ef41a95 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -40,7 +40,7 @@ using namespace mlir::vector;
// Returns vector shape if the type is a vector. Returns an empty shape if it is
// not a vector.
static ArrayRef<int64_t> vectorShape(Type type) {
- auto vectorType = type.dyn_cast<VectorType>();
+ auto vectorType = dyn_cast<VectorType>(type);
return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
}
@@ -54,14 +54,14 @@ static ArrayRef<int64_t> vectorShape(Value value) {
// Broadcasts scalar type into vector type (iff shape is non-scalar).
static Type broadcast(Type type, ArrayRef<int64_t> shape) {
- assert(!type.isa<VectorType>() && "must be scalar type");
+ assert(!isa<VectorType>(type) && "must be scalar type");
return !shape.empty() ? VectorType::get(shape, type) : type;
}
// Broadcasts scalar value into vector (iff shape is non-scalar).
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
ArrayRef<int64_t> shape) {
- assert(!value.getType().isa<VectorType>() && "must be scalar value");
+ assert(!isa<VectorType>(value.getType()) && "must be scalar value");
auto type = broadcast(value.getType(), shape);
return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
}
@@ -92,7 +92,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
assert(!operands.empty() && "operands must be not empty");
assert(vectorWidth > 0 && "vector width must be larger than 0");
- VectorType inputType = operands[0].getType().cast<VectorType>();
+ VectorType inputType = cast<VectorType>(operands[0].getType());
ArrayRef<int64_t> inputShape = inputType.getShape();
// If input shape matches target vector width, we can just call the
@@ -118,7 +118,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
for (unsigned i = 0; i < operands.size(); ++i) {
auto operand = operands[i];
- auto eltType = operand.getType().cast<VectorType>().getElementType();
+ auto eltType = cast<VectorType>(operand.getType()).getElementType();
auto expandedType = VectorType::get(expandedShape, eltType);
expandedOperands[i] =
builder.create<vector::ShapeCastOp>(expandedType, operand);
@@ -145,7 +145,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
}
// Stitch results together into one large vector.
- Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
+ Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
Value result = builder.create<arith::ConstantOp>(
resultExpandedType, builder.getZeroAttr(resultExpandedType));
@@ -318,9 +318,9 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
// Create F32 equivalent type.
Type newType;
- if (auto shaped = origType.dyn_cast<ShapedType>()) {
+ if (auto shaped = dyn_cast<ShapedType>(origType)) {
newType = shaped.clone(rewriter.getF32Type());
- } else if (origType.isa<FloatType>()) {
+ } else if (isa<FloatType>(origType)) {
newType = rewriter.getF32Type();
} else {
return rewriter.notifyMatchFailure(op,
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 7f702e1978540..ae2472db4f862 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -69,7 +69,7 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
results.push_back(*newBuffer);
}
- transformResults.set(getResult().cast<OpResult>(), results);
+ transformResults.set(cast<OpResult>(getResult()), results);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 369f22521895d..9b1d85b290274 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -57,7 +57,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// always 1.
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
Attribute attr = valueOrAttr.dyn_cast<Attribute>();
- return attr && attr.cast<IntegerAttr>().getInt() == 1;
+ return attr && cast<IntegerAttr>(attr).getInt() == 1;
})) {
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
rewriter.getI64IntegerAttr(1));
@@ -93,8 +93,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// If both offsets are static we can simply calculate the combined
// offset statically.
offsets.push_back(rewriter.getI64IntegerAttr(
- opOffsetAttr.cast<IntegerAttr>().getInt() +
- sourceOffsetAttr.cast<IntegerAttr>().getInt()));
+ cast<IntegerAttr>(opOffsetAttr).getInt() +
+ cast<IntegerAttr>(sourceOffsetAttr).getInt()));
} else {
// When either offset is dynamic, we must emit an additional affine
// transformation to add the two offsets together dynamically.
@@ -102,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
SmallVector<Value> affineApplyOperands;
for (auto valueOrAttr : {opOffset, sourceOffset}) {
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
- expr = expr + attr.cast<IntegerAttr>().getInt();
+ expr = expr + cast<IntegerAttr>(attr).getInt();
} else {
expr =
expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 6202b5730c218..57f0141c95dc5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -149,7 +149,7 @@ void memref::populateMemRefWideIntEmulationConversions(
arith::WideIntEmulationConverter &typeConverter) {
typeConverter.addConversion(
[&typeConverter](MemRefType ty) -> std::optional<Type> {
- auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
if (!intTy)
return ty;
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 38fb11348f285..8a276ebbff6a9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -89,11 +89,11 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
LogicalResult matchAndRewrite(memref::ReshapeOp op,
PatternRewriter &rewriter) const final {
- auto shapeType = op.getShape().getType().cast<MemRefType>();
+ auto shapeType = cast<MemRefType>(op.getShape().getType());
if (!shapeType.hasStaticShape())
return failure();
- int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
+ int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
SmallVector<OpFoldResult, 4> sizes, strides;
sizes.resize(rank);
strides.resize(rank);
@@ -106,7 +106,7 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
if (op.getType().isDynamicDim(i)) {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
- if (!size.getType().isa<IndexType>())
+ if (!isa<IndexType>(size.getType()))
size = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), size);
sizes[i] = size;
@@ -141,7 +141,7 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
op.getKind() != arith::AtomicRMWKind::minf;
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
- return !op.getShape().getType().cast<MemRefType>().hasStaticShape();
+ return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index ea372bffbc0b7..ff2c4107ee46d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -62,7 +62,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
// Build a plain extract_strided_metadata(memref) from subview(memref).
Location origLoc = subview.getLoc();
Value source = subview.getSource();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
@@ -115,7 +115,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
// The final result is <baseBuffer, offset, sizes, strides>.
// Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
// the values.
- auto subType = subview.getType().cast<MemRefType>();
+ auto subType = cast<MemRefType>(subview.getType());
unsigned subRank = subType.getRank();
// The sizes of the final type are defined directly by the input sizes of
@@ -338,7 +338,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
// Collect the statically known information about the original stride.
Value source = expandShape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
auto [strides, offset] = getStridesAndOffset(sourceType);
OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
@@ -358,10 +358,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
AffineExpr s0 = builder.getAffineSymbolExpr(0);
AffineExpr s1 = builder.getAffineSymbolExpr(1);
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
- int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
- .get<Attribute>()
- .cast<IntegerAttr>()
- .getInt();
+ int64_t baseExpandedStride =
+ cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
+ .getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(),
(s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
@@ -372,10 +371,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
// Now apply the origStride to the remaining dimensions.
AffineExpr s0 = builder.getAffineSymbolExpr(0);
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
- int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
- .get<Attribute>()
- .cast<IntegerAttr>()
- .getInt();
+ int64_t baseExpandedStride =
+ cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
+ .getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
}
@@ -445,7 +443,7 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
// Build the affine expr of the product of the original sizes involved in that
// group.
Value source = collapseShape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
SmallVector<int64_t, 2> reassocGroup =
collapseShape.getReassociationIndices()[groupId];
@@ -479,7 +477,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
"Reassociation group should have at least one dimension");
Value source = collapseShape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
auto [strides, offset] = getStridesAndOffset(sourceType);
@@ -562,7 +560,7 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
// extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = reshape.getLoc();
Value source = reshape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
@@ -650,8 +648,7 @@ struct ExtractStridedMetadataOpAllocFolder
if (!allocLikeOp)
return failure();
- auto memRefType =
- allocLikeOp.getResult().getType().template cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
if (!memRefType.getLayout().isIdentity())
return rewriter.notifyMatchFailure(
allocLikeOp, "alloc-like operations should have been normalized");
@@ -688,7 +685,7 @@ struct ExtractStridedMetadataOpAllocFolder
SmallVector<Value> results;
results.reserve(rank * 2 + 2);
- auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>();
+ auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
int64_t offset = 0;
if (allocLikeOp.getType() == baseBufferType)
results.push_back(allocLikeOp);
@@ -737,7 +734,7 @@ struct ExtractStridedMetadataOpGetGlobalFolder
if (!getGlobalOp)
return failure();
- auto memRefType = getGlobalOp.getResult().getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
if (!memRefType.getLayout().isIdentity()) {
return rewriter.notifyMatchFailure(
getGlobalOp,
@@ -759,7 +756,7 @@ struct ExtractStridedMetadataOpGetGlobalFolder
SmallVector<Value> results;
results.reserve(rank * 2 + 2);
- auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>();
+ auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
int64_t offset = 0;
if (getGlobalOp.getType() == baseBufferType)
results.push_back(getGlobalOp);
@@ -838,8 +835,7 @@ class ExtractStridedMetadataOpReinterpretCastFolder
return rewriter.notifyMatchFailure(
reinterpretCastOp, "reinterpret_cast source's type is incompatible");
- auto memrefType =
- reinterpretCastOp.getResult().getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
unsigned rank = memrefType.getRank();
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 5141b5f33cfa2..05ba6a3f38708 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -120,7 +120,7 @@ template <typename TransferLikeOp>
static FailureOr<Value>
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
Value src = transferLikeOp.getSource();
- if (src.getType().isa<MemRefType>())
+ if (isa<MemRefType>(src.getType()))
return src;
return failure();
}
@@ -240,7 +240,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
return rewriter.notifyMatchFailure(loadStoreLikeOp,
"source is not a memref");
Value srcMemRef = *failureOrSrcMemRef;
- auto ldStTy = srcMemRef.getType().cast<MemRefType>();
+ auto ldStTy = cast<MemRefType>(srcMemRef.getType());
unsigned loadStoreRank = ldStTy.getRank();
// Don't waste compile time if there is nothing to rewrite.
if (loadStoreRank == 0)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 72675b03abf65..2c30e98dd1070 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -148,7 +148,7 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
- collapseShapeOp.getViewSource().getType().cast<MemRefType>().getRank();
+ cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
for (int64_t i = 0; i < srcRank; i++) {
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, zeroAffineMap, dynamicIndices);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index aa1d27dc863e6..68b72eff8c973 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -71,11 +71,9 @@ propagateSubViewOp(RewriterBase &rewriter,
UnrealizedConversionCastOp conversionOp, SubViewOp op) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
- auto newResultType =
- SubViewOp::inferRankReducedResultType(
- op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
- op.getMixedSizes(), op.getMixedStrides())
- .cast<MemRefType>();
+ auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
+ op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
+ op.getMixedSizes(), op.getMixedStrides()));
Value newSubview = rewriter.create<SubViewOp>(
op.getLoc(), newResultType, conversionOp.getOperand(0),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index ee1adcce80e5f..eb1df2a87b99a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -61,11 +61,11 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(subviewUse);
Type newType = memref::SubViewOp::inferRankReducedResultType(
- subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
+ subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
Value newSubview = rewriter.create<memref::SubViewOp>(
- subviewUse->getLoc(), newType.cast<MemRefType>(), val,
+ subviewUse->getLoc(), cast<MemRefType>(newType), val,
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
subviewUse.getMixedStrides());
@@ -209,9 +209,9 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
// Strides is [1, 1 ... 1 ].
- auto dstMemref = memref::SubViewOp::inferRankReducedResultType(
- originalShape, mbMemRefType, offsets, sizes, strides)
- .cast<MemRefType>();
+ auto dstMemref =
+ cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ originalShape, mbMemRefType, offsets, sizes, strides));
Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index c252433d16fa1..aa21497fad8f8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -180,7 +180,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
llvm::seq<unsigned>(0, callOp.getNumResults())) {
Value oldMemRef = callOp.getResult(resIndex);
if (auto oldMemRefType =
- oldMemRef.getType().dyn_cast<MemRefType>())
+ dyn_cast<MemRefType>(oldMemRef.getType()))
if (!oldMemRefType.getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return WalkResult::interrupt();
@@ -192,7 +192,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
BlockArgument oldMemRef = funcOp.getArgument(argIndex);
- if (auto oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>())
+ if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()))
if (!oldMemRefType.getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return false;
@@ -226,7 +226,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
funcOp.walk([&](func::ReturnOp returnOp) {
for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
Type opType = operandEn.value().getType();
- MemRefType memrefType = opType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(opType);
// If type is not memref or if the memref type is same as that in
// function's return signature then no update is required.
if (!memrefType || memrefType == resultTypes[operandEn.index()])
@@ -284,7 +284,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
if (oldResult.getType() == newResult.getType())
continue;
AffineMap layoutMap =
- oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
+ cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap();
if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
/*extraIndices=*/{},
/*indexRemap=*/layoutMap,
@@ -358,7 +358,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
for (unsigned argIndex :
llvm::seq<unsigned>(0, functionType.getNumInputs())) {
Type argType = functionType.getInput(argIndex);
- MemRefType memrefType = argType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(argType);
// Check whether argument is of MemRef type. Any other argument type can
// simply be part of the final function signature.
if (!memrefType) {
@@ -422,11 +422,11 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
// Replace all uses of the old memrefs.
Value oldMemRef = op->getResult(resIndex);
Value newMemRef = newOp->getResult(resIndex);
- MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
+ MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
// Check whether the operation result is MemRef type.
if (!oldMemRefType)
continue;
- MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
+ MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
if (oldMemRefType == newMemRefType)
continue;
// TODO: Assume single layout map. Multiple maps not supported.
@@ -466,7 +466,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
for (unsigned resIndex :
llvm::seq<unsigned>(0, functionType.getNumResults())) {
Type resType = functionType.getResult(resIndex);
- MemRefType memrefType = resType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(resType);
// Check whether result is of MemRef type. Any other argument type can
// simply be part of the final function signature.
if (!memrefType) {
@@ -507,7 +507,7 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
bool resultTypeNormalized = false;
for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
auto resultType = oldOp->getResult(resIndex).getType();
- MemRefType memrefType = resultType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(resultType);
// Check whether the operation result is MemRef type.
if (!memrefType) {
resultTypes.push_back(resultType);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 8c544bbd9fb0d..526c1c6e198ff 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -40,7 +40,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
- OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>();
+ OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
if (!dimValue)
return failure();
auto shapedTypeOp =
@@ -61,8 +61,8 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
return failure();
Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
- auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
- if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
+ auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
+ if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
return failure();
Location loc = dimOp->getLoc();
@@ -82,7 +82,7 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
- OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>();
+ OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
if (!dimValue)
return failure();
std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 9ffb315587e3b..05a069d98ef35 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -38,14 +38,14 @@ struct CastOpInterface
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto castOp = cast<CastOp>(op);
- auto srcType = castOp.getSource().getType().cast<BaseMemRefType>();
+ auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
// Nothing to check if the result is an unranked memref.
- auto resultType = castOp.getType().dyn_cast<MemRefType>();
+ auto resultType = dyn_cast<MemRefType>(castOp.getType());
if (!resultType)
return;
- if (srcType.isa<UnrankedMemRefType>()) {
+ if (isa<UnrankedMemRefType>(srcType)) {
// Check rank.
Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
Value resultRank =
@@ -75,7 +75,7 @@ struct CastOpInterface
// Check dimension sizes.
for (const auto &it : llvm::enumerate(resultType.getShape())) {
// Static dim size -> static/dynamic dim size does not need verification.
- if (auto rankedSrcType = srcType.dyn_cast<MemRefType>())
+ if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
if (!rankedSrcType.isDynamicDim(it.index()))
continue;
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
index 292738de4b52a..b9dd174a6b253 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
@@ -42,7 +42,7 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
Location location = op->getLoc();
if (op->hasAttr(op.getTf32EnabledAttrName()) ||
- !op.getMatrixA().getType().cast<VectorType>().getElementType().isF32())
+ !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
return failure();
if (precision == MmaSyncF32Lowering::Unkown)
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
index 07e9ae9f8650d..486c786892c28 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
@@ -180,7 +180,7 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
mlir::LogicalResult
mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue) {
- auto memRefType = memrefValue.getType().dyn_cast<MemRefType>();
+ auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
return failure();
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 7525f9f57bc5f..5a0018c315176 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -63,7 +63,7 @@ FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
info.vectorType = writeOp.getVectorType();
} else if (isa<vector::TransferReadOp, vector::ContractionOp,
vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
- info.vectorType = op->getResult(0).getType().cast<VectorType>();
+ info.vectorType = cast<VectorType>(op->getResult(0).getType());
} else {
return op->emitError()
<< "unhandled operation type in nvgpu.mma.sync conversion path";
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index ddd8ae0c0fd35..5ee53eaad5850 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -14,13 +14,13 @@ using namespace mlir;
using namespace mlir::quant;
static bool isQuantizablePrimitiveType(Type inputType) {
- return inputType.isa<FloatType>();
+ return isa<FloatType>(inputType);
}
ExpressedToQuantizedConverter
ExpressedToQuantizedConverter::forInputType(Type inputType) {
- if (inputType.isa<TensorType, VectorType>()) {
- Type elementType = inputType.cast<ShapedType>().getElementType();
+ if (isa<TensorType, VectorType>(inputType)) {
+ Type elementType = cast<ShapedType>(inputType).getElementType();
if (!isQuantizablePrimitiveType(elementType))
return ExpressedToQuantizedConverter{inputType, nullptr};
return ExpressedToQuantizedConverter{inputType, elementType};
@@ -34,11 +34,11 @@ ExpressedToQuantizedConverter::forInputType(Type inputType) {
Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
assert(expressedType && "convert() on unsupported conversion");
- if (auto tensorType = inputType.dyn_cast<RankedTensorType>())
+ if (auto tensorType = dyn_cast<RankedTensorType>(inputType))
return RankedTensorType::get(tensorType.getShape(), elementalType);
- if (auto tensorType = inputType.dyn_cast<UnrankedTensorType>())
+ if (auto tensorType = dyn_cast<UnrankedTensorType>(inputType))
return UnrankedTensorType::get(elementalType);
- if (auto vectorType = inputType.dyn_cast<VectorType>())
+ if (auto vectorType = dyn_cast<VectorType>(inputType))
return VectorType::get(vectorType.getShape(), elementalType);
// If the expressed types match, just use the new elemental type.
@@ -50,7 +50,7 @@ Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
ElementsAttr
UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) {
- if (auto attr = realValue.dyn_cast<DenseFPElementsAttr>()) {
+ if (auto attr = dyn_cast<DenseFPElementsAttr>(realValue)) {
return convert(attr);
}
// TODO: handles sparse elements attribute
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 18425dea7b19f..2da7473bf6595 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -49,7 +49,7 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
}
parents.insert(loop);
}
- results.set(getResult().cast<OpResult>(), parents.getArrayRef());
+ results.set(cast<OpResult>(getResult()), parents.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
@@ -116,8 +116,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
functions.push_back(*outlined);
calls.push_back(call);
}
- results.set(getFunction().cast<OpResult>(), functions);
- results.set(getCall().cast<OpResult>(), calls);
+ results.set(cast<OpResult>(getFunction()), functions);
+ results.set(cast<OpResult>(getCall()), calls);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 13f0d769ef4cc..ad395a9ac457b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -30,8 +30,8 @@ namespace {
/// Helper function for loop bufferization. Cast the given buffer to the given
/// memref type.
static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
- assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
- assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
+ assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
+ assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
// If the buffer already has the correct type, no cast is needed.
if (buffer.getType() == type)
return buffer;
@@ -78,7 +78,7 @@ struct ConditionOpInterface
SmallVector<Value> newArgs;
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Value value = it.value();
- if (value.getType().isa<TensorType>()) {
+ if (isa<TensorType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
@@ -141,7 +141,7 @@ struct ExecuteRegionOpInterface
rewriter.setInsertionPointAfter(newOp);
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
- if (it.value().isa<TensorType>()) {
+ if (isa<TensorType>(it.value())) {
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
executeRegionOp.getLoc(), newOp->getResult(it.index())));
} else {
@@ -183,7 +183,7 @@ struct IfOpInterface
// Compute bufferized result types.
SmallVector<Type> newTypes;
for (Value result : ifOp.getResults()) {
- if (!result.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(result.getType())) {
newTypes.push_back(result.getType());
continue;
}
@@ -218,13 +218,13 @@ struct IfOpInterface
assert(value.getDefiningOp() == op && "invalid valid");
// Determine buffer types of the true/false branches.
- auto opResult = value.cast<OpResult>();
+ auto opResult = cast<OpResult>(value);
auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
BaseMemRefType thenBufferType, elseBufferType;
- if (thenValue.getType().isa<BaseMemRefType>()) {
+ if (isa<BaseMemRefType>(thenValue.getType())) {
// True branch was already bufferized.
- thenBufferType = thenValue.getType().cast<BaseMemRefType>();
+ thenBufferType = cast<BaseMemRefType>(thenValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(thenValue, options, fixedTypes);
@@ -232,9 +232,9 @@ struct IfOpInterface
return failure();
thenBufferType = *maybeBufferType;
}
- if (elseValue.getType().isa<BaseMemRefType>()) {
+ if (isa<BaseMemRefType>(elseValue.getType())) {
// False branch was already bufferized.
- elseBufferType = elseValue.getType().cast<BaseMemRefType>();
+ elseBufferType = cast<BaseMemRefType>(elseValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(elseValue, options, fixedTypes);
@@ -253,7 +253,7 @@ struct IfOpInterface
// Layout maps are
diff erent: Promote to fully dynamic layout map.
return getMemRefTypeWithFullyDynamicLayout(
- opResult.getType().cast<TensorType>(), thenBufferType.getMemorySpace());
+ cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
}
};
@@ -262,7 +262,7 @@ struct IfOpInterface
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
DenseSet<int64_t> result;
for (const auto &it : llvm::enumerate(values))
- if (it.value().getType().isa<TensorType>())
+ if (isa<TensorType>(it.value().getType()))
result.insert(it.index());
return result;
}
@@ -275,8 +275,8 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
DenseSet<int64_t> result;
for (unsigned int i = 0; i < minSize; ++i) {
- if (!bbArgs[i].getType().isa<TensorType>() ||
- !yieldedValues[i].getType().isa<TensorType>())
+ if (!isa<TensorType>(bbArgs[i].getType()) ||
+ !isa<TensorType>(yieldedValues[i].getType()))
continue;
if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
result.insert(i);
@@ -291,7 +291,7 @@ getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
- if (opOperand.get().getType().isa<TensorType>()) {
+ if (isa<TensorType>(opOperand.get().getType())) {
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand.get(), options);
if (failed(resultBuffer))
@@ -361,9 +361,9 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
// Compute the buffer type of the yielded value.
BaseMemRefType yieldedValueBufferType;
- if (yieldedValue.getType().isa<BaseMemRefType>()) {
+ if (isa<BaseMemRefType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
- yieldedValueBufferType = yieldedValue.getType().cast<BaseMemRefType>();
+ yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(yieldedValue, options, newFixedTypes);
@@ -379,7 +379,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
// If there is a mismatch between the yielded buffer type and the iter_arg
// buffer type, the buffer type must be promoted to a fully dynamic layout
// map.
- auto yieldedRanked = yieldedValueBufferType.cast<MemRefType>();
+ auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
#ifndef NDEBUG
auto iterRanked = initArgBufferType->cast<MemRefType>();
assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
@@ -388,7 +388,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
"expected same memory space");
#endif // NDEBUG
return getMemRefTypeWithFullyDynamicLayout(
- iterArg.getType().cast<RankedTensorType>(),
+ cast<RankedTensorType>(iterArg.getType()),
yieldedRanked.getMemorySpace());
}
@@ -516,16 +516,16 @@ struct ForOpInterface
const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto forOp = cast<scf::ForOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
- assert(value.getType().isa<TensorType>() && "expected tensor type");
+ assert(isa<TensorType>(value.getType()) && "expected tensor type");
// Get result/argument number.
unsigned resultNum;
- if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(value)) {
resultNum =
forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
.getResultNumber();
} else {
- resultNum = value.cast<OpResult>().getResultNumber();
+ resultNum = cast<OpResult>(value).getResultNumber();
}
// Compute the bufferized type.
@@ -560,7 +560,7 @@ struct ForOpInterface
Value initArg = it.value();
Value result = forOp->getResult(it.index());
// If the type is not a tensor, bufferization doesn't need to touch it.
- if (!result.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(result.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
@@ -611,7 +611,7 @@ struct ForOpInterface
auto yieldOp =
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
for (OpResult opResult : op->getOpResults()) {
- if (!opResult.getType().isa<TensorType>())
+ if (!isa<TensorType>(opResult.getType()))
continue;
// Note: This is overly strict. We should check for aliasing bufferized
@@ -736,7 +736,7 @@ struct WhileOpInterface
for (int64_t idx = 0;
idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
Value value = conditionOp.getArgs()[idx];
- if (!value.getType().isa<TensorType>() ||
+ if (!isa<TensorType>(value.getType()) ||
(equivalentYieldsAfter.contains(idx) &&
equivalentYieldsBefore.contains(idx))) {
beforeYieldValues.push_back(value);
@@ -786,7 +786,7 @@ struct WhileOpInterface
Value initArg = it.value();
Value beforeArg = whileOp.getBeforeArguments()[it.index()];
// If the type is not a tensor, bufferization doesn't need to touch it.
- if (!beforeArg.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(beforeArg.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
@@ -799,7 +799,7 @@ struct WhileOpInterface
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
- if (!bbArg.getType().isa<TensorType>())
+ if (!isa<TensorType>(bbArg.getType()))
return bbArg.getType();
// TODO: error handling
return bufferization::getBufferType(bbArg, options)->cast<Type>();
@@ -848,10 +848,10 @@ struct WhileOpInterface
const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
- assert(value.getType().isa<TensorType>() && "expected tensor type");
+ assert(isa<TensorType>(value.getType()) && "expected tensor type");
// Case 1: Block argument of the "before" region.
- if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(value)) {
if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
auto yieldOp = whileOp.getYieldOp();
@@ -865,18 +865,18 @@ struct WhileOpInterface
// The bufferized "after" bbArg type can be directly computed from the
// bufferized "before" bbArg type.
unsigned resultNum;
- if (auto opResult = value.dyn_cast<OpResult>()) {
+ if (auto opResult = dyn_cast<OpResult>(value)) {
resultNum = opResult.getResultNumber();
- } else if (value.cast<BlockArgument>().getOwner()->getParent() ==
+ } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
&whileOp.getAfter()) {
- resultNum = value.cast<BlockArgument>().getArgNumber();
+ resultNum = cast<BlockArgument>(value).getArgNumber();
} else {
llvm_unreachable("invalid value");
}
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
- if (!conditionYieldedVal.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
- return conditionYieldedVal.getType().cast<BaseMemRefType>();
+ return cast<BaseMemRefType>(conditionYieldedVal.getType());
}
return bufferization::getBufferType(conditionYieldedVal, options,
fixedTypes);
@@ -902,7 +902,7 @@ struct WhileOpInterface
auto conditionOp = whileOp.getConditionOp();
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
- if (!it.value().getType().isa<TensorType>())
+ if (!isa<TensorType>(it.value().getType()))
continue;
if (!state.areEquivalentBufferizedValues(
it.value(), conditionOp->getBlock()->getArgument(it.index())))
@@ -913,7 +913,7 @@ struct WhileOpInterface
auto yieldOp = whileOp.getYieldOp();
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
- if (!it.value().getType().isa<TensorType>())
+ if (!isa<TensorType>(it.value().getType()))
continue;
if (!state.areEquivalentBufferizedValues(
it.value(), yieldOp->getBlock()->getArgument(it.index())))
@@ -971,7 +971,7 @@ struct YieldOpInterface
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Value value = it.value();
- if (value.getType().isa<TensorType>()) {
+ if (isa<TensorType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
@@ -1110,7 +1110,7 @@ struct ForallOpInterface
const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto forallOp = cast<ForallOp>(op);
- if (auto bbArg = value.dyn_cast<BlockArgument>())
+ if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
return bufferization::getBufferType(
@@ -1119,8 +1119,8 @@ struct ForallOpInterface
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
return bufferization::getBufferType(
- forallOp.getOutputs()[value.cast<OpResult>().getResultNumber()],
- options, fixedTypes);
+ forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
+ fixedTypes);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 2450a0e5fb347..99591493d132c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -43,7 +43,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
while (value) {
if (value == forOp.getRegionIterArgs()[arg])
return true;
- OpResult opResult = value.dyn_cast<OpResult>();
+ OpResult opResult = dyn_cast<OpResult>(value);
if (!opResult)
return false;
@@ -91,7 +91,7 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
- auto blockArg = dimOp.getSource().template dyn_cast<BlockArgument>();
+ auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
@@ -139,7 +139,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
if (!forOp)
return failure();
- auto opResult = dimOp.getSource().template cast<OpResult>();
+ auto opResult = cast<OpResult>(dimOp.getSource());
unsigned resultNumber = opResult.getResultNumber();
if (!isShapePreserving(forOp, resultNumber))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 6a9f725211427..a85985b84a037 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -164,8 +164,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
clone->walk([&](Operation *nested) {
for (OpOperand &operand : nested->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
- if ((def && !clone->isAncestor(def)) ||
- operand.get().isa<BlockArgument>())
+ if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
callback(&operand);
}
});
@@ -346,7 +345,7 @@ void LoopPipelinerInternal::createKernel(
rewriter.setInsertionPointAfter(newOp);
continue;
}
- auto arg = operand->get().dyn_cast<BlockArgument>();
+ auto arg = dyn_cast<BlockArgument>(operand->get());
if (arg && arg.getOwner() == forOp.getBody()) {
// If the value is a loop carried value coming from stage N + 1 remap,
// it will become a direct use.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 131e8216ef5d2..224bec3b26d29 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -496,7 +496,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
ArrayRef<scf::ForOp> loops) {
std::optional<OpOperand *> destinationIterArg;
auto loopIt = loops.rbegin();
- while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
+ while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
scf::ForOp loop = *loopIt;
if (iterArg.getOwner()->getParentOp() != loop)
break;
@@ -505,7 +505,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
}
if (loopIt == loops.rend())
destinationIterArg = source;
- return {source->get().dyn_cast<OpResult>(), destinationIterArg};
+ return {dyn_cast<OpResult>(source->get()), destinationIterArg};
}
/// Implementation of fusing producer of a single slice by computing the
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index f154840b6f659..c22cb6710a7e5 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -42,8 +42,8 @@ class SPIRVGlobalVariableOpLayoutInfoDecoration
PatternRewriter &rewriter) const override {
SmallVector<NamedAttribute, 4> globalVarAttrs;
- auto ptrType = op.getType().cast<spirv::PointerType>();
- auto pointeeType = ptrType.getPointeeType().cast<spirv::StructType>();
+ auto ptrType = cast<spirv::PointerType>(op.getType());
+ auto pointeeType = cast<spirv::StructType>(ptrType.getPointeeType());
spirv::StructType structType = VulkanLayoutUtils::decorateType(pointeeType);
if (!structType)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index c0ab2152675ee..9f2755da09229 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -51,19 +51,19 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
// info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
// not it must already be a !spirv.ptr<!spirv.struct<...>>.
auto varType = funcOp.getFunctionType().getInput(argIndex);
- if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
+ if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
auto storageClass = abiInfo.getStorageClass();
if (!storageClass)
return nullptr;
varType =
spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
- auto varPtrType = varType.cast<spirv::PointerType>();
- auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
+ auto varPtrType = cast<spirv::PointerType>(varType);
+ auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
// Set the offset information.
varPointeeType =
- VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
+ cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType));
if (!varPointeeType)
return nullptr;
@@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
// Starting with version 1.4, the interface’s storage classes are all
// storage classes used in declaring all global variables referenced by the
// entry point’s call tree." We should consider the target environment here.
- switch (var.getType().cast<spirv::PointerType>().getStorageClass()) {
+ switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
case spirv::StorageClass::Input:
case spirv::StorageClass::Output:
interfaceVarSet.insert(var.getOperation());
@@ -247,7 +247,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
// at the start of the function. It is probably better to do the load just
// before the use. There might be multiple loads and currently there is no
// easy way to replace all uses with a sequence of operations.
- if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
+ if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
@@ -287,7 +287,7 @@ void LowerABIAttributesPass::runOnOperation() {
typeConverter.addSourceMaterialization([](OpBuilder &builder,
spirv::PointerType type,
ValueRange inputs, Location loc) {
- if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>())
+ if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
return Value();
return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
});
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index 51c36bd12db19..f38282f57a2c3 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -84,15 +84,13 @@ void RewriteInsertsPass::runOnOperation() {
LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
- auto indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
+ auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
// TODO: handle nested composite object.
if (indicesArrayAttr.size() == 1) {
- auto numElements = op.getComposite()
- .getType()
- .cast<spirv::CompositeType>()
+ auto numElements = cast<spirv::CompositeType>(op.getComposite().getType())
.getNumElements();
- auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
+ auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt();
// Need a last index to collect a sequential chain.
if (index + 1 != numElements)
return failure();
@@ -109,9 +107,9 @@ LogicalResult RewriteInsertsPass::collectInsertionChain(
return failure();
--index;
- indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
+ indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
if ((indicesArrayAttr.size() != 1) ||
- (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
+ (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index))
return failure();
}
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 5a5cdfe341942..793b02520f235 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -139,7 +139,7 @@ bool SPIRVTypeConverter::allows(spirv::Capability capability) {
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
- if (type.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with boolean
@@ -152,21 +152,21 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
- if (auto complexType = type.dyn_cast<ComplexType>()) {
+ if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
return std::nullopt;
return 2 * *elementSize;
}
- if (auto vecType = type.dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
auto elementSize = getTypeNumBytes(options, vecType.getElementType());
if (!elementSize)
return std::nullopt;
return vecType.getNumElements() * *elementSize;
}
- if (auto memRefType = type.dyn_cast<MemRefType>()) {
+ if (auto memRefType = dyn_cast<MemRefType>(type)) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
@@ -198,7 +198,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return (offset + memrefSize) * *elementSize;
}
- if (auto tensorType = type.dyn_cast<TensorType>()) {
+ if (auto tensorType = dyn_cast<TensorType>(type)) {
if (!tensorType.hasStaticShape())
return std::nullopt;
@@ -246,12 +246,12 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (auto floatType = dyn_cast<FloatType>(type)) {
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return Builder(targetEnv.getContext()).getF32Type();
}
- auto intType = type.cast<IntegerType>();
+ auto intType = cast<IntegerType>(type);
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return IntegerType::get(targetEnv.getContext(), /*width=*/32,
intType.getSignedness());
@@ -319,8 +319,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
// Get extension and capability requirements for the given type.
SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
- type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
- type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
+ cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
+ cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
// If all requirements are met, then we can accept this type as-is.
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
@@ -415,8 +415,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
<< "using non-8-bit storage for bool types unimplemented");
return nullptr;
}
- auto elementType = IntegerType::get(type.getContext(), numBoolBits)
- .dyn_cast<spirv::ScalarType>();
+ auto elementType = dyn_cast<spirv::ScalarType>(
+ IntegerType::get(type.getContext(), numBoolBits));
if (!elementType)
return nullptr;
Type arrayElemType =
@@ -487,7 +487,7 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options,
MemRefType type) {
- auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+ auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!attr) {
LLVM_DEBUG(
llvm::dbgs()
@@ -499,7 +499,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
}
spirv::StorageClass storageClass = attr.getValue();
- if (type.getElementType().isa<IntegerType>()) {
+ if (isa<IntegerType>(type.getElementType())) {
if (type.getElementTypeBitWidth() == 1)
return convertBoolMemrefType(targetEnv, options, type, storageClass);
if (type.getElementTypeBitWidth() < 8)
@@ -508,17 +508,17 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
Type arrayElemType;
Type elementType = type.getElementType();
- if (auto vecType = elementType.dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(elementType)) {
arrayElemType =
convertVectorType(targetEnv, options, vecType, storageClass);
- } else if (auto complexType = elementType.dyn_cast<ComplexType>()) {
+ } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
arrayElemType =
convertComplexType(targetEnv, options, complexType, storageClass);
- } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
+ } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
arrayElemType =
convertScalarType(targetEnv, options, scalarType, storageClass);
- } else if (auto indexType = elementType.dyn_cast<IndexType>()) {
- type = convertIndexElementType(type, options).cast<MemRefType>();
+ } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
+ type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
@@ -583,7 +583,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
addConversion([this](IntegerType intType) -> std::optional<Type> {
- if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
+ if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
return convertScalarType(this->targetEnv, this->options, scalarType);
if (intType.getWidth() < 8)
return convertSubByteIntegerType(this->options, intType);
@@ -591,7 +591,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
});
addConversion([this](FloatType floatType) -> std::optional<Type> {
- if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
+ if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
return Type();
});
@@ -784,7 +784,7 @@ static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
unsigned elementCount) {
for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
- auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
if (!ptrType)
continue;
@@ -792,10 +792,9 @@ static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
// block statically used per shader entry point." So we should always reuse
// the existing one.
if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
- auto numElements = ptrType.getPointeeType()
- .cast<spirv::StructType>()
- .getElementType(0)
- .cast<spirv::ArrayType>()
+ auto numElements = cast<spirv::ArrayType>(
+ cast<spirv::StructType>(ptrType.getPointeeType())
+ .getElementType(0))
.getNumElements();
if (numElements == elementCount)
return varOp;
@@ -926,8 +925,8 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
linearizeIndex(indices, strides, offset, indexType, loc, builder);
}
Type pointeeType =
- basePtr.getType().cast<spirv::PointerType>().getPointeeType();
- if (pointeeType.isa<spirv::ArrayType>()) {
+ cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
+ if (isa<spirv::ArrayType>(pointeeType)) {
linearizedIndices.push_back(linearIndex);
return builder.create<spirv::AccessChainOp>(loc, basePtr,
linearizedIndices);
@@ -1015,7 +1014,7 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
// Ensure that all types have been converted to SPIRV types.
if (llvm::any_of(valueTypes,
- [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
+ [](Type t) { return !isa<spirv::SPIRVType>(t); }))
return false;
// Special treatment for global variables, whose type requirements are
@@ -1029,13 +1028,13 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
- valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
typeExtensions)))
return false;
typeCapabilities.clear();
- valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+ cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
typeCapabilities)))
return false;
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 3cd4937e96f26..44fea86785593 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -41,7 +41,7 @@ namespace {
//===----------------------------------------------------------------------===//
Attribute getScalarOrSplatAttr(Type type, int64_t value) {
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
- if (auto intTy = type.dyn_cast<IntegerType>())
+ if (auto intTy = dyn_cast<IntegerType>(type))
return IntegerAttr::get(intTy, sizedValue);
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
@@ -149,7 +149,7 @@ struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
// Currently, WGSL only supports 32-bit integer types. Any other integer
// types should already have been promoted/demoted to i32.
- auto elemTy = getElementTypeOrSelf(lhs.getType()).cast<IntegerType>();
+ auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
if (elemTy.getIntOrFloatBitWidth() != 32)
return rewriter.notifyMatchFailure(
loc,
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 97f16d1b1b95f..ea856c7486777 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -65,16 +65,16 @@ static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
/// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
/// otherwise.
static Type getRuntimeArrayElementType(Type type) {
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return {};
- auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
+ auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
if (!structType || structType.getNumElements() != 1)
return {};
auto rtArrayType =
- structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
+ dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
if (!rtArrayType)
return {};
@@ -97,7 +97,7 @@ deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
for (const auto &indexedTypes : llvm::enumerate(types)) {
spirv::SPIRVType type = indexedTypes.value();
assert(type.isScalarOrVector());
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
if (vectorType.getNumElements() % 2 != 0)
return std::nullopt; // Odd-sized vector has special layout
// requirements.
@@ -277,7 +277,7 @@ void ResourceAliasAnalysis::recordIfUnifiable(
if (!elementType)
return; // Unexpected resource variable type.
- auto type = elementType.cast<spirv::SPIRVType>();
+ auto type = cast<spirv::SPIRVType>(elementType);
if (!type.isScalarOrVector())
return; // Unexpected resource element type.
@@ -370,7 +370,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
Location loc = acOp.getLoc();
- if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
+ if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) {
// The source indices are for a buffer with scalar element types. Rewrite
// them into a buffer with vector element types. We need to scale the last
// index for the vector as a whole, then add one level of index for inside
@@ -398,7 +398,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
- (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
+ (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
// The source indices are for a buffer with larger bitwidth scalar/vector
// element types. Rewrite them into a buffer with smaller bitwidth element
// types. We only need to scale the last index.
@@ -433,10 +433,10 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
LogicalResult
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>();
- auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
- auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>();
- auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
+ auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
+ auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
+ auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
+ auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
Location loc = loadOp.getLoc();
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
@@ -454,7 +454,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
- (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
+ (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
// The source and destination have scalar types of
diff erent bitwidths, or
// vector types of
diff erent component counts. For such cases, we load
// multiple smaller bitwidth values and construct a larger bitwidth one.
@@ -495,13 +495,13 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
// type.
Type vectorType = srcElemType;
- if (!srcElemType.isa<VectorType>())
+ if (!isa<VectorType>(srcElemType))
vectorType = VectorType::get({ratio}, dstElemType);
// If both the source and destination are vector types, we need to make
// sure the scalar type is the same for composite construction later.
- if (auto srcElemVecType = srcElemType.dyn_cast<VectorType>())
- if (auto dstElemVecType = dstElemType.dyn_cast<VectorType>()) {
+ if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
+ if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
if (srcElemVecType.getElementType() !=
dstElemVecType.getElementType()) {
int64_t count =
@@ -515,7 +515,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
loc, vectorType, components);
- if (!srcElemType.isa<VectorType>())
+ if (!isa<VectorType>(srcElemType))
vectorValue =
rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
rewriter.replaceOp(loadOp, vectorValue);
@@ -534,9 +534,9 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcElemType =
- storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
+ cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
auto dstElemType =
- adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
+ cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6e09a848c494e..095db6b815f51 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -159,13 +159,13 @@ void UpdateVCEPass::runOnOperation() {
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
- valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
if (failed(checkAndUpdateExtensionRequirements(
op, targetEnv, typeExtensions, deducedExtensions)))
return WalkResult::interrupt();
typeCapabilities.clear();
- valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+ cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
if (failed(checkAndUpdateCapabilityRequirements(
op, targetEnv, typeCapabilities, deducedCapabilities)))
return WalkResult::interrupt();
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index 67d61f820b629..b19495bc37445 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -53,7 +53,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
// must be a runtime array.
assert(memberSize != std::numeric_limits<Size>().max() ||
(i + 1 == e &&
- structType.getElementType(i).isa<spirv::RuntimeArrayType>()));
+ isa<spirv::RuntimeArrayType>(structType.getElementType(i))));
// According to the Vulkan spec:
// "A structure has a base alignment equal to the largest base alignment of
// any of its members."
@@ -79,23 +79,23 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
VulkanLayoutUtils::Size &alignment) {
- if (type.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(type)) {
alignment = getScalarTypeAlignment(type);
// Vulkan spec does not specify any padding for a scalar type.
size = alignment;
return type;
}
- if (auto structType = type.dyn_cast<spirv::StructType>())
+ if (auto structType = dyn_cast<spirv::StructType>(type))
return decorateType(structType, size, alignment);
- if (auto arrayType = type.dyn_cast<spirv::ArrayType>())
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(type))
return decorateType(arrayType, size, alignment);
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = dyn_cast<VectorType>(type))
return decorateType(vectorType, size, alignment);
- if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
+ if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
size = std::numeric_limits<Size>().max();
return decorateType(arrayType, alignment);
}
- if (type.isa<spirv::PointerType>()) {
+ if (isa<spirv::PointerType>(type)) {
// TODO: Add support for `PhysicalStorageBufferAddresses`.
return nullptr;
}
@@ -161,13 +161,13 @@ VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
}
bool VulkanLayoutUtils::isLegalType(Type type) {
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType) {
return true;
}
auto storageClass = ptrType.getStorageClass();
- auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
+ auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
if (!structType) {
return true;
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index fc67fea1a4931..4a567f48aeb42 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -64,7 +64,7 @@ struct AssumingOpInterface
rewriter.setInsertionPointAfter(newOp);
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
- if (it.value().isa<TensorType>()) {
+ if (isa<TensorType>(it.value())) {
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
assumingOp.getLoc(), newOp->getResult(it.index())));
} else {
@@ -116,7 +116,7 @@ struct AssumingYieldOpInterface
auto yieldOp = cast<shape::AssumingYieldOp>(op);
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
- if (value.getType().isa<TensorType>()) {
+ if (isa<TensorType>(value.getType())) {
FailureOr<Value> buffer = getBuffer(rewriter, value, options);
if (failed(buffer))
return failure();
diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
index f23a090a25a09..1a6f868cf21df 100644
--- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
@@ -133,7 +133,7 @@ void constructShapeFunc(
for (shape::WithOp withOp : allWithOps) {
Value value = withOp.getOperand();
Value shape = withOp.getShape();
- RankedTensorType rankedType = value.getType().dyn_cast<RankedTensorType>();
+ RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
if (rankedType == nullptr)
continue;
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 99a619cda7b63..990f8f7327d80 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -41,7 +41,7 @@ getBufferizationOptions(bool analysisOnly) {
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
- value.getType().cast<TensorType>(), memorySpace);
+ cast<TensorType>(value.getType()), memorySpace);
};
if (analysisOnly) {
options.testAnalysisOnly = true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 6fd55c7799306..ace8a8867081d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -28,7 +28,7 @@ using namespace mlir::sparse_tensor;
static std::optional<std::pair<Value, Value>>
genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) {
if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
- if (auto a = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+ if (auto a = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
auto coordinates = builder.create<arith::ConstantOp>(loc, a.getIndices());
auto values = builder.create<arith::ConstantOp>(loc, a.getValues());
return std::make_pair(coordinates, values);
@@ -94,7 +94,7 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
if (tp.isIndex())
return OverheadType::kIndex;
- if (auto intTp = tp.dyn_cast<IntegerType>())
+ if (auto intTp = dyn_cast<IntegerType>(tp))
return overheadTypeEncoding(intTp.getWidth());
llvm_unreachable("Unknown overhead type");
}
@@ -169,7 +169,7 @@ PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
return PrimaryType::kI16;
if (elemTp.isInteger(8))
return PrimaryType::kI8;
- if (auto complexTp = elemTp.dyn_cast<ComplexType>()) {
+ if (auto complexTp = dyn_cast<ComplexType>(elemTp)) {
auto complexEltTp = complexTp.getElementType();
if (complexEltTp.isF64())
return PrimaryType::kC64;
@@ -205,10 +205,10 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
return value;
// int <=> index
- if (srcTp.isa<IndexType>() || dstTp.isa<IndexType>())
+ if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp))
return builder.create<arith::IndexCastOp>(loc, dstTp, value);
- const auto srcIntTp = srcTp.dyn_cast_or_null<IntegerType>();
+ const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp);
const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
@@ -216,7 +216,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
Value s) {
Value load = builder.create<memref::LoadOp>(loc, mem, s);
- if (!load.getType().isa<IndexType>()) {
+ if (!isa<IndexType>(load.getType())) {
if (load.getType().getIntOrFloatBitWidth() < 64)
load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
load =
@@ -226,14 +226,14 @@ Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
}
mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
- if (tp.isa<FloatType>())
+ if (isa<FloatType>(tp))
return builder.getFloatAttr(tp, 1.0);
- if (tp.isa<IndexType>())
+ if (isa<IndexType>(tp))
return builder.getIndexAttr(1);
- if (auto intTp = tp.dyn_cast<IntegerType>())
+ if (auto intTp = dyn_cast<IntegerType>(tp))
return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
- if (tp.isa<RankedTensorType, VectorType>()) {
- auto shapedTp = tp.cast<ShapedType>();
+ if (isa<RankedTensorType, VectorType>(tp)) {
+ auto shapedTp = cast<ShapedType>(tp);
if (auto one = getOneAttr(builder, shapedTp.getElementType()))
return DenseElementsAttr::get(shapedTp, one);
}
@@ -244,13 +244,13 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
Value v) {
Type tp = v.getType();
Value zero = constantZero(builder, loc, tp);
- if (tp.isa<FloatType>())
+ if (isa<FloatType>(tp))
return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
zero);
if (tp.isIntOrIndex())
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
zero);
- if (tp.dyn_cast<ComplexType>())
+ if (dyn_cast<ComplexType>(tp))
return builder.create<complex::NotEqualOp>(loc, v, zero);
llvm_unreachable("Non-numeric type");
}
@@ -580,12 +580,12 @@ void sparse_tensor::foreachInSparseConstant(
}
// Remap value.
Value val;
- if (attr.getElementType().isa<ComplexType>()) {
- auto valAttr = elems[i].second.cast<ArrayAttr>();
+ if (isa<ComplexType>(attr.getElementType())) {
+ auto valAttr = cast<ArrayAttr>(elems[i].second);
val = builder.create<complex::ConstantOp>(loc, attr.getElementType(),
valAttr);
} else {
- auto valAttr = elems[i].second.cast<TypedAttr>();
+ auto valAttr = cast<TypedAttr>(elems[i].second);
val = builder.create<arith::ConstantOp>(loc, valAttr);
}
assert(val);
@@ -597,7 +597,7 @@ SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
size_t size, Value mem,
size_t offsetIdx, Value offsetVal) {
#ifndef NDEBUG
- const auto memTp = mem.getType().cast<MemRefType>();
+ const auto memTp = cast<MemRefType>(mem.getType());
assert(memTp.getRank() == 1);
const DynSize memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(size));
@@ -619,7 +619,7 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
ValueRange vs, size_t offsetIdx, Value offsetVal) {
#ifndef NDEBUG
const size_t vsize = vs.size();
- const auto memTp = mem.getType().cast<MemRefType>();
+ const auto memTp = cast<MemRefType>(mem.getType());
assert(memTp.getRank() == 1);
const DynSize memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(vsize));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index e04475ea2e8f1..9e762892e8648 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -260,7 +260,7 @@ Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
/// `IntegerType`), this also works for `RankedTensorType` and `VectorType`
/// (for which it generates a constant `DenseElementsAttr` of zeros).
inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
- if (auto ctp = tp.dyn_cast<ComplexType>()) {
+ if (auto ctp = dyn_cast<ComplexType>(tp)) {
auto zeroe = builder.getZeroAttr(ctp.getElementType());
auto zeroa = builder.getArrayAttr({zeroe, zeroe});
return builder.create<complex::ConstantOp>(loc, tp, zeroa);
@@ -271,7 +271,7 @@ inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
/// Generates a 1-valued constant of the given type. This supports all
/// the same types as `constantZero`.
inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
- if (auto ctp = tp.dyn_cast<ComplexType>()) {
+ if (auto ctp = dyn_cast<ComplexType>(tp)) {
auto zeroe = builder.getZeroAttr(ctp.getElementType());
auto onee = getOneAttr(builder, ctp.getElementType());
auto zeroa = builder.getArrayAttr({onee, zeroe});
@@ -350,7 +350,7 @@ inline Value constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
}
inline bool isZeroRankedTensorOrScalar(Type type) {
- auto rtp = type.dyn_cast<RankedTensorType>();
+ auto rtp = dyn_cast<RankedTensorType>(type);
return !rtp || rtp.getRank() == 0;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 731a1a9e460e5..d61e545056788 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -350,7 +350,7 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
// on positions.
for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) {
const Value tensor = tensors[t];
- const auto rtp = tensor.getType().dyn_cast<RankedTensorType>();
+ const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
if (!rtp)
// Skips only scalar, zero ranked tensor still need to be bufferized and
// (probably) filled with zeros by users.
@@ -432,7 +432,7 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
Type indexType = builder.getIndexType();
Value c0 = constantZero(builder, loc, indexType);
for (TensorId t = 0, e = tensors.size(); t < e; t++) {
- auto rtp = tensors[t].getType().dyn_cast<RankedTensorType>();
+ auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType());
if (!rtp)
continue;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 67a3f3d038a1d..03715785d2844 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -415,11 +415,11 @@ class LoopEmitter {
// check `dstLvl < dstLvlRank` at the top; and only here need to
// assert that `reassoc.size() == dstLvlRank`.
assert(dstLvl < reassoc.size() && "Level is out-of-bounds");
- const auto srcLvls = reassoc[dstLvl].cast<ArrayAttr>();
+ const auto srcLvls = cast<ArrayAttr>(reassoc[dstLvl]);
return llvm::to_vector<2>(
llvm::map_range(srcLvls, [&](Attribute srcLvl) -> Level {
// TODO: replace this with the converter for `LevelAttr`.
- return srcLvl.cast<IntegerAttr>().getValue().getZExtValue();
+ return cast<IntegerAttr>(srcLvl).getValue().getZExtValue();
}));
}
return {dstLvl};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index c99c26b9c98cd..bb52e08686fe5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -100,7 +100,7 @@ static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
/// completion. Needs to cast the buffer to a unranked buffer.
static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
Value mem) {
- MemRefType memTp = mem.getType().cast<MemRefType>();
+ MemRefType memTp = cast<MemRefType>(mem.getType());
UnrankedMemRefType resTp =
UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
@@ -133,7 +133,7 @@ static void genBlockingWait(OpBuilder &builder, Location loc,
/// that feature does not seem to be fully supported yet.
static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
Value token) {
- auto tp = mem.getType().cast<ShapedType>();
+ auto tp = cast<ShapedType>(mem.getType());
auto elemTp = tp.getElementType();
auto shape = tp.getShape();
auto memTp = MemRefType::get(shape, elemTp);
@@ -304,7 +304,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
for (OpOperand &o : op->getOpOperands()) {
Value val = o.get();
Block *block;
- if (auto arg = val.dyn_cast<BlockArgument>())
+ if (auto arg = dyn_cast<BlockArgument>(val))
block = arg.getOwner();
else
block = val.getDefiningOp()->getBlock();
@@ -321,7 +321,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
Type tp = val.getType();
if (val.getDefiningOp<arith::ConstantOp>())
constants.push_back(val);
- else if (tp.isa<FloatType>() || tp.isIntOrIndex())
+ else if (isa<FloatType>(tp) || tp.isIntOrIndex())
scalars.push_back(val);
else if (isa<MemRefType>(tp))
buffers.push_back(val);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index 0c68c4db4fe95..f34ed9779cfd3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -111,9 +111,9 @@ Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
SpecifierStructBuilder md(metaData);
if (!source) {
- auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
- .getBody()[kMemSizePosInSpecifier]
- .cast<LLVM::LLVMArrayType>();
+ auto memSizeArrayType =
+ cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
+ .getBody()[kMemSizePosInSpecifier]);
Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
// Fill memSizes array with zero.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index aebf0542b3332..88f79bf3e8d46 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -80,7 +80,7 @@ static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
Value idx) {
idx = genCast(builder, loc, idx, builder.getIndexType());
val = genCast(builder, loc, val,
- mem.getType().cast<ShapedType>().getElementType());
+ cast<ShapedType>(mem.getType()).getElementType());
builder.create<memref::StoreOp>(loc, val, mem, idx);
}
@@ -253,7 +253,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
case SparseTensorFieldKind::CrdMemRef:
case SparseTensorFieldKind::ValMemRef:
field = createAllocation(
- builder, loc, fType.cast<MemRefType>(),
+ builder, loc, cast<MemRefType>(fType),
(fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic
: (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic
: valHeuristic,
@@ -779,7 +779,7 @@ class SparseTensorAllocConverter
fields.reserve(desc.getNumFields());
// Memcpy on memref fields.
for (auto field : desc.getMemRefFields()) {
- auto memrefTp = field.getType().cast<MemRefType>();
+ auto memrefTp = cast<MemRefType>(field.getType());
auto size = rewriter.create<memref::DimOp>(loc, field, 0);
auto copied =
rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
@@ -1128,7 +1128,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
SmallVector<Value> fields;
foreachFieldAndTypeInSparseTensor(
- SparseTensorType(op.getResult().getType().cast<RankedTensorType>()),
+ SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
[&rewriter, &fields, srcDesc,
loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
DimLevelType /*dlt*/) -> bool {
@@ -1143,7 +1143,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
// values.
Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
auto dstMem = rewriter.create<memref::AllocOp>(
- loc, fTp.cast<MemRefType>(), sz);
+ loc, cast<MemRefType>(fTp), sz);
if (fTp != srcMem.getType()) {
// Converts elements type.
scf::buildLoopNest(
@@ -1397,7 +1397,7 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
}
assert(field);
- if (auto memrefTp = field.getType().dyn_cast<MemRefType>();
+ if (auto memrefTp = dyn_cast<MemRefType>(field.getType());
memrefTp && memrefTp.getRank() > 1) {
ReassociationIndices reassociation;
for (int i = 0, e = memrefTp.getRank(); i < e; i++)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 8d0c8548097f1..906f700cfc475 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -399,7 +399,7 @@ static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType,
/// (which can be either dim- or lvl-coords, depending on context).
static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter,
Value coords, Value elemPtr) {
- Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
+ Type elemTp = cast<ShapedType>(elemPtr.getType()).getElementType();
SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
SmallVector<Value, 3> params{iter, coords, elemPtr};
Type i1 = builder.getI1Type();
@@ -1045,7 +1045,7 @@ class SparseTensorToPositionsConverter
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resTp = op.getType();
- Type posTp = resTp.cast<ShapedType>().getElementType();
+ Type posTp = cast<ShapedType>(resTp).getElementType();
SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
@@ -1064,7 +1064,7 @@ class SparseTensorToCoordinatesConverter
ConversionPatternRewriter &rewriter) const override {
// TODO: use `SparseTensorType::getCrdType` instead.
Type resType = op.getType();
- const Type crdTp = resType.cast<ShapedType>().getElementType();
+ const Type crdTp = cast<ShapedType>(resType).getElementType();
SmallString<19> name{"sparseCoordinates",
overheadTypeFunctionSuffix(crdTp)};
Location loc = op->getLoc();
@@ -1096,7 +1096,7 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resType = op.getType().cast<ShapedType>();
+ auto resType = cast<ShapedType>(op.getType());
rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
adaptor.getOperands()));
return success();
@@ -1113,7 +1113,7 @@ class SparseNumberOfEntriesConverter
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Query values array size for the actually stored values size.
- Type eltType = op.getTensor().getType().cast<ShapedType>().getElementType();
+ Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2a4bbb06eb507..ca27794b64c1f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -79,7 +79,7 @@ static bool isSampling(GenericOp op) {
// Helper to detect chain of multiplications that do not involve x.
static bool isMulChain(Value val, Value x) {
- if (auto arg = val.dyn_cast<BlockArgument>())
+ if (auto arg = dyn_cast<BlockArgument>(val))
return arg != x;
if (auto *def = val.getDefiningOp()) {
if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
@@ -105,7 +105,7 @@ static bool isSumOfMul(GenericOp op) {
// Helper to detect direct yield of a zero value.
static bool isZeroYield(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
- if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
+ if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
if (arg.getOwner()->getParentOp() == op) {
return isZeroValue(op->getOperand(arg.getArgNumber()));
}
@@ -719,7 +719,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
bool fromSparseConst = false;
if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) {
- if (constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+ if (dyn_cast<SparseElementsAttr>(constOp.getValue())) {
fromSparseConst = true;
}
}
@@ -972,7 +972,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
// Special-case: for each over a sparse constant uses its own rewriting
// rule.
if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
- if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+ if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
return genForeachOnSparseConstant(op, rewriter, attr);
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 788ad28ee4221..a51fcc598ea9f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -450,7 +450,7 @@ inline Value genTuple(OpBuilder &builder, Location loc,
inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
auto tuple = getTuple(tensor);
- SparseTensorType stt(tuple.getResultTypes()[0].cast<RankedTensorType>());
+ SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return SparseTensorDescriptor(stt, tuple.getInputs());
}
@@ -458,7 +458,7 @@ inline MutSparseTensorDescriptor
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
auto tuple = getTuple(tensor);
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
- SparseTensorType stt(tuple.getResultTypes()[0].cast<RankedTensorType>());
+ SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return MutSparseTensorDescriptor(stt, fields);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index afeabb33fcd78..681ba21dd4a35 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -88,9 +88,9 @@ class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> {
// Overrides method from AffineExprVisitor.
void visitDimExpr(AffineDimExpr expr) {
if (pickedDim == nullptr ||
- pickIterType == iterTypes[expr.getPosition()]
- .cast<linalg::IteratorTypeAttr>()
- .getValue()) {
+ pickIterType ==
+ cast<linalg::IteratorTypeAttr>(iterTypes[expr.getPosition()])
+ .getValue()) {
pickedDim = expr;
}
}
@@ -344,7 +344,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
// we can't use `getRankedTensorType`/`getSparseTensorType` here.
// However, we don't need to handle `StorageSpecifierType`, so we
// can use `SparseTensorType` once we guard against non-tensors.
- const auto rtp = tensor.getType().dyn_cast<RankedTensorType>();
+ const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
if (!rtp)
return 0;
const SparseTensorType stt(rtp);
@@ -1243,7 +1243,7 @@ static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
Location loc = op.getLoc();
if (atStart) {
auto dynShape = {ShapedType::kDynamic};
- Type etp = tensor.getType().cast<ShapedType>().getElementType();
+ Type etp = cast<ShapedType>(tensor.getType()).getElementType();
Type t1 = MemRefType::get(dynShape, etp);
Type t2 = MemRefType::get(dynShape, builder.getI1Type());
Type t3 = MemRefType::get(dynShape, builder.getIndexType());
@@ -1833,7 +1833,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// required for sparse tensor slice rank reducing too.
Level maxLvlRank = 0;
for (auto operand : op.getOperands()) {
- if (auto rtp = operand.getType().dyn_cast<RankedTensorType>()) {
+ if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index e8558c1d8d9d4..ae31af0cc572c 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1061,8 +1061,8 @@ bool Merger::maybeZero(ExprId e) const {
if (expr.kind == TensorExp::Kind::kInvariant) {
if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
ArrayAttr arrayAttr = c.getValue();
- return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
- arrayAttr[1].cast<FloatAttr>().getValue().isZero();
+ return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
+ cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
}
if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
return c.value() == 0;
@@ -1077,7 +1077,7 @@ Type Merger::inferType(ExprId e, Value src) const {
Type dtp = exp(e).val.getType();
// Inspect source type. For vector types, apply the same
// vectorization to the destination type.
- if (auto vtp = src.getType().dyn_cast<VectorType>())
+ if (auto vtp = dyn_cast<VectorType>(src.getType()))
return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
return dtp;
}
@@ -1085,7 +1085,7 @@ Type Merger::inferType(ExprId e, Value src) const {
/// Ensures that sparse compiler can generate code for expression.
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
// Arguments are always admissible.
- if (v.isa<BlockArgument>())
+ if (isa<BlockArgument>(v))
return true;
// Accept index anywhere.
Operation *def = v.getDefiningOp();
@@ -1113,7 +1113,7 @@ static bool isAdmissibleBranch(Operation *op, Region ®ion) {
}
std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
- if (auto arg = v.dyn_cast<BlockArgument>()) {
+ if (auto arg = dyn_cast<BlockArgument>(v)) {
const TensorId tid = makeTensorId(arg.getArgNumber());
// Any argument of the generic op that is not marked as a scalar
// argument is considered a tensor, indexed by the implicit loop
@@ -1346,8 +1346,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
case TensorExp::Kind::kAbsF:
return rewriter.create<math::AbsFOp>(loc, v0);
case TensorExp::Kind::kAbsC: {
- auto type = v0.getType().cast<ComplexType>();
- auto eltType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(v0.getType());
+ auto eltType = cast<FloatType>(type.getElementType());
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
}
case TensorExp::Kind::kAbsI:
@@ -1407,13 +1407,13 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
case TensorExp::Kind::kTruncI:
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
case TensorExp::Kind::kCIm: {
- auto type = v0.getType().cast<ComplexType>();
- auto eltType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(v0.getType());
+ auto eltType = cast<FloatType>(type.getElementType());
return rewriter.create<complex::ImOp>(loc, eltType, v0);
}
case TensorExp::Kind::kCRe: {
- auto type = v0.getType().cast<ComplexType>();
- auto eltType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(v0.getType());
+ auto eltType = cast<FloatType>(type.getElementType());
return rewriter.create<complex::ReOp>(loc, eltType, v0);
}
case TensorExp::Kind::kBitCast:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 57e5df4633438..d93d88630fd86 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -60,20 +60,20 @@ struct CastOpInterface
// type in case the input is an unranked tensor type.
// Case 1: Casting an unranked tensor
- if (castOp.getSource().getType().isa<UnrankedTensorType>()) {
+ if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
// When casting to a ranked tensor, we cannot infer any static offset or
// strides from the source. Assume fully dynamic.
return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
}
// Case 2: Casting to an unranked tensor type
- if (castOp.getType().isa<UnrankedTensorType>()) {
+ if (isa<UnrankedTensorType>(castOp.getType())) {
return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
}
// Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
// change.
- auto rankedResultType = castOp.getType().cast<RankedTensorType>();
+ auto rankedResultType = cast<RankedTensorType>(castOp.getType());
return MemRefType::get(
rankedResultType.getShape(), rankedResultType.getElementType(),
maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
@@ -158,7 +158,7 @@ struct CollapseShapeOpInterface
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
- auto bufferType = buffer.getType().cast<MemRefType>();
+ auto bufferType = cast<MemRefType>(buffer.getType());
if (tensorResultType.getRank() == 0) {
// 0-d collapses must go through a
diff erent op builder.
@@ -383,11 +383,9 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
- return memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(),
- srcMemrefType->cast<MemRefType>(), mixedOffsets, mixedSizes,
- mixedStrides)
- .cast<BaseMemRefType>();
+ return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ extractSliceOp.getType().getShape(), srcMemrefType->cast<MemRefType>(),
+ mixedOffsets, mixedSizes, mixedStrides));
}
};
@@ -459,7 +457,7 @@ struct FromElementsOpInterface
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
// Should the buffer be deallocated?
bool dealloc = shouldDeallocateOpResult(
- fromElementsOp.getResult().cast<OpResult>(), options);
+ cast<OpResult>(fromElementsOp.getResult()), options);
// TODO: Implement memory space for this op.
if (options.defaultMemorySpace != Attribute())
@@ -467,7 +465,7 @@ struct FromElementsOpInterface
// Allocate a buffer for the result.
Location loc = op->getLoc();
- auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
auto shape = tensorType.getShape();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
FailureOr<Value> tensorAlloc =
@@ -540,7 +538,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
ValueRange dynamicSizes,
Region &generateBody) {
assert(generateBody.hasOneBlock() && "expected body with single block");
- auto tensorType = tensorDestination.getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
assert(generateBody.getNumArguments() == tensorType.getRank() &&
"rank mismatch");
@@ -579,7 +577,7 @@ struct GenerateOpInterface
auto generateOp = cast<tensor::GenerateOp>(op);
// Should the buffer be deallocated?
bool dealloc = shouldDeallocateOpResult(
- generateOp.getResult().cast<OpResult>(), options);
+ cast<OpResult>(generateOp.getResult()), options);
// TODO: Implement memory space for this op.
if (options.defaultMemorySpace != Attribute())
@@ -800,12 +798,11 @@ struct InsertSliceOpInterface
return failure();
// Take a subview of the destination buffer.
- auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
+ auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
auto subviewMemRefType =
- memref::SubViewOp::inferRankReducedResultType(
+ cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getShape(), dstMemrefType,
- mixedOffsets, mixedSizes, mixedStrides)
- .cast<MemRefType>();
+ mixedOffsets, mixedSizes, mixedStrides));
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
mixedStrides);
@@ -900,7 +897,7 @@ struct PadOpInterface
// Should the buffer be deallocated?
bool dealloc =
- shouldDeallocateOpResult(padOp.getResult().cast<OpResult>(), options);
+ shouldDeallocateOpResult(cast<OpResult>(padOp.getResult()), options);
// Allocate a buffer for the padded result.
FailureOr<Value> tensorAlloc =
allocateTensorForShapedValue(rewriter, loc, padOp.getResult(),
@@ -992,7 +989,7 @@ struct ReshapeOpInterface
return failure();
auto resultMemRefType = getMemRefType(
reshapeOp.getResult(), options, /*layout=*/{},
- srcBuffer->getType().cast<BaseMemRefType>().getMemorySpace());
+ cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
return success();
@@ -1039,14 +1036,13 @@ struct ParallelInsertSliceOpInterface
return failure();
// Take a subview of the destination buffer.
- auto destBufferType = destBuffer->getType().cast<MemRefType>();
+ auto destBufferType = cast<MemRefType>(destBuffer->getType());
auto subviewMemRefType =
- memref::SubViewOp::inferRankReducedResultType(
+ cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
- parallelInsertSliceOp.getMixedStrides())
- .cast<MemRefType>();
+ parallelInsertSliceOp.getMixedStrides()));
Value subview = rewriter.create<memref::SubViewOp>(
parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
index b5e75e081886a..968d68e143fe1 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
@@ -29,7 +29,7 @@ using namespace mlir::tensor;
/// Get the dimension size of a value of RankedTensor type at the
static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc,
Value rankedTensor, int64_t dimIdx) {
- RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
+ RankedTensorType tensorType = cast<RankedTensorType>(rankedTensor.getType());
if (!tensorType.isDynamicDim(dimIdx)) {
return b.getIndexAttr(tensorType.getDimSize(dimIdx));
}
@@ -41,7 +41,7 @@ static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc,
static SmallVector<OpFoldResult> getShapeDimSizes(OpBuilder &b, Location loc,
Value rankedTensor) {
SmallVector<OpFoldResult> dimSizes;
- RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
+ RankedTensorType tensorType = cast<RankedTensorType>(rankedTensor.getType());
for (unsigned i = 0; i < tensorType.getRank(); i++)
dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i));
return dimSizes;
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 71dddd1363794..4ecb800caab42 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -44,7 +44,7 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
Location loc,
Value rankedTensor) {
- auto tensorTy = rankedTensor.getType().cast<RankedTensorType>();
+ auto tensorTy = cast<RankedTensorType>(rankedTensor.getType());
SmallVector<Value> dynamicDims;
for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
if (en.value() == ShapedType::kDynamic)
@@ -57,7 +57,7 @@ SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
FailureOr<OpFoldResult> mlir::tensor::createDimValue(OpBuilder &b, Location loc,
Value rankedTensor,
int64_t dim) {
- auto tensorTy = rankedTensor.getType().dyn_cast<RankedTensorType>();
+ auto tensorTy = dyn_cast<RankedTensorType>(rankedTensor.getType());
if (!tensorTy)
return failure();
auto shape = tensorTy.getShape();
@@ -70,7 +70,7 @@ FailureOr<OpFoldResult> mlir::tensor::createDimValue(OpBuilder &b, Location loc,
SmallVector<OpFoldResult>
mlir::tensor::createDimValues(OpBuilder &b, Location loc, Value rankedTensor) {
- auto tensorTy = rankedTensor.getType().cast<RankedTensorType>();
+ auto tensorTy = cast<RankedTensorType>(rankedTensor.getType());
SmallVector<OpFoldResult> dims;
for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
if (ShapedType::isDynamic(en.value())) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7b47338649724..44f64f76e9b02 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -34,9 +34,9 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value weight = op.getWeight();
- ShapedType inputType = input.getType().cast<ShapedType>();
- ShapedType weightType = weight.getType().cast<ShapedType>();
- ShapedType resultType = op.getType().cast<ShapedType>();
+ ShapedType inputType = cast<ShapedType>(input.getType());
+ ShapedType weightType = cast<ShapedType>(weight.getType());
+ ShapedType resultType = cast<ShapedType>(op.getType());
auto numDynamic =
llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
@@ -66,7 +66,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto quantizationInfo = op.getQuantizationInfo();
int64_t iZp = quantizationInfo->getInputZp();
- if (!validIntegerRange(inputETy.cast<IntegerType>(), iZp))
+ if (!validIntegerRange(cast<IntegerType>(inputETy), iZp))
return rewriter.notifyMatchFailure(
op, "tosa.conv op quantization has zp outside of input range");
@@ -116,7 +116,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
- weight.getType().dyn_cast<RankedTensorType>().getElementType());
+ dyn_cast<RankedTensorType>(weight.getType()).getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 81ec7fd663791..488e46d1339a1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -28,9 +28,9 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value weight = op.getWeight();
- ShapedType inputType = input.getType().cast<ShapedType>();
- ShapedType weightType = weight.getType().cast<ShapedType>();
- ShapedType resultType = op.getOutput().getType().cast<ShapedType>();
+ ShapedType inputType = cast<ShapedType>(input.getType());
+ ShapedType weightType = cast<ShapedType>(weight.getType());
+ ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
resultType.hasStaticShape())) {
@@ -52,7 +52,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
inputType = RankedTensorType::get(
revisedInputShape,
- input.getType().dyn_cast<RankedTensorType>().getElementType());
+ dyn_cast<RankedTensorType>(input.getType()).getElementType());
input = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), inputType, input,
@@ -76,7 +76,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
auto applyZp = [&](Value val, int64_t zp) -> Value {
if (zp == 0)
return val;
- auto ety = val.getType().cast<ShapedType>().getElementType();
+ auto ety = cast<ShapedType>(val.getType()).getElementType();
auto zpTy = RankedTensorType::get({}, ety);
auto zpAttr =
DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
@@ -126,17 +126,17 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
auto mulShapeType = RankedTensorType::get(
mulShape,
- weight.getType().dyn_cast<RankedTensorType>().getElementType());
+ dyn_cast<RankedTensorType>(weight.getType()).getElementType());
Value mulValue = rewriter
.create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
weight, /*shift=*/0)
.getResult();
// Reshape output to [N, H, W, C * M].
- auto outputShape = op.getOutput().getType().cast<ShapedType>().getShape();
+ auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
auto outputShapeType = RankedTensorType::get(
outputShape,
- input.getType().dyn_cast<RankedTensorType>().getElementType());
+ dyn_cast<RankedTensorType>(input.getType()).getElementType());
auto outputValue = rewriter.create<tosa::ReshapeOp>(
op.getLoc(), outputShapeType, mulValue,
rewriter.getDenseI64ArrayAttr(outputShape));
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 74533defd055f..87563c1761a8d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -56,7 +56,7 @@ TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
// Compute the knowledge based on the inferred type.
auto inferredKnowledge =
mlir::tosa::ValueKnowledge::getPessimisticValueState();
- inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType();
+ inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
@@ -83,10 +83,10 @@ class TransposeConvNonStridedConverter
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
- ShapedType inputTy = input.getType().cast<ShapedType>();
- ShapedType weightTy = weight.getType().cast<ShapedType>();
- ShapedType biasTy = bias.getType().cast<ShapedType>();
- ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ ShapedType weightTy = cast<ShapedType>(weight.getType());
+ ShapedType biasTy = cast<ShapedType>(bias.getType());
+ ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
llvm::ArrayRef<int64_t> stride = op.getStride();
llvm::ArrayRef<int64_t> pad = op.getOutPad();
@@ -146,10 +146,10 @@ class TransposeConvStridedConverter
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
- ShapedType inputTy = input.getType().cast<ShapedType>();
- ShapedType weightTy = weight.getType().cast<ShapedType>();
- ShapedType biasTy = bias.getType().cast<ShapedType>();
- ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ ShapedType weightTy = cast<ShapedType>(weight.getType());
+ ShapedType biasTy = cast<ShapedType>(bias.getType());
+ ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
Type inputETy = inputTy.getElementType();
Type weightETy = weightTy.getElementType();
@@ -202,7 +202,7 @@ class TransposeConvStridedConverter
weight, weightPaddingVal);
}
- weightTy = weight.getType().cast<ShapedType>();
+ weightTy = cast<ShapedType>(weight.getType());
weightHeight = weightTy.getDimSize(1);
weightWidth = weightTy.getDimSize(2);
@@ -231,7 +231,7 @@ class TransposeConvStridedConverter
weight = createOpAndInfer<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
- ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
+ ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
weight = createOpAndInfer<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
@@ -297,7 +297,7 @@ class TransposeConvStridedConverter
}
// Factor the resulting width / height.
- ShapedType convTy = conv2d.getType().cast<ShapedType>();
+ ShapedType convTy = cast<ShapedType>(conv2d.getType());
Type convETy = convTy.getElementType();
int64_t convHeight = convTy.getDimSize(1);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
index 9e2102ee1d0ab..302e2793f0a32 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
@@ -72,7 +72,7 @@ DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
auto baseType = inputType.getElementType();
// Handle possible integer types
- if (auto intType = baseType.dyn_cast<IntegerType>()) {
+ if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
return transposeType<bool>(attr, inputType, outputType, permValues);
@@ -102,7 +102,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
- auto outputType = op.getType().cast<ShapedType>();
+ auto outputType = cast<ShapedType>(op.getType());
// TOSA supports quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();
@@ -122,7 +122,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
- auto inputType = op.getInput1().getType().cast<ShapedType>();
+ auto inputType = cast<ShapedType>(op.getInput1().getType());
auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 0c03cecf61bc4..3e2da9df3f94b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -54,7 +54,7 @@ void propagateShapesToTosaIf(
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
auto inferredTy = shapesStorage[op.getOperand(i)];
auto blockArg = frontBlock.getArgument(i - 1);
- auto oldType = blockArg.getType().cast<ShapedType>();
+ auto oldType = cast<ShapedType>(blockArg.getType());
if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getDims());
@@ -89,7 +89,7 @@ void propagateShapesToTosaWhile(
// loop body / condition for tosa.while.
llvm::SmallVector<Type> argTypes;
for (auto operand : op.getOperands()) {
- auto operandTy = operand.getType().cast<ShapedType>();
+ auto operandTy = cast<ShapedType>(operand.getType());
auto shapedTypeComponent = shapesStorage[operand];
if (shapedTypeComponent.hasRank()) {
auto newTy = operandTy.clone(shapedTypeComponent.getDims());
@@ -188,7 +188,7 @@ void propagateShapesToTosaWhile(
void propagateShapesInRegion(Region ®ion) {
DenseMap<Value, ShapedTypeComponents> shapesStorage;
auto setShapes = [&](Value val, Type t) {
- if (auto st = t.dyn_cast<ShapedType>())
+ if (auto st = dyn_cast<ShapedType>(t))
shapesStorage[val] = st;
else
shapesStorage[val] = t;
@@ -247,8 +247,7 @@ void propagateShapesInRegion(Region ®ion) {
// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
- inferredKnowledge.dtype =
- resultTy.cast<ShapedType>().getElementType();
+ inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
@@ -274,7 +273,7 @@ void propagateShapesInRegion(Region ®ion) {
for (auto it : shapesStorage) {
auto result = it.second;
if (result.hasRank()) {
- Type t = it.first.getType().cast<ShapedType>().clone(result.getDims());
+ Type t = cast<ShapedType>(it.first.getType()).clone(result.getDims());
it.first.setType(t);
}
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index b18e3b4bd2777..bcfcbbbbcee69 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -82,8 +82,8 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
Location loc,
RankedTensorType outputType,
Value &input1, Value &input2) {
- auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
- auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
+ auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
+ auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
if (!input1Ty || !input2Ty) {
return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
@@ -106,9 +106,9 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
}
ArrayRef<int64_t> higherRankShape =
- higherTensorValue.getType().cast<RankedTensorType>().getShape();
+ cast<RankedTensorType>(higherTensorValue.getType()).getShape();
ArrayRef<int64_t> lowerRankShape =
- lowerTensorValue.getType().cast<RankedTensorType>().getShape();
+ cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
SmallVector<int64_t, 4> reshapeOutputShape;
@@ -116,7 +116,7 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
.failed())
return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type");
- auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
+ auto reshapeInputType = cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
@@ -155,7 +155,7 @@ struct ConvertTosaOp : public OpRewritePattern<OpTy> {
Value input2 = tosaBinaryOp.getInput2();
Value output = tosaBinaryOp.getResult();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
return failure();
@@ -183,7 +183,7 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
Value input2 = tosaBinaryOp.getInput2();
int32_t shift = tosaBinaryOp.getShift();
Value output = tosaBinaryOp.getResult();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
return failure();
@@ -214,7 +214,7 @@ struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
Value input2 = tosaBinaryOp.getInput2();
int32_t round = tosaBinaryOp.getRound();
Value output = tosaBinaryOp.getResult();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
return failure();
@@ -242,7 +242,7 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
Value input3 = tosaOp.getOnFalse();
Value output = tosaOp.getResult();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
@@ -265,9 +265,9 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
tosaOp,
"cannot rewrite as the rank of all operands is already aligned");
- int32_t result1Rank = input1.getType().cast<RankedTensorType>().getRank();
- int32_t result2Rank = input2.getType().cast<RankedTensorType>().getRank();
- int32_t result3Rank = input3.getType().cast<RankedTensorType>().getRank();
+ int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
+ int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
+ int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
if ((result1Rank != result2Rank) || (result2Rank != result3Rank))
return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4cb727b00ca0c..5605080384bd7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -106,7 +106,7 @@ void TosaValidation::runOnOperation() {
getOperation().walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
if ((profileType == TosaProfileEnum::BaseInference) &&
- getElementTypeOrSelf(operand).isa<FloatType>()) {
+ isa<FloatType>(getElementTypeOrSelf(operand))) {
return signalPassFailure();
}
if (getElementTypeOrSelf(operand).isF64()) {
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 0b5fc451115cc..1c4ae1f27319f 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -116,16 +116,16 @@ ConvOpQuantizationAttr
mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
Value weight) {
- auto inputType = input.getType().dyn_cast<ShapedType>();
- auto weightType = weight.getType().dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
+ auto weightType = dyn_cast<ShapedType>(weight.getType());
if (!inputType || !weightType)
return nullptr;
auto inputQType = GET_UQTYPE(inputType);
auto weightPerTensorQType = GET_UQTYPE(weightType);
- auto weightPerAxisQType = weightType.getElementType()
- .dyn_cast<quant::UniformQuantizedPerAxisType>();
+ auto weightPerAxisQType =
+ dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType());
// Weights must be either per-tensor quantized or per-axis quantized.
assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
@@ -160,8 +160,8 @@ MatMulOpQuantizationAttr
mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
Value b) {
- auto aType = a.getType().dyn_cast<ShapedType>();
- auto bType = b.getType().dyn_cast<ShapedType>();
+ auto aType = dyn_cast<ShapedType>(a.getType());
+ auto bType = dyn_cast<ShapedType>(b.getType());
if (!aType || !bType)
return nullptr;
@@ -189,8 +189,8 @@ UnaryOpQuantizationAttr
mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
Type outputRawType) {
- auto inputType = input.getType().dyn_cast<ShapedType>();
- auto outputType = outputRawType.dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
+ auto outputType = dyn_cast<ShapedType>(outputRawType);
if (!inputType || !outputType)
return nullptr;
@@ -215,7 +215,7 @@ mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
Value input) {
- auto inputType = input.getType().dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType)
return nullptr;
@@ -235,8 +235,8 @@ PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
Value input, Value weight) {
- auto inputType = input.getType().dyn_cast<ShapedType>();
- auto weightType = weight.getType().dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
+ auto weightType = dyn_cast<ShapedType>(weight.getType());
assert(inputType && weightType &&
"Could not extract input or weight tensors from Conv op");
@@ -250,7 +250,7 @@ Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
- auto outputShapedType = outputType.dyn_cast<ShapedType>();
+ auto outputShapedType = dyn_cast<ShapedType>(outputType);
assert(outputShapedType &&
"Could not extract output shape type from Conv op");
@@ -274,8 +274,8 @@ Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
auto convfunc =
quant::ExpressedToQuantizedConverter::forInputType(inputDType);
- auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>();
- auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>();
+ auto minElems = dyn_cast<DenseFPElementsAttr>(minAttr);
+ auto maxElems = dyn_cast<DenseFPElementsAttr>(maxAttr);
SmallVector<double, 2> min, max;
@@ -291,12 +291,12 @@ Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
for (auto i : maxElems)
max.push_back(FloatAttr::getValueAsDouble(i));
} else { // Just a single FP value.
- auto minVal = minAttr.dyn_cast<FloatAttr>();
+ auto minVal = dyn_cast<FloatAttr>(minAttr);
if (minVal)
min.push_back(minVal.getValueAsDouble());
else
return {};
- auto maxVal = maxAttr.dyn_cast<FloatAttr>();
+ auto maxVal = dyn_cast<FloatAttr>(maxAttr);
if (maxVal)
max.push_back(maxVal.getValueAsDouble());
else
@@ -309,7 +309,7 @@ Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
narrowRange.getValue(), convfunc.expressedType, isSigned);
} else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
- auto shape = inputDType.dyn_cast<ShapedType>();
+ auto shape = dyn_cast<ShapedType>(inputDType);
if (!shape)
return {};
if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a1a30325258cf..2ae67dcfd0bef 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -116,7 +116,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
/// Returns the shape of the given type. Scalars will be considered as having a
/// shape with zero dimensions.
static ArrayRef<int64_t> getShape(Type type) {
- if (auto sType = type.dyn_cast<ShapedType>())
+ if (auto sType = dyn_cast<ShapedType>(type))
return sType.getShape();
return {};
}
@@ -142,8 +142,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
// If one of the types is unranked tensor, then the other type shouldn't be
// vector and the result should have unranked tensor type.
- if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
- if (type1.isa<VectorType>() || type2.isa<VectorType>())
+ if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
+ if (isa<VectorType>(type1) || isa<VectorType>(type2))
return {};
return UnrankedTensorType::get(elementType);
}
@@ -151,7 +151,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
// Returns the type kind if the given type is a vector or ranked tensor type.
// Returns std::nullopt otherwise.
auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> {
- if (type.isa<VectorType, RankedTensorType>())
+ if (isa<VectorType, RankedTensorType>(type))
return type.getTypeID();
return std::nullopt;
};
@@ -189,8 +189,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
template <typename iterator_range>
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
return std::make_tuple(
- llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
- llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
+ llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }),
+ llvm::any_of(types, [](Type t) { return isa<VectorType>(t); }));
}
static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
@@ -242,7 +242,7 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
return op->emitError("cannot broadcast vector with tensor");
auto rankedOperands = make_filter_range(
- op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
+ op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); });
// If all operands are unranked, then all result shapes are possible.
if (rankedOperands.empty())
@@ -261,7 +261,7 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
}
auto rankedResults = make_filter_range(
- op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
+ op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); });
// If all of the results are unranked then no further verification.
if (rankedResults.empty())
diff --git a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp
index 05fba01d689cb..45fa644f42ec3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp
@@ -148,14 +148,14 @@ class TransformOpMemFreeAnalysis {
// TODO: when this ported to the dataflow analysis infra, we should have
// proper support for region-based control flow.
Operation *valueSource =
- operand.get().isa<OpResult>()
+ isa<OpResult>(operand.get())
? operand.get().getDefiningOp()
: operand.get().getParentBlock()->getParentOp();
auto iface = cast<MemoryEffectOpInterface>(valueSource);
SmallVector<MemoryEffects::EffectInstance> instances;
iface.getEffectsOnResource(transform::TransformMappingResource::get(),
instances);
- assert((operand.get().isa<BlockArgument>() ||
+ assert((isa<BlockArgument>(operand.get()) ||
hasEffect<MemoryEffects::Allocate>(instances, operand.get())) &&
"expected the op defining the value to have an allocation effect "
"on it");
@@ -182,7 +182,7 @@ class TransformOpMemFreeAnalysis {
// value is defined in the middle of the block, i.e., is not a block
// argument.
bool isOutermost = ancestor == ancestors.front();
- bool isFromBlockPartial = isOutermost && operand.get().isa<OpResult>();
+ bool isFromBlockPartial = isOutermost && isa<OpResult>(operand.get());
// Check if the value may be freed by operations between its definition
// (allocation) point in its block and the terminator of the block or the
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 94fa2d3de22fb..853889269d0fb 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -162,7 +162,7 @@ ArrayAttr mlir::getReassociationIndicesAttribute(
SmallVector<Attribute, 4> reassociationAttr =
llvm::to_vector<4>(llvm::map_range(
reassociation, [&](const ReassociationIndices &indices) -> Attribute {
- return b.getI64ArrayAttr(indices).cast<Attribute>();
+ return cast<Attribute>(b.getI64ArrayAttr(indices));
}));
return b.getArrayAttr(reassociationAttr);
}
@@ -267,7 +267,7 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
}
bool mlir::hasNonIdentityLayout(Type type) {
- if (auto memrefType = type.dyn_cast<MemRefType>())
+ if (auto memrefType = dyn_cast<MemRefType>(type))
return !memrefType.getLayout().isIdentity();
return false;
}
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 45edd5f89ffed..09137d3336cc0 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -19,7 +19,7 @@ bool isZeroIndex(OpFoldResult v) {
if (!v)
return false;
if (auto attr = v.dyn_cast<Attribute>()) {
- IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
+ IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
@@ -53,7 +53,7 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
SmallVectorImpl<int64_t> &staticVec) {
auto v = ofr.dyn_cast<Value>();
if (!v) {
- APInt apInt = ofr.get<Attribute>().cast<IntegerAttr>().getValue();
+ APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
staticVec.push_back(apInt.getSExtValue());
return;
}
@@ -71,8 +71,8 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
return llvm::to_vector<4>(
- llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
+ llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> int64_t {
+ return cast<IntegerAttr>(a).getInt();
}));
}
@@ -124,7 +124,7 @@ std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
}
// Case 2: Check for IntegerAttr.
Attribute attr = ofr.dyn_cast<Attribute>();
- if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+ if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
}
@@ -184,7 +184,7 @@ decomposeMixedValues(Builder &b,
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
if (it.is<Attribute>()) {
- staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
+ staticValues.push_back(cast<IntegerAttr>(it.get<Attribute>()).getInt());
} else {
staticValues.push_back(ShapedType::kDynamic);
dynamicValues.push_back(it.get<Value>());
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index aed39f8644008..a2977901f4751 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -21,9 +21,9 @@ bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
- auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
- auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
+ auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -47,9 +47,9 @@ bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
- auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
- auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
+ auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -73,9 +73,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
- auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
- auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
+ auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index a6431043475aa..ad7e367c71ab7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -30,14 +30,14 @@ struct TransferReadOpInterface
vector::TransferReadOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return false;
}
@@ -50,7 +50,7 @@ struct TransferReadOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto readOp = cast<vector::TransferReadOp>(op);
- assert(readOp.getShapedType().isa<TensorType>() &&
+ assert(isa<TensorType>(readOp.getShapedType()) &&
"only tensor types expected");
FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
if (failed(buffer))
@@ -74,7 +74,7 @@ struct TransferWriteOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
- assert(writeOp.getShapedType().isa<TensorType>() &&
+ assert(isa<TensorType>(writeOp.getShapedType()) &&
"only tensor types expected");
// Create a new transfer_write on buffer that doesn't have a return value.
@@ -99,14 +99,14 @@ struct GatherOpInterface
vector::GatherOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return false;
}
@@ -119,7 +119,7 @@ struct GatherOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto gatherOp = cast<vector::GatherOp>(op);
- assert(gatherOp.getBaseType().isa<TensorType>() &&
+ assert(isa<TensorType>(gatherOp.getBaseType()) &&
"only tensor types expected");
FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
if (failed(buffer))
@@ -266,7 +266,7 @@ struct YieldOpInterface
// may get dropped during the bufferization of vector.mask.
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
- if (value.getType().isa<TensorType>()) {
+ if (isa<TensorType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index ad538fe4a6828..7c606e0c35f08 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -49,7 +49,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType dstType = op.getResultVectorType();
- VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
+ VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
Type eltType = dstType.getElementType();
// Scalar to any vector can use splat.
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 16751f82dad27..986c5f81d60c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -96,9 +96,9 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
}
// Unroll leading dimensions.
- VectorType vType = lowType.cast<VectorType>();
+ VectorType vType = cast<VectorType>(lowType);
Type resType = VectorType::Builder(type).dropDim(index);
- auto resVectorType = resType.cast<VectorType>();
+ auto resVectorType = cast<VectorType>(resType);
Value result = rewriter.create<arith::ConstantOp>(
loc, resVectorType, rewriter.getZeroAttr(resVectorType));
for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
@@ -126,7 +126,7 @@ static Value reshapeStore(Location loc, Value val, Value result,
}
// Unroll leading dimensions.
Type lowType = VectorType::Builder(type).dropDim(0);
- VectorType vType = lowType.cast<VectorType>();
+ VectorType vType = cast<VectorType>(lowType);
Type insType = VectorType::Builder(vType).dropDim(0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
@@ -160,7 +160,7 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
// Only valid for integer types.
return std::nullopt;
// Special case for fused multiply-add.
- if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
+ if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
if (mask)
// The fma op doesn't need explicit masking. However, fma ops used in
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
Value promote(Value v, Type dstElementType) {
Type elementType = v.getType();
- auto vecType = elementType.dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(elementType);
if (vecType)
elementType = vecType.getElementType();
if (elementType == dstElementType)
@@ -426,7 +426,7 @@ struct UnrolledOuterProductGenerator
Type promotedType = dstElementType;
if (vecType)
promotedType = VectorType::get(vecType.getShape(), promotedType);
- if (dstElementType.isa<FloatType>())
+ if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}
@@ -438,7 +438,7 @@ struct UnrolledOuterProductGenerator
if (mask && !maybeMask.has_value())
return failure();
- Type resElementType = res.getType().cast<VectorType>().getElementType();
+ Type resElementType = cast<VectorType>(res.getType()).getElementType();
for (int64_t k = 0; k < reductionSize; ++k) {
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
@@ -684,7 +684,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
return failure();
}
- VectorType dstType = op.getResultType().cast<VectorType>();
+ VectorType dstType = cast<VectorType>(op.getResultType());
assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
"Expected dst type of rank 1 or 2");
@@ -695,7 +695,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
// ExtractOp does not allow dynamic indexing, we must unroll explicitly.
Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
rewriter.getZeroAttr(dstType));
- bool isInt = dstType.getElementType().isa<IntegerType>();
+ bool isInt = isa<IntegerType>(dstType.getElementType());
for (unsigned r = 0; r < dstRows; ++r) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
@@ -789,7 +789,7 @@ struct ContractOpToElementwise
} else {
// If the parallel dimension doesn't exist we will have to broadcast it.
lhsDims.push_back(
- contractOp.getResultType().cast<VectorType>().getDimSize(i));
+ cast<VectorType>(contractOp.getResultType()).getDimSize(i));
lhsTranspose.push_back(lhsDims.size() - 1);
}
std::optional<unsigned> rhsDim =
@@ -799,7 +799,7 @@ struct ContractOpToElementwise
} else {
// If the parallel dimension doesn't exist we will have to broadcast it.
rhsDims.push_back(
- contractOp.getResultType().cast<VectorType>().getDimSize(i));
+ cast<VectorType>(contractOp.getResultType()).getDimSize(i));
rhsTranspose.push_back(rhsDims.size() - 1);
}
}
@@ -969,7 +969,7 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
Value mask) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
- VectorType resType = op.getResultType().cast<VectorType>();
+ VectorType resType = cast<VectorType>(op.getResultType());
// Find the iterator type index and result index.
SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
int64_t iterIndex = -1;
@@ -1044,10 +1044,10 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
Type resType = op.getResultType();
- if (resType.isa<VectorType>())
+ if (isa<VectorType>(resType))
return rewriter.notifyMatchFailure(op,
"did not expect a VectorType result");
- bool isInt = resType.isa<IntegerType>();
+ bool isInt = isa<IntegerType>(resType);
// Use iterator index 0.
int64_t iterIndex = 0;
SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
@@ -1133,10 +1133,10 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
auto loc = op.getLoc();
VectorType lhsType = op.getOperandVectorTypeLHS();
- VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
+ VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
VectorType resType = op.getResultVectorType();
Type eltType = resType.getElementType();
- bool isInt = eltType.isa<IntegerType, IndexType>();
+ bool isInt = isa<IntegerType, IndexType>(eltType);
Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
vector::CombiningKind kind = op.getKind();
@@ -1231,7 +1231,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
return failure();
Type dstElementType = op.getType();
- if (auto vecType = dstElementType.dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(dstElementType))
dstElementType = vecType.getElementType();
if (elementType != dstElementType)
return failure();
@@ -1259,8 +1259,8 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
return failure();
// At this point lhs and rhs are in row-major.
- VectorType lhsType = lhs.getType().cast<VectorType>();
- VectorType rhsType = rhs.getType().cast<VectorType>();
+ VectorType lhsType = cast<VectorType>(lhs.getType());
+ VectorType rhsType = cast<VectorType>(rhs.getType());
int64_t lhsRows = lhsType.getDimSize(0);
int64_t lhsColumns = lhsType.getDimSize(1);
int64_t rhsColumns = rhsType.getDimSize(1);
@@ -1289,7 +1289,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
llvm_unreachable("invalid contraction semantics");
Value res =
- elementType.isa<IntegerType>()
+ isa<IntegerType>(elementType)
? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
: static_cast<Value>(
rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 3f26558237a2f..a0ed056fc7a32 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -52,7 +52,7 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
- auto dstType = op.getResult().getType().cast<VectorType>();
+ auto dstType = cast<VectorType>(op.getResult().getType());
int64_t rank = dstType.getRank();
if (rank <= 1)
return rewriter.notifyMatchFailure(
@@ -112,7 +112,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
- bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
+ bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(
@@ -122,14 +122,14 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
}
// Scalable constant masks can only be lowered for the "none set" case.
- if (dstType.cast<VectorType>().isScalable()) {
+ if (cast<VectorType>(dstType).isScalable()) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(dstType, false));
return success();
}
int64_t trueDim = std::min(dstType.getDimSize(0),
- dimSizes[0].cast<IntegerAttr>().getInt());
+ cast<IntegerAttr>(dimSizes[0]).getInt());
if (rank == 1) {
// Express constant 1-D case in explicit vector form:
@@ -146,7 +146,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
VectorType::get(dstType.getShape().drop_front(), eltType);
SmallVector<int64_t> newDimSizes;
for (int64_t r = 1; r < rank; r++)
- newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
+ newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
Value result = rewriter.create<arith::ConstantOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index eb2deba7bc46b..463aab1ead38f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -48,7 +48,7 @@ static Value genOperator(Location loc, Value x, Value y,
PatternRewriter &rewriter) {
using vector::CombiningKind;
- auto elType = x.getType().cast<VectorType>().getElementType();
+ auto elType = cast<VectorType>(x.getType()).getElementType();
bool isInt = elType.isIntOrIndex();
Value combinedResult{nullptr};
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f15d0c85fd195..4f68526ac401e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -29,7 +29,7 @@ inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
size_t index = 0;
for (unsigned pos : permutation)
newInBoundsValues[pos] =
- attr.getValue()[index++].cast<BoolAttr>().getValue();
+ cast<BoolAttr>(attr.getValue()[index++]).getValue();
return builder.getBoolArrayAttr(newInBoundsValues);
}
@@ -37,7 +37,7 @@ inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
/// dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
int64_t addedRank) {
- auto originalVecType = vec.getType().cast<VectorType>();
+ auto originalVecType = cast<VectorType>(vec.getType());
SmallVector<int64_t> newShape(addedRank, 1);
newShape.append(originalVecType.getShape().begin(),
originalVecType.getShape().end());
@@ -257,7 +257,7 @@ struct TransferWriteNonPermutationLowering
// All the new dimensions added are inbound.
SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
for (Attribute attr : op.getInBounds().value().getValue()) {
- newInBoundsValues.push_back(attr.cast<BoolAttr>().getValue());
+ newInBoundsValues.push_back(cast<BoolAttr>(attr).getValue());
}
newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
}
@@ -315,7 +315,7 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
// In the meantime, lower these to a scalar load when they pop up.
if (reducedShapeRank == 0) {
Value newRead;
- if (op.getShapedType().isa<TensorType>()) {
+ if (isa<TensorType>(op.getShapedType())) {
newRead = rewriter.create<tensor::ExtractOp>(
op.getLoc(), op.getSource(), op.getIndices());
} else {
@@ -397,7 +397,7 @@ struct TransferReadToVectorLoadLowering
&broadcastedDims))
return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
- auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
+ auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
if (!memRefType)
return rewriter.notifyMatchFailure(read, "not a memref source");
@@ -418,11 +418,11 @@ struct TransferReadToVectorLoadLowering
// `vector.load` supports vector types as memref's elements only when the
// resulting vector type is the same as the element type.
auto memrefElTy = memRefType.getElementType();
- if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
+ if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
return rewriter.notifyMatchFailure(read, "incompatible element type");
// Otherwise, element types of the memref and the vector must match.
- if (!memrefElTy.isa<VectorType>() &&
+ if (!isa<VectorType>(memrefElTy) &&
memrefElTy != read.getVectorType().getElementType())
return rewriter.notifyMatchFailure(read, "non-matching element type");
@@ -543,7 +543,7 @@ struct TransferWriteToVectorStoreLowering
diag << "permutation map is not minor identity: " << write;
});
- auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
+ auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
if (!memRefType)
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "not a memref type: " << write;
@@ -558,13 +558,13 @@ struct TransferWriteToVectorStoreLowering
// `vector.store` supports vector types as memref's elements only when the
// type of the vector value being written is the same as the element type.
auto memrefElTy = memRefType.getElementType();
- if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
+ if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
});
// Otherwise, element types of the memref and the vector must match.
- if (!memrefElTy.isa<VectorType>() &&
+ if (!isa<VectorType>(memrefElTy) &&
memrefElTy != write.getVectorType().getElementType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 42c1aa58c5e5e..7d804ddcfa42f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -156,7 +156,7 @@ static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
/// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
uint8_t mask) {
- assert(v1.getType().cast<VectorType>().getShape()[0] == 16 &&
+ assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
"expected a vector with length=16");
SmallVector<int64_t> shuffleMask;
auto appendToMask = [&](int64_t base, uint8_t control) {
@@ -291,7 +291,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
auto reshInputType = VectorType::get(
- {m, n}, source.getType().cast<VectorType>().getElementType());
+ {m, n}, cast<VectorType>(source.getType()).getElementType());
Value res =
b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
for (int64_t i = 0; i < m; ++i)
@@ -329,7 +329,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
// Set up convenience transposition table.
SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
- transp.push_back(attr.cast<IntegerAttr>().getInt());
+ transp.push_back(cast<IntegerAttr>(attr).getInt());
if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
succeeded(isTranspose2DSlice(op)))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2b5706aaa7748..e56aa62a1871e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -62,8 +62,8 @@ struct DistributedLoadStoreHelper {
Value laneId, Value zero)
: sequentialVal(sequentialVal), distributedVal(distributedVal),
laneId(laneId), zero(zero) {
- sequentialVectorType = sequentialVal.getType().dyn_cast<VectorType>();
- distributedVectorType = distributedVal.getType().dyn_cast<VectorType>();
+ sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
+ distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
if (sequentialVectorType && distributedVectorType)
distributionMap =
calculateImplicitMap(sequentialVectorType, distributedVectorType);
@@ -89,7 +89,7 @@ struct DistributedLoadStoreHelper {
"Must store either the preregistered distributed or the "
"preregistered sequential value.");
// Scalar case can directly use memref.store.
- if (!val.getType().isa<VectorType>())
+ if (!isa<VectorType>(val.getType()))
return b.create<memref::StoreOp>(loc, val, buffer, zero);
// Vector case must use vector::TransferWriteOp which will later lower to
@@ -131,7 +131,7 @@ struct DistributedLoadStoreHelper {
Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
// Scalar case can directly use memref.store.
- if (!type.isa<VectorType>())
+ if (!isa<VectorType>(type))
return b.create<memref::LoadOp>(loc, buffer, zero);
// Other cases must be vector atm.
@@ -149,7 +149,7 @@ struct DistributedLoadStoreHelper {
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
- loc, type.cast<VectorType>(), buffer, indices,
+ loc, cast<VectorType>(type), buffer, indices,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
@@ -630,14 +630,14 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
Location loc = warpOp.getLoc();
for (OpOperand &operand : elementWise->getOpOperands()) {
Type targetType;
- if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
// If the result type is a vector, the operands must also be vectors.
- auto operandType = operand.get().getType().cast<VectorType>();
+ auto operandType = cast<VectorType>(operand.get().getType());
targetType =
VectorType::get(vecType.getShape(), operandType.getElementType());
} else {
auto operandType = operand.get().getType();
- assert(!operandType.isa<VectorType>() &&
+ assert(!isa<VectorType>(operandType) &&
"unexpected yield of vector from op with scalar result type");
targetType = operandType;
}
@@ -687,7 +687,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!yieldOperand)
return failure();
auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
- auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
+ auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
if (!dense)
return failure();
unsigned operandIndex = yieldOperand->getOperandNumber();
@@ -737,8 +737,8 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
SmallVector<Value, 4> indices(read.getIndices().begin(),
read.getIndices().end());
- auto sequentialType = read.getResult().getType().cast<VectorType>();
- auto distributedType = distributedVal.getType().cast<VectorType>();
+ auto sequentialType = cast<VectorType>(read.getResult().getType());
+ auto distributedType = cast<VectorType>(distributedVal.getType());
AffineMap map = calculateImplicitMap(sequentialType, distributedType);
AffineMap indexMap = map.compose(read.getPermutationMap());
OpBuilder::InsertionGuard g(rewriter);
@@ -752,7 +752,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
int64_t scale =
- distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
+ cast<VectorType>(distributedVal.getType()).getDimSize(vectorPos);
indices[indexPos] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
{indices[indexPos], warpOp.getLaneid()});
@@ -845,7 +845,7 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
resultIndex = operand.getOperandNumber();
break;
}
- auto arg = operand.get().dyn_cast<BlockArgument>();
+ auto arg = dyn_cast<BlockArgument>(operand.get());
if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
continue;
Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
@@ -874,7 +874,7 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
Location loc = broadcastOp.getLoc();
auto destVecType =
- warpOp->getResultTypes()[operandNumber].cast<VectorType>();
+ cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {broadcastOp.getSource()},
@@ -914,7 +914,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Rewrite vector.extract with 1d source to vector.extractelement.
if (extractSrcType.getRank() == 1) {
assert(extractOp.getPosition().size() == 1 && "expected 1 index");
- int64_t pos = extractOp.getPosition()[0].cast<IntegerAttr>().getInt();
+ int64_t pos = cast<IntegerAttr>(extractOp.getPosition()[0]).getInt();
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
extractOp, extractOp.getVector(),
@@ -946,8 +946,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Find the distributed dimension. There should be exactly one.
auto distributedType =
- warpOp.getResult(operandNumber).getType().cast<VectorType>();
- auto yieldedType = operand->get().getType().cast<VectorType>();
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distributedDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
@@ -1083,7 +1083,7 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
VectorType vecType = insertOp.getDestVectorType();
VectorType distrType =
- warpOp.getResult(operandNumber).getType().cast<VectorType>();
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
bool hasPos = static_cast<bool>(insertOp.getPosition());
// Yield destination vector, source scalar and position from warp op.
@@ -1171,7 +1171,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Rewrite vector.insert with 1d dest to vector.insertelement.
if (insertOp.getDestVectorType().getRank() == 1) {
assert(insertOp.getPosition().size() == 1 && "expected 1 index");
- int64_t pos = insertOp.getPosition()[0].cast<IntegerAttr>().getInt();
+ int64_t pos = cast<IntegerAttr>(insertOp.getPosition()[0]).getInt();
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
insertOp, insertOp.getSource(), insertOp.getDest(),
@@ -1199,8 +1199,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Find the distributed dimension. There should be exactly one.
auto distrDestType =
- warpOp.getResult(operandNumber).getType().cast<VectorType>();
- auto yieldedType = operand->get().getType().cast<VectorType>();
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distrDestDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
@@ -1213,7 +1213,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(distrDestDim != -1 && "could not find distributed dimension");
// Compute the distributed source vector type.
- VectorType srcVecType = insertOp.getSourceType().cast<VectorType>();
+ VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
srcVecType.getShape().end());
// E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
@@ -1248,7 +1248,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
SmallVector<int64_t> newPos = llvm::to_vector(
llvm::map_range(insertOp.getPosition(), [](Attribute attr) {
- return attr.cast<IntegerAttr>().getInt();
+ return cast<IntegerAttr>(attr).getInt();
}));
// tid of inserting lane: pos / elementsPerLane
Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
@@ -1337,7 +1337,7 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!escapingValues.insert(operand->get()))
return;
Type distType = operand->get().getType();
- if (auto vecType = distType.cast<VectorType>()) {
+ if (auto vecType = cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
@@ -1359,7 +1359,7 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
for (OpOperand &yieldOperand : yield->getOpOperands()) {
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
continue;
- auto forResult = yieldOperand.get().cast<OpResult>();
+ auto forResult = cast<OpResult>(yieldOperand.get());
newOperands.push_back(
newWarpOp.getResult(yieldOperand.getOperandNumber()));
yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
@@ -1463,7 +1463,7 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto reductionOp =
cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
- auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
+ auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
// Only rank 1 vectors supported.
if (vectorType.getRank() != 1)
return rewriter.notifyMatchFailure(
@@ -1564,7 +1564,7 @@ void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
// operations from there.
for (auto &op : body->without_terminator()) {
bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
- return result.getType().isa<VectorType>();
+ return isa<VectorType>(result.getType());
});
if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
opsToMove.insert(&op);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 6105e87573c23..8b2444199a501 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -136,10 +136,10 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
Type oldSrcType = insertOp.getSourceType();
Type newSrcType = oldSrcType;
int64_t oldSrcRank = 0, newSrcRank = 0;
- if (auto type = oldSrcType.dyn_cast<VectorType>()) {
+ if (auto type = dyn_cast<VectorType>(oldSrcType)) {
newSrcType = trimLeadingOneDims(type);
oldSrcRank = type.getRank();
- newSrcRank = newSrcType.cast<VectorType>().getRank();
+ newSrcRank = cast<VectorType>(newSrcType).getRank();
}
VectorType oldDstType = insertOp.getDestVectorType();
@@ -199,7 +199,7 @@ struct CastAwayTransferReadLeadingOneDim
if (read.getMask())
return failure();
- auto shapedType = read.getSource().getType().cast<ShapedType>();
+ auto shapedType = cast<ShapedType>(read.getSource().getType());
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
@@ -247,7 +247,7 @@ struct CastAwayTransferWriteLeadingOneDim
if (write.getMask())
return failure();
- auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
+ auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
@@ -284,7 +284,7 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
RewriterBase &rewriter) {
- VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
+ VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
if (oldAccType == nullptr)
return failure();
if (oldAccType.getRank() < 2)
@@ -418,7 +418,7 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
PatternRewriter &rewriter) const override {
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
return failure();
- auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
if (!vecType)
return failure();
VectorType newVecType = trimLeadingOneDims(vecType);
@@ -427,7 +427,7 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
int64_t dropDim = vecType.getRank() - newVecType.getRank();
SmallVector<Value, 4> newOperands;
for (Value operand : op->getOperands()) {
- if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
+ if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
newOperands.push_back(rewriter.create<vector::ExtractOp>(
op->getLoc(), operand, splatZero(dropDim)));
} else {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 313a3f9a9c090..37216cea7b615 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,7 +21,7 @@ using namespace mlir::vector;
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
- auto vectorType = into.getType().cast<VectorType>();
+ auto vectorType = cast<VectorType>(into.getType());
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
@@ -32,7 +32,7 @@ static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
- auto vectorType = vector.getType().cast<VectorType>();
+ auto vectorType = cast<VectorType>(vector.getType());
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
@@ -134,10 +134,10 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
}
int64_t offset =
- op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
int64_t size = srcType.getShape().front();
int64_t stride =
- op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
auto loc = op.getLoc();
Value res = op.getDest();
@@ -174,7 +174,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
- if (extractedSource.getType().isa<VectorType>()) {
+ if (isa<VectorType>(extractedSource.getType())) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
@@ -208,11 +208,10 @@ class Convert1DExtractStridedSliceIntoShuffle
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
- op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size =
- op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
+ int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
int64_t stride =
- op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
@@ -254,11 +253,10 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
return failure();
int64_t offset =
- op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size =
- op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
+ int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
int64_t stride =
- op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
Location loc = op.getLoc();
SmallVector<Value> elements;
@@ -300,11 +298,10 @@ class DecomposeNDExtractStridedSlice
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
- op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size =
- op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
+ int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
int64_t stride =
- op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 3a06d9bdea1f0..68d8c92a94df4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -261,7 +261,7 @@ static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
targetShape, inputType, offsets, sizes, strides);
- return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
+ return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
}
/// Creates a rank-reducing memref.subview op that drops unit dims from its
@@ -269,7 +269,7 @@ static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
mlir::Location loc,
Value input) {
- MemRefType inputType = input.getType().cast<MemRefType>();
+ MemRefType inputType = cast<MemRefType>(input.getType());
assert(inputType.hasStaticShape());
SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
@@ -304,9 +304,9 @@ class TransferReadDropUnitDimsPattern
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
- VectorType vectorType = vector.getType().cast<VectorType>();
+ VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getSource();
- MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+ MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// TODO: support tensor types.
if (!sourceType || !sourceType.hasStaticShape())
return failure();
@@ -347,9 +347,9 @@ class TransferWriteDropUnitDimsPattern
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
- VectorType vectorType = vector.getType().cast<VectorType>();
+ VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
- MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+ MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// TODO: support tensor type.
if (!sourceType || !sourceType.hasStaticShape())
return failure();
@@ -406,7 +406,7 @@ static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
/// input starting at `firstDimToCollapse`.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
Value input, int64_t firstDimToCollapse) {
- ShapedType inputType = input.getType().cast<ShapedType>();
+ ShapedType inputType = cast<ShapedType>(input.getType());
if (inputType.getRank() == 1)
return input;
SmallVector<ReassociationIndices> reassociation;
@@ -451,9 +451,9 @@ class FlattenContiguousRowMajorTransferReadPattern
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
- VectorType vectorType = vector.getType().cast<VectorType>();
+ VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getSource();
- MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+ MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
@@ -481,7 +481,7 @@ class FlattenContiguousRowMajorTransferReadPattern
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
MemRefType collapsedSourceType =
- collapsedSource.getType().dyn_cast<MemRefType>();
+ dyn_cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstContiguousInnerDim + 1);
SmallVector<AffineExpr, 1> dimExprs{
@@ -494,7 +494,7 @@ class FlattenContiguousRowMajorTransferReadPattern
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- transferReadOp, vector.getType().cast<VectorType>(), flatRead);
+ transferReadOp, cast<VectorType>(vector.getType()), flatRead);
return success();
}
};
@@ -511,9 +511,9 @@ class FlattenContiguousRowMajorTransferWritePattern
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
- VectorType vectorType = vector.getType().cast<VectorType>();
+ VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
- MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+ MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
@@ -541,7 +541,7 @@ class FlattenContiguousRowMajorTransferWritePattern
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
MemRefType collapsedSourceType =
- collapsedSource.getType().cast<MemRefType>();
+ cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstContiguousInnerDim + 1);
SmallVector<AffineExpr, 1> dimExprs{
@@ -610,7 +610,7 @@ class RewriteScalarExtractElementOfTransferRead
*getConstantIntValue(ofr));
}
}
- if (xferOp.getSource().getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
newIndices);
} else {
@@ -637,7 +637,7 @@ class RewriteScalarExtractOfTransferRead
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Only match scalar extracts.
- if (extractOp.getType().isa<VectorType>())
+ if (isa<VectorType>(extractOp.getType()))
return failure();
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
@@ -660,7 +660,7 @@ class RewriteScalarExtractOfTransferRead
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
- int64_t offset = it.value().cast<IntegerAttr>().getInt();
+ int64_t offset = cast<IntegerAttr>(it.value()).getInt();
int64_t idx =
newIndices.size() - extractOp.getPosition().size() + it.index();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
@@ -673,7 +673,7 @@ class RewriteScalarExtractOfTransferRead
extractOp.getLoc(), *getConstantIntValue(ofr));
}
}
- if (xferOp.getSource().getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
newIndices);
} else {
@@ -714,7 +714,7 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
xferOp.getVector(), pos);
}
// Construct a scalar store.
- if (xferOp.getSource().getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
} else {
@@ -732,12 +732,12 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
// Run store to load forwarding first since it can expose more dead store
// opportunity.
rootOp->walk([&](vector::TransferReadOp read) {
- if (read.getShapedType().isa<MemRefType>())
+ if (isa<MemRefType>(read.getShapedType()))
opt.storeToLoadForwarding(read);
});
opt.removeDeadOp();
rootOp->walk([&](vector::TransferWriteOp write) {
- if (write.getShapedType().isa<MemRefType>())
+ if (isa<MemRefType>(write.getShapedType()))
opt.deadStoreOp(write);
});
opt.removeDeadOp();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 34a7ce16ce983..6dacb1e199f3f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -190,7 +190,7 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
Location loc = xferOp.getLoc();
int64_t memrefRank = xferOp.getShapedType().getRank();
// TODO: relax this precondition, will require rank-reducing subviews.
- assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
+ assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
"Expected memref rank to match the alloc rank");
ValueRange leadingIndices =
xferOp.indices().take_front(xferOp.getLeadingShapedRank());
@@ -571,8 +571,8 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
}
MemRefType compatibleMemRefType =
- getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
- alloc.getType().cast<MemRefType>());
+ getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
+ cast<MemRefType>(alloc.getType()));
if (!compatibleMemRefType)
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 44f3a10c4da5c..d634d6a19030d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -93,9 +93,9 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
PatternRewriter &rewriter) const override {
// Check if 'shapeCastOp' has vector source/result type.
auto sourceVectorType =
- shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
+ dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
auto resultVectorType =
- shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
+ dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
if (!sourceVectorType || !resultVectorType)
return failure();
@@ -105,7 +105,7 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
if (!sourceShapeCastOp)
return failure();
auto operandSourceVectorType =
- sourceShapeCastOp.getSource().getType().cast<VectorType>();
+ cast<VectorType>(sourceShapeCastOp.getSource().getType());
auto operandResultVectorType = sourceShapeCastOp.getType();
// Check if shape cast operations invert each other.
@@ -342,7 +342,7 @@ struct CombineContractBroadcast
if (!broadcast)
continue;
// contractionOp can only take vector as operands.
- auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
+ auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
if (!srcType ||
srcType.getRank() == broadcast.getResultVectorType().getRank())
continue;
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
return failure();
Type castResTy = getElementTypeOrSelf(op->getResult(0));
- if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
+ if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
castResTy = VectorType::get(vecTy.getShape(), castResTy);
auto *castOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
@@ -530,7 +530,7 @@ struct ReorderElementwiseOpsOnTranspose final
// This is a constant. Create a reverse transpose op for it.
auto vectorType = VectorType::get(
srcType.getShape(),
- operand.getType().cast<VectorType>().getElementType());
+ cast<VectorType>(operand.getType()).getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
operand.getLoc(), vectorType, operand,
rewriter.getI64ArrayAttr(invOrder)));
@@ -539,7 +539,7 @@ struct ReorderElementwiseOpsOnTranspose final
auto vectorType = VectorType::get(
srcType.getShape(),
- op->getResultTypes()[0].cast<VectorType>().getElementType());
+ cast<VectorType>(op->getResultTypes()[0]).getElementType());
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
vectorType, op->getAttrs());
@@ -693,7 +693,7 @@ struct BubbleDownBitCastForStridedSliceExtract
}
SmallVector<int64_t> dims =
- llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
+ llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
dims.back() = dims.back() / expandRatio;
VectorType newExtractType =
VectorType::get(dims, castSrcType.getElementType());
@@ -996,7 +996,7 @@ class VectorCreateMaskOpConversion
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
- if (dstType.cast<VectorType>().isScalable())
+ if (cast<VectorType>(dstType).isScalable())
return failure();
int64_t rank = dstType.getRank();
if (rank > 1)
@@ -1026,7 +1026,7 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
if (readOp.getMask())
return failure();
- auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
+ auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
if (!srcType || !srcType.hasStaticShape())
return failure();
@@ -1060,13 +1060,13 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
MemRefType resultMemrefType;
MemRefLayoutAttrInterface layout = srcType.getLayout();
- if (layout.isa<AffineMapAttr>() && layout.isIdentity()) {
+ if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
resultMemrefType = MemRefType::get(
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
nullptr, srcType.getMemorySpace());
} else {
MemRefLayoutAttrInterface updatedLayout;
- if (auto strided = layout.dyn_cast<StridedLayoutAttr>()) {
+ if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
auto strides =
llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
updatedLayout = StridedLayoutAttr::get(strided.getContext(),
@@ -1099,7 +1099,7 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
strides);
auto permMap = getTransferMinorIdentityMap(
- rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
+ cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
Value result = rewriter.create<vector::TransferReadOp>(
loc, resultTargetVecType, rankedReducedView,
readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index f56e7cf256033..5eee318b51b33 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -316,7 +316,7 @@ struct UnrollContractionPattern
auto targetShape = getTargetShape(options, contractOp);
if (!targetShape)
return failure();
- auto dstVecType = contractOp.getResultType().cast<VectorType>();
+ auto dstVecType = cast<VectorType>(contractOp.getResultType());
SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
Location loc = contractOp.getLoc();
@@ -491,7 +491,7 @@ struct UnrollElementwisePattern : public RewritePattern {
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
- auto dstVecType = op->getResult(0).getType().cast<VectorType>();
+ auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
@@ -512,7 +512,7 @@ struct UnrollElementwisePattern : public RewritePattern {
getVectorOffset(ratioStrides, i, *targetShape);
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
- auto vecType = operand.get().getType().template dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(operand.get().getType());
if (!vecType) {
extractOperands.push_back(operand.get());
continue;
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index e77a13a9c653a..a1451fbf7f31d 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -36,9 +36,9 @@ using namespace mlir;
/// the type of `source`.
Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
int64_t dim) {
- if (source.getType().isa<UnrankedMemRefType, MemRefType>())
+ if (isa<UnrankedMemRefType, MemRefType>(source.getType()))
return b.createOrFold<memref::DimOp>(loc, source, dim);
- if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
+ if (isa<UnrankedTensorType, RankedTensorType>(source.getType()))
return b.createOrFold<tensor::DimOp>(loc, source, dim);
llvm_unreachable("Expected MemRefType or TensorType");
}
@@ -89,7 +89,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
- transp.push_back(attr.cast<IntegerAttr>().getInt());
+ transp.push_back(cast<IntegerAttr>(attr).getInt());
// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
@@ -223,7 +223,7 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
}
return false;
} else if (op.getNumResults() == 1) {
- if (auto v = op.getResult(0).getType().dyn_cast<VectorType>()) {
+ if (auto v = dyn_cast<VectorType>(op.getResult(0).getType())) {
superVectorType = v;
} else {
// Not a vector type.
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index e806db7b7cdef..b36f2978d20e3 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -22,11 +22,11 @@ using namespace mlir::x86vector;
/// Extracts the "main" vector element type from the given X86Vector operation.
template <typename OpTy>
static Type getSrcVectorElementType(OpTy op) {
- return op.getSrc().getType().template cast<VectorType>().getElementType();
+ return cast<VectorType>(op.getSrc().getType()).getElementType();
}
template <>
Type getSrcVectorElementType(Vp2IntersectOp op) {
- return op.getA().getType().template cast<VectorType>().getElementType();
+ return cast<VectorType>(op.getA().getType()).getElementType();
}
namespace {
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index 18405a92702b1..9464ce849b857 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -288,30 +288,27 @@ template <typename Type>
Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
template <>
Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
- auto resultType = mainFunction.getFunctionType()
- .cast<LLVM::LLVMFunctionType>()
- .getReturnType()
- .dyn_cast<IntegerType>();
+ auto resultType = dyn_cast<IntegerType>(
+ cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
+ .getReturnType());
if (!resultType || resultType.getWidth() != 32)
return makeStringError("only single i32 function result supported");
return Error::success();
}
template <>
Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
- auto resultType = mainFunction.getFunctionType()
- .cast<LLVM::LLVMFunctionType>()
- .getReturnType()
- .dyn_cast<IntegerType>();
+ auto resultType = dyn_cast<IntegerType>(
+ cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
+ .getReturnType());
if (!resultType || resultType.getWidth() != 64)
return makeStringError("only single i64 function result supported");
return Error::success();
}
template <>
Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
- if (!mainFunction.getFunctionType()
- .cast<LLVM::LLVMFunctionType>()
- .getReturnType()
- .isa<Float32Type>())
+ if (!isa<Float32Type>(
+ cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
+ .getReturnType()))
return makeStringError("only single f32 function result supported");
return Error::success();
}
@@ -324,8 +321,7 @@ Error compileAndExecuteSingleReturnFunction(
if (!mainFunction || mainFunction.isExternal())
return makeStringError("entry point not found");
- if (mainFunction.getFunctionType()
- .cast<LLVM::LLVMFunctionType>()
+ if (cast<LLVM::LLVMFunctionType>(mainFunction.getFunctionType())
.getNumParams() != 0)
return makeStringError("function inputs not supported");
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index cc04fa551d5c0..e335e156e89df 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -37,7 +37,7 @@ using namespace mlir;
static unsigned getIndexBitwidth(DataLayoutEntryListRef params) {
if (params.empty())
return 64;
- auto attr = params.front().getValue().cast<IntegerAttr>();
+ auto attr = cast<IntegerAttr>(params.front().getValue());
return attr.getValue().getZExtValue();
}
@@ -51,10 +51,10 @@ mlir::detail::getDefaultTypeSize(Type type, const DataLayout &dataLayout,
unsigned mlir::detail::getDefaultTypeSizeInBits(Type type,
const DataLayout &dataLayout,
DataLayoutEntryListRef params) {
- if (type.isa<IntegerType, FloatType>())
+ if (isa<IntegerType, FloatType>(type))
return type.getIntOrFloatBitWidth();
- if (auto ctype = type.dyn_cast<ComplexType>()) {
+ if (auto ctype = dyn_cast<ComplexType>(type)) {
auto et = ctype.getElementType();
auto innerAlignment =
getDefaultPreferredAlignment(et, dataLayout, params) * 8;
@@ -66,7 +66,7 @@ unsigned mlir::detail::getDefaultTypeSizeInBits(Type type,
}
// Index is an integer of some bitwidth.
- if (type.isa<IndexType>())
+ if (isa<IndexType>(type))
return dataLayout.getTypeSizeInBits(
IntegerType::get(type.getContext(), getIndexBitwidth(params)));
@@ -75,12 +75,12 @@ unsigned mlir::detail::getDefaultTypeSizeInBits(Type type,
// there is no bit-packing at the moment element sizes are taken in bytes and
// multiplied with 8 bits.
// TODO: make this extensible.
- if (auto vecType = type.dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(type))
return vecType.getNumElements() / vecType.getShape().back() *
llvm::PowerOf2Ceil(vecType.getShape().back()) *
dataLayout.getTypeSize(vecType.getElementType()) * 8;
- if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
+ if (auto typeInterface = dyn_cast<DataLayoutTypeInterface>(type))
return typeInterface.getTypeSizeInBits(dataLayout, params);
reportMissingDataLayout(type);
@@ -104,7 +104,7 @@ findEntryForIntegerType(IntegerType intType,
static unsigned extractABIAlignment(DataLayoutEntryInterface entry) {
auto values =
- entry.getValue().cast<DenseIntElementsAttr>().getValues<int32_t>();
+ cast<DenseIntElementsAttr>(entry.getValue()).getValues<int32_t>();
return *values.begin() / 8u;
}
@@ -134,24 +134,24 @@ unsigned mlir::detail::getDefaultABIAlignment(
Type type, const DataLayout &dataLayout,
ArrayRef<DataLayoutEntryInterface> params) {
// Natural alignment is the closest power-of-two number above.
- if (type.isa<VectorType>())
+ if (isa<VectorType>(type))
return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
- if (auto fltType = type.dyn_cast<FloatType>())
+ if (auto fltType = dyn_cast<FloatType>(type))
return getFloatTypeABIAlignment(fltType, dataLayout, params);
// Index is an integer of some bitwidth.
- if (type.isa<IndexType>())
+ if (isa<IndexType>(type))
return dataLayout.getTypeABIAlignment(
IntegerType::get(type.getContext(), getIndexBitwidth(params)));
- if (auto intType = type.dyn_cast<IntegerType>())
+ if (auto intType = dyn_cast<IntegerType>(type))
return getIntegerTypeABIAlignment(intType, params);
- if (auto ctype = type.dyn_cast<ComplexType>())
+ if (auto ctype = dyn_cast<ComplexType>(type))
return getDefaultABIAlignment(ctype.getElementType(), dataLayout, params);
- if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
+ if (auto typeInterface = dyn_cast<DataLayoutTypeInterface>(type))
return typeInterface.getABIAlignment(dataLayout, params);
reportMissingDataLayout(type);
@@ -159,7 +159,7 @@ unsigned mlir::detail::getDefaultABIAlignment(
static unsigned extractPreferredAlignment(DataLayoutEntryInterface entry) {
auto values =
- entry.getValue().cast<DenseIntElementsAttr>().getValues<int32_t>();
+ cast<DenseIntElementsAttr>(entry.getValue()).getValues<int32_t>();
return *std::next(values.begin(), values.size() - 1) / 8u;
}
@@ -187,27 +187,27 @@ unsigned mlir::detail::getDefaultPreferredAlignment(
Type type, const DataLayout &dataLayout,
ArrayRef<DataLayoutEntryInterface> params) {
// Preferred alignment is same as natural for floats and vectors.
- if (type.isa<VectorType>())
+ if (isa<VectorType>(type))
return dataLayout.getTypeABIAlignment(type);
- if (auto fltType = type.dyn_cast<FloatType>())
+ if (auto fltType = dyn_cast<FloatType>(type))
return getFloatTypePreferredAlignment(fltType, dataLayout, params);
// Preferred alignment is the closest power-of-two number above for integers
// (ABI alignment may be smaller).
- if (auto intType = type.dyn_cast<IntegerType>())
+ if (auto intType = dyn_cast<IntegerType>(type))
return getIntegerTypePreferredAlignment(intType, dataLayout, params);
- if (type.isa<IndexType>()) {
+ if (isa<IndexType>(type)) {
return dataLayout.getTypePreferredAlignment(
IntegerType::get(type.getContext(), getIndexBitwidth(params)));
}
- if (auto ctype = type.dyn_cast<ComplexType>())
+ if (auto ctype = dyn_cast<ComplexType>(type))
return getDefaultPreferredAlignment(ctype.getElementType(), dataLayout,
params);
- if (auto typeInterface = type.dyn_cast<DataLayoutTypeInterface>())
+ if (auto typeInterface = dyn_cast<DataLayoutTypeInterface>(type))
return typeInterface.getPreferredAlignment(dataLayout, params);
reportMissingDataLayout(type);
@@ -232,7 +232,7 @@ mlir::detail::getDefaultStackAlignment(DataLayoutEntryInterface entry) {
if (entry == DataLayoutEntryInterface())
return 0;
- auto value = entry.getValue().cast<IntegerAttr>();
+ auto value = cast<IntegerAttr>(entry.getValue());
return value.getValue().getZExtValue();
}
@@ -543,19 +543,19 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
for (const auto &kvp : types) {
auto sampleType = kvp.second.front().getKey().get<Type>();
- if (sampleType.isa<IndexType>()) {
+ if (isa<IndexType>(sampleType)) {
assert(kvp.second.size() == 1 &&
"expected one data layout entry for non-parametric 'index' type");
- if (!kvp.second.front().getValue().isa<IntegerAttr>())
+ if (!isa<IntegerAttr>(kvp.second.front().getValue()))
return emitError(loc)
<< "expected integer attribute in the data layout entry for "
<< sampleType;
continue;
}
- if (sampleType.isa<IntegerType, FloatType>()) {
+ if (isa<IntegerType, FloatType>(sampleType)) {
for (DataLayoutEntryInterface entry : kvp.second) {
- auto value = entry.getValue().dyn_cast<DenseIntElementsAttr>();
+ auto value = dyn_cast<DenseIntElementsAttr>(entry.getValue());
if (!value || !value.getElementType().isSignlessInteger(32)) {
emitError(loc) << "expected a dense i32 elements attribute in the "
"data layout entry "
@@ -587,7 +587,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
if (isa<BuiltinDialect>(&sampleType.getDialect()))
return emitError(loc) << "unexpected data layout for a built-in type";
- auto dlType = sampleType.dyn_cast<DataLayoutTypeInterface>();
+ auto dlType = dyn_cast<DataLayoutTypeInterface>(sampleType);
if (!dlType)
return emitError(loc)
<< "data layout specified for a type that does not support it";
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
index aff6a8fc1925e..a9bab23f1a72c 100644
--- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
+++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
@@ -29,9 +29,9 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) {
Type type = operand->get().getType();
- if (type.isa<MemRefType>()) {
+ if (isa<MemRefType>(type)) {
outputBufferOperands.push_back(operand);
- } else if (type.isa<RankedTensorType>()) {
+ } else if (isa<RankedTensorType>(type)) {
outputTensorOperands.push_back(operand);
} else {
return op->emitOpError("expected that operand #")
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9248b1149ff2f..cc31104ce3335 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -30,7 +30,7 @@ const APInt &ConstantIntRanges::smax() const { return smaxVal; }
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
if (type.isIndex())
return IndexType::kInternalStorageBitWidth;
- if (auto integerType = type.dyn_cast<IntegerType>())
+ if (auto integerType = dyn_cast<IntegerType>(type))
return integerType.getWidth();
// Non-integer types have their bounds stored in width 0 `APInt`s.
return 0;
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index ebb10e07cebd2..80ed2cc3f22ed 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -36,7 +36,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
// a correct result.
int64_t resultIdx = 0;
for (OpResult result : op->getResults()) {
- auto shapedType = result.getType().dyn_cast<ShapedType>();
+ auto shapedType = dyn_cast<ShapedType>(result.getType());
if (!shapedType)
continue;
if (!shapedType.hasRank()) {
@@ -69,7 +69,7 @@ bool ShapeAdaptor::hasRank() const {
if (val.isNull())
return false;
if (auto t = val.dyn_cast<Type>())
- return t.cast<ShapedType>().hasRank();
+ return cast<ShapedType>(t).hasRank();
if (val.is<Attribute>())
return true;
return val.get<ShapedTypeComponents *>()->hasRank();
@@ -79,7 +79,7 @@ Type ShapeAdaptor::getElementType() const {
if (val.isNull())
return nullptr;
if (auto t = val.dyn_cast<Type>())
- return t.cast<ShapedType>().getElementType();
+ return cast<ShapedType>(t).getElementType();
if (val.is<Attribute>())
return nullptr;
return val.get<ShapedTypeComponents *>()->getElementType();
@@ -88,10 +88,10 @@ Type ShapeAdaptor::getElementType() const {
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
assert(hasRank());
if (auto t = val.dyn_cast<Type>()) {
- ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape();
+ ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
res.assign(vals.begin(), vals.end());
} else if (auto attr = val.dyn_cast<Attribute>()) {
- auto dattr = attr.cast<DenseIntElementsAttr>();
+ auto dattr = cast<DenseIntElementsAttr>(attr);
res.clear();
res.reserve(dattr.size());
for (auto it : dattr.getValues<APInt>())
@@ -111,9 +111,9 @@ void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
int64_t ShapeAdaptor::getDimSize(int index) const {
assert(hasRank());
if (auto t = val.dyn_cast<Type>())
- return t.cast<ShapedType>().getDimSize(index);
+ return cast<ShapedType>(t).getDimSize(index);
if (auto attr = val.dyn_cast<Attribute>())
- return attr.cast<DenseIntElementsAttr>()
+ return cast<DenseIntElementsAttr>(attr)
.getValues<APInt>()[index]
.getSExtValue();
auto *stc = val.get<ShapedTypeComponents *>();
@@ -123,9 +123,9 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
int64_t ShapeAdaptor::getRank() const {
assert(hasRank());
if (auto t = val.dyn_cast<Type>())
- return t.cast<ShapedType>().getRank();
+ return cast<ShapedType>(t).getRank();
if (auto attr = val.dyn_cast<Attribute>())
- return attr.cast<DenseIntElementsAttr>().size();
+ return cast<DenseIntElementsAttr>(attr).size();
return val.get<ShapedTypeComponents *>()->getDims().size();
}
@@ -134,9 +134,9 @@ bool ShapeAdaptor::hasStaticShape() const {
return false;
if (auto t = val.dyn_cast<Type>())
- return t.cast<ShapedType>().hasStaticShape();
+ return cast<ShapedType>(t).hasStaticShape();
if (auto attr = val.dyn_cast<Attribute>()) {
- auto dattr = attr.cast<DenseIntElementsAttr>();
+ auto dattr = cast<DenseIntElementsAttr>(attr);
for (auto index : dattr.getValues<APInt>())
if (ShapedType::isDynamic(index.getSExtValue()))
return false;
@@ -150,10 +150,10 @@ int64_t ShapeAdaptor::getNumElements() const {
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
if (auto t = val.dyn_cast<Type>())
- return t.cast<ShapedType>().getNumElements();
+ return cast<ShapedType>(t).getNumElements();
if (auto attr = val.dyn_cast<Attribute>()) {
- auto dattr = attr.cast<DenseIntElementsAttr>();
+ auto dattr = cast<DenseIntElementsAttr>(attr);
int64_t num = 1;
for (auto index : dattr.getValues<APInt>()) {
num *= index.getZExtValue();
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 95fb785defdd9..1fbe42cd114b0 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -34,7 +34,7 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
}
// Case 2: Check for IntegerAttr.
Attribute attr = ofr.dyn_cast<Attribute>();
- if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+ if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
}
@@ -137,8 +137,8 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
std::optional<int64_t> dim) const {
#ifndef NDEBUG
assertValidValueDim(value, dim);
- assert((value.isa<OpResult>() ||
- value.cast<BlockArgument>().getOwner()->isEntryBlock()) &&
+ assert((isa<OpResult>(value) ||
+ cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
"unstructured control flow is not supported");
#endif // NDEBUG
@@ -149,7 +149,7 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
}
static Operation *getOwnerOfValue(Value value) {
- if (auto bbArg = value.dyn_cast<BlockArgument>())
+ if (auto bbArg = dyn_cast<BlockArgument>(value))
return bbArg.getOwner()->getParentOp();
return value.getDefiningOp();
}
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index eca0297733e7d..c8c442823781b 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -402,7 +402,7 @@ struct ByteCodeWriter {
.Case<pdl::OperationType>(
[](Type) { return PDLValue::Kind::Operation; })
.Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
- if (rangeTy.getElementType().isa<pdl::TypeType>())
+ if (isa<pdl::TypeType>(rangeTy.getElementType()))
return PDLValue::Kind::TypeRange;
return PDLValue::Kind::ValueRange;
})
@@ -538,11 +538,11 @@ void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
auto processRewriterValue = [&](Value val) {
valueToMemIndex.try_emplace(val, index++);
- if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
+ if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
Type elementTy = rangeType.getElementType();
- if (elementTy.isa<pdl::TypeType>())
+ if (isa<pdl::TypeType>(elementTy))
valueToRangeIndex.try_emplace(val, typeRangeIndex++);
- else if (elementTy.isa<pdl::ValueType>())
+ else if (isa<pdl::ValueType>(elementTy))
valueToRangeIndex.try_emplace(val, valueRangeIndex++);
}
};
@@ -611,13 +611,13 @@ void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
/*dummyValue*/ 0);
// Check to see if this value is a range type.
- if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
+ if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
Type eleType = rangeTy.getElementType();
- if (eleType.isa<pdl::OperationType>())
+ if (isa<pdl::OperationType>(eleType))
defRangeIt->second.opRangeIndex = 0;
- else if (eleType.isa<pdl::TypeType>())
+ else if (isa<pdl::TypeType>(eleType))
defRangeIt->second.typeRangeIndex = 0;
- else if (eleType.isa<pdl::ValueType>())
+ else if (isa<pdl::ValueType>(eleType))
defRangeIt->second.valueRangeIndex = 0;
}
};
@@ -792,14 +792,14 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
#endif
// Range results also need to append the range storage index.
- if (result.getType().isa<pdl::RangeType>())
+ if (isa<pdl::RangeType>(result.getType()))
writer.append(getRangeStorageIndex(result));
writer.append(result);
}
}
void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
Value lhs = op.getLhs();
- if (lhs.getType().isa<pdl::RangeType>()) {
+ if (isa<pdl::RangeType>(lhs.getType())) {
writer.append(OpCode::AreRangesEqual);
writer.appendPDLValueKind(lhs);
writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
@@ -945,7 +945,7 @@ void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
writer.append(OpCode::GetOperands,
index.value_or(std::numeric_limits<uint32_t>::max()),
op.getInputOp());
- if (result.getType().isa<pdl::RangeType>())
+ if (isa<pdl::RangeType>(result.getType()))
writer.append(getRangeStorageIndex(result));
else
writer.append(std::numeric_limits<ByteCodeField>::max());
@@ -965,7 +965,7 @@ void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
writer.append(OpCode::GetResults,
index.value_or(std::numeric_limits<uint32_t>::max()),
op.getInputOp());
- if (result.getType().isa<pdl::RangeType>())
+ if (isa<pdl::RangeType>(result.getType()))
writer.append(getRangeStorageIndex(result));
else
writer.append(std::numeric_limits<ByteCodeField>::max());
@@ -979,7 +979,7 @@ void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
}
void Generator::generate(pdl_interp::GetValueTypeOp op,
ByteCodeWriter &writer) {
- if (op.getType().isa<pdl::RangeType>()) {
+ if (isa<pdl::RangeType>(op.getType())) {
Value result = op.getResult();
writer.append(OpCode::GetValueRangeTypes, result,
getRangeStorageIndex(result), op.getValue());
@@ -1016,7 +1016,7 @@ void Generator::generate(pdl_interp::SwitchOperandCountOp op,
void Generator::generate(pdl_interp::SwitchOperationNameOp op,
ByteCodeWriter &writer) {
auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
- return OperationName(attr.cast<StringAttr>().getValue(), ctx);
+ return OperationName(cast<StringAttr>(attr).getValue(), ctx);
});
writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
op.getSuccessors());
@@ -1566,7 +1566,7 @@ void ByteCodeExecutor::executeCheckTypes() {
Attribute rhs = read<Attribute>();
LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
- selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
+ selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
}
void ByteCodeExecutor::executeContinue() {
@@ -1581,7 +1581,7 @@ void ByteCodeExecutor::executeCreateConstantTypeRange() {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
unsigned memIndex = read();
unsigned rangeIndex = read();
- ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
+ ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
@@ -1743,7 +1743,7 @@ void ByteCodeExecutor::executeGetAttributeType() {
unsigned memIndex = read();
Attribute attr = read<Attribute>();
Type type;
- if (auto typedAttr = attr.dyn_cast<TypedAttr>())
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr))
type = typedAttr.getType();
LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index b0e278cdddb55..b7eb1f07f2d9b 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -190,7 +190,7 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
// the FuncOp.
if (emitter.shouldDeclareVariablesAtTop()) {
// Skip the assignment if the emitc.constant has no value.
- if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
+ if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
if (oAttr.getValue().empty())
return success();
}
@@ -201,7 +201,7 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
}
// Emit a variable declaration for an emitc.constant op without value.
- if (auto oAttr = value.dyn_cast<emitc::OpaqueAttr>()) {
+ if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
if (oAttr.getValue().empty())
// The semicolon gets printed by the emitOperation function.
return emitter.emitVariableDeclaration(result,
@@ -333,7 +333,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
os << callOp.getCallee();
auto emitArgs = [&](Attribute attr) -> LogicalResult {
- if (auto t = attr.dyn_cast<IntegerAttr>()) {
+ if (auto t = dyn_cast<IntegerAttr>(attr)) {
// Index attributes are treated specially as operand index.
if (t.getType().isIndex()) {
int64_t idx = t.getInt();
@@ -759,11 +759,11 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
};
// Print floating point attributes.
- if (auto fAttr = attr.dyn_cast<FloatAttr>()) {
+ if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
printFloat(fAttr.getValue());
return success();
}
- if (auto dense = attr.dyn_cast<DenseFPElementsAttr>()) {
+ if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
os << '{';
interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
os << '}';
@@ -771,21 +771,19 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
}
// Print integer attributes.
- if (auto iAttr = attr.dyn_cast<IntegerAttr>()) {
- if (auto iType = iAttr.getType().dyn_cast<IntegerType>()) {
+ if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
+ if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
return success();
}
- if (auto iType = iAttr.getType().dyn_cast<IndexType>()) {
+ if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
printInt(iAttr.getValue(), false);
return success();
}
}
- if (auto dense = attr.dyn_cast<DenseIntElementsAttr>()) {
- if (auto iType = dense.getType()
- .cast<TensorType>()
- .getElementType()
- .dyn_cast<IntegerType>()) {
+ if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
+ if (auto iType = dyn_cast<IntegerType>(
+ cast<TensorType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os, [&](const APInt &val) {
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
@@ -793,10 +791,8 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
os << '}';
return success();
}
- if (auto iType = dense.getType()
- .cast<TensorType>()
- .getElementType()
- .dyn_cast<IndexType>()) {
+ if (auto iType = dyn_cast<IndexType>(
+ cast<TensorType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os,
[&](const APInt &val) { printInt(val, false); });
@@ -806,13 +802,13 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
}
// Print opaque attributes.
- if (auto oAttr = attr.dyn_cast<emitc::OpaqueAttr>()) {
+ if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
os << oAttr.getValue();
return success();
}
// Print symbolic reference attributes.
- if (auto sAttr = attr.dyn_cast<SymbolRefAttr>()) {
+ if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
if (sAttr.getNestedReferences().size() > 1)
return emitError(loc, "attribute has more than 1 nested reference");
os << sAttr.getRootReference().getValue();
@@ -820,7 +816,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
}
// Print type attributes.
- if (auto type = attr.dyn_cast<TypeAttr>())
+ if (auto type = dyn_cast<TypeAttr>(attr))
return emitType(loc, type.getValue());
return emitError(loc, "cannot emit attribute: ") << attr;
@@ -957,7 +953,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
}
LogicalResult CppEmitter::emitType(Location loc, Type type) {
- if (auto iType = type.dyn_cast<IntegerType>()) {
+ if (auto iType = dyn_cast<IntegerType>(type)) {
switch (iType.getWidth()) {
case 1:
return (os << "bool"), success();
@@ -973,7 +969,7 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
return emitError(loc, "cannot emit integer type ") << type;
}
}
- if (auto fType = type.dyn_cast<FloatType>()) {
+ if (auto fType = dyn_cast<FloatType>(type)) {
switch (fType.getWidth()) {
case 32:
return (os << "float"), success();
@@ -983,9 +979,9 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
return emitError(loc, "cannot emit float type ") << type;
}
}
- if (auto iType = type.dyn_cast<IndexType>())
+ if (auto iType = dyn_cast<IndexType>(type))
return (os << "size_t"), success();
- if (auto tType = type.dyn_cast<TensorType>()) {
+ if (auto tType = dyn_cast<TensorType>(type)) {
if (!tType.hasRank())
return emitError(loc, "cannot emit unranked tensor type");
if (!tType.hasStaticShape())
@@ -1001,13 +997,13 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
os << ">";
return success();
}
- if (auto tType = type.dyn_cast<TupleType>())
+ if (auto tType = dyn_cast<TupleType>(type))
return emitTupleType(loc, tType.getTypes());
- if (auto oType = type.dyn_cast<emitc::OpaqueType>()) {
+ if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
os << oType.getValue();
return success();
}
- if (auto pType = type.dyn_cast<emitc::PointerType>()) {
+ if (auto pType = dyn_cast<emitc::PointerType>(type)) {
if (failed(emitType(loc, pType.getPointee())))
return failure();
os << "*";
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index f409aa413cf7a..87d02f84e4b4c 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -21,8 +21,8 @@ using namespace mlir::LLVM::detail;
/// A utility walker that interrupts if the operation has valid debug
/// information.
static WalkResult interruptIfValidLocation(Operation *op) {
- return op->getLoc().isa<UnknownLoc>() ? WalkResult::advance()
- : WalkResult::interrupt();
+ return isa<UnknownLoc>(op->getLoc()) ? WalkResult::advance()
+ : WalkResult::interrupt();
}
DebugTranslation::DebugTranslation(Operation *module, llvm::Module &llvmModule)
@@ -45,7 +45,7 @@ DebugTranslation::DebugTranslation(Operation *module, llvm::Module &llvmModule)
if (auto targetTripleAttr =
module->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) {
auto targetTriple =
- llvm::Triple(targetTripleAttr.cast<StringAttr>().getValue());
+ llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue());
if (targetTriple.isKnownWindowsMSVCEnvironment()) {
// Dwarf debugging files will be generated by default, unless "CodeView"
// is set explicitly. Windows/MSVC should use CodeView instead.
@@ -68,8 +68,8 @@ void DebugTranslation::translate(LLVMFuncOp func, llvm::Function &llvmFunc) {
const bool hasCallWithoutDebugInfo =
func.walk([&](LLVM::CallOp call) {
return call.getLoc()->walk([](Location l) {
- return l.isa<UnknownLoc>() ? WalkResult::interrupt()
- : WalkResult::advance();
+ return isa<UnknownLoc>(l) ? WalkResult::interrupt()
+ : WalkResult::advance();
});
})
.wasInterrupted();
@@ -273,7 +273,7 @@ const llvm::DILocation *
DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
const llvm::DILocation *inlinedAt) {
// LLVM doesn't have a representation for unknown.
- if (!scope || loc.isa<UnknownLoc>())
+ if (!scope || isa<UnknownLoc>(loc))
return nullptr;
// Check for a cached instance.
@@ -282,12 +282,12 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
return existingIt->second;
const llvm::DILocation *llvmLoc = nullptr;
- if (auto callLoc = loc.dyn_cast<CallSiteLoc>()) {
+ if (auto callLoc = dyn_cast<CallSiteLoc>(loc)) {
// For callsites, the caller is fed as the inlinedAt for the callee.
const auto *callerLoc = translateLoc(callLoc.getCaller(), scope, inlinedAt);
llvmLoc = translateLoc(callLoc.getCallee(), scope, callerLoc);
- } else if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
+ } else if (auto fileLoc = dyn_cast<FileLineColLoc>(loc)) {
llvm::DILocalScope *locationScope = scope;
// Only construct a new DIFile when no local scope is present. This
// prioritizes existing DI information when it's present.
@@ -300,12 +300,12 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
fileLoc.getColumn(), locationScope,
const_cast<llvm::DILocation *>(inlinedAt));
- } else if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
+ } else if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
ArrayRef<Location> locations = fusedLoc.getLocations();
// Check for a scope encoded with the location.
if (auto scopedAttr =
- fusedLoc.getMetadata().dyn_cast_or_null<LLVM::DILocalScopeAttr>())
+ dyn_cast_or_null<LLVM::DILocalScopeAttr>(fusedLoc.getMetadata()))
scope = translate(scopedAttr);
// For fused locations, merge each of the nodes.
@@ -315,10 +315,10 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
llvmLoc, translateLoc(locIt, scope, inlinedAt));
}
- } else if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
+ } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
llvmLoc = translateLoc(nameLoc.getChildLoc(), scope, inlinedAt);
- } else if (auto opaqueLoc = loc.dyn_cast<OpaqueLoc>()) {
+ } else if (auto opaqueLoc = dyn_cast<OpaqueLoc>(loc)) {
llvmLoc = translateLoc(opaqueLoc.getFallbackLocation(), scope, inlinedAt);
} else {
llvm_unreachable("unknown location kind");
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 5c98f922b22e2..c12d7f5166a5d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -231,9 +231,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
Attribute attr = it.value();
if (!attr)
continue;
- DictionaryAttr dAttr = attr.cast<DictionaryAttr>();
+ DictionaryAttr dAttr = cast<DictionaryAttr>(attr);
TypeAttr tAttr =
- dAttr.get(InlineAsmOp::getElementTypeAttrName()).cast<TypeAttr>();
+ cast<TypeAttr>(dAttr.get(InlineAsmOp::getElementTypeAttrName()));
llvm::AttrBuilder b(moduleTranslation.getLLVMContext());
llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue());
b.addTypeAttr(llvm::Attribute::ElementType, ty);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index d7f1bb6a7be7c..eec84569a21a1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -162,7 +162,7 @@ class NVVMDialectLLVMIRTranslationInterface
->addOperand(llvmMetadataNode);
};
if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
- if (!attribute.getValue().dyn_cast<ArrayAttr>())
+ if (!dyn_cast<ArrayAttr>(attribute.getValue()))
return failure();
SmallVector<int64_t> values =
extractFromI64ArrayAttr(attribute.getValue());
@@ -172,7 +172,7 @@ class NVVMDialectLLVMIRTranslationInterface
if (values.size() > 2)
generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
} else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
- if (!attribute.getValue().dyn_cast<ArrayAttr>())
+ if (!dyn_cast<ArrayAttr>(attribute.getValue()))
return failure();
SmallVector<int64_t> values =
extractFromI64ArrayAttr(attribute.getValue());
@@ -183,10 +183,10 @@ class NVVMDialectLLVMIRTranslationInterface
generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
} else if (attribute.getName() ==
NVVM::NVVMDialect::getMinctasmAttrName()) {
- auto value = attribute.getValue().dyn_cast<IntegerAttr>();
+ auto value = dyn_cast<IntegerAttr>(attribute.getValue());
generateMetadata(value.getInt(), "minctasm");
} else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
- auto value = attribute.getValue().dyn_cast<IntegerAttr>();
+ auto value = dyn_cast<IntegerAttr>(attribute.getValue());
generateMetadata(value.getInt(), "maxnreg");
} else if (attribute.getName() ==
NVVM::NVVMDialect::getKernelFuncAttrName()) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index 91ff17437bbcf..392d34cd6f913 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -99,7 +99,7 @@ processOperands(llvm::IRBuilderBase &builder,
llvm::Value *dataPtr;
llvm::Value *dataSize;
- if (data.getType().isa<LLVM::LLVMPointerType>()) {
+ if (isa<LLVM::LLVMPointerType>(data.getType())) {
dataPtrBase = dataValue;
dataPtr = dataValue;
dataSize = accBuilder->getSizeInBytes(dataValue);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 84c39c07816af..750f7157bf53f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -367,7 +367,7 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
if (criticalOp.getNameAttr()) {
// The verifiers in OpenMP Dialect guarentee that all the pointers are
// non-null
- auto symbolRef = criticalOp.getNameAttr().cast<SymbolRefAttr>();
+ auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
auto criticalDeclareOp =
SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
symbolRef);
@@ -389,7 +389,7 @@ static omp::ReductionDeclareOp findReductionDecl(omp::WsLoopOp container,
for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) {
if (container.getReductionVars()[i] != reduction.getAccumulator())
continue;
- reductionSymbol = (*container.getReductions())[i].cast<SymbolRefAttr>();
+ reductionSymbol = cast<SymbolRefAttr>((*container.getReductions())[i]);
break;
}
assert(reductionSymbol &&
@@ -705,7 +705,7 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) {
llvm::omp::RTLDependenceKindTy type;
switch (
- std::get<1>(dep).cast<mlir::omp::ClauseTaskDependAttr>().getValue()) {
+ cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
case mlir::omp::ClauseTaskDepend::taskdependin:
type = llvm::omp::RTLDependenceKindTy::DepIn;
break;
@@ -1379,7 +1379,7 @@ static LogicalResult processMapOperand(
llvm::Value *mapOpPtr;
llvm::Value *mapOpSize;
- if (mapOp.getType().isa<LLVM::LLVMPointerType>()) {
+ if (isa<LLVM::LLVMPointerType>(mapOp.getType())) {
mapOpPtrBase = mapOpValue;
mapOpPtr = mapOpValue;
mapOpSize = ompBuilder->getSizeInBytes(mapOpValue);
@@ -1410,7 +1410,7 @@ static LogicalResult processMapOperand(
{builder.getInt32(0), builder.getInt32(index)});
builder.CreateStore(mapOpSize, sizeGEP);
- mapTypeFlags.push_back(mapTypeOp.dyn_cast<mlir::IntegerAttr>().getInt());
+ mapTypeFlags.push_back(dyn_cast<mlir::IntegerAttr>(mapTypeOp).getInt());
llvm::Constant *mapName =
mlir::LLVM::createMappingInformation(mapOp.getLoc(), *ompBuilder);
mapNames.push_back(mapName);
@@ -1445,7 +1445,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (auto constOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(
devId.getDefiningOp()))
if (auto intAttr =
- constOp.getValue().dyn_cast<mlir::IntegerAttr>())
+ dyn_cast<mlir::IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
numMapOperands = dataOp.getMapOperands().size();
@@ -1464,7 +1464,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (auto constOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(
devId.getDefiningOp()))
if (auto intAttr =
- constOp.getValue().dyn_cast<mlir::IntegerAttr>())
+ dyn_cast<mlir::IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
numMapOperands = enterDataOp.getMapOperands().size();
@@ -1483,7 +1483,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (auto constOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(
devId.getDefiningOp()))
if (auto intAttr =
- constOp.getValue().dyn_cast<mlir::IntegerAttr>())
+ dyn_cast<mlir::IntegerAttr>(constOp.getValue()))
deviceID = intAttr.getInt();
numMapOperands = exitDataOp.getMapOperands().size();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp
index fd739cf4d7e96..2145b953b64e6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMPCommon.cpp
@@ -16,7 +16,7 @@ llvm::Constant *
mlir::LLVM::createSourceLocStrFromLocation(Location loc,
llvm::OpenMPIRBuilder &builder,
StringRef name, uint32_t &strLen) {
- if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
+ if (auto fileLoc = dyn_cast<FileLineColLoc>(loc)) {
StringRef fileName = fileLoc.getFilename();
unsigned lineNo = fileLoc.getLine();
unsigned colNo = fileLoc.getColumn();
@@ -32,7 +32,7 @@ llvm::Constant *
mlir::LLVM::createMappingInformation(Location loc,
llvm::OpenMPIRBuilder &builder) {
uint32_t strLen;
- if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
+ if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
StringRef name = nameLoc.getName();
return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name,
strLen);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 826bac9d73e5a..5ab70280f6c81 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -109,7 +109,7 @@ class ROCDLDialectLLVMIRTranslationInterface
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
- auto value = attribute.getValue().dyn_cast<IntegerAttr>();
+ auto value = dyn_cast<IntegerAttr>(attribute.getValue());
if (!value)
return failure();
@@ -125,7 +125,7 @@ class ROCDLDialectLLVMIRTranslationInterface
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
- auto value = attribute.getValue().dyn_cast<StringAttr>();
+ auto value = dyn_cast<StringAttr>(attribute.getValue());
if (!value)
return failure();
@@ -142,7 +142,7 @@ class ROCDLDialectLLVMIRTranslationInterface
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
return failure();
- auto value = attribute.getValue().dyn_cast<DenseI32ArrayAttr>();
+ auto value = dyn_cast<DenseI32ArrayAttr>(attribute.getValue());
if (!value)
return failure();
llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
diff --git a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
index 8e6906a215325..d3c5bb4466392 100644
--- a/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp
@@ -190,7 +190,7 @@ void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options) {
void LoopAnnotationConversion::convertLocation(FusedLoc location) {
auto localScopeAttr =
- location.getMetadata().dyn_cast_or_null<DILocalScopeAttr>();
+ dyn_cast_or_null<DILocalScopeAttr>(location.getMetadata());
if (!localScopeAttr)
return;
auto *localScope = dyn_cast<llvm::DILocalScope>(
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index ba378930863b0..f4ea8017ac8ec 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -623,7 +623,7 @@ void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
/// Returns if `type` is a scalar integer or floating-point type.
static bool isScalarType(Type type) {
- return type.isa<IntegerType, FloatType>();
+ return isa<IntegerType, FloatType>(type);
}
/// Returns `type` if it is a builtin integer or floating-point vector type that
@@ -970,7 +970,7 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
// Convert constants that can be represented as attributes.
if (Attribute attr = getConstantAsAttr(constant)) {
Type type = convertType(constant->getType());
- if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>()) {
+ if (auto symbolRef = dyn_cast<FlatSymbolRefAttr>(attr)) {
return builder.create<AddressOfOp>(loc, type, symbolRef.getValue())
.getResult();
}
@@ -1047,7 +1047,7 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
// Generate an UndefOp as root value and insert the aggregate elements.
Type rootType = convertType(constant->getType());
- bool isArrayOrStruct = rootType.isa<LLVMArrayType, LLVMStructType>();
+ bool isArrayOrStruct = isa<LLVMArrayType, LLVMStructType>(rootType);
assert((isArrayOrStruct || LLVM::isCompatibleVectorType(rootType)) &&
"unrecognized aggregate type");
Value root = builder.create<UndefOp>(loc, rootType);
@@ -1609,7 +1609,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
clearBlockAndValueMapping();
auto functionType =
- convertType(func->getFunctionType()).dyn_cast<LLVMFunctionType>();
+ dyn_cast<LLVMFunctionType>(convertType(func->getFunctionType()));
if (func->isIntrinsic() &&
iface.isConvertibleIntrinsic(func->getIntrinsicID()))
return success();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index f8854d75f972b..9d796c55cb14d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -73,7 +73,7 @@ translateDataLayout(DataLayoutSpecInterface attribute,
if (!key)
continue;
if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
- auto value = entry.getValue().cast<StringAttr>();
+ auto value = cast<StringAttr>(entry.getValue());
bool isLittleEndian =
value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle;
layoutStream << "-" << (isLittleEndian ? "e" : "E");
@@ -81,7 +81,7 @@ translateDataLayout(DataLayoutSpecInterface attribute,
continue;
}
if (key.getValue() == DLTIDialect::kDataLayoutAllocaMemorySpaceKey) {
- auto value = entry.getValue().cast<IntegerAttr>();
+ auto value = cast<IntegerAttr>(entry.getValue());
uint64_t space = value.getValue().getZExtValue();
// Skip the default address space.
if (space == 0)
@@ -91,7 +91,7 @@ translateDataLayout(DataLayoutSpecInterface attribute,
continue;
}
if (key.getValue() == DLTIDialect::kDataLayoutStackAlignmentKey) {
- auto value = entry.getValue().cast<IntegerAttr>();
+ auto value = cast<IntegerAttr>(entry.getValue());
uint64_t alignment = value.getValue().getZExtValue();
// Skip the default stack alignment.
if (alignment == 0)
@@ -112,14 +112,14 @@ translateDataLayout(DataLayoutSpecInterface attribute,
if (!type)
continue;
// Data layout for the index type is irrelevant at this point.
- if (type.isa<IndexType>())
+ if (isa<IndexType>(type))
continue;
layoutStream << "-";
LogicalResult result =
llvm::TypeSwitch<Type, LogicalResult>(type)
.Case<IntegerType, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type>([&](Type type) -> LogicalResult {
- if (auto intType = type.dyn_cast<IntegerType>()) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
if (intType.getSignedness() != IntegerType::Signless)
return emitError(*loc)
<< "unsupported data layout for non-signless integer "
@@ -250,7 +250,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
// Compute the shape of all dimensions but the innermost. Note that the
// innermost dimension may be that of the vector element type.
- bool hasVectorElementType = type.getElementType().isa<VectorType>();
+ bool hasVectorElementType = isa<VectorType>(type.getElementType());
unsigned numAggregates =
denseElementsAttr.getNumElements() /
(hasVectorElementType ? 1
@@ -261,7 +261,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
// Handle the case of vector splat, LLVM has special support for it.
if (denseElementsAttr.isSplat() &&
- (type.isa<VectorType>() || hasVectorElementType)) {
+ (isa<VectorType>(type) || hasVectorElementType)) {
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
moduleTranslation);
@@ -277,8 +277,8 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
// In case of non-splat, create a constructor for the innermost constant from
// a piece of raw data.
std::function<llvm::Constant *(StringRef)> buildCstData;
- if (type.isa<TensorType>()) {
- auto vectorElementType = type.getElementType().dyn_cast<VectorType>();
+ if (isa<TensorType>(type)) {
+ auto vectorElementType = dyn_cast<VectorType>(type.getElementType());
if (vectorElementType && vectorElementType.getRank() == 1) {
buildCstData = [&](StringRef data) {
return llvm::ConstantDataVector::getRaw(
@@ -290,7 +290,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
innermostLLVMType);
};
}
- } else if (type.isa<VectorType>()) {
+ } else if (isa<VectorType>(type)) {
buildCstData = [&](StringRef data) {
return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
innermostLLVMType);
@@ -326,7 +326,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
if (!attr)
return llvm::UndefValue::get(llvmType);
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
- auto arrayAttr = attr.dyn_cast<ArrayAttr>();
+ auto arrayAttr = dyn_cast<ArrayAttr>(attr);
if (!arrayAttr || arrayAttr.size() != 2) {
emitError(loc, "expected struct type to be a complex number");
return nullptr;
@@ -344,11 +344,11 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
}
// For integer types, we allow a mismatch in sizes as the index type in
// MLIR might have a
diff erent size than the index type in the LLVM module.
- if (auto intAttr = attr.dyn_cast<IntegerAttr>())
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
return llvm::ConstantInt::get(
llvmType,
intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
- if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
+ if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
if (llvmType !=
llvm::Type::getFloatingPointTy(llvmType->getContext(),
floatAttr.getValue().getSemantics())) {
@@ -357,10 +357,10 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
}
return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
}
- if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
+ if (auto funcAttr = dyn_cast<FlatSymbolRefAttr>(attr))
return llvm::ConstantExpr::getBitCast(
moduleTranslation.lookupFunction(funcAttr.getValue()), llvmType);
- if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+ if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) {
llvm::Type *elementType;
uint64_t numElements;
bool isScalable = false;
@@ -401,13 +401,13 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
// Try using raw elements data if possible.
if (llvm::Constant *result =
- convertDenseElementsAttr(loc, attr.dyn_cast<DenseElementsAttr>(),
+ convertDenseElementsAttr(loc, dyn_cast<DenseElementsAttr>(attr),
llvmType, moduleTranslation)) {
return result;
}
// Fall back to element-by-element construction otherwise.
- if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
+ if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
assert(elementsAttr.getShapedType().hasStaticShape());
assert(!elementsAttr.getShapedType().getShape().empty() &&
"unexpected empty elements attribute shape");
@@ -428,7 +428,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
return result;
}
- if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
+ if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
return llvm::ConstantDataArray::get(
moduleTranslation.getLLVMContext(),
ArrayRef<char>{stringAttr.getValue().data(),
@@ -685,7 +685,7 @@ LogicalResult ModuleTranslation::convertGlobals() {
if (op.getValueOrNull()) {
// String attributes are treated separately because they cannot appear as
// in-function constants and are thus not supported by getLLVMConstant.
- if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
+ if (auto strAttr = dyn_cast_or_null<StringAttr>(op.getValueOrNull())) {
cst = llvm::ConstantDataArray::getString(
llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
type = cst->getType();
@@ -763,11 +763,10 @@ LogicalResult ModuleTranslation::convertGlobals() {
ctorOp ? llvm::appendToGlobalCtors : llvm::appendToGlobalDtors;
for (auto symbolAndPriority : range) {
llvm::Function *f = lookupFunction(
- std::get<0>(symbolAndPriority).cast<FlatSymbolRefAttr>().getValue());
- appendGlobalFn(
- *llvmModule, f,
- std::get<1>(symbolAndPriority).cast<IntegerAttr>().getInt(),
- /*Data=*/nullptr);
+ cast<FlatSymbolRefAttr>(std::get<0>(symbolAndPriority)).getValue());
+ appendGlobalFn(*llvmModule, f,
+ cast<IntegerAttr>(std::get<1>(symbolAndPriority)).getInt(),
+ /*Data=*/nullptr);
}
}
@@ -830,20 +829,20 @@ forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes,
return success();
for (Attribute attr : *attributes) {
- if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
+ if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
if (failed(
checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
return failure();
continue;
}
- auto arrayAttr = attr.dyn_cast<ArrayAttr>();
+ auto arrayAttr = dyn_cast<ArrayAttr>(attr);
if (!arrayAttr || arrayAttr.size() != 2)
return emitError(loc)
<< "expected 'passthrough' to contain string or array attributes";
- auto keyAttr = arrayAttr[0].dyn_cast<StringAttr>();
- auto valueAttr = arrayAttr[1].dyn_cast<StringAttr>();
+ auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]);
+ auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]);
if (!keyAttr || !valueAttr)
return emitError(loc)
<< "expected arrays within 'passthrough' to contain two strings";
@@ -985,7 +984,7 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
// Convert result attributes.
if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
- DictionaryAttr resultAttrs = allResultAttrs[0].cast<DictionaryAttr>();
+ DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
}
@@ -1133,7 +1132,7 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
return;
}
- SymbolRefAttr tagRef = tagRefs[0].cast<SymbolRefAttr>();
+ SymbolRefAttr tagRef = cast<SymbolRefAttr>(tagRefs[0]);
llvm::MDNode *node = getTBAANode(op, tagRef);
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
}
@@ -1192,7 +1191,7 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
// The type references are in 1, 3, 5, etc. positions.
unsigned opNum = 1;
for (Attribute typeAttr : tdOp.getMembers()) {
- refNames.push_back(typeAttr.cast<FlatSymbolRefAttr>().getValue());
+ refNames.push_back(cast<FlatSymbolRefAttr>(typeAttr).getValue());
operandIndices.push_back(opNum);
opNum += 2;
}
@@ -1299,7 +1298,7 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
if (auto dataLayoutAttr =
m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
- llvmModule->setDataLayout(dataLayoutAttr.cast<StringAttr>().getValue());
+ llvmModule->setDataLayout(cast<StringAttr>(dataLayoutAttr).getValue());
} else {
FailureOr<llvm::DataLayout> llvmDataLayout(llvm::DataLayout(""));
if (auto iface = dyn_cast<DataLayoutOpInterface>(m)) {
@@ -1319,7 +1318,7 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
}
if (auto targetTripleAttr =
m->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
- llvmModule->setTargetTriple(targetTripleAttr.cast<StringAttr>().getValue());
+ llvmModule->setTargetTriple(cast<StringAttr>(targetTripleAttr).getValue());
// Inject declarations for `malloc` and `free` functions that can be used in
// memref allocation/deallocation coming from standard ops lowering.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 4c3713fa9a75b..1724808f0ba8e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -364,11 +364,11 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
}
Type fnType = getType(operands[3]);
- if (!fnType || !fnType.isa<FunctionType>()) {
+ if (!fnType || !isa<FunctionType>(fnType)) {
return emitError(unknownLoc, "unknown function type from <id> ")
<< operands[3];
}
- auto functionType = fnType.cast<FunctionType>();
+ auto functionType = cast<FunctionType>(fnType);
if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
(functionType.getNumResults() == 1 &&
@@ -562,7 +562,7 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "unknown result type <id> : ")
<< operands[wordIndex];
}
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType) {
return emitError(unknownLoc,
"expected a result type <id> to be a spirv.ptr, found : ")
@@ -623,7 +623,7 @@ IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
if (!constInfo) {
return nullptr;
}
- return constInfo->first.dyn_cast<IntegerAttr>();
+ return dyn_cast<IntegerAttr>(constInfo->first);
}
LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
@@ -825,7 +825,7 @@ spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
<< operands[2] << "can only come from normal constant right now";
}
- if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
+ if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
count = intVal.getValue().getZExtValue();
} else {
return emitError(unknownLoc, "OpTypeArray count must come from a "
@@ -1172,7 +1172,7 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
auto resultID = operands[1];
- if (auto intType = resultType.dyn_cast<IntegerType>()) {
+ if (auto intType = dyn_cast<IntegerType>(resultType)) {
auto bitwidth = intType.getWidth();
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
return failure();
@@ -1205,7 +1205,7 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
return success();
}
- if (auto floatType = resultType.dyn_cast<FloatType>()) {
+ if (auto floatType = dyn_cast<FloatType>(resultType)) {
auto bitwidth = floatType.getWidth();
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
return failure();
@@ -1295,12 +1295,12 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
- if (auto vectorType = resultType.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(resultType)) {
auto attr = DenseElementsAttr::get(vectorType, elements);
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
constantMap.try_emplace(resultID, attr, resultType);
- } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
+ } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
auto attr = opBuilder.getArrayAttr(elements);
constantMap.try_emplace(resultID, attr, resultType);
} else {
@@ -1444,7 +1444,7 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
- if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
+ if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
auto attr = opBuilder.getZeroAttr(resultType);
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 487b66769390e..613e4f6738df6 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -244,7 +244,7 @@ class Deserializer {
Type getUndefType(uint32_t id) { return undefMap.lookup(id); }
/// Returns true if the given `type` is for SPIR-V void type.
- bool isVoidType(Type type) const { return type.isa<NoneType>(); }
+ bool isVoidType(Type type) const { return isa<NoneType>(type); }
/// Processes a SPIR-V type instruction with given `opcode` and `operands` and
/// registers the type into `module`.
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index d863ab475176b..f3e8a4b84e892 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -98,7 +98,7 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
auto constituents = op.getConstituents();
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
- auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
+ auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
auto constituentName = constituent.getValue();
auto constituentID = getSpecConstID(constituentName);
@@ -280,7 +280,7 @@ LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
if (attr) {
operands.push_back(
- static_cast<uint32_t>(attr.cast<spirv::StorageClassAttr>().getValue()));
+ static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
}
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
for (auto arg : op.getODSOperands(0)) {
@@ -491,7 +491,7 @@ LogicalResult Serializer::processBranchConditionalOp(
if (auto weights = condBranchOp.getBranchWeights()) {
for (auto val : weights->getValue())
- arguments.push_back(val.cast<IntegerAttr>().getInt());
+ arguments.push_back(cast<IntegerAttr>(val).getInt());
}
if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
@@ -554,7 +554,7 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
// Add the interface values.
if (auto interface = op.getInterface()) {
for (auto var : interface.getValue()) {
- auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
+ auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
if (!id) {
return op.emitError(
"referencing undefined global variable."
@@ -617,7 +617,7 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
operands.push_back(valueID);
}
- if (!resultTy.isa<NoneType>())
+ if (!isa<NoneType>(resultTy))
valueIDMap[op.getResult(0)] = funcCallID;
encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
@@ -638,28 +638,28 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
if (auto attr = op->getAttr("memory_access")) {
operands.push_back(
- static_cast<uint32_t>(attr.cast<spirv::MemoryAccessAttr>().getValue()));
+ static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
elidedAttrs.push_back("memory_access");
if (auto attr = op->getAttr("alignment")) {
operands.push_back(static_cast<uint32_t>(
- attr.cast<IntegerAttr>().getValue().getZExtValue()));
+ cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
elidedAttrs.push_back("alignment");
if (auto attr = op->getAttr("source_memory_access")) {
operands.push_back(
- static_cast<uint32_t>(attr.cast<spirv::MemoryAccessAttr>().getValue()));
+ static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
elidedAttrs.push_back("source_memory_access");
if (auto attr = op->getAttr("source_alignment")) {
operands.push_back(static_cast<uint32_t>(
- attr.cast<IntegerAttr>().getValue().getZExtValue()));
+ cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
elidedAttrs.push_back("source_alignment");
@@ -689,7 +689,7 @@ LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
for (Value operand : op->getOperands())
operands.push_back(getValueID(operand));
spirv::StorageClass resultStorage =
- resultTy.cast<spirv::PointerType>().getStorageClass();
+ cast<spirv::PointerType>(resultTy).getStorageClass();
operands.push_back(static_cast<uint32_t>(resultStorage));
encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 292ed97cfc066..b1c5dfd5e6bc9 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -144,7 +144,7 @@ void Serializer::printValueIDMap(raw_ostream &os) {
<< "id = " << valueIDPair.second << ' ';
if (auto *op = val.getDefiningOp()) {
os << "from op '" << op->getName() << "'";
- } else if (auto arg = val.dyn_cast<BlockArgument>()) {
+ } else if (auto arg = dyn_cast<BlockArgument>(val)) {
Block *block = arg.getOwner();
os << "from argument of block " << block << ' ';
os << " in op '" << block->getParentOp()->getName() << "'";
@@ -176,7 +176,7 @@ void Serializer::processCapability() {
void Serializer::processDebugInfo() {
if (!options.emitDebugInfo)
return;
- auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
+ auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
fileID = getNextID();
SmallVector<uint32_t, 16> operands;
@@ -221,13 +221,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
- if (auto intAttr = attr.getValue().dyn_cast<IntegerAttr>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
args.push_back(intAttr.getValue().getZExtValue());
break;
}
return emitError(loc, "expected integer attribute for ") << attrName;
case spirv::Decoration::BuiltIn:
- if (auto strAttr = attr.getValue().dyn_cast<StringAttr>()) {
+ if (auto strAttr = dyn_cast<StringAttr>(attr.getValue())) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) {
args.push_back(static_cast<uint32_t>(*enumVal));
@@ -245,7 +245,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::Restrict:
case spirv::Decoration::RelaxedPrecision:
// For unit attributes, the args list has no values so we do nothing
- if (auto unitAttr = attr.getValue().dyn_cast<UnitAttr>())
+ if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
break;
return emitError(loc, "expected unit attribute for ") << attrName;
default:
@@ -307,13 +307,13 @@ LogicalResult Serializer::processMemberDecoration(
// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
// PushConstant Storage Classes must be explicitly laid out."
bool Serializer::isInterfaceStructPtrType(Type type) const {
- if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
+ if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
switch (ptrType.getStorageClass()) {
case spirv::StorageClass::PhysicalStorageBuffer:
case spirv::StorageClass::PushConstant:
case spirv::StorageClass::StorageBuffer:
case spirv::StorageClass::Uniform:
- return ptrType.getPointeeType().isa<spirv::StructType>();
+ return isa<spirv::StructType>(ptrType.getPointeeType());
default:
break;
}
@@ -343,8 +343,8 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
auto typeEnum = spirv::Opcode::OpTypeVoid;
bool deferSerialization = false;
- if ((type.isa<FunctionType>() &&
- succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
+ if ((isa<FunctionType>(type) &&
+ succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
operands))) ||
succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
deferSerialization, serializationCtx))) {
@@ -390,7 +390,7 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto intType = type.dyn_cast<IntegerType>()) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
if (intType.getWidth() == 1) {
typeEnum = spirv::Opcode::OpTypeBool;
return success();
@@ -406,13 +406,13 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (auto floatType = dyn_cast<FloatType>(type)) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
return success();
}
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
serializationCtx))) {
@@ -424,7 +424,7 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto imageType = type.dyn_cast<spirv::ImageType>()) {
+ if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
typeEnum = spirv::Opcode::OpTypeImage;
uint32_t sampledTypeID = 0;
if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
@@ -440,7 +440,7 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
typeEnum = spirv::Opcode::OpTypeArray;
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
@@ -455,10 +455,10 @@ LogicalResult Serializer::prepareBasicType(
return processTypeDecoration(loc, arrayType, resultID);
}
- if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
+ if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
uint32_t pointeeTypeID = 0;
spirv::StructType pointeeStruct =
- ptrType.getPointeeType().dyn_cast<spirv::StructType>();
+ dyn_cast<spirv::StructType>(ptrType.getPointeeType());
if (pointeeStruct && pointeeStruct.isIdentified() &&
serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
@@ -510,7 +510,7 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
+ if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
elementTypeID, serializationCtx))) {
@@ -521,7 +521,7 @@ LogicalResult Serializer::prepareBasicType(
return processTypeDecoration(loc, runtimeArrayType, resultID);
}
- if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) {
+ if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
typeEnum = spirv::Opcode::OpTypeSampledImage;
uint32_t imageTypeID = 0;
if (failed(
@@ -532,7 +532,7 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto structType = type.dyn_cast<spirv::StructType>()) {
+ if (auto structType = dyn_cast<spirv::StructType>(type)) {
if (structType.isIdentified()) {
if (failed(processName(resultID, structType.getIdentifier())))
return failure();
@@ -581,7 +581,7 @@ LogicalResult Serializer::prepareBasicType(
}
if (auto cooperativeMatrixType =
- type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
+ dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
elementTypeID, serializationCtx))) {
@@ -600,7 +600,7 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto jointMatrixType = type.dyn_cast<spirv::JointMatrixINTELType>()) {
+ if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
elementTypeID, serializationCtx))) {
@@ -621,7 +621,7 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
+ if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
serializationCtx))) {
@@ -684,12 +684,12 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
}
uint32_t resultID = 0;
- if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
- int rank = attr.getType().dyn_cast<ShapedType>().getRank();
+ if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
+ int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
SmallVector<uint64_t, 4> index(rank);
resultID = prepareDenseElementsConstant(loc, constType, attr,
/*dim=*/0, index);
- } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
resultID = prepareArrayConstant(loc, constType, arrayAttr);
}
@@ -712,7 +712,7 @@ uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
operands.reserve(attr.size() + 2);
- auto elementType = constType.cast<spirv::ArrayType>().getElementType();
+ auto elementType = cast<spirv::ArrayType>(constType).getElementType();
for (Attribute elementAttr : attr) {
if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
operands.push_back(elementID);
@@ -732,16 +732,16 @@ uint32_t
Serializer::prepareDenseElementsConstant(Location loc, Type constType,
DenseElementsAttr valueAttr, int dim,
MutableArrayRef<uint64_t> index) {
- auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
+ auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
assert(dim <= shapedType.getRank());
if (shapedType.getRank() == dim) {
- if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
+ if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
return attr.getType().getElementType().isInteger(1)
? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
: prepareConstantInt(loc,
attr.getValues<IntegerAttr>()[index]);
}
- if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
+ if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
}
return 0;
@@ -755,7 +755,7 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
operands.reserve(shapedType.getDimSize(dim) + 2);
- auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
+ auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
index[dim] = i;
if (auto elementID = prepareDenseElementsConstant(
@@ -773,13 +773,13 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
bool isSpec) {
- if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
+ if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
return prepareConstantFp(loc, floatAttr, isSpec);
}
- if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
+ if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
return prepareConstantBool(loc, boolAttr, isSpec);
}
- if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
return prepareConstantInt(loc, intAttr, isSpec);
}
@@ -797,8 +797,7 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
// Process the type for this bool literal
uint32_t typeID = 0;
- if (failed(
- processType(loc, boolAttr.cast<IntegerAttr>().getType(), typeID))) {
+ if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
return 0;
}
@@ -1246,7 +1245,7 @@ LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
return success();
}
- auto fileLoc = loc.dyn_cast<FileLineColLoc>();
+ auto fileLoc = dyn_cast<FileLineColLoc>(loc);
if (fileLoc)
encodeInstructionInto(binary, spirv::Opcode::OpLine,
{fileID, fileLoc.getLine(), fileLoc.getColumn()});
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index ab9b901fa2691..4b2ebf610bd72 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -156,7 +156,7 @@ class Serializer {
Type getVoidType() { return mlirBuilder.getNoneType(); }
- bool isVoidType(Type type) const { return type.isa<NoneType>(); }
+ bool isVoidType(Type type) const { return isa<NoneType>(type); }
/// Returns true if the given type is a pointer type to a struct in some
/// interface storage class.
diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
index 8225680aac91e..00b68162a1ed2 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
@@ -239,7 +239,7 @@ void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
// replacement values.
bool usesReplOperation =
replValues.size() == 1 &&
- replValues.front().getType().isa<pdl::OperationType>();
+ isa<pdl::OperationType>(replValues.front().getType());
builder.create<pdl::ReplaceOp>(
loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
usesReplOperation ? ValueRange() : ValueRange(replValues));
@@ -441,7 +441,7 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
Type mlirType = genType(expr->getType());
- if (mlirType.isa<pdl::ValueType>())
+ if (isa<pdl::ValueType>(mlirType))
return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
builder.getI32IntegerAttr(0));
return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index 5e4dc07c4c3f1..7278aba438b96 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -58,7 +58,7 @@ getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
std::optional<lsp::Location> location;
loc->walk([&](Location nestedLoc) {
- FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
+ FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
if (!fileLoc)
return WalkResult::advance();
@@ -91,7 +91,7 @@ static void collectLocationsFromLoc(Location loc,
const lsp::URIForFile &uri) {
SetVector<Location> visitedLocs;
loc->walk([&](Location nestedLoc) {
- FileLineColLoc fileLoc = nestedLoc.dyn_cast<FileLineColLoc>();
+ FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
if (!fileLoc || !visitedLocs.insert(nestedLoc))
return WalkResult::advance();
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index e98ccccd68d2d..2884295c215af 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -136,7 +136,7 @@ void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
// If the existing operation has an unknown location and the current
// operation doesn't, then set the existing op's location to that of the
// current op.
- if (existing->getLoc().isa<UnknownLoc>() && !op->getLoc().isa<UnknownLoc>())
+ if (isa<UnknownLoc>(existing->getLoc()) && !isa<UnknownLoc>(op->getLoc()))
existing->setLoc(op->getLoc());
++numCSE;
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 5d6eaad62eed4..57ccb3b2057c1 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -345,7 +345,7 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
// TODO: Support inlining nested call references.
CallInterfaceCallable callable = call.getCallableForCallee();
if (SymbolRefAttr symRef = dyn_cast<SymbolRefAttr>(callable)) {
- if (!symRef.isa<FlatSymbolRefAttr>())
+ if (!isa<FlatSymbolRefAttr>(symRef))
continue;
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index a4bf97ca5ffdb..45d6f7d0c1ed8 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -99,7 +99,7 @@ MemorySlotPromoter::MemorySlotPromoter(
info(std::move(info)) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
- if (BlockArgument arg = slot.ptr.dyn_cast<BlockArgument>())
+ if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
return arg.getOwner()->getParentOp() == allocator;
return slot.ptr.getDefiningOp() == allocator;
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f812d4d72ca98..615c8e4a99ceb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -401,7 +401,7 @@ static Value buildUnresolvedTargetMaterialization(
SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
Block *insertBlock = input.getParentBlock();
Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = input.dyn_cast<OpResult>())
+ if (OpResult inputRes = dyn_cast<OpResult>(input))
insertPt = ++inputRes.getOwner()->getIterator();
return buildUnresolvedMaterialization(
@@ -1033,7 +1033,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
if (!repl)
continue;
- if (repl.isa<BlockArgument>()) {
+ if (isa<BlockArgument>(repl)) {
arg.replaceAllUsesWith(repl);
continue;
}
@@ -1041,7 +1041,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// If the replacement value is an operation, we check to make sure that we
// don't replace uses that are within the parent operation of the
// replacement value.
- Operation *replOp = repl.cast<OpResult>().getOwner();
+ Operation *replOp = cast<OpResult>(repl).getOwner();
Block *replBlock = replOp->getBlock();
arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
@@ -2615,7 +2615,7 @@ static void computeNecessaryMaterializations(
}
// Check to see if this is an argument materialization.
- auto isBlockArg = [](Value v) { return v.isa<BlockArgument>(); };
+ auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); };
if (llvm::any_of(op->getOperands(), isBlockArg) ||
llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
mat->setKind(UnresolvedMaterialization::Argument);
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 13e7fa8df2972..44908510205f7 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -384,7 +384,7 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
assert(castKind == getCastKindName(CastKind::Argument) &&
"unexpected value of cast kind attribute");
assert(llvm::all_of(operands,
- [&](Value v) { return v.isa<BlockArgument>(); }));
+ [&](Value v) { return isa<BlockArgument>(v); }));
maybeResult = typeConverter.materializeArgumentConversion(
rewriter, castOp->getLoc(), resultTypes.front(),
castOp.getOperands());
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 2933758db62b9..b95af9ca02996 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -244,17 +244,17 @@ class LiveMap {
bool wasProvenLive(Value value) {
// TODO: For results that are removable, e.g. for region based control flow,
// we could allow for these values to be tracked independently.
- if (OpResult result = value.dyn_cast<OpResult>())
+ if (OpResult result = dyn_cast<OpResult>(value))
return wasProvenLive(result.getOwner());
- return wasProvenLive(value.cast<BlockArgument>());
+ return wasProvenLive(cast<BlockArgument>(value));
}
bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
void setProvedLive(Value value) {
// TODO: For results that are removable, e.g. for region based control flow,
// we could allow for these values to be tracked independently.
- if (OpResult result = value.dyn_cast<OpResult>())
+ if (OpResult result = dyn_cast<OpResult>(value))
return setProvedLive(result.getOwner());
- setProvedLive(value.cast<BlockArgument>());
+ setProvedLive(cast<BlockArgument>(value));
}
void setProvedLive(BlockArgument arg) {
changed |= liveValues.insert(arg).second;
@@ -538,11 +538,11 @@ unsigned BlockEquivalenceData::getOrderOf(Value value) const {
assert(value.getParentBlock() == block && "expected value of this block");
// Arguments use the argument number as the order index.
- if (BlockArgument arg = value.dyn_cast<BlockArgument>())
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value))
return arg.getArgNumber();
// Otherwise, the result order is offset from the parent op's order.
- OpResult result = value.cast<OpResult>();
+ OpResult result = cast<OpResult>(value);
auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
return opOrderIt->second + result.getResultNumber();
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 4598b56a901d5..def8a1443b1aa 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -145,13 +145,13 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
int64_t largeAttrLimit = getLargeAttributeSizeLimit();
// Always emit splat attributes.
- if (attr.isa<SplatElementsAttr>()) {
+ if (isa<SplatElementsAttr>(attr)) {
attr.print(os);
return;
}
// Elide "big" elements attributes.
- auto elements = attr.dyn_cast<ElementsAttr>();
+ auto elements = dyn_cast<ElementsAttr>(attr);
if (elements && elements.getNumElements() > largeAttrLimit) {
os << std::string(elements.getShapedType().getRank(), '[') << "..."
<< std::string(elements.getShapedType().getRank(), ']') << " : "
@@ -159,7 +159,7 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
return;
}
- auto array = attr.dyn_cast<ArrayAttr>();
+ auto array = dyn_cast<ArrayAttr>(attr);
if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
os << "[...]";
return;
diff --git a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp
index b563be4a57d40..95c16a6cc5893 100644
--- a/mlir/test/lib/Analysis/TestAliasAnalysis.cpp
+++ b/mlir/test/lib/Analysis/TestAliasAnalysis.cpp
@@ -24,7 +24,7 @@ static void printAliasOperand(Operation *op) {
llvm::errs() << op->getAttrOfType<StringAttr>("test.ptr").getValue();
}
static void printAliasOperand(Value value) {
- if (BlockArgument arg = value.dyn_cast<BlockArgument>()) {
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
Region *region = arg.getParentRegion();
unsigned parentBlockNumber =
std::distance(region->begin(), arg.getOwner()->getIterator());
@@ -37,7 +37,7 @@ static void printAliasOperand(Value value) {
llvm::errs() << "#" << arg.getArgNumber();
return;
}
- OpResult result = value.cast<OpResult>();
+ OpResult result = cast<OpResult>(value);
printAliasOperand(result.getOwner());
llvm::errs() << "#" << result.getResultNumber();
}
@@ -156,7 +156,7 @@ struct TestAliasAnalysisModRefPass
/// Check if value is function argument.
static bool isFuncArg(Value val) {
- auto blockArg = val.dyn_cast<BlockArgument>();
+ auto blockArg = dyn_cast<BlockArgument>(val);
if (!blockArg)
return false;
@@ -166,7 +166,7 @@ static bool isFuncArg(Value val) {
/// Check if value has "restrict" attribute. Value must be a function argument.
static bool isRestrict(Value val) {
- auto blockArg = val.cast<BlockArgument>();
+ auto blockArg = cast<BlockArgument>(val);
auto func =
mlir::cast<FunctionOpInterface>(blockArg.getOwner()->getParentOp());
return !!func.getArgAttr(blockArg.getArgNumber(),
diff --git a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
index c9e72f844a1fa..968e10b8d0cab 100644
--- a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
+++ b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
@@ -32,7 +32,7 @@ struct TestMemRefStrideCalculation
void TestMemRefStrideCalculation::runOnOperation() {
llvm::outs() << "Testing: " << getOperation().getName() << "\n";
getOperation().walk([&](memref::AllocOp allocOp) {
- auto memrefType = allocOp.getResult().getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(allocOp.getResult().getType());
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memrefType, strides, offset))) {
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index c3f20989dbd6b..e1ccc1b900df5 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -102,7 +102,7 @@ class ConvertGetTupleElementOp
matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
// Construct mapping for tuple element types.
- auto stateType = op->getOperand(0).getType().cast<TupleType>();
+ auto stateType = cast<TupleType>(op->getOperand(0).getType());
TypeRange originalElementTypes = stateType.getTypes();
OneToNTypeMapping elementMapping(originalElementTypes);
if (failed(typeConverter->convertSignatureArgs(originalElementTypes,
@@ -148,7 +148,7 @@ static void populateDecomposeTuplesTestPatterns(TypeConverter &typeConverter,
static std::optional<SmallVector<Value>>
buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
Location loc) {
- TupleType inputType = input.getType().dyn_cast<TupleType>();
+ TupleType inputType = dyn_cast<TupleType>(input.getType());
if (!inputType)
return {};
@@ -156,7 +156,7 @@ buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) {
Value element = builder.create<::test::GetTupleElementOp>(
loc, elementType, input, builder.getI32IntegerAttr(idx));
- if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+ if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
// Recurse if the current element is also a tuple.
SmallVector<Type> flatRecursiveTypes;
nestedTupleType.getFlattenedTypes(flatRecursiveTypes);
@@ -186,7 +186,7 @@ static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
elements.reserve(resultType.getTypes().size());
ValueRange::iterator inputIt = inputs.begin();
for (Type elementType : resultType.getTypes()) {
- if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+ if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
// Determine how many input values are needed for the nested elements of
// the nested TupleType and advance inputIt by that number.
// TODO: We only need the *number* of nested types, not the types itself.
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 1bf3ce4ceb329..dff619efda28d 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -81,7 +81,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
return WalkResult::skip();
}
Value value = op->getOperand(0);
- if (value.getType().isa<IndexType>() !=
+ if (isa<IndexType>(value.getType()) !=
!op->hasAttrOfType<IntegerAttr>("dim")) {
// Op should have "dim" attribute if and only if the operand is an
// index-typed value.
@@ -119,7 +119,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
if (reifyToFuncArgs) {
// Reify in terms of function block arguments.
stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
- auto bbArg = v.dyn_cast<BlockArgument>();
+ auto bbArg = dyn_cast<BlockArgument>(v);
if (!bbArg)
return false;
return isa<FunctionOpInterface>(
@@ -166,7 +166,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
return WalkResult::skip();
}
Value constOp = rewriter.create<arith::ConstantIndexOp>(
- op->getLoc(), reified->get<Attribute>().cast<IntegerAttr>().getInt());
+ op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
rewriter.replaceOp(op, constOp);
return WalkResult::skip();
}
diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index 85dd0718c9f04..f8588fab3aef7 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -127,7 +127,7 @@ void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
// As a consequence we write only Ops with a single return type for the
// purpose of this test. If we need to test more intricate behavior in the
// future we can always extend.
- auto superVectorType = opInst->getResult(0).getType().cast<VectorType>();
+ auto superVectorType = cast<VectorType>(opInst->getResult(0).getType());
auto ratio =
computeShapeRatio(superVectorType.getShape(), subVectorType.getShape());
if (!ratio) {
@@ -211,8 +211,8 @@ void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) {
maps.reserve(matches.size());
for (auto m : llvm::reverse(matches)) {
auto *opInst = m.getMatchedOperation();
- auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
- .cast<AffineMapAttr>()
+ auto map = cast<AffineMapAttr>(
+ opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName))
.getValue();
maps.push_back(map);
}
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 41e166600c433..10aba733bd569 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -27,7 +27,7 @@ static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
Type elementType = resultType.getType(i);
Value element = builder.create<test::GetTupleElementOp>(
loc, elementType, value, builder.getI32IntegerAttr(i));
- if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+ if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
// Recurse if the current element is also a tuple.
if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element,
values)))
@@ -50,7 +50,7 @@ static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
elements.reserve(resultType.getTypes().size());
ValueRange::iterator inputIt = inputs.begin();
for (Type elementType : resultType.getTypes()) {
- if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+ if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
// Determine how many input values are needed for the nested elements of
// the nested TupleType and advance inputIt by that number.
// TODO: We only need the *number* of nested types, not the types itself.
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 50504988689b0..2231e427007a8 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -38,9 +38,9 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
bool changed = false;
for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
- if (opOperand.get().getType().isa<MemRefType>())
+ if (isa<MemRefType>(opOperand.get().getType()))
continue;
- if (opOperand.get().getType().isa<RankedTensorType>()) {
+ if (isa<RankedTensorType>(opOperand.get().getType())) {
// Tile and Fuse tensor input.
if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs())
continue;
diff --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
index 449a3e92b7da9..0f0875874c498 100644
--- a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
+++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
@@ -61,9 +61,9 @@ void ReportShapeFnPass::runOnOperation() {
if (attr) {
auto lookup = [&](Attribute attr) {
return cast<shape::FunctionLibraryOp>(
- SymbolTable::lookupSymbolIn(module, attr.cast<SymbolRefAttr>()));
+ SymbolTable::lookupSymbolIn(module, cast<SymbolRefAttr>(attr)));
};
- if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
libraries.reserve(arrayAttr.size());
for (auto attr : arrayAttr)
libraries.push_back(lookup(attr));
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 6dc8b4a27d29b..46fe86524797e 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -113,7 +113,7 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
if (!op.getSource().hasOneUse())
return false;
- auto resultType = op.getResult().getType().cast<ShapedType>();
+ auto resultType = cast<ShapedType>(op.getResult().getType());
constexpr int64_t kConstantFoldingMaxNumElements = 1024;
return resultType.getNumElements() <= kConstantFoldingMaxNumElements;
};
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 4660d9abe6769..715c77b9a3963 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -49,7 +49,7 @@ Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
}
LogicalResult MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
InFlightDiagnostic *diag) {
- StringAttr strAttr = attr.dyn_cast<StringAttr>();
+ StringAttr strAttr = dyn_cast<StringAttr>(attr);
if (!strAttr) {
if (diag)
*diag << "Expect StringAttr but got " << attr;
@@ -221,7 +221,7 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
//===------------------------------------------------------------------===//
AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
- StringAttr strAttr = attr.dyn_cast<StringAttr>();
+ StringAttr strAttr = dyn_cast<StringAttr>(attr);
if (!strAttr)
return AliasResult::NoAlias;
@@ -246,16 +246,16 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
}
AliasResult getAlias(Type type, raw_ostream &os) const final {
- if (auto tupleType = type.dyn_cast<TupleType>()) {
+ if (auto tupleType = dyn_cast<TupleType>(type)) {
if (tupleType.size() > 0 &&
llvm::all_of(tupleType.getTypes(), [](Type elemType) {
- return elemType.isa<SimpleAType>();
+ return isa<SimpleAType>(elemType);
})) {
os << "test_tuple";
return AliasResult::FinalAlias;
}
}
- if (auto intType = type.dyn_cast<TestIntegerType>()) {
+ if (auto intType = dyn_cast<TestIntegerType>(type)) {
if (intType.getSignedness() ==
TestIntegerType::SignednessSemantics::Unsigned &&
intType.getWidth() == 8) {
@@ -263,7 +263,7 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
- if (auto recType = type.dyn_cast<TestRecursiveType>()) {
+ if (auto recType = dyn_cast<TestRecursiveType>(type)) {
if (recType.getName() == "type_to_alias") {
// We only make alias for a specific recursive type.
os << "testrec";
@@ -1230,7 +1230,7 @@ void PolyForOp::getAsmBlockArgumentNames(Region ®ion,
auto args = getRegion().front().getArguments();
auto e = std::min(arrayAttr.size(), args.size());
for (unsigned i = 0; i < e; ++i) {
- if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
+ if (auto strAttr = dyn_cast<StringAttr>(arrayAttr[i]))
setNameFn(args[i], strAttr.getValue());
}
}
@@ -1252,7 +1252,7 @@ static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) {
}
static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) {
- p.printOptionalLocationSpecifier(loc.cast<LocationAttr>());
+ p.printOptionalLocationSpecifier(cast<LocationAttr>(loc));
}
//===----------------------------------------------------------------------===//
@@ -1376,7 +1376,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Create return type consisting of the last element of the first operand.
auto operandType = operands.front().getType();
- auto sval = operandType.dyn_cast<ShapedType>();
+ auto sval = dyn_cast<ShapedType>(operandType);
if (!sval) {
return emitOptionalError(location, "only shaped type operands allowed");
}
@@ -1384,7 +1384,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
auto type = IntegerType::get(context, 17);
Attribute encoding;
- if (auto rankedTy = sval.dyn_cast<RankedTensorType>())
+ if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
encoding = rankedTy.getEncoding();
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
return success();
@@ -1404,7 +1404,7 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
Location loc = getLoc();
shapes.reserve(operands.size());
for (Value operand : llvm::reverse(operands)) {
- auto rank = operand.getType().cast<RankedTensorType>().getRank();
+ auto rank = cast<RankedTensorType>(operand.getType()).getRank();
auto currShape = llvm::to_vector<4>(
llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
@@ -1421,7 +1421,7 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
Location loc = getLoc();
shapes.reserve(getNumOperands());
for (Value operand : llvm::reverse(getOperands())) {
- auto tensorType = operand.getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(operand.getType());
auto currShape = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, tensorType.getRank()),
[&](int64_t dim) -> OpFoldResult {
@@ -1471,12 +1471,12 @@ void SideEffectOp::getEffects(
// If there is one, it is an array of dictionary attributes that hold
// information on the effects of this operation.
for (Attribute element : effectsAttr) {
- DictionaryAttr effectElement = element.cast<DictionaryAttr>();
+ DictionaryAttr effectElement = cast<DictionaryAttr>(element);
// Get the specific memory effect.
MemoryEffects::Effect *effect =
StringSwitch<MemoryEffects::Effect *>(
- effectElement.get("effect").cast<StringAttr>().getValue())
+ cast<StringAttr>(effectElement.get("effect")).getValue())
.Case("allocate", MemoryEffects::Allocate::get())
.Case("free", MemoryEffects::Free::get())
.Case("read", MemoryEffects::Read::get())
@@ -1491,7 +1491,7 @@ void SideEffectOp::getEffects(
if (effectElement.get("on_result"))
effects.emplace_back(effect, getResult(), resource);
else if (Attribute ref = effectElement.get("on_reference"))
- effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
+ effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
else
effects.emplace_back(effect, resource);
}
@@ -1556,7 +1556,7 @@ void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
llvm::raw_svector_ostream tmpStream(resultNameStr);
p.printOperand(getResult(i), tmpStream);
- auto expectedName = getNames()[i].dyn_cast<StringAttr>();
+ auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
if (!expectedName ||
tmpStream.str().drop_front() != expectedName.getValue()) {
namesDisagree = true;
@@ -1576,7 +1576,7 @@ void StringAttrPrettyNameOp::getAsmResultNames(
auto value = getNames();
for (size_t i = 0, e = value.size(); i != e; ++i)
- if (auto str = value[i].dyn_cast<StringAttr>())
+ if (auto str = dyn_cast<StringAttr>(value[i]))
if (!str.getValue().empty())
setNameFn(getResult(i), str.getValue());
}
@@ -1585,7 +1585,7 @@ void CustomResultsNameOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
ArrayAttr value = getNames();
for (size_t i = 0, e = value.size(); i != e; ++i)
- if (auto str = value[i].dyn_cast<StringAttr>())
+ if (auto str = dyn_cast<StringAttr>(value[i]))
if (!str.getValue().empty())
setNameFn(getResult(i), str.getValue());
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index adaa6e1558999..a61ba8e47e3e8 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -153,7 +153,7 @@ struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
LogicalResult matchAndRewrite(AnyAttrOfOp op,
PatternRewriter &rewriter) const override {
- auto intAttr = op.getAttr().dyn_cast<IntegerAttr>();
+ auto intAttr = dyn_cast<IntegerAttr>(op.getAttr());
if (!intAttr)
return failure();
int64_t val = intAttr.getInt();
@@ -1271,11 +1271,11 @@ struct TestTypeConversionProducer
Type convertedType = getTypeConverter()
? getTypeConverter()->convertType(resultType)
: resultType;
- if (resultType.isa<FloatType>())
+ if (isa<FloatType>(resultType))
resultType = rewriter.getF64Type();
else if (resultType.isInteger(16))
resultType = rewriter.getIntegerType(64);
- else if (resultType.isa<test::TestRecursiveType>() &&
+ else if (isa<test::TestRecursiveType>(resultType) &&
convertedType != resultType)
resultType = convertedType;
else
@@ -1430,8 +1430,8 @@ struct TestTypeConversionDriver
inputs.empty())
return builder.create<TestTypeProducerOp>(loc, resultType);
// Allow producing an i64 from an integer.
- if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
- inputs[0].getType().isa<IntegerType>())
+ if (isa<IntegerType>(resultType) && inputs.size() == 1 &&
+ isa<IntegerType>(inputs[0].getType()))
return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
// Otherwise, fail.
return nullptr;
@@ -1440,7 +1440,7 @@ struct TestTypeConversionDriver
// Initialize the conversion target.
mlir::ConversionTarget target(getContext());
target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
- auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
+ auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType());
return op.getType().isF64() || op.getType().isInteger(64) ||
(recursiveType &&
recursiveType.getName() == "outer_converted_type");
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index c147ff471d281..9642301e8111c 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -42,20 +42,20 @@ ConvertTosaNegateOp::matchAndRewrite(Operation *op,
auto tosaNegateOp = cast<tosa::NegateOp>(op);
auto inputType =
- tosaNegateOp.getInput1().getType().dyn_cast<mlir::RankedTensorType>();
+ dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getInput1().getType());
// skip if input is not ranked tensor type
if (!inputType)
return failure();
// skip if it's not ranked tensor type.
auto outputType =
- tosaNegateOp.getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ dyn_cast<mlir::RankedTensorType>(tosaNegateOp.getResult().getType());
if (!outputType)
return failure();
// skip if output is not per-tensor quantized type.
auto outputElementType =
- outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());
if (!outputElementType)
return failure();
@@ -112,14 +112,14 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
auto tosaConv2DOp = cast<tosa::Conv2DOp>(op);
auto inputType =
- tosaConv2DOp.getInput().getType().dyn_cast<mlir::RankedTensorType>();
+ dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getInput().getType());
// skip if input is not ranked tensor type
if (!inputType)
return failure();
auto weightType =
- tosaConv2DOp.getWeight().getType().dyn_cast<mlir::RankedTensorType>();
+ dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getWeight().getType());
// skip if wt is not ranked tensor type
if (!weightType)
@@ -127,16 +127,16 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
// skip if it's not ranked tensor type.
auto outputType =
- tosaConv2DOp.getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ dyn_cast<mlir::RankedTensorType>(tosaConv2DOp.getResult().getType());
if (!outputType)
return failure();
auto inputQType =
- inputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ dyn_cast<mlir::quant::UniformQuantizedType>(inputType.getElementType());
auto weightQType =
- weightType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ dyn_cast<mlir::quant::UniformQuantizedType>(weightType.getElementType());
auto outputQType =
- outputType.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
+ dyn_cast<mlir::quant::UniformQuantizedType>(outputType.getElementType());
// Works on quantized type only.
if (!(inputQType && weightQType && outputQType))
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index dd853aa1dc3cc..d0c79ab989151 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -89,7 +89,7 @@ struct TestVectorToVectorLowering
auto extract = dyn_cast<ExtractStridedSliceOp>(users);
if (!extract)
return std::nullopt;
- auto vecType = extract.getResult().getType().cast<VectorType>();
+ auto vecType = cast<VectorType>(extract.getResult().getType());
if (dstVec && dstVec != vecType)
return std::nullopt;
dstVec = vecType;
@@ -430,7 +430,7 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
static constexpr int64_t kSharedMemorySpace = 3;
// Compute type of shared memory buffer.
MemRefType memrefType;
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
memrefType =
MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
kSharedMemorySpace);
@@ -535,7 +535,7 @@ struct TestVectorDistribution
// Create a map (d0, d1) -> (d1) to distribute along the inner
// dimension. Once we support n-d distribution we can add more
// complex cases.
- VectorType vecType = val.getType().dyn_cast<VectorType>();
+ VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
if (vecRank == 0)
@@ -642,9 +642,9 @@ struct TestCreateVectorBroadcast
if (op->getName().getStringRef() != "test_create_broadcast")
return;
auto targetShape =
- op->getResult(0).getType().cast<VectorType>().getShape();
+ cast<VectorType>(op->getResult(0).getType()).getShape();
auto arrayAttr =
- op->getAttr("broadcast_dims").cast<DenseI64ArrayAttr>().asArrayRef();
+ cast<DenseI64ArrayAttr>(op->getAttr("broadcast_dims")).asArrayRef();
llvm::SetVector<int64_t> broadcastedDims;
broadcastedDims.insert(arrayAttr.begin(), arrayAttr.end());
OpBuilder b(op);
diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index 9313f403ce1c5..498de3d87bd4b 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -34,7 +34,7 @@ struct TestElementsAttrInterface
void runOnOperation() override {
getOperation().walk([&](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) {
- auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
+ auto elementsAttr = dyn_cast<ElementsAttr>(attr.getValue());
if (!elementsAttr)
continue;
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
diff --git a/mlir/test/lib/IR/TestDiagnostics.cpp b/mlir/test/lib/IR/TestDiagnostics.cpp
index 1f5b29d00de75..578486c0a3b14 100644
--- a/mlir/test/lib/IR/TestDiagnostics.cpp
+++ b/mlir/test/lib/IR/TestDiagnostics.cpp
@@ -36,7 +36,7 @@ struct TestDiagnosticFilterPass
// Build a diagnostic handler that has filtering capabilities.
auto filterFn = [&](Location loc) {
// Ignore non-file locations.
- FileLineColLoc fileLoc = loc.dyn_cast<FileLineColLoc>();
+ FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc);
if (!fileLoc)
return true;
diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
index 171d46abfde63..45897882b0032 100644
--- a/mlir/test/lib/IR/TestFunc.cpp
+++ b/mlir/test/lib/IR/TestFunc.cpp
@@ -35,13 +35,13 @@ struct TestFuncInsertArg
SmallVector<Location, 4> locsToInsert;
for (auto insert : inserts.getAsRange<ArrayAttr>()) {
indicesToInsert.push_back(
- insert[0].cast<IntegerAttr>().getValue().getZExtValue());
- typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
+ cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
+ typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
attrsToInsert.push_back(insert.size() > 2
- ? insert[2].cast<DictionaryAttr>()
+ ? cast<DictionaryAttr>(insert[2])
: DictionaryAttr::get(&getContext()));
locsToInsert.push_back(insert.size() > 3
- ? Location(insert[3].cast<LocationAttr>())
+ ? Location(cast<LocationAttr>(insert[3]))
: unknownLoc);
}
func->removeAttr("test.insert_args");
@@ -72,10 +72,10 @@ struct TestFuncInsertResult
SmallVector<DictionaryAttr, 4> attrsToInsert;
for (auto insert : inserts.getAsRange<ArrayAttr>()) {
indicesToInsert.push_back(
- insert[0].cast<IntegerAttr>().getValue().getZExtValue());
- typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
+ cast<IntegerAttr>(insert[0]).getValue().getZExtValue());
+ typesToInsert.push_back(cast<TypeAttr>(insert[1]).getValue());
attrsToInsert.push_back(insert.size() > 2
- ? insert[2].cast<DictionaryAttr>()
+ ? cast<DictionaryAttr>(insert[2])
: DictionaryAttr::get(&getContext()));
}
func->removeAttr("test.insert_results");
diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp
index 633d5304bc9b3..2dd3fe245e220 100644
--- a/mlir/test/lib/IR/TestInterfaces.cpp
+++ b/mlir/test/lib/IR/TestInterfaces.cpp
@@ -27,7 +27,7 @@ struct TestTypeInterfaces
void runOnOperation() override {
getOperation().walk([](Operation *op) {
for (Type type : op->getResultTypes()) {
- if (auto testInterface = type.dyn_cast<TestTypeInterface>()) {
+ if (auto testInterface = dyn_cast<TestTypeInterface>(type)) {
testInterface.printTypeA(op->getLoc());
testInterface.printTypeB(op->getLoc());
testInterface.printTypeC(op->getLoc());
@@ -37,7 +37,7 @@ struct TestTypeInterfaces
TestTypeInterface result = testInterface.printTypeRet(op->getLoc());
(void)result;
}
- if (auto testType = type.dyn_cast<TestType>())
+ if (auto testType = dyn_cast<TestType>(type))
testType.printTypeE(op->getLoc());
}
});
diff --git a/mlir/test/lib/IR/TestOpaqueLoc.cpp b/mlir/test/lib/IR/TestOpaqueLoc.cpp
index 977d2b001a181..c0ce8965868ab 100644
--- a/mlir/test/lib/IR/TestOpaqueLoc.cpp
+++ b/mlir/test/lib/IR/TestOpaqueLoc.cpp
@@ -74,7 +74,7 @@ struct TestOpaqueLoc
ScopedDiagnosticHandler diagHandler(&getContext(), [](Diagnostic &diag) {
auto &os = llvm::outs();
- if (diag.getLocation().isa<OpaqueLoc>()) {
+ if (isa<OpaqueLoc>(diag.getLocation())) {
MyLocation *loc = OpaqueLoc::getUnderlyingLocationOrNull<MyLocation *>(
diag.getLocation());
if (loc)
diff --git a/mlir/test/lib/IR/TestPrintDefUse.cpp b/mlir/test/lib/IR/TestPrintDefUse.cpp
index 0656036731a12..5d489a342f57d 100644
--- a/mlir/test/lib/IR/TestPrintDefUse.cpp
+++ b/mlir/test/lib/IR/TestPrintDefUse.cpp
@@ -34,7 +34,7 @@ struct TestPrintDefUsePass
} else {
// If there is no defining op, the Value is necessarily a Block
// argument.
- auto blockArg = operand.cast<BlockArgument>();
+ auto blockArg = cast<BlockArgument>(operand);
llvm::outs() << " - Operand produced by Block argument, number "
<< blockArg.getArgNumber() << "\n";
}
diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Transforms/TestTopologicalSort.cpp
index 4ad5b5c2608fc..a8cc7a5af60d8 100644
--- a/mlir/test/lib/Transforms/TestTopologicalSort.cpp
+++ b/mlir/test/lib/Transforms/TestTopologicalSort.cpp
@@ -42,7 +42,7 @@ struct TestTopologicalSortAnalysisPass
// If the root has an "ordered" attribute, we fill the selectedOps
// vector in a certain order.
int64_t pos =
- selected->getAttr("selected").cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(selected->getAttr("selected")).getInt();
if (pos >= static_cast<int64_t>(selectedOps.size()))
selectedOps.append(pos + 1 - selectedOps.size(), nullptr);
selectedOps[pos] = selected;
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 52ea148179357..35f901519920e 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -317,10 +317,10 @@ struct ScalarTraits<SerializedAffineMap> {
SerializedAffineMap &value) {
assert(rawYamlContext);
auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
- if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext)
- .dyn_cast_or_null<AffineMapAttr>())
+ if (auto attr = dyn_cast_or_null<AffineMapAttr>(
+ mlir::parseAttribute(scalar, yamlContext->mlirContext)))
value.affineMapAttr = attr;
- else if (!value.affineMapAttr || !value.affineMapAttr.isa<AffineMapAttr>())
+ else if (!value.affineMapAttr || !isa<AffineMapAttr>(value.affineMapAttr))
return "could not parse as an affine map attribute";
return StringRef();
}
diff --git a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
index f39093498e09d..aa19b5c651f55 100644
--- a/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
+++ b/mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp
@@ -36,18 +36,18 @@ TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) {
ASSERT_EQ(subElementTypes.size(), 4U);
// !llvm.ptr<struct<"foo",...>>
- ASSERT_TRUE(subElementTypes[0].isa<LLVMPointerType>());
+ ASSERT_TRUE(isa<LLVMPointerType>(subElementTypes[0]));
// !llvm.struct<"bar",...>
- auto structType = subElementTypes[1].dyn_cast<LLVMStructType>();
+ auto structType = dyn_cast<LLVMStructType>(subElementTypes[1]);
ASSERT_TRUE(bool(structType));
ASSERT_TRUE(structType.getName().equals("bar"));
// !llvm.ptr<struct<"bar",...>>
- ASSERT_TRUE(subElementTypes[2].isa<LLVMPointerType>());
+ ASSERT_TRUE(isa<LLVMPointerType>(subElementTypes[2]));
// !llvm.struct<"foo",...>
- structType = subElementTypes[3].dyn_cast<LLVMStructType>();
+ structType = dyn_cast<LLVMStructType>(subElementTypes[3]);
ASSERT_TRUE(bool(structType));
ASSERT_TRUE(structType.getName().equals("foo"));
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 94345d00f1130..f01cc026b72cc 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -278,7 +278,7 @@ static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
// Check that we cast to this attribute when possible.
Attribute genericAttr = attr;
- EXPECT_TRUE(genericAttr.template isa<AttrT>());
+ EXPECT_TRUE(isa<AttrT>(genericAttr));
}
template <typename AttrT, typename T>
static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
@@ -330,9 +330,9 @@ TEST(DenseResourceElementsAttrTest, CheckNoCast) {
Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
- EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>());
- EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>());
- EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>());
+ EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr));
+ EXPECT_FALSE(isa<DenseF32ResourceElementsAttr>(i32ResourceAttr));
+ EXPECT_FALSE(isa<DenseBoolResourceElementsAttr>(i32ResourceAttr));
}
TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
@@ -407,17 +407,17 @@ TEST(SparseElementsAttrTest, GetZero) {
// Only index (0, 0) contains an element, others are supposed to return
// the zero/empty value.
auto zeroIntValue =
- sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>();
+ cast<IntegerAttr>(sparseInt.getValues<Attribute>()[{1, 1}]);
EXPECT_EQ(zeroIntValue.getInt(), 0);
EXPECT_TRUE(zeroIntValue.getType() == intTy);
auto zeroFloatValue =
- sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>();
+ cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]);
EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
auto zeroStringValue =
- sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>();
+ cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]);
EXPECT_TRUE(zeroStringValue.getValue().empty());
EXPECT_TRUE(zeroStringValue.getType() == stringTy);
}
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index d5e19d27f8eb5..fe855164f8748 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -61,11 +61,11 @@ TEST(InterfaceAttachment, Type) {
// Check that the type has no interface.
IntegerType i8 = IntegerType::get(&context, 8);
- ASSERT_FALSE(i8.isa<TestExternalTypeInterface>());
+ ASSERT_FALSE(isa<TestExternalTypeInterface>(i8));
// Attach an interface and check that the type now has the interface.
IntegerType::attachInterface<Model>(context);
- TestExternalTypeInterface iface = i8.dyn_cast<TestExternalTypeInterface>();
+ TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8);
ASSERT_TRUE(iface != nullptr);
EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
@@ -74,9 +74,9 @@ TEST(InterfaceAttachment, Type) {
// Same, but with the default implementation overridden.
FloatType flt = Float32Type::get(&context);
- ASSERT_FALSE(flt.isa<TestExternalTypeInterface>());
+ ASSERT_FALSE(isa<TestExternalTypeInterface>(flt));
Float32Type::attachInterface<OverridingModel>(context);
- iface = flt.dyn_cast<TestExternalTypeInterface>();
+ iface = dyn_cast<TestExternalTypeInterface>(flt);
ASSERT_TRUE(iface != nullptr);
EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
@@ -86,7 +86,7 @@ TEST(InterfaceAttachment, Type) {
// Other contexts shouldn't have the attribute attached.
MLIRContext other;
IntegerType i8other = IntegerType::get(&other, 8);
- EXPECT_FALSE(i8other.isa<TestExternalTypeInterface>());
+ EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other));
}
/// External interface model for the test type from the test dialect.
@@ -111,7 +111,7 @@ TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
MLIRContext context(registry);
context.loadDialect<test::TestDialect>();
test::TestType testType = test::TestType::get(&context);
- auto iface = testType.dyn_cast<TestExternalTypeInterface>();
+ auto iface = dyn_cast<TestExternalTypeInterface>(testType);
ASSERT_TRUE(iface != nullptr);
EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
@@ -130,9 +130,9 @@ TEST(InterfaceAttachment, TypeDelayedContextAppend) {
MLIRContext context;
context.loadDialect<test::TestDialect>();
test::TestType testType = test::TestType::get(&context);
- EXPECT_FALSE(testType.isa<TestExternalTypeInterface>());
+ EXPECT_FALSE(isa<TestExternalTypeInterface>(testType));
context.appendDialectRegistry(registry);
- EXPECT_TRUE(testType.isa<TestExternalTypeInterface>());
+ EXPECT_TRUE(isa<TestExternalTypeInterface>(testType));
}
TEST(InterfaceAttachment, RepeatedRegistration) {
@@ -156,13 +156,13 @@ TEST(InterfaceAttachment, TypeBuiltinDelayed) {
MLIRContext context(registry);
IntegerType i16 = IntegerType::get(&context, 16);
- EXPECT_TRUE(i16.isa<TestExternalTypeInterface>());
+ EXPECT_TRUE(isa<TestExternalTypeInterface>(i16));
MLIRContext initiallyEmpty;
IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
- EXPECT_FALSE(i32.isa<TestExternalTypeInterface>());
+ EXPECT_FALSE(isa<TestExternalTypeInterface>(i32));
initiallyEmpty.appendDialectRegistry(registry);
- EXPECT_TRUE(i32.isa<TestExternalTypeInterface>());
+ EXPECT_TRUE(isa<TestExternalTypeInterface>(i32));
}
/// The interface provides a default implementation that expects
@@ -181,9 +181,8 @@ struct TestExternalFallbackTypeVectorModel
: public TestExternalFallbackTypeInterface::FallbackModel<
TestExternalFallbackTypeVectorModel> {
unsigned getBitwidth(Type type) const {
- IntegerType elementType = type.cast<VectorType>()
- .getElementType()
- .dyn_cast_or_null<IntegerType>();
+ IntegerType elementType =
+ dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType());
return elementType ? elementType.getWidth() : 0;
}
};
@@ -193,16 +192,16 @@ TEST(InterfaceAttachment, Fallback) {
// Just check that we can attach the interface.
IntegerType i8 = IntegerType::get(&context, 8);
- ASSERT_FALSE(i8.isa<TestExternalFallbackTypeInterface>());
+ ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8));
IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
- ASSERT_TRUE(i8.isa<TestExternalFallbackTypeInterface>());
+ ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8));
// Call the method so it is guaranteed not to be instantiated.
VectorType vec = VectorType::get({42}, i8);
- ASSERT_FALSE(vec.isa<TestExternalFallbackTypeInterface>());
+ ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec));
VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
- ASSERT_TRUE(vec.isa<TestExternalFallbackTypeInterface>());
- EXPECT_EQ(vec.cast<TestExternalFallbackTypeInterface>().getBitwidth(), 8u);
+ ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec));
+ EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u);
}
/// External model for attribute interfaces.
@@ -210,7 +209,7 @@ struct TestExternalIntegerAttrModel
: public TestExternalAttrInterface::ExternalModel<
TestExternalIntegerAttrModel, IntegerAttr> {
const Dialect *getDialectPtr(Attribute attr) const {
- return &attr.cast<IntegerAttr>().getDialect();
+ return &cast<IntegerAttr>(attr).getDialect();
}
static int getSomeNumber() { return 42; }
@@ -222,9 +221,9 @@ TEST(InterfaceAttachment, Attribute) {
// Attribute interfaces use the exact same mechanism as types, so just check
// that the basics work for attributes.
IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
- ASSERT_FALSE(attr.isa<TestExternalAttrInterface>());
+ ASSERT_FALSE(isa<TestExternalAttrInterface>(attr));
IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
- auto iface = attr.dyn_cast<TestExternalAttrInterface>();
+ auto iface = dyn_cast<TestExternalAttrInterface>(attr);
ASSERT_TRUE(iface != nullptr);
EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
EXPECT_EQ(iface.getSomeNumber(), 42);
@@ -253,14 +252,14 @@ TEST(InterfaceAttachmentTest, AttributeDelayed) {
MLIRContext context(registry);
context.loadDialect<test::TestDialect>();
auto attr = test::SimpleAAttr::get(&context);
- EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
+ EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
MLIRContext initiallyEmpty;
initiallyEmpty.loadDialect<test::TestDialect>();
attr = test::SimpleAAttr::get(&initiallyEmpty);
- EXPECT_FALSE(attr.isa<TestExternalAttrInterface>());
+ EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
initiallyEmpty.appendDialectRegistry(registry);
- EXPECT_TRUE(attr.isa<TestExternalAttrInterface>());
+ EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
}
/// External interface model for the module operation. Only provides non-default
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 7e0a8f55a5526..6601f32f3288a 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -152,16 +152,16 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
static unsigned getTypeSizeInBits(Type type, const DataLayout &dataLayout,
DataLayoutEntryListRef params) {
// Make a recursive query.
- if (type.isa<FloatType>())
+ if (isa<FloatType>(type))
return dataLayout.getTypeSizeInBits(
IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth()));
// Handle built-in types that are not handled by the default process.
- if (auto iType = type.dyn_cast<IntegerType>()) {
+ if (auto iType = dyn_cast<IntegerType>(type)) {
for (DataLayoutEntryInterface entry : params)
if (entry.getKey().dyn_cast<Type>() == type)
return 8 *
- entry.getValue().cast<IntegerAttr>().getValue().getZExtValue();
+ cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
return 8 * iType.getIntOrFloatBitWidth();
}
@@ -217,7 +217,7 @@ struct DLTestDialect : Dialect {
void printAttribute(Attribute attr,
DialectAsmPrinter &printer) const override {
printer << "spec<";
- llvm::interleaveComma(attr.cast<CustomDataLayoutSpec>().getEntries(),
+ llvm::interleaveComma(cast<CustomDataLayoutSpec>(attr).getEntries(),
printer);
printer << ">";
}
@@ -244,7 +244,7 @@ struct DLTestDialect : Dialect {
}
void printType(Type type, DialectAsmPrinter &printer) const override {
- if (type.isa<SingleQueryType>())
+ if (isa<SingleQueryType>(type))
printer << "single_query";
else
printer << "no_layout";
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 24e87020c9329..97349d681c3a0 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -75,12 +75,12 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
// Verify that each function got annotated with expected attributes.
for (func::FuncOp func : module->getOps<func::FuncOp>()) {
- ASSERT_TRUE(func->getAttr("isFunc").isa<BoolAttr>());
- EXPECT_TRUE(func->getAttr("isFunc").cast<BoolAttr>().getValue());
+ ASSERT_TRUE(isa<BoolAttr>(func->getAttr("isFunc")));
+ EXPECT_TRUE(cast<BoolAttr>(func->getAttr("isFunc")).getValue());
bool isSecret = func.getName() == "secret";
- ASSERT_TRUE(func->getAttr("isSecret").isa<BoolAttr>());
- EXPECT_EQ(func->getAttr("isSecret").cast<BoolAttr>().getValue(), isSecret);
+ ASSERT_TRUE(isa<BoolAttr>(func->getAttr("isSecret")));
+ EXPECT_EQ(cast<BoolAttr>(func->getAttr("isSecret")).getValue(), isSecret);
}
}
More information about the Mlir-commits
mailing list