[Mlir-commits] [mlir] 68f5881 - [mlir] Move casting calls from methods to function calls
Tres Popp
llvmlistbot at llvm.org
Fri May 26 01:30:11 PDT 2023
Author: Tres Popp
Date: 2023-05-26T10:29:55+02:00
New Revision: 68f58812e3e99e31d77c0c23b6298489444dc0be
URL: https://github.com/llvm/llvm-project/commit/68f58812e3e99e31d77c0c23b6298489444dc0be
DIFF: https://github.com/llvm/llvm-project/commit/68f58812e3e99e31d77c0c23b6298489444dc0be.diff
LOG: [mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.
Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.
Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443
Implementation:
This patch updates all remaining uses of the deprecated functionality in
mlir/. This was done with clang-tidy as described below and further
modifications to GPUBase.td and OpenMPOpsInterfaces.td.
Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
additional check:
main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
them to a pure state.
```
ninja -C $BUILD_DIR clang-tidy
run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
-header-filter=mlir/ mlir/* -fix
rm -rf $BUILD_DIR/tools/mlir/**/*.inc
```
Differential Revision: https://reviews.llvm.org/D151542
Added:
Modified:
mlir/examples/toy/Ch2/mlir/Dialect.cpp
mlir/examples/toy/Ch3/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/TypeRange.h
mlir/include/mlir/Interfaces/SideEffectInterfaces.h
mlir/include/mlir/Pass/AnalysisManager.h
mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Analysis/DataFlowFramework.cpp
mlir/lib/AsmParser/Parser.cpp
mlir/lib/CAPI/Interfaces/Interfaces.cpp
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Debug/DebuggerExecutionContextHook.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Arith/Utils/Utils.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
mlir/lib/Dialect/DLTI/DLTI.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/Split.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Block.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/IR/Region.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/IR/TypeRange.cpp
mlir/lib/IR/Types.cpp
mlir/lib/IR/Unit.cpp
mlir/lib/IR/Value.cpp
mlir/lib/Interfaces/DataLayoutInterfaces.cpp
mlir/lib/Interfaces/InferTypeOpInterface.cpp
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/lib/Pass/PassDetail.h
mlir/lib/TableGen/Operator.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/lib/Transforms/Utils/InliningUtils.cpp
mlir/test/lib/Analysis/TestDataFlowFramework.cpp
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
Removed:
################################################################################
diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
index ef07af26ec435..df9105c8a8ce5 100644
--- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp
@@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
index 43f8d5b1481d5..ca076f22a60a4 100644
--- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp
@@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index d533e5805081f..e84151884ad44 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
- TensorType input = inputs.front().dyn_cast<TensorType>();
- TensorType output = outputs.front().dyn_cast<TensorType>();
+ TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+ TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
@@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
void TransposeOp::inferShapes() {
- auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+ auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
index cf3e492989b0f..d45baa14ab3e8 100644
--- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
@@ -94,7 +94,7 @@ struct ShapeInferencePass
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
- return operandType.isa<RankedTensorType>();
+ return llvm::isa<RankedTensorType>(operandType);
});
}
@@ -102,7 +102,7 @@ struct ShapeInferencePass
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
- return !resultType.isa<RankedTensorType>();
+ return !llvm::isa<RankedTensorType>(resultType);
});
}
};
diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index 4f0326682fbd7..c2a99aa2921b8 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
- TensorType input = inputs.front().dyn_cast<TensorType>();
- TensorType output = outputs.front().dyn_cast<TensorType>();
+ TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+ TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
@@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
void TransposeOp::inferShapes() {
- auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+ auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 988175508ca1b..fd589ddf84541 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>(op.getType());
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
@@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
target.addIllegalDialect<toy::ToyDialect>();
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
- [](Type type) { return type.isa<TensorType>(); });
+ [](Type type) { return llvm::isa<TensorType>(type); });
});
// Now that the conversion target has been defined, we just need to provide
diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
index cf3e492989b0f..d45baa14ab3e8 100644
--- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
@@ -94,7 +94,7 @@ struct ShapeInferencePass
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
- return operandType.isa<RankedTensorType>();
+ return llvm::isa<RankedTensorType>(operandType);
});
}
@@ -102,7 +102,7 @@ struct ShapeInferencePass
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
- return !resultType.isa<RankedTensorType>();
+ return !llvm::isa<RankedTensorType>(resultType);
});
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index 4f0326682fbd7..c2a99aa2921b8 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
- TensorType input = inputs.front().dyn_cast<TensorType>();
- TensorType output = outputs.front().dyn_cast<TensorType>();
+ TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+ TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
@@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
void TransposeOp::inferShapes() {
- auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+ auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index 988175508ca1b..fd589ddf84541 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>(op.getType());
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
@@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
target.addIllegalDialect<toy::ToyDialect>();
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
- [](Type type) { return type.isa<TensorType>(); });
+ [](Type type) { return llvm::isa<TensorType>(type); });
});
// Now that the conversion target has been defined, we just need to provide
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 06e509639a918..a10588e51d9b4 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -61,7 +61,7 @@ class PrintOpLowering : public ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
+ auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
diff --git a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
index cf3e492989b0f..d45baa14ab3e8 100644
--- a/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch6/mlir/ShapeInferencePass.cpp
@@ -94,7 +94,7 @@ struct ShapeInferencePass
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
- return operandType.isa<RankedTensorType>();
+ return llvm::isa<RankedTensorType>(operandType);
});
}
@@ -102,7 +102,7 @@ struct ShapeInferencePass
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
- return !resultType.isa<RankedTensorType>();
+ return !llvm::isa<RankedTensorType>(resultType);
});
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 643240333d92e..1b77f8ce6d8a4 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -101,7 +101,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
- if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
+ if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@@ -179,9 +179,9 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
static mlir::LogicalResult verifyConstantForType(mlir::Type type,
mlir::Attribute opaqueValue,
mlir::Operation *op) {
- if (type.isa<mlir::TensorType>()) {
+ if (llvm::isa<mlir::TensorType>(type)) {
// Check that the value is an elements attribute.
- auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
+ auto attrValue = llvm::dyn_cast<mlir::DenseFPElementsAttr>(opaqueValue);
if (!attrValue)
return op->emitError("constant of TensorType must be initialized by "
"a DenseFPElementsAttr, got ")
@@ -189,13 +189,13 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
- auto resultType = type.dyn_cast<mlir::RankedTensorType>();
+ auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(type);
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the
// constant result type.
- auto attrType = attrValue.getType().cast<mlir::RankedTensorType>();
+ auto attrType = llvm::cast<mlir::RankedTensorType>(attrValue.getType());
if (attrType.getRank() != resultType.getRank()) {
return op->emitOpError("return type must match the one of the attached "
"value attribute: ")
@@ -213,11 +213,11 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
}
return mlir::success();
}
- auto resultType = type.cast<StructType>();
+ auto resultType = llvm::cast<StructType>(type);
llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
// Verify that the initializer is an Array.
- auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
+ auto attrValue = llvm::dyn_cast<ArrayAttr>(opaqueValue);
if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
return op->emitError("constant of StructType must be initialized by an "
"ArrayAttr with the same number of elements, got ")
@@ -283,8 +283,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
- TensorType input = inputs.front().dyn_cast<TensorType>();
- TensorType output = outputs.front().dyn_cast<TensorType>();
+ TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
+ TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
@@ -426,8 +426,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
- if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
- resultType.isa<mlir::UnrankedTensorType>())
+ if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@@ -442,7 +442,7 @@ mlir::LogicalResult ReturnOp::verify() {
void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
mlir::Value input, size_t index) {
// Extract the result type from the input type.
- StructType structTy = input.getType().cast<StructType>();
+ StructType structTy = llvm::cast<StructType>(input.getType());
assert(index < structTy.getNumElementTypes());
mlir::Type resultType = structTy.getElementTypes()[index];
@@ -451,7 +451,7 @@ void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
}
mlir::LogicalResult StructAccessOp::verify() {
- StructType structTy = getInput().getType().cast<StructType>();
+ StructType structTy = llvm::cast<StructType>(getInput().getType());
size_t indexValue = getIndex();
if (indexValue >= structTy.getNumElementTypes())
return emitOpError()
@@ -474,14 +474,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
void TransposeOp::inferShapes() {
- auto arrayTy = getOperand().getType().cast<RankedTensorType>();
+ auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
- auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
- auto resultType = getType().dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
+ auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
@@ -598,7 +598,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
return nullptr;
// Check that the type is either a TensorType or another StructType.
- if (!elementType.isa<mlir::TensorType, StructType>()) {
+ if (!llvm::isa<mlir::TensorType, StructType>(elementType)) {
parser.emitError(typeLoc, "element type for a struct must either "
"be a TensorType or a StructType, got: ")
<< elementType;
@@ -619,7 +619,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
void ToyDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const {
// Currently the only toy type is a struct type.
- StructType structType = type.cast<StructType>();
+ StructType structType = llvm::cast<StructType>(type);
// Print the struct type according to the parser format.
printer << "struct<";
@@ -653,9 +653,9 @@ mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
mlir::Attribute value,
mlir::Type type,
mlir::Location loc) {
- if (type.isa<StructType>())
+ if (llvm::isa<StructType>(type))
return builder.create<StructConstantOp>(loc, type,
- value.cast<mlir::ArrayAttr>());
+ llvm::cast<mlir::ArrayAttr>(value));
return builder.create<ConstantOp>(loc, type,
- value.cast<mlir::DenseElementsAttr>());
+ llvm::cast<mlir::DenseElementsAttr>(value));
}
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 988175508ca1b..fd589ddf84541 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>(op.getType());
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
@@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
target.addIllegalDialect<toy::ToyDialect>();
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
- [](Type type) { return type.isa<TensorType>(); });
+ [](Type type) { return llvm::isa<TensorType>(type); });
});
// Now that the conversion target has been defined, we just need to provide
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 06e509639a918..a10588e51d9b4 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -61,7 +61,7 @@ class PrintOpLowering : public ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
+ auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
index cf3e492989b0f..d45baa14ab3e8 100644
--- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp
@@ -94,7 +94,7 @@ struct ShapeInferencePass
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
- return operandType.isa<RankedTensorType>();
+ return llvm::isa<RankedTensorType>(operandType);
});
}
@@ -102,7 +102,7 @@ struct ShapeInferencePass
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
- return !resultType.isa<RankedTensorType>();
+ return !llvm::isa<RankedTensorType>(resultType);
});
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
index 62b00d99476a0..09be97ea3c015 100644
--- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
+++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp
@@ -31,7 +31,8 @@ OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
/// Fold simple struct access operations that access into a constant.
OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
- auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
+ auto structAttr =
+ llvm::dyn_cast_if_present<mlir::ArrayAttr>(adaptor.getInput());
if (!structAttr)
return nullptr;
diff --git a/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h b/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
index 7def9e25a69ae..e62b9c0bc0de5 100644
--- a/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
+++ b/mlir/include/mlir/Debug/BreakpointManagers/FileLineColLocBreakpointManager.h
@@ -62,19 +62,19 @@ class FileLineColLocBreakpointManager
public:
Breakpoint *match(const Action &action) const override {
for (const IRUnit &unit : action.getContextIRUnits()) {
- if (auto *op = unit.dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(unit)) {
if (auto match = matchFromLocation(op->getLoc()))
return *match;
continue;
}
- if (auto *block = unit.dyn_cast<Block *>()) {
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(unit)) {
for (auto &op : block->getOperations()) {
if (auto match = matchFromLocation(op.getLoc()))
return *match;
}
continue;
}
- if (Region *region = unit.dyn_cast<Region *>()) {
+ if (Region *region = llvm::dyn_cast_if_present<Region *>(unit)) {
if (auto match = matchFromLocation(region->getLoc()))
return *match;
continue;
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index ddef02095c64f..63f18ebb46a93 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -110,27 +110,27 @@ class MMAMatrixOf<list<Type> allowedTypes> :
"gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
// Types for all sparse handles.
-def GPU_SparseEnvHandle :
- DialectType<GPU_Dialect,
- CPred<"$_self.isa<::mlir::gpu::SparseEnvHandleType>()">,
- "sparse environment handle type">,
+def GPU_SparseEnvHandle :
+ DialectType<GPU_Dialect,
+ CPred<"llvm::isa<::mlir::gpu::SparseEnvHandleType>($_self)">,
+ "sparse environment handle type">,
BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">;
-def GPU_SparseDnVecHandle :
- DialectType<GPU_Dialect,
- CPred<"$_self.isa<::mlir::gpu::SparseDnVecHandleType>()">,
+def GPU_SparseDnVecHandle :
+ DialectType<GPU_Dialect,
+ CPred<"llvm::isa<::mlir::gpu::SparseDnVecHandleType>($_self)">,
"dense vector handle type">,
BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">;
-def GPU_SparseDnMatHandle :
- DialectType<GPU_Dialect,
- CPred<"$_self.isa<::mlir::gpu::SparseDnMatHandleType>()">,
+def GPU_SparseDnMatHandle :
+ DialectType<GPU_Dialect,
+ CPred<"llvm::isa<::mlir::gpu::SparseDnMatHandleType>($_self)">,
"dense matrix handle type">,
BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">;
-def GPU_SparseSpMatHandle :
- DialectType<GPU_Dialect,
- CPred<"$_self.isa<::mlir::gpu::SparseSpMatHandleType>()">,
+def GPU_SparseSpMatHandle :
+ DialectType<GPU_Dialect,
+ CPred<"llvm::isa<::mlir::gpu::SparseSpMatHandleType>($_self)">,
"sparse matrix handle type">,
BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 6f83c053bdd5c..0331f9ff1eb29 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -95,7 +95,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
/*methodName=*/"getDeclareTargetDeviceType",
(ins), [{}], [{
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
- if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
+ if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
return dAttr.getDeviceType().getValue();
return {};
}]>,
@@ -108,7 +108,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
/*methodName=*/"getDeclareTargetCaptureClause",
(ins), [{}], [{
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
- if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
+ if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
return dAttr.getCaptureClause().getValue();
return {};
}]>
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 79313b6facda0..acb355654ef71 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -115,7 +115,7 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
static bool classof(Type type);
/// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return cast<ShapedType>(); }
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
//===----------------------------------------------------------------------===//
@@ -169,7 +169,7 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
unsigned getMemorySpaceAsInt() const;
/// Allow implicit conversion to ShapedType.
- operator ShapedType() const { return cast<ShapedType>(); }
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
} // namespace mlir
diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index 37a7c242ee84c..99fabab334f92 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -217,13 +217,15 @@ struct DenseMapInfo<mlir::TypeRange> {
}
static bool isEmptyKey(mlir::TypeRange range) {
- if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
+ if (const auto *type =
+ llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
return type == getEmptyKeyPointer();
return false;
}
static bool isTombstoneKey(mlir::TypeRange range) {
- if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
+ if (const auto *type =
+ llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
return type == getTombstoneKeyPointer();
return false;
}
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index 306f4cf973181..ac42f38eb05e4 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -163,12 +163,12 @@ class EffectInstance {
/// Return the value the effect is applied on, or nullptr if there isn't a
/// known value being affected.
- Value getValue() const { return value ? value.dyn_cast<Value>() : Value(); }
+ Value getValue() const { return value ? llvm::dyn_cast_if_present<Value>(value) : Value(); }
/// Return the symbol reference the effect is applied on, or nullptr if there
/// isn't a known smbol being affected.
SymbolRefAttr getSymbolRef() const {
- return value ? value.dyn_cast<SymbolRefAttr>() : SymbolRefAttr();
+ return value ? llvm::dyn_cast_if_present<SymbolRefAttr>(value) : SymbolRefAttr();
}
/// Return the resource that the effect applies to.
diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index 9821a6817466f..f9db26140259b 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -254,7 +254,7 @@ struct NestedAnalysisMap {
/// Returns the parent analysis map for this analysis map, or null if this is
/// the top-level map.
const NestedAnalysisMap *getParent() const {
- return parentOrInstrumentor.dyn_cast<NestedAnalysisMap *>();
+ return llvm::dyn_cast_if_present<NestedAnalysisMap *>(parentOrInstrumentor);
}
/// Returns a pass instrumentation object for the current operation. This
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index db25c239204f7..a6d53b2e22119 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -89,7 +89,7 @@ void SparseConstantPropagation::visitOperation(
// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = std::get<1>(it);
- if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
propagateIfChanged(lattice,
lattice->join(ConstantValue(attr, op->getDialect())));
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 8ff71b59750a8..d681604aaff64 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -31,7 +31,7 @@ void Executable::print(raw_ostream &os) const {
}
void Executable::onUpdate(DataFlowSolver *solver) const {
- if (auto *block = point.dyn_cast<Block *>()) {
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
// Re-invoke the analyses on the block itself.
for (DataFlowAnalysis *analysis : subscribers)
solver->enqueue({block, analysis});
@@ -39,7 +39,7 @@ void Executable::onUpdate(DataFlowSolver *solver) const {
for (DataFlowAnalysis *analysis : subscribers)
for (Operation &op : *block)
solver->enqueue({&op, analysis});
- } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
+ } else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(point)) {
// Re-invoke the analysis on the successor block.
if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
for (DataFlowAnalysis *analysis : subscribers)
@@ -219,7 +219,7 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
if (point.is<Block *>())
return success();
- auto *op = point.dyn_cast<Operation *>();
+ auto *op = llvm::dyn_cast_if_present<Operation *>(point);
if (!op)
return emitError(point.getLoc(), "unknown program point kind");
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 6450891f7dcc7..77ef87d90b75a 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -33,9 +33,9 @@ LogicalResult AbstractDenseDataFlowAnalysis::initialize(Operation *top) {
}
LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
- if (auto *op = point.dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
processOperation(op);
- else if (auto *block = point.dyn_cast<Block *>())
+ else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
visitBlock(block);
else
return failure();
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index c866fc610bc8e..c832405e51f64 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -181,7 +181,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
if (auto bound =
dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
return bound.getValue();
- } else if (auto value = loopBound->dyn_cast<Value>()) {
+ } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
const IntegerValueRangeLattice *lattice =
getLatticeElementFor(op, value);
if (lattice != nullptr)
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 629c482edab22..f5cf866d0d2a5 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -66,9 +66,9 @@ AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
}
LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
- if (Operation *op = point.dyn_cast<Operation *>())
+ if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
visitOperation(op);
- else if (Block *block = point.dyn_cast<Block *>())
+ else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
visitBlock(block);
else
return failure();
@@ -238,7 +238,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
unsigned firstIndex = 0;
if (inputs.size() != lattices.size()) {
- if (point.dyn_cast<Operation *>()) {
+ if (llvm::dyn_cast_if_present<Operation *>(point)) {
if (!inputs.empty())
firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
visitNonControlFlowArgumentsImpl(
@@ -316,9 +316,9 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (Operation *op = point.dyn_cast<Operation *>())
+ if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
visitOperation(op);
- else if (point.dyn_cast<Block *>())
+ else if (llvm::dyn_cast_if_present<Block *>(point))
// For backward dataflow, we don't have to do any work for the blocks
// themselves. CFG edges between blocks are processed by the BranchOp
// logic in `visitOperation`, and entry blocks for functions are tied
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 9c8a8899d8c12..47caf268290ad 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -39,21 +39,21 @@ void ProgramPoint::print(raw_ostream &os) const {
os << "<NULL POINT>";
return;
}
- if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
+ if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
return programPoint->print(os);
- if (auto *op = dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast<Operation *>(*this))
return op->print(os);
- if (auto value = dyn_cast<Value>())
+ if (auto value = llvm::dyn_cast<Value>(*this))
return value.print(os);
return get<Block *>()->print(os);
}
Location ProgramPoint::getLoc() const {
- if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
+ if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
return programPoint->getLoc();
- if (auto *op = dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast<Operation *>(*this))
return op->getLoc();
- if (auto value = dyn_cast<Value>())
+ if (auto value = llvm::dyn_cast<Value>(*this))
return value.getLoc();
return get<Block *>()->getParent()->getLoc();
}
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 75f4d4d607fc0..3b562e013ccbb 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -2060,7 +2060,7 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
if (parseToken(Token::r_paren, "expected ')' in location"))
return failure();
- if (auto *op = opOrArgument.dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument))
op->setLoc(directLoc);
else
opOrArgument.get<BlockArgument>().setLoc(directLoc);
diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
index 3144a338fa426..d3fd6b4c0b34b 100644
--- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp
+++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
@@ -47,7 +47,7 @@ SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
DictionaryAttr attributeDict;
if (!mlirAttributeIsNull(attributes))
- attributeDict = unwrap(attributes).cast<DictionaryAttr>();
+ attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
return attributeDict;
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1d1923d4d0c2b..8defd8970b900 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -1190,9 +1190,9 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
// TODO: safer and more flexible to store data type in actual op instead?
static Type getSpMatElemType(Value spMat) {
if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
- return op.getValues().getType().cast<MemRefType>().getElementType();
+ return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
- return op.getValues().getType().cast<MemRefType>().getElementType();
+ return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
llvm_unreachable("cannot find spmat def");
}
@@ -1235,7 +1235,7 @@ LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
- Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
+ Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
auto handle =
@@ -1271,7 +1271,7 @@ LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
- Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
+ Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
auto handle =
@@ -1315,8 +1315,8 @@ LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
}
- Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
- Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
+ Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
+ Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto iw = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
@@ -1350,9 +1350,9 @@ LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
}
- Type pType = op.getRowPos().getType().cast<MemRefType>().getElementType();
- Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
- Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
+ Type pType = llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
+ Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
+ Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto pw = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
auto iw = rewriter.create<LLVM::ConstantOp>(
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index cf0d5068c02d7..aac6e60b4f50d 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -405,7 +405,7 @@ LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
return failure();
if (!(*converted)) // Conversion to default is 0.
return 0;
- if (auto explicitSpace = converted->dyn_cast_or_null<IntegerAttr>())
+ if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
return explicitSpace.getInt();
return failure();
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 24ea1a6eab5c6..f82a86c88efb0 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -671,7 +671,7 @@ struct GlobalMemrefOpLowering
Attribute initialValue = nullptr;
if (!global.isExternal() && !global.isUninitialized()) {
- auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
+ auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
initialValue = elementsAttr;
// For scalar memrefs, the global variable created is of the element type,
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 8cd180d450eb0..b9a1cc93ffd95 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -412,10 +412,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
auto *ans = cast<TypeAnswer>(answer);
if (isa<pdl::RangeType>(val.getType()))
builder.create<pdl_interp::CheckTypesOp>(
- loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
+ loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
else
builder.create<pdl_interp::CheckTypeOp>(
- loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
+ loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
break;
}
case Predicates::AttributeQuestion: {
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 2faf7f1a625ea..9e0cccff6cf99 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -300,7 +300,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
// tosa::ErfOp
- if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
+ if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
// tosa::GreaterOp
@@ -1885,7 +1885,7 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
auto addDynamicDimension = [&](Value source, int64_t dim) {
auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
- if (auto dimValue = dynamicDim.value().dyn_cast<Value>())
+ if (auto dimValue = llvm::dyn_cast_if_present<Value>(dynamicDim.value()))
results.push_back(dimValue);
};
diff --git a/mlir/lib/Debug/DebuggerExecutionContextHook.cpp b/mlir/lib/Debug/DebuggerExecutionContextHook.cpp
index 6cbd7f380c07f..744a0380ec710 100644
--- a/mlir/lib/Debug/DebuggerExecutionContextHook.cpp
+++ b/mlir/lib/Debug/DebuggerExecutionContextHook.cpp
@@ -121,11 +121,11 @@ void mlirDebuggerCursorSelectParentIRUnit() {
return;
}
IRUnit *unit = &state.cursor;
- if (auto *op = unit->dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
state.cursor = op->getBlock();
- } else if (auto *region = unit->dyn_cast<Region *>()) {
+ } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
state.cursor = region->getParentOp();
- } else if (auto *block = unit->dyn_cast<Block *>()) {
+ } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
state.cursor = block->getParent();
} else {
llvm::outs() << "Current cursor is not a valid IRUnit";
@@ -142,14 +142,14 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
return;
}
IRUnit *unit = &state.cursor;
- if (auto *op = unit->dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
llvm::outs() << "Index invalid, op has " << op->getNumRegions()
<< " but got " << index << "\n";
return;
}
state.cursor = &op->getRegion(index);
- } else if (auto *region = unit->dyn_cast<Region *>()) {
+ } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
auto block = region->begin();
int count = 0;
while (block != region->end() && count != index) {
@@ -163,7 +163,7 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
return;
}
state.cursor = &*block;
- } else if (auto *block = unit->dyn_cast<Block *>()) {
+ } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
auto op = block->begin();
int count = 0;
while (op != block->end() && count != index) {
@@ -192,14 +192,14 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
return;
}
IRUnit *unit = &state.cursor;
- if (auto *op = unit->dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
Operation *previous = op->getPrevNode();
if (!previous) {
llvm::outs() << "No previous operation in the current block\n";
return;
}
state.cursor = previous;
- } else if (auto *region = unit->dyn_cast<Region *>()) {
+ } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
llvm::outs() << "Has region\n";
Operation *parent = region->getParentOp();
if (!parent) {
@@ -212,7 +212,7 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
}
state.cursor =
®ion->getParentOp()->getRegion(region->getRegionNumber() - 1);
- } else if (auto *block = unit->dyn_cast<Block *>()) {
+ } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
Block *previous = block->getPrevNode();
if (!previous) {
llvm::outs() << "No previous block in the current region\n";
@@ -234,14 +234,14 @@ void mlirDebuggerCursorSelectNextIRUnit() {
return;
}
IRUnit *unit = &state.cursor;
- if (auto *op = unit->dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
Operation *next = op->getNextNode();
if (!next) {
llvm::outs() << "No next operation in the current block\n";
return;
}
state.cursor = next;
- } else if (auto *region = unit->dyn_cast<Region *>()) {
+ } else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
Operation *parent = region->getParentOp();
if (!parent) {
llvm::outs() << "No parent operation for the current region\n";
@@ -253,7 +253,7 @@ void mlirDebuggerCursorSelectNextIRUnit() {
}
state.cursor =
®ion->getParentOp()->getRegion(region->getRegionNumber() + 1);
- } else if (auto *block = unit->dyn_cast<Block *>()) {
+ } else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
Block *next = block->getNextNode();
if (!next) {
llvm::outs() << "No next block in the current region\n";
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2009c393ff4dc..9153686ab2f52 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1212,7 +1212,7 @@ static void materializeConstants(OpBuilder &b, Location loc,
actualValues.reserve(values.size());
auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
for (OpFoldResult ofr : values) {
- if (auto value = ofr.dyn_cast<Value>()) {
+ if (auto value = llvm::dyn_cast_if_present<Value>(ofr)) {
actualValues.push_back(value);
continue;
}
@@ -4599,7 +4599,7 @@ void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
if (staticDim.has_value())
return builder.create<arith::ConstantIndexOp>(result.location,
*staticDim);
- return ofr.dyn_cast<Value>();
+ return llvm::dyn_cast_if_present<Value>(ofr);
});
result.addOperands(basisValues);
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d0d83c22a8f36..e0dd2d6fbc03b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -808,7 +808,7 @@ OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
/// or(x, <all ones>) -> <all ones>
- if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
+ if (auto rhsAttr = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()))
if (rhsAttr.getValue().isAllOnes())
return rhsAttr;
@@ -1249,7 +1249,7 @@ LogicalResult arith::ExtSIOp::verify() {
/// Always fold extension of FP constants.
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
- auto constOperand = adaptor.getIn().dyn_cast_or_null<FloatAttr>();
+ auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
if (!constOperand)
return {};
@@ -1702,7 +1702,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
// We are moving constants to the right side; So if lhs is constant rhs is
// guaranteed to be a constant.
- if (auto lhs = adaptor.getLhs().dyn_cast_or_null<TypedAttr>()) {
+ if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getI1SameShape(lhs.getType()),
[pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
@@ -1772,8 +1772,8 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
}
OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
- auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
- auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
+ auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
+ auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
// If one operand is NaN, making them both NaN does not change the result.
if (lhs && lhs.getValue().isNaN())
@@ -2193,11 +2193,11 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
// Constant-fold constant operands over non-splat constant condition.
// select %cst_vec, %cst0, %cst1 => %cst2
if (auto cond =
- adaptor.getCondition().dyn_cast_or_null<DenseElementsAttr>()) {
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
if (auto lhs =
- adaptor.getTrueValue().dyn_cast_or_null<DenseElementsAttr>()) {
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
if (auto rhs =
- adaptor.getFalseValue().dyn_cast_or_null<DenseElementsAttr>()) {
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
SmallVector<Attribute> results;
results.reserve(static_cast<size_t>(cond.getNumElements()));
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85e07253c488f..61ec365f3f26b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -184,7 +184,7 @@ struct SelectOpInterface
// If the buffers have
diff erent types, they
diff er only in their layout
// map.
- auto memrefType = trueType->cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(*trueType);
return getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index fb363c82a069f..965ef117d79ef 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -33,8 +33,8 @@ LogicalResult mlir::foldDynamicIndexList(Builder &b,
if (ofr.is<Attribute>())
continue;
// Newly static, move from Value to constant.
- if (auto cstOp =
- ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) {
+ if (auto cstOp = llvm::dyn_cast_if_present<Value>(ofr)
+ .getDefiningOp<arith::ConstantIndexOp>()) {
ofr = b.getIndexAttr(cstOp.value());
valuesChanged = true;
}
@@ -56,9 +56,9 @@ llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
OpFoldResult ofr) {
- if (auto value = ofr.dyn_cast<Value>())
+ if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
return value;
- auto attr = dyn_cast<IntegerAttr>(ofr.dyn_cast<Attribute>());
+ auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
assert(attr && "expect the op fold result casts to an integer attribute");
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index eeba571658206..fdd6f3de6464a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -179,7 +179,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
}
FailureOr<Value> alloc = options.createAlloc(
- rewriter, loc, allocType->cast<MemRefType>(), dynamicDims);
+ rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
if (failed(alloc))
return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
index be80f30768358..016ec2be62dce 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp
@@ -59,7 +59,8 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index f73efc120d377..89904db43508f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -80,7 +80,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+ SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 5e3519669ec84..34959aad6577b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -995,7 +995,7 @@ static void annotateOpsWithAliasSets(Operation *op,
op->walk([&](Operation *op) {
SmallVector<Attribute> aliasSets;
for (OpResult opResult : op->getOpResults()) {
- if (opResult.getType().isa<TensorType>()) {
+ if (llvm::isa<TensorType>(opResult.getType())) {
SmallVector<Attribute> aliases;
state.applyOnAliases(opResult, [&](Value alias) {
std::string buffer;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d0af1c278c146..417f457c8910c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -238,7 +238,7 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(func::CallOp callOp) {
- SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+ SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 02c0e1643f3c6..f2d1a96fa4a28 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -90,7 +90,8 @@ OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
- ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
+ ArrayAttr arrayAttr =
+ llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[1];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
@@ -103,7 +104,8 @@ OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
- ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
+ ArrayAttr arrayAttr =
+ llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[0];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 3970c9c659a22..aba9e7db0a2fe 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -94,7 +94,7 @@ DataLayoutEntryAttr DataLayoutEntryAttr::parse(AsmParser &parser) {
void DataLayoutEntryAttr::print(AsmPrinter &os) const {
os << DataLayoutEntryAttr::kAttrKeyword << "<";
- if (auto type = getKey().dyn_cast<Type>())
+ if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
os << type;
else
os << "\"" << getKey().get<StringAttr>().strref() << "\"";
@@ -151,7 +151,7 @@ DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
DenseSet<Type> types;
DenseSet<StringAttr> ids;
for (DataLayoutEntryInterface entry : entries) {
- if (auto type = entry.getKey().dyn_cast<Type>()) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
if (!types.insert(type).second)
return emitError() << "repeated layout entry key: " << type;
} else {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 65a20a0c426b2..06e8d79f68a6e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -493,7 +493,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
// error. All other canonicalization is done in the fold method.
bool requiresConst = !rawConstantIndices.empty() &&
currType.isa_and_nonnull<LLVMStructType>();
- if (Value val = iter.dyn_cast<Value>()) {
+ if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
APInt intC;
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
intC.isSignedIntN(kGEPConstantBitWidth)) {
@@ -598,7 +598,7 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
llvm::interleaveComma(
GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
[&](PointerUnion<IntegerAttr, Value> cst) {
- if (Value val = cst.dyn_cast<Value>())
+ if (Value val = llvm::dyn_cast_if_present<Value>(cst))
printer.printOperand(val);
else
printer << cst.get<IntegerAttr>().getInt();
@@ -2495,7 +2495,7 @@ OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
!integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
- if (Value val = existing.dyn_cast<Value>())
+ if (Value val = llvm::dyn_cast_if_present<Value>(existing))
gepArgs.emplace_back(val);
else
gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 89693ec69cf6c..b36ed78d6685c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -261,7 +261,7 @@ DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
return llvm::all_of(gepOp.getIndices(), [](auto index) {
- auto indexAttr = index.template dyn_cast<IntegerAttr>();
+ auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
return indexAttr && indexAttr.getValue() == 0;
});
}
@@ -289,7 +289,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
// Ensures all indices are static and fetches them.
SmallVector<IntegerAttr> indices;
for (auto index : gep.getIndices()) {
- IntegerAttr indexInt = index.dyn_cast<IntegerAttr>();
+ IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
if (!indexInt)
return {};
indices.push_back(indexInt);
@@ -310,7 +310,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
for (IntegerAttr index : llvm::drop_begin(indices)) {
// Ensure the structure of the type being indexed can be reasoned about.
// This includes rejecting any potential typed pointer.
- auto destructurable = selectedType.dyn_cast<DestructurableTypeInterface>();
+ auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
if (!destructurable)
return {};
@@ -343,7 +343,7 @@ LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
- auto basePtrType = getBase().getType().dyn_cast<LLVM::LLVMPointerType>();
+ auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(getBase().getType());
if (!basePtrType)
return false;
@@ -359,7 +359,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
return false;
auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);
assert(slot.elementPtrs.contains(firstLevelIndex));
- if (!slot.elementPtrs.at(firstLevelIndex).isa<LLVM::LLVMPointerType>())
+ if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
return false;
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
usedIndices.insert(firstLevelIndex);
@@ -369,7 +369,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
- IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast<IntegerAttr>();
+ IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
@@ -414,7 +414,7 @@ LLVM::LLVMStructType::getSubelementIndexMap() {
}
Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
- auto indexAttr = index.dyn_cast<IntegerAttr>();
+ auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
if (!indexAttr || !indexAttr.getType().isInteger(32))
return {};
int32_t indexInt = indexAttr.getInt();
@@ -439,7 +439,7 @@ LLVM::LLVMArrayType::getSubelementIndexMap() const {
}
Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
- auto indexAttr = index.dyn_cast<IntegerAttr>();
+ auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
if (!indexAttr || !indexAttr.getType().isInteger(32))
return {};
int32_t indexInt = indexAttr.getInt();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index dcbfbf32a1487..be129ffe2aadc 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -354,7 +354,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
const auto *it =
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
- if (auto type = entry.getKey().dyn_cast<Type>()) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
newType.getAddressSpace();
}
@@ -362,7 +362,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
});
if (it == oldLayout.end()) {
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
- if (auto type = entry.getKey().dyn_cast<Type>()) {
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
}
return false;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a63f6647cf7b3..52699db910461 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2368,7 +2368,7 @@ transform::TileOp::apply(TransformResults &transformResults,
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
- if (auto attr = ofr.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), cast<IntegerAttr>(attr).getInt()));
continue;
@@ -2794,7 +2794,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
- if (auto attr = ofr.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), cast<IntegerAttr>(attr).getInt()));
} else {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 33ff4a3ecc091..6a800578f4ac8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1447,7 +1447,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
cast<LinalgOp>(genericOp.getOperation())
.createLoopRanges(rewriter, genericOp.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
- if (auto attr = ofr.dyn_cast<Attribute>())
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
llvm::APInt actual;
return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index d39cd0e686e00..e952f9437ce8f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -229,7 +229,7 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
// to look for the bound.
LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
Value size;
- if (auto attr = rangeValue.size.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
} else {
Value materializedSize =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index 203ae437a2a5a..bbe3a542f66b8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -92,7 +92,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
rewriter, op.getLoc(), d0 + d1 - d2,
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
minSplitPoint});
- if (auto attr = remainingSize.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
if (cast<IntegerAttr>(attr).getValue().isZero())
return {op, TilingInterface()};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 1293d03f15bfc..5ef34b198d789 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -48,7 +48,7 @@ using namespace mlir::scf;
static bool isZero(OpFoldResult v) {
if (!v)
return false;
- if (auto attr = v.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
@@ -104,7 +104,7 @@ void mlir::linalg::transformIndexOps(
/// checked at runtime.
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
OpFoldResult value) {
- if (auto attr = value.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
"expected strictly positive tile size and divisor");
return;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 6f932bd9fead8..4dceab36f4501 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1135,7 +1135,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const {
// Given an OpFoldResult, return an index-typed value.
auto getIdxValue = [&](OpFoldResult ofr) {
- if (auto val = ofr.dyn_cast<Value>())
+ if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
return val;
return rewriter
.create<arith::ConstantIndexOp>(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index aae36035eeece..d081e1a90a0d2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1646,7 +1646,7 @@ static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> ofrs) {
SmallVector<Value> result;
for (auto o : ofrs) {
- if (auto val = o.template dyn_cast<Value>()) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
result.push_back(val);
} else {
result.push_back(rewriter.create<arith::ConstantIndexOp>(
@@ -1954,8 +1954,8 @@ struct PadOpVectorizationWithTransferWritePattern
continue;
// Other cases: Take a deeper look at defining ops of values.
- auto v1 = size1.dyn_cast<Value>();
- auto v2 = size2.dyn_cast<Value>();
+ auto v1 = llvm::dyn_cast_if_present<Value>(size1);
+ auto v2 = llvm::dyn_cast_if_present<Value>(size2);
if (!v1 || !v2)
return false;
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index ef31668ed25b1..d5eea24534982 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -970,7 +970,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
auto dim = it.index();
auto size = it.value();
curr.push_back(dim);
- auto attr = size.dyn_cast<Attribute>();
+ auto attr = llvm::dyn_cast_if_present<Attribute>(size);
if (attr && cast<IntegerAttr>(attr).getInt() == 1)
continue;
reassociation.emplace_back(ReassociationIndices{});
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 0474d1923e086..faffe9a5969eb 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -64,7 +64,7 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
//===----------------------------------------------------------------------===//
static bool isSupportedElementType(Type type) {
- return type.isa<MemRefType>() ||
+ return llvm::isa<MemRefType>(type) ||
OpBuilder(type.getContext()).getZeroAttr(type);
}
@@ -110,7 +110,7 @@ void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
SmallVector<DestructurableMemorySlot>
memref::AllocaOp::getDestructurableSlots() {
MemRefType memrefType = getType();
- auto destructurable = memrefType.dyn_cast<DestructurableTypeInterface>();
+ auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
if (!destructurable)
return {};
@@ -134,7 +134,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> slotMap;
- auto memrefType = getType().cast<DestructurableTypeInterface>();
+ auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
for (Attribute usedIndex : usedIndices) {
Type elemType = memrefType.getTypeAtIndex(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
@@ -281,7 +281,7 @@ struct MemRefDestructurableTypeExternalModel
MemRefDestructurableTypeExternalModel, MemRefType> {
std::optional<DenseMap<Attribute, Type>>
getSubelementIndexMap(Type type) const {
- auto memrefType = type.cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(type);
constexpr int64_t maxMemrefSizeForDestructuring = 16;
if (!memrefType.hasStaticShape() ||
memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
@@ -298,15 +298,15 @@ struct MemRefDestructurableTypeExternalModel
}
Type getTypeAtIndex(Type type, Attribute index) const {
- auto memrefType = type.cast<MemRefType>();
- auto coordArrAttr = index.dyn_cast<ArrayAttr>();
+ auto memrefType = llvm::cast<MemRefType>(type);
+ auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
return {};
Type indexType = IndexType::get(memrefType.getContext());
for (const auto &[coordAttr, dimSize] :
llvm::zip(coordArrAttr, memrefType.getShape())) {
- auto coord = coordAttr.dyn_cast<IntegerAttr>();
+ auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
coord.getInt() >= dimSize)
return {};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 3beda2cafff12..8ed790c421a4c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -970,7 +970,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
return unusedDims;
for (const auto &dim : llvm::enumerate(sizes))
- if (auto attr = dim.value().dyn_cast<Attribute>())
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
unusedDims.set(dim.index());
@@ -1042,7 +1042,7 @@ llvm::SmallBitVector SubViewOp::getDroppedDims() {
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
- auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
+ auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
if (!index)
return {};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 9b1d85b290274..431d270b0a2cb 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -56,7 +56,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// Because we only support input strides of 1, the output stride is also
// always 1.
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
- Attribute attr = valueOrAttr.dyn_cast<Attribute>();
+ Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
return attr && cast<IntegerAttr>(attr).getInt() == 1;
})) {
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
@@ -86,8 +86,9 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
}
sizes.push_back(opSize);
- Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
- sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
+ Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
+ sourceOffsetAttr =
+ llvm::dyn_cast_if_present<Attribute>(sourceOffset);
if (opOffsetAttr && sourceOffsetAttr) {
// If both offsets are static we can simply calculate the combined
@@ -101,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
AffineExpr expr = rewriter.getAffineConstantExpr(0);
SmallVector<Value> affineApplyOperands;
for (auto valueOrAttr : {opOffset, sourceOffset}) {
- if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
expr = expr + cast<IntegerAttr>(attr).getInt();
} else {
expr =
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 15be4d51f5388..fab270a6f1730 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -520,7 +520,7 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
<< operandName << " operand appears more than once";
mlir::Type varType = operand.getType();
- auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
+ auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
if (!decl)
return op->emitOpError()
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 65cca0e494209..4e3d899c062c6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -802,10 +802,10 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands,
for (const auto &mapTypeOp : *map_types) {
int64_t mapTypeBits = 0x00;
- if (!mapTypeOp.isa<mlir::IntegerAttr>())
+ if (!llvm::isa<mlir::IntegerAttr>(mapTypeOp))
return failure();
- mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt();
+ mapTypeBits = llvm::cast<mlir::IntegerAttr>(mapTypeOp).getInt();
bool to =
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 4b0d0e40740f0..b6cb8c7f438e6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -381,7 +381,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
// map.
auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
#ifndef NDEBUG
- auto iterRanked = initArgBufferType->cast<MemRefType>();
+ auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
"expected same shape");
assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
@@ -802,7 +802,7 @@ struct WhileOpInterface
if (!isa<TensorType>(bbArg.getType()))
return bbArg.getType();
// TODO: error handling
- return bufferization::getBufferType(bbArg, options)->cast<Type>();
+ return llvm::cast<Type>(*bufferization::getBufferType(bbArg, options));
}));
// Construct a new scf.while op with memref instead of tensor values.
diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
index 11df319fffc32..89dfb61d48af9 100644
--- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
@@ -88,10 +88,10 @@ LogicalResult scf::addLoopRangeConstraints(FlatAffineValueConstraints &cstr,
return failure();
unsigned dimIv = cstr.appendDimVar(iv);
- auto lbv = lb.dyn_cast<Value>();
+ auto lbv = llvm::dyn_cast_if_present<Value>(lb);
unsigned symLb =
lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1);
- auto ubv = ub.dyn_cast<Value>();
+ auto ubv = llvm::dyn_cast_if_present<Value>(ub);
unsigned symUb =
ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1);
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 164e2e024423c..cb4ae4efe6062 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -152,7 +152,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
if (getIndices().size() == 1 &&
constructOp.getConstituents().size() == type.getNumElements()) {
- auto i = getIndices().begin()->cast<IntegerAttr>();
+ auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
return constructOp.getConstituents()[i.getValue().getSExtValue()];
}
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 16737560362d1..6747f75f468a0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1562,8 +1562,8 @@ LogicalResult spirv::BitcastOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ConvertPtrToUOp::verify() {
- auto operandType = getPointer().getType().cast<spirv::PointerType>();
- auto resultType = getResult().getType().cast<spirv::ScalarType>();
+ auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
if (!resultType || !resultType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
@@ -1583,8 +1583,8 @@ LogicalResult spirv::ConvertPtrToUOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ConvertUToPtrOp::verify() {
- auto operandType = getOperand().getType().cast<spirv::ScalarType>();
- auto resultType = getResult().getType().cast<spirv::PointerType>();
+ auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
+ auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
if (!operandType || !operandType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index fbcc5d84a2701..30fc3e1d11bb1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -125,23 +125,23 @@ Type CompositeType::getElementType(unsigned index) const {
}
unsigned CompositeType::getNumElements() const {
- if (auto arrayType = dyn_cast<ArrayType>())
+ if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
return arrayType.getNumElements();
- if (auto matrixType = dyn_cast<MatrixType>())
+ if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
return matrixType.getNumColumns();
- if (auto structType = dyn_cast<StructType>())
+ if (auto structType = llvm::dyn_cast<StructType>(*this))
return structType.getNumElements();
- if (auto vectorType = dyn_cast<VectorType>())
+ if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
- if (isa<CooperativeMatrixNVType>()) {
+ if (llvm::isa<CooperativeMatrixNVType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
}
- if (isa<JointMatrixINTELType>()) {
+ if (llvm::isa<JointMatrixINTELType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::JointMatrix type");
}
- if (isa<RuntimeArrayType>()) {
+ if (llvm::isa<RuntimeArrayType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
}
@@ -149,8 +149,8 @@ unsigned CompositeType::getNumElements() const {
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
- return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
- RuntimeArrayType>();
+ return !llvm::isa<CooperativeMatrixNVType, JointMatrixINTELType,
+ RuntimeArrayType>(*this);
}
void CompositeType::getExtensions(
@@ -188,11 +188,11 @@ void CompositeType::getCapabilities(
}
std::optional<int64_t> CompositeType::getSizeInBytes() {
- if (auto arrayType = dyn_cast<ArrayType>())
+ if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
return arrayType.getSizeInBytes();
- if (auto structType = dyn_cast<StructType>())
+ if (auto structType = llvm::dyn_cast<StructType>(*this))
return structType.getSizeInBytes();
- if (auto vectorType = dyn_cast<VectorType>()) {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
std::optional<int64_t> elementSize =
llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
if (!elementSize)
@@ -680,7 +680,7 @@ void ScalarType::getCapabilities(
capabilities.push_back(ref); \
} break
- if (auto intType = dyn_cast<IntegerType>()) {
+ if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
switch (bitwidth) {
WIDTH_CASE(Int, 8);
WIDTH_CASE(Int, 16);
@@ -692,7 +692,7 @@ void ScalarType::getCapabilities(
llvm_unreachable("invalid bitwidth to getCapabilities");
}
} else {
- assert(isa<FloatType>());
+ assert(llvm::isa<FloatType>(*this));
switch (bitwidth) {
WIDTH_CASE(Float, 16);
WIDTH_CASE(Float, 64);
@@ -735,22 +735,22 @@ bool SPIRVType::classof(Type type) {
}
bool SPIRVType::isScalarOrVector() {
- return isIntOrFloat() || isa<VectorType>();
+ return isIntOrFloat() || llvm::isa<VectorType>(*this);
}
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- if (auto scalarType = dyn_cast<ScalarType>()) {
+ if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
scalarType.getExtensions(extensions, storage);
- } else if (auto compositeType = dyn_cast<CompositeType>()) {
+ } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
compositeType.getExtensions(extensions, storage);
- } else if (auto imageType = dyn_cast<ImageType>()) {
+ } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
imageType.getExtensions(extensions, storage);
- } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
+ } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
sampledImageType.getExtensions(extensions, storage);
- } else if (auto matrixType = dyn_cast<MatrixType>()) {
+ } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
matrixType.getExtensions(extensions, storage);
- } else if (auto ptrType = dyn_cast<PointerType>()) {
+ } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
@@ -760,17 +760,17 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
void SPIRVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
- if (auto scalarType = dyn_cast<ScalarType>()) {
+ if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
scalarType.getCapabilities(capabilities, storage);
- } else if (auto compositeType = dyn_cast<CompositeType>()) {
+ } else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
compositeType.getCapabilities(capabilities, storage);
- } else if (auto imageType = dyn_cast<ImageType>()) {
+ } else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
imageType.getCapabilities(capabilities, storage);
- } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
+ } else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
sampledImageType.getCapabilities(capabilities, storage);
- } else if (auto matrixType = dyn_cast<MatrixType>()) {
+ } else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
matrixType.getCapabilities(capabilities, storage);
- } else if (auto ptrType = dyn_cast<PointerType>()) {
+ } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
@@ -778,9 +778,9 @@ void SPIRVType::getCapabilities(
}
std::optional<int64_t> SPIRVType::getSizeInBytes() {
- if (auto scalarType = dyn_cast<ScalarType>())
+ if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
return scalarType.getSizeInBytes();
- if (auto compositeType = dyn_cast<CompositeType>())
+ if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
return compositeType.getSizeInBytes();
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 58d0e6aa39644..b1dffbf98920d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -856,9 +856,9 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getLhs() || !adaptor.getRhs())
return nullptr;
auto lhsShape = llvm::to_vector<6>(
- adaptor.getLhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
- adaptor.getRhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
resultShape.append(lhsShape.begin(), lhsShape.end());
resultShape.append(rhsShape.begin(), rhsShape.end());
@@ -989,7 +989,7 @@ OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
if (!operand)
return false;
extents.push_back(llvm::to_vector<6>(
- operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
+ llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
}
return OpTrait::util::staticallyKnownBroadcastable(extents);
}())
@@ -1132,10 +1132,10 @@ LogicalResult mlir::shape::DimOp::verify() {
//===----------------------------------------------------------------------===//
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
- auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
+ auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return nullptr;
- auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
+ auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return nullptr;
@@ -1346,7 +1346,7 @@ std::optional<int64_t> GetExtentOp::getConstantDim() {
}
OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
- auto elements = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
+ auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
if (!elements)
return nullptr;
std::optional<int64_t> dim = getConstantDim();
@@ -1490,7 +1490,7 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
//===----------------------------------------------------------------------===//
OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
- auto shape = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
+ auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
if (!shape)
return {};
int64_t rank = shape.getNumElements();
@@ -1671,10 +1671,10 @@ bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
- auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
+ auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return nullptr;
- auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
+ auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return nullptr;
APInt folded = lhs.getValue() * rhs.getValue();
@@ -1864,9 +1864,9 @@ LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
if (!adaptor.getOperand() || !adaptor.getIndex())
return failure();
auto shapeVec = llvm::to_vector<6>(
- adaptor.getOperand().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
auto shape = llvm::ArrayRef(shapeVec);
- auto splitPoint = adaptor.getIndex().cast<IntegerAttr>().getInt();
+ auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
// Verify that the split point is in the correct range.
// TODO: Constant fold to an "error".
int64_t rank = shape.size();
@@ -1889,7 +1889,7 @@ OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
Builder builder(getContext());
auto shape = llvm::to_vector<6>(
- adaptor.getInput().cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
builder.getIndexType());
return DenseIntElementsAttr::get(type, shape);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 0ecc77f228e42..3175e957698d0 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -815,7 +815,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
Level cooStartLvl = getCOOStart(stt.getEncoding());
if (cooStartLvl < stt.getLvlRank()) {
// We only supports trailing COO for now, must be the last input.
- auto cooTp = lvlTps.back().cast<ShapedType>();
+ auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
// The coordinates should be in shape of <? x rank>
unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
@@ -844,7 +844,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
inputTp = lvlTps[idx++];
}
// The input element type and expected element type should match.
- Type inpElemTp = inputTp.cast<TensorType>().getElementType();
+ Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
Type expElemTp = getFieldElemType(stt, fKind);
if (inpElemTp != expElemTp) {
misMatch = true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 5a1615ee7f197..246e5d985d87c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -188,7 +188,7 @@ static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
/// Generates a memref from tensor operation.
static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
Value tensor) {
- auto tensorType = tensor.getType().cast<ShapedType>();
+ auto tensorType = llvm::cast<ShapedType>(tensor.getType());
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index f6405d2a47e4f..20d0c5e7d4f1b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -414,7 +414,7 @@ class SparseInsertGenerator
/// TODO: better unord/not-unique; also generalize, optimize, specialize!
SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
OpBuilder &builder, Location loc) {
- const SparseTensorType stt(rtp.cast<RankedTensorType>());
+ const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
const Level lvlRank = stt.getLvlRank();
// Extract fields and coordinates from args.
SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
@@ -466,7 +466,7 @@ class SparseInsertGenerator
// The mangled name of the function has this format:
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
constexpr const char kInsertFuncNamePrefix[] = "_insert_";
- const SparseTensorType stt(rtp.cast<RankedTensorType>());
+ const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
@@ -541,14 +541,14 @@ static void genEndInsert(OpBuilder &builder, Location loc,
static TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor) {
- auto tTp = tensor.getType().cast<TensorType>();
+ auto tTp = llvm::cast<TensorType>(tensor.getType());
auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
.getResult();
}
Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) {
- auto elemTp = mem.getType().cast<MemRefType>().getElementType();
+ auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
return builder
.create<memref::SubViewOp>(
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 8a52082dfca5f..a964d9116f11c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -180,7 +180,7 @@ struct ReifyPadOp
AffineExpr expr = b.getAffineDimExpr(0);
unsigned numSymbols = 0;
auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
- if (Value v = valueOrAttr.dyn_cast<Value>()) {
+ if (Value v = llvm::dyn_cast_if_present<Value>(valueOrAttr)) {
expr = expr + b.getAffineSymbolExpr(numSymbols++);
mapOperands.push_back(v);
return;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index eab64b5cf9994..1adb9c7f262fe 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -501,7 +501,7 @@ Speculation::Speculatability DimOp::getSpeculatability() {
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
- auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
+ auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
if (!index)
return {};
@@ -764,7 +764,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
OpFoldResult currDim = std::get<1>(it);
// Case 1: The empty tensor dim is static. Check that the tensor cast
// result dim matches.
- if (auto attr = currDim.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
if (ShapedType::isDynamic(newDim) ||
newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
// Something is off, the cast result shape cannot be more dynamic
@@ -2106,7 +2106,7 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
}
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
- if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
+ if (auto splat = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
auto resultType = llvm::cast<ShapedType>(getResult().getType());
if (resultType.hasStaticShape())
return splat.resizeSplat(resultType);
@@ -3558,7 +3558,7 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
SmallVector<int64_t> result;
for (auto o : ofrs) {
// Have to do this first, as getConstantIntValue special-cases constants.
- if (o.dyn_cast<Value>())
+ if (llvm::dyn_cast_if_present<Value>(o))
result.push_back(ShapedType::kDynamic);
else
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 935a1b9fededf..545a9d09c4aba 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -76,7 +76,7 @@ struct CastOpInterface
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
return MemRefType::get(
rankedResultType.getShape(), rankedResultType.getElementType(),
- maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
+ llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -139,7 +139,7 @@ struct CollapseShapeOpInterface
collapseShapeOp.getSrc(), options, fixedTypes);
if (failed(maybeSrcBufferType))
return failure();
- auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+ auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
srcBufferType, collapseShapeOp.getReassociationIndices());
@@ -303,7 +303,7 @@ struct ExpandShapeOpInterface
expandShapeOp.getSrc(), options, fixedTypes);
if (failed(maybeSrcBufferType))
return failure();
- auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+ auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
srcBufferType, expandShapeOp.getResultType().getShape(),
expandShapeOp.getReassociationIndices());
@@ -369,7 +369,7 @@ struct ExtractSliceOpInterface
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, resultMemrefType->cast<MemRefType>(), *srcMemref, mixedOffsets,
+ loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
mixedSizes, mixedStrides);
replaceOpWithBufferizedValues(rewriter, op, subView);
@@ -389,7 +389,7 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(), srcMemrefType->cast<MemRefType>(),
+ extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
mixedOffsets, mixedSizes, mixedStrides));
}
};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e95e6282049b1..fb3b934b4f9af 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -548,8 +548,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
@@ -573,8 +573,8 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsAttr && lhsAttr.isSplat()) {
if (llvm::isa<IntegerType>(resultETy) &&
lhsAttr.getSplatValue<APInt>().isZero())
@@ -642,8 +642,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
if (rhsTy == resultTy) {
@@ -670,8 +670,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
@@ -713,8 +713,8 @@ struct APIntFoldGreaterEqual {
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
@@ -725,8 +725,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
@@ -738,8 +738,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
- auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
- auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
+ auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
+ auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
Value lhs = getInput1();
Value rhs = getInput2();
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
@@ -763,7 +763,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
if (getInput().getType() == getType())
return getInput();
- auto operand = adaptor.getInput().dyn_cast_or_null<ElementsAttr>();
+ auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
if (!operand)
return {};
@@ -852,7 +852,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (inputTy == outputTy)
return getInput1();
- auto operand = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
+ auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
@@ -863,7 +863,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
// If the pad is all zeros we can fold this operation away.
if (adaptor.getPadding()) {
- auto densePad = adaptor.getPadding().cast<DenseElementsAttr>();
+ auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
return getInput1();
}
@@ -907,7 +907,7 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
- auto operandAttr = adaptor.getInput().dyn_cast_or_null<SplatElementsAttr>();
+ auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
if (operandAttr)
return operandAttr;
@@ -936,7 +936,7 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
!outputTy.getElementType().isIntOrIndexOrFloat())
return {};
- auto operand = adaptor.getInput().cast<ElementsAttr>();
+ auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
if (operand.isSplat() && outputTy.hasStaticShape()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
@@ -955,7 +955,7 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
if (getOnTrue() == getOnFalse())
return getOnTrue();
- auto predicate = adaptor.getPred().dyn_cast_or_null<DenseIntElementsAttr>();
+ auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
if (!predicate)
return {};
@@ -977,7 +977,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::cast<ShapedType>(getType());
// Transposing splat values just means reshaping.
- if (auto input = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>()) {
+ if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
if (input.isSplat() && resultTy.hasStaticShape() &&
inputTy.getElementType() == resultTy.getElementType())
return input.reshape(resultTy);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 829db2a86c44a..94cbb0afd2744 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -63,9 +63,9 @@ LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
// Verify the rank agrees with the output type if the output type is ranked.
if (outputType) {
if (outputType.getRank() !=
- input1_copy.getType().cast<RankedTensorType>().getRank() ||
+ llvm::cast<RankedTensorType>(input1_copy.getType()).getRank() ||
outputType.getRank() !=
- input2_copy.getType().cast<RankedTensorType>().getRank())
+ llvm::cast<RankedTensorType>(input2_copy.getType()).getRank())
return rewriter.notifyMatchFailure(
loc, "the reshaped type doesn't agrees with the ranked output type");
}
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 8f84a064382f4..d260c93e1cf44 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -103,8 +103,8 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2) {
- auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
- auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
+ auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
+ auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
if (!input1Ty || !input2Ty) {
return failure();
@@ -126,9 +126,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
}
ArrayRef<int64_t> higherRankShape =
- higherTensorValue.getType().cast<RankedTensorType>().getShape();
+ llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
ArrayRef<int64_t> lowerRankShape =
- lowerTensorValue.getType().cast<RankedTensorType>().getShape();
+ llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
SmallVector<int64_t, 4> reshapeOutputShape;
@@ -136,7 +136,8 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
.failed())
return failure();
- auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
+ auto reshapeInputType =
+ llvm::cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index b3176b144451d..695b4a39f3d65 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -118,7 +118,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
SmallVector<Operation *> operations;
operations.reserve(values.size());
for (transform::MappedValue value : values) {
- if (auto *op = value.dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
operations.push_back(op);
continue;
}
@@ -135,7 +135,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
SmallVector<Value> payloadValues;
payloadValues.reserve(values.size());
for (transform::MappedValue value : values) {
- if (auto v = value.dyn_cast<Value>()) {
+ if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
payloadValues.push_back(v);
continue;
}
@@ -152,7 +152,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
SmallVector<transform::Param> parameters;
parameters.reserve(values.size());
for (transform::MappedValue value : values) {
- if (auto attr = value.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
parameters.push_back(attr);
continue;
}
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 09137d3336cc0..75d2dcec13643 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -18,7 +18,7 @@ namespace mlir {
bool isZeroIndex(OpFoldResult v) {
if (!v)
return false;
- if (auto attr = v.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
@@ -51,7 +51,7 @@ getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
void dispatchIndexOpFoldResult(OpFoldResult ofr,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec) {
- auto v = ofr.dyn_cast<Value>();
+ auto v = llvm::dyn_cast_if_present<Value>(ofr);
if (!v) {
APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
staticVec.push_back(apInt.getSExtValue());
@@ -116,14 +116,14 @@ SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
/// If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
- if (auto val = ofr.dyn_cast<Value>()) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
APSInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
return intVal.getSExtValue();
return std::nullopt;
}
// Case 2: Check for IntegerAttr.
- Attribute attr = ofr.dyn_cast<Attribute>();
+ Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
@@ -143,7 +143,8 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
if (cst1 && cst2 && *cst1 == *cst2)
return true;
- auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
+ auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
+ v2 = llvm::dyn_cast_if_present<Value>(ofr2);
return v1 && v1 == v2;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index aac677792e768..20c088c2acfe1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1154,7 +1154,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractOp::Adaptor op(operands, attributes, properties);
- auto vectorType = op.getVector().getType().cast<VectorType>();
+ auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
} else {
@@ -2003,9 +2003,9 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
- if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
+ if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, adaptor.getSource());
- if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
+ if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
return {};
}
@@ -2090,7 +2090,7 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ShuffleOp::Adaptor op(operands, attributes, properties);
- auto v1Type = op.getV1().getType().cast<VectorType>();
+ auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
auto v1Rank = v1Type.getRank();
// Construct resulting type: leading dimension matches mask
// length, all trailing dimensions match the operands.
@@ -4951,7 +4951,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
- if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
+ if (auto attr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
if (attr.isSplat())
return attr.reshape(getResultVectorType());
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index a75bc584ef3a5..ef0bf75a9cd67 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3642,7 +3642,7 @@ void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
if (auto *op = getDefiningOp())
return op->print(os, flags);
// TODO: Improve BlockArgument print'ing.
- BlockArgument arg = this->cast<BlockArgument>();
+ BlockArgument arg = llvm::cast<BlockArgument>(*this);
os << "<block argument> of type '" << arg.getType()
<< "' at index: " << arg.getArgNumber();
}
@@ -3656,7 +3656,7 @@ void Value::print(raw_ostream &os, AsmState &state) {
return op->print(os, state);
// TODO: Improve BlockArgument print'ing.
- BlockArgument arg = this->cast<BlockArgument>();
+ BlockArgument arg = llvm::cast<BlockArgument>(*this);
os << "<block argument> of type '" << arg.getType()
<< "' at index: " << arg.getArgNumber();
}
@@ -3693,10 +3693,10 @@ static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
Operation *op;
- if (auto result = dyn_cast<OpResult>()) {
+ if (auto result = llvm::dyn_cast<OpResult>(*this)) {
op = result.getOwner();
} else {
- op = cast<BlockArgument>().getOwner()->getParentOp();
+ op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
if (!op) {
os << "<<UNKNOWN SSA VALUE>>";
return;
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index dc4ec1458bd95..069be73908acd 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -347,14 +347,14 @@ BlockRange::BlockRange(SuccessorRange successors)
/// See `llvm::detail::indexed_accessor_range_base` for details.
BlockRange::OwnerT BlockRange::offset_base(OwnerT object, ptr
diff _t index) {
- if (auto *operand = object.dyn_cast<BlockOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
return {operand + index};
- return {object.dyn_cast<Block *const *>() + index};
+ return {llvm::dyn_cast_if_present<Block *const *>(object) + index};
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Block *BlockRange::dereference_iterator(OwnerT object, ptr
diff _t index) {
- if (const auto *operand = object.dyn_cast<BlockOperand *>())
+ if (const auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
return operand[index].get();
- return object.dyn_cast<Block *const *>()[index];
+ return llvm::dyn_cast_if_present<Block *const *>(object)[index];
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index c4fad9c4b3d49..22abd52120680 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -483,7 +483,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
Type expectedType = std::get<1>(it);
// Normal values get pushed back directly.
- if (auto value = std::get<0>(it).dyn_cast<Value>()) {
+ if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
if (value.getType() != expectedType)
return cleanupFailure();
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index de26c8e05802a..9a18f07efed01 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1247,12 +1247,12 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
DenseElementsAttr
DenseElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APInt &)> mapping) const {
- return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
+ return llvm::cast<DenseIntElementsAttr>(*this).mapValues(newElementType, mapping);
}
DenseElementsAttr DenseElementsAttr::mapValues(
Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
- return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
+ return llvm::cast<DenseFPElementsAttr>(*this).mapValues(newElementType, mapping);
}
ShapedType DenseElementsAttr::getType() const {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index c816e4a6dbcf3..eea07edfdab3c 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -88,45 +88,45 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
//===----------------------------------------------------------------------===//
unsigned FloatType::getWidth() {
- if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
- Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
+ if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
return 8;
- if (isa<Float16Type, BFloat16Type>())
+ if (llvm::isa<Float16Type, BFloat16Type>(*this))
return 16;
- if (isa<Float32Type>())
+ if (llvm::isa<Float32Type>(*this))
return 32;
- if (isa<Float64Type>())
+ if (llvm::isa<Float64Type>(*this))
return 64;
- if (isa<Float80Type>())
+ if (llvm::isa<Float80Type>(*this))
return 80;
- if (isa<Float128Type>())
+ if (llvm::isa<Float128Type>(*this))
return 128;
llvm_unreachable("unexpected float type");
}
/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() {
- if (isa<Float8E5M2Type>())
+ if (llvm::isa<Float8E5M2Type>(*this))
return APFloat::Float8E5M2();
- if (isa<Float8E4M3FNType>())
+ if (llvm::isa<Float8E4M3FNType>(*this))
return APFloat::Float8E4M3FN();
- if (isa<Float8E5M2FNUZType>())
+ if (llvm::isa<Float8E5M2FNUZType>(*this))
return APFloat::Float8E5M2FNUZ();
- if (isa<Float8E4M3FNUZType>())
+ if (llvm::isa<Float8E4M3FNUZType>(*this))
return APFloat::Float8E4M3FNUZ();
- if (isa<Float8E4M3B11FNUZType>())
+ if (llvm::isa<Float8E4M3B11FNUZType>(*this))
return APFloat::Float8E4M3B11FNUZ();
- if (isa<BFloat16Type>())
+ if (llvm::isa<BFloat16Type>(*this))
return APFloat::BFloat();
- if (isa<Float16Type>())
+ if (llvm::isa<Float16Type>(*this))
return APFloat::IEEEhalf();
- if (isa<Float32Type>())
+ if (llvm::isa<Float32Type>(*this))
return APFloat::IEEEsingle();
- if (isa<Float64Type>())
+ if (llvm::isa<Float64Type>(*this))
return APFloat::IEEEdouble();
- if (isa<Float80Type>())
+ if (llvm::isa<Float80Type>(*this))
return APFloat::x87DoubleExtended();
- if (isa<Float128Type>())
+ if (llvm::isa<Float128Type>(*this))
return APFloat::IEEEquad();
llvm_unreachable("non-floating point type used");
}
@@ -269,21 +269,21 @@ Type TensorType::getElementType() const {
[](auto type) { return type.getElementType(); });
}
-bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
+bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
ArrayRef<int64_t> TensorType::getShape() const {
- return cast<RankedTensorType>().getShape();
+ return llvm::cast<RankedTensorType>(*this).getShape();
}
TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
- if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
+ if (auto unrankedTy = llvm::dyn_cast<UnrankedTensorType>(*this)) {
if (shape)
return RankedTensorType::get(*shape, elementType);
return UnrankedTensorType::get(elementType);
}
- auto rankedTy = cast<RankedTensorType>();
+ auto rankedTy = llvm::cast<RankedTensorType>(*this);
if (!shape)
return RankedTensorType::get(rankedTy.getShape(), elementType,
rankedTy.getEncoding());
@@ -356,15 +356,15 @@ Type BaseMemRefType::getElementType() const {
[](auto type) { return type.getElementType(); });
}
-bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
+bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
ArrayRef<int64_t> BaseMemRefType::getShape() const {
- return cast<MemRefType>().getShape();
+ return llvm::cast<MemRefType>(*this).getShape();
}
BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
- if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
+ if (auto unrankedTy = llvm::dyn_cast<UnrankedMemRefType>(*this)) {
if (!shape)
return UnrankedMemRefType::get(elementType, getMemorySpace());
MemRefType::Builder builder(*shape, elementType);
@@ -372,7 +372,7 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
return builder;
}
- MemRefType::Builder builder(cast<MemRefType>());
+ MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
if (shape)
builder.setShape(*shape);
builder.setElementType(elementType);
@@ -389,15 +389,15 @@ MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
}
Attribute BaseMemRefType::getMemorySpace() const {
- if (auto rankedMemRefTy = dyn_cast<MemRefType>())
+ if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
return rankedMemRefTy.getMemorySpace();
- return cast<UnrankedMemRefType>().getMemorySpace();
+ return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
}
unsigned BaseMemRefType::getMemorySpaceAsInt() const {
- if (auto rankedMemRefTy = dyn_cast<MemRefType>())
+ if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
return rankedMemRefTy.getMemorySpaceAsInt();
- return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
+ return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 0a4a19dbf2c69..c353188d964e3 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -626,17 +626,17 @@ ValueRange::ValueRange(ResultRange values)
/// See `llvm::detail::indexed_accessor_range_base` for details.
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
ptr
diff _t index) {
- if (const auto *value = owner.dyn_cast<const Value *>())
+ if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
return {value + index};
- if (auto *operand = owner.dyn_cast<OpOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return {operand + index};
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Value ValueRange::dereference_iterator(const OwnerT &owner, ptr
diff _t index) {
- if (const auto *value = owner.dyn_cast<const Value *>())
+ if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
return value[index];
- if (auto *operand = owner.dyn_cast<OpOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return operand[index].get();
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
}
diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index 2b84a935371a9..e1caa8979eb41 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -267,18 +267,18 @@ RegionRange::RegionRange(ArrayRef<Region *> regions)
/// See `llvm::detail::indexed_accessor_range_base` for details.
RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner,
ptr
diff _t index) {
- if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
+ if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
return region + index;
- if (auto **region = owner.dyn_cast<Region **>())
+ if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
return region + index;
return &owner.get<Region *>()[index];
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Region *RegionRange::dereference_iterator(const OwnerT &owner,
ptr
diff _t index) {
- if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
+ if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
return region[index].get();
- if (auto **region = owner.dyn_cast<Region **>())
+ if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
return region[index];
return &owner.get<Region *>()[index];
}
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index c03f4dd5f352e..2494cb7086f0d 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -551,7 +551,7 @@ struct SymbolScope {
typename llvm::function_traits<CallbackT>::result_t,
void>::value> * = nullptr>
std::optional<WalkResult> walk(CallbackT cback) {
- if (Region *region = limit.dyn_cast<Region *>())
+ if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
return walkSymbolUses(*region, cback);
return walkSymbolUses(limit.get<Operation *>(), cback);
}
@@ -571,7 +571,7 @@ struct SymbolScope {
/// traversing into any nested symbol tables.
template <typename CallbackT>
std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
- if (Region *region = limit.dyn_cast<Region *>())
+ if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
return ::walkSymbolTable(*region, cback);
return ::walkSymbolTable(limit.get<Operation *>(), cback);
}
diff --git a/mlir/lib/IR/TypeRange.cpp b/mlir/lib/IR/TypeRange.cpp
index 2e2121af024fb..c05c0ce0d2544 100644
--- a/mlir/lib/IR/TypeRange.cpp
+++ b/mlir/lib/IR/TypeRange.cpp
@@ -27,9 +27,9 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
if (count == 0)
return;
ValueRange::OwnerT owner = values.begin().getBase();
- if (auto *result = owner.dyn_cast<detail::OpResultImpl *>())
+ if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(owner))
this->base = result;
- else if (auto *operand = owner.dyn_cast<OpOperand *>())
+ else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
this->base = operand;
else
this->base = owner.get<const Value *>();
@@ -37,22 +37,22 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
/// See `llvm::detail::indexed_accessor_range_base` for details.
TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptr
diff _t index) {
- if (const auto *value = object.dyn_cast<const Value *>())
+ if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
return {value + index};
- if (auto *operand = object.dyn_cast<OpOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
return {operand + index};
- if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
+ if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return {result->getNextResultAtOffset(index)};
- return {object.dyn_cast<const Type *>() + index};
+ return {llvm::dyn_cast_if_present<const Type *>(object) + index};
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Type TypeRange::dereference_iterator(OwnerT object, ptr
diff _t index) {
- if (const auto *value = object.dyn_cast<const Value *>())
+ if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
return (value + index)->getType();
- if (auto *operand = object.dyn_cast<OpOperand *>())
+ if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
return (operand + index)->get().getType();
- if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
+ if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return result->getNextResultAtOffset(index)->getType();
- return object.dyn_cast<const Type *>()[index];
+ return llvm::dyn_cast_if_present<const Type *>(object)[index];
}
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index d3d1d860d5d32..e376a5fd33922 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -34,84 +34,94 @@ Type AbstractType::replaceImmediateSubElements(Type type,
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
-bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
-bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
-bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
-bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
-bool Type::isFloat8E4M3B11FNUZ() const { return isa<Float8E4M3B11FNUZType>(); }
-bool Type::isBF16() const { return isa<BFloat16Type>(); }
-bool Type::isF16() const { return isa<Float16Type>(); }
-bool Type::isF32() const { return isa<Float32Type>(); }
-bool Type::isF64() const { return isa<Float64Type>(); }
-bool Type::isF80() const { return isa<Float80Type>(); }
-bool Type::isF128() const { return isa<Float128Type>(); }
-
-bool Type::isIndex() const { return isa<IndexType>(); }
+bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
+bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); }
+bool Type::isFloat8E5M2FNUZ() const {
+ return llvm::isa<Float8E5M2FNUZType>(*this);
+}
+bool Type::isFloat8E4M3FNUZ() const {
+ return llvm::isa<Float8E4M3FNUZType>(*this);
+}
+bool Type::isFloat8E4M3B11FNUZ() const {
+ return llvm::isa<Float8E4M3B11FNUZType>(*this);
+}
+bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
+bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
+bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
+bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
+bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
+bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
+
+bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
/// Return true if this is an integer type with the specified width.
bool Type::isInteger(unsigned width) const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.getWidth() == width;
return false;
}
bool Type::isSignlessInteger() const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSignless();
return false;
}
bool Type::isSignlessInteger(unsigned width) const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSignless() && intTy.getWidth() == width;
return false;
}
bool Type::isSignedInteger() const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSigned();
return false;
}
bool Type::isSignedInteger(unsigned width) const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSigned() && intTy.getWidth() == width;
return false;
}
bool Type::isUnsignedInteger() const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isUnsigned();
return false;
}
bool Type::isUnsignedInteger(unsigned width) const {
- if (auto intTy = dyn_cast<IntegerType>())
+ if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isUnsigned() && intTy.getWidth() == width;
return false;
}
bool Type::isSignlessIntOrIndex() const {
- return isSignlessInteger() || isa<IndexType>();
+ return isSignlessInteger() || llvm::isa<IndexType>(*this);
}
bool Type::isSignlessIntOrIndexOrFloat() const {
- return isSignlessInteger() || isa<IndexType, FloatType>();
+ return isSignlessInteger() || llvm::isa<IndexType, FloatType>(*this);
}
bool Type::isSignlessIntOrFloat() const {
- return isSignlessInteger() || isa<FloatType>();
+ return isSignlessInteger() || llvm::isa<FloatType>(*this);
}
-bool Type::isIntOrIndex() const { return isa<IntegerType>() || isIndex(); }
+bool Type::isIntOrIndex() const {
+ return llvm::isa<IntegerType>(*this) || isIndex();
+}
-bool Type::isIntOrFloat() const { return isa<IntegerType, FloatType>(); }
+bool Type::isIntOrFloat() const {
+ return llvm::isa<IntegerType, FloatType>(*this);
+}
bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); }
unsigned Type::getIntOrFloatBitWidth() const {
assert(isIntOrFloat() && "only integers and floats have a bitwidth");
- if (auto intType = dyn_cast<IntegerType>())
+ if (auto intType = llvm::dyn_cast<IntegerType>(*this))
return intType.getWidth();
- return cast<FloatType>().getWidth();
+ return llvm::cast<FloatType>(*this).getWidth();
}
diff --git a/mlir/lib/IR/Unit.cpp b/mlir/lib/IR/Unit.cpp
index 7da714fe7d539..c109d387e6b74 100644
--- a/mlir/lib/IR/Unit.cpp
+++ b/mlir/lib/IR/Unit.cpp
@@ -48,11 +48,11 @@ static void printBlock(llvm::raw_ostream &os, Block *block,
}
void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
- if (auto *op = this->dyn_cast<Operation *>())
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(*this))
return printOp(os, op, flags);
- if (auto *region = this->dyn_cast<Region *>())
+ if (auto *region = llvm::dyn_cast_if_present<Region *>(*this))
return printRegion(os, region, flags);
- if (auto *block = this->dyn_cast<Block *>())
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(*this))
return printBlock(os, block, flags);
llvm_unreachable("unknown IRUnit");
}
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 86b9cde76c05d..6b5195da5e47b 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -18,7 +18,7 @@ using namespace mlir::detail;
/// If this value is the result of an Operation, return the operation that
/// defines it.
Operation *Value::getDefiningOp() const {
- if (auto result = dyn_cast<OpResult>())
+ if (auto result = llvm::dyn_cast<OpResult>(*this))
return result.getOwner();
return nullptr;
}
@@ -27,28 +27,28 @@ Location Value::getLoc() const {
if (auto *op = getDefiningOp())
return op->getLoc();
- return cast<BlockArgument>().getLoc();
+ return llvm::cast<BlockArgument>(*this).getLoc();
}
void Value::setLoc(Location loc) {
if (auto *op = getDefiningOp())
return op->setLoc(loc);
- return cast<BlockArgument>().setLoc(loc);
+ return llvm::cast<BlockArgument>(*this).setLoc(loc);
}
/// Return the Region in which this Value is defined.
Region *Value::getParentRegion() {
if (auto *op = getDefiningOp())
return op->getParentRegion();
- return cast<BlockArgument>().getOwner()->getParent();
+ return llvm::cast<BlockArgument>(*this).getOwner()->getParent();
}
/// Return the Block in which this Value is defined.
Block *Value::getParentBlock() {
if (Operation *op = getDefiningOp())
return op->getBlock();
- return cast<BlockArgument>().getOwner();
+ return llvm::cast<BlockArgument>(*this).getOwner();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index e335e156e89df..e460fe133379c 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -241,7 +241,7 @@ mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
TypeID typeID) {
return llvm::to_vector<4>(llvm::make_filter_range(
entries, [typeID](DataLayoutEntryInterface entry) {
- auto type = entry.getKey().dyn_cast<Type>();
+ auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
return type && type.getTypeID() == typeID;
}));
}
@@ -521,7 +521,7 @@ void DataLayoutSpecInterface::bucketEntriesByType(
DenseMap<TypeID, DataLayoutEntryList> &types,
DenseMap<StringAttr, DataLayoutEntryInterface> &ids) {
for (DataLayoutEntryInterface entry : getEntries()) {
- if (auto type = entry.getKey().dyn_cast<Type>())
+ if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()))
types[type.getTypeID()].push_back(entry);
else
ids[entry.getKey().get<StringAttr>()] = entry;
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index aaa1e1b245251..00d1c51b0348f 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -68,7 +68,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
bool ShapeAdaptor::hasRank() const {
if (val.isNull())
return false;
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).hasRank();
if (val.is<Attribute>())
return true;
@@ -78,7 +78,7 @@ bool ShapeAdaptor::hasRank() const {
Type ShapeAdaptor::getElementType() const {
if (val.isNull())
return nullptr;
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getElementType();
if (val.is<Attribute>())
return nullptr;
@@ -87,10 +87,10 @@ Type ShapeAdaptor::getElementType() const {
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
assert(hasRank());
- if (auto t = val.dyn_cast<Type>()) {
+ if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
res.assign(vals.begin(), vals.end());
- } else if (auto attr = val.dyn_cast<Attribute>()) {
+ } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
res.clear();
res.reserve(dattr.size());
@@ -110,9 +110,9 @@ void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
int64_t ShapeAdaptor::getDimSize(int index) const {
assert(hasRank());
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getDimSize(index);
- if (auto attr = val.dyn_cast<Attribute>())
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
return cast<DenseIntElementsAttr>(attr)
.getValues<APInt>()[index]
.getSExtValue();
@@ -122,9 +122,9 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
int64_t ShapeAdaptor::getRank() const {
assert(hasRank());
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getRank();
- if (auto attr = val.dyn_cast<Attribute>())
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
return cast<DenseIntElementsAttr>(attr).size();
return val.get<ShapedTypeComponents *>()->getDims().size();
}
@@ -133,9 +133,9 @@ bool ShapeAdaptor::hasStaticShape() const {
if (!hasRank())
return false;
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).hasStaticShape();
- if (auto attr = val.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
for (auto index : dattr.getValues<APInt>())
if (ShapedType::isDynamic(index.getSExtValue()))
@@ -149,10 +149,10 @@ bool ShapeAdaptor::hasStaticShape() const {
int64_t ShapeAdaptor::getNumElements() const {
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
- if (auto t = val.dyn_cast<Type>())
+ if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getNumElements();
- if (auto attr = val.dyn_cast<Attribute>()) {
+ if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
int64_t num = 1;
for (auto index : dattr.getValues<APInt>()) {
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index bc7d6b45cba57..3fab2a3f90896 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -26,14 +26,14 @@ namespace mlir {
/// If ofr is a constant integer or an IntegerAttr, return the integer.
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
- if (auto val = ofr.dyn_cast<Value>()) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
APSInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
return intVal.getSExtValue();
return std::nullopt;
}
// Case 2: Check for IntegerAttr.
- Attribute attr = ofr.dyn_cast<Attribute>();
+ Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
@@ -99,7 +99,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
}
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
- if (Value value = ofr.dyn_cast<Value>())
+ if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
return getExpr(value, /*dim=*/std::nullopt);
auto constInt = getConstantIntValue(ofr);
assert(constInt.has_value() && "expected Integer constant");
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 35f1edd8b0e2a..727607146a68c 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -26,7 +26,8 @@ struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
const Pass &getPass() const { return pass; }
Operation *getOp() const {
ArrayRef<IRUnit> irUnits = getContextIRUnits();
- return irUnits.empty() ? nullptr : irUnits[0].dyn_cast<Operation *>();
+ return irUnits.empty() ? nullptr
+ : llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
}
public:
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 03557d9d283a3..fe9cae3b5439d 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -384,7 +384,7 @@ void Operator::populateTypeInferenceInfo(
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// Check for a non-variable length operand to use as the type anchor.
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
- NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
+ NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
return operand && !operand->isVariableLength();
});
if (operandI == arguments.end())
@@ -824,7 +824,7 @@ StringRef Operator::getAssemblyFormat() const {
void Operator::print(llvm::raw_ostream &os) const {
os << "op '" << getOperationName() << "'\n";
for (Argument arg : arguments) {
- if (auto *attr = arg.dyn_cast<NamedAttribute *>())
+ if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
os << "[attribute] " << attr->name << '\n';
else
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 8783be7da377e..96c727e429de9 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -131,7 +131,7 @@ convertBranchWeights(std::optional<ElementsAttr> weights,
return nullptr;
SmallVector<uint32_t> weightValues;
weightValues.reserve(weights->size());
- for (APInt weight : weights->cast<DenseIntElementsAttr>())
+ for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
weightValues.push_back(weight.getLimitedValue());
return llvm::MDBuilder(moduleTranslation.getLLVMContext())
.createBranchWeights(weightValues);
@@ -330,7 +330,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
auto *ty = llvm::cast<llvm::IntegerType>(
moduleTranslation.convertType(switchOp.getValue().getType()));
for (auto i :
- llvm::zip(switchOp.getCaseValues()->cast<DenseIntElementsAttr>(),
+ llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
switchOp.getCaseDestinations()))
switchInst->addCase(
llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2e67db55abc6b..05d6b7827d83a 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -730,8 +730,8 @@ Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) {
// Returns the static shape of the provided type if possible.
auto getConstantShape = [&](llvm::Type *type) {
- return getBuiltinTypeForAttr(convertType(type))
- .dyn_cast_or_null<ShapedType>();
+ return llvm::dyn_cast_if_present<ShapedType>(getBuiltinTypeForAttr(convertType(type))
+ );
};
// Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte
@@ -798,8 +798,8 @@ Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) {
// Convert zero aggregates.
if (auto *constZero = dyn_cast<llvm::ConstantAggregateZero>(constant)) {
- auto shape = getBuiltinTypeForAttr(convertType(constZero->getType()))
- .dyn_cast_or_null<ShapedType>();
+ auto shape = llvm::dyn_cast_if_present<ShapedType>(getBuiltinTypeForAttr(convertType(constZero->getType()))
+ );
if (!shape)
return {};
// Convert zero aggregates with a static shape to splat elements attributes.
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index d1d23b6c6f647..772721e31e1ce 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -69,7 +69,7 @@ translateDataLayout(DataLayoutSpecInterface attribute,
std::string llvmDataLayout;
llvm::raw_string_ostream layoutStream(llvmDataLayout);
for (DataLayoutEntryInterface entry : attribute.getEntries()) {
- auto key = entry.getKey().dyn_cast<StringAttr>();
+ auto key = llvm::dyn_cast_if_present<StringAttr>(entry.getKey());
if (!key)
continue;
if (key.getValue() == DLTIDialect::kDataLayoutEndiannessKey) {
@@ -108,7 +108,7 @@ translateDataLayout(DataLayoutSpecInterface attribute,
// specified in entries. Where possible, data layout queries are used instead
// of directly inspecting the entries.
for (DataLayoutEntryInterface entry : attribute.getEntries()) {
- auto type = entry.getKey().dyn_cast<Type>();
+ auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
if (!type)
continue;
// Data layout for the index type is irrelevant at this point.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index a0b8faa7e5eab..b84d1d9c21879 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -285,7 +285,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
linkageName, linkageTypeAttr);
- decorations[words[0]].set(symbol, linkageAttr.dyn_cast<Attribute>());
+ decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
break;
}
case spirv::Decoration::Aliased:
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 582f02f37456f..44538c38a41b8 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -639,7 +639,7 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
if (values) {
for (auto &intVal : values.getValue()) {
operands.push_back(static_cast<uint32_t>(
- intVal.cast<IntegerAttr>().getValue().getZExtValue()));
+ llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
}
}
encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9089ebc84fa60..f32f6e8242063 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -222,7 +222,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::LinkageAttributes: {
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
- auto linkageAttr = attr.getValue().dyn_cast<spirv::LinkageAttributesAttr>();
+ auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index 4c6aaf34ba05c..8efc614767317 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -136,7 +136,7 @@ struct PDLIndexSymbol {
/// Return the location of the definition of this symbol.
SMRange getDefLoc() const {
- if (const ast::Decl *decl = definition.dyn_cast<const ast::Decl *>()) {
+ if (const ast::Decl *decl = llvm::dyn_cast_if_present<const ast::Decl *>(definition)) {
const ast::Name *declName = decl->getName();
return declName ? declName->getLoc() : decl->getLoc();
}
@@ -465,7 +465,7 @@ PDLDocument::findHover(const lsp::URIForFile &uri,
return std::nullopt;
// Add hover for operation names.
- if (const auto *op = symbol->definition.dyn_cast<const ods::Operation *>())
+ if (const auto *op = llvm::dyn_cast_if_present<const ods::Operation *>(symbol->definition))
return buildHoverForOpName(op, hoverRange);
const auto *decl = symbol->definition.get<const ast::Decl *>();
return findHover(decl, hoverRange);
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 57ccb3b2057c1..a7dcd2b0e87d2 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -373,7 +373,7 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
- if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
+ if (auto sym = llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
return debugString(op);
return "_unnamed_callee_";
}
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 827c0ad4290b7..e9e59cfeed79e 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -272,7 +272,7 @@ OperationFolder::processFoldResults(Operation *op,
assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
// Check if the result was an SSA value.
- if (auto repl = foldResults[i].dyn_cast<Value>()) {
+ if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
if (repl.getType() != op->getResult(i).getType()) {
results.clear();
return failure();
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 46126b0c5f400..f0768692d6ab1 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -266,7 +266,7 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
// Remap the locations of the inlined operations if a valid source location
// was provided.
- if (inlineLoc && !inlineLoc->isa<UnknownLoc>())
+ if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
remapInlinedLocations(newBlocks, *inlineLoc);
// If the blocks were moved in-place, make sure to remap any necessary
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 5a66f1916515d..ed361b5a0e270 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -115,11 +115,11 @@ LogicalResult FooAnalysis::initialize(Operation *top) {
}
LogicalResult FooAnalysis::visit(ProgramPoint point) {
- if (auto *op = point.dyn_cast<Operation *>()) {
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) {
visitOperation(op);
return success();
}
- if (auto *block = point.dyn_cast<Block *>()) {
+ if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
visitBlock(block);
return success();
}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index db3b9a1decf07..ad017cef1b9ba 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -161,7 +161,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
}
// Replace the op with the reified bound.
- if (auto val = reified->dyn_cast<Value>()) {
+ if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
rewriter.replaceOp(op, val);
return WalkResult::skip();
}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 31b7504f7b859..98e88fe01019d 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1134,7 +1134,7 @@ void OpEmitter::genPropertiesSupport() {
)decl";
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
auto &prop = namedProperty->prop;
FmtContext fctx;
@@ -1145,7 +1145,7 @@ void OpEmitter::genPropertiesSupport() {
.addSubst("_diag", propertyDiag)),
name);
} else {
- const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
StringRef name = namedAttr->attrName;
setPropMethod << formatv(R"decl(
{{
@@ -1187,7 +1187,7 @@ void OpEmitter::genPropertiesSupport() {
)decl";
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
auto &prop = namedProperty->prop;
FmtContext fctx;
@@ -1198,7 +1198,7 @@ void OpEmitter::genPropertiesSupport() {
.addSubst("_storage", propertyStorage)));
continue;
}
- const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
StringRef name = namedAttr->attrName;
getPropMethod << formatv(R"decl(
{{
@@ -1225,7 +1225,7 @@ void OpEmitter::genPropertiesSupport() {
)decl";
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
auto &prop = namedProperty->prop;
FmtContext fctx;
@@ -1238,13 +1238,13 @@ void OpEmitter::genPropertiesSupport() {
llvm::interleaveComma(
attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
hashMethod << "\n hash_" << namedProperty->name << "(prop."
<< namedProperty->name << ")";
return;
}
const auto *namedAttr =
- attrOrProp.dyn_cast<const AttributeMetadata *>();
+ llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
StringRef name = namedAttr->attrName;
hashMethod << "\n llvm::hash_value(prop." << name
<< ".getAsOpaquePointer())";
@@ -1266,7 +1266,7 @@ void OpEmitter::genPropertiesSupport() {
)decl";
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedAttr =
- attrOrProp.dyn_cast<const AttributeMetadata *>()) {
+ llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp)) {
StringRef name = namedAttr->attrName;
getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name);
setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name);
@@ -1281,7 +1281,7 @@ void OpEmitter::genPropertiesSupport() {
// syntax. This method verifies the constraint on the properties attributes
// before they are set, since dyn_cast<> will silently omit failures.
for (const auto &attrOrProp : attrOrProperties) {
- const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
if (!namedAttr || !namedAttr->constraint)
continue;
Attribute attr = *namedAttr->constraint;
@@ -2472,7 +2472,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
// Calculate the start index from which we can attach default values in the
// builder declaration.
for (int i = op.getNumArgs() - 1; i >= 0; --i) {
- auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
+ auto *namedAttr = llvm::dyn_cast_if_present<tblgen::NamedAttribute *>(op.getArg(i));
if (!namedAttr || !namedAttr->attr.hasDefaultValue())
break;
@@ -2502,7 +2502,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
Argument arg = op.getArg(i);
- if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) {
+ if (const auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
StringRef type;
if (operand->isVariadicOfVariadic())
type = "::llvm::ArrayRef<::mlir::ValueRange>";
@@ -2515,7 +2515,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
operand->isOptional());
continue;
}
- if (const auto *operand = arg.dyn_cast<NamedProperty *>()) {
+ if (const auto *operand = llvm::dyn_cast_if_present<NamedProperty *>(arg)) {
// TODO
continue;
}
@@ -3442,7 +3442,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
llvm::raw_string_ostream comparatorOs(comparator);
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
- attrOrProp.dyn_cast<const NamedProperty *>()) {
+ llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
StringRef name = namedProperty->name;
if (name.empty())
report_fatal_error("missing name for property");
@@ -3476,7 +3476,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
.addSubst("_storage", propertyStorage)));
continue;
}
- const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
+ const auto *namedAttr = llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
const Attribute *attr = nullptr;
if (namedAttr->constraint)
attr = &*namedAttr->constraint;
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 9f8fa4310334f..3e6db51b59758 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -265,11 +265,11 @@ struct OperationFormat {
/// Get the variable this type is resolved to, or nullptr.
const NamedTypeConstraint *getVariable() const {
- return resolver.dyn_cast<const NamedTypeConstraint *>();
+ return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
}
/// Get the attribute this type is resolved to, or nullptr.
const NamedAttribute *getAttribute() const {
- return resolver.dyn_cast<const NamedAttribute *>();
+ return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
}
/// Get the transformer for the type of the variable, or std::nullopt.
std::optional<StringRef> getVarTransformer() const {
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 330268dc8183c..6e4a9e347cd6d 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -674,7 +674,7 @@ populateBuilderLinesAttr(const Operator &op,
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i);
- auto *attribute = arg.dyn_cast<NamedAttribute *>();
+ auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
if (!attribute)
continue;
@@ -914,9 +914,9 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
// - default-valued named attributes
// - optional operands
Argument a = op.getArg(builderArgIndex - numResultArgs);
- if (auto *nattr = a.dyn_cast<NamedAttribute *>())
+ if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
- if (auto *ntype = a.dyn_cast<NamedTypeConstraint *>())
+ if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
return ntype->isOptional();
return false;
};
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 8a04cc9f5be11..9463c4f9c7e29 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -595,7 +595,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
++opArgIdx;
continue;
}
- if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
+ if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
if (operand->isVariableLength()) {
auto error = formatv("use nested DAG construct to match op {0}'s "
"variadic operand #{1} unsupported now",
@@ -1524,7 +1524,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
int valueIndex = 0; // An index for uniquing local variable names.
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
const auto *operand =
- resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
+ llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex));
// We do not need special handling for attributes.
if (!operand)
continue;
@@ -1579,7 +1579,7 @@ void PatternEmitter::supplyValuesForOpArgs(
Argument opArg = resultOp.getArg(argIndex);
// Handle the case of operand first.
- if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
+ if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
if (!operand->name.empty())
os << "/*" << operand->name << "=*/";
os << childNodeNames.lookup(argIndex);
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 4caaf1df23334..7bf7755a8a52f 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -926,7 +926,7 @@ static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
// Process operands/attributes
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
- if (auto *valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
+ if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
if (valueArg->isVariableLength()) {
if (i != e - 1) {
PrintFatalError(loc, "SPIR-V ops can have Variadic<..> or "
diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
index 6601f32f3288a..e9ba28e854bb4 100644
--- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp
@@ -159,7 +159,7 @@ struct OpWithLayout : public Op<OpWithLayout, DataLayoutOpInterface::Trait> {
// Handle built-in types that are not handled by the default process.
if (auto iType = dyn_cast<IntegerType>(type)) {
for (DataLayoutEntryInterface entry : params)
- if (entry.getKey().dyn_cast<Type>() == type)
+ if (llvm::dyn_cast_if_present<Type>(entry.getKey()) == type)
return 8 *
cast<IntegerAttr>(entry.getValue()).getValue().getZExtValue();
return 8 * iType.getIntOrFloatBitWidth();
More information about the Mlir-commits
mailing list