[Mlir-commits] [mlir] c1fa60b - [mlir] Update method cast calls to function calls
Tres Popp
llvmlistbot at llvm.org
Fri May 12 02:52:47 PDT 2023
Author: Tres Popp
Date: 2023-05-12T11:21:30+02:00
New Revision: c1fa60b4cde512964544ab66404dea79dbc5dcb4
URL: https://github.com/llvm/llvm-project/commit/c1fa60b4cde512964544ab66404dea79dbc5dcb4
DIFF: https://github.com/llvm/llvm-project/commit/c1fa60b4cde512964544ab66404dea79dbc5dcb4.diff
LOG: [mlir] Update method cast calls to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.
Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.
Context:
* https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…"
* Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443
Implementation:
This follows a previous patch that updated calls
`op.cast<T>()-> cast<T>(op)`. However some cases could not handle an
unprefixed `cast` call due to occurrences of variables named cast, or
occurring inside of class definitions which would resolve to the method.
All C++ files that did not work automatically with `cast<T>()` are
updated here to `llvm::cast` and similar with the intention that they
can be easily updated after the methods are removed through a
find-replace.
See https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check
for the clang-tidy check that is used and then update printed
occurrences of the function to include `llvm::` before.
One can then run the following:
```
ninja -C $BUILD_DIR clang-tidy
run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
-export-fixes /tmp/cast/casts.yaml mlir/*\
-header-filter=mlir/ -fix
rm -rf $BUILD_DIR/tools/mlir/**/*.inc
```
Differential Revision: https://reviews.llvm.org/D150348
Added:
Modified:
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/FunctionInterfaces.h
mlir/include/mlir/IR/Location.h
mlir/include/mlir/IR/Matchers.h
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Value.h
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/CAPI/IR/BuiltinTypes.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
mlir/lib/Dialect/DLTI/DLTI.cpp
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
mlir/lib/Dialect/Func/IR/FuncOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/Index/IR/IndexOps.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
mlir/lib/Dialect/Quant/IR/TypeParser.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
mlir/lib/IR/AffineMap.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinAttributeInterfaces.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/BuiltinDialectBytecode.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/lib/IR/Diagnostics.cpp
mlir/lib/IR/ExtensibleDialect.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/lib/IR/FunctionInterfaces.cpp
mlir/lib/IR/Location.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/IR/TypeUtilities.cpp
mlir/lib/IR/Verifier.cpp
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/python/lib/PythonTestCAPI.cpp
mlir/unittests/TableGen/EnumsGenTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1b0f4bfb3f629..4dbeb418099d1 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -403,7 +403,7 @@ class OpBuilder : public Builder {
if (Operation *op = val.getDefiningOp()) {
setInsertionPointAfter(op);
} else {
- auto blockArg = val.cast<BlockArgument>();
+ auto blockArg = llvm::cast<BlockArgument>(val);
setInsertionPointToStart(blockArg.getOwner());
}
}
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 496c197e47152..7c4136021cb52 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -389,7 +389,7 @@ class DenseElementsAttr : public Attribute {
!std::is_same<Attribute, T>::value,
T>
getSplatValue() const {
- return getSplatValue<Attribute>().template cast<T>();
+ return llvm::cast<T>(getSplatValue<Attribute>());
}
/// Try to get an iterator of the given type to the start of the held element
@@ -510,7 +510,7 @@ class DenseElementsAttr : public Attribute {
T>::mapped_iterator_base;
/// Map the element to the iterator result type.
- T mapElement(Attribute attr) const { return attr.cast<T>(); }
+ T mapElement(Attribute attr) const { return llvm::cast<T>(attr); }
};
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
FailureOr<iterator_range_impl<DerivedAttributeElementIterator<T>>>
@@ -684,7 +684,7 @@ class SplatElementsAttr : public DenseElementsAttr {
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
- auto denseAttr = attr.dyn_cast<DenseElementsAttr>();
+ auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(attr);
return denseAttr && denseAttr.isSplat();
}
};
@@ -887,7 +887,7 @@ class FlatSymbolRefAttr : public SymbolRefAttr {
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Attribute attr) {
- SymbolRefAttr refAttr = attr.dyn_cast<SymbolRefAttr>();
+ SymbolRefAttr refAttr = llvm::dyn_cast<SymbolRefAttr>(attr);
return refAttr && refAttr.getNestedReferences().empty();
}
@@ -912,14 +912,13 @@ class DenseFPElementsAttr : public DenseIntOrFPElementsAttr {
/// simply wraps the DenseElementsAttr::get calls.
template <typename Arg>
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) {
- return DenseElementsAttr::get(type, llvm::ArrayRef(arg))
- .template cast<DenseFPElementsAttr>();
+ return llvm::cast<DenseFPElementsAttr>(
+ DenseElementsAttr::get(type, llvm::ArrayRef(arg)));
}
template <typename T>
static DenseFPElementsAttr get(const ShapedType &type,
const std::initializer_list<T> &list) {
- return DenseElementsAttr::get(type, list)
- .template cast<DenseFPElementsAttr>();
+ return llvm::cast<DenseFPElementsAttr>(DenseElementsAttr::get(type, list));
}
/// Generates a new DenseElementsAttr by mapping each value attribute, and
@@ -954,14 +953,13 @@ class DenseIntElementsAttr : public DenseIntOrFPElementsAttr {
/// simply wraps the DenseElementsAttr::get calls.
template <typename Arg>
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) {
- return DenseElementsAttr::get(type, llvm::ArrayRef(arg))
- .template cast<DenseIntElementsAttr>();
+ return llvm::cast<DenseIntElementsAttr>(
+ DenseElementsAttr::get(type, llvm::ArrayRef(arg)));
}
template <typename T>
static DenseIntElementsAttr get(const ShapedType &type,
const std::initializer_list<T> &list) {
- return DenseElementsAttr::get(type, list)
- .template cast<DenseIntElementsAttr>();
+ return llvm::cast<DenseIntElementsAttr>(DenseElementsAttr::get(type, list));
}
/// Generates a new DenseElementsAttr by mapping each value attribute, and
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f26465ef1d667..4fc82dd7a8e9d 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -367,20 +367,21 @@ SliceVerificationResult isRankReducedType(ShapedType originalType,
//===----------------------------------------------------------------------===//
inline bool BaseMemRefType::classof(Type type) {
- return type.isa<MemRefType, UnrankedMemRefType>();
+ return llvm::isa<MemRefType, UnrankedMemRefType>(type);
}
inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
- type.isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>() ||
- type.isa<MemRefElementTypeInterface>();
+ llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
+ type) ||
+ llvm::isa<MemRefElementTypeInterface>(type);
}
inline bool FloatType::classof(Type type) {
- return type
- .isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
- Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
- Float32Type, Float64Type, Float80Type, Float128Type>();
+ return llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type,
+ Float16Type, Float32Type, Float64Type, Float80Type,
+ Float128Type>(type);
}
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -428,7 +429,7 @@ inline FloatType FloatType::getF128(MLIRContext *ctx) {
}
inline bool TensorType::classof(Type type) {
- return type.isa<RankedTensorType, UnrankedTensorType>();
+ return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h
index 3beb3db4e5662..e813cb8f03903 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.h
+++ b/mlir/include/mlir/IR/FunctionInterfaces.h
@@ -178,7 +178,7 @@ LogicalResult verifyTrait(ConcreteOp op) {
}
for (unsigned i = 0; i != numArgs; ++i) {
DictionaryAttr argAttrs =
- allArgAttrs[i].dyn_cast_or_null<DictionaryAttr>();
+ llvm::dyn_cast_or_null<DictionaryAttr>(allArgAttrs[i]);
if (!argAttrs) {
return op.emitOpError() << "expects argument attribute dictionary "
"to be a DictionaryAttr, but got `"
@@ -209,7 +209,7 @@ LogicalResult verifyTrait(ConcreteOp op) {
}
for (unsigned i = 0; i != numResults; ++i) {
DictionaryAttr resultAttrs =
- allResultAttrs[i].dyn_cast_or_null<DictionaryAttr>();
+ llvm::dyn_cast_or_null<DictionaryAttr>(allResultAttrs[i]);
if (!resultAttrs) {
return op.emitOpError() << "expects result attribute dictionary "
"to be a DictionaryAttr, but got `"
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index 63b12899e2499..d4268e804f4f7 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -148,12 +148,12 @@ class FusedLocWith : public FusedLoc {
/// Return the metadata associated with this fused location.
MetadataT getMetadata() const {
- return FusedLoc::getMetadata().template cast<MetadataT>();
+ return llvm::cast<MetadataT>(FusedLoc::getMetadata());
}
/// Support llvm style casting.
static bool classof(Attribute attr) {
- auto fusedLoc = attr.dyn_cast<FusedLoc>();
+ auto fusedLoc = llvm::dyn_cast<FusedLoc>(attr);
return fusedLoc && fusedLoc.getMetadata().isa_and_nonnull<MetadataT>();
}
};
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 4dbc623916acf..2361a541efc22 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -39,7 +39,7 @@ struct attr_value_binder {
attr_value_binder(ValueType *bv) : bind_value(bv) {}
bool match(const Attribute &attr) {
- if (auto intAttr = attr.dyn_cast<AttrClass>()) {
+ if (auto intAttr = llvm::dyn_cast<AttrClass>(attr)) {
*bind_value = intAttr.getValue();
return true;
}
@@ -90,7 +90,7 @@ struct constant_op_binder {
(void)result;
assert(succeeded(result) && "expected ConstantLike op to be foldable");
- if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
+ if (auto attr = llvm::dyn_cast<AttrT>(foldedOp.front().get<Attribute>())) {
if (bind_value)
*bind_value = attr;
return true;
@@ -136,10 +136,10 @@ struct constant_float_op_binder {
return false;
auto type = op->getResult(0).getType();
- if (type.isa<FloatType>())
+ if (llvm::isa<FloatType>(type))
return attr_value_binder<FloatAttr>(bind_value).match(attr);
- if (type.isa<VectorType, RankedTensorType>()) {
- if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+ if (llvm::isa<VectorType, RankedTensorType>(type)) {
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr)) {
return attr_value_binder<FloatAttr>(bind_value)
.match(splatAttr.getSplatValue<Attribute>());
}
@@ -173,10 +173,10 @@ struct constant_int_op_binder {
return false;
auto type = op->getResult(0).getType();
- if (type.isa<IntegerType, IndexType>())
+ if (llvm::isa<IntegerType, IndexType>(type))
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
- if (type.isa<VectorType, RankedTensorType>()) {
- if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+ if (llvm::isa<VectorType, RankedTensorType>(type)) {
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr)) {
return attr_value_binder<IntegerAttr>(bind_value)
.match(splatAttr.getSplatValue<Attribute>());
}
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index f4045b236a448..4c36453f31b2f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -204,7 +204,7 @@ class AsmPrinter {
auto &os = getStream() << " -> ";
bool wrapped = !llvm::hasSingleElement(types) ||
- (*types.begin()).template isa<FunctionType>();
+ llvm::isa<FunctionType>((*types.begin()));
if (wrapped)
os << '(';
llvm::interleaveComma(types, *this);
@@ -865,7 +865,7 @@ class AsmParser {
return failure();
// Check for the right kind of attribute.
- if (!(result = attr.dyn_cast<AttrType>()))
+ if (!(result = llvm::dyn_cast<AttrType>(attr)))
return emitError(loc, "invalid kind of attribute specified");
return success();
@@ -899,7 +899,7 @@ class AsmParser {
return failure();
// Check for the right kind of attribute.
- result = attr.dyn_cast<AttrType>();
+ result = llvm::dyn_cast<AttrType>(attr);
if (!result)
return emitError(loc, "invalid kind of attribute specified");
@@ -936,7 +936,7 @@ class AsmParser {
return failure();
// Check for the right kind of attribute.
- result = attr.dyn_cast<AttrType>();
+ result = llvm::dyn_cast<AttrType>(attr);
if (!result)
return emitError(loc, "invalid kind of attribute specified");
@@ -970,7 +970,7 @@ class AsmParser {
return failure();
// Check for the right kind of attribute.
- result = attr.dyn_cast<AttrType>();
+ result = llvm::dyn_cast<AttrType>(attr);
if (!result)
return emitError(loc, "invalid kind of attribute specified");
return success();
@@ -1126,7 +1126,7 @@ class AsmParser {
return failure();
// Check for the right kind of type.
- result = type.dyn_cast<TypeT>();
+ result = llvm::dyn_cast<TypeT>(type);
if (!result)
return emitError(loc, "invalid kind of type specified");
@@ -1158,7 +1158,7 @@ class AsmParser {
return failure();
// Check for the right kind of Type.
- result = type.dyn_cast<TypeT>();
+ result = llvm::dyn_cast<TypeT>(type);
if (!result)
return emitError(loc, "invalid kind of Type specified");
return success();
@@ -1198,7 +1198,7 @@ class AsmParser {
return failure();
// Check for the right kind of type.
- result = type.dyn_cast<TypeType>();
+ result = llvm::dyn_cast<TypeType>(type);
if (!result)
return emitError(loc, "invalid kind of type specified");
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 8bd23a2a1fc57..ec6d4ca2d6e6f 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -509,11 +509,11 @@ class alignas(8) Operation final
template <typename AttrClass>
AttrClass getAttrOfType(StringAttr name) {
- return getAttr(name).dyn_cast_or_null<AttrClass>();
+ return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
}
template <typename AttrClass>
AttrClass getAttrOfType(StringRef name) {
- return getAttr(name).dyn_cast_or_null<AttrClass>();
+ return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
}
/// Return true if the operation has an attribute with the provided name,
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 7a8aee29ca445..a280fbdf64bc8 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -433,7 +433,7 @@ struct TypedValue : Value {
static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }
/// Return the known Type
- Ty getType() { return Value::getType().template cast<Ty>(); }
+ Ty getType() { return llvm::cast<Ty>(Value::getType()); }
void setType(Ty ty) { Value::setType(ty); }
};
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 66d291eddb65a..f2441e0b0ae9b 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -25,7 +25,7 @@ MlirAttribute mlirAttributeGetNull() { return {nullptr}; }
//===----------------------------------------------------------------------===//
bool mlirAttributeIsALocation(MlirAttribute attr) {
- return unwrap(attr).isa<LocationAttr>();
+ return llvm::isa<LocationAttr>(unwrap(attr));
}
//===----------------------------------------------------------------------===//
@@ -33,7 +33,7 @@ bool mlirAttributeIsALocation(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
- return unwrap(attr).isa<AffineMapAttr>();
+ return llvm::isa<AffineMapAttr>(unwrap(attr));
}
MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
@@ -41,7 +41,7 @@ MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
}
MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<AffineMapAttr>().getValue());
+ return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue());
}
//===----------------------------------------------------------------------===//
@@ -49,7 +49,7 @@ MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAArray(MlirAttribute attr) {
- return unwrap(attr).isa<ArrayAttr>();
+ return llvm::isa<ArrayAttr>(unwrap(attr));
}
MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
@@ -61,11 +61,11 @@ MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
}
intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
- return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size());
+ return static_cast<intptr_t>(llvm::cast<ArrayAttr>(unwrap(attr)).size());
}
MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
- return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]);
+ return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]);
}
//===----------------------------------------------------------------------===//
@@ -73,7 +73,7 @@ MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsADictionary(MlirAttribute attr) {
- return unwrap(attr).isa<DictionaryAttr>();
+ return llvm::isa<DictionaryAttr>(unwrap(attr));
}
MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
@@ -87,19 +87,19 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
}
intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
- return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size());
+ return static_cast<intptr_t>(llvm::cast<DictionaryAttr>(unwrap(attr)).size());
}
MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
intptr_t pos) {
NamedAttribute attribute =
- unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
+ llvm::cast<DictionaryAttr>(unwrap(attr)).getValue()[pos];
return {wrap(attribute.getName()), wrap(attribute.getValue())};
}
MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
MlirStringRef name) {
- return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name)));
+ return wrap(llvm::cast<DictionaryAttr>(unwrap(attr)).get(unwrap(name)));
}
//===----------------------------------------------------------------------===//
@@ -107,7 +107,7 @@ MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAFloat(MlirAttribute attr) {
- return unwrap(attr).isa<FloatAttr>();
+ return llvm::isa<FloatAttr>(unwrap(attr));
}
MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
@@ -121,7 +121,7 @@ MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
}
double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
- return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
+ return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble();
}
//===----------------------------------------------------------------------===//
@@ -129,7 +129,7 @@ double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAInteger(MlirAttribute attr) {
- return unwrap(attr).isa<IntegerAttr>();
+ return llvm::isa<IntegerAttr>(unwrap(attr));
}
MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
@@ -137,15 +137,15 @@ MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
}
int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
- return unwrap(attr).cast<IntegerAttr>().getInt();
+ return llvm::cast<IntegerAttr>(unwrap(attr)).getInt();
}
int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) {
- return unwrap(attr).cast<IntegerAttr>().getSInt();
+ return llvm::cast<IntegerAttr>(unwrap(attr)).getSInt();
}
uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
- return unwrap(attr).cast<IntegerAttr>().getUInt();
+ return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
}
//===----------------------------------------------------------------------===//
@@ -153,7 +153,7 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsABool(MlirAttribute attr) {
- return unwrap(attr).isa<BoolAttr>();
+ return llvm::isa<BoolAttr>(unwrap(attr));
}
MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
@@ -161,7 +161,7 @@ MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
}
bool mlirBoolAttrGetValue(MlirAttribute attr) {
- return unwrap(attr).cast<BoolAttr>().getValue();
+ return llvm::cast<BoolAttr>(unwrap(attr)).getValue();
}
//===----------------------------------------------------------------------===//
@@ -169,7 +169,7 @@ bool mlirBoolAttrGetValue(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
- return unwrap(attr).isa<IntegerSetAttr>();
+ return llvm::isa<IntegerSetAttr>(unwrap(attr));
}
//===----------------------------------------------------------------------===//
@@ -177,7 +177,7 @@ bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAOpaque(MlirAttribute attr) {
- return unwrap(attr).isa<OpaqueAttr>();
+ return llvm::isa<OpaqueAttr>(unwrap(attr));
}
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
@@ -189,11 +189,12 @@ MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
}
MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref());
+ return wrap(
+ llvm::cast<OpaqueAttr>(unwrap(attr)).getDialectNamespace().strref());
}
MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<OpaqueAttr>().getAttrData());
+ return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData());
}
//===----------------------------------------------------------------------===//
@@ -201,7 +202,7 @@ MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAString(MlirAttribute attr) {
- return unwrap(attr).isa<StringAttr>();
+ return llvm::isa<StringAttr>(unwrap(attr));
}
MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
@@ -213,7 +214,7 @@ MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
}
MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<StringAttr>().getValue());
+ return wrap(llvm::cast<StringAttr>(unwrap(attr)).getValue());
}
//===----------------------------------------------------------------------===//
@@ -221,7 +222,7 @@ MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
- return unwrap(attr).isa<SymbolRefAttr>();
+ return llvm::isa<SymbolRefAttr>(unwrap(attr));
}
MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
@@ -230,27 +231,30 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
SmallVector<FlatSymbolRefAttr, 4> refs;
refs.reserve(numReferences);
for (intptr_t i = 0; i < numReferences; ++i)
- refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
+ refs.push_back(llvm::cast<FlatSymbolRefAttr>(unwrap(references[i])));
auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol));
return wrap(SymbolRefAttr::get(symbolAttr, refs));
}
MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference().getValue());
+ return wrap(
+ llvm::cast<SymbolRefAttr>(unwrap(attr)).getRootReference().getValue());
}
MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference().getValue());
+ return wrap(
+ llvm::cast<SymbolRefAttr>(unwrap(attr)).getLeafReference().getValue());
}
intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
return static_cast<intptr_t>(
- unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size());
+ llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences().size());
}
MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
intptr_t pos) {
- return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]);
+ return wrap(
+ llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]);
}
//===----------------------------------------------------------------------===//
@@ -258,7 +262,7 @@ MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
- return unwrap(attr).isa<FlatSymbolRefAttr>();
+ return llvm::isa<FlatSymbolRefAttr>(unwrap(attr));
}
MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
@@ -266,7 +270,7 @@ MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
}
MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<FlatSymbolRefAttr>().getValue());
+ return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
}
//===----------------------------------------------------------------------===//
@@ -274,7 +278,7 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAType(MlirAttribute attr) {
- return unwrap(attr).isa<TypeAttr>();
+ return llvm::isa<TypeAttr>(unwrap(attr));
}
MlirAttribute mlirTypeAttrGet(MlirType type) {
@@ -282,7 +286,7 @@ MlirAttribute mlirTypeAttrGet(MlirType type) {
}
MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<TypeAttr>().getValue());
+ return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue());
}
//===----------------------------------------------------------------------===//
@@ -290,7 +294,7 @@ MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAUnit(MlirAttribute attr) {
- return unwrap(attr).isa<UnitAttr>();
+ return llvm::isa<UnitAttr>(unwrap(attr));
}
MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
@@ -302,24 +306,23 @@ MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAElements(MlirAttribute attr) {
- return unwrap(attr).isa<ElementsAttr>();
+ return llvm::isa<ElementsAttr>(unwrap(attr));
}
MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
uint64_t *idxs) {
- return wrap(unwrap(attr)
- .cast<ElementsAttr>()
+ return wrap(llvm::cast<ElementsAttr>(unwrap(attr))
.getValues<Attribute>()[llvm::ArrayRef(idxs, rank)]);
}
bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
uint64_t *idxs) {
- return unwrap(attr).cast<ElementsAttr>().isValidIndex(
- llvm::ArrayRef(idxs, rank));
+ return llvm::cast<ElementsAttr>(unwrap(attr))
+ .isValidIndex(llvm::ArrayRef(idxs, rank));
}
int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
- return unwrap(attr).cast<ElementsAttr>().getNumElements();
+ return llvm::cast<ElementsAttr>(unwrap(attr)).getNumElements();
}
//===----------------------------------------------------------------------===//
@@ -330,25 +333,25 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
// IsA support.
bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
- return unwrap(attr).isa<DenseBoolArrayAttr>();
+ return llvm::isa<DenseBoolArrayAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseI8Array(MlirAttribute attr) {
- return unwrap(attr).isa<DenseI8ArrayAttr>();
+ return llvm::isa<DenseI8ArrayAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseI16Array(MlirAttribute attr) {
- return unwrap(attr).isa<DenseI16ArrayAttr>();
+ return llvm::isa<DenseI16ArrayAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseI32Array(MlirAttribute attr) {
- return unwrap(attr).isa<DenseI32ArrayAttr>();
+ return llvm::isa<DenseI32ArrayAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseI64Array(MlirAttribute attr) {
- return unwrap(attr).isa<DenseI64ArrayAttr>();
+ return llvm::isa<DenseI64ArrayAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseF32Array(MlirAttribute attr) {
- return unwrap(attr).isa<DenseF32ArrayAttr>();
+ return llvm::isa<DenseF32ArrayAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
- return unwrap(attr).isa<DenseF64ArrayAttr>();
+ return llvm::isa<DenseF64ArrayAttr>(unwrap(attr));
}
//===----------------------------------------------------------------------===//
@@ -394,32 +397,32 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
// Accessors.
intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
- return unwrap(attr).cast<DenseArrayAttr>().size();
+ return llvm::cast<DenseArrayAttr>(unwrap(attr)).size();
}
//===----------------------------------------------------------------------===//
// Indexed accessors.
bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseBoolArrayAttr>()[pos];
+ return llvm::cast<DenseBoolArrayAttr>(unwrap(attr))[pos];
}
int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseI8ArrayAttr>()[pos];
+ return llvm::cast<DenseI8ArrayAttr>(unwrap(attr))[pos];
}
int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseI16ArrayAttr>()[pos];
+ return llvm::cast<DenseI16ArrayAttr>(unwrap(attr))[pos];
}
int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseI32ArrayAttr>()[pos];
+ return llvm::cast<DenseI32ArrayAttr>(unwrap(attr))[pos];
}
int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseI64ArrayAttr>()[pos];
+ return llvm::cast<DenseI64ArrayAttr>(unwrap(attr))[pos];
}
float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseF32ArrayAttr>()[pos];
+ return llvm::cast<DenseF32ArrayAttr>(unwrap(attr))[pos];
}
double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseF64ArrayAttr>()[pos];
+ return llvm::cast<DenseF64ArrayAttr>(unwrap(attr))[pos];
}
//===----------------------------------------------------------------------===//
@@ -430,13 +433,13 @@ double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
// IsA support.
bool mlirAttributeIsADenseElements(MlirAttribute attr) {
- return unwrap(attr).isa<DenseElementsAttr>();
+ return llvm::isa<DenseElementsAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
- return unwrap(attr).isa<DenseIntElementsAttr>();
+ return llvm::isa<DenseIntElementsAttr>(unwrap(attr));
}
bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
- return unwrap(attr).isa<DenseFPElementsAttr>();
+ return llvm::isa<DenseFPElementsAttr>(unwrap(attr));
}
//===----------------------------------------------------------------------===//
@@ -447,14 +450,14 @@ MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
MlirAttribute const *elements) {
SmallVector<Attribute, 8> attributes;
return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
+ DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
unwrapList(numElements, elements, attributes)));
}
MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
size_t rawBufferSize,
const void *rawBuffer) {
- auto shapedTypeCpp = unwrap(shapedType).cast<ShapedType>();
+ auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType));
ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
rawBufferSize);
bool isSplat = false;
@@ -466,61 +469,61 @@ MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
MlirAttribute element) {
- return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
unwrap(element)));
}
MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
bool element) {
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ element));
}
MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,
uint8_t element) {
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ element));
}
MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,
int8_t element) {
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ element));
}
MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
uint32_t element) {
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ element));
}
MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
int32_t element) {
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ element));
}
MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
uint64_t element) {
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ element));
}
MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
int64_t element) {
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ element));
}
MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
float element) {
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ element));
}
MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
double element) {
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ element));
}
MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
intptr_t numElements,
const int *elements) {
SmallVector<bool, 8> values(elements, elements + numElements);
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ values));
}
/// Creates a dense attribute with elements of the type deduced by templates.
@@ -528,7 +531,7 @@ template <typename T>
static MlirAttribute getDenseAttribute(MlirType shapedType,
intptr_t numElements,
const T *elements) {
- return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
llvm::ArrayRef(elements, numElements)));
}
@@ -605,99 +608,99 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
for (intptr_t i = 0; i < numElements; ++i)
values.push_back(unwrap(strs[i]));
- return wrap(
- DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
+ return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
+ values));
}
MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
MlirType shapedType) {
- return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape(
- unwrap(shapedType).cast<ShapedType>()));
+ return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
+ .reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
}
//===----------------------------------------------------------------------===//
// Splat accessors.
bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().isSplat();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat();
}
MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
return wrap(
- unwrap(attr).cast<DenseElementsAttr>().getSplatValue<Attribute>());
+ llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<Attribute>());
}
int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<bool>();
}
int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int8_t>();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int8_t>();
}
uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint8_t>();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint8_t>();
}
int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int32_t>();
}
uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint32_t>();
}
int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<int64_t>();
}
uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<uint64_t>();
}
float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<float>();
}
double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
- return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>();
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<double>();
}
MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
return wrap(
- unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>());
+ llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<StringRef>());
}
//===----------------------------------------------------------------------===//
// Indexed accessors.
bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<bool>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<bool>()[pos];
}
int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int8_t>()[pos];
}
uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint8_t>()[pos];
}
int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<int16_t>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int16_t>()[pos];
}
uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<uint16_t>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint16_t>()[pos];
}
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int32_t>()[pos];
}
uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint32_t>()[pos];
}
int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<int64_t>()[pos];
}
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<uint64_t>()[pos];
}
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<float>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<float>()[pos];
}
double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<DenseElementsAttr>().getValues<double>()[pos];
+ return llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<double>()[pos];
}
MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
intptr_t pos) {
return wrap(
- unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>()[pos]);
+ llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<StringRef>()[pos]);
}
//===----------------------------------------------------------------------===//
@@ -705,7 +708,7 @@ MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
return static_cast<const void *>(
- unwrap(attr).cast<DenseElementsAttr>().getRawData().data());
+ llvm::cast<DenseElementsAttr>(unwrap(attr)).getRawData().data());
}
//===----------------------------------------------------------------------===//
@@ -715,7 +718,7 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
template <typename U, typename T>
static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
intptr_t numElements, const T *elements) {
- return wrap(U::get(unwrap(shapedType).cast<ShapedType>(), unwrap(name),
+ return wrap(U::get(llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name),
UnmanagedAsmResourceBlob::allocateInferAlign(
llvm::ArrayRef(elements, numElements))));
}
@@ -797,7 +800,7 @@ mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType,
template <typename U, typename T>
static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) {
- return (*unwrap(attr).cast<U>().tryGetAsArrayRef())[pos];
+ return (*llvm::cast<U>(unwrap(attr)).tryGetAsArrayRef())[pos];
}
MLIR_CAPI_EXPORTED bool
@@ -853,24 +856,24 @@ mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsASparseElements(MlirAttribute attr) {
- return unwrap(attr).isa<SparseElementsAttr>();
+ return llvm::isa<SparseElementsAttr>(unwrap(attr));
}
MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
MlirAttribute denseIndices,
MlirAttribute denseValues) {
- return wrap(
- SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
- unwrap(denseIndices).cast<DenseElementsAttr>(),
- unwrap(denseValues).cast<DenseElementsAttr>()));
+ return wrap(SparseElementsAttr::get(
+ llvm::cast<ShapedType>(unwrap(shapedType)),
+ llvm::cast<DenseElementsAttr>(unwrap(denseIndices)),
+ llvm::cast<DenseElementsAttr>(unwrap(denseValues))));
}
MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices());
+ return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getIndices());
}
MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
- return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
+ return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues());
}
//===----------------------------------------------------------------------===//
@@ -878,7 +881,7 @@ MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool mlirAttributeIsAStridedLayout(MlirAttribute attr) {
- return unwrap(attr).isa<StridedLayoutAttr>();
+ return llvm::isa<StridedLayoutAttr>(unwrap(attr));
}
MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
@@ -889,14 +892,14 @@ MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
}
int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) {
- return unwrap(attr).cast<StridedLayoutAttr>().getOffset();
+ return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getOffset();
}
intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) {
return static_cast<intptr_t>(
- unwrap(attr).cast<StridedLayoutAttr>().getStrides().size());
+ llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides().size());
}
int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) {
- return unwrap(attr).cast<StridedLayoutAttr>().getStrides()[pos];
+ return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos];
}
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 2468c05463f4a..90ab847606ee0 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -23,7 +23,7 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
bool mlirTypeIsAInteger(MlirType type) {
- return unwrap(type).isa<IntegerType>();
+ return llvm::isa<IntegerType>(unwrap(type));
}
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
@@ -39,26 +39,28 @@ MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
}
unsigned mlirIntegerTypeGetWidth(MlirType type) {
- return unwrap(type).cast<IntegerType>().getWidth();
+ return llvm::cast<IntegerType>(unwrap(type)).getWidth();
}
bool mlirIntegerTypeIsSignless(MlirType type) {
- return unwrap(type).cast<IntegerType>().isSignless();
+ return llvm::cast<IntegerType>(unwrap(type)).isSignless();
}
bool mlirIntegerTypeIsSigned(MlirType type) {
- return unwrap(type).cast<IntegerType>().isSigned();
+ return llvm::cast<IntegerType>(unwrap(type)).isSigned();
}
bool mlirIntegerTypeIsUnsigned(MlirType type) {
- return unwrap(type).cast<IntegerType>().isUnsigned();
+ return llvm::cast<IntegerType>(unwrap(type)).isUnsigned();
}
//===----------------------------------------------------------------------===//
// Index type.
//===----------------------------------------------------------------------===//
-bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
+bool mlirTypeIsAIndex(MlirType type) {
+ return llvm::isa<IndexType>(unwrap(type));
+}
MlirType mlirIndexTypeGet(MlirContext ctx) {
return wrap(IndexType::get(unwrap(ctx)));
@@ -136,7 +138,9 @@ MlirType mlirF64TypeGet(MlirContext ctx) {
// None type.
//===----------------------------------------------------------------------===//
-bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
+bool mlirTypeIsANone(MlirType type) {
+ return llvm::isa<NoneType>(unwrap(type));
+}
MlirType mlirNoneTypeGet(MlirContext ctx) {
return wrap(NoneType::get(unwrap(ctx)));
@@ -147,7 +151,7 @@ MlirType mlirNoneTypeGet(MlirContext ctx) {
//===----------------------------------------------------------------------===//
bool mlirTypeIsAComplex(MlirType type) {
- return unwrap(type).isa<ComplexType>();
+ return llvm::isa<ComplexType>(unwrap(type));
}
MlirType mlirComplexTypeGet(MlirType elementType) {
@@ -155,38 +159,41 @@ MlirType mlirComplexTypeGet(MlirType elementType) {
}
MlirType mlirComplexTypeGetElementType(MlirType type) {
- return wrap(unwrap(type).cast<ComplexType>().getElementType());
+ return wrap(llvm::cast<ComplexType>(unwrap(type)).getElementType());
}
//===----------------------------------------------------------------------===//
// Shaped type.
//===----------------------------------------------------------------------===//
-bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
+bool mlirTypeIsAShaped(MlirType type) {
+ return llvm::isa<ShapedType>(unwrap(type));
+}
MlirType mlirShapedTypeGetElementType(MlirType type) {
- return wrap(unwrap(type).cast<ShapedType>().getElementType());
+ return wrap(llvm::cast<ShapedType>(unwrap(type)).getElementType());
}
bool mlirShapedTypeHasRank(MlirType type) {
- return unwrap(type).cast<ShapedType>().hasRank();
+ return llvm::cast<ShapedType>(unwrap(type)).hasRank();
}
int64_t mlirShapedTypeGetRank(MlirType type) {
- return unwrap(type).cast<ShapedType>().getRank();
+ return llvm::cast<ShapedType>(unwrap(type)).getRank();
}
bool mlirShapedTypeHasStaticShape(MlirType type) {
- return unwrap(type).cast<ShapedType>().hasStaticShape();
+ return llvm::cast<ShapedType>(unwrap(type)).hasStaticShape();
}
bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
- return unwrap(type).cast<ShapedType>().isDynamicDim(
- static_cast<unsigned>(dim));
+ return llvm::cast<ShapedType>(unwrap(type))
+ .isDynamicDim(static_cast<unsigned>(dim));
}
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
- return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
+ return llvm::cast<ShapedType>(unwrap(type))
+ .getDimSize(static_cast<unsigned>(dim));
}
int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; }
@@ -207,7 +214,9 @@ int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
// Vector type.
//===----------------------------------------------------------------------===//
-bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
+bool mlirTypeIsAVector(MlirType type) {
+ return llvm::isa<VectorType>(unwrap(type));
+}
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
@@ -226,14 +235,16 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
// Ranked / Unranked tensor type.
//===----------------------------------------------------------------------===//
-bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
+bool mlirTypeIsATensor(MlirType type) {
+ return llvm::isa<TensorType>(unwrap(type));
+}
bool mlirTypeIsARankedTensor(MlirType type) {
- return unwrap(type).isa<RankedTensorType>();
+ return llvm::isa<RankedTensorType>(unwrap(type));
}
bool mlirTypeIsAUnrankedTensor(MlirType type) {
- return unwrap(type).isa<UnrankedTensorType>();
+ return llvm::isa<UnrankedTensorType>(unwrap(type));
}
MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
@@ -253,7 +264,7 @@ MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
}
MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
- return wrap(unwrap(type).cast<RankedTensorType>().getEncoding());
+ return wrap(llvm::cast<RankedTensorType>(unwrap(type)).getEncoding());
}
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
@@ -269,7 +280,9 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
-bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
+bool mlirTypeIsAMemRef(MlirType type) {
+ return llvm::isa<MemRefType>(unwrap(type));
+}
MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
const int64_t *shape, MlirAttribute layout,
@@ -278,7 +291,7 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
mlirAttributeIsNull(layout)
? MemRefLayoutAttrInterface()
- : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
+ : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
unwrap(memorySpace)));
}
@@ -291,7 +304,7 @@ MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
unwrap(elementType),
mlirAttributeIsNull(layout)
? MemRefLayoutAttrInterface()
- : unwrap(layout).cast<MemRefLayoutAttrInterface>(),
+ : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
unwrap(memorySpace)));
}
@@ -313,19 +326,19 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
}
MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
- return wrap(unwrap(type).cast<MemRefType>().getLayout());
+ return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout());
}
MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
- return wrap(unwrap(type).cast<MemRefType>().getLayout().getAffineMap());
+ return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout().getAffineMap());
}
MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
- return wrap(unwrap(type).cast<MemRefType>().getMemorySpace());
+ return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}
bool mlirTypeIsAUnrankedMemRef(MlirType type) {
- return unwrap(type).isa<UnrankedMemRefType>();
+ return llvm::isa<UnrankedMemRefType>(unwrap(type));
}
MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
@@ -342,14 +355,16 @@ MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
}
MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
- return wrap(unwrap(type).cast<UnrankedMemRefType>().getMemorySpace());
+ return wrap(llvm::cast<UnrankedMemRefType>(unwrap(type)).getMemorySpace());
}
//===----------------------------------------------------------------------===//
// Tuple type.
//===----------------------------------------------------------------------===//
-bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
+bool mlirTypeIsATuple(MlirType type) {
+ return llvm::isa<TupleType>(unwrap(type));
+}
MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
MlirType const *elements) {
@@ -359,11 +374,12 @@ MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
}
intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
- return unwrap(type).cast<TupleType>().size();
+ return llvm::cast<TupleType>(unwrap(type)).size();
}
MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
- return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
+ return wrap(
+ llvm::cast<TupleType>(unwrap(type)).getType(static_cast<size_t>(pos)));
}
//===----------------------------------------------------------------------===//
@@ -371,7 +387,7 @@ MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
//===----------------------------------------------------------------------===//
bool mlirTypeIsAFunction(MlirType type) {
- return unwrap(type).isa<FunctionType>();
+ return llvm::isa<FunctionType>(unwrap(type));
}
MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
@@ -385,30 +401,32 @@ MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
}
intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
- return unwrap(type).cast<FunctionType>().getNumInputs();
+ return llvm::cast<FunctionType>(unwrap(type)).getNumInputs();
}
intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
- return unwrap(type).cast<FunctionType>().getNumResults();
+ return llvm::cast<FunctionType>(unwrap(type)).getNumResults();
}
MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
assert(pos >= 0 && "pos in array must be positive");
- return wrap(
- unwrap(type).cast<FunctionType>().getInput(static_cast<unsigned>(pos)));
+ return wrap(llvm::cast<FunctionType>(unwrap(type))
+ .getInput(static_cast<unsigned>(pos)));
}
MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
assert(pos >= 0 && "pos in array must be positive");
- return wrap(
- unwrap(type).cast<FunctionType>().getResult(static_cast<unsigned>(pos)));
+ return wrap(llvm::cast<FunctionType>(unwrap(type))
+ .getResult(static_cast<unsigned>(pos)));
}
//===----------------------------------------------------------------------===//
// Opaque type.
//===----------------------------------------------------------------------===//
-bool mlirTypeIsAOpaque(MlirType type) { return unwrap(type).isa<OpaqueType>(); }
+bool mlirTypeIsAOpaque(MlirType type) {
+ return llvm::isa<OpaqueType>(unwrap(type));
+}
MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
MlirStringRef typeData) {
@@ -418,9 +436,10 @@ MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
}
MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) {
- return wrap(unwrap(type).cast<OpaqueType>().getDialectNamespace().strref());
+ return wrap(
+ llvm::cast<OpaqueType>(unwrap(type)).getDialectNamespace().strref());
}
MlirStringRef mlirOpaqueTypeGetData(MlirType type) {
- return wrap(unwrap(type).cast<OpaqueType>().getTypeData());
+ return wrap(llvm::cast<OpaqueType>(unwrap(type)).getTypeData());
}
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 79386dedfdd98..c0cf5977736f4 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -172,7 +172,7 @@ MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
}
MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
- return wrap(Location(unwrap(attribute).cast<LocationAttr>()));
+ return wrap(Location(llvm::cast<LocationAttr>(unwrap(attribute))));
}
MlirLocation mlirLocationFileLineColGet(MlirContext context,
@@ -727,33 +727,33 @@ bool mlirValueEqual(MlirValue value1, MlirValue value2) {
}
bool mlirValueIsABlockArgument(MlirValue value) {
- return unwrap(value).isa<BlockArgument>();
+ return llvm::isa<BlockArgument>(unwrap(value));
}
bool mlirValueIsAOpResult(MlirValue value) {
- return unwrap(value).isa<OpResult>();
+ return llvm::isa<OpResult>(unwrap(value));
}
MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
- return wrap(unwrap(value).cast<BlockArgument>().getOwner());
+ return wrap(llvm::cast<BlockArgument>(unwrap(value)).getOwner());
}
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
return static_cast<intptr_t>(
- unwrap(value).cast<BlockArgument>().getArgNumber());
+ llvm::cast<BlockArgument>(unwrap(value)).getArgNumber());
}
void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
- unwrap(value).cast<BlockArgument>().setType(unwrap(type));
+ llvm::cast<BlockArgument>(unwrap(value)).setType(unwrap(type));
}
MlirOperation mlirOpResultGetOwner(MlirValue value) {
- return wrap(unwrap(value).cast<OpResult>().getOwner());
+ return wrap(llvm::cast<OpResult>(unwrap(value)).getOwner());
}
intptr_t mlirOpResultGetResultNumber(MlirValue value) {
return static_cast<intptr_t>(
- unwrap(value).cast<OpResult>().getResultNumber());
+ llvm::cast<OpResult>(unwrap(value)).getResultNumber());
}
MlirType mlirValueGetType(MlirValue value) {
@@ -857,7 +857,7 @@ MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
MlirType mlirAttributeGetType(MlirAttribute attribute) {
Attribute attr = unwrap(attribute);
- if (auto typedAttr = attr.dyn_cast<TypedAttr>())
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
return wrap(typedAttr.getType());
return wrap(NoneType::get(attr.getContext()));
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cafa6a8fa2dd0..105535b05de85 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -48,14 +48,15 @@ void AMDGPUDialect::initialize() {
//===----------------------------------------------------------------------===//
template <typename T>
static LogicalResult verifyRawBufferOp(T &op) {
- MemRefType bufferType = op.getMemref().getType().template cast<MemRefType>();
+ MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
Attribute memorySpace = bufferType.getMemorySpace();
bool isGlobal = false;
if (!memorySpace)
isGlobal = true;
- else if (auto intMemorySpace = memorySpace.dyn_cast<IntegerAttr>())
+ else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
- else if (auto gpuMemorySpace = memorySpace.dyn_cast<gpu::AddressSpaceAttr>())
+ else if (auto gpuMemorySpace =
+ llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
if (!isGlobal)
@@ -216,11 +217,11 @@ LogicalResult MFMAOp::verify() {
Type sourceElem = sourceType, destElem = destType;
uint32_t sourceLen = 1, destLen = 1;
- if (auto sourceVector = sourceType.dyn_cast<VectorType>()) {
+ if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
sourceLen = sourceVector.getNumElements();
sourceElem = sourceVector.getElementType();
}
- if (auto destVector = destType.dyn_cast<VectorType>()) {
+ if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
destLen = destVector.getNumElements();
destElem = destVector.getElementType();
}
@@ -229,7 +230,7 @@ LogicalResult MFMAOp::verify() {
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
int64_t sourceBLen = 1;
Type sourceBElem = sourceBType;
- if (auto sourceBVector = sourceBType.dyn_cast<VectorType>()) {
+ if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
sourceBLen = sourceBVector.getNumElements();
sourceBElem = sourceBVector.getElementType();
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f3626afcd7059..2009c393ff4dc 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -38,7 +38,7 @@ using namespace mlir::affine;
/// top level of a `AffineScope` region is always a valid symbol for all
/// uses in that region.
bool mlir::affine::isTopLevelValue(Value value, Region *region) {
- if (auto arg = value.dyn_cast<BlockArgument>())
+ if (auto arg = llvm::dyn_cast<BlockArgument>(value))
return arg.getParentRegion() == region;
return value.getDefiningOp()->getParentRegion() == region;
}
@@ -62,7 +62,7 @@ remainsLegalAfterInline(Value value, Region *src, Region *dest,
// If it's a top-level value because it's a block operand, i.e. a
// function argument, check whether the value replacing it after
// inlining is a valid dimension in the new region.
- if (value.isa<BlockArgument>())
+ if (llvm::isa<BlockArgument>(value))
return legalityCheck(mapping.lookup(value), dest);
// If it's a top-level value because it's defined in the region,
@@ -234,7 +234,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder,
/// conservatively assume it is not top-level. A value of index type defined at
/// the top level is always a valid symbol.
bool mlir::affine::isTopLevelValue(Value value) {
- if (auto arg = value.dyn_cast<BlockArgument>()) {
+ if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
// The block owning the argument may be unlinked, e.g. when the surrounding
// region has not yet been attached to an Op, at which point the parent Op
// is null.
@@ -273,7 +273,7 @@ bool mlir::affine::isValidDim(Value value) {
// This value has to be a block argument for an op that has the
// `AffineScope` trait or for an affine.for or affine.parallel.
- auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
+ auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
isa<AffineForOp, AffineParallelOp>(parentOp));
}
@@ -296,7 +296,7 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
if (!op) {
// This value has to be a block argument for an affine.for or an
// affine.parallel.
- auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
+ auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
return isa<AffineForOp, AffineParallelOp>(parentOp);
}
@@ -334,7 +334,7 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
// Conservatively handle remaining BlockArguments as non-valid symbols.
// E.g. scf.for iterArgs.
- if (dimOp.getShapedValue().template isa<BlockArgument>())
+ if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
return false;
// The dim op is also okay if its operand memref is a view/subview whose
@@ -1221,7 +1221,8 @@ static void materializeConstants(OpBuilder &b, Location loc,
// AffineDialect materializer will create invalid `arith.constant`
// operations if the provided Attribute is any other kind of integer.
constants.push_back(dialect->materializeConstant(
- b, b.getIndexAttr(ofr.get<Attribute>().cast<IntegerAttr>().getInt()),
+ b,
+ b.getIndexAttr(llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt()),
b.getIndexType(), loc));
actualValues.push_back(constants.back()->getResult(0));
}
@@ -1785,11 +1786,11 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
}
LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
- if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
return emitOpError("expected DMA source to be of memref type");
- if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
return emitOpError("expected DMA destination to be of memref type");
- if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
return emitOpError("expected DMA tag to be of memref type");
unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
@@ -1888,7 +1889,7 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
parser.resolveOperand(numElementsInfo, indexType, result.operands))
return failure();
- if (!type.isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(type))
return parser.emitError(parser.getNameLoc(),
"expected tag to be of memref type");
@@ -1899,7 +1900,7 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
}
LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
- if (!getOperand(0).getType().isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(getOperand(0).getType()))
return emitOpError("expected DMA tag to be of memref type");
Region *scope = getAffineScope(*this);
for (auto idx : getTagIndices()) {
@@ -2073,7 +2074,7 @@ static ParseResult parseBound(bool isLower, OperationState &result,
return failure();
// Parse full form - affine map followed by dim and symbol list.
- if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
+ if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
unsigned currentNumOperands = result.operands.size();
unsigned numDims;
if (parseDimAndSymbolList(p, result.operands, numDims))
@@ -2106,7 +2107,7 @@ static ParseResult parseBound(bool isLower, OperationState &result,
}
// Parse custom assembly form.
- if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
+ if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
result.attributes.pop_back();
result.addAttribute(
boundAttrStrName,
@@ -2296,9 +2297,9 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) {
// Compute the max or min as applicable over the results.
assert(!foldedResults.empty() && "bounds should have at least one result");
- auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
+ auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
- auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
+ auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
: llvm::APIntOps::smin(maxOrMin, foldedResult);
}
@@ -2653,7 +2654,7 @@ bool mlir::affine::isAffineInductionVar(Value val) {
}
AffineForOp mlir::affine::getForInductionVarOwner(Value val) {
- auto ivArg = val.dyn_cast<BlockArgument>();
+ auto ivArg = llvm::dyn_cast<BlockArgument>(val);
if (!ivArg || !ivArg.getOwner())
return AffineForOp();
auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
@@ -2664,7 +2665,7 @@ AffineForOp mlir::affine::getForInductionVarOwner(Value val) {
}
AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
- auto ivArg = val.dyn_cast<BlockArgument>();
+ auto ivArg = llvm::dyn_cast<BlockArgument>(val);
if (!ivArg || !ivArg.getOwner())
return nullptr;
Operation *containingOp = ivArg.getOwner()->getParentOp();
@@ -3113,7 +3114,7 @@ void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(operands);
if (map)
result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
- auto memrefType = operands[0].getType().cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
result.types.push_back(memrefType.getElementType());
}
@@ -3122,14 +3123,14 @@ void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result.addOperands(memref);
result.addOperands(mapOperands);
- auto memrefType = memref.getType().cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(memref.getType());
result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
result.types.push_back(memrefType.getElementType());
}
void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
Value memref, ValueRange indices) {
- auto memrefType = memref.getType().cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(memref.getType());
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
@@ -3238,11 +3239,11 @@ OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
// Check if the global memref is a constant.
auto cstAttr =
- global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>();
+ llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
if (!cstAttr)
return {};
// If it's a splat constant, we can fold irrespective of indices.
- if (auto splatAttr = cstAttr.dyn_cast<SplatElementsAttr>())
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
return splatAttr.getSplatValue<Attribute>();
// Otherwise, we can fold only if we know the indices.
if (!getAffineMap().isConstant())
@@ -3271,7 +3272,7 @@ void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
Value valueToStore, Value memref,
ValueRange indices) {
- auto memrefType = memref.getType().cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(memref.getType());
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
@@ -4017,7 +4018,7 @@ LogicalResult AffineParallelOp::verify() {
// Verify reduction ops are all valid
for (Attribute attr : getReductions()) {
- auto intAttr = attr.dyn_cast<IntegerAttr>();
+ auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
return emitOpError("invalid reduction attribute");
}
@@ -4119,7 +4120,7 @@ void AffineParallelOp::print(OpAsmPrinter &p) {
p << " reduce (";
llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
- attr.template cast<IntegerAttr>().getInt());
+ llvm::cast<IntegerAttr>(attr).getInt());
p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
});
p << ") -> (" << getResultTypes() << ")";
@@ -4429,7 +4430,7 @@ void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
VectorType resultType, Value memref,
ValueRange indices) {
- auto memrefType = memref.getType().cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(memref.getType());
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
@@ -4520,7 +4521,7 @@ void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
Value valueToStore, Value memref,
ValueRange indices) {
- auto memrefType = memref.getType().cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(memref.getType());
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index b4b0572fdee75..a58354413390e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -36,15 +36,15 @@ using namespace mlir::arith;
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
return builder.getIntegerAttr(res.getType(),
- lhs.cast<IntegerAttr>().getInt() +
- rhs.cast<IntegerAttr>().getInt());
+ llvm::cast<IntegerAttr>(lhs).getInt() +
+ llvm::cast<IntegerAttr>(rhs).getInt());
}
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
return builder.getIntegerAttr(res.getType(),
- lhs.cast<IntegerAttr>().getInt() -
- rhs.cast<IntegerAttr>().getInt());
+ llvm::cast<IntegerAttr>(lhs).getInt() -
+ llvm::cast<IntegerAttr>(rhs).getInt());
}
/// Invert an integer comparison predicate.
@@ -92,11 +92,11 @@ static int64_t getScalarOrElementWidth(Value value) {
}
static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
- if (auto intAttr = attr.dyn_cast<IntegerAttr>())
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
return intAttr.getValue();
- if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>())
- if (splatAttr.getElementType().isa<IntegerType>())
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
+ if (llvm::isa<IntegerType>(splatAttr.getElementType()))
return splatAttr.getSplatValue<APInt>();
return failure();
@@ -117,11 +117,11 @@ namespace {
/// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
- if (auto tensorType = type.dyn_cast<RankedTensorType>())
+ if (auto tensorType = llvm::dyn_cast<RankedTensorType>(type))
return RankedTensorType::get(tensorType.getShape(), i1Type);
- if (type.isa<UnrankedTensorType>())
+ if (llvm::isa<UnrankedTensorType>(type))
return UnrankedTensorType::get(i1Type);
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return VectorType::get(vectorType.getShape(), i1Type,
vectorType.getNumScalableDims());
return i1Type;
@@ -134,8 +134,8 @@ static Type getI1SameShape(Type type) {
void arith::ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
auto type = getType();
- if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
- auto intType = type.dyn_cast<IntegerType>();
+ if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
+ auto intType = llvm::dyn_cast<IntegerType>(type);
// Sugar i1 constants with 'true' and 'false'.
if (intType && intType.getWidth() == 1)
@@ -163,10 +163,11 @@ LogicalResult arith::ConstantOp::verify() {
<< " must match return type: " << type;
}
// Integer values must be signless.
- if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
+ if (llvm::isa<IntegerType>(type) &&
+ !llvm::cast<IntegerType>(type).isSignless())
return emitOpError("integer return type must be signless");
// Any float or elements attribute are acceptable.
- if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
+ if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
return emitOpError(
"value must be an integer, float, or elements attribute");
}
@@ -175,14 +176,15 @@ LogicalResult arith::ConstantOp::verify() {
bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
// The value's type must be the same as the provided type.
- auto typedAttr = value.dyn_cast<TypedAttr>();
+ auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
if (!typedAttr || typedAttr.getType() != type)
return false;
// Integer values must be signless.
- if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
+ if (llvm::isa<IntegerType>(type) &&
+ !llvm::cast<IntegerType>(type).isSignless())
return false;
// Integer, float, and element attributes are buildable.
- return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
+ return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
}
ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
@@ -223,7 +225,7 @@ void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
bool arith::ConstantFloatOp::classof(Operation *op) {
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
- return constOp.getType().isa<FloatType>();
+ return llvm::isa<FloatType>(constOp.getType());
return false;
}
@@ -275,7 +277,7 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
std::optional<SmallVector<int64_t, 4>>
arith::AddUIExtendedOp::getShapeForUnroll() {
- if (auto vt = getType(0).dyn_cast<VectorType>())
+ if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}
@@ -309,7 +311,7 @@ arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
[](APInt a, const APInt &b) { return std::move(a) + b; })) {
Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
ArrayRef({sumAttr, adaptor.getLhs()}),
- getI1SameShape(sumAttr.cast<TypedAttr>().getType()),
+ getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
calculateUnsignedOverflow);
if (!overflowAttr)
return failure();
@@ -385,7 +387,7 @@ OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
std::optional<SmallVector<int64_t, 4>>
arith::MulSIExtendedOp::getShapeForUnroll() {
- if (auto vt = getType(0).dyn_cast<VectorType>())
+ if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}
@@ -433,7 +435,7 @@ void arith::MulSIExtendedOp::getCanonicalizationPatterns(
std::optional<SmallVector<int64_t, 4>>
arith::MulUIExtendedOp::getShapeForUnroll() {
- if (auto vt = getType(0).dyn_cast<VectorType>())
+ if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}
@@ -1093,11 +1095,11 @@ using type_list = std::tuple<Types...> *;
template <typename... ShapedTypes, typename... ElementTypes>
static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
type_list<ElementTypes...>) {
- if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
+ if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
return {};
auto underlyingType = getElementTypeOrSelf(type);
- if (!underlyingType.isa<ElementTypes...>())
+ if (!llvm::isa<ElementTypes...>(underlyingType))
return {};
return underlyingType;
@@ -1133,7 +1135,8 @@ static LogicalResult verifyExtOp(Op op) {
Type srcType = getElementTypeOrSelf(op.getIn().getType());
Type dstType = getElementTypeOrSelf(op.getType());
- if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
+ if (llvm::cast<ValType>(srcType).getWidth() >=
+ llvm::cast<ValType>(dstType).getWidth())
return op.emitError("result type ")
<< dstType << " must be wider than operand type " << srcType;
@@ -1146,7 +1149,8 @@ static LogicalResult verifyTruncateOp(Op op) {
Type srcType = getElementTypeOrSelf(op.getIn().getType());
Type dstType = getElementTypeOrSelf(op.getType());
- if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
+ if (llvm::cast<ValType>(srcType).getWidth() <=
+ llvm::cast<ValType>(dstType).getWidth())
return op.emitError("result type ")
<< dstType << " must be shorter than operand type " << srcType;
@@ -1179,7 +1183,7 @@ OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
}
Type resType = getElementTypeOrSelf(getType());
- unsigned bitWidth = resType.cast<IntegerType>().getWidth();
+ unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
adaptor.getOperands(), getType(),
[bitWidth](const APInt &a, bool &castStatus) {
@@ -1206,7 +1210,7 @@ OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
}
Type resType = getElementTypeOrSelf(getType());
- unsigned bitWidth = resType.cast<IntegerType>().getWidth();
+ unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
adaptor.getOperands(), getType(),
[bitWidth](const APInt &a, bool &castStatus) {
@@ -1259,8 +1263,8 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
Type dstType = getElementTypeOrSelf(getType());
// trunci(zexti(a)) -> trunci(a)
// trunci(sexti(a)) -> trunci(a)
- if (srcType.cast<IntegerType>().getWidth() >
- dstType.cast<IntegerType>().getWidth()) {
+ if (llvm::cast<IntegerType>(srcType).getWidth() >
+ llvm::cast<IntegerType>(dstType).getWidth()) {
setOperand(src);
return getResult();
}
@@ -1276,7 +1280,7 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
}
Type resType = getElementTypeOrSelf(getType());
- unsigned bitWidth = resType.cast<IntegerType>().getWidth();
+ unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
adaptor.getOperands(), getType(),
[bitWidth](const APInt &a, bool &castStatus) {
@@ -1307,12 +1311,12 @@ LogicalResult arith::TruncIOp::verify() {
/// can be represented without precision loss or rounding.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getIn();
- if (!constOperand || !constOperand.isa<FloatAttr>())
+ if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
return {};
// Convert to target type via 'double'.
double sourceValue =
- constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
+ llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
auto targetAttr = FloatAttr::get(getType(), sourceValue);
// Propagate if constant's value does not change after truncation.
@@ -1376,7 +1380,7 @@ OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
return constFoldCastOp<IntegerAttr, FloatAttr>(
adaptor.getOperands(), getType(),
[&resEleType](const APInt &a, bool &castStatus) {
- FloatType floatTy = resEleType.cast<FloatType>();
+ FloatType floatTy = llvm::cast<FloatType>(resEleType);
APFloat apf(floatTy.getFloatSemantics(),
APInt::getZero(floatTy.getWidth()));
apf.convertFromAPInt(a, /*IsSigned=*/false,
@@ -1398,7 +1402,7 @@ OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
return constFoldCastOp<IntegerAttr, FloatAttr>(
adaptor.getOperands(), getType(),
[&resEleType](const APInt &a, bool &castStatus) {
- FloatType floatTy = resEleType.cast<FloatType>();
+ FloatType floatTy = llvm::cast<FloatType>(resEleType);
APFloat apf(floatTy.getFloatSemantics(),
APInt::getZero(floatTy.getWidth()));
apf.convertFromAPInt(a, /*IsSigned=*/true,
@@ -1416,7 +1420,7 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
Type resType = getElementTypeOrSelf(getType());
- unsigned bitWidth = resType.cast<IntegerType>().getWidth();
+ unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
return constFoldCastOp<FloatAttr, IntegerAttr>(
adaptor.getOperands(), getType(),
[&bitWidth](const APFloat &a, bool &castStatus) {
@@ -1438,7 +1442,7 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
Type resType = getElementTypeOrSelf(getType());
- unsigned bitWidth = resType.cast<IntegerType>().getWidth();
+ unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
return constFoldCastOp<FloatAttr, IntegerAttr>(
adaptor.getOperands(), getType(),
[&bitWidth](const APFloat &a, bool &castStatus) {
@@ -1542,18 +1546,18 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
return {};
/// Bitcast dense elements.
- if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
- return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
+ if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
+ return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
/// Other shaped types unhandled.
- if (resType.isa<ShapedType>())
+ if (llvm::isa<ShapedType>(resType))
return {};
/// Bitcast integer or float to integer or float.
- APInt bits = operand.isa<FloatAttr>()
- ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
- : operand.cast<IntegerAttr>().getValue();
+ APInt bits = llvm::isa<FloatAttr>(operand)
+ ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
+ : llvm::cast<IntegerAttr>(operand).getValue();
- if (auto resFloatType = resType.dyn_cast<FloatType>())
+ if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
return FloatAttr::get(resType,
APFloat(resFloatType.getFloatSemantics(), bits));
return IntegerAttr::get(resType, bits);
@@ -1618,18 +1622,18 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
auto boolAttr = BoolAttr::get(ctx, value);
- ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
+ ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
if (!shapedType)
return boolAttr;
return DenseElementsAttr::get(shapedType, boolAttr);
}
static std::optional<int64_t> getIntegerWidth(Type t) {
- if (auto intType = t.dyn_cast<IntegerType>()) {
+ if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
return intType.getWidth();
}
- if (auto vectorIntType = t.dyn_cast<VectorType>()) {
- return vectorIntType.getElementType().cast<IntegerType>().getWidth();
+ if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
+ return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
}
return std::nullopt;
}
@@ -1817,7 +1821,7 @@ class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
// Get the width of the mantissa. We don't want to hack on conversions that
// might lose information from the integer, e.g. "i64 -> float"
- FloatType floatTy = op.getRhs().getType().cast<FloatType>();
+ FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
int mantissaWidth = floatTy.getFPMantissaWidth();
if (mantissaWidth <= 0)
return failure();
@@ -1837,7 +1841,7 @@ class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
// Check to see that the input is converted from an integer type that is
// small enough that preserves all bits.
- auto intTy = intVal.getType().cast<IntegerType>();
+ auto intTy = llvm::cast<IntegerType>(intVal.getType());
auto intWidth = intTy.getWidth();
// Number of bits representing values, as opposed to the sign
@@ -2103,7 +2107,7 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
LogicalResult matchAndRewrite(arith::SelectOp op,
PatternRewriter &rewriter) const override {
// Cannot extui i1 to i1, or i1 to f32
- if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
+ if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
return failure();
// select %x, c1, %c0 => extui %arg
@@ -2230,7 +2234,8 @@ void arith::SelectOp::print(OpAsmPrinter &p) {
p << " " << getOperands();
p.printOptionalAttrDict((*this)->getAttrs());
p << " : ";
- if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
+ if (ShapedType condType =
+ llvm::dyn_cast<ShapedType>(getCondition().getType()))
p << condType << ", ";
p << getType();
}
@@ -2243,7 +2248,7 @@ LogicalResult arith::SelectOp::verify() {
// If the result type is a vector or tensor, the type can be a mask with the
// same elements.
Type resultType = getType();
- if (!resultType.isa<TensorType, VectorType>())
+ if (!llvm::isa<TensorType, VectorType>(resultType))
return emitOpError() << "expected condition to be a signless i1, but got "
<< conditionType;
Type shapedConditionType = getI1SameShape(resultType);
@@ -2320,7 +2325,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
case AtomicRMWKind::maxf:
return builder.getFloatAttr(
resultType,
- APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
+ APFloat::getInf(llvm::cast<FloatType>(resultType).getFloatSemantics(),
/*Negative=*/true));
case AtomicRMWKind::addf:
case AtomicRMWKind::addi:
@@ -2330,24 +2335,24 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
resultType,
- APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
+ APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
case AtomicRMWKind::maxs:
return builder.getIntegerAttr(
- resultType,
- APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
+ resultType, APInt::getSignedMinValue(
+ llvm::cast<IntegerType>(resultType).getWidth()));
case AtomicRMWKind::minf:
return builder.getFloatAttr(
resultType,
- APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
+ APFloat::getInf(llvm::cast<FloatType>(resultType).getFloatSemantics(),
/*Negative=*/false));
case AtomicRMWKind::mins:
return builder.getIntegerAttr(
- resultType,
- APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
+ resultType, APInt::getSignedMaxValue(
+ llvm::cast<IntegerType>(resultType).getWidth()));
case AtomicRMWKind::minu:
return builder.getIntegerAttr(
resultType,
- APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
+ APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
case AtomicRMWKind::muli:
return builder.getIntegerAttr(resultType, 1);
case AtomicRMWKind::mulf:
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 16ff6dd2f6204..71eb36bb07a6e 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -25,7 +25,7 @@ using namespace mlir::intrange;
void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- auto constAttr = getValue().dyn_cast_or_null<IntegerAttr>();
+ auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
if (constAttr) {
const APInt &value = constAttr.getValue();
setResultRange(getResult(), ConstantIntRanges::constant(value));
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index d24f1870ac078..9c6b50e767ea2 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -37,7 +37,7 @@ struct ConstantOpInterface
auto constantOp = cast<ConstantOp>(op);
assert(value == constantOp.getResult() && "invalid value");
- if (auto attr = constantOp.getValue().dyn_cast<IntegerAttr>())
+ if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue()))
cstr.bound(value) == attr.getInt();
}
};
diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
index 1ea2fad56272a..765242b5416af 100644
--- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
+++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp
@@ -28,7 +28,7 @@ using namespace mlir::arm_sve;
/// Return the scalable vector of the same shape and containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
- if (auto sVectorType = type.dyn_cast<VectorType>())
+ if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
return VectorType::get(sVectorType.getShape(), i1Type,
sVectorType.getNumScalableDims());
return nullptr;
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 8ccec11b8e218..7d018bf8f3a3d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -42,7 +42,7 @@ LogicalResult YieldOp::verify() {
auto executeOp = (*this)->getParentOfType<ExecuteOp>();
auto types =
llvm::map_range(executeOp.getBodyResults(), [](const OpResult &result) {
- return result.getType().cast<ValueType>().getValueType();
+ return llvm::cast<ValueType>(result.getType()).getValueType();
});
if (getOperandTypes() != types)
@@ -71,7 +71,7 @@ ExecuteOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
const auto getValueOrTokenType = [](Type type) {
- if (auto value = type.dyn_cast<ValueType>())
+ if (auto value = llvm::dyn_cast<ValueType>(type))
return value.getValueType();
return type;
};
@@ -118,7 +118,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
bodyRegion->push_back(new Block);
Block &bodyBlock = bodyRegion->front();
for (Value operand : operands) {
- auto valueType = operand.getType().dyn_cast<ValueType>();
+ auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
bodyBlock.addArgument(valueType ? valueType.getValueType()
: operand.getType(),
operand.getLoc());
@@ -195,7 +195,7 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseColonType(valueTypes.emplace_back()))
return failure();
- auto valueTy = valueTypes.back().dyn_cast<ValueType>();
+ auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back());
unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
return success();
};
@@ -234,7 +234,7 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
LogicalResult ExecuteOp::verifyRegions() {
// Unwrap async.execute value operands types.
auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
- return operand.getType().cast<ValueType>().getValueType();
+ return llvm::cast<ValueType>(operand.getType()).getValueType();
});
// Verify that unwrapped argument types matches the body region arguments.
@@ -285,7 +285,7 @@ void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
result.attributes.append(attrs.begin(), attrs.end());
// Add unwrapped async.value type to the returned values types.
- if (auto valueType = operand.getType().dyn_cast<ValueType>())
+ if (auto valueType = llvm::dyn_cast<ValueType>(operand.getType()))
result.addTypes(valueType.getValueType());
}
@@ -295,7 +295,7 @@ static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
return failure();
// Add unwrapped async.value type to the returned values types.
- if (auto valueType = operandType.dyn_cast<ValueType>())
+ if (auto valueType = llvm::dyn_cast<ValueType>(operandType))
resultType = valueType.getValueType();
return success();
@@ -310,11 +310,11 @@ LogicalResult AwaitOp::verify() {
Type argType = getOperand().getType();
// Awaiting on a token does not have any results.
- if (argType.isa<TokenType>() && !getResultTypes().empty())
+ if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
return emitOpError("awaiting on a token must have empty result");
// Awaiting on a value unwraps the async value type.
- if (auto value = argType.dyn_cast<ValueType>()) {
+ if (auto value = llvm::dyn_cast<ValueType>(argType)) {
if (*getResultType() != value.getValueType())
return emitOpError() << "result type " << *getResultType()
<< " does not match async value type "
@@ -375,12 +375,12 @@ LogicalResult FuncOp::verify() {
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
auto type = resultTypes[i];
- if (!type.isa<TokenType>() && !type.isa<ValueType>())
+ if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
return emitOpError() << "result type must be async value type or async "
"token type, but got "
<< type;
// We only allow AsyncToken appear as the first return value
- if (type.isa<TokenType>() && i != 0) {
+ if (llvm::isa<TokenType>(type) && i != 0) {
return emitOpError()
<< " results' (optional) async token type is expected "
"to appear as the 1st return value, but got "
@@ -446,7 +446,7 @@ LogicalResult ReturnOp::verify() {
// Get the underlying value types from async types returned from the
// parent `async.func` operation.
auto types = llvm::map_range(resultTypes, [](const Type &result) {
- return result.cast<ValueType>().getValueType();
+ return llvm::cast<ValueType>(result).getValueType();
});
if (getOperandTypes() != types)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d126b4c02b705..cb69a9e5879c0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -91,9 +91,9 @@ Region *bufferization::getNextEnclosingRepetitiveRegion(
}
Operation *bufferization::getOwnerOfValue(Value value) {
- if (auto opResult = value.dyn_cast<OpResult>())
+ if (auto opResult = llvm::dyn_cast<OpResult>(value))
return opResult.getDefiningOp();
- return value.cast<BlockArgument>().getOwner()->getParentOp();
+ return llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
}
bool bufferization::allocationDoesNotEscape(OpResult opResult) {
@@ -109,7 +109,7 @@ bool bufferization::allocationDoesNotEscape(OpResult opResult) {
return false;
auto attr =
op->getAttrOfType<ArrayAttr>(BufferizationDialect::kEscapeAttrName);
- return !attr[opResult.getResultNumber()].cast<BoolAttr>().getValue();
+ return !llvm::cast<BoolAttr>(attr[opResult.getResultNumber()]).getValue();
}
/// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
@@ -119,31 +119,31 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue, bool escape,
const BufferizationOptions &options, bool copy) {
Value tensor;
- if (shapedValue.getType().isa<RankedTensorType>()) {
+ if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
- } else if (shapedValue.getType().isa<MemRefType>()) {
+ } else if (llvm::isa<MemRefType>(shapedValue.getType())) {
tensor = b.create<ToTensorOp>(loc, shapedValue);
- } else if (shapedValue.getType().isa<UnrankedTensorType>() ||
- shapedValue.getType().isa<UnrankedMemRefType>()) {
+ } else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
+ llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
return getOwnerOfValue(shapedValue)
->emitError("copying of unranked tensors is not implemented");
} else {
llvm_unreachable("expected RankedTensorType or MemRefType");
}
- RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
+ RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.getType());
SmallVector<Value> dynamicSizes;
if (!copy) {
// Compute the dynamic part of the shape.
// First try to query the shape via ReifyRankedShapedTypeOpInterface.
bool reifiedShapes = false;
- if (shapedValue.getType().isa<RankedTensorType>() &&
- shapedValue.isa<OpResult>()) {
+ if (llvm::isa<RankedTensorType>(shapedValue.getType()) &&
+ llvm::isa<OpResult>(shapedValue)) {
ReifiedRankedShapedTypeDims resultDims;
if (succeeded(
reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) {
reifiedShapes = true;
auto &shape =
- resultDims[shapedValue.cast<OpResult>().getResultNumber()];
+ resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
for (const auto &dim : enumerate(tensorType.getShape()))
if (ShapedType::isDynamic(dim.value()))
dynamicSizes.push_back(shape[dim.index()].get<Value>());
@@ -188,11 +188,11 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// Find all out-of-place OpOperands.
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
- if (!operandType.isa<TensorType>())
+ if (!llvm::isa<TensorType>(operandType))
continue;
if (state.isInPlace(opOperand))
continue;
- if (operandType.isa<UnrankedTensorType>())
+ if (llvm::isa<UnrankedTensorType>(operandType))
return op->emitError("copying of unranked tensors is not implemented");
AliasingOpResultList aliasingOpResults =
@@ -209,9 +209,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
!state.bufferizesToMemoryWrite(opOperand) &&
state.getAliasingOpOperands(aliasingOpResults.getAliases()[0].opResult)
.getNumAliases() == 1 &&
- !aliasingOpResults.getAliases()[0]
- .opResult.getType()
- .isa<UnrankedTensorType>()) {
+ !llvm::isa<UnrankedTensorType>(
+ aliasingOpResults.getAliases()[0].opResult.getType())) {
// The op itself does not write but may create exactly one alias. Instead
// of copying the OpOperand, copy the OpResult. The OpResult can sometimes
// be smaller than the OpOperand (e.g., in the case of an extract_slice,
@@ -281,9 +280,9 @@ bool bufferization::shouldDeallocateOpResult(
AnalysisState analysisState(options);
if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) {
// AllocTensorOp has one result.
- ArrayAttr escapeAttr =
- op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>();
- return !escapeAttr[0].cast<BoolAttr>().getValue();
+ ArrayAttr escapeAttr = llvm::cast<ArrayAttr>(
+ op->getAttr(BufferizationDialect::kEscapeAttrName));
+ return !llvm::cast<BoolAttr>(escapeAttr[0]).getValue();
}
// No "escape" annotation found.
@@ -335,8 +334,8 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
BaseMemRefType
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
- memorySpace);
+ return getMemRefTypeWithFullyDynamicLayout(
+ llvm::cast<TensorType>(value.getType()), memorySpace);
}
} // namespace
@@ -394,7 +393,7 @@ void BufferizationOptions::setFunctionBoundaryTypeConversion(
//===----------------------------------------------------------------------===//
static void setInsertionPointAfter(OpBuilder &b, Value value) {
- if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = llvm::dyn_cast<BlockArgument>(value)) {
b.setInsertionPointToStart(bbArg.getOwner());
} else {
b.setInsertionPointAfter(value.getDefiningOp());
@@ -463,7 +462,7 @@ bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
}
bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
- auto opResult = value.dyn_cast<OpResult>();
+ auto opResult = llvm::dyn_cast<OpResult>(value);
if (!opResult)
return true;
auto bufferizableOp = getOptions().dynCastBufferizableOp(value);
@@ -476,7 +475,7 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
bool AnalysisState::isValueRead(Value value) const {
- assert(value.getType().isa<TensorType>() && "expected TensorType");
+ assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
SmallVector<OpOperand *> workingSet;
for (OpOperand &use : value.getUses())
workingSet.push_back(&use);
@@ -512,13 +511,13 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
continue;
}
- if (value.isa<BlockArgument>()) {
+ if (llvm::isa<BlockArgument>(value)) {
if (alwaysIncludeLeaves)
result.insert(value);
continue;
}
- OpResult opResult = value.cast<OpResult>();
+ OpResult opResult = llvm::cast<OpResult>(value);
BufferizableOpInterface bufferizableOp =
options.dynCastBufferizableOp(opResult.getDefiningOp());
AliasingOpOperandList aliases = getAliasingOpOperands(opResult);
@@ -658,8 +657,8 @@ bool AnalysisState::isTensorYielded(Value tensor) const {
// bufferization.to_memref is not allowed to change the rank.
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#ifndef NDEBUG
- auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
- assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
+ auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
+ assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
rankedTensorType.getRank()) &&
"to_memref would be invalid: mismatching ranks");
#endif
@@ -668,7 +667,7 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
#ifndef NDEBUG
- auto tensorType = value.getType().dyn_cast<TensorType>();
+ auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
#endif // NDEBUG
@@ -699,7 +698,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
FailureOr<BaseMemRefType> bufferization::getBufferType(
Value value, const BufferizationOptions &options,
const DenseMap<Value, BaseMemRefType> &fixedTypes) {
- assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
+ assert(llvm::isa<TensorType>(value.getType()) &&
+ "unexpected non-tensor type");
// If the `value` is in `fixedTypes`, return the mapped type.
const auto &it = fixedTypes.find(value);
@@ -731,11 +731,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
Value replacement = values[opResult.getResultNumber()];
- if (opResult.getType().isa<TensorType>()) {
+ if (llvm::isa<TensorType>(opResult.getType())) {
// The OpResult is a tensor. Such values are replaced with memrefs during
// bufferization.
- assert((replacement.getType().isa<MemRefType>() ||
- replacement.getType().isa<UnrankedMemRefType>()) &&
+ assert((llvm::isa<MemRefType>(replacement.getType()) ||
+ llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
"tensor op result should be replaced with a memref value");
// The existing uses of the OpResult still expect a tensor. Insert a
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
@@ -797,7 +797,7 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
//===----------------------------------------------------------------------===//
bool bufferization::isFunctionArgument(Value value) {
- auto bbArg = value.dyn_cast<BlockArgument>();
+ auto bbArg = llvm::dyn_cast<BlockArgument>(value);
if (!bbArg)
return false;
return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
@@ -807,17 +807,18 @@ BaseMemRefType bufferization::getMemRefType(Value value,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
- auto tensorType = value.getType().cast<TensorType>();
+ auto tensorType = llvm::cast<TensorType>(value.getType());
// Case 1: Unranked memref type.
- if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
+ if (auto unrankedTensorType =
+ llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
assert(!layout && "UnrankedTensorType cannot have a layout map");
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
}
// Case 2: Ranked memref type with specified layout.
- auto rankedTensorType = tensorType.cast<RankedTensorType>();
+ auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
if (layout) {
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
@@ -831,13 +832,14 @@ BaseMemRefType
bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
Attribute memorySpace) {
// Case 1: Unranked memref type.
- if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
+ if (auto unrankedTensorType =
+ llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
}
// Case 2: Ranked memref type.
- auto rankedTensorType = tensorType.cast<RankedTensorType>();
+ auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
int64_t dynamicOffset = ShapedType::kDynamic;
SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
ShapedType::kDynamic);
@@ -854,13 +856,14 @@ BaseMemRefType
bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
Attribute memorySpace) {
// Case 1: Unranked memref type.
- if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
+ if (auto unrankedTensorType =
+ llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
}
// Case 2: Ranked memref type.
- auto rankedTensorType = tensorType.cast<RankedTensorType>();
+ auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
MemRefLayoutAttrInterface layout = {};
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
@@ -943,7 +946,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
Operation *op = opResult.getDefiningOp();
SmallVector<AliasingOpOperand> result;
for (OpOperand &opOperand : op->getOpOperands()) {
- if (!opOperand.get().getType().isa<TensorType>())
+ if (!llvm::isa<TensorType>(opOperand.get().getType()))
continue;
AliasingOpResultList aliasingOpResults =
state.getAliasingOpResults(opOperand);
@@ -957,15 +960,15 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
const DenseMap<Value, BaseMemRefType> &fixedTypes) {
- assert(value.getType().isa<TensorType>() && "expected tensor type");
+ assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
// No further analysis is possible for a block argument.
- if (value.isa<BlockArgument>())
+ if (llvm::isa<BlockArgument>(value))
return bufferization::getMemRefType(value, options);
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
- auto opResult = value.cast<OpResult>();
+ auto opResult = llvm::cast<OpResult>(value);
AnalysisState state(options);
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
if (aliases.getNumAliases() > 0 &&
@@ -1000,7 +1003,7 @@ bufferization::detail::unknownGetAliasingOpOperands(OpResult opResult) {
// Conservatively assume that everything may be aliasing.
AliasingOpOperandList r;
for (OpOperand &operand : opResult.getDefiningOp()->getOpOperands())
- if (operand.get().getType().isa<TensorType>())
+ if (llvm::isa<TensorType>(operand.get().getType()))
r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false});
return r;
}
@@ -1010,7 +1013,7 @@ bufferization::detail::unknownGetAliasingOpResults(OpOperand &opOperand) {
// Conservatively assume that everything may be aliasing.
AliasingOpResultList r;
for (OpResult result : opOperand.getOwner()->getOpResults())
- if (result.getType().isa<TensorType>())
+ if (llvm::isa<TensorType>(result.getType()))
r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false});
return r;
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 0052829e2dd43..0dfd9a0d1a0bc 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -62,7 +62,7 @@ LogicalResult BufferizationDialect::verifyRegionArgAttribute(
Operation *op, unsigned /*regionIndex*/, unsigned argIndex,
NamedAttribute attr) {
if (attr.getName() == kWritableAttrName) {
- if (!attr.getValue().isa<BoolAttr>()) {
+ if (!llvm::isa<BoolAttr>(attr.getValue())) {
return op->emitError() << "'" << kWritableAttrName
<< "' is expected to be a boolean attribute";
}
@@ -75,11 +75,11 @@ LogicalResult BufferizationDialect::verifyRegionArgAttribute(
return success();
}
if (attr.getName() == kBufferAccessAttrName) {
- if (!attr.getValue().isa<StringAttr>()) {
+ if (!llvm::isa<StringAttr>(attr.getValue())) {
return op->emitError() << "'" << kBufferAccessAttrName
<< "' is expected to be a string attribute";
}
- StringRef str = attr.getValue().cast<StringAttr>().getValue();
+ StringRef str = llvm::cast<StringAttr>(attr.getValue()).getValue();
if (str != "none" && str != "read" && str != "write" && str != "read-write")
return op->emitError()
<< "invalid value for '" << kBufferAccessAttrName << "'";
@@ -89,7 +89,7 @@ LogicalResult BufferizationDialect::verifyRegionArgAttribute(
return success();
}
if (attr.getName() == kBufferLayoutAttrName) {
- if (!attr.getValue().isa<AffineMapAttr>()) {
+ if (!llvm::isa<AffineMapAttr>(attr.getValue())) {
return op->emitError() << "'" << kBufferLayoutAttrName
<< "' is expected to be a affine map attribute";
}
@@ -109,7 +109,7 @@ BufferizationDialect::verifyOperationAttribute(Operation *op,
using bufferization::BufferizableOpInterface;
if (attr.getName() == kEscapeAttrName) {
- auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
+ auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr.getValue());
if (!arrayAttr)
return op->emitError() << "'" << kEscapeAttrName
<< "' is expected to be a bool array attribute";
@@ -124,13 +124,13 @@ BufferizationDialect::verifyOperationAttribute(Operation *op,
<< "'" << kEscapeAttrName << "' only valid on bufferizable ops";
for (const auto &it : llvm::enumerate(arrayAttr)) {
auto attr = it.value();
- auto boolAttr = attr.dyn_cast<BoolAttr>();
+ auto boolAttr = llvm::dyn_cast<BoolAttr>(attr);
if (!boolAttr)
return op->emitError() << "'" << kEscapeAttrName
<< "' is expected to be a bool array attribute";
if (!boolAttr.getValue())
continue;
- if (!op->getResult(it.index()).getType().isa<TensorType>())
+ if (!llvm::isa<TensorType>(op->getResult(it.index()).getType()))
return op->emitError()
<< "'" << kEscapeAttrName << "' only valid for tensor results";
if (!bufferizableOp.bufferizesToAllocation(op->getOpResult(it.index())))
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index a25ca3806c11e..eeba571658206 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -27,7 +27,7 @@ using namespace mlir::bufferization;
FailureOr<Value>
mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
MemRefType destType) {
- auto srcType = value.getType().cast<MemRefType>();
+ auto srcType = llvm::cast<MemRefType>(value.getType());
// Element type, rank and memory space must match.
if (srcType.getElementType() != destType.getElementType())
@@ -100,9 +100,9 @@ mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
return success();
}
- auto rankedSrcType = srcType.dyn_cast<MemRefType>();
- auto rankedDestType = destType.dyn_cast<MemRefType>();
- auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
+ auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
+ auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
+ auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
// Ranked memref -> Ranked memref cast.
if (rankedSrcType && rankedDestType) {
@@ -132,13 +132,13 @@ mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
void mlir::bufferization::populateDynamicDimSizes(
OpBuilder &b, Location loc, Value shapedValue,
SmallVector<Value> &dynamicDims) {
- auto shapedType = shapedValue.getType().cast<ShapedType>();
+ auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
for (int64_t i = 0; i < shapedType.getRank(); ++i) {
if (shapedType.isDynamicDim(i)) {
- if (shapedType.isa<MemRefType>()) {
+ if (llvm::isa<MemRefType>(shapedType)) {
dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i));
} else {
- assert(shapedType.isa<RankedTensorType>() && "expected tensor");
+ assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i));
}
}
@@ -191,7 +191,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
// Should the buffer be deallocated?
bool dealloc =
- shouldDeallocateOpResult(getResult().cast<OpResult>(), options);
+ shouldDeallocateOpResult(llvm::cast<OpResult>(getResult()), options);
// Replace op.
replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
@@ -431,7 +431,7 @@ void AllocTensorOp::print(OpAsmPrinter &p) {
AllocTensorOp::getOperandSegmentSizeAttr()});
p << " : ";
auto type = getResult().getType();
- if (auto validType = type.dyn_cast<::mlir::TensorType>())
+ if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
p.printStrippedAttrOrType(validType);
else
p << type;
@@ -620,8 +620,8 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
toMemref.getOperand().getDefiningOp<tensor::CastOp>();
if (!tensorCastOperand)
return failure();
- auto srcTensorType =
- tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
+ auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
+ tensorCastOperand.getOperand().getType());
if (!srcTensorType)
return failure();
auto memrefType = MemRefType::get(srcTensorType.getShape(),
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
index 0a2691a113f71..da57d254676eb 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
@@ -34,7 +34,7 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
Location loc) {
if (complex::ConstantOp::isBuildableWith(value, type)) {
return builder.create<complex::ConstantOp>(loc, type,
- value.cast<ArrayAttr>());
+ llvm::cast<ArrayAttr>(value));
}
return arith::ConstantOp::materialize(builder, value, type, loc);
}
@@ -46,16 +46,16 @@ LogicalResult complex::NumberAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
::llvm::APFloat real, ::llvm::APFloat imag, ::mlir::Type type) {
- if (!type.isa<ComplexType>())
+ if (!llvm::isa<ComplexType>(type))
return emitError() << "complex attribute must be a complex type.";
- Type elementType = type.cast<ComplexType>().getElementType();
- if (!elementType.isa<FloatType>())
+ Type elementType = llvm::cast<ComplexType>(type).getElementType();
+ if (!llvm::isa<FloatType>(elementType))
return emitError()
<< "element type of the complex attribute must be float like type.";
const auto &typeFloatSemantics =
- elementType.cast<FloatType>().getFloatSemantics();
+ llvm::cast<FloatType>(elementType).getFloatSemantics();
if (&real.getSemantics() != &typeFloatSemantics)
return emitError()
<< "type doesn't match the type implied by its `real` value";
@@ -67,7 +67,7 @@ LogicalResult complex::NumberAttr::verify(
}
void complex::NumberAttr::print(AsmPrinter &printer) const {
- printer << "<:" << getType().cast<ComplexType>().getElementType() << " "
+ printer << "<:" << llvm::cast<ComplexType>(getType()).getElementType() << " "
<< getReal() << ", " << getImag() << ">";
}
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 28e121b6026a5..02c0e1643f3c6 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -27,18 +27,18 @@ void ConstantOp::getAsmResultNames(
}
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
- if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
- auto complexTy = type.dyn_cast<ComplexType>();
+ if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
+ auto complexTy = llvm::dyn_cast<ComplexType>(type);
if (!complexTy || arrAttr.size() != 2)
return false;
auto complexEltTy = complexTy.getElementType();
- if (auto fre = arrAttr[0].dyn_cast<FloatAttr>()) {
- auto im = arrAttr[1].dyn_cast<FloatAttr>();
+ if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
+ auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
return im && fre.getType() == complexEltTy &&
im.getType() == complexEltTy;
}
- if (auto ire = arrAttr[0].dyn_cast<IntegerAttr>()) {
- auto im = arrAttr[1].dyn_cast<IntegerAttr>();
+ if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
+ auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
return im && ire.getType() == complexEltTy &&
im.getType() == complexEltTy;
}
@@ -55,8 +55,8 @@ LogicalResult ConstantOp::verify() {
}
auto complexEltTy = getType().getElementType();
- auto re = arrayAttr[0].dyn_cast<FloatAttr>();
- auto im = arrayAttr[1].dyn_cast<FloatAttr>();
+ auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
+ auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
if (!re || !im)
return emitOpError("requires attribute's elements to be float attributes");
if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
@@ -129,8 +129,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
// complex.add(a, complex.constant<0.0, 0.0>) -> a
if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
auto arrayAttr = constantOp.getValue();
- if (arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
- arrayAttr[1].cast<FloatAttr>().getValue().isZero()) {
+ if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
+ llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
return getLhs();
}
}
@@ -151,8 +151,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
// complex.sub(a, complex.constant<0.0, 0.0>) -> a
if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
auto arrayAttr = constantOp.getValue();
- if (arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
- arrayAttr[1].cast<FloatAttr>().getValue().isZero()) {
+ if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
+ llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
return getLhs();
}
}
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 6bdf42e527887..0a86d8f15b0d6 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -125,7 +125,7 @@ static LogicalResult collapseBranch(Block *&successor,
// Otherwise, we need to remap any argument operands.
for (Value operand : operands) {
- BlockArgument argOperand = operand.dyn_cast<BlockArgument>();
+ BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
if (argOperand && argOperand.getOwner() == successor)
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
else
@@ -442,7 +442,8 @@ SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
}
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
- if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
+ if (IntegerAttr condAttr =
+ llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
return nullptr;
}
@@ -601,7 +602,7 @@ Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
return getDefaultDestination();
SuccessorRange caseDests = getCaseDestinations();
- if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
+ if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
if (it.value() == value.getValue())
return caseDests[it.index()];
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index 889778a1177cb..3970c9c659a22 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -215,7 +215,7 @@ combineOneSpec(DataLayoutSpecInterface spec,
typeSample.getContext()->getLoadedDialect<BuiltinDialect>() &&
"unexpected data layout entry for built-in type");
- auto interface = typeSample.cast<DataLayoutTypeInterface>();
+ auto interface = llvm::cast<DataLayoutTypeInterface>(typeSample);
if (!interface.areCompatible(entriesForType.lookup(kvp.first), kvp.second))
return failure();
@@ -250,7 +250,7 @@ DataLayoutSpecAttr::combineWith(ArrayRef<DataLayoutSpecInterface> specs) const {
// Only combine with attributes of the same kind.
// TODO: reconsider this when the need arises.
if (llvm::any_of(specs, [](DataLayoutSpecInterface spec) {
- return !spec.isa<DataLayoutSpecAttr>();
+ return !llvm::isa<DataLayoutSpecAttr>(spec);
}))
return {};
@@ -334,7 +334,7 @@ class TargetDataLayoutInterface : public DataLayoutDialectInterface {
Location loc) const final {
StringRef entryName = entry.getKey().get<StringAttr>().strref();
if (entryName == DLTIDialect::kDataLayoutEndiannessKey) {
- auto value = entry.getValue().dyn_cast<StringAttr>();
+ auto value = llvm::dyn_cast<StringAttr>(entry.getValue());
if (value &&
(value.getValue() == DLTIDialect::kDataLayoutEndiannessBig ||
value.getValue() == DLTIDialect::kDataLayoutEndiannessLittle))
@@ -383,7 +383,7 @@ void DLTIDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
if (attr.getName() == DLTIDialect::kDataLayoutAttrName) {
- if (!attr.getValue().isa<DataLayoutSpecAttr>()) {
+ if (!llvm::isa<DataLayoutSpecAttr>(attr.getValue())) {
return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName
<< "' is expected to be a #dlti.dl_spec attribute";
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 71904762d0b13..cbf615a3972b7 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -73,10 +73,10 @@ LogicalResult ApplyOp::verify() {
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
Type input = inputs.front(), output = outputs.front();
- return ((input.isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
- emitc::PointerType>()) &&
- (output.isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
- emitc::PointerType>()));
+ return ((llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
+ emitc::PointerType>(input)) &&
+ (llvm::isa<IntegerType, FloatType, IndexType, emitc::OpaqueType,
+ emitc::PointerType>(output)));
}
//===----------------------------------------------------------------------===//
@@ -90,8 +90,8 @@ LogicalResult emitc::CallOp::verify() {
if (std::optional<ArrayAttr> argsAttr = getArgs()) {
for (Attribute arg : *argsAttr) {
- auto intAttr = arg.dyn_cast<IntegerAttr>();
- if (intAttr && intAttr.getType().isa<IndexType>()) {
+ auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
+ if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
int64_t index = intAttr.getInt();
// Args with elements of type index must be in range
// [0..operands.size).
@@ -99,7 +99,8 @@ LogicalResult emitc::CallOp::verify() {
return emitOpError("index argument is out of range");
// Args with elements of type ArrayAttr must have a type.
- } else if (arg.isa<ArrayAttr>() /*&& arg.getType().isa<NoneType>()*/) {
+ } else if (llvm::isa<ArrayAttr>(
+ arg) /*&& arg.getType().isa<NoneType>()*/) {
// FIXME: Array attributes never have types
return emitOpError("array argument has no type");
}
@@ -108,7 +109,7 @@ LogicalResult emitc::CallOp::verify() {
if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
for (Attribute tArg : *templateArgsAttr) {
- if (!tArg.isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>())
+ if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
return emitOpError("template argument has invalid type");
}
}
@@ -122,17 +123,17 @@ LogicalResult emitc::CallOp::verify() {
/// The constant op requires that the attribute's type matches the return type.
LogicalResult emitc::ConstantOp::verify() {
- if (getValueAttr().isa<emitc::OpaqueAttr>())
+ if (llvm::isa<emitc::OpaqueAttr>(getValueAttr()))
return success();
// Value must not be empty
- StringAttr strAttr = getValueAttr().dyn_cast<StringAttr>();
+ StringAttr strAttr = llvm::dyn_cast<StringAttr>(getValueAttr());
if (strAttr && strAttr.getValue().empty())
return emitOpError() << "value must not be empty";
auto value = cast<TypedAttr>(getValueAttr());
Type type = getType();
- if (!value.getType().isa<NoneType>() && type != value.getType())
+ if (!llvm::isa<NoneType>(value.getType()) && type != value.getType())
return emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";
return success();
@@ -183,12 +184,12 @@ ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
/// The variable op requires that the attribute's type matches the return type.
LogicalResult emitc::VariableOp::verify() {
- if (getValueAttr().isa<emitc::OpaqueAttr>())
+ if (llvm::isa<emitc::OpaqueAttr>(getValueAttr()))
return success();
auto value = cast<TypedAttr>(getValueAttr());
Type type = getType();
- if (!value.getType().isa<NoneType>() && type != value.getType())
+ if (!llvm::isa<NoneType>(value.getType()) && type != value.getType())
return emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";
return success();
diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
index 7205a7eac18e2..4fa2608785f98 100644
--- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp
+++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp
@@ -112,7 +112,7 @@ Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
if (ConstantOp::isBuildableWith(value, type))
return builder.create<ConstantOp>(loc, type,
- value.cast<FlatSymbolRefAttr>());
+ llvm::cast<FlatSymbolRefAttr>(value));
return nullptr;
}
@@ -209,7 +209,7 @@ void ConstantOp::getAsmResultNames(
}
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
- return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>();
+ return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 9472a6798a094..e5e0327ebf1cf 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -220,7 +220,7 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
- if (!attr.getValue().isa<UnitAttr>() ||
+ if (!llvm::isa<UnitAttr>(attr.getValue()) ||
attr.getName() != getContainerModuleAttrName())
return success();
@@ -368,14 +368,14 @@ static LogicalResult verifyAttributions(Operation *op,
ArrayRef<BlockArgument> attributions,
gpu::AddressSpace memorySpace) {
for (Value v : attributions) {
- auto type = v.getType().dyn_cast<MemRefType>();
+ auto type = llvm::dyn_cast<MemRefType>(v.getType());
if (!type)
return op->emitOpError() << "expected memref type in attribution";
// We can only verify the address space if it hasn't already been lowered
// from the AddressSpaceAttr to a target-specific numeric value.
auto addressSpace =
- type.getMemorySpace().dyn_cast_or_null<gpu::AddressSpaceAttr>();
+ llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
if (!addressSpace)
continue;
if (addressSpace.getValue() != memorySpace)
@@ -395,7 +395,7 @@ static bool verifyReduceOpAndType(gpu::AllReduceOperation opName,
return (opName != gpu::AllReduceOperation::AND &&
opName != gpu::AllReduceOperation::OR &&
opName != gpu::AllReduceOperation::XOR) ||
- resType.isa<IntegerType>();
+ llvm::isa<IntegerType>(resType);
}
LogicalResult gpu::AllReduceOp::verifyRegions() {
@@ -1186,7 +1186,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
size_t attributionIndex = pair.index();
DictionaryAttr attrs;
if (attributes && attributionIndex < attributes.size())
- attrs = attributes[attributionIndex].cast<DictionaryAttr>();
+ attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
if (attrs)
p.printOptionalAttrDict(attrs.getValue());
});
@@ -1221,10 +1221,10 @@ void GPUFuncOp::print(OpAsmPrinter &p) {
static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
StringAttr attrName) {
- auto allAttrs = op->getAttr(attrName).dyn_cast_or_null<ArrayAttr>();
+ auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
if (!allAttrs || index >= allAttrs.size())
return DictionaryAttr();
- return allAttrs[index].cast<DictionaryAttr>();
+ return llvm::cast<DictionaryAttr>(allAttrs[index]);
}
DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
@@ -1238,7 +1238,7 @@ DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
static void setAttributionAttrs(GPUFuncOp op, unsigned index,
DictionaryAttr value, StringAttr attrName) {
MLIRContext *ctx = op.getContext();
- auto allAttrs = op->getAttr(attrName).dyn_cast_or_null<ArrayAttr>();
+ auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
SmallVector<Attribute> elements;
if (allAttrs)
elements.append(allAttrs.begin(), allAttrs.end());
@@ -1379,7 +1379,7 @@ static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op,
auto maybeAttr = op->getAttr(attrName);
if (!maybeAttr)
return success();
- auto array = maybeAttr.dyn_cast<DenseI32ArrayAttr>();
+ auto array = llvm::dyn_cast<DenseI32ArrayAttr>(maybeAttr);
if (!array)
return op.emitOpError(attrName + " must be a dense i32 array");
if (array.size() != 3)
@@ -1536,9 +1536,9 @@ static bool isLastMemrefDimUnitStride(MemRefType type) {
LogicalResult SubgroupMmaLoadMatrixOp::verify() {
auto srcType = getSrcMemref().getType();
auto resType = getRes().getType();
- auto resMatrixType = resType.cast<gpu::MMAMatrixType>();
+ auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
auto operand = resMatrixType.getOperand();
- auto srcMemrefType = srcType.cast<MemRefType>();
+ auto srcMemrefType = llvm::cast<MemRefType>(srcType);
if (!isLastMemrefDimUnitStride(srcMemrefType))
return emitError(
@@ -1558,8 +1558,8 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
LogicalResult SubgroupMmaStoreMatrixOp::verify() {
auto srcType = getSrc().getType();
auto dstType = getDstMemref().getType();
- auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
- auto dstMemrefType = dstType.cast<MemRefType>();
+ auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
+ auto dstMemrefType = llvm::cast<MemRefType>(dstType);
if (!isLastMemrefDimUnitStride(dstMemrefType))
return emitError(
@@ -1579,9 +1579,9 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
LogicalResult SubgroupMmaComputeOp::verify() {
enum OperandMap { A, B, C };
SmallVector<MMAMatrixType, 3> opTypes;
- opTypes.push_back(getOpA().getType().cast<MMAMatrixType>());
- opTypes.push_back(getOpB().getType().cast<MMAMatrixType>());
- opTypes.push_back(getOpC().getType().cast<MMAMatrixType>());
+ opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
+ opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
+ opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
if (!opTypes[A].getOperand().equals("AOp") ||
!opTypes[B].getOperand().equals("BOp") ||
@@ -1688,7 +1688,7 @@ void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
LogicalResult AllocOp::verify() {
- auto memRefType = getMemref().getType().cast<MemRefType>();
+ auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
if (static_cast<int64_t>(getDynamicSizes().size()) !=
memRefType.getNumDynamicDims())
@@ -1719,7 +1719,7 @@ struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
if (!index)
return failure();
- auto memrefType = dimOp.getSource().getType().dyn_cast<MemRefType>();
+ auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
if (!memrefType || !memrefType.isDynamicDim(index.value()))
return failure();
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index 7f7cc50528a79..b6ccb77e53800 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -39,7 +39,8 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
// Materialize integer attributes as `index`.
if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
- if (!indexValue.getType().isa<IndexType>() || !type.isa<IndexType>())
+ if (!llvm::isa<IndexType>(indexValue.getType()) ||
+ !llvm::isa<IndexType>(type))
return nullptr;
assert(indexValue.getValue().getBitWidth() ==
IndexType::kInternalStorageBitWidth);
@@ -399,7 +400,8 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
- return lhsTypes.front().isa<IndexType>() != rhsTypes.front().isa<IndexType>();
+ return llvm::isa<IndexType>(lhsTypes.front()) !=
+ llvm::isa<IndexType>(rhsTypes.front());
}
//===----------------------------------------------------------------------===//
@@ -407,7 +409,8 @@ bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
//===----------------------------------------------------------------------===//
bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
- return lhsTypes.front().isa<IndexType>() != rhsTypes.front().isa<IndexType>();
+ return llvm::isa<IndexType>(lhsTypes.front()) !=
+ llvm::isa<IndexType>(rhsTypes.front());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5380ba0666197..65a20a0c426b2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -179,7 +179,7 @@ ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//
void AllocaOp::print(OpAsmPrinter &p) {
- Type elemTy = getType().cast<LLVM::LLVMPointerType>().getElementType();
+ Type elemTy = llvm::cast<LLVM::LLVMPointerType>(getType()).getElementType();
if (!elemTy)
elemTy = *getElemType();
@@ -220,7 +220,7 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
std::optional<NamedAttribute> alignmentAttr =
result.attributes.getNamed("alignment");
if (alignmentAttr.has_value()) {
- auto alignmentInt = alignmentAttr->getValue().dyn_cast<IntegerAttr>();
+ auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
if (!alignmentInt)
return parser.emitError(parser.getNameLoc(),
"expected integer alignment");
@@ -229,7 +229,7 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
}
// Extract the result type from the trailing function type.
- auto funcType = type.dyn_cast<FunctionType>();
+ auto funcType = llvm::dyn_cast<FunctionType>(type);
if (!funcType || funcType.getNumInputs() != 1 ||
funcType.getNumResults() != 1)
return parser.emitError(
@@ -240,7 +240,7 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
Type resultType = funcType.getResult(0);
- if (auto ptrResultType = resultType.dyn_cast<LLVMPointerType>()) {
+ if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType)) {
if (ptrResultType.isOpaque())
result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
}
@@ -266,7 +266,7 @@ static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType,
}
LogicalResult AllocaOp::verify() {
- return verifyOpaquePtr(getOperation(), getType().cast<LLVMPointerType>(),
+ return verifyOpaquePtr(getOperation(), llvm::cast<LLVMPointerType>(getType()),
getElemType());
}
@@ -410,7 +410,7 @@ static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
size_t index = 0;
llvm::interleave(
- llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
+ llvm::zip(llvm::cast<DenseIntElementsAttr>(caseValues), caseDestinations),
[&](auto i) {
p << " ";
p << std::get<0>(i).getLimitedValue();
@@ -457,11 +457,11 @@ GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
/// Returns the elemental type of any LLVM-compatible vector type or self.
static Type extractVectorElementType(Type type) {
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return vectorType.getElementType();
- if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
+ if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type))
return scalableVectorType.getElementType();
- if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
+ if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type))
return fixedVectorType.getElementType();
return type;
}
@@ -470,7 +470,7 @@ void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
Value basePtr, ArrayRef<GEPArg> indices, bool inbounds,
ArrayRef<NamedAttribute> attributes) {
auto ptrType =
- extractVectorElementType(basePtr.getType()).cast<LLVMPointerType>();
+ llvm::cast<LLVMPointerType>(extractVectorElementType(basePtr.getType()));
assert(!ptrType.isOpaque() &&
"expected non-opaque pointer, provide elementType explicitly when "
"opaque pointers are used");
@@ -543,8 +543,7 @@ void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
result.addAttribute(getInboundsAttrName(result.name),
builder.getUnitAttr());
}
- if (extractVectorElementType(basePtr.getType())
- .cast<LLVMPointerType>()
+ if (llvm::cast<LLVMPointerType>(extractVectorElementType(basePtr.getType()))
.isOpaque())
result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
result.addOperands(basePtr);
@@ -695,7 +694,7 @@ static llvm::Error verifyStructIndices(Type baseGEPType,
LogicalResult LLVM::GEPOp::verify() {
if (failed(verifyOpaquePtr(
getOperation(),
- extractVectorElementType(getType()).cast<LLVMPointerType>(),
+ llvm::cast<LLVMPointerType>(extractVectorElementType(getType())),
getElemType())))
return failure();
@@ -716,8 +715,8 @@ Type LLVM::GEPOp::getSourceElementType() {
if (std::optional<Type> elemType = getElemType())
return *elemType;
- return extractVectorElementType(getBase().getType())
- .cast<LLVMPointerType>()
+ return llvm::cast<LLVMPointerType>(
+ extractVectorElementType(getBase().getType()))
.getElementType();
}
@@ -729,16 +728,16 @@ Type LLVM::GEPOp::getSourceElementType() {
/// integer and float types with limited bit width are supported. Additionally,
/// depending on the operation pointers may be supported as well.
static bool isTypeCompatibleWithAtomicOp(Type type, bool isPointerTypeAllowed) {
- if (type.isa<LLVMPointerType>())
+ if (llvm::isa<LLVMPointerType>(type))
return isPointerTypeAllowed;
std::optional<unsigned> bitWidth;
- if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
if (!isCompatibleFloatingPointType(type))
return false;
bitWidth = floatType.getWidth();
}
- if (auto integerType = type.dyn_cast<IntegerType>())
+ if (auto integerType = llvm::dyn_cast<IntegerType>(type))
bitWidth = integerType.getWidth();
// The type is neither an integer, float, or pointer type.
if (!bitWidth)
@@ -777,7 +776,7 @@ LogicalResult LoadOp::verify() {
void LoadOp::build(OpBuilder &builder, OperationState &state, Value addr,
unsigned alignment, bool isVolatile, bool isNonTemporal) {
- auto type = addr.getType().cast<LLVMPointerType>().getElementType();
+ auto type = llvm::cast<LLVMPointerType>(addr.getType()).getElementType();
assert(type && "must provide explicit element type to the constructor "
"when the pointer type is opaque");
build(builder, state, type, addr, alignment, isVolatile, isNonTemporal);
@@ -801,7 +800,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
// std::nullopt if the given type is not the pointer type.
static std::optional<Type>
getLoadStoreElementType(OpAsmParser &parser, Type type, SMLoc trailingTypeLoc) {
- auto llvmTy = type.dyn_cast<LLVM::LLVMPointerType>();
+ auto llvmTy = llvm::dyn_cast<LLVM::LLVMPointerType>(type);
if (!llvmTy) {
parser.emitError(trailingTypeLoc, "expected LLVM pointer type");
return std::nullopt;
@@ -919,7 +918,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
ValueRange args) {
SmallVector<Type> results;
Type resultType = func.getFunctionType().getReturnType();
- if (!resultType.isa<LLVM::LLVMVoidType>())
+ if (!llvm::isa<LLVM::LLVMVoidType>(resultType))
results.push_back(resultType);
build(builder, state, results, SymbolRefAttr::get(func), args, nullptr,
nullptr);
@@ -964,7 +963,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (!getNumOperands())
return emitOpError(
"must have either a `callee` attribute or at least an operand");
- auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>();
+ auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
if (!ptrType)
return emitOpError("indirect call expects a pointer as callee: ")
<< getOperand(0).getType();
@@ -988,7 +987,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
fnType = fn.getFunctionType();
}
- LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
+ LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
if (!funcType)
return emitOpError("callee does not have a functional type: ") << fnType;
@@ -1023,11 +1022,11 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
<< " != " << funcType.getParamType(i);
if (getNumResults() == 0 &&
- !funcType.getReturnType().isa<LLVM::LLVMVoidType>())
+ !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
return emitOpError() << "expected function call to produce a value";
if (getNumResults() != 0 &&
- funcType.getReturnType().isa<LLVM::LLVMVoidType>())
+ llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
return emitOpError()
<< "calling function with void result must not produce values";
@@ -1083,7 +1082,7 @@ static ParseResult parseCallTypeAndResolveOperands(
return parser.emitError(trailingTypesLoc,
"expected indirect call to have 2 trailing types");
- auto funcType = types.pop_back_val().dyn_cast<FunctionType>();
+ auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val());
if (!funcType)
return parser.emitError(trailingTypesLoc,
"expected trailing function type");
@@ -1091,7 +1090,7 @@ static ParseResult parseCallTypeAndResolveOperands(
return parser.emitError(trailingTypesLoc,
"expected function with 0 or 1 result");
if (funcType.getNumResults() == 1 &&
- funcType.getResult(0).isa<LLVM::LLVMVoidType>())
+ llvm::isa<LLVM::LLVMVoidType>(funcType.getResult(0)))
return parser.emitError(trailingTypesLoc,
"expected a non-void result type");
@@ -1292,7 +1291,7 @@ LogicalResult LandingpadOp::verify() {
for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
value = getOperand(idx);
- bool isFilter = value.getType().isa<LLVMArrayType>();
+ bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
if (isFilter) {
// FIXME: Verify filter clauses when arrays are appropriately handled
} else {
@@ -1324,7 +1323,7 @@ void LandingpadOp::print(OpAsmPrinter &p) {
for (auto value : getOperands()) {
// Similar to llvm - if clause is an array type then it is filter
// clause else catch clause
- bool isArrayTy = value.getType().isa<LLVMArrayType>();
+ bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
<< value.getType() << ") ";
}
@@ -1383,13 +1382,13 @@ static Type getInsertExtractValueElementType(
// structures. Check the position index before accessing, it is supposed to
// be in bounds.
for (int64_t idx : position) {
- if (auto arrayType = llvmType.dyn_cast<LLVMArrayType>()) {
+ if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
emitError("position out of bounds: ") << idx;
return {};
}
llvmType = arrayType.getElementType();
- } else if (auto structType = llvmType.dyn_cast<LLVMStructType>()) {
+ } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
if (idx < 0 ||
static_cast<unsigned>(idx) >= structType.getBody().size()) {
emitError("position out of bounds: ") << idx;
@@ -1409,10 +1408,10 @@ static Type getInsertExtractValueElementType(
static Type getInsertExtractValueElementType(Type llvmType,
ArrayRef<int64_t> position) {
for (int64_t idx : position) {
- if (auto structType = llvmType.dyn_cast<LLVMStructType>())
+ if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
llvmType = structType.getBody()[idx];
else
- llvmType = llvmType.cast<LLVMArrayType>().getElementType();
+ llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
}
return llvmType;
}
@@ -1519,7 +1518,7 @@ LogicalResult ReturnOp::verify() {
return success();
Type expectedType = parent.getFunctionType().getReturnType();
- if (expectedType.isa<LLVMVoidType>()) {
+ if (llvm::isa<LLVMVoidType>(expectedType)) {
if (!getArg())
return success();
InFlightDiagnostic diag = emitOpError("expected no operands");
@@ -1527,7 +1526,7 @@ LogicalResult ReturnOp::verify() {
return diag;
}
if (!getArg()) {
- if (expectedType.isa<LLVMVoidType>())
+ if (llvm::isa<LLVMVoidType>(expectedType))
return success();
InFlightDiagnostic diag = emitOpError("expected 1 operand");
diag.attachNote(parent->getLoc()) << "when returning from function";
@@ -1664,7 +1663,7 @@ void GlobalOp::print(OpAsmPrinter &p) {
getVisibility_AttrName()});
// Print the trailing type unless it's a string global.
- if (getValueOrNull().dyn_cast_or_null<StringAttr>())
+ if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
return;
p << " : " << getType();
@@ -1779,7 +1778,7 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
Region &initRegion = *result.addRegion();
if (types.empty()) {
- if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
+ if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
MLIRContext *context = parser.getContext();
auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
strAttr.getValue().size());
@@ -1802,15 +1801,15 @@ ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
}
static bool isZeroAttribute(Attribute value) {
- if (auto intValue = value.dyn_cast<IntegerAttr>())
+ if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
return intValue.getValue().isZero();
- if (auto fpValue = value.dyn_cast<FloatAttr>())
+ if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
return fpValue.getValue().isZero();
- if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
+ if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value))
return isZeroAttribute(splatValue.getSplatValue<Attribute>());
- if (auto elementsValue = value.dyn_cast<ElementsAttr>())
+ if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
- if (auto arrayValue = value.dyn_cast<ArrayAttr>())
+ if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
return false;
}
@@ -1822,10 +1821,10 @@ LogicalResult GlobalOp::verify() {
if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
return emitOpError("must appear at the module level");
- if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) {
- auto type = getType().dyn_cast<LLVMArrayType>();
+ if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
+ auto type = llvm::dyn_cast<LLVMArrayType>(getType());
IntegerType elementType =
- type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
+ type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
if (!elementType || elementType.getWidth() != 8 ||
type.getNumElements() != strAttr.getValue().size())
return emitOpError(
@@ -1844,7 +1843,7 @@ LogicalResult GlobalOp::verify() {
}
if (getLinkage() == Linkage::Appending) {
- if (!getType().isa<LLVMArrayType>()) {
+ if (!llvm::isa<LLVMArrayType>(getType())) {
return emitOpError() << "expected array type for '"
<< stringifyLinkage(Linkage::Appending)
<< "' linkage";
@@ -1892,7 +1891,7 @@ LogicalResult GlobalOp::verifyRegions() {
LogicalResult
GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
for (Attribute ctor : getCtors()) {
- if (failed(verifySymbolAttrUse(ctor.cast<FlatSymbolRefAttr>(), *this,
+ if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
symbolTable)))
return failure();
}
@@ -1913,7 +1912,7 @@ LogicalResult GlobalCtorsOp::verify() {
LogicalResult
GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
for (Attribute dtor : getDtors()) {
- if (failed(verifySymbolAttrUse(dtor.cast<FlatSymbolRefAttr>(), *this,
+ if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
symbolTable)))
return failure();
}
@@ -2012,7 +2011,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
if (argAttrs.empty())
return;
- assert(type.cast<LLVMFunctionType>().getNumParams() == argAttrs.size() &&
+ assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
"expected as many argument attribute lists as arguments");
function_interface_impl::addArgAndResultAttrs(
builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
@@ -2143,7 +2142,7 @@ void LLVMFuncOp::print(OpAsmPrinter &p) {
argTypes.push_back(fnType.getParamType(i));
Type returnType = fnType.getReturnType();
- if (!returnType.isa<LLVMVoidType>())
+ if (!llvm::isa<LLVMVoidType>(returnType))
resTypes.push_back(returnType);
function_interface_impl::printFunctionSignature(p, *this, argTypes,
@@ -2251,8 +2250,8 @@ Region *LLVMFuncOp::getCallableRegion() {
//===----------------------------------------------------------------------===//
LogicalResult LLVM::ConstantOp::verify() {
- if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) {
- auto arrayType = getType().dyn_cast<LLVMArrayType>();
+ if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
+ auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
!arrayType.getElementType().isInteger(8)) {
return emitOpError() << "expected array type of "
@@ -2261,35 +2260,35 @@ LogicalResult LLVM::ConstantOp::verify() {
}
return success();
}
- if (auto structType = getType().dyn_cast<LLVMStructType>()) {
+ if (auto structType = llvm::dyn_cast<LLVMStructType>(getType())) {
if (structType.getBody().size() != 2 ||
structType.getBody()[0] != structType.getBody()[1]) {
return emitError() << "expected struct type with two elements of the "
"same type, the type of a complex constant";
}
- auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
+ auto arrayAttr = llvm::dyn_cast<ArrayAttr>(getValue());
if (!arrayAttr || arrayAttr.size() != 2) {
return emitOpError() << "expected array attribute with two elements, "
"representing a complex constant";
}
- auto re = arrayAttr[0].dyn_cast<TypedAttr>();
- auto im = arrayAttr[1].dyn_cast<TypedAttr>();
+ auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
+ auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
if (!re || !im || re.getType() != im.getType()) {
return emitOpError()
<< "expected array attribute with two elements of the same type";
}
Type elementType = structType.getBody()[0];
- if (!elementType
- .isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
+ if (!llvm::isa<IntegerType, Float16Type, Float32Type, Float64Type>(
+ elementType)) {
return emitError()
<< "expected struct element types to be floating point type or "
"integer type";
}
return success();
}
- if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
+ if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
return emitOpError()
<< "only supports integer, float, string or elements attributes";
return success();
@@ -2314,7 +2313,7 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult AtomicRMWOp::verify() {
- auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
+ auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
auto valType = getVal().getType();
if (!ptrType.isOpaque() && valType != ptrType.getElementType())
return emitOpError("expected LLVM IR element type for operand #0 to "
@@ -2327,7 +2326,7 @@ LogicalResult AtomicRMWOp::verify() {
if (!isTypeCompatibleWithAtomicOp(valType, /*isPointerTypeAllowed=*/false))
return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
- auto intType = valType.dyn_cast<IntegerType>();
+ auto intType = llvm::dyn_cast<IntegerType>(valType);
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
intBitWidth != 64)
@@ -2367,7 +2366,7 @@ void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult AtomicCmpXchgOp::verify() {
- auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
+ auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
if (!ptrType)
return emitOpError("expected LLVM IR pointer type for operand #0");
auto valType = getVal().getType();
@@ -2421,10 +2420,10 @@ OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
}
LogicalResult LLVM::BitcastOp::verify() {
- auto resultType = extractVectorElementType(getResult().getType())
- .dyn_cast<LLVMPointerType>();
- auto sourceType =
- extractVectorElementType(getArg().getType()).dyn_cast<LLVMPointerType>();
+ auto resultType = llvm::dyn_cast<LLVMPointerType>(
+ extractVectorElementType(getResult().getType()));
+ auto sourceType = llvm::dyn_cast<LLVMPointerType>(
+ extractVectorElementType(getArg().getType()));
// If one of the types is a pointer (or vector of pointers), then
// both source and result type have to be pointers.
@@ -2435,7 +2434,8 @@ LogicalResult LLVM::BitcastOp::verify() {
return success();
auto isVector = [](Type type) {
- return type.isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>();
+ return llvm::isa<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
+ type);
};
// Due to bitcast requiring both operands to be of the same size, it is not
@@ -2480,7 +2480,7 @@ OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
// gep %x:T, 0 -> %x
if (getBase().getType() == getType() && indices.size() == 1)
- if (auto integer = indices[0].dyn_cast_or_null<IntegerAttr>())
+ if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
if (integer.getValue().isZero())
return getBase();
@@ -2488,7 +2488,7 @@ OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
bool changed = false;
SmallVector<GEPArg> gepArgs;
for (auto iter : llvm::enumerate(indices)) {
- auto integer = iter.value().dyn_cast_or_null<IntegerAttr>();
+ auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
// Constant indices can only be int32_t, so if integer does not fit we
// are forced to keep it dynamic, despite being a constant.
if (!indices.isDynamicIndex(iter.index()) || !integer ||
@@ -2686,7 +2686,7 @@ LogicalResult MetadataOp::verifyRegions() {
SmallVectorImpl<TBAAGraphNode *> &operands =
tbaaGraph[tdOp.getSymNameAttr()]->operands;
for (Attribute attr : tdOp.getMembers()) {
- StringAttr symbolRef = attr.cast<FlatSymbolRefAttr>().getAttr();
+ StringAttr symbolRef = llvm::cast<FlatSymbolRefAttr>(attr).getAttr();
if (failed(verifyReference(op, symbolRef, tdOp.getMembersAttrName())))
return failure();
@@ -2888,7 +2888,7 @@ LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
// llvm::DataLayout constructor.
if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
return success();
- if (auto stringAttr = attr.getValue().dyn_cast<StringAttr>())
+ if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
return verifyDataLayoutString(
stringAttr.getValue(),
[op](const Twine &message) { op->emitOpError() << message.str(); });
@@ -2909,28 +2909,28 @@ LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
StringAttr name = paramAttr.getName();
auto checkUnitAttrType = [&]() -> LogicalResult {
- if (!paramAttr.getValue().isa<UnitAttr>())
+ if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
return op->emitError() << name << " should be a unit attribute";
return success();
};
auto checkTypeAttrType = [&]() -> LogicalResult {
- if (!paramAttr.getValue().isa<TypeAttr>())
+ if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
return op->emitError() << name << " should be a type attribute";
return success();
};
auto checkIntegerAttrType = [&]() -> LogicalResult {
- if (!paramAttr.getValue().isa<IntegerAttr>())
+ if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
return op->emitError() << name << " should be an integer attribute";
return success();
};
auto checkPointerType = [&]() -> LogicalResult {
- if (!paramType.isa<LLVMPointerType>())
+ if (!llvm::isa<LLVMPointerType>(paramType))
return op->emitError()
<< name << " attribute attached to non-pointer LLVM type";
return success();
};
auto checkIntegerType = [&]() -> LogicalResult {
- if (!paramType.isa<IntegerType>())
+ if (!llvm::isa<IntegerType>(paramType))
return op->emitError()
<< name << " attribute attached to non-integer LLVM type";
return success();
@@ -2938,8 +2938,8 @@ LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
auto checkPointerTypeMatches = [&]() -> LogicalResult {
if (failed(checkPointerType()))
return failure();
- auto ptrType = paramType.cast<LLVMPointerType>();
- auto typeAttr = paramAttr.getValue().cast<TypeAttr>();
+ auto ptrType = llvm::cast<LLVMPointerType>(paramType);
+ auto typeAttr = llvm::cast<TypeAttr>(paramAttr.getValue());
if (!ptrType.isOpaque() && ptrType.getElementType() != typeAttr.getValue())
return op->emitError()
@@ -3033,7 +3033,7 @@ LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
// Check to see if this function has a void return with a result attribute
// to it. It isn't clear what semantics we would assign to that.
- if (resType.isa<LLVMVoidType>())
+ if (llvm::isa<LLVMVoidType>(resType))
return op->emitError() << "cannot attach result attributes to functions "
"with a void return";
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
index aae87727f54b9..0c27d0fc2e2d7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
@@ -176,7 +176,7 @@ static unsigned tryToEnforceAlignment(Value value, unsigned requestedAlignment,
if (auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
// Use the alignment attribute set for this argument in the parent function
// if it has been set.
- auto blockArg = value.cast<BlockArgument>();
+ auto blockArg = llvm::cast<BlockArgument>(value);
if (Attribute alignAttr = func.getArgAttr(
blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
index 95e0ea9553676..4b5c1eace52c3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
@@ -29,7 +29,7 @@ verifySymbolRefs(Operation *op, StringRef name, ArrayAttr symbolRefs,
// names processed here (e.g. 'tbaa'). This verification
// is redundant in some cases.
if (!llvm::all_of(symbolRefs, [](Attribute attr) {
- return attr && attr.isa<SymbolRefAttr>();
+ return attr && llvm::isa<SymbolRefAttr>(attr);
}))
return op->emitOpError() << name
<< " attribute failed to satisfy constraint: "
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 0c8b97b4de124..742ab5b632663 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -24,7 +24,8 @@ using namespace mlir::LLVM;
/// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
/// prints it as usual.
static void dispatchPrint(AsmPrinter &printer, Type type) {
- if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>())
+ if (isCompatibleType(type) &&
+ !llvm::isa<IntegerType, FloatType, VectorType>(type))
return mlir::LLVM::detail::printType(type, printer);
printer.printType(type);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 4b3b1e183903f..dcbfbf32a1487 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -130,8 +130,9 @@ generatedTypePrinter(Type def, AsmPrinter &printer);
//===----------------------------------------------------------------------===//
bool LLVMArrayType::isValidElementType(Type type) {
- return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
+ return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
+ LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
+ type);
}
LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
@@ -186,11 +187,11 @@ LLVMArrayType::getPreferredAlignment(const DataLayout &dataLayout,
//===----------------------------------------------------------------------===//
bool LLVMFunctionType::isValidArgumentType(Type type) {
- return !type.isa<LLVMVoidType, LLVMFunctionType>();
+ return !llvm::isa<LLVMVoidType, LLVMFunctionType>(type);
}
bool LLVMFunctionType::isValidResultType(Type type) {
- return !type.isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>();
+ return !llvm::isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(type);
}
LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
@@ -239,9 +240,9 @@ bool LLVMPointerType::isValidElementType(Type type) {
if (!type)
return true;
return isCompatibleOuterType(type)
- ? !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
- LLVMLabelType>()
- : type.isa<PointerElementTypeInterface>();
+ ? !llvm::isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
+ LLVMLabelType>(type)
+ : llvm::isa<PointerElementTypeInterface>(type);
}
LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
@@ -266,7 +267,7 @@ constexpr const static unsigned kDefaultPointerAlignment = 8;
std::optional<unsigned> mlir::LLVM::extractPointerSpecValue(Attribute attr,
PtrDLEntryPos pos) {
- auto spec = attr.cast<DenseIntElementsAttr>();
+ auto spec = llvm::cast<DenseIntElementsAttr>(attr);
auto idx = static_cast<unsigned>(pos);
if (idx >= spec.size())
return std::nullopt;
@@ -285,8 +286,8 @@ getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type,
for (DataLayoutEntryInterface entry : params) {
if (!entry.isTypeEntry())
continue;
- if (entry.getKey().get<Type>().cast<LLVMPointerType>().getAddressSpace() ==
- type.getAddressSpace()) {
+ if (llvm::cast<LLVMPointerType>(entry.getKey().get<Type>())
+ .getAddressSpace() == type.getAddressSpace()) {
currentEntry = entry.getValue();
break;
}
@@ -350,11 +351,11 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
continue;
unsigned size = kDefaultPointerSizeBits;
unsigned abi = kDefaultPointerAlignment;
- auto newType = newEntry.getKey().get<Type>().cast<LLVMPointerType>();
+ auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
const auto *it =
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
if (auto type = entry.getKey().dyn_cast<Type>()) {
- return type.cast<LLVMPointerType>().getAddressSpace() ==
+ return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
newType.getAddressSpace();
}
return false;
@@ -362,7 +363,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
if (it == oldLayout.end()) {
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
if (auto type = entry.getKey().dyn_cast<Type>()) {
- return type.cast<LLVMPointerType>().getAddressSpace() == 0;
+ return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
}
return false;
});
@@ -372,7 +373,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi);
}
- Attribute newSpec = newEntry.getValue().cast<DenseIntElementsAttr>();
+ Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
unsigned newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
unsigned newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
if (size != newSize || abi < newAbi || abi % newAbi != 0)
@@ -386,8 +387,8 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
for (DataLayoutEntryInterface entry : entries) {
if (!entry.isTypeEntry())
continue;
- auto key = entry.getKey().get<Type>().cast<LLVMPointerType>();
- auto values = entry.getValue().dyn_cast<DenseIntElementsAttr>();
+ auto key = llvm::cast<LLVMPointerType>(entry.getKey().get<Type>());
+ auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
if (!values || (values.size() != 3 && values.size() != 4)) {
return emitError(loc)
<< "expected layout attribute for " << entry.getKey().get<Type>()
@@ -412,8 +413,9 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
//===----------------------------------------------------------------------===//
bool LLVMStructType::isValidElementType(Type type) {
- return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
+ return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
+ LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
+ type);
}
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -538,7 +540,7 @@ getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type,
if (currentEntry == params.end())
return std::nullopt;
- auto attr = currentEntry->getValue().cast<DenseIntElementsAttr>();
+ auto attr = llvm::cast<DenseIntElementsAttr>(currentEntry->getValue());
if (pos == StructDLEntryPos::Preferred &&
attr.size() <= static_cast<unsigned>(StructDLEntryPos::Preferred))
// If no preferred was specified, fall back to abi alignment
@@ -586,7 +588,7 @@ LLVMStructType::getPreferredAlignment(const DataLayout &dataLayout,
}
static unsigned extractStructSpecValue(Attribute attr, StructDLEntryPos pos) {
- return attr.cast<DenseIntElementsAttr>()
+ return llvm::cast<DenseIntElementsAttr>(attr)
.getValues<unsigned>()[static_cast<unsigned>(pos)];
}
@@ -619,8 +621,8 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries,
if (!entry.isTypeEntry())
continue;
- auto key = entry.getKey().get<Type>().cast<LLVMStructType>();
- auto values = entry.getValue().dyn_cast<DenseIntElementsAttr>();
+ auto key = llvm::cast<LLVMStructType>(entry.getKey().get<Type>());
+ auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
if (!values || (values.size() != 2 && values.size() != 1)) {
return emitError(loc)
<< "expected layout attribute for " << entry.getKey().get<Type>()
@@ -676,7 +678,7 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
}
bool LLVMFixedVectorType::isValidElementType(Type type) {
- return type.isa<LLVMPointerType, LLVMPPCFP128Type>();
+ return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
}
LogicalResult
@@ -705,10 +707,11 @@ LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
}
bool LLVMScalableVectorType::isValidElementType(Type type) {
- if (auto intType = type.dyn_cast<IntegerType>())
+ if (auto intType = llvm::dyn_cast<IntegerType>(type))
return intType.isSignless();
- return isCompatibleFloatingPointType(type) || type.isa<LLVMPointerType>();
+ return isCompatibleFloatingPointType(type) ||
+ llvm::isa<LLVMPointerType>(type);
}
LogicalResult
@@ -724,7 +727,7 @@ LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
bool mlir::LLVM::isCompatibleOuterType(Type type) {
// clang-format off
- if (type.isa<
+ if (llvm::isa<
BFloat16Type,
Float16Type,
Float32Type,
@@ -745,17 +748,17 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
LLVMScalableVectorType,
LLVMVoidType,
LLVMX86MMXType
- >()) {
+ >(type)) {
// clang-format on
return true;
}
// Only signless integers are compatible.
- if (auto intType = type.dyn_cast<IntegerType>())
+ if (auto intType = llvm::dyn_cast<IntegerType>(type))
return intType.isSignless();
// 1D vector types are compatible.
- if (auto vecType = type.dyn_cast<VectorType>())
+ if (auto vecType = llvm::dyn_cast<VectorType>(type))
return vecType.getRank() == 1;
return false;
@@ -835,22 +838,22 @@ bool mlir::LLVM::isCompatibleType(Type type) {
}
bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
- return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
- Float80Type, Float128Type, LLVMPPCFP128Type>();
+ return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
+ Float80Type, Float128Type, LLVMPPCFP128Type>(type);
}
bool mlir::LLVM::isCompatibleVectorType(Type type) {
- if (type.isa<LLVMFixedVectorType, LLVMScalableVectorType>())
+ if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type))
return true;
- if (auto vecType = type.dyn_cast<VectorType>()) {
+ if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
if (vecType.getRank() != 1)
return false;
Type elementType = vecType.getElementType();
- if (auto intType = elementType.dyn_cast<IntegerType>())
+ if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
return intType.isSignless();
- return elementType.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
- Float80Type, Float128Type>();
+ return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
+ Float80Type, Float128Type>(elementType);
}
return false;
}
@@ -883,13 +886,12 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
}
bool mlir::LLVM::isScalableVectorType(Type vectorType) {
- assert(
- (vectorType
- .isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>()) &&
- "expected LLVM-compatible vector type");
- return !vectorType.isa<LLVMFixedVectorType>() &&
- (vectorType.isa<LLVMScalableVectorType>() ||
- vectorType.cast<VectorType>().isScalable());
+ assert((llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>(
+ vectorType)) &&
+ "expected LLVM-compatible vector type");
+ return !llvm::isa<LLVMFixedVectorType>(vectorType) &&
+ (llvm::isa<LLVMScalableVectorType>(vectorType) ||
+ llvm::cast<VectorType>(vectorType).isScalable());
}
Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
@@ -970,9 +972,9 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
elementSize.isScalable());
})
.Default([](Type ty) {
- assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMTokenType, LLVMStructType, LLVMArrayType,
- LLVMPointerType, LLVMFunctionType>()) &&
+ assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
+ LLVMTokenType, LLVMStructType, LLVMArrayType,
+ LLVMPointerType, LLVMFunctionType>(ty)) &&
"unexpected missing support for primitive type");
return llvm::TypeSize::Fixed(0);
});
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ff949e1fbfe2..9bec96c4d4bc5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -90,13 +90,13 @@ MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
return NVVM::MMATypes::f32;
if (operandElType.isF32() && !isAccumulator)
return NVVM::MMATypes::tf32;
- if (operandElType.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(operandElType)) {
if (isAccumulator)
return NVVM::MMATypes::s32;
return std::nullopt;
}
- if (auto structType = operandElType.dyn_cast<LLVM::LLVMStructType>()) {
+ if (auto structType = llvm::dyn_cast<LLVM::LLVMStructType>(operandElType)) {
if (structType.getBody().empty())
return std::nullopt;
return inferOperandMMAType(structType.getBody()[0], isAccumulator);
@@ -526,9 +526,9 @@ LogicalResult MmaOp::verify() {
LogicalResult ShflOp::verify() {
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
- auto type = getType().dyn_cast<LLVM::LLVMStructType>();
+ auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType());
auto elementType = (type && type.getBody().size() == 2)
- ? type.getBody()[1].dyn_cast<IntegerType>()
+ ? llvm::dyn_cast<IntegerType>(type.getBody()[1])
: nullptr;
if (!elementType || elementType.getWidth() != 1)
return emitError("expected return type to be a two-element struct with "
@@ -600,7 +600,7 @@ inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
LogicalResult NVVM::WMMALoadOp::verify() {
unsigned addressSpace =
- getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
return emitOpError("expected source pointer in memory "
"space 0, 1, 3");
@@ -620,7 +620,7 @@ LogicalResult NVVM::WMMALoadOp::verify() {
LogicalResult NVVM::WMMAStoreOp::verify() {
unsigned addressSpace =
- getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
return emitOpError("expected operands to be a source pointer in memory "
"space 0, 1, 3");
@@ -672,7 +672,7 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
LogicalResult NVVM::LdMatrixOp::verify() {
unsigned addressSpace =
- getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
+ llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
if (addressSpace != 3)
return emitOpError("expected source pointer in memory space 3");
@@ -725,13 +725,13 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
// If maxntid and reqntid exist, it must be an array with max 3 dim
if (attrName == NVVMDialect::getMaxntidAttrName() ||
attrName == NVVMDialect::getReqntidAttrName()) {
- auto values = attr.getValue().dyn_cast<ArrayAttr>();
+ auto values = llvm::dyn_cast<ArrayAttr>(attr.getValue());
if (!values || values.empty() || values.size() > 3)
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
- for (auto val : attr.getValue().cast<ArrayAttr>()) {
- if (!val.dyn_cast<IntegerAttr>())
+ for (auto val : llvm::cast<ArrayAttr>(attr.getValue())) {
+ if (!llvm::dyn_cast<IntegerAttr>(val))
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
@@ -740,7 +740,7 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
// If minctasm and maxnreg exist, it must be an array with max 3 dim
if (attrName == NVVMDialect::getMinctasmAttrName() ||
attrName == NVVMDialect::getMaxnregAttrName()) {
- if (!attr.getValue().dyn_cast<IntegerAttr>())
+ if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
return op->emitError()
<< "'" << attrName << "' attribute must be integer constant";
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ed59518f64741..137204343ef06 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -506,15 +506,15 @@ LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
/// the type of `source`.
static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
int64_t dim) {
- if (source.getType().isa<UnrankedMemRefType, MemRefType>())
+ if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
return b.createOrFold<memref::DimOp>(loc, source, dim);
- if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
+ if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
return b.createOrFold<tensor::DimOp>(loc, source, dim);
llvm_unreachable("Expected MemRefType or TensorType");
}
static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source,
int64_t dim) {
- auto shapedType = source.getType().cast<ShapedType>();
+ auto shapedType = llvm::cast<ShapedType>(source.getType());
if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
return createOrFoldDimOp(b, loc, source, dim);
return b.getIndexAttr(shapedType.getDimSize(dim));
@@ -644,7 +644,7 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
for (OpOperand *opOperand : getDpsInitOperands()) {
SmallVector<OpFoldResult> shapes;
for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
- auto shapedType = opOperand->get().getType().cast<ShapedType>();
+ auto shapedType = llvm::cast<ShapedType>(opOperand->get().getType());
if (!shapedType.isDynamicDim(dim)) {
// Static dim: Return IntegerAttr.
shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 84663af346f5f..e5af2035a4832 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -63,7 +63,8 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) {
- assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
+ assert(llvm::all_of(outputTypes,
+ [](Type t) { return llvm::isa<ShapedType>(t); }));
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
@@ -106,7 +107,7 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
resultTensorTypes.value_or(TypeRange());
if (!resultTensorTypes)
copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
- [](Type type) { return type.isa<RankedTensorType>(); });
+ [](Type type) { return llvm::isa<RankedTensorType>(type); });
state.addOperands(inputs);
state.addOperands(outputs);
@@ -173,7 +174,7 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
// Otherwise we append it to the discardable attributes dictionary where it is
// handled by the generic Operation::create(...) method.
if (result.propertiesAttr) {
- NamedAttrList attrs = result.propertiesAttr.cast<DictionaryAttr>();
+ NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
attrs.append("operand_segment_sizes",
parser.getBuilder().getDenseI32ArrayAttr(
{static_cast<int32_t>(inputsOperands.size()),
@@ -448,9 +449,15 @@ class RegionBuilderHelper {
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
}
- bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
- bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
- bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
+ bool isComplex(Value value) {
+ return llvm::isa<ComplexType>(value.getType());
+ }
+ bool isFloatingPoint(Value value) {
+ return llvm::isa<FloatType>(value.getType());
+ }
+ bool isInteger(Value value) {
+ return llvm::isa<IntegerType>(value.getType());
+ }
OpBuilder getBuilder() {
OpBuilder builder(context);
@@ -748,8 +755,7 @@ void GenericOp::print(OpAsmPrinter &p) {
for (auto attr : (*this)->getAttrs()) {
if (attr.getName() == getIteratorTypesAttrName()) {
auto iteratorTypes =
- attr.getValue()
- .cast<ArrayAttr>()
+ llvm::cast<ArrayAttr>(attr.getValue())
.getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
// Convert IteratorType enums into the string representation. This is
// needed, because tests still use the old format when 'iterator_types'
@@ -873,13 +879,13 @@ static void getGenericEffectsImpl(
ValueRange results, const OpOperandVector &inputOperands,
const OpOperandVector &outputOperands) {
for (auto *operand : inputOperands) {
- if (!operand->get().getType().isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(operand->get().getType()))
continue;
effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
SideEffects::DefaultResource::get());
}
for (auto *operand : outputOperands) {
- if (!operand->get().getType().isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(operand->get().getType()))
continue;
effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
SideEffects::DefaultResource::get());
@@ -942,7 +948,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
// number to use for replacing uses of this operation.
SmallVector<Value> returnedArgs;
for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
- auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
+ auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
if (!yieldArg || yieldArg.getOwner() != &body)
return failure();
unsigned argumentNumber = yieldArg.getArgNumber();
@@ -1003,7 +1009,7 @@ static ParseResult parseDstStyleOp(
// Add result types.
for (Type outputType : outputTypes) {
- if (outputType.isa<RankedTensorType>())
+ if (llvm::isa<RankedTensorType>(outputType))
result.addTypes(outputType);
}
@@ -1037,7 +1043,7 @@ void MapOp::build(
// Add output types for `RankedTensorType` output arguments.
Type initType = init.getType();
- if (initType.isa<RankedTensorType>())
+ if (llvm::isa<RankedTensorType>(initType))
result.addTypes(initType);
if (bodyBuild)
@@ -1056,8 +1062,9 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
b.setInsertionPointToStart(&block);
SmallVector<Value> bbArgs;
for (auto &operand : operands) {
- block.addArgument(operand.getType().cast<ShapedType>().getElementType(),
- b.getUnknownLoc());
+ block.addArgument(
+ llvm::cast<ShapedType>(operand.getType()).getElementType(),
+ b.getUnknownLoc());
}
SmallVector<Value> payloadOpOperands;
// If initFirst flag is enabled, we consider init as the first position of
@@ -1074,8 +1081,8 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
Operation *payloadOp = b.create(
result.location, b.getStringAttr(payloadOpName.getStringRef()),
payloadOpOperands,
- TypeRange{
- result.operands.back().getType().cast<ShapedType>().getElementType()},
+ TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
+ .getElementType()},
payloadOpAttrs);
b.create<YieldOp>(result.location, payloadOp->getResults());
}
@@ -1151,7 +1158,8 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
for (const auto &attr : payloadOp->getAttrs()) {
- auto fastAttr = attr.getValue().dyn_cast<mlir::arith::FastMathFlagsAttr>();
+ auto fastAttr =
+ llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
attrToElide = attr.getName().str();
elidedAttrs.push_back(attrToElide);
@@ -1200,7 +1208,8 @@ LogicalResult MapOp::verify() {
// The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
- auto inputElemType = inputArg.getType().cast<ShapedType>().getElementType();
+ auto inputElemType =
+ llvm::cast<ShapedType>(inputArg.getType()).getElementType();
if (bbArgType != inputElemType) {
return emitOpError() << "expected element type of input " << inputElemType
<< " to match bbArg type " << bbArgType;
@@ -1210,7 +1219,7 @@ LogicalResult MapOp::verify() {
// The shape of each input must match the shape of the output.
auto outputShape = getInit().getType().getShape();
for (Type inputArgType : TypeRange{getInputs()}) {
- auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
+ auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
if (inputElemShape != outputShape) {
return emitOpError() << "expected shape of input (" << inputElemShape
<< ") to match shape of output (" << outputShape
@@ -1270,7 +1279,7 @@ void ReduceOp::build(
// Add output types for `RankedTensorType` output arguments.
for (Value init : inits) {
Type initType = init.getType();
- if (initType.isa<RankedTensorType>())
+ if (llvm::isa<RankedTensorType>(initType))
result.addTypes(initType);
}
@@ -1280,7 +1289,8 @@ void ReduceOp::build(
}
SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
- int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
+ int64_t inputRank =
+ llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
utils::IteratorType::parallel);
for (int64_t reductionDim : getDimensions())
@@ -1289,7 +1299,8 @@ SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
}
ArrayAttr ReduceOp::getIndexingMaps() {
- int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
+ int64_t inputRank =
+ llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
SmallVector<AffineMap> affineMaps(
getNumDpsInputs(),
AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
@@ -1390,8 +1401,8 @@ LogicalResult ReduceOp::verify() {
ArrayRef<int64_t> dimensionsRef = getDimensions();
for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
- if (getInputs()[i].getType().cast<ShapedType>().getShape() !=
- getInputs()[0].getType().cast<ShapedType>().getShape()) {
+ if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
+ llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
return emitOpError() << "expects all inputs to have the same shapes. "
"Shape at input-index "
<< i
@@ -1399,16 +1410,16 @@ LogicalResult ReduceOp::verify() {
}
}
for (int64_t i = 1; i < getNumDpsInits(); ++i) {
- if (getInits()[i].getType().cast<ShapedType>().getShape() !=
- getInits()[0].getType().cast<ShapedType>().getShape()) {
+ if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
+ llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
return emitOpError() << "expects all outputs to have the same shapes. "
"Shape at output-index "
<< i
<< " is not equal to the shape at output-index 0.";
}
}
- auto inputType = getInputs()[0].getType().cast<ShapedType>();
- auto initType = getInits()[0].getType().cast<ShapedType>();
+ auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
+ auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
DenseSet<int64_t> dimensionsToReduce;
for (int64_t dimension : dimensionsRef) {
@@ -1449,7 +1460,8 @@ LogicalResult ReduceOp::verify() {
// Check that the first block arguments match the element type of the inputs.
for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
- Type inputElementType = input.getType().cast<ShapedType>().getElementType();
+ Type inputElementType =
+ llvm::cast<ShapedType>(input.getType()).getElementType();
if (inputElementType != bbArg.getType())
return emitOpError()
<< "input element type " << inputElementType
@@ -1462,7 +1474,7 @@ LogicalResult ReduceOp::verify() {
llvm::zip(getDpsInitOperands(),
block->getArguments().take_back(getNumDpsInits()))) {
auto outputElementType =
- output->get().getType().cast<ShapedType>().getElementType();
+ llvm::cast<ShapedType>(output->get().getType()).getElementType();
if (outputElementType != bbArg.getType())
return emitOpError()
<< "output element type " << outputElementType
@@ -1496,7 +1508,7 @@ void TransposeOp::build(::mlir::OpBuilder &builder,
// Add output types for `RankedTensorType` output arguments.
Type initType = init.getType();
- if (initType.isa<RankedTensorType>())
+ if (llvm::isa<RankedTensorType>(initType))
result.addTypes(initType);
buildIdentityRegion(builder, result.location, *result.addRegion(), input,
@@ -1610,7 +1622,7 @@ void BroadcastOp::build(::mlir::OpBuilder &builder,
// Add output types for `RankedTensorType` output arguments.
Type initType = init.getType();
- if (initType.isa<RankedTensorType>())
+ if (llvm::isa<RankedTensorType>(initType))
result.addTypes(initType);
buildIdentityRegion(builder, result.location, *result.addRegion(), input,
@@ -1828,7 +1840,7 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
}
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
- if (auto memref = t.dyn_cast<MemRefType>()) {
+ if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
ss << "view";
for (auto size : memref.getShape())
if (size < 0)
@@ -1838,14 +1850,14 @@ static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
if (failed(appendMangledType(ss, memref.getElementType())))
return failure();
if (auto as = memref.getMemorySpace()) {
- if (auto attr = as.dyn_cast<IntegerAttr>())
+ if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
ss << "as" << attr.getInt();
else
return failure();
}
return success();
}
- if (auto vec = t.dyn_cast<VectorType>()) {
+ if (auto vec = llvm::dyn_cast<VectorType>(t)) {
ss << "vector";
llvm::interleave(
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
@@ -1864,9 +1876,9 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
std::string name(op->getName().getStringRef().str());
std::string fun = "";
for (NamedAttribute kv : op->getAttrs()) {
- if (UnaryFnAttr ufa = kv.getValue().dyn_cast<UnaryFnAttr>()) {
+ if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
fun = stringifyEnum(ufa.getValue()).str() + "_";
- } else if (BinaryFnAttr bfa = kv.getValue().dyn_cast<BinaryFnAttr>()) {
+ } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
fun = stringifyEnum(bfa.getValue()).str() + "_";
}
}
@@ -1898,7 +1910,7 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
// Linalg "inputs" may be either tensor or memref type.
// tensor<0xelt_type> is a convention that may not always mean
// "0 iterations". Only erase in cases we see memref<...x0x...>.
- auto mt = opOperand.get().getType().dyn_cast<MemRefType>();
+ auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
if (!mt)
continue;
if (llvm::is_contained(op.getShape(&opOperand), 0)) {
@@ -1934,9 +1946,10 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
rewriter.setInsertionPoint(linalgOp);
Location loc = linalgOp.getLoc();
- OpResult resultValue = castOp.getSource().cast<OpResult>();
+ OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
unsigned resultNumber = resultValue.getResultNumber();
- auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
+ auto resultType =
+ llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
// Replace the `outs` for the result with a `tensor.cast`. This cast is now
// going from a more dynamic shape to a less dynamic shape. If the producer
// for this cast, i.e. producer of the out operand, is also an operation
@@ -1975,7 +1988,7 @@ static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
if (linalgOp.isScalar(&opOperand))
continue;
Value src = opOperand.get();
- auto sourceType = src.getType().cast<RankedTensorType>();
+ auto sourceType = llvm::cast<RankedTensorType>(src.getType());
auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
@@ -1986,7 +1999,8 @@ static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
if (parentOp) {
if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
Value castSource = castOp.getSource();
- auto castSourceType = castSource.getType().dyn_cast<RankedTensorType>();
+ auto castSourceType =
+ llvm::dyn_cast<RankedTensorType>(castSource.getType());
if (castSourceType && castSourceType.hasStaticShape())
sourceShape = castSourceType.getShape();
}
@@ -2017,7 +2031,7 @@ static void createNewOperandWithStaticSizes(
newOperands.push_back(src);
if (linalgOp.isScalar(opOperand))
return;
- auto sourceType = src.getType().cast<RankedTensorType>();
+ auto sourceType = llvm::cast<RankedTensorType>(src.getType());
Type resultType = sourceType;
if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
resultTypes.push_back(resultType);
diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
index 55d09c421e31b..a259074d68180 100644
--- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -37,7 +37,7 @@ struct IndexOpInterface
int64_t flatDimCtr = 0;
for (Value operand : linalgOp->getOperands()) {
assert(flatDimPos >= flatDimCtr && "invalid pos");
- auto shapedType = operand.getType().cast<ShapedType>();
+ auto shapedType = llvm::cast<ShapedType>(operand.getType());
if (flatDimPos < flatDimCtr + shapedType.getRank()) {
cstr.bound(value) < cstr.getExpr(operand, flatDimPos - flatDimCtr);
break;
diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
index 70e5557d5a7bf..1a8fe208d4099 100644
--- a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
@@ -28,7 +28,7 @@ struct MLProgramOpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
- if (attr.isa<ExternAttr>()) {
+ if (llvm::isa<ExternAttr>(attr)) {
os << "extern";
return AliasResult::OverridableAlias;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
index d10c6ca12bdee..b5f5272d64212 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
@@ -22,7 +22,7 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
static bool isSupportedElementType(Type type) {
- return type.isa<MemRefType>() ||
+ return llvm::isa<MemRefType>(type) ||
OpBuilder(type.getContext()).getZeroAttr(type);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 4dfc09f11b911..3beda2cafff12 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -90,7 +90,7 @@ LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
for (OpOperand &operand : op->getOpOperands()) {
auto cast = operand.get().getDefiningOp<CastOp>();
if (cast && operand.get() != inner &&
- !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
+ !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
operand.set(cast.getOperand());
folded = true;
}
@@ -101,16 +101,16 @@ LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
/// Return an unranked/ranked tensor type for the given unranked/ranked memref
/// type.
Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
- if (auto memref = type.dyn_cast<MemRefType>())
+ if (auto memref = llvm::dyn_cast<MemRefType>(type))
return RankedTensorType::get(memref.getShape(), memref.getElementType());
- if (auto memref = type.dyn_cast<UnrankedMemRefType>())
+ if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
return UnrankedTensorType::get(memref.getElementType());
return NoneType::get(type.getContext());
}
SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
Location loc, Value value) {
- auto memrefType = value.getType().cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(value.getType());
SmallVector<OpFoldResult> result;
for (int64_t i = 0; i < memrefType.getRank(); ++i) {
if (memrefType.isDynamicDim(i)) {
@@ -180,7 +180,7 @@ static void constifyIndexValues(
// values, hence we recreate the attribute even when it is already static
// to make sure the type is consistent.
ofr = builder.getIndexAttr(
- ofr.get<Attribute>().cast<IntegerAttr>().getInt());
+ llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt());
continue;
}
std::optional<int64_t> maybeConstant =
@@ -241,7 +241,7 @@ template <typename AllocLikeOp>
static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
"applies to only alloc or alloca");
- auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
+ auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
if (!memRefType)
return op.emitOpError("result must be a memref");
@@ -378,7 +378,7 @@ void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
LogicalResult ReallocOp::verify() {
- auto sourceType = getOperand(0).getType().cast<MemRefType>();
+ auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
MemRefType resultType = getType();
// The source memref should have identity layout (or none).
@@ -691,8 +691,9 @@ void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
/// ```
bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
- MemRefType sourceType = castOp.getSource().getType().dyn_cast<MemRefType>();
- MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
+ MemRefType sourceType =
+ llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+ MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
// Requires ranked MemRefType.
if (!sourceType || !resultType)
@@ -743,11 +744,11 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
- auto aT = a.dyn_cast<MemRefType>();
- auto bT = b.dyn_cast<MemRefType>();
+ auto aT = llvm::dyn_cast<MemRefType>(a);
+ auto bT = llvm::dyn_cast<MemRefType>(b);
- auto uaT = a.dyn_cast<UnrankedMemRefType>();
- auto ubT = b.dyn_cast<UnrankedMemRefType>();
+ auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
+ auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
if (aT && bT) {
if (aT.getElementType() != bT.getElementType())
@@ -831,8 +832,8 @@ struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
// Check source.
if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
- auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
- auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
+ auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
if (fromType && toType) {
if (fromType.getShape() == toType.getShape() &&
@@ -847,8 +848,8 @@ struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
// Check target.
if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
- auto fromType = castOp.getSource().getType().dyn_cast<MemRefType>();
- auto toType = castOp.getSource().getType().dyn_cast<MemRefType>();
+ auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
+ auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
if (fromType && toType) {
if (fromType.getShape() == toType.getShape() &&
@@ -970,7 +971,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
for (const auto &dim : llvm::enumerate(sizes))
if (auto attr = dim.value().dyn_cast<Attribute>())
- if (attr.cast<IntegerAttr>().getInt() == 1)
+ if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
unusedDims.set(dim.index());
// Early exit for the case where the number of unused dims matches the number
@@ -1046,7 +1047,7 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return {};
// Folding for unranked types (UnrankedMemRefType) is not supported.
- auto memrefType = getSource().getType().dyn_cast<MemRefType>();
+ auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
if (!memrefType)
return {};
@@ -1256,7 +1257,7 @@ LogicalResult DmaStartOp::verify() {
// Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position.
// 1. Source memref.
- if (!getSrcMemRef().getType().isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
return emitOpError("expected source to be of memref type");
if (numOperands < getSrcMemRefRank() + 4)
return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
@@ -1267,7 +1268,7 @@ LogicalResult DmaStartOp::verify() {
return emitOpError("expected source indices to be of index type");
// 2. Destination memref.
- if (!getDstMemRef().getType().isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
return emitOpError("expected destination to be of memref type");
unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
if (numOperands < numExpectedOperands)
@@ -1283,7 +1284,7 @@ LogicalResult DmaStartOp::verify() {
return emitOpError("expected num elements to be of index type");
// 4. Tag memref.
- if (!getTagMemRef().getType().isa<MemRefType>())
+ if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
return emitOpError("expected tag to be of memref type");
numExpectedOperands += getTagMemRefRank();
if (numOperands < numExpectedOperands)
@@ -1359,7 +1360,8 @@ LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes,
properties);
- auto sourceType = extractAdaptor.getSource().getType().dyn_cast<MemRefType>();
+ auto sourceType =
+ llvm::dyn_cast<MemRefType>(extractAdaptor.getSource().getType());
if (!sourceType)
return failure();
@@ -1409,8 +1411,7 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
"The constified value should be either unchanged (i.e., == result) "
"or a constant");
Value constantVal = rewriter.create<arith::ConstantIndexOp>(
- loc, maybeConstant.template get<Attribute>()
- .template cast<IntegerAttr>()
+ loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
.getInt());
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
// updateRootInplace: lambda cannot capture structured bindings in C++17
@@ -1470,7 +1471,7 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(memref);
result.addOperands(ivs);
- if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
+ if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
Type elementType = memrefType.getElementType();
result.addTypes(elementType);
@@ -1519,7 +1520,7 @@ ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
if (parser.parseRegion(*body, {}) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
- result.types.push_back(memrefType.cast<MemRefType>().getElementType());
+ result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
return success();
}
@@ -1567,7 +1568,7 @@ parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
if (parser.parseType(type))
return failure();
- auto memrefType = type.dyn_cast<MemRefType>();
+ auto memrefType = llvm::dyn_cast<MemRefType>(type);
if (!memrefType || !memrefType.hasStaticShape())
return parser.emitError(parser.getNameLoc())
<< "type should be static shaped memref, but got " << type;
@@ -1584,14 +1585,14 @@ parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
Type tensorType = getTensorTypeFromMemRefType(memrefType);
if (parser.parseAttribute(initialValue, tensorType))
return failure();
- if (!initialValue.isa<ElementsAttr>())
+ if (!llvm::isa<ElementsAttr>(initialValue))
return parser.emitError(parser.getNameLoc())
<< "initial value should be a unit or elements attribute";
return success();
}
LogicalResult GlobalOp::verify() {
- auto memrefType = getType().dyn_cast<MemRefType>();
+ auto memrefType = llvm::dyn_cast<MemRefType>(getType());
if (!memrefType || !memrefType.hasStaticShape())
return emitOpError("type should be static shaped memref, but got ")
<< getType();
@@ -1600,14 +1601,14 @@ LogicalResult GlobalOp::verify() {
// an elements attribute.
if (getInitialValue().has_value()) {
Attribute initValue = getInitialValue().value();
- if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
+ if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
return emitOpError("initial value should be a unit or elements "
"attribute, but got ")
<< initValue;
// Check that the type of the initial value is compatible with the type of
// the global variable.
- if (auto elementsAttr = initValue.dyn_cast<ElementsAttr>()) {
+ if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
Type initType = elementsAttr.getType();
Type tensorType = getTensorTypeFromMemRefType(memrefType);
if (initType != tensorType)
@@ -1631,7 +1632,7 @@ LogicalResult GlobalOp::verify() {
ElementsAttr GlobalOp::getConstantInitValue() {
auto initVal = getInitialValue();
if (getConstant() && initVal.has_value())
- return initVal.value().cast<ElementsAttr>();
+ return llvm::cast<ElementsAttr>(initVal.value());
return {};
}
@@ -1687,11 +1688,11 @@ bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
- auto aT = a.dyn_cast<MemRefType>();
- auto bT = b.dyn_cast<MemRefType>();
+ auto aT = llvm::dyn_cast<MemRefType>(a);
+ auto bT = llvm::dyn_cast<MemRefType>(b);
- auto uaT = a.dyn_cast<UnrankedMemRefType>();
- auto ubT = b.dyn_cast<UnrankedMemRefType>();
+ auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
+ auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
if (aT && bT) {
if (aT.getElementType() != bT.getElementType())
@@ -1794,7 +1795,7 @@ LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
// Constant fold rank when the rank of the operand is known.
auto type = getOperand().getType();
- auto shapedType = type.dyn_cast<ShapedType>();
+ auto shapedType = llvm::dyn_cast<ShapedType>(type);
if (shapedType && shapedType.hasRank())
return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
return IntegerAttr();
@@ -1861,8 +1862,8 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
// completed automatically, like we have for subview and extract_slice.
LogicalResult ReinterpretCastOp::verify() {
// The source and result memrefs should be in the same memory space.
- auto srcType = getSource().getType().cast<BaseMemRefType>();
- auto resultType = getType().cast<MemRefType>();
+ auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
+ auto resultType = llvm::cast<MemRefType>(getType());
if (srcType.getMemorySpace() != resultType.getMemorySpace())
return emitError("
diff erent memory spaces specified for source type ")
<< srcType << " and result memref type " << resultType;
@@ -2250,7 +2251,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> resultShape, Value src,
ArrayRef<ReassociationIndices> reassociation) {
// Only ranked memref source values are supported.
- auto srcType = src.getType().cast<MemRefType>();
+ auto srcType = llvm::cast<MemRefType>(src.getType());
FailureOr<MemRefType> resultType =
ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
// Failure of this assertion usually indicates a problem with the source
@@ -2406,7 +2407,7 @@ MemRefType CollapseShapeOp::computeCollapsedType(
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
- auto srcType = src.getType().cast<MemRefType>();
+ auto srcType = llvm::cast<MemRefType>(src.getType());
MemRefType resultType =
CollapseShapeOp::computeCollapsedType(srcType, reassociation);
build(b, result, resultType, src, attrs);
@@ -2473,7 +2474,7 @@ struct CollapseShapeOpMemRefCastFolder
return failure();
Type newResultType = CollapseShapeOp::computeCollapsedType(
- cast.getOperand().getType().cast<MemRefType>(),
+ llvm::cast<MemRefType>(cast.getOperand().getType()),
op.getReassociationIndices());
if (newResultType == op.getResultType()) {
@@ -2518,18 +2519,20 @@ LogicalResult ReshapeOp::verify() {
Type operandType = getSource().getType();
Type resultType = getResult().getType();
- Type operandElementType = operandType.cast<ShapedType>().getElementType();
- Type resultElementType = resultType.cast<ShapedType>().getElementType();
+ Type operandElementType =
+ llvm::cast<ShapedType>(operandType).getElementType();
+ Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
if (operandElementType != resultElementType)
return emitOpError("element types of source and destination memref "
"types should be the same");
- if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
+ if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
if (!operandMemRefType.getLayout().isIdentity())
return emitOpError("source memref type should have identity affine map");
- int64_t shapeSize = getShape().getType().cast<MemRefType>().getDimSize(0);
- auto resultMemRefType = resultType.dyn_cast<MemRefType>();
+ int64_t shapeSize =
+ llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
+ auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
if (resultMemRefType) {
if (!resultMemRefType.getLayout().isIdentity())
return emitOpError("result memref type should have identity affine map");
@@ -2634,9 +2637,8 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
- auto inferredType =
- inferResultType(sourceRankedTensorType, offsets, sizes, strides)
- .cast<MemRefType>();
+ auto inferredType = llvm::cast<MemRefType>(
+ inferResultType(sourceRankedTensorType, offsets, sizes, strides));
assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
"expected ");
if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
@@ -2648,7 +2650,7 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
assert(dimsToProject.has_value() && "invalid rank reduction");
// Compute the layout and result type.
- auto inferredLayout = inferredType.getLayout().cast<StridedLayoutAttr>();
+ auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
SmallVector<int64_t> rankReducedStrides;
rankReducedStrides.reserve(resultShape.size());
for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
@@ -2690,12 +2692,11 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
- auto sourceMemRefType = source.getType().cast<MemRefType>();
+ auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
- staticSizes, staticStrides)
- .cast<MemRefType>();
+ resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
+ sourceMemRefType, staticOffsets, staticSizes, staticStrides));
}
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
@@ -2824,7 +2825,7 @@ isRankReducedMemRefType(MemRefType originalType,
template <typename OpTy>
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
OpTy op, Type expectedType) {
- auto memrefType = expectedType.cast<ShapedType>();
+ auto memrefType = llvm::cast<ShapedType>(expectedType);
switch (result) {
case SliceVerificationResult::Success:
return success();
@@ -2867,7 +2868,7 @@ LogicalResult SubViewOp::verify() {
auto expectedType = SubViewOp::inferResultType(
baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
- auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
+ auto result = isRankReducedMemRefType(llvm::cast<MemRefType>(expectedType),
subViewType, getMixedSizes());
return produceSubViewErrorMsg(result, *this, expectedType);
}
@@ -2917,9 +2918,8 @@ static MemRefType getCanonicalSubViewResultType(
MemRefType currentResultType, MemRefType currentSourceType,
MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
- auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
- mixedSizes, mixedStrides)
- .cast<MemRefType>();
+ auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
+ sourceType, mixedOffsets, mixedSizes, mixedStrides));
std::optional<llvm::SmallBitVector> unusedDims =
computeMemRefRankReductionMask(currentSourceType, currentResultType,
mixedSizes);
@@ -2927,7 +2927,7 @@ static MemRefType getCanonicalSubViewResultType(
if (!unusedDims)
return nullptr;
- auto layout = nonRankReducedType.getLayout().cast<StridedLayoutAttr>();
+ auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
SmallVector<int64_t> shape, strides;
unsigned numDimsAfterReduction =
nonRankReducedType.getRank() - unusedDims->count();
@@ -2962,14 +2962,14 @@ static MemRefType getCanonicalSubViewResultType(
Value mlir::memref::createCanonicalRankReducingSubViewOp(
OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
- auto memrefType = memref.getType().cast<MemRefType>();
+ auto memrefType = llvm::cast<MemRefType>(memref.getType());
unsigned rank = memrefType.getRank();
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
- auto targetType = SubViewOp::inferRankReducedResultType(
- targetShape, memrefType, offsets, sizes, strides)
- .cast<MemRefType>();
+ auto targetType =
+ llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
+ targetShape, memrefType, offsets, sizes, strides));
return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
sizes, strides);
}
@@ -2977,7 +2977,7 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
Value value,
ArrayRef<int64_t> desiredShape) {
- auto sourceMemrefType = value.getType().dyn_cast<MemRefType>();
+ auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
assert(sourceMemrefType && "not a ranked memref type");
auto sourceShape = sourceMemrefType.getShape();
if (sourceShape.equals(desiredShape))
@@ -3069,7 +3069,7 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
// if the operation is rank-reducing.
auto resultType = getCanonicalSubViewResultType(
subViewOp.getType(), subViewOp.getSourceType(),
- castOp.getSource().getType().cast<MemRefType>(),
+ llvm::cast<MemRefType>(castOp.getSource().getType()),
subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
subViewOp.getMixedStrides());
if (!resultType)
@@ -3134,8 +3134,8 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
- auto resultShapedType = getResult().getType().cast<ShapedType>();
- auto sourceShapedType = getSource().getType().cast<ShapedType>();
+ auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
+ auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
if (resultShapedType.hasStaticShape() &&
resultShapedType == sourceShapedType) {
@@ -3201,7 +3201,7 @@ void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
auto permutationMap = permutation.getValue();
assert(permutationMap);
- auto memRefType = in.getType().cast<MemRefType>();
+ auto memRefType = llvm::cast<MemRefType>(in.getType());
// Compute result type.
MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
@@ -3239,8 +3239,8 @@ LogicalResult TransposeOp::verify() {
if (getPermutation().getNumDims() != getIn().getType().getRank())
return emitOpError("expected a permutation map of same rank as the input");
- auto srcType = getIn().getType().cast<MemRefType>();
- auto dstType = getType().cast<MemRefType>();
+ auto srcType = llvm::cast<MemRefType>(getIn().getType());
+ auto dstType = llvm::cast<MemRefType>(getType());
auto transposedType = inferTransposeResultType(srcType, getPermutation());
if (dstType != transposedType)
return emitOpError("output type ")
@@ -3264,7 +3264,7 @@ void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
}
LogicalResult ViewOp::verify() {
- auto baseType = getOperand(0).getType().cast<MemRefType>();
+ auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
auto viewType = getType();
// The base memref should have identity layout map (or none).
@@ -3401,7 +3401,7 @@ LogicalResult AtomicRMWOp::verify() {
case arith::AtomicRMWKind::maxf:
case arith::AtomicRMWKind::minf:
case arith::AtomicRMWKind::mulf:
- if (!getValue().getType().isa<FloatType>())
+ if (!llvm::isa<FloatType>(getValue().getType()))
return emitOpError() << "with kind '"
<< arith::stringifyAtomicRMWKind(getKind())
<< "' expects a floating-point type";
@@ -3414,7 +3414,7 @@ LogicalResult AtomicRMWOp::verify() {
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
case arith::AtomicRMWKind::andi:
- if (!getValue().getType().isa<IntegerType>())
+ if (!llvm::isa<IntegerType>(getValue().getType()))
return emitOpError() << "with kind '"
<< arith::stringifyAtomicRMWKind(getKind())
<< "' expects an integer type";
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index ca63fb3d0de6a..daec22cf6ebdc 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -37,8 +37,8 @@ struct CastOpInterface
auto castOp = cast<CastOp>(op);
assert(value == castOp.getResult() && "invalid value");
- if (castOp.getResult().getType().isa<MemRefType>() &&
- castOp.getSource().getType().isa<MemRefType>()) {
+ if (llvm::isa<MemRefType>(castOp.getResult().getType()) &&
+ llvm::isa<MemRefType>(castOp.getSource().getType())) {
cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
}
}
@@ -79,7 +79,7 @@ struct RankOpInterface
auto rankOp = cast<RankOp>(op);
assert(value == rankOp.getResult() && "invalid value");
- auto memrefType = rankOp.getMemref().getType().dyn_cast<MemRefType>();
+ auto memrefType = llvm::dyn_cast<MemRefType>(rankOp.getMemref().getType());
if (!memrefType)
return;
cstr.bound(value) == memrefType.getRank();
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 32f76b6cd5154..77c853a2c35f4 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -38,9 +38,9 @@ bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
Attribute memorySpace = type.getMemorySpace();
if (!memorySpace)
return false;
- if (auto intAttr = memorySpace.dyn_cast<IntegerAttr>())
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
return intAttr.getInt() == NVGPUDialect::kSharedMemoryAddressSpace;
- if (auto gpuAttr = memorySpace.dyn_cast<gpu::AddressSpaceAttr>())
+ if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
return gpuAttr.getValue() == gpu::AddressSpace::Workgroup;
return false;
}
@@ -61,8 +61,8 @@ static bool isLastMemrefDimUnitStride(MemRefType type) {
}
LogicalResult DeviceAsyncCopyOp::verify() {
- auto srcMemref = getSrc().getType().cast<MemRefType>();
- auto dstMemref = getDst().getType().cast<MemRefType>();
+ auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
+ auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
if (!isLastMemrefDimUnitStride(srcMemref))
return emitError("source memref most minor dim must have unit stride");
@@ -246,10 +246,10 @@ LogicalResult MmaSparseSyncOp::verify() {
LogicalResult LdMatrixOp::verify() {
// ldmatrix reads data from source in shared memory
- auto srcMemref = getSrcMemref().getType().cast<MemRefType>();
+ auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
// ldmatrix writes data to result/destination in vector registers
- auto resVector = getRes().getType().cast<VectorType>();
+ auto resVector = llvm::cast<VectorType>(getRes().getType());
// vector register shape, element type, and bitwidth
ArrayRef<int64_t> resShape = resVector.getShape();
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 78c824c3deba1..66cc653d6a19b 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -41,7 +41,7 @@ template <typename T>
struct PointerLikeModel
: public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
Type getElementType(Type pointer) const {
- return pointer.cast<T>().getElementType();
+ return llvm::cast<T>(pointer).getElementType();
}
};
@@ -231,7 +231,7 @@ verifyAlignedClause(Operation *op, std::optional<ArrayAttr> alignmentValues,
// Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
for (unsigned i = 0; i < (*alignmentValues).size(); ++i) {
- if (auto intAttr = (*alignmentValues)[i].dyn_cast<IntegerAttr>()) {
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
if (intAttr.getValue().sle(0))
return op->emitOpError() << "alignment should be greater than 0";
} else {
@@ -463,7 +463,7 @@ static LogicalResult verifyReductionVarList(Operation *op,
return op->emitOpError() << "accumulator variable used more than once";
Type varType = accum.getType();
- auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
+ auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
auto decl =
SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
if (!decl)
@@ -521,7 +521,8 @@ static void printDependVarList(OpAsmPrinter &p, Operation *op,
if (i != 0)
p << ", ";
p << stringifyClauseTaskDepend(
- (*depends)[i].cast<mlir::omp::ClauseTaskDependAttr>().getValue())
+ llvm::cast<mlir::omp::ClauseTaskDependAttr>((*depends)[i])
+ .getValue())
<< " -> " << dependVars[i] << " : " << dependTypes[i];
}
}
@@ -723,8 +724,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
Value mapOp = map_operands[i];
Attribute mapTypeOp = map_types[i];
- assert(mapTypeOp.isa<mlir::IntegerAttr>());
- mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt();
+ assert(llvm::isa<mlir::IntegerAttr>(mapTypeOp));
+ mapTypeBits = llvm::cast<mlir::IntegerAttr>(mapTypeOp).getInt();
bool always = bitAnd(mapTypeBits,
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
@@ -1018,8 +1019,8 @@ LogicalResult ReductionDeclareOp::verifyRegions() {
atomicReductionEntryBlock.getArgumentTypes()[1])
return emitOpError() << "expects atomic reduction region with two "
"arguments of the same type";
- auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
- .dyn_cast<PointerLikeType>();
+ auto ptrType = llvm::dyn_cast<PointerLikeType>(
+ atomicReductionEntryBlock.getArgumentTypes()[0]);
if (!ptrType ||
(ptrType.getElementType() && ptrType.getElementType() != getType()))
return emitOpError() << "expects atomic reduction region arguments to "
@@ -1210,7 +1211,7 @@ LogicalResult AtomicWriteOp::verify() {
}
}
Type elementType =
- getAddress().getType().cast<PointerLikeType>().getElementType();
+ llvm::cast<PointerLikeType>(getAddress().getType()).getElementType();
if (elementType && elementType != getValue().getType())
return emitError("address must dereference to value type");
return verifySynchronizationHint(*this, getHintVal());
@@ -1261,7 +1262,8 @@ LogicalResult AtomicUpdateOp::verify() {
if (getRegion().getNumArguments() != 1)
return emitError("the region must accept exactly one argument");
- Type elementType = getX().getType().cast<PointerLikeType>().getElementType();
+ Type elementType =
+ llvm::cast<PointerLikeType>(getX().getType()).getElementType();
if (elementType && elementType != getRegion().getArgument(0).getType()) {
return emitError("the type of the operand must be a pointer type whose "
"element type is the same as that of the region argument");
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index a450386fef722..d5f34679f06c6 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -465,7 +465,7 @@ static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
}
LogicalResult ResultsOp::verify() {
- if (!getIndex() && getType().isa<pdl::ValueType>()) {
+ if (!getIndex() && llvm::isa<pdl::ValueType>(getType())) {
return emitOpError() << "expected `pdl.range<value>` result type when "
"no index is specified, but got: "
<< getType();
diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
index 49eee1afe0964..6f6c9d337066d 100644
--- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp
@@ -60,7 +60,7 @@ bool PDLType::classof(Type type) {
}
Type pdl::getRangeElementTypeOrSelf(Type type) {
- if (auto rangeType = type.dyn_cast<RangeType>())
+ if (auto rangeType = llvm::dyn_cast<RangeType>(type))
return rangeType.getElementType();
return type;
}
@@ -78,7 +78,7 @@ Type RangeType::parse(AsmParser &parser) {
if (!elementType || parser.parseGreater())
return Type();
- if (elementType.isa<RangeType>()) {
+ if (llvm::isa<RangeType>(elementType)) {
parser.emitError(elementLoc)
<< "element of pdl.range cannot be another range, but got"
<< elementType;
@@ -95,7 +95,7 @@ void RangeType::print(AsmPrinter &printer) const {
LogicalResult RangeType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType) {
- if (!elementType.isa<PDLType>() || elementType.isa<RangeType>()) {
+ if (!llvm::isa<PDLType>(elementType) || llvm::isa<RangeType>(elementType)) {
return emitError()
<< "expected element of pdl.range to be one of [!pdl.attribute, "
"!pdl.operation, !pdl.type, !pdl.value], but got "
diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index 2cc282d25a262..e242e9ad33993 100644
--- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -145,7 +145,7 @@ void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
if (initLoop) {
// Create the block and the loop variable.
// FIXME: Allow passing in a proper location for the loop variable.
- auto rangeType = range.getType().cast<pdl::RangeType>();
+ auto rangeType = llvm::cast<pdl::RangeType>(range.getType());
state.regions.front()->emplaceBlock();
state.regions.front()->addArgument(rangeType.getElementType(),
state.location);
@@ -238,7 +238,8 @@ void FuncOp::print(OpAsmPrinter &p) {
/// Given the result type of a `GetValueTypeOp`, return the expected input type.
static Type getGetValueTypeOpValueType(Type type) {
Type valueTy = pdl::ValueType::get(type.getContext());
- return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
+ return llvm::isa<pdl::RangeType>(type) ? pdl::RangeType::get(valueTy)
+ : valueTy;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 43f538b62a236..81e3b914755be 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -35,7 +35,7 @@ QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
- auto intStorageType = storageType.dyn_cast<IntegerType>();
+ auto intStorageType = llvm::dyn_cast<IntegerType>(storageType);
if (!intStorageType)
return emitError() << "storage type must be integral";
unsigned integralWidth = intStorageType.getWidth();
@@ -83,8 +83,8 @@ Type QuantizedType::getExpressedType() const {
}
bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
- if (candidateExpressedType.isa<ShapedType>()) {
- return candidateExpressedType.cast<ShapedType>().getElementType() ==
+ if (llvm::isa<ShapedType>(candidateExpressedType)) {
+ return llvm::cast<ShapedType>(candidateExpressedType).getElementType() ==
getExpressedType();
}
return candidateExpressedType == getExpressedType();
@@ -92,12 +92,12 @@ bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
QuantizedType
QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
- if (primitiveOrContainerType.isa<ShapedType>()) {
+ if (llvm::isa<ShapedType>(primitiveOrContainerType)) {
Type elementType =
- primitiveOrContainerType.cast<ShapedType>().getElementType();
- return elementType.dyn_cast<QuantizedType>();
+ llvm::cast<ShapedType>(primitiveOrContainerType).getElementType();
+ return llvm::dyn_cast<QuantizedType>(elementType);
}
- return primitiveOrContainerType.dyn_cast<QuantizedType>();
+ return llvm::dyn_cast<QuantizedType>(primitiveOrContainerType);
}
Type QuantizedType::castFromStorageType(Type candidateType) {
@@ -105,18 +105,19 @@ Type QuantizedType::castFromStorageType(Type candidateType) {
// i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
return *this;
}
- if (candidateType.isa<RankedTensorType>()) {
+ if (llvm::isa<RankedTensorType>(candidateType)) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return RankedTensorType::get(
- candidateType.cast<RankedTensorType>().getShape(), getStorageType());
+ llvm::cast<RankedTensorType>(candidateType).getShape(),
+ getStorageType());
}
- if (candidateType.isa<UnrankedTensorType>()) {
+ if (llvm::isa<UnrankedTensorType>(candidateType)) {
// i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
return UnrankedTensorType::get(getStorageType());
}
- if (candidateType.isa<VectorType>()) {
+ if (llvm::isa<VectorType>(candidateType)) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- return VectorType::get(candidateType.cast<VectorType>().getShape(),
+ return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
getStorageType());
}
@@ -124,25 +125,25 @@ Type QuantizedType::castFromStorageType(Type candidateType) {
}
Type QuantizedType::castToStorageType(Type quantizedType) {
- if (quantizedType.isa<QuantizedType>()) {
+ if (llvm::isa<QuantizedType>(quantizedType)) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
- return quantizedType.cast<QuantizedType>().getStorageType();
+ return llvm::cast<QuantizedType>(quantizedType).getStorageType();
}
- if (quantizedType.isa<ShapedType>()) {
+ if (llvm::isa<ShapedType>(quantizedType)) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- ShapedType sType = quantizedType.cast<ShapedType>();
- if (!sType.getElementType().isa<QuantizedType>()) {
+ ShapedType sType = llvm::cast<ShapedType>(quantizedType);
+ if (!llvm::isa<QuantizedType>(sType.getElementType())) {
return nullptr;
}
Type storageType =
- sType.getElementType().cast<QuantizedType>().getStorageType();
- if (quantizedType.isa<RankedTensorType>()) {
+ llvm::cast<QuantizedType>(sType.getElementType()).getStorageType();
+ if (llvm::isa<RankedTensorType>(quantizedType)) {
return RankedTensorType::get(sType.getShape(), storageType);
}
- if (quantizedType.isa<UnrankedTensorType>()) {
+ if (llvm::isa<UnrankedTensorType>(quantizedType)) {
return UnrankedTensorType::get(storageType);
}
- if (quantizedType.isa<VectorType>()) {
+ if (llvm::isa<VectorType>(quantizedType)) {
return VectorType::get(sType.getShape(), storageType);
}
}
@@ -155,21 +156,21 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
// i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
return *this;
}
- if (candidateType.isa<ShapedType>()) {
- ShapedType candidateShapedType = candidateType.cast<ShapedType>();
+ if (llvm::isa<ShapedType>(candidateType)) {
+ ShapedType candidateShapedType = llvm::cast<ShapedType>(candidateType);
if (candidateShapedType.getElementType() != getExpressedType()) {
return nullptr;
}
- if (candidateType.isa<RankedTensorType>()) {
+ if (llvm::isa<RankedTensorType>(candidateType)) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return RankedTensorType::get(candidateShapedType.getShape(), *this);
}
- if (candidateType.isa<UnrankedTensorType>()) {
+ if (llvm::isa<UnrankedTensorType>(candidateType)) {
// i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
return UnrankedTensorType::get(*this);
}
- if (candidateType.isa<VectorType>()) {
+ if (llvm::isa<VectorType>(candidateType)) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return VectorType::get(candidateShapedType.getShape(), *this);
}
@@ -179,25 +180,25 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
}
Type QuantizedType::castToExpressedType(Type quantizedType) {
- if (quantizedType.isa<QuantizedType>()) {
+ if (llvm::isa<QuantizedType>(quantizedType)) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
- return quantizedType.cast<QuantizedType>().getExpressedType();
+ return llvm::cast<QuantizedType>(quantizedType).getExpressedType();
}
- if (quantizedType.isa<ShapedType>()) {
+ if (llvm::isa<ShapedType>(quantizedType)) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- ShapedType sType = quantizedType.cast<ShapedType>();
- if (!sType.getElementType().isa<QuantizedType>()) {
+ ShapedType sType = llvm::cast<ShapedType>(quantizedType);
+ if (!llvm::isa<QuantizedType>(sType.getElementType())) {
return nullptr;
}
Type expressedType =
- sType.getElementType().cast<QuantizedType>().getExpressedType();
- if (quantizedType.isa<RankedTensorType>()) {
+ llvm::cast<QuantizedType>(sType.getElementType()).getExpressedType();
+ if (llvm::isa<RankedTensorType>(quantizedType)) {
return RankedTensorType::get(sType.getShape(), expressedType);
}
- if (quantizedType.isa<UnrankedTensorType>()) {
+ if (llvm::isa<UnrankedTensorType>(quantizedType)) {
return UnrankedTensorType::get(expressedType);
}
- if (quantizedType.isa<VectorType>()) {
+ if (llvm::isa<VectorType>(quantizedType)) {
return VectorType::get(sType.getShape(), expressedType);
}
}
@@ -243,7 +244,7 @@ AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
- if (expressedType && !expressedType.isa<FloatType>())
+ if (expressedType && !llvm::isa<FloatType>(expressedType))
return emitError() << "expressed type must be floating point";
return success();
@@ -284,7 +285,7 @@ LogicalResult UniformQuantizedType::verify(
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
- if (!expressedType.isa<FloatType>())
+ if (!llvm::isa<FloatType>(expressedType))
return emitError() << "expressed type must be floating point";
// Verify scale.
@@ -338,7 +339,7 @@ LogicalResult UniformQuantizedPerAxisType::verify(
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
- if (!expressedType.isa<FloatType>())
+ if (!llvm::isa<FloatType>(expressedType))
return emitError() << "expressed type must be floating point";
// Ensure that the number of scales and zeroPoints match.
@@ -385,7 +386,7 @@ CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
- if (!expressedType.isa<FloatType>())
+ if (!llvm::isa<FloatType>(expressedType))
return emitError() << "expressed type must be floating point";
if (max <= min)
return emitError() << "illegal min and max: (" << min << ":" << max << ")";
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 05a8ebd5095de..926a8a0aa13d5 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -420,13 +420,13 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type,
/// Print a type registered to this dialect.
void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
- if (auto anyType = type.dyn_cast<AnyQuantizedType>())
+ if (auto anyType = llvm::dyn_cast<AnyQuantizedType>(type))
printAnyQuantizedType(anyType, os);
- else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
+ else if (auto uniformType = llvm::dyn_cast<UniformQuantizedType>(type))
printUniformQuantizedType(uniformType, os);
- else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
+ else if (auto perAxisType = llvm::dyn_cast<UniformQuantizedPerAxisType>(type))
printUniformQuantizedPerAxisType(perAxisType, os);
- else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
+ else if (auto calibratedType = llvm::dyn_cast<CalibratedQuantizedType>(type))
printCalibratedQuantizedType(calibratedType, os);
else
llvm_unreachable("Unhandled quantized type");
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index cbd98f39b4068..a88b9136974a8 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -495,7 +495,7 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
Region &ForOp::getLoopBody() { return getRegion(); }
ForOp mlir::scf::getForInductionVarOwner(Value val) {
- auto ivArg = val.dyn_cast<BlockArgument>();
+ auto ivArg = llvm::dyn_cast<BlockArgument>(val);
if (!ivArg)
return ForOp();
assert(ivArg.getOwner() && "unlinked block argument");
@@ -576,7 +576,7 @@ void mlir::scf::promote(PatternRewriter &rewriter, scf::ForallOp forallOp) {
};
Value srcVal = mapping.lookupOrDefault(src);
- if (srcVal.getType().isa<TensorType>()) {
+ if (llvm::isa<TensorType>(srcVal.getType())) {
results.push_back(rewriter.create<tensor::InsertSliceOp>(
forallOp.getLoc(), dst.getType(), srcVal,
mapping.lookupOrDefault(dst),
@@ -890,7 +890,8 @@ static SmallVector<Value>
replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
Value replacement) {
Type oldType = operand.get().getType(), newType = replacement.getType();
- assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
+ assert(llvm::isa<RankedTensorType>(oldType) &&
+ llvm::isa<RankedTensorType>(newType) &&
"expected ranked tensor types");
// 1. Create new iter operands, exactly 1 is replaced.
@@ -1074,7 +1075,7 @@ struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
Value yieldVal = yieldOp->getOperand(idx);
auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
- bool isTensor = bbArg.getType().isa<TensorType>();
+ bool isTensor = llvm::isa<TensorType>(bbArg.getType());
bufferization::ToMemrefOp tensorToMemref;
// Either bbArg has no use or it has a single buffer_cast use.
@@ -1445,7 +1446,7 @@ InParallelOp ForallOp::getTerminator() {
}
ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) {
- auto tidxArg = val.dyn_cast<BlockArgument>();
+ auto tidxArg = llvm::dyn_cast<BlockArgument>(val);
if (!tidxArg)
return ForallOp();
assert(tidxArg.getOwner() && "unlinked block argument");
@@ -1464,7 +1465,8 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
if (!forallOp)
return failure();
Value sharedOut =
- forallOp.getTiedOpOperand(dimOp.getSource().cast<OpResult>())->get();
+ forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
+ ->get();
rewriter.updateRootInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
return success();
@@ -1744,7 +1746,7 @@ SmallVector<BlockArgument> InParallelOp::getDests() {
llvm::map_range(getYieldingOps(), [](Operation &op) {
// Add new ops here as needed.
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
- return insertSliceOp.getDest().cast<BlockArgument>();
+ return llvm::cast<BlockArgument>(insertSliceOp.getDest());
}));
}
@@ -1964,7 +1966,7 @@ void IfOp::getSuccessorRegions(std::optional<unsigned> index,
// Otherwise, the successor is dependent on the condition.
bool condition;
- if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
+ if (auto condAttr = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
condition = condAttr.getValue().isOne();
} else {
// If the condition isn't constant, both regions may be executed.
@@ -2006,7 +2008,7 @@ LogicalResult IfOp::fold(FoldAdaptor adaptor,
void IfOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands,
SmallVectorImpl<InvocationBounds> &invocationBounds) {
- if (auto cond = operands[0].dyn_cast_or_null<BoolAttr>()) {
+ if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
// If the condition is known, then one region is known to be executed once
// and the other zero times.
invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
@@ -2542,7 +2544,7 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
// come from the same scf.if.
for (const auto &tup : llvm::enumerate(thenYield)) {
if (tup.value().getDefiningOp() == nestedIf) {
- auto nestedIdx = tup.value().cast<OpResult>().getResultNumber();
+ auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
if (nestedIf.elseYield().getOperand(nestedIdx) !=
elseYield[tup.index()]) {
return failure();
@@ -2818,7 +2820,7 @@ void ParallelOp::print(OpAsmPrinter &p) {
Region &ParallelOp::getLoopBody() { return getRegion(); }
ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
- auto ivArg = val.dyn_cast<BlockArgument>();
+ auto ivArg = llvm::dyn_cast<BlockArgument>(val);
if (!ivArg)
return ParallelOp();
assert(ivArg.getOwner() && "unlinked block argument");
@@ -3130,7 +3132,7 @@ void WhileOp::getSuccessorRegions(std::optional<unsigned> index,
// Try to narrow the successor to the condition region.
assert(!operands.empty() && "expected at least one operand");
- auto cond = operands[0].dyn_cast_or_null<BoolAttr>();
+ auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0]);
if (!cond || !cond.getValue())
regions.emplace_back(getResults());
if (!cond || cond.getValue())
@@ -3360,7 +3362,7 @@ struct RemoveLoopInvariantArgsFromBeforeBlock
// block argument or the initial value of i-th before block argument. If
// the comparison results `true`, i-th before block argument is a loop
// invariant.
- auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
+ auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
@@ -3392,7 +3394,7 @@ struct RemoveLoopInvariantArgsFromBeforeBlock
// before block argument or the initial value of i-th before block
// argument. If the comparison results `true`, i-th before block
// argument is a loop invariant.
- auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
+ auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
@@ -3960,7 +3962,7 @@ void IndexSwitchOp::getSuccessorRegions(
}
// If a constant was not provided, all regions are possible successors.
- auto operandValue = operands.front().dyn_cast_or_null<IntegerAttr>();
+ auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
if (!operandValue) {
for (Region &caseRegion : getCaseRegions())
successors.emplace_back(&caseRegion);
@@ -3981,7 +3983,7 @@ void IndexSwitchOp::getSuccessorRegions(
void IndexSwitchOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
- auto operandValue = operands.front().dyn_cast_or_null<IntegerAttr>();
+ auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
if (!operandValue) {
// All regions are invoked at most once.
bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index cc6541e02f9ca..47f6da25b3251 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -27,10 +27,10 @@ struct ForOpInterface
ValueBoundsConstraintSet &cstr) {
// `value` is an iter_arg or an OpResult.
int64_t iterArgIdx;
- if (auto iterArg = value.dyn_cast<BlockArgument>()) {
+ if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
} else {
- iterArgIdx = value.cast<OpResult>().getResultNumber();
+ iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
}
// An EQ constraint can be added if the yielded value (dimension size)
@@ -63,7 +63,7 @@ struct ForOpInterface
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
[&](Value v, std::optional<int64_t> d) {
// Stop when reaching a block argument of the loop body.
- if (auto bbArg = v.dyn_cast<BlockArgument>())
+ if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
return bbArg.getOwner()->getParentOp() == forOp;
// Stop when reaching a value that is defined outside of the loop. It
// is impossible to reach an iter_arg from there.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index c8179cc5e7b34..8a0ee7a3d8136 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -145,18 +145,20 @@ StringRef spirv::InterfaceVarABIAttr::getKindName() {
}
uint32_t spirv::InterfaceVarABIAttr::getBinding() {
- return getImpl()->binding.cast<IntegerAttr>().getInt();
+ return llvm::cast<IntegerAttr>(getImpl()->binding).getInt();
}
uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
- return getImpl()->descriptorSet.cast<IntegerAttr>().getInt();
+ return llvm::cast<IntegerAttr>(getImpl()->descriptorSet).getInt();
}
std::optional<spirv::StorageClass>
spirv::InterfaceVarABIAttr::getStorageClass() {
if (getImpl()->storageClass)
return static_cast<spirv::StorageClass>(
- getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue());
+ llvm::cast<IntegerAttr>(getImpl()->storageClass)
+ .getValue()
+ .getZExtValue());
return std::nullopt;
}
@@ -170,7 +172,7 @@ LogicalResult spirv::InterfaceVarABIAttr::verify(
return emitError() << "expected 32-bit integer for binding";
if (storageClass) {
- if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
+ if (auto storageClassAttr = llvm::cast<IntegerAttr>(storageClass)) {
auto storageClassValue =
spirv::symbolizeStorageClass(storageClassAttr.getInt());
if (!storageClassValue)
@@ -219,14 +221,14 @@ StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
spirv::Version spirv::VerCapExtAttr::getVersion() {
return static_cast<spirv::Version>(
- getImpl()->version.cast<IntegerAttr>().getValue().getZExtValue());
+ llvm::cast<IntegerAttr>(getImpl()->version).getValue().getZExtValue());
}
spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
: llvm::mapped_iterator<ArrayAttr::iterator,
spirv::Extension (*)(Attribute)>(
it, [](Attribute attr) {
- return *symbolizeExtension(attr.cast<StringAttr>().getValue());
+ return *symbolizeExtension(llvm::cast<StringAttr>(attr).getValue());
}) {}
spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
@@ -235,7 +237,7 @@ spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
}
ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() {
- return getImpl()->extensions.cast<ArrayAttr>();
+ return llvm::cast<ArrayAttr>(getImpl()->extensions);
}
spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
@@ -243,7 +245,7 @@ spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
spirv::Capability (*)(Attribute)>(
it, [](Attribute attr) {
return *symbolizeCapability(
- attr.cast<IntegerAttr>().getValue().getZExtValue());
+ llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
}) {}
spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
@@ -252,7 +254,7 @@ spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
}
ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
- return getImpl()->capabilities.cast<ArrayAttr>();
+ return llvm::cast<ArrayAttr>(getImpl()->capabilities);
}
LogicalResult
@@ -263,7 +265,7 @@ spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "expected 32-bit integer for version";
if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
- if (auto intAttr = attr.dyn_cast<IntegerAttr>())
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
return true;
return false;
@@ -271,7 +273,7 @@ spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "unknown capability in capability list";
if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
- if (auto strAttr = attr.dyn_cast<StringAttr>())
+ if (auto strAttr = llvm::dyn_cast<StringAttr>(attr))
if (spirv::symbolizeExtension(strAttr.getValue()))
return true;
return false;
@@ -297,7 +299,7 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get(
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
- return getImpl()->triple.cast<spirv::VerCapExtAttr>();
+ return llvm::cast<spirv::VerCapExtAttr>(getImpl()->triple);
}
spirv::Version spirv::TargetEnvAttr::getVersion() const {
@@ -337,7 +339,7 @@ uint32_t spirv::TargetEnvAttr::getDeviceID() const {
}
spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
- return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
+ return llvm::cast<spirv::ResourceLimitsAttr>(getImpl()->limits);
}
//===----------------------------------------------------------------------===//
@@ -628,7 +630,7 @@ static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
[&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
printer << "], [";
llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
- os << attr.cast<StringAttr>().getValue();
+ os << llvm::cast<StringAttr>(attr).getValue();
});
printer << "]>";
}
@@ -669,11 +671,11 @@ void SPIRVDialect::printAttribute(Attribute attr,
if (succeeded(generatedAttributePrinter(attr, printer)))
return;
- if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
+ if (auto targetEnv = llvm::dyn_cast<TargetEnvAttr>(attr))
print(targetEnv, printer);
- else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
+ else if (auto vceAttr = llvm::dyn_cast<VerCapExtAttr>(attr))
print(vceAttr, printer);
- else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
+ else if (auto interfaceVarABIAttr = llvm::dyn_cast<InterfaceVarABIAttr>(attr))
print(interfaceVarABIAttr, printer);
else
llvm_unreachable("unhandled SPIR-V attribute kind");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 37b9051bfee9b..3aa1c3f1c2c54 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -33,9 +33,9 @@ static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
if (!attr)
return std::nullopt;
- if (auto boolAttr = attr.dyn_cast<BoolAttr>())
+ if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
return boolAttr.getValue();
- if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>())
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
if (splatAttr.getElementType().isInteger(1))
return splatAttr.getSplatValue<bool>();
return std::nullopt;
@@ -52,12 +52,12 @@ static Attribute extractCompositeElement(Attribute composite,
if (indices.empty())
return composite;
- if (auto vector = composite.dyn_cast<ElementsAttr>()) {
+ if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
assert(indices.size() == 1 && "must have exactly one index for a vector");
return vector.getValues<Attribute>()[indices[0]];
}
- if (auto array = composite.dyn_cast<ArrayAttr>()) {
+ if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
assert(!indices.empty() && "must have at least one index for an array");
return extractCompositeElement(array.getValue()[indices[0]],
indices.drop_front());
@@ -149,7 +149,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
if (auto constructOp =
getComposite().getDefiningOp<spirv::CompositeConstructOp>()) {
- auto type = constructOp.getType().cast<spirv::CompositeType>();
+ auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
if (getIndices().size() == 1 &&
constructOp.getConstituents().size() == type.getNumElements()) {
auto i = getIndices().begin()->cast<IntegerAttr>();
@@ -159,7 +159,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
auto indexVector =
llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) {
- return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
+ return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
}));
return extractCompositeElement(adaptor.getComposite(), indexVector);
}
@@ -434,10 +434,9 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
// "Before version 1.4, Result Type must be a pointer, scalar, or vector.
// Starting with version 1.4, Result Type can additionally be a composite type
// other than a vector."
- bool isScalarOrVector = trueBrStoreOp.getValue()
- .getType()
- .cast<spirv::SPIRVType>()
- .isScalarOrVector();
+ bool isScalarOrVector =
+ llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
+ .isScalarOrVector();
// Check that each `spirv.Store` uses the same pointer, memory access
// attributes and a valid type of the value.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 220e08f309b32..37e5d77caa49c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -164,19 +164,19 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
return type;
// Check other allowed types
- if (auto t = type.dyn_cast<FloatType>()) {
+ if (auto t = llvm::dyn_cast<FloatType>(type)) {
if (type.isBF16()) {
parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
return Type();
}
- } else if (auto t = type.dyn_cast<IntegerType>()) {
+ } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
"only 1/8/16/32/64-bit integer type allowed but found ")
<< type;
return Type();
}
- } else if (auto t = type.dyn_cast<VectorType>()) {
+ } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
if (t.getRank() != 1) {
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
@@ -203,7 +203,7 @@ static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
if (parser.parseType(type))
return Type();
- if (auto t = type.dyn_cast<VectorType>()) {
+ if (auto t = llvm::dyn_cast<VectorType>(type)) {
if (t.getRank() != 1) {
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
@@ -216,7 +216,7 @@ static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
return Type();
}
- if (!t.getElementType().isa<FloatType>()) {
+ if (!llvm::isa<FloatType>(t.getElementType())) {
parser.emitError(typeLoc, "matrix columns' elements must be of "
"Float type, got ")
<< t.getElementType();
@@ -239,7 +239,7 @@ static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
if (parser.parseType(type))
return Type();
- if (!type.isa<ImageType>()) {
+ if (!llvm::isa<ImageType>(type)) {
parser.emitError(typeLoc,
"sampled image must be composed using image type, got ")
<< type;
@@ -939,12 +939,12 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
Attribute attr = attribute.getValue();
if (symbol == spirv::getEntryPointABIAttrName()) {
- if (!attr.isa<spirv::EntryPointABIAttr>()) {
+ if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
return op->emitError("'")
<< symbol << "' attribute must be an entry point ABI attribute";
}
} else if (symbol == spirv::getTargetEnvAttrName()) {
- if (!attr.isa<spirv::TargetEnvAttr>())
+ if (!llvm::isa<spirv::TargetEnvAttr>(attr))
return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
} else {
return op->emitError("found unsupported '")
@@ -965,7 +965,7 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
return emitError(loc, "found unsupported '")
<< symbol << "' attribute on region argument";
- auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
+ auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
if (!varABIAttr)
return emitError(loc, "'")
<< symbol << "' must be a spirv::InterfaceVarABIAttr";
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2ad249773a6fe..a51ca20572387 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -81,7 +81,7 @@ static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColon() || parser.parseType(type))
return failure();
- auto fnType = type.dyn_cast<FunctionType>();
+ auto fnType = llvm::dyn_cast<FunctionType>(type);
if (!fnType) {
parser.emitError(loc, "expected function type");
return failure();
@@ -141,7 +141,7 @@ static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
return failure();
}
auto valueAttr = constOp.getValue();
- auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
+ auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
if (!integerValueAttr) {
return failure();
}
@@ -181,11 +181,11 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
attrName, attr))
return failure();
- if (!attrVal.isa<StringAttr>())
+ if (!llvm::isa<StringAttr>(attrVal))
return parser.emitError(loc, "expected ")
<< attrName << " attribute specified as string";
- auto attrOptional =
- spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
+ auto attrOptional = spirv::symbolizeEnum<EnumClass>(
+ llvm::cast<StringAttr>(attrVal).getValue());
if (!attrOptional)
return parser.emitError(loc, "invalid ")
<< attrName << " attribute specification: " << attrVal;
@@ -430,23 +430,23 @@ static LogicalResult verifyCastOp(Operation *op,
Type resultType = op->getResult(0).getType();
// ODS checks that result type and operand type have the same shape.
- if (auto vectorType = operandType.dyn_cast<VectorType>()) {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
operandType = vectorType.getElementType();
- resultType = resultType.cast<VectorType>().getElementType();
+ resultType = llvm::cast<VectorType>(resultType).getElementType();
}
if (auto coopMatrixType =
- operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
+ llvm::dyn_cast<spirv::CooperativeMatrixNVType>(operandType)) {
operandType = coopMatrixType.getElementType();
resultType =
- resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
+ llvm::cast<spirv::CooperativeMatrixNVType>(resultType).getElementType();
}
if (auto jointMatrixType =
- operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
+ llvm::dyn_cast<spirv::JointMatrixINTELType>(operandType)) {
operandType = jointMatrixType.getElementType();
resultType =
- resultType.cast<spirv::JointMatrixINTELType>().getElementType();
+ llvm::cast<spirv::JointMatrixINTELType>(resultType).getElementType();
}
auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
@@ -490,7 +490,7 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
return success();
}
- auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
+ auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
@@ -534,7 +534,7 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
return success();
}
- auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
+ auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
@@ -589,7 +589,7 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
// TODO: Check that the value type satisfies restrictions of
// SPIR-V OpLoad/OpStore operations
if (val.getType() !=
- ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
+ llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
return op.emitOpError("mismatch in result type and pointer type");
}
return success();
@@ -599,10 +599,11 @@ template <typename BlockReadWriteOpTy>
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
Value ptr, Value val) {
auto valType = val.getType();
- if (auto valVecTy = valType.dyn_cast<VectorType>())
+ if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
valType = valVecTy.getElementType();
- if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
+ if (valType !=
+ llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
return op.emitOpError("mismatch in result type and pointer type");
}
return success();
@@ -674,7 +675,7 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
// Get bit width of types.
static unsigned getBitWidth(Type type) {
- if (type.isa<spirv::PointerType>()) {
+ if (llvm::isa<spirv::PointerType>(type)) {
// Just return 64 bits for pointer types for now.
// TODO: Make sure not caller relies on the actual pointer width value.
return 64;
@@ -683,7 +684,7 @@ static unsigned getBitWidth(Type type) {
if (type.isIntOrFloat())
return type.getIntOrFloatBitWidth();
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
assert(vectorType.getElementType().isIntOrFloat());
return vectorType.getNumElements() *
vectorType.getElementType().getIntOrFloatBitWidth();
@@ -703,7 +704,7 @@ getElementType(Type type, ArrayRef<int32_t> indices,
}
for (auto index : indices) {
- if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
+ if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
if (cType.hasCompileTimeKnownNumElements() &&
(index < 0 ||
static_cast<uint64_t>(index) >= cType.getNumElements())) {
@@ -723,7 +724,7 @@ getElementType(Type type, ArrayRef<int32_t> indices,
static Type
getElementType(Type type, Attribute indices,
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
- auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
+ auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
if (!indicesArrayAttr) {
emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
return nullptr;
@@ -735,7 +736,7 @@ getElementType(Type type, Attribute indices,
SmallVector<int32_t, 2> indexVals;
for (auto indexAttr : indicesArrayAttr) {
- auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
+ auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
if (!indexIntAttr) {
emitErrorFn("expected an 32-bit integer for index, but found '")
<< indexAttr << "'";
@@ -769,7 +770,7 @@ static inline bool isMergeBlock(Block &block) {
template <typename ExtendedBinaryOp>
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
- auto resultType = op.getType().template cast<spirv::StructType>();
+ auto resultType = llvm::cast<spirv::StructType>(op.getType());
if (resultType.getNumElements() != 2)
return op.emitOpError("expected result struct type containing two members");
@@ -794,7 +795,7 @@ static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
if (parser.parseType(resultType))
return failure();
- auto structType = resultType.dyn_cast<spirv::StructType>();
+ auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
if (!structType || structType.getNumElements() != 2)
return parser.emitError(loc, "expected spirv.struct type with two members");
@@ -836,7 +837,7 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
parser.getCurrentLocation(&loc) || parser.parseColonType(type))
return failure();
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return parser.emitError(loc, "expected pointer type");
@@ -877,9 +878,9 @@ StringRef stringifyTypeName<FloatType>() {
// Verifies an atomic update op.
template <typename ExpectedElementType>
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
- auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
+ auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
auto elementType = ptrType.getPointeeType();
- if (!elementType.isa<ExpectedElementType>())
+ if (!llvm::isa<ExpectedElementType>(elementType))
return op->emitOpError() << "pointer operand must point to an "
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;
@@ -990,7 +991,7 @@ static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
static Type getUnaryOpResultType(Type operandType) {
Builder builder(operandType.getContext());
Type resultType = builder.getIntegerType(1);
- if (auto vecType = operandType.dyn_cast<VectorType>())
+ if (auto vecType = llvm::dyn_cast<VectorType>(operandType))
return VectorType::get(vecType.getNumElements(), resultType);
return resultType;
}
@@ -1010,7 +1011,7 @@ static LogicalResult verifyShiftOp(Operation *op) {
//===----------------------------------------------------------------------===//
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType) {
emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
"to composite type, but provided ")
@@ -1023,7 +1024,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
int32_t index = 0;
for (auto indexSSA : indices) {
- auto cType = resultType.dyn_cast<spirv::CompositeType>();
+ auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
if (!cType) {
emitError(
baseLoc,
@@ -1032,7 +1033,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
return nullptr;
}
index = 0;
- if (resultType.isa<spirv::StructType>()) {
+ if (llvm::isa<spirv::StructType>(resultType)) {
Operation *op = indexSSA.getDefiningOp();
if (!op) {
emitError(baseLoc, "'spirv.AccessChain' op index must be an "
@@ -1134,7 +1135,7 @@ static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
return failure();
auto providedResultType =
- accessChainOp.getType().template dyn_cast<spirv::PointerType>();
+ llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
if (!providedResultType)
return accessChainOp.emitOpError(
"result type must be a pointer, but provided")
@@ -1201,7 +1202,7 @@ static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
if (parser.parseColonType(type))
return failure();
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return parser.emitError(loc, "expected pointer type");
@@ -1231,10 +1232,9 @@ static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
"result, but found ")
<< atomOp.getComparator().getType() << " vs " << atomOp.getType();
- Type pointeeType = atomOp.getPointer()
- .getType()
- .template cast<spirv::PointerType>()
- .getPointeeType();
+ Type pointeeType =
+ llvm::cast<spirv::PointerType>(atomOp.getPointer().getType())
+ .getPointeeType();
if (atomOp.getType() != pointeeType)
return atomOp.emitOpError(
"pointer operand's pointee type must have the same "
@@ -1322,7 +1322,7 @@ ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
if (parser.parseColonType(type))
return failure();
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return parser.emitError(loc, "expected pointer type");
@@ -1340,7 +1340,7 @@ LogicalResult spirv::AtomicExchangeOp::verify() {
<< getValue().getType() << " vs " << getType();
Type pointeeType =
- getPointer().getType().cast<spirv::PointerType>().getPointeeType();
+ llvm::cast<spirv::PointerType>(getPointer().getType()).getPointeeType();
if (getType() != pointeeType)
return emitOpError("pointer operand's pointee type must have the same "
"as the op result type, but found ")
@@ -1537,13 +1537,13 @@ LogicalResult spirv::BitcastOp::verify() {
if (operandType == resultType) {
return emitError("result type must be
diff erent from operand type");
}
- if (operandType.isa<spirv::PointerType>() &&
- !resultType.isa<spirv::PointerType>()) {
+ if (llvm::isa<spirv::PointerType>(operandType) &&
+ !llvm::isa<spirv::PointerType>(resultType)) {
return emitError(
"unhandled bit cast conversion from pointer type to non-pointer type");
}
- if (!operandType.isa<spirv::PointerType>() &&
- resultType.isa<spirv::PointerType>()) {
+ if (!llvm::isa<spirv::PointerType>(operandType) &&
+ llvm::isa<spirv::PointerType>(resultType)) {
return emitError(
"unhandled bit cast conversion from non-pointer type to pointer type");
}
@@ -1562,8 +1562,8 @@ LogicalResult spirv::BitcastOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::PtrCastToGenericOp::verify() {
- auto operandType = getPointer().getType().cast<spirv::PointerType>();
- auto resultType = getResult().getType().cast<spirv::PointerType>();
+ auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Workgroup &&
@@ -1590,8 +1590,8 @@ LogicalResult spirv::PtrCastToGenericOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GenericCastToPtrOp::verify() {
- auto operandType = getPointer().getType().cast<spirv::PointerType>();
- auto resultType = getResult().getType().cast<spirv::PointerType>();
+ auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
@@ -1618,8 +1618,8 @@ LogicalResult spirv::GenericCastToPtrOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GenericCastToPtrExplicitOp::verify() {
- auto operandType = getPointer().getType().cast<spirv::PointerType>();
- auto resultType = getResult().getType().cast<spirv::PointerType>();
+ auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
@@ -1719,7 +1719,7 @@ void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
if (auto weights = getBranchWeights()) {
printer << " [";
llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
- printer << a.cast<IntegerAttr>().getInt();
+ printer << llvm::cast<IntegerAttr>(a).getInt();
});
printer << "]";
}
@@ -1736,7 +1736,7 @@ LogicalResult spirv::BranchConditionalOp::verify() {
return emitOpError("must have exactly two branch weights");
}
if (llvm::all_of(*weights, [](Attribute attr) {
- return attr.cast<IntegerAttr>().getValue().isZero();
+ return llvm::cast<IntegerAttr>(attr).getValue().isZero();
}))
return emitOpError("branch weights cannot both be zero");
}
@@ -1749,10 +1749,10 @@ LogicalResult spirv::BranchConditionalOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::CompositeConstructOp::verify() {
- auto cType = getType().cast<spirv::CompositeType>();
+ auto cType = llvm::cast<spirv::CompositeType>(getType());
operand_range constituents = this->getConstituents();
- if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
+ if (auto coopType = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(cType)) {
if (constituents.size() != 1)
return emitOpError("has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
@@ -1763,7 +1763,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
return success();
}
- if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
+ if (auto jointType = llvm::dyn_cast<spirv::JointMatrixINTELType>(cType)) {
if (constituents.size() != 1)
return emitOpError("has incorrect number of operands: expected ")
<< "1, but provided " << constituents.size();
@@ -1787,7 +1787,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
// If not constructing a cooperative matrix type, then we must be constructing
// a vector type.
- auto resultType = cType.dyn_cast<VectorType>();
+ auto resultType = llvm::dyn_cast<VectorType>(cType);
if (!resultType)
return emitOpError(
"expected to return a vector or cooperative matrix when the number of "
@@ -1795,14 +1795,14 @@ LogicalResult spirv::CompositeConstructOp::verify() {
SmallVector<unsigned> sizes;
for (Value component : constituents) {
- if (!component.getType().isa<VectorType>() &&
+ if (!llvm::isa<VectorType>(component.getType()) &&
!component.getType().isIntOrFloat())
return emitOpError("operand type mismatch: expected operand to have "
"a scalar or vector type, but provided ")
<< component.getType();
Type elementType = component.getType();
- if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
sizes.push_back(vectorType.getNumElements());
elementType = vectorType.getElementType();
} else {
@@ -1866,7 +1866,7 @@ void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::CompositeExtractOp::verify() {
- auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
+ auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
auto resultType =
getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
if (!resultType)
@@ -1909,7 +1909,7 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
}
LogicalResult spirv::CompositeInsertOp::verify() {
- auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
+ auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
auto objectType =
getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
if (!objectType)
@@ -1946,9 +1946,9 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
return failure();
Type type = NoneType::get(parser.getContext());
- if (auto typedAttr = value.dyn_cast<TypedAttr>())
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
type = typedAttr.getType();
- if (type.isa<NoneType, TensorType>()) {
+ if (llvm::isa<NoneType, TensorType>(type)) {
if (parser.parseColonType(type))
return failure();
}
@@ -1958,25 +1958,25 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
void spirv::ConstantOp::print(OpAsmPrinter &printer) {
printer << ' ' << getValue();
- if (getType().isa<spirv::ArrayType>())
+ if (llvm::isa<spirv::ArrayType>(getType()))
printer << " : " << getType();
}
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
Type opType) {
- if (value.isa<IntegerAttr, FloatAttr>()) {
- auto valueType = value.cast<TypedAttr>().getType();
+ if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
+ auto valueType = llvm::cast<TypedAttr>(value).getType();
if (valueType != opType)
return op.emitOpError("result type (")
<< opType << ") does not match value type (" << valueType << ")";
return success();
}
- if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
- auto valueType = value.cast<TypedAttr>().getType();
+ if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
+ auto valueType = llvm::cast<TypedAttr>(value).getType();
if (valueType == opType)
return success();
- auto arrayType = opType.dyn_cast<spirv::ArrayType>();
- auto shapedType = valueType.dyn_cast<ShapedType>();
+ auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
+ auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
if (!arrayType)
return op.emitOpError("result or element type (")
<< opType << ") does not match value type (" << valueType
@@ -1984,7 +1984,7 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
int numElements = arrayType.getNumElements();
auto opElemType = arrayType.getElementType();
- while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
+ while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
numElements *= t.getNumElements();
opElemType = t.getElementType();
}
@@ -2005,8 +2005,8 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
}
return success();
}
- if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
- auto arrayType = opType.dyn_cast<spirv::ArrayType>();
+ if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
+ auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
if (!arrayType)
return op.emitOpError(
"must have spirv.array result type for array value");
@@ -2030,12 +2030,12 @@ LogicalResult spirv::ConstantOp::verify() {
bool spirv::ConstantOp::isBuildableWith(Type type) {
// Must be valid SPIR-V type first.
- if (!type.isa<spirv::SPIRVType>())
+ if (!llvm::isa<spirv::SPIRVType>(type))
return false;
if (isa<SPIRVDialect>(type.getDialect())) {
// TODO: support constant struct
- return type.isa<spirv::ArrayType>();
+ return llvm::isa<spirv::ArrayType>(type);
}
return true;
@@ -2043,7 +2043,7 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
OpBuilder &builder) {
- if (auto intType = type.dyn_cast<IntegerType>()) {
+ if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
return builder.create<spirv::ConstantOp>(loc, type,
@@ -2051,19 +2051,19 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
return builder.create<spirv::ConstantOp>(
loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
}
- if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
return builder.create<spirv::ConstantOp>(
loc, type, builder.getFloatAttr(floatType, 0.0));
}
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
- if (elemType.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(elemType)) {
return builder.create<spirv::ConstantOp>(
loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 0).getValue()));
}
- if (elemType.isa<FloatType>()) {
+ if (llvm::isa<FloatType>(elemType)) {
return builder.create<spirv::ConstantOp>(
loc, type,
DenseFPElementsAttr::get(vectorType,
@@ -2076,7 +2076,7 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
OpBuilder &builder) {
- if (auto intType = type.dyn_cast<IntegerType>()) {
+ if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
return builder.create<spirv::ConstantOp>(loc, type,
@@ -2084,19 +2084,19 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
return builder.create<spirv::ConstantOp>(
loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
}
- if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
return builder.create<spirv::ConstantOp>(
loc, type, builder.getFloatAttr(floatType, 1.0));
}
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
- if (elemType.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(elemType)) {
return builder.create<spirv::ConstantOp>(
loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 1).getValue()));
}
- if (elemType.isa<FloatType>()) {
+ if (llvm::isa<FloatType>(elemType)) {
return builder.create<spirv::ConstantOp>(
loc, type,
DenseFPElementsAttr::get(vectorType,
@@ -2115,9 +2115,9 @@ void mlir::spirv::ConstantOp::getAsmResultNames(
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << "cst";
- IntegerType intTy = type.dyn_cast<IntegerType>();
+ IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
- if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) {
+ if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
if (intTy && intTy.getWidth() == 1) {
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
}
@@ -2131,17 +2131,18 @@ void mlir::spirv::ConstantOp::getAsmResultNames(
}
}
- if (intTy || type.isa<FloatType>()) {
+ if (intTy || llvm::isa<FloatType>(type)) {
specialName << '_' << type;
}
- if (auto vecType = type.dyn_cast<VectorType>()) {
+ if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
specialName << "_vec_";
specialName << vecType.getDimSize(0);
Type elementType = vecType.getElementType();
- if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
+ if (llvm::isa<IntegerType>(elementType) ||
+ llvm::isa<FloatType>(elementType)) {
specialName << "x" << elementType;
}
}
@@ -2210,9 +2211,10 @@ LogicalResult spirv::INTELConvertBF16ToFOp::verify() {
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type have the same
// shape.
- if (auto vectorType = operandType.dyn_cast<VectorType>()) {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
+ unsigned resultNumElements =
+ llvm::cast<VectorType>(resultType).getNumElements();
if (operandNumElements != resultNumElements) {
return emitOpError(
"operand and result must have same number of elements");
@@ -2230,9 +2232,10 @@ LogicalResult spirv::INTELConvertFToBF16Op::verify() {
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type have the same
// shape.
- if (auto vectorType = operandType.dyn_cast<VectorType>()) {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
+ unsigned resultNumElements =
+ llvm::cast<VectorType>(resultType).getNumElements();
if (operandNumElements != resultNumElements) {
return emitOpError(
"operand and result must have same number of elements");
@@ -2331,7 +2334,7 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
if (parser.parseAttribute(value, i32Type, "value", attr)) {
return failure();
}
- values.push_back(value.cast<IntegerAttr>().getInt());
+ values.push_back(llvm::cast<IntegerAttr>(value).getInt());
}
result.addAttribute(kValuesAttrName,
parser.getBuilder().getI32ArrayAttr(values));
@@ -2347,7 +2350,7 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
return;
printer << ", ";
llvm::interleaveComma(values, printer, [&](Attribute a) {
- printer << a.cast<IntegerAttr>().getInt();
+ printer << llvm::cast<IntegerAttr>(a).getInt();
});
}
@@ -2677,7 +2680,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
if (parser.parseColonType(type)) {
return failure();
}
- if (!type.isa<spirv::PointerType>()) {
+ if (!llvm::isa<spirv::PointerType>(type)) {
return parser.emitError(loc, "expected spirv.ptr type");
}
result.addAttribute(kTypeAttrName, TypeAttr::get(type));
@@ -2708,7 +2711,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::GlobalVariableOp::verify() {
- if (!getType().isa<spirv::PointerType>())
+ if (!llvm::isa<spirv::PointerType>(getType()))
return emitOpError("result must be of a !spv.ptr type");
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
@@ -2748,7 +2751,7 @@ LogicalResult spirv::GroupBroadcastOp::verify() {
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
- if (auto localIdTy = getLocalid().getType().dyn_cast<VectorType>())
+ if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
return emitOpError("localid is a vector and can be with only "
" 2 or 3 components, actual number is ")
@@ -2839,7 +2842,7 @@ ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
}
auto ptrType = spirv::PointerType::get(elementType, storageClass);
- if (auto valVecTy = elementType.dyn_cast<VectorType>())
+ if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
@@ -2879,7 +2882,7 @@ ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
}
auto ptrType = spirv::PointerType::get(elementType, storageClass);
- if (auto valVecTy = elementType.dyn_cast<VectorType>())
+ if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
@@ -3148,7 +3151,7 @@ void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
Value basePtr, MemoryAccessAttr memoryAccess,
IntegerAttr alignment) {
- auto ptrType = basePtr.getType().cast<spirv::PointerType>();
+ auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
alignment);
}
@@ -3177,7 +3180,7 @@ ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
void spirv::LoadOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
- getPtr().getType().cast<spirv::PointerType>().getStorageClass());
+ llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
printer << " \"" << sc << "\" " << getPtr();
printMemoryAccessAttribute(*this, printer, elidedAttrs);
@@ -3494,7 +3497,7 @@ LogicalResult spirv::ModuleOp::verifyRegions() {
}
if (auto interface = entryPointOp.getInterface()) {
for (Attribute varRef : interface) {
- auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
+ auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
if (!varSymRef) {
return entryPointOp.emitError(
"expected symbol reference for interface "
@@ -3587,8 +3590,8 @@ LogicalResult spirv::ReturnValueOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::SelectOp::verify() {
- if (auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) {
- auto resultVectorTy = getResult().getType().dyn_cast<VectorType>();
+ if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
+ auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
if (!resultVectorTy) {
return emitOpError("result expected to be of vector type when "
"condition is of vector type");
@@ -3760,9 +3763,9 @@ LogicalResult spirv::SpecConstantOp::verify() {
return emitOpError("SpecId cannot be negative");
auto value = getDefaultValue();
- if (value.isa<IntegerAttr, FloatAttr>()) {
+ if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
// Make sure bitwidth is allowed.
- if (!value.getType().isa<spirv::SPIRVType>())
+ if (!llvm::isa<spirv::SPIRVType>(value.getType()))
return emitOpError("default value bitwidth disallowed");
return success();
}
@@ -3798,7 +3801,7 @@ ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
void spirv::StoreOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
- getPtr().getType().cast<spirv::PointerType>().getStorageClass());
+ llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
printMemoryAccessAttribute(*this, printer, elidedAttrs);
@@ -3861,7 +3864,7 @@ ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
if (parser.parseType(type))
return failure();
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return parser.emitError(loc, "expected spirv.ptr type");
result.addTypes(ptrType);
@@ -3901,7 +3904,7 @@ LogicalResult spirv::VariableOp::verify() {
"spirv.GlobalVariable for module-level variables.");
}
- auto pointerType = getPointer().getType().cast<spirv::PointerType>();
+ auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
if (getStorageClass() != pointerType.getStorageClass())
return emitOpError(
"storage class must match result pointer's storage class");
@@ -3940,7 +3943,7 @@ LogicalResult spirv::VariableOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::VectorShuffleOp::verify() {
- VectorType resultType = getType().cast<VectorType>();
+ VectorType resultType = llvm::cast<VectorType>(getType());
size_t numResultElements = resultType.getNumElements();
if (numResultElements != getComponents().size())
@@ -3950,8 +3953,8 @@ LogicalResult spirv::VectorShuffleOp::verify() {
<< getComponents().size() << ")";
size_t totalSrcElements =
- getVector1().getType().cast<VectorType>().getNumElements() +
- getVector2().getType().cast<VectorType>().getNumElements();
+ llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
+ llvm::cast<VectorType>(getVector2().getType()).getNumElements();
for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
uint32_t index = selector.getZExtValue();
@@ -4001,13 +4004,14 @@ void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
Type coopMatrix) {
- Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
- if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
+ Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
+ if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
+ !llvm::isa<VectorType>(pointeeType))
return op->emitError(
"Pointer must point to a scalar or vector type but provided ")
<< pointeeType;
spirv::StorageClass storage =
- pointer.cast<spirv::PointerType>().getStorageClass();
+ llvm::cast<spirv::PointerType>(pointer).getStorageClass();
if (storage != spirv::StorageClass::Workgroup &&
storage != spirv::StorageClass::StorageBuffer &&
storage != spirv::StorageClass::PhysicalStorageBuffer)
@@ -4071,10 +4075,11 @@ static LogicalResult
verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
if (op.getC().getType() != op.getResult().getType())
return op.emitOpError("result and third operand must have the same type");
- auto typeA = op.getA().getType().cast<spirv::CooperativeMatrixNVType>();
- auto typeB = op.getB().getType().cast<spirv::CooperativeMatrixNVType>();
- auto typeC = op.getC().getType().cast<spirv::CooperativeMatrixNVType>();
- auto typeR = op.getResult().getType().cast<spirv::CooperativeMatrixNVType>();
+ auto typeA = llvm::cast<spirv::CooperativeMatrixNVType>(op.getA().getType());
+ auto typeB = llvm::cast<spirv::CooperativeMatrixNVType>(op.getB().getType());
+ auto typeC = llvm::cast<spirv::CooperativeMatrixNVType>(op.getC().getType());
+ auto typeR =
+ llvm::cast<spirv::CooperativeMatrixNVType>(op.getResult().getType());
if (typeA.getRows() != typeR.getRows() ||
typeA.getColumns() != typeB.getRows() ||
typeB.getColumns() != typeR.getColumns())
@@ -4086,8 +4091,8 @@ verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
auto elementTypeA = typeA.getElementType();
auto elementTypeB = typeB.getElementType();
if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
- if (elementTypeA.cast<IntegerType>().getWidth() !=
- elementTypeB.cast<IntegerType>().getWidth())
+ if (llvm::cast<IntegerType>(elementTypeA).getWidth() !=
+ llvm::cast<IntegerType>(elementTypeB).getWidth())
return op.emitOpError(
"matrix A and B integer element types must be the same bit width");
} else if (elementTypeA != elementTypeB) {
@@ -4105,13 +4110,14 @@ LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
static LogicalResult
verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
- Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
- if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
+ Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
+ if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
+ !llvm::isa<VectorType>(pointeeType))
return op->emitError(
"Pointer must point to a scalar or vector type but provided ")
<< pointeeType;
spirv::StorageClass storage =
- pointer.cast<spirv::PointerType>().getStorageClass();
+ llvm::cast<spirv::PointerType>(pointer).getStorageClass();
if (storage != spirv::StorageClass::Workgroup &&
storage != spirv::StorageClass::CrossWorkgroup &&
storage != spirv::StorageClass::UniformConstant &&
@@ -4147,10 +4153,11 @@ LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
if (op.getC().getType() != op.getResult().getType())
return op.emitOpError("result and third operand must have the same type");
- auto typeA = op.getA().getType().cast<spirv::JointMatrixINTELType>();
- auto typeB = op.getB().getType().cast<spirv::JointMatrixINTELType>();
- auto typeC = op.getC().getType().cast<spirv::JointMatrixINTELType>();
- auto typeR = op.getResult().getType().cast<spirv::JointMatrixINTELType>();
+ auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType());
+ auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType());
+ auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType());
+ auto typeR =
+ llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType());
if (typeA.getRows() != typeR.getRows() ||
typeA.getColumns() != typeB.getRows() ||
typeB.getColumns() != typeR.getColumns())
@@ -4174,8 +4181,8 @@ LogicalResult spirv::INTELJointMatrixMadOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::MatrixTimesScalarOp::verify() {
- if (auto inputCoopmat =
- getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) {
+ if (auto inputCoopmat = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(
+ getMatrix().getType())) {
if (inputCoopmat.getElementType() != getScalar().getType())
return emitError("input matrix components' type and scaling value must "
"have the same type");
@@ -4183,7 +4190,7 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
}
// Check that the scalar type is the same as the matrix element type.
- auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
+ auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
if (getScalar().getType() != inputMatrix.getElementType())
return emitError("input matrix components' type and scaling value must "
"have the same type");
@@ -4199,11 +4206,11 @@ void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
printer << ' ';
StringRef targetStorageClass = stringifyStorageClass(
- getTarget().getType().cast<spirv::PointerType>().getStorageClass());
+ llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
StringRef sourceStorageClass = stringifyStorageClass(
- getSource().getType().cast<spirv::PointerType>().getStorageClass());
+ llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
printer << " \"" << sourceStorageClass << "\" " << getSource();
SmallVector<StringRef, 4> elidedAttrs;
@@ -4215,7 +4222,7 @@ void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
Type pointeeType =
- getTarget().getType().cast<spirv::PointerType>().getPointeeType();
+ llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
printer << " : " << pointeeType;
}
@@ -4263,10 +4270,10 @@ ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
LogicalResult spirv::CopyMemoryOp::verify() {
Type targetType =
- getTarget().getType().cast<spirv::PointerType>().getPointeeType();
+ llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
Type sourceType =
- getSource().getType().cast<spirv::PointerType>().getPointeeType();
+ llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
if (targetType != sourceType)
return emitOpError("both operands must be pointers to the same type");
@@ -4290,8 +4297,8 @@ LogicalResult spirv::CopyMemoryOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::TransposeOp::verify() {
- auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
- auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
+ auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
+ auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
// Verify that the input and output matrices have correct shapes.
if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
@@ -4315,9 +4322,9 @@ LogicalResult spirv::TransposeOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::MatrixTimesMatrixOp::verify() {
- auto leftMatrix = getLeftmatrix().getType().cast<spirv::MatrixType>();
- auto rightMatrix = getRightmatrix().getType().cast<spirv::MatrixType>();
- auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
+ auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
+ auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
+ auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
// left matrix columns' count and right matrix rows' count must be equal
if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
@@ -4403,16 +4410,16 @@ void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::SpecConstantCompositeOp::verify() {
- auto cType = getType().dyn_cast<spirv::CompositeType>();
+ auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
auto constituents = this->getConstituents().getValue();
if (!cType)
return emitError("result type must be a composite type, but provided ")
<< getType();
- if (cType.isa<spirv::CooperativeMatrixNVType>())
+ if (llvm::isa<spirv::CooperativeMatrixNVType>(cType))
return emitError("unsupported composite type ") << cType;
- if (cType.isa<spirv::JointMatrixINTELType>())
+ if (llvm::isa<spirv::JointMatrixINTELType>(cType))
return emitError("unsupported composite type ") << cType;
if (constituents.size() != cType.getNumElements())
return emitError("has incorrect number of operands: expected ")
@@ -4420,7 +4427,7 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
<< constituents.size();
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
- auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
+ auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
auto constituentSpecConstOp =
dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
@@ -4498,19 +4505,19 @@ LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
LogicalResult spirv::GLFrexpStructOp::verify() {
spirv::StructType structTy =
- getResult().getType().dyn_cast<spirv::StructType>();
+ llvm::dyn_cast<spirv::StructType>(getResult().getType());
if (structTy.getNumElements() != 2)
return emitError("result type must be a struct type with two memebers");
Type significandTy = structTy.getElementType(0);
Type exponentTy = structTy.getElementType(1);
- VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
- IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
+ VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
+ IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
Type operandTy = getOperand().getType();
- VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
- FloatType operandFTy = operandTy.dyn_cast<FloatType>();
+ VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
+ FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
if (significandTy != operandTy)
return emitError("member zero of the resulting struct type must be the "
@@ -4518,7 +4525,7 @@ LogicalResult spirv::GLFrexpStructOp::verify() {
if (exponentVecTy) {
IntegerType componentIntTy =
- exponentVecTy.getElementType().dyn_cast<IntegerType>();
+ llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
if (!componentIntTy || componentIntTy.getWidth() != 32)
return emitError("member one of the resulting struct type must"
"be a scalar or vector of 32 bit integer type");
@@ -4547,11 +4554,12 @@ LogicalResult spirv::GLLdexpOp::verify() {
Type significandType = getX().getType();
Type exponentType = getExp().getType();
- if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
+ if (llvm::isa<FloatType>(significandType) !=
+ llvm::isa<IntegerType>(exponentType))
return emitOpError("operands must both be scalars or vectors");
auto getNumElements = [](Type type) -> unsigned {
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return vectorType.getNumElements();
return 1;
};
@@ -4567,17 +4575,19 @@ LogicalResult spirv::GLLdexpOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ImageDrefGatherOp::verify() {
- VectorType resultType = getResult().getType().cast<VectorType>();
+ VectorType resultType = llvm::cast<VectorType>(getResult().getType());
auto sampledImageType =
- getSampledimage().getType().cast<spirv::SampledImageType>();
- auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
+ llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
+ auto imageType =
+ llvm::cast<spirv::ImageType>(sampledImageType.getImageType());
if (resultType.getNumElements() != 4)
return emitOpError("result type must be a vector of four components");
Type elementType = resultType.getElementType();
Type sampledElementType = imageType.getElementType();
- if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
+ if (!llvm::isa<NoneType>(sampledElementType) &&
+ elementType != sampledElementType)
return emitOpError(
"the component type of result must be the same as sampled type of the "
"underlying image type");
@@ -4629,7 +4639,8 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ImageQuerySizeOp::verify() {
- spirv::ImageType imageType = getImage().getType().cast<spirv::ImageType>();
+ spirv::ImageType imageType =
+ llvm::cast<spirv::ImageType>(getImage().getType());
Type resultType = getResult().getType();
spirv::Dim dim = imageType.getDim();
@@ -4677,7 +4688,7 @@ LogicalResult spirv::ImageQuerySizeOp::verify() {
componentNumber += 1;
unsigned resultComponentNumber = 1;
- if (auto resultVectorType = resultType.dyn_cast<VectorType>())
+ if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
resultComponentNumber = resultVectorType.getNumElements();
if (componentNumber != resultComponentNumber)
@@ -4798,7 +4809,7 @@ LogicalResult spirv::PtrAccessChainOp::verify() {
LogicalResult spirv::VectorTimesScalarOp::verify() {
if (getVector().getType() != getType())
return emitOpError("vector operand and result type mismatch");
- auto scalarType = getType().cast<VectorType>().getElementType();
+ auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
if (getScalar().getType() != scalarType)
return emitOpError("scalar operand and result element type match");
return success();
@@ -4851,11 +4862,11 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
return op->emitOpError("requires the same type for both vector operands");
unsigned expectedNumAttrs = 0;
- if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
+ if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
++expectedNumAttrs;
auto packedVectorFormat =
- op->getAttr(kPackedVectorFormatAttrName)
- .dyn_cast_or_null<spirv::PackedVectorFormatAttr>();
+ llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
+ op->getAttr(kPackedVectorFormatAttrName));
if (!packedVectorFormat)
return op->emitOpError("requires Packed Vector Format attribute for "
"integer vector operands");
@@ -4927,9 +4938,9 @@ getIntegerDotProductCapabilities(Operation *op) {
SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
Type factorTy = op->getOperand(0).getType();
- if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
- auto formatAttr = op->getAttr(kPackedVectorFormatAttrName)
- .cast<spirv::PackedVectorFormatAttr>();
+ if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
+ auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
+ op->getAttr(kPackedVectorFormatAttrName));
if (formatAttr.getValue() ==
spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
capabilities.push_back(dotProductInput4x8BitPackedCap);
@@ -4937,7 +4948,7 @@ getIntegerDotProductCapabilities(Operation *op) {
return capabilities;
}
- auto vecTy = factorTy.cast<VectorType>();
+ auto vecTy = llvm::cast<VectorType>(factorTy);
if (vecTy.getElementTypeBitWidth() == 8) {
capabilities.push_back(dotProductInput4x8BitCap);
return capabilities;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index d2f2cb8e3ce65..fbcc5d84a2701 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -68,17 +68,18 @@ unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+ llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
}
void ArrayType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
- getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+ llvm::cast<SPIRVType>(getElementType())
+ .getCapabilities(capabilities, storage);
}
std::optional<int64_t> ArrayType::getSizeInBytes() {
- auto elementType = getElementType().cast<SPIRVType>();
+ auto elementType = llvm::cast<SPIRVType>(getElementType());
std::optional<int64_t> size = elementType.getSizeInBytes();
if (!size)
return std::nullopt;
@@ -90,11 +91,11 @@ std::optional<int64_t> ArrayType::getSizeInBytes() {
//===----------------------------------------------------------------------===//
bool CompositeType::classof(Type type) {
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return isValid(vectorType);
- return type.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType,
- spirv::JointMatrixINTELType, spirv::MatrixType,
- spirv::RuntimeArrayType, spirv::StructType>();
+ return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixNVType,
+ spirv::JointMatrixINTELType, spirv::MatrixType,
+ spirv::RuntimeArrayType, spirv::StructType>(type);
}
bool CompositeType::isValid(VectorType type) {
@@ -108,7 +109,7 @@ bool CompositeType::isValid(VectorType type) {
default:
return false;
}
- return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
+ return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType());
}
Type CompositeType::getElementType(unsigned index) const {
@@ -160,8 +161,8 @@ void CompositeType::getExtensions(
MatrixType, RuntimeArrayType, StructType>(
[&](auto type) { type.getExtensions(extensions, storage); })
.Case<VectorType>([&](VectorType type) {
- return type.getElementType().cast<ScalarType>().getExtensions(
- extensions, storage);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getExtensions(extensions, storage);
})
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -180,8 +181,8 @@ void CompositeType::getCapabilities(
ArrayRef<Capability> ref(caps, std::size(caps));
capabilities.push_back(ref);
}
- return type.getElementType().cast<ScalarType>().getCapabilities(
- capabilities, storage);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getCapabilities(capabilities, storage);
})
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -193,7 +194,7 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
return structType.getSizeInBytes();
if (auto vectorType = dyn_cast<VectorType>()) {
std::optional<int64_t> elementSize =
- vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
+ llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
if (!elementSize)
return std::nullopt;
return *elementSize * vectorType.getNumElements();
@@ -249,7 +250,7 @@ unsigned CooperativeMatrixNVType::getColumns() const {
void CooperativeMatrixNVType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+ llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
ArrayRef<Extension> ref(exts, std::size(exts));
extensions.push_back(ref);
@@ -258,7 +259,8 @@ void CooperativeMatrixNVType::getExtensions(
void CooperativeMatrixNVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
- getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+ llvm::cast<SPIRVType>(getElementType())
+ .getCapabilities(capabilities, storage);
static const Capability caps[] = {Capability::CooperativeMatrixNV};
ArrayRef<Capability> ref(caps, std::size(caps));
capabilities.push_back(ref);
@@ -317,7 +319,7 @@ MatrixLayout JointMatrixINTELType::getMatrixLayout() const {
void JointMatrixINTELType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+ llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
static const Extension exts[] = {Extension::SPV_INTEL_joint_matrix};
ArrayRef<Extension> ref(exts, std::size(exts));
extensions.push_back(ref);
@@ -326,7 +328,8 @@ void JointMatrixINTELType::getExtensions(
void JointMatrixINTELType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
- getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+ llvm::cast<SPIRVType>(getElementType())
+ .getCapabilities(capabilities, storage);
static const Capability caps[] = {Capability::JointMatrixINTEL};
ArrayRef<Capability> ref(caps, std::size(caps));
capabilities.push_back(ref);
@@ -489,8 +492,8 @@ void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
- getPointeeType().cast<SPIRVType>().getExtensions(extensions,
- getStorageClass());
+ llvm::cast<SPIRVType>(getPointeeType())
+ .getExtensions(extensions, getStorageClass());
if (auto scExts = spirv::getExtensions(getStorageClass()))
extensions.push_back(*scExts);
@@ -501,8 +504,8 @@ void PointerType::getCapabilities(
std::optional<StorageClass> storage) {
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
- getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
- getStorageClass());
+ llvm::cast<SPIRVType>(getPointeeType())
+ .getCapabilities(capabilities, getStorageClass());
if (auto scCaps = spirv::getCapabilities(getStorageClass()))
capabilities.push_back(*scCaps);
@@ -547,7 +550,7 @@ unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
void RuntimeArrayType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+ llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
}
void RuntimeArrayType::getCapabilities(
@@ -558,7 +561,8 @@ void RuntimeArrayType::getCapabilities(
ArrayRef<Capability> ref(caps, std::size(caps));
capabilities.push_back(ref);
}
- getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+ llvm::cast<SPIRVType>(getElementType())
+ .getCapabilities(capabilities, storage);
}
//===----------------------------------------------------------------------===//
@@ -566,10 +570,10 @@ void RuntimeArrayType::getCapabilities(
//===----------------------------------------------------------------------===//
bool ScalarType::classof(Type type) {
- if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
return isValid(floatType);
}
- if (auto intType = type.dyn_cast<IntegerType>()) {
+ if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
return isValid(intType);
}
return false;
@@ -723,9 +727,9 @@ bool SPIRVType::classof(Type type) {
// Allow SPIR-V dialect types
if (llvm::isa<SPIRVDialect>(type.getDialect()))
return true;
- if (type.isa<ScalarType>())
+ if (llvm::isa<ScalarType>(type))
return true;
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return CompositeType::isValid(vectorType);
return false;
}
@@ -815,7 +819,7 @@ Type SampledImageType::getImageType() const { return getImpl()->imageType; }
LogicalResult
SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
Type imageType) {
- if (!imageType.isa<ImageType>())
+ if (!llvm::isa<ImageType>(imageType))
return emitError() << "expected image type";
return success();
@@ -824,13 +828,13 @@ SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
void SampledImageType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- getImageType().cast<ImageType>().getExtensions(extensions, storage);
+ llvm::cast<ImageType>(getImageType()).getExtensions(extensions, storage);
}
void SampledImageType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
- getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
+ llvm::cast<ImageType>(getImageType()).getCapabilities(capabilities, storage);
}
//===----------------------------------------------------------------------===//
@@ -1125,14 +1129,14 @@ StructType::trySetBody(ArrayRef<Type> memberTypes,
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
for (Type elementType : getElementTypes())
- elementType.cast<SPIRVType>().getExtensions(extensions, storage);
+ llvm::cast<SPIRVType>(elementType).getExtensions(extensions, storage);
}
void StructType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
for (Type elementType : getElementTypes())
- elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
+ llvm::cast<SPIRVType>(elementType).getCapabilities(capabilities, storage);
}
llvm::hash_code spirv::hash_value(
@@ -1186,7 +1190,7 @@ LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "matrix columns must be vectors of floats";
/// The underlying vectors (columns) must be of size 2, 3, or 4
- ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
+ ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
if (columnShape.size() != 1)
return emitError() << "matrix columns must be 1D vectors";
@@ -1198,8 +1202,8 @@ LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
/// Returns true if the matrix elements are vectors of float elements
bool MatrixType::isValidColumnType(Type columnType) {
- if (auto vectorType = columnType.dyn_cast<VectorType>()) {
- if (vectorType.getElementType().isa<FloatType>())
+ if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
+ if (llvm::isa<FloatType>(vectorType.getElementType()))
return true;
}
return false;
@@ -1208,13 +1212,13 @@ bool MatrixType::isValidColumnType(Type columnType) {
Type MatrixType::getColumnType() const { return getImpl()->columnType; }
Type MatrixType::getElementType() const {
- return getImpl()->columnType.cast<VectorType>().getElementType();
+ return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
}
unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
unsigned MatrixType::getNumRows() const {
- return getImpl()->columnType.cast<VectorType>().getShape()[0];
+ return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
}
unsigned MatrixType::getNumElements() const {
@@ -1223,7 +1227,7 @@ unsigned MatrixType::getNumElements() const {
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
- getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
+ llvm::cast<SPIRVType>(getColumnType()).getExtensions(extensions, storage);
}
void MatrixType::getCapabilities(
@@ -1235,7 +1239,7 @@ void MatrixType::getCapabilities(
capabilities.push_back(ref);
}
// Add any capabilities associated with the underlying vectors (i.e., columns)
- getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+ llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index d198b00b9caba..1a056a0d597cc 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -41,14 +41,14 @@ RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
}
bool shape::isExtentTensorType(Type type) {
- auto ranked = type.dyn_cast<RankedTensorType>();
+ auto ranked = llvm::dyn_cast<RankedTensorType>(type);
return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
}
LogicalResult shape::getShapeVec(Value input,
SmallVectorImpl<int64_t> &shapeValues) {
if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
- auto type = inputOp.getArg().getType().cast<ShapedType>();
+ auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
if (!type.hasRank())
return failure();
llvm::append_range(shapeValues, type.getShape());
@@ -64,7 +64,7 @@ LogicalResult shape::getShapeVec(Value input,
static bool isErrorPropagationPossible(TypeRange operandTypes) {
return llvm::any_of(operandTypes, [](Type ty) {
- return ty.isa<SizeType, ShapeType, ValueShapeType>();
+ return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
});
}
@@ -72,7 +72,7 @@ static LogicalResult verifySizeOrIndexOp(Operation *op) {
assert(op != nullptr && op->getNumResults() == 1);
Type resultTy = op->getResultTypes().front();
if (isErrorPropagationPossible(op->getOperandTypes())) {
- if (!resultTy.isa<SizeType>())
+ if (!llvm::isa<SizeType>(resultTy))
return op->emitOpError()
<< "if at least one of the operands can hold error values then "
"the result must be of type `size` to propagate them";
@@ -84,7 +84,7 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
assert(op != nullptr && op->getNumResults() == 1);
Type resultTy = op->getResultTypes().front();
if (isErrorPropagationPossible(op->getOperandTypes())) {
- if (!resultTy.isa<ShapeType>())
+ if (!llvm::isa<ShapeType>(resultTy))
return op->emitOpError()
<< "if at least one of the operands can hold error values then "
"the result must be of type `shape` to propagate them";
@@ -94,7 +94,7 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
template <typename... Ty>
static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
- return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
+ return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
}
template <typename... Ty, typename... ranges>
@@ -147,13 +147,15 @@ void ShapeDialect::initialize() {
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
- if (type.isa<ShapeType>() || isExtentTensorType(type))
- return builder.create<ConstShapeOp>(loc, type,
- value.cast<DenseIntElementsAttr>());
- if (type.isa<SizeType>())
- return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
- if (type.isa<WitnessType>())
- return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
+ if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
+ return builder.create<ConstShapeOp>(
+ loc, type, llvm::cast<DenseIntElementsAttr>(value));
+ if (llvm::isa<SizeType>(type))
+ return builder.create<ConstSizeOp>(loc, type,
+ llvm::cast<IntegerAttr>(value));
+ if (llvm::isa<WitnessType>(type))
+ return builder.create<ConstWitnessOp>(loc, type,
+ llvm::cast<BoolAttr>(value));
return arith::ConstantOp::materialize(builder, value, type, loc);
}
@@ -165,7 +167,7 @@ LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
return op->emitError(
"shape.lib attribute may only be on op implementing SymbolTable");
- if (auto symbolRef = attribute.getValue().dyn_cast<SymbolRefAttr>()) {
+ if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
if (!symbol)
return op->emitError("shape function library ")
@@ -176,17 +178,17 @@ LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
<< symbolRef << " required to be shape function library";
}
- if (auto arr = attribute.getValue().dyn_cast<ArrayAttr>()) {
+ if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
// Verify all entries are function libraries and mappings in libraries
// refer to unique ops.
DenseSet<StringAttr> key;
for (auto it : arr) {
- if (!it.isa<SymbolRefAttr>())
+ if (!llvm::isa<SymbolRefAttr>(it))
return op->emitError(
"only SymbolRefAttr allowed in shape.lib attribute array");
auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
- SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
+ SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
if (!shapeFnLib)
return op->emitError()
<< it << " does not refer to FunctionLibraryOp";
@@ -395,8 +397,8 @@ LogicalResult mlir::shape::AddOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType().isa<SizeType>() ||
- operands[1].getType().isa<SizeType>())
+ if (llvm::isa<SizeType>(operands[0].getType()) ||
+ llvm::isa<SizeType>(operands[1].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -617,7 +619,7 @@ OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
getOperation()->eraseOperand(idx);
// Always false if any input is statically known false
- if (!a.cast<BoolAttr>().getValue())
+ if (!llvm::cast<BoolAttr>(a).getValue())
return a;
}
// If this is reached, all inputs were statically known passing.
@@ -651,9 +653,11 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
return nullptr;
auto lhsShape = llvm::to_vector<6>(
- adaptor.getShapes()[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
+ .getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
- adaptor.getShapes()[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+ .getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
// If the shapes are not compatible, we can't fold it.
@@ -677,7 +681,8 @@ struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
auto isPotentiallyNonEmptyShape = [](Value shape) {
- if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
+ if (auto extentTensorTy =
+ llvm::dyn_cast<RankedTensorType>(shape.getType())) {
if (extentTensorTy.getDimSize(0) == 0)
return false;
}
@@ -714,11 +719,11 @@ struct BroadcastForwardSingleOperandPattern
// Insert cast if needed.
if (replacement.getType() != op.getType()) {
auto loc = op.getLoc();
- if (op.getType().isa<ShapeType>()) {
+ if (llvm::isa<ShapeType>(op.getType())) {
replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
} else {
- assert(!op.getType().isa<ShapeType>() &&
- !replacement.getType().isa<ShapeType>() &&
+ assert(!llvm::isa<ShapeType>(op.getType()) &&
+ !llvm::isa<ShapeType>(replacement.getType()) &&
"expect extent tensor cast");
replacement =
rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
@@ -781,7 +786,7 @@ struct CanonicalizeCastExtentTensorOperandsPattern
if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
// Only eliminate the cast if it holds no shape information.
bool isInformationLoosingCast =
- castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
+ llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
if (isInformationLoosingCast) {
anyChange = true;
return castOp.getSource();
@@ -807,14 +812,15 @@ struct BroadcastConcretizeResultTypePattern
LogicalResult matchAndRewrite(BroadcastOp op,
PatternRewriter &rewriter) const override {
// Only concretize dynamic extent tensor result types.
- auto resultTy = op.getType().dyn_cast<RankedTensorType>();
+ auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
if (!resultTy || !resultTy.isDynamicDim(0))
return failure();
// Infer resulting shape rank if possible.
int64_t maxRank = 0;
for (Value shape : op.getShapes()) {
- if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
+ if (auto extentTensorTy =
+ llvm::dyn_cast<RankedTensorType>(shape.getType())) {
// Cannot infer resulting shape rank if any operand is dynamically
// ranked.
if (extentTensorTy.isDynamicDim(0))
@@ -883,12 +889,12 @@ ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
NamedAttrList dummy;
if (parser.parseAttribute(extentsRaw, "dummy", dummy))
return failure();
- auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
+ auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
if (!extentsArray)
return failure();
SmallVector<int64_t, 6> ints;
for (Attribute extent : extentsArray) {
- IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
+ IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
if (!attr)
return failure();
ints.push_back(attr.getInt());
@@ -930,7 +936,7 @@ bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
Type lhs = l.front();
Type rhs = r.front();
- if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
+ if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
// Shape type is compatible with all other valid return types.
return true;
return lhs == rhs;
@@ -956,7 +962,7 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
bool nonScalarSeen = false;
for (Attribute a : attributes) {
- if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
+ if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
if (nonScalarSeen)
return false;
nonScalarSeen = true;
@@ -1070,13 +1076,13 @@ std::optional<int64_t> DimOp::getConstantIndex() {
if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
return constSizeOp.getValue().getLimitedValue();
if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
- return constantOp.getValue().cast<IntegerAttr>().getInt();
+ return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
return std::nullopt;
}
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
Type valType = getValue().getType();
- auto valShapedType = valType.dyn_cast<ShapedType>();
+ auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
if (!valShapedType || !valShapedType.hasRank())
return nullptr;
std::optional<int64_t> index = getConstantIndex();
@@ -1104,7 +1110,7 @@ bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
}
LogicalResult mlir::shape::DimOp::verify() {
- auto st = getValue().getType().cast<ShapedType>();
+ auto st = llvm::cast<ShapedType>(getValue().getType());
if (!st.hasRank())
return success();
if (auto index = getConstantIndex()) {
@@ -1142,8 +1148,8 @@ LogicalResult mlir::shape::DivOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType().isa<SizeType>() ||
- operands[1].getType().isa<SizeType>())
+ if (llvm::isa<SizeType>(operands[0].getType()) ||
+ llvm::isa<SizeType>(operands[1].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -1199,7 +1205,7 @@ OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
return nullptr;
SmallVector<int64_t, 6> extents;
for (auto attr : adaptor.getExtents())
- extents.push_back(attr.cast<IntegerAttr>().getInt());
+ extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
Builder builder(getContext());
return builder.getIndexTensorAttr(extents);
}
@@ -1215,9 +1221,8 @@ void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
}
FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
- auto attr = getMapping()
- .get(op->getName().getIdentifier())
- .dyn_cast_or_null<FlatSymbolRefAttr>();
+ auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
+ getMapping().get(op->getName().getIdentifier()));
if (!attr)
return nullptr;
return lookupSymbol<FuncOp>(attr);
@@ -1329,7 +1334,7 @@ std::optional<int64_t> GetExtentOp::getConstantDim() {
if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
return constSizeOp.getValue().getLimitedValue();
if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
- return constantOp.getValue().cast<IntegerAttr>().getInt();
+ return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
return std::nullopt;
}
@@ -1349,7 +1354,7 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
int64_t dim) {
auto loc = result.location;
auto dimAttr = builder.getIndexAttr(dim);
- if (shape.getType().isa<ShapeType>()) {
+ if (llvm::isa<ShapeType>(shape.getType())) {
Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
build(builder, result, builder.getType<SizeType>(), shape, dim);
} else {
@@ -1405,7 +1410,7 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
return failure();
auto isShapeType = [](Type arg) {
- if (arg.isa<ShapeType>())
+ if (llvm::isa<ShapeType>(arg))
return true;
return isExtentTensorType(arg);
};
@@ -1414,29 +1419,29 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
Type acc = types.front();
for (auto t : drop_begin(types)) {
Type l = acc, r = t;
- if (!l.isa<ShapeType, SizeType>())
+ if (!llvm::isa<ShapeType, SizeType>(l))
std::swap(l, r);
// Handle sizes, propagate error type if present.
- if (l.isa<SizeType>()) {
- if (r.isa<SizeType, IndexType>())
+ if (llvm::isa<SizeType>(l)) {
+ if (llvm::isa<SizeType, IndexType>(r))
acc = l;
else
return emitOptionalError(location, "requires all sizes or shapes");
- } else if (l.isa<IndexType>()) {
- if (r.isa<IndexType>())
+ } else if (llvm::isa<IndexType>(l)) {
+ if (llvm::isa<IndexType>(r))
acc = r;
else
return emitOptionalError(location, "requires all sizes or shapes");
- } else if (l.isa<ShapeType>()) {
+ } else if (llvm::isa<ShapeType>(l)) {
// Handle shapes, propagate error type if present.
if (isShapeType(r))
acc = l;
else
return emitOptionalError(location, "requires all sizes or shapes");
} else if (isExtentTensorType(l)) {
- auto rank1 = l.cast<RankedTensorType>().getShape()[0];
- auto rank2 = r.cast<RankedTensorType>().getShape()[0];
+ auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
+ auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
if (ShapedType::isDynamic(rank1))
acc = l;
else if (ShapedType::isDynamic(rank2))
@@ -1460,13 +1465,13 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
Type lhs = l.front();
Type rhs = r.front();
- if (!lhs.isa<ShapeType, SizeType>())
+ if (!llvm::isa<ShapeType, SizeType>(lhs))
std::swap(lhs, rhs);
- if (lhs.isa<SizeType>())
- return rhs.isa<SizeType, IndexType>();
- if (lhs.isa<ShapeType>())
- return rhs.isa<ShapeType, TensorType>();
+ if (llvm::isa<SizeType>(lhs))
+ return llvm::isa<SizeType, IndexType>(rhs);
+ if (llvm::isa<ShapeType>(lhs))
+ return llvm::isa<ShapeType, TensorType>(rhs);
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
return true;
@@ -1511,14 +1516,14 @@ struct RankShapeOfCanonicalizationPattern
if (!shapeOfOp)
return failure();
auto rankedTensorType =
- shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
+ llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
if (!rankedTensorType)
return failure();
int64_t rank = rankedTensorType.getRank();
- if (op.getType().isa<IndexType>()) {
+ if (llvm::isa<IndexType>(op.getType())) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
rank);
- } else if (op.getType().isa<shape::SizeType>()) {
+ } else if (llvm::isa<shape::SizeType>(op.getType())) {
rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
} else {
return failure();
@@ -1537,7 +1542,7 @@ LogicalResult mlir::shape::RankOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType().isa<ShapeType>())
+ if (llvm::isa<ShapeType>(operands[0].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -1563,7 +1568,7 @@ OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
return {};
APInt product(64, 1);
- for (auto value : shape.cast<DenseIntElementsAttr>())
+ for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
product *= value;
Builder builder(getContext());
return builder.getIndexAttr(product.getLimitedValue());
@@ -1573,7 +1578,7 @@ LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType().isa<ShapeType>())
+ if (llvm::isa<ShapeType>(operands[0].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -1615,9 +1620,9 @@ LogicalResult mlir::shape::MaxOp::inferReturnTypes(
bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
- if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
+ if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
return true;
- if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
+ if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
return true;
return false;
}
@@ -1647,9 +1652,9 @@ LogicalResult mlir::shape::MinOp::inferReturnTypes(
bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
- if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
+ if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
return true;
- if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
+ if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
return true;
return false;
}
@@ -1674,8 +1679,8 @@ LogicalResult mlir::shape::MulOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType().isa<SizeType>() ||
- operands[1].getType().isa<SizeType>())
+ if (llvm::isa<SizeType>(operands[0].getType()) ||
+ llvm::isa<SizeType>(operands[1].getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@@ -1694,7 +1699,7 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
//===----------------------------------------------------------------------===//
OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
- auto type = getOperand().getType().dyn_cast<ShapedType>();
+ auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
if (!type || !type.hasStaticShape())
return nullptr;
Builder builder(getContext());
@@ -1707,9 +1712,9 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
- if (!op.getArg().getType().isa<ShapedType>())
+ if (!llvm::isa<ShapedType>(op.getArg().getType()))
return failure();
- if (op.getType().isa<ShapedType>())
+ if (llvm::isa<ShapedType>(op.getType()))
return failure();
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
@@ -1732,7 +1737,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
LogicalResult matchAndRewrite(tensor::CastOp op,
PatternRewriter &rewriter) const override {
- auto ty = op.getType().dyn_cast<RankedTensorType>();
+ auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
if (!ty || ty.getRank() != 1)
return failure();
@@ -1741,7 +1746,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
return failure();
// Argument type must be ranked and must not conflict.
- auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
+ auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
return failure();
@@ -1761,10 +1766,10 @@ LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType().isa<ValueShapeType>())
+ if (llvm::isa<ValueShapeType>(operands[0].getType()))
inferredReturnTypes.assign({ShapeType::get(context)});
else {
- auto shapedTy = operands[0].getType().cast<ShapedType>();
+ auto shapedTy = llvm::cast<ShapedType>(operands[0].getType());
int64_t rank =
shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
Type indexTy = IndexType::get(context);
@@ -1783,10 +1788,11 @@ bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
Type lhs = l.front();
Type rhs = r.front();
- if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
+ if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
+ !llvm::isa<ShapeType, ShapedType>(rhs))
return false;
- if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
+ if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
// Shape type is compatible with all other valid return types.
return true;
@@ -1819,7 +1825,8 @@ void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
- return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
+ return llvm::isa<IndexType, SizeType>(inputs[0]) &&
+ llvm::isa<IndexType>(outputs[0]);
}
//===----------------------------------------------------------------------===//
@@ -1884,16 +1891,16 @@ OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
- if (auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
- if (!inputTensor.getElementType().isa<IndexType>() ||
+ if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
+ if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
inputTensor.getRank() != 1)
return false;
- } else if (!inputs[0].isa<ShapeType>()) {
+ } else if (!llvm::isa<ShapeType>(inputs[0])) {
return false;
}
- TensorType outputTensor = outputs[0].dyn_cast<TensorType>();
- return outputTensor && outputTensor.getElementType().isa<IndexType>();
+ TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
+ return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
}
//===----------------------------------------------------------------------===//
@@ -1911,7 +1918,7 @@ void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
bodyBlock.addArgument(builder.getIndexType(), result.location);
Type elementType;
- if (auto tensorType = shape.getType().dyn_cast<TensorType>())
+ if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
elementType = tensorType.getElementType();
else
elementType = SizeType::get(builder.getContext());
@@ -1934,7 +1941,7 @@ LogicalResult ReduceOp::verify() {
<< blockArgsCount << " arguments";
// The first block argument is the index and must always be of type `index`.
- if (!block.getArgument(0).getType().isa<IndexType>())
+ if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
return emitOpError(
"argument 0 of ReduceOp body is expected to be of IndexType");
@@ -1942,12 +1949,12 @@ LogicalResult ReduceOp::verify() {
// `index`, depending on whether the reduce operation is applied to a shape or
// to an extent tensor.
Type extentTy = block.getArgument(1).getType();
- if (getShape().getType().isa<ShapeType>()) {
- if (!extentTy.isa<SizeType>())
+ if (llvm::isa<ShapeType>(getShape().getType())) {
+ if (!llvm::isa<SizeType>(extentTy))
return emitOpError("argument 1 of ReduceOp body is expected to be of "
"SizeType if the ReduceOp operates on a ShapeType");
} else {
- if (!extentTy.isa<IndexType>())
+ if (!llvm::isa<IndexType>(extentTy))
return emitOpError(
"argument 1 of ReduceOp body is expected to be of IndexType if the "
"ReduceOp operates on an extent tensor");
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4b98030a9dbfa..2def7ccfba946 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -261,10 +261,10 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
if (attrName == "dimLevelType") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr));
- auto arrayAttr = attr.dyn_cast<ArrayAttr>();
+ auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr);
ERROR_IF(!arrayAttr, "expected an array for dimension level types")
for (auto i : arrayAttr) {
- auto strAttr = i.dyn_cast<StringAttr>();
+ auto strAttr = llvm::dyn_cast<StringAttr>(i);
ERROR_IF(!strAttr, "expected a string value in dimension level types")
auto strVal = strAttr.getValue();
if (auto optDLT = parseDLT(strVal)) {
@@ -279,25 +279,25 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
} else if (attrName == "dimOrdering") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
- auto affineAttr = attr.dyn_cast<AffineMapAttr>();
+ auto affineAttr = llvm::dyn_cast<AffineMapAttr>(attr);
ERROR_IF(!affineAttr, "expected an affine map for dimension ordering")
dimOrd = affineAttr.getValue();
} else if (attrName == "higherOrdering") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
- auto affineAttr = attr.dyn_cast<AffineMapAttr>();
+ auto affineAttr = llvm::dyn_cast<AffineMapAttr>(attr);
ERROR_IF(!affineAttr, "expected an affine map for higher ordering")
higherOrd = affineAttr.getValue();
} else if (attrName == "posWidth") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
- auto intAttr = attr.dyn_cast<IntegerAttr>();
+ auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
ERROR_IF(!intAttr, "expected an integral position bitwidth")
posWidth = intAttr.getInt();
} else if (attrName == "crdWidth") {
Attribute attr;
RETURN_ON_FAIL(parser.parseAttribute(attr))
- auto intAttr = attr.dyn_cast<IntegerAttr>();
+ auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
ERROR_IF(!intAttr, "expected an integral index bitwidth")
crdWidth = intAttr.getInt();
} else if (attrName == "slice") {
@@ -305,7 +305,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
// Dispatches to DimSliceAttr to skip mnemonic
bool finished = false;
while (auto attr = SparseTensorDimSliceAttr::parse(parser, nullptr)) {
- auto sliceAttr = attr.cast<SparseTensorDimSliceAttr>();
+ auto sliceAttr = llvm::cast<SparseTensorDimSliceAttr>(attr);
slices.push_back(sliceAttr);
if (parser.parseOptionalComma().failed()) {
finished = true;
@@ -442,9 +442,9 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
SparseTensorEncodingAttr
mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
- if (auto ttp = type.dyn_cast<RankedTensorType>())
- return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
- if (auto mdtp = type.dyn_cast<StorageSpecifierType>())
+ if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
+ return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
+ if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
return mdtp.getEncoding();
return nullptr;
}
@@ -725,12 +725,12 @@ unsigned UnpackOp::getNumBatchedLvls() {
}
LogicalResult ConvertOp::verify() {
- if (auto tp1 = getSource().getType().dyn_cast<RankedTensorType>()) {
- if (auto tp2 = getDest().getType().dyn_cast<RankedTensorType>()) {
+ if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
+ if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
if (tp1.getRank() != tp2.getRank())
return emitError("unexpected conversion mismatch in rank");
auto dstEnc =
- tp2.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
+ llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
if (dstEnc && dstEnc.isSlice())
return emitError("cannot convert to a sparse tensor slice");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index defef6608edb7..20fc678773be0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -112,7 +112,7 @@ static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
Value mem, ArrayRef<Value> idxs, Value vmask) {
VectorType vtp = vectorType(vl, mem);
Value pass = constantZero(rewriter, loc, vtp);
- if (idxs.back().getType().isa<VectorType>()) {
+ if (llvm::isa<VectorType>(idxs.back().getType())) {
SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
Value indexVec = idxs.back();
scalarArgs.back() = constantIndex(rewriter, loc, 0);
@@ -129,7 +129,7 @@ static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
/// the last index, i.e. back().
static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
ArrayRef<Value> idxs, Value vmask, Value rhs) {
- if (idxs.back().getType().isa<VectorType>()) {
+ if (llvm::isa<VectorType>(idxs.back().getType())) {
SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
Value indexVec = idxs.back();
scalarArgs.back() = constantIndex(rewriter, loc, 0);
@@ -260,7 +260,7 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
// innermost loop simply pass through as well.
// Example:
// a[i][j] for both i and j
- if (auto arg = sub.dyn_cast<BlockArgument>()) {
+ if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
if (isInvariantArg(arg, block) == innermost)
return false;
if (codegen)
@@ -298,8 +298,8 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
Location loc = forOp.getLoc();
Value vload =
genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
- Type etp = vload.getType().cast<VectorType>().getElementType();
- if (!etp.isa<IndexType>()) {
+ Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
+ if (!llvm::isa<IndexType>(etp)) {
if (etp.getIntOrFloatBitWidth() < 32)
vload = rewriter.create<arith::ExtUIOp>(
loc, vectorType(vl, rewriter.getI32Type()), vload);
@@ -318,7 +318,7 @@ static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
Value inv = load.getOperand(0);
Value idx = load.getOperand(1);
if (isInvariantValue(inv, block)) {
- if (auto arg = idx.dyn_cast<BlockArgument>()) {
+ if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
if (isInvariantArg(arg, block) || !innermost)
return false;
if (codegen)
@@ -369,7 +369,7 @@ static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
if (!VectorType::isValidElementType(exp.getType()))
return false;
// A block argument is invariant/reduction/index.
- if (auto arg = exp.dyn_cast<BlockArgument>()) {
+ if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
if (arg == forOp.getInductionVar()) {
// We encountered a single, innermost index inside the computation,
// such as a[i] = i, which must convert to [i, i+1, ...].
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index f30423de62b40..8a52082dfca5f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -130,7 +130,8 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape,
ArrayRef<AffineMap> reassocation) {
return dstStaticShape.size() >
- static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
+ static_cast<size_t>(
+ llvm::cast<ShapedType>(src.getType()).getRank())
? getExpandedOutputShapeFromInputShape(
builder, loc, src, dstStaticShape, reassocation)
: getCollapsedOutputShapeFromInputShape(
@@ -185,7 +186,7 @@ struct ReifyPadOp
return;
}
int64_t staticValue =
- valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt();
+ llvm::cast<IntegerAttr>(valueOrAttr.get<Attribute>()).getInt();
expr = expr + staticValue;
};
addOpFoldResult(lowPad[dim]);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 74a21771d034e..eab64b5cf9994 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -42,13 +42,13 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
return op;
if (complex::ConstantOp::isBuildableWith(value, type))
return builder.create<complex::ConstantOp>(loc, type,
- value.cast<ArrayAttr>());
+ llvm::cast<ArrayAttr>(value));
return nullptr;
}
SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
Location loc, Value value) {
- auto tensorType = value.getType().cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> result;
for (int64_t i = 0; i < tensorType.getRank(); ++i) {
if (tensorType.isDynamicDim(i)) {
@@ -63,7 +63,7 @@ SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
OpResult opResult) {
- auto tensorType = opResult.getType().dyn_cast<TensorType>();
+ auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
assert(tensorType && "expected tensor type");
// If the op has a destination, it implements DestinationStyleOpInterface and
@@ -100,7 +100,7 @@ LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
Operation *op,
SmallVector<Value> &result) {
for (OpResult opResult : op->getResults()) {
- if (opResult.getType().isa<TensorType>()) {
+ if (llvm::isa<TensorType>(opResult.getType())) {
FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
if (failed(destination))
return failure();
@@ -111,8 +111,8 @@ LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
}
bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
- if (auto rtp1 = tp1.dyn_cast<RankedTensorType>()) {
- if (auto rtp2 = tp2.dyn_cast<RankedTensorType>())
+ if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
+ if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
return rtp1.getShape() == rtp2.getShape() &&
rtp1.getElementType() == rtp2.getElementType();
return false;
@@ -131,7 +131,7 @@ static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
// Rank-reduced dims must have a static unit dimension.
bool isStaticUnitSize =
size.value().is<Attribute>() &&
- size.value().get<Attribute>().cast<IntegerAttr>().getInt() == 1;
+ llvm::cast<IntegerAttr>(size.value().get<Attribute>()).getInt() == 1;
if (shapePos == static_cast<int64_t>(reducedShape.size())) {
// There are no more dims in the reduced shape. All remaining sizes must
@@ -220,8 +220,8 @@ void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
/// Returns true if `target` is a ranked tensor type that preserves static
/// information available in the `source` ranked tensor type.
bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
- auto sourceType = source.dyn_cast<RankedTensorType>();
- auto targetType = target.dyn_cast<RankedTensorType>();
+ auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
+ auto targetType = llvm::dyn_cast<RankedTensorType>(target);
// Requires RankedTensorType.
if (!sourceType || !targetType)
@@ -322,8 +322,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
- auto aT = a.dyn_cast<TensorType>();
- auto bT = b.dyn_cast<TensorType>();
+ auto aT = llvm::dyn_cast<TensorType>(a);
+ auto bT = llvm::dyn_cast<TensorType>(b);
if (!aT || !bT)
return false;
@@ -380,9 +380,9 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
return failure();
auto sourceType =
- tensorCastOperand.getOperand().getType().cast<TensorType>();
- auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
- auto resultType = tensorCast.getType().cast<TensorType>();
+ llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
+ auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
+ auto resultType = llvm::cast<TensorType>(tensorCast.getType());
// We can remove the intermediate cast if joining all three produces the
// same result as just joining the source and result shapes.
@@ -427,15 +427,15 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
// Cannot fold cast to unranked tensor.
- auto rankedResultType = tensorCast.getType().dyn_cast<RankedTensorType>();
+ auto rankedResultType =
+ llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
if (!rankedResultType)
return failure();
if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
- rankedResultType.getShape() == tensorCast.getSource()
- .getType()
- .cast<RankedTensorType>()
- .getShape())
+ rankedResultType.getShape() ==
+ llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
+ .getShape())
return failure();
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
@@ -506,7 +506,7 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return {};
// Folding for unranked types (UnrankedTensorType) is not supported.
- auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
+ auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
if (!tensorType)
return {};
@@ -527,7 +527,7 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// Fold dim to the operand of tensor.generate.
if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
auto resultType =
- fromElements.getResult().getType().cast<RankedTensorType>();
+ llvm::cast<RankedTensorType>(fromElements.getResult().getType());
// The case where the type encodes the size of the dimension is handled
// above.
assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
@@ -751,7 +751,8 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
if (!producer)
return failure();
- auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
+ auto resultType =
+ llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
ArrayRef<int64_t> resultShape = resultType.getShape();
SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
SmallVector<OpFoldResult> newMixedSizes;
@@ -765,7 +766,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
// result dim matches.
if (auto attr = currDim.dyn_cast<Attribute>()) {
if (ShapedType::isDynamic(newDim) ||
- newDim != attr.cast<IntegerAttr>().getInt()) {
+ newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
// Something is off, the cast result shape cannot be more dynamic
// than the empty tensor result shape (enforced by
// `canFoldIntoProducer`). Abort for now.
@@ -826,7 +827,7 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
if (!tensorCast)
return failure();
- if (!tensorCast.getSource().getType().isa<RankedTensorType>())
+ if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
return failure();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extract, tensorCast.getSource(), extract.getIndices());
@@ -843,7 +844,7 @@ void ExtractOp::getAsmResultNames(
LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
- auto tensorType = getTensor().getType().cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
return emitOpError("incorrect number of indices for extract_element");
return success();
@@ -853,20 +854,20 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// If this is a splat elements attribute, simply return the value. All of
// the elements of a splat attribute are the same.
if (Attribute tensor = adaptor.getTensor())
- if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
+ if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
return splatTensor.getSplatValue<Attribute>();
// Collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : adaptor.getIndices()) {
- if (!indice || !indice.isa<IntegerAttr>())
+ if (!indice || !llvm::isa<IntegerAttr>(indice))
return {};
- indices.push_back(indice.cast<IntegerAttr>().getInt());
+ indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
}
// Fold extract(from_elements(...)).
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
- auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
+ auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
auto rank = tensorType.getRank();
assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
"rank mismatch");
@@ -887,7 +888,7 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// If this is an elements attribute, query the value at the given indices.
if (Attribute tensor = adaptor.getTensor()) {
- auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
+ auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValues<Attribute>()[indices];
}
@@ -1070,7 +1071,7 @@ void InsertOp::getAsmResultNames(
LogicalResult InsertOp::verify() {
// Verify the # indices match if we have a ranked type.
- auto destType = getDest().getType().cast<RankedTensorType>();
+ auto destType = llvm::cast<RankedTensorType>(getDest().getType());
if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
return emitOpError("incorrect number of indices");
return success();
@@ -1080,7 +1081,7 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
Attribute scalar = adaptor.getScalar();
Attribute dest = adaptor.getDest();
if (scalar && dest)
- if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
+ if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
if (scalar == splatDest.getSplatValue<Attribute>())
return dest;
return {};
@@ -1113,7 +1114,7 @@ LogicalResult GenerateOp::reifyResultShapes(
LogicalResult GenerateOp::verify() {
// Ensure that the tensor type has as many dynamic dimensions as are
// specified by the operands.
- RankedTensorType resultTy = getType().cast<RankedTensorType>();
+ RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
if (getNumOperands() != resultTy.getNumDynamicDims())
return emitError("must have as many index operands as dynamic extents "
"in the result type");
@@ -1122,7 +1123,7 @@ LogicalResult GenerateOp::verify() {
}
LogicalResult GenerateOp::verifyRegions() {
- RankedTensorType resultTy = getType().cast<RankedTensorType>();
+ RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
// Ensure that region arguments span the index space.
if (!llvm::all_of(getBody().getArgumentTypes(),
[](Type ty) { return ty.isIndex(); }))
@@ -1150,7 +1151,7 @@ void GenerateOp::build(
// Build and populate body.
OpBuilder::InsertionGuard guard(b);
Region *bodyRegion = result.regions.front().get();
- auto rank = resultTy.cast<RankedTensorType>().getRank();
+ auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
SmallVector<Location, 2> argumentLocs(rank, result.location);
Block *bodyBlock =
@@ -1170,7 +1171,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
PatternRewriter &rewriter) const final {
auto resultType =
- tensorFromElements.getResult().getType().cast<RankedTensorType>();
+ llvm::cast<RankedTensorType>(tensorFromElements.getResult().getType());
if (resultType.hasStaticShape())
return failure();
@@ -1261,7 +1262,7 @@ void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
// Constant fold rank when the rank of the operand is known.
auto type = getOperand().getType();
- auto shapedType = type.dyn_cast<ShapedType>();
+ auto shapedType = llvm::dyn_cast<ShapedType>(type);
if (shapedType && shapedType.hasRank())
return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
return IntegerAttr();
@@ -1284,17 +1285,17 @@ static int64_t getNumElements(ShapedType type) {
}
LogicalResult ReshapeOp::verify() {
- TensorType operandType = getSource().getType().cast<TensorType>();
- TensorType resultType = getResult().getType().cast<TensorType>();
+ TensorType operandType = llvm::cast<TensorType>(getSource().getType());
+ TensorType resultType = llvm::cast<TensorType>(getResult().getType());
if (operandType.getElementType() != resultType.getElementType())
return emitOpError("element types of source and destination tensor "
"types should be the same");
int64_t shapeSize =
- getShape().getType().cast<RankedTensorType>().getDimSize(0);
- auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
- auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
+ llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
+ auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
+ auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
if (resultRankedType) {
if (operandRankedType && resultRankedType.hasStaticShape() &&
@@ -1392,7 +1393,7 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto resultType = inferCollapsedType(
- src.getType().cast<RankedTensorType>(),
+ llvm::cast<RankedTensorType>(src.getType()),
getSymbolLessAffineMaps(
convertReassociationIndicesToExprs(b.getContext(), reassociation)));
build(b, result, resultType, src, attrs);
@@ -1488,7 +1489,7 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
if (!fromElements)
return failure();
- auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
+ auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
if (!shapedTy.hasStaticShape())
return failure();
@@ -1510,7 +1511,7 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
return failure();
RankedTensorType srcType =
- castOp.getSource().getType().cast<RankedTensorType>();
+ llvm::cast<RankedTensorType>(castOp.getSource().getType());
RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
srcType, collapseShapeOp.getReassociationMaps());
@@ -1693,9 +1694,8 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
// Type inferred in the absence of rank-reducing behavior.
- auto inferredType =
- inferResultType(sourceRankedTensorType, offsets, sizes, strides)
- .cast<RankedTensorType>();
+ auto inferredType = llvm::cast<RankedTensorType>(
+ inferResultType(sourceRankedTensorType, offsets, sizes, strides));
int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
@@ -1739,13 +1739,11 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
- auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
+ auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
- resultType =
- ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
- staticSizes, staticStrides)
- .cast<RankedTensorType>();
+ resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
+ sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
}
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
@@ -1831,7 +1829,7 @@ llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
FailureOr<Value>
ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
ArrayRef<int64_t> desiredShape) {
- auto sourceTensorType = value.getType().dyn_cast<RankedTensorType>();
+ auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
assert(sourceTensorType && "not a ranked tensor type");
auto sourceShape = sourceTensorType.getShape();
if (sourceShape.equals(desiredShape))
@@ -1968,8 +1966,8 @@ class ConstantOpExtractSliceFolder final
return failure();
// Dynamic result shape is not supported.
- auto sourceType = op.getSource().getType().cast<ShapedType>();
- auto resultType = op.getResult().getType().cast<ShapedType>();
+ auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
+ auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
@@ -2004,13 +2002,13 @@ class ConstantOpExtractSliceFolder final
// New attribute constructed by the sliced values.
DenseElementsAttr newAttr;
- if (auto elems = attr.dyn_cast<DenseIntElementsAttr>()) {
+ if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
SmallVector<APInt> outValues;
outValues.reserve(sourceType.getNumElements());
sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
elems.begin(), counts, offsets, sizes, strides, &outValues);
newAttr = DenseElementsAttr::get(resultType, outValues);
- } else if (auto elems = attr.dyn_cast<DenseFPElementsAttr>()) {
+ } else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
SmallVector<APFloat> outValues;
outValues.reserve(sourceType.getNumElements());
sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
@@ -2109,7 +2107,7 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
- auto resultType = getResult().getType().cast<ShapedType>();
+ auto resultType = llvm::cast<ShapedType>(getResult().getType());
if (resultType.hasStaticShape())
return splat.resizeSplat(resultType);
}
@@ -2124,7 +2122,7 @@ OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
- auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
+ auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
unsigned rank = rankedTensorType.getRank();
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
@@ -2372,8 +2370,8 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
auto src =
(sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
- auto srcType = src.getType().template dyn_cast<RankedTensorType>();
- auto dstType = dst.getType().template dyn_cast<RankedTensorType>();
+ auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
+ auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
if (!srcType || !dstType)
return failure();
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
@@ -2482,7 +2480,7 @@ Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
Location loc,
Value tensor,
Value dest) {
- auto rankedTensorType = dest.getType().cast<RankedTensorType>();
+ auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
unsigned rank = rankedTensorType.getRank();
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
@@ -2514,8 +2512,8 @@ parseInferType(OpAsmParser &parser,
}
LogicalResult PadOp::verify() {
- auto sourceType = getSource().getType().cast<RankedTensorType>();
- auto resultType = getResult().getType().cast<RankedTensorType>();
+ auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
+ auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
auto expectedType =
PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
if (!expectedType) {
@@ -2542,7 +2540,7 @@ LogicalResult PadOp::verify() {
LogicalResult PadOp::verifyRegions() {
auto ®ion = getRegion();
- unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
+ unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
Block &block = region.front();
if (block.getNumArguments() != rank)
return emitError("expected the block to have ") << rank << " arguments";
@@ -2557,7 +2555,7 @@ LogicalResult PadOp::verifyRegions() {
// Ensure that the region yields an element of the right type.
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
if (yieldOp.getValue().getType() !=
- getType().cast<ShapedType>().getElementType())
+ llvm::cast<ShapedType>(getType()).getElementType())
return emitOpError("expected yield type to match shape element type");
return success();
@@ -2597,7 +2595,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ArrayRef<int64_t> staticLow,
ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
bool nofold, ArrayRef<NamedAttribute> attrs) {
- auto sourceType = source.getType().cast<RankedTensorType>();
+ auto sourceType = llvm::cast<RankedTensorType>(source.getType());
if (!resultType)
resultType = inferResultType(sourceType, staticLow, staticHigh);
build(b, result, resultType, source, low, high,
@@ -2609,7 +2607,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ValueRange low, ValueRange high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
- auto sourceType = source.getType().cast<RankedTensorType>();
+ auto sourceType = llvm::cast<RankedTensorType>(source.getType());
unsigned rank = sourceType.getRank();
SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
build(b, result, resultType, source, staticVector, staticVector, low, high,
@@ -2620,7 +2618,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ArrayRef<OpFoldResult> low,
ArrayRef<OpFoldResult> high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
- auto sourceType = source.getType().cast<RankedTensorType>();
+ auto sourceType = llvm::cast<RankedTensorType>(source.getType());
SmallVector<Value, 4> dynamicLow, dynamicHigh;
SmallVector<int64_t, 4> staticLow, staticHigh;
// staticLow and staticHigh have full information of the padding config.
@@ -2632,7 +2630,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
if (!resultType) {
resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
}
- assert(resultType.isa<RankedTensorType>());
+ assert(llvm::isa<RankedTensorType>(resultType));
build(b, result, resultType, source, dynamicLow, dynamicHigh,
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
nofold ? b.getUnitAttr() : UnitAttr());
@@ -2647,7 +2645,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
// Add a region and a block to yield the pad value.
Region *region = result.regions[0].get();
- int sourceRank = source.getType().cast<RankedTensorType>().getRank();
+ int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
SmallVector<Location> blockArgLocs(sourceRank, result.location);
@@ -2700,7 +2698,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
return failure();
auto newResultType = PadOp::inferResultType(
- castOp.getSource().getType().cast<RankedTensorType>(),
+ llvm::cast<RankedTensorType>(castOp.getSource().getType()),
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
padTensorOp.getResultType().getShape());
@@ -2919,9 +2917,9 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
LogicalResult matchAndRewrite(PadOp padTensorOp,
PatternRewriter &rewriter) const override {
Value input = padTensorOp.getSource();
- if (!input.getType().isa<RankedTensorType>())
+ if (!llvm::isa<RankedTensorType>(input.getType()))
return failure();
- auto inputDims = input.getType().cast<RankedTensorType>().getShape();
+ auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
auto inputRank = inputDims.size();
auto oldResultType =
@@ -3240,7 +3238,7 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
"applies to only pack or unpack operations");
int64_t destRank = op.getDestRank();
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
- ShapedType resultType = op.getResult().getType().template cast<ShapedType>();
+ ShapedType resultType = llvm::cast<ShapedType>(op.getResult().getType());
for (auto dim : llvm::seq<int64_t>(0, destRank)) {
if (resultType.isDynamicDim(dim)) {
reifiedReturnShapes[0][dim] =
@@ -3655,8 +3653,8 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
};
SmallVector<OpFoldResult> mixedSizes;
- for (auto [index, value] :
- llvm::enumerate(source.getType().cast<RankedTensorType>().getShape())) {
+ for (auto [index, value] : llvm::enumerate(
+ llvm::cast<RankedTensorType>(source.getType()).getShape())) {
if (ShapedType::isDynamic(value))
mixedSizes.push_back(b.create<DimOp>(loc, source, index).getResult());
else
@@ -3671,7 +3669,7 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
- auto elemType = source.getType().cast<ShapedType>().getElementType();
+ auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
}
@@ -3789,7 +3787,7 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
bool PackOp::isLikePad() {
auto packedTensorType =
- (*this)->getResultTypes().front().cast<RankedTensorType>();
+ llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
return isLikePadUnPad(*this, packedTensorType);
}
@@ -3861,7 +3859,7 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
};
SmallVector<OpFoldResult> mixedSizes;
- auto srcType = source.getType().cast<RankedTensorType>();
+ auto srcType = llvm::cast<RankedTensorType>(source.getType());
for (auto i :
llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
if (srcType.isDynamicDim(i))
@@ -3944,7 +3942,7 @@ struct FoldTensorCastProducerOp
// If no operand comes from a tensor::CastOp and can be folded then fail.
bool hasTensorCastOperand =
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
- if (opOperand.get().isa<BlockArgument>())
+ if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
@@ -3961,7 +3959,7 @@ struct FoldTensorCastProducerOp
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
- !newOperands.back().getType().isa<MemRefType>())
+ !llvm::isa<MemRefType>(newOperands.back().getType()))
newResultTypes.push_back(newOperands.back().getType());
}
diff --git a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
index 757885a736ae6..d3f16756b17a5 100644
--- a/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -24,8 +24,8 @@ struct CastOpInterface
auto castOp = cast<CastOp>(op);
assert(value == castOp.getResult() && "invalid value");
- if (castOp.getResult().getType().isa<RankedTensorType>() &&
- castOp.getSource().getType().isa<RankedTensorType>()) {
+ if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
+ llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
}
}
@@ -100,7 +100,8 @@ struct RankOpInterface
auto rankOp = cast<RankOp>(op);
assert(value == rankOp.getResult() && "invalid value");
- auto tensorType = rankOp.getTensor().getType().dyn_cast<RankedTensorType>();
+ auto tensorType =
+ llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
if (!tensorType)
return;
cstr.bound(value) == tensorType.getRank();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 11a9661800e5b..d5263a3ff0567 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,8 +88,8 @@ struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput1();
- ShapedType inputTy = input.getType().cast<ShapedType>();
- ShapedType resultTy = op.getType().cast<ShapedType>();
+ ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
+ ShapedType resultTy = llvm::cast<ShapedType>(op.getType());
if (inputTy.getElementType() != resultTy.getElementType())
return rewriter.notifyMatchFailure(op, "element type does not match.");
@@ -106,7 +106,7 @@ struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
// Build new const op with correct output shape
DenseElementsAttr outputAttr = inputAttr.reshape(
- inputAttr.getType().cast<ShapedType>().clone(op.getNewShape()));
+ llvm::cast<ShapedType>(inputAttr.getType()).clone(op.getNewShape()));
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultTy, outputAttr);
return success();
}
@@ -198,7 +198,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
}
auto input = op.getInput1();
- auto inputTy = input.getType().cast<ShapedType>();
+ auto inputTy = llvm::cast<ShapedType>(input.getType());
if (!inputTy.hasRank())
return rewriter.notifyMatchFailure(op, "Unranked input.");
@@ -255,15 +255,15 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
auto input = op.getInput1();
auto padding = op.getPadding();
- ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
Type elementTy = inputTy.getElementType();
Attribute constantAttr;
- if (elementTy.isa<FloatType>()) {
+ if (llvm::isa<FloatType>(elementTy)) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- } else if (elementTy.isa<IntegerType>() && !op.getQuantizationInfo()) {
+ } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- } else if (elementTy.isa<IntegerType>() && op.getQuantizationInfo()) {
+ } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
auto value = op.getQuantizationInfo()->getInputZp();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
@@ -298,8 +298,8 @@ struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value output = op.getOutput();
- ShapedType inputType = input.getType().cast<ShapedType>();
- ShapedType outputType = output.getType().cast<ShapedType>();
+ ShapedType inputType = llvm::cast<ShapedType>(input.getType());
+ ShapedType outputType = llvm::cast<ShapedType>(output.getType());
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
return failure();
@@ -332,8 +332,7 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();
- auto inputType =
- op.getInput().getType().template dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
auto inputElementType = inputType.getElementType();
if (!inputType.hasStaticShape()) {
@@ -373,7 +372,7 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
return failure();
}
- if (inputElementType.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(inputElementType)) {
int64_t minClamp = op.getMinInt();
int64_t maxClamp = op.getMaxInt();
@@ -498,19 +497,19 @@ template <typename IntFolder, typename FloatFolder>
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- auto lETy = lhs.getType().cast<ShapedType>().getElementType();
- auto rETy = rhs.getType().cast<ShapedType>().getElementType();
+ auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+ auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
if (lETy != rETy)
return {};
- if (lETy.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(lETy)) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
auto result = IntFolder()(l, r);
return DenseElementsAttr::get(returnTy, result);
}
- if (lETy.isa<FloatType>()) {
+ if (llvm::isa<FloatType>(lETy)) {
APFloat l = lhs.getSplatValue<APFloat>();
APFloat r = rhs.getSplatValue<APFloat>();
auto result = FloatFolder()(l, r);
@@ -522,18 +521,18 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
}
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
- if (elemType.isa<FloatType>())
+ if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
- if (elemType.isa<IntegerType>())
+ if (llvm::isa<IntegerType>(elemType))
return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
return false;
}
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
- if (elemType.isa<FloatType>())
+ if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() &&
val.getSplatValue<APFloat>().isExactlyValue(1.0);
- if (elemType.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(elemType)) {
const int64_t shifted = 1LL << shift;
return val && val.isSplat() &&
val.getSplatValue<APInt>().getSExtValue() == shifted;
@@ -542,9 +541,9 @@ static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
}
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
- auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
- auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
- auto resultTy = getType().dyn_cast<RankedTensorType>();
+ auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
+ auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
+ auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!lhsTy || !rhsTy || !resultTy)
return {};
@@ -565,9 +564,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
- auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
- auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
- auto resultTy = getType().dyn_cast<RankedTensorType>();
+ auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
+ auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
+ auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!lhsTy || !rhsTy || !resultTy)
return {};
if (lhsTy != rhsTy)
@@ -577,17 +576,19 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
if (lhsAttr && lhsAttr.isSplat()) {
- if (resultETy.isa<IntegerType>() && lhsAttr.getSplatValue<APInt>().isZero())
+ if (llvm::isa<IntegerType>(resultETy) &&
+ lhsAttr.getSplatValue<APInt>().isZero())
return lhsAttr;
}
if (rhsAttr && rhsAttr.isSplat()) {
- if (resultETy.isa<IntegerType>() && rhsAttr.getSplatValue<APInt>().isOne())
+ if (llvm::isa<IntegerType>(resultETy) &&
+ rhsAttr.getSplatValue<APInt>().isOne())
return getInput1();
}
if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
- if (resultETy.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(resultETy)) {
APInt l = lhsAttr.getSplatValue<APInt>();
APInt r = rhsAttr.getSplatValue<APInt>();
APInt result = l.sdiv(r);
@@ -602,7 +603,7 @@ namespace {
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- if (ty.getElementType().isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(ty.getElementType())) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
@@ -619,7 +620,7 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
return DenseElementsAttr::get(ty, result);
}
- if (ty.getElementType().isa<FloatType>()) {
+ if (llvm::isa<FloatType>(ty.getElementType())) {
APFloat l = lhs.getSplatValue<APFloat>();
APFloat r = rhs.getSplatValue<APFloat>();
APFloat result = l * r;
@@ -634,9 +635,9 @@ DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto lhs = getInput1();
auto rhs = getInput2();
- auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
- auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
- auto resultTy = getType().dyn_cast<RankedTensorType>();
+ auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
+ auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
+ auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!lhsTy || !rhsTy || !resultTy)
return {};
@@ -644,7 +645,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
- const int64_t shift = resultETy.isa<IntegerType>() ? getShift() : 0;
+ const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
if (rhsTy == resultTy) {
if (isSplatZero(resultETy, lhsAttr))
return lhsAttr;
@@ -662,9 +663,9 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
- auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
- auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
- auto resultTy = getType().dyn_cast<RankedTensorType>();
+ auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
+ auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
+ auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!lhsTy || !rhsTy || !resultTy)
return {};
@@ -711,7 +712,7 @@ struct APIntFoldGreaterEqual {
} // namespace
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
- auto resultTy = getType().dyn_cast<RankedTensorType>();
+ auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
@@ -723,7 +724,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
- auto resultTy = getType().dyn_cast<RankedTensorType>();
+ auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
@@ -736,16 +737,16 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
- auto resultTy = getType().dyn_cast<RankedTensorType>();
+ auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
Value lhs = getInput1();
Value rhs = getInput2();
- auto lhsTy = lhs.getType().cast<ShapedType>();
+ auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
// If we are comparing an integer value to itself it is always true. We can
// not do this with float due to float values.
- if (lhsTy.getElementType().isa<IntegerType>() && resultTy &&
+ if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
resultTy.hasStaticShape() && lhs == rhs) {
return DenseElementsAttr::get(resultTy, true);
}
@@ -766,41 +767,41 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
if (!operand)
return {};
- auto inTy = getInput().getType().cast<ShapedType>();
- auto outTy = getType().cast<ShapedType>();
+ auto inTy = llvm::cast<ShapedType>(getInput().getType());
+ auto outTy = llvm::cast<ShapedType>(getType());
auto inETy = inTy.getElementType();
auto outETy = outTy.getElementType();
if (operand.isSplat()) {
- if (inETy.isa<FloatType>() && outETy.isa<FloatType>()) {
+ if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
bool overflow;
auto splatVal = operand.getSplatValue<APFloat>();
- auto &semantics = outETy.cast<FloatType>().getFloatSemantics();
+ auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
&overflow);
return SplatElementsAttr::get(outTy, splatVal);
}
- if (inETy.isa<IntegerType>() && outETy.isa<FloatType>()) {
- auto unsign = inETy.cast<IntegerType>().isUnsignedInteger();
- APFloat splatVal(outETy.cast<FloatType>().getFloatSemantics());
+ if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
+ auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
+ APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
llvm::RoundingMode::NearestTiesToEven);
return SplatElementsAttr::get(outTy, splatVal);
}
- if (inETy.isa<FloatType>() && outETy.isa<IntegerType>()) {
- auto unsign = outETy.cast<IntegerType>().isUnsignedInteger();
- auto intVal =
- APSInt(outETy.cast<IntegerType>().getIntOrFloatBitWidth(), unsign);
+ if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
+ auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
+ auto intVal = APSInt(
+ llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
auto floatVal = operand.getSplatValue<APFloat>();
bool exact;
floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
return SplatElementsAttr::get(outTy, intVal);
}
- if (inETy.isa<IntegerType>() && outETy.isa<IntegerType>()) {
- auto unsignIn = inETy.cast<IntegerType>().isUnsignedInteger();
+ if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
+ auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
bool trunc =
inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
auto intVal = operand.getSplatValue<APInt>();
@@ -842,8 +843,8 @@ REDUCE_FOLDER(ReduceSumOp)
#undef REDUCE_FOLDER
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
- auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
- auto outputTy = getType().dyn_cast<RankedTensorType>();
+ auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
+ auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputTy || !outputTy)
return {};
@@ -894,8 +895,8 @@ OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
}
auto input = getInput();
- auto inputTy = input.getType().cast<RankedTensorType>();
- auto resultTy = getType().cast<RankedTensorType>();
+ auto inputTy = llvm::cast<RankedTensorType>(input.getType());
+ auto resultTy = llvm::cast<RankedTensorType>(getType());
if (inputTy != resultTy)
return {};
@@ -904,7 +905,7 @@ OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput();
- auto operandTy = operand.getType().cast<ShapedType>();
+ auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
auto operandAttr = adaptor.getInput().dyn_cast_or_null<SplatElementsAttr>();
if (operandAttr)
@@ -918,8 +919,8 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
- auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
- auto outputTy = getType().dyn_cast<RankedTensorType>();
+ auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+ auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputTy || !outputTy)
return {};
@@ -972,8 +973,8 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
- auto inputTy = getInput1().getType().cast<ShapedType>();
- auto resultTy = getType().cast<ShapedType>();
+ auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
+ auto resultTy = llvm::cast<ShapedType>(getType());
// Transposing splat values just means reshaping.
if (auto input = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>()) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5c17d281c2ec7..3070b63647a5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -90,8 +90,9 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
// Tosa dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes.
- if (value.isa<ElementsAttr>())
- return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
+ if (llvm::isa<ElementsAttr>(value))
+ return builder.create<tosa::ConstOp>(loc, type,
+ llvm::cast<ElementsAttr>(value));
return nullptr;
}
@@ -101,10 +102,8 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
template <typename T> static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
- auto inputType =
- op.getInput().getType().template dyn_cast<RankedTensorType>();
- auto weightType =
- op.getWeight().getType().template dyn_cast<RankedTensorType>();
+ auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+ auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
// Must be ranked tensor types
if (!inputType) {
@@ -119,8 +118,8 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
auto inputEType = inputType.getElementType();
auto weightEType = weightType.getElementType();
- bool inputIsQuant = !inputEType.template isa<FloatType>();
- bool weightIsQuant = !weightEType.template isa<FloatType>();
+ bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
+ bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
// Either both must be quantized or both unquantized.
if (inputIsQuant != weightIsQuant) {
@@ -143,13 +142,15 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
}
LogicalResult tosa::AvgPool2dOp::verify() {
- auto inputETy = getInput().getType().cast<ShapedType>().getElementType();
- auto resultETy = getType().cast<ShapedType>().getElementType();
+ auto inputETy = llvm::cast<ShapedType>(getInput().getType()).getElementType();
+ auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
- if (auto quantType = inputETy.dyn_cast<mlir::quant::UniformQuantizedType>())
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
inputETy = quantType.getStorageType();
- if (auto quantType = resultETy.dyn_cast<mlir::quant::UniformQuantizedType>())
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
resultETy = quantType.getStorageType();
if (inputETy.isF32() && resultETy.isF32())
@@ -240,16 +241,16 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
if (quantAttr) {
result.addAttribute("quantization_info", quantAttr);
- auto inputType = a.getType().dyn_cast<ShapedType>();
+ auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
assert(inputType && "Input must be a shaped tensor type!");
- auto inputQType = inputType.getElementType()
- .dyn_cast<mlir::quant::UniformQuantizedType>();
+ auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
+ inputType.getElementType());
assert(inputQType && "Tensor must have quantized datatype!");
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
- auto outputShapedType = outputType.dyn_cast<ShapedType>();
+ auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
assert(outputShapedType && "Output must be a shaped type");
IntegerType accElementType;
@@ -368,7 +369,7 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
- IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
+ IntegerAttr axis = llvm::cast<IntegerAttr>(attributes.get("axis"));
int32_t axisVal = axis.getValue().getSExtValue();
if (!inputShape.hasRank()) {
@@ -432,7 +433,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Infer all dimension sizes by reducing based on inputs.
int32_t axis =
- attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
+ llvm::cast<IntegerAttr>(attributes.get("axis")).getValue().getSExtValue();
llvm::SmallVector<int64_t> outputShape;
bool hasRankedInput = false;
for (auto operand : operands) {
@@ -459,7 +460,8 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
hasRankedInput = true;
}
- Type inputType = operands.getType()[0].cast<TensorType>().getElementType();
+ Type inputType =
+ llvm::cast<TensorType>(operands.getType()[0]).getElementType();
if (!hasRankedInput) {
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
return success();
@@ -738,8 +740,8 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}
mlir::LogicalResult tosa::ReshapeOp::verify() {
- ShapedType inputType = getInput1().getType().cast<ShapedType>();
- ShapedType outputType = getType().cast<ShapedType>();
+ ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
+ ShapedType outputType = llvm::cast<ShapedType>(getType());
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
int64_t inputElementsNum = inputType.getNumElements();
@@ -1064,9 +1066,11 @@ static LogicalResult poolingInferReturnTypes(
int64_t height = inputShape.getDimSize(1);
int64_t width = inputShape.getDimSize(2);
- ArrayRef<int64_t> kernel = attributes.get("kernel").cast<DenseI64ArrayAttr>();
- ArrayRef<int64_t> stride = attributes.get("stride").cast<DenseI64ArrayAttr>();
- ArrayRef<int64_t> pad = attributes.get("pad").cast<DenseI64ArrayAttr>();
+ ArrayRef<int64_t> kernel =
+ llvm::cast<DenseI64ArrayAttr>(attributes.get("kernel"));
+ ArrayRef<int64_t> stride =
+ llvm::cast<DenseI64ArrayAttr>(attributes.get("stride"));
+ ArrayRef<int64_t> pad = llvm::cast<DenseI64ArrayAttr>(attributes.get("pad"));
if (!ShapedType::isDynamic(height)) {
int64_t padded = height + pad[0] + pad[1] - kernel[0];
@@ -1473,7 +1477,7 @@ LogicalResult WhileOp::inferReturnTypeComponents(
}
std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
- if (auto vt = getType().dyn_cast<VectorType>())
+ if (auto vt = llvm::dyn_cast<VectorType>(getType()))
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 5d29e0bd3b3d9..6780c3bad9685 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -169,7 +169,7 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
return success();
}
if (attribute.getName().getValue() == kTargetTagAttrName) {
- if (!attribute.getValue().isa<StringAttr>()) {
+ if (!llvm::isa<StringAttr>(attribute.getValue())) {
return op->emitError()
<< attribute.getName() << " attribute must be a string";
}
@@ -177,7 +177,7 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute(
}
if (attribute.getName().getValue() == kArgConsumedAttrName ||
attribute.getName().getValue() == kArgReadOnlyAttrName) {
- if (!attribute.getValue().isa<UnitAttr>()) {
+ if (!llvm::isa<UnitAttr>(attribute.getValue())) {
return op->emitError()
<< attribute.getName() << " must be a unit attribute";
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 9c2baa3d41c23..4f3afd8e695d7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -114,7 +114,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
function_ref<LogicalResult(ValueRange)> valuesFn) {
- if (handle.getType().isa<transform::TransformHandleTypeInterface>()) {
+ if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
SmallVector<Operation *> operations;
operations.reserve(values.size());
for (transform::MappedValue value : values) {
@@ -130,7 +130,8 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
return DiagnosedSilenceableFailure::success();
}
- if (handle.getType().isa<transform::TransformValueHandleTypeInterface>()) {
+ if (llvm::isa<transform::TransformValueHandleTypeInterface>(
+ handle.getType())) {
SmallVector<Value> payloadValues;
payloadValues.reserve(values.size());
for (transform::MappedValue value : values) {
@@ -146,7 +147,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
return DiagnosedSilenceableFailure::success();
}
- assert(handle.getType().isa<transform::TransformParamTypeInterface>() &&
+ assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
"unsupported kind of block argument");
SmallVector<transform::Param> parameters;
parameters.reserve(values.size());
@@ -185,7 +186,7 @@ transform::TransformState::setPayloadOps(Value value,
ArrayRef<Operation *> targets) {
assert(value != kTopLevelValue &&
"attempting to reset the transformation root");
- assert(value.getType().isa<TransformHandleTypeInterface>() &&
+ assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
"wrong handle type");
for (Operation *target : targets) {
@@ -195,7 +196,7 @@ transform::TransformState::setPayloadOps(Value value,
<< "attempting to assign a null payload op to this transform value";
}
- auto iface = value.getType().cast<TransformHandleTypeInterface>();
+ auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
DiagnosedSilenceableFailure result =
iface.checkPayload(value.getLoc(), targets);
if (failed(result.checkAndReport()))
@@ -220,7 +221,7 @@ LogicalResult
transform::TransformState::setPayloadValues(Value handle,
ValueRange payloadValues) {
assert(handle != nullptr && "attempting to set params for a null value");
- assert(handle.getType().isa<TransformValueHandleTypeInterface>() &&
+ assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
"wrong handle type");
for (Value payload : payloadValues) {
@@ -230,7 +231,7 @@ transform::TransformState::setPayloadValues(Value handle,
"value to this transform handle";
}
- auto iface = handle.getType().cast<TransformValueHandleTypeInterface>();
+ auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
DiagnosedSilenceableFailure result =
iface.checkPayload(handle.getLoc(), payloadValueVector);
@@ -262,7 +263,7 @@ LogicalResult transform::TransformState::setParams(Value value,
<< "attempting to assign a null parameter to this transform value";
}
- auto valueType = value.getType().dyn_cast<TransformParamTypeInterface>();
+ auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
assert(value &&
"cannot associate parameter with a value of non-parameter type");
DiagnosedSilenceableFailure result =
@@ -497,11 +498,11 @@ void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
Operation *definingOp;
std::optional<unsigned> resultNo;
unsigned argumentNo, blockNo, regionNo;
- if (auto opResult = payloadValue.dyn_cast<OpResult>()) {
+ if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
definingOp = opResult.getOwner();
resultNo = opResult.getResultNumber();
} else {
- auto arg = payloadValue.cast<BlockArgument>();
+ auto arg = llvm::cast<BlockArgument>(payloadValue);
definingOp = arg.getParentBlock()->getParentOp();
argumentNo = arg.getArgNumber();
blockNo = std::distance(arg.getOwner()->getParent()->begin(),
@@ -602,11 +603,11 @@ void transform::TransformState::recordValueHandleInvalidation(
};
}
- if (auto opResult = payloadValue.dyn_cast<OpResult>()) {
+ if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
Operation *payloadOp = opResult.getOwner();
recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue);
} else {
- auto arg = payloadValue.dyn_cast<BlockArgument>();
+ auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
for (Operation &payloadOp : *arg.getOwner())
recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue);
}
@@ -642,13 +643,12 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
};
if (llvm::any_of(effects, consumesTarget)) {
FULL_LDBG("----found consume effect -> SKIP\n");
- if (target.get().getType().isa<TransformHandleTypeInterface>()) {
+ if (llvm::isa<TransformHandleTypeInterface>(target.get().getType())) {
FULL_LDBG("----recordOpHandleInvalidation\n");
ArrayRef<Operation *> payloadOps = getPayloadOps(target.get());
recordOpHandleInvalidation(target, payloadOps);
- } else if (target.get()
- .getType()
- .isa<TransformValueHandleTypeInterface>()) {
+ } else if (llvm::isa<TransformValueHandleTypeInterface>(
+ target.get().getType())) {
FULL_LDBG("----recordValueHandleInvalidation\n");
recordValueHandleInvalidation(target);
} else {
@@ -717,7 +717,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("--handle is consumed\n");
Type operandType = operand.get().getType();
- if (operandType.isa<TransformHandleTypeInterface>()) {
+ if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Operation *>(
@@ -727,7 +727,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("----FAILED\n");
return check;
}
- } else if (operandType.isa<TransformValueHandleTypeInterface>()) {
+ } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
DiagnosedSilenceableFailure check =
checkRepeatedConsumptionInOperand<Value>(
@@ -794,7 +794,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
for (unsigned index : consumedOperands) {
Value operand = transform->getOperand(index);
- if (operand.getType().isa<TransformHandleTypeInterface>()) {
+ if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
for (Operation *payloadOp : getPayloadOps(operand)) {
llvm::append_range(origOpFlatResults, payloadOp->getResults());
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -808,15 +808,15 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
continue;
}
- if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
+ if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
for (Value payloadValue : getPayloadValues(operand)) {
- if (payloadValue.isa<OpResult>()) {
+ if (llvm::isa<OpResult>(payloadValue)) {
origAssociatedOps.push_back(payloadValue.getDefiningOp());
continue;
}
llvm::append_range(
origAssociatedOps,
- llvm::map_range(*payloadValue.cast<BlockArgument>().getOwner(),
+ llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
[](Operation &op) { return &op; }));
}
continue;
@@ -847,9 +847,10 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// allows us to catch use-after-free with assertions later on.
for (unsigned index : consumedOperands) {
Value operand = transform->getOperand(index);
- if (operand.getType().isa<TransformHandleTypeInterface>()) {
+ if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
forgetMapping(operand, origOpFlatResults);
- } else if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
+ } else if (llvm::isa<TransformValueHandleTypeInterface>(
+ operand.getType())) {
forgetValueMapping(operand, origAssociatedOps);
}
}
@@ -923,14 +924,14 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
LogicalResult transform::TransformState::updateStateFromResults(
const TransformResults &results, ResultRange opResults) {
for (OpResult result : opResults) {
- if (result.getType().isa<TransformParamTypeInterface>()) {
+ if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
assert(results.isParam(result.getResultNumber()) &&
"expected parameters for the parameter-typed result");
if (failed(
setParams(result, results.getParams(result.getResultNumber())))) {
return failure();
}
- } else if (result.getType().isa<TransformValueHandleTypeInterface>()) {
+ } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
assert(results.isValue(result.getResultNumber()) &&
"expected values for value-type-result");
if (failed(setPayloadValues(
@@ -1137,19 +1138,19 @@ transform::detail::checkApplyToOne(Operation *transformOp,
llvm::zip(partialResult, transformOp->getResults())) {
if (ptr.isNull())
continue;
- if (res.getType().template isa<TransformHandleTypeInterface>() &&
+ if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
!ptr.is<Operation *>()) {
return emitDiag() << "application of " << transformOpName
<< " expected to produce an Operation * for result #"
<< res.getResultNumber();
}
- if (res.getType().template isa<TransformParamTypeInterface>() &&
+ if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
!ptr.is<Attribute>()) {
return emitDiag() << "application of " << transformOpName
<< " expected to produce an Attribute for result #"
<< res.getResultNumber();
}
- if (res.getType().template isa<TransformValueHandleTypeInterface>() &&
+ if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
!ptr.is<Value>()) {
return emitDiag() << "application of " << transformOpName
<< " expected to produce a Value for result #"
@@ -1182,10 +1183,10 @@ void transform::detail::setApplyToOneResults(
for (OpResult r : transformOp->getResults()) {
unsigned position = r.getResultNumber();
- if (r.getType().isa<TransformParamTypeInterface>()) {
+ if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
transformResults.setParams(r,
castVector<Attribute>(transposed[position]));
- } else if (r.getType().isa<TransformValueHandleTypeInterface>()) {
+ } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
transformResults.setValues(r, castVector<Value>(transposed[position]));
} else {
transformResults.set(r, castVector<Operation *>(transposed[position]));
@@ -1202,12 +1203,13 @@ void transform::detail::prepareValueMappings(
ValueRange values, const transform::TransformState &state) {
for (Value operand : values) {
SmallVector<MappedValue> &mapped = mappings.emplace_back();
- if (operand.getType().isa<TransformHandleTypeInterface>()) {
+ if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
llvm::append_range(mapped, state.getPayloadOps(operand));
- } else if (operand.getType().isa<TransformValueHandleTypeInterface>()) {
+ } else if (llvm::isa<TransformValueHandleTypeInterface>(
+ operand.getType())) {
llvm::append_range(mapped, state.getPayloadValues(operand));
} else {
- assert(operand.getType().isa<TransformParamTypeInterface>() &&
+ assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
"unsupported kind of transform dialect value");
llvm::append_range(mapped, state.getParams(operand));
}
@@ -1220,14 +1222,15 @@ void transform::detail::forwardTerminatorOperands(
for (auto &&[terminatorOperand, result] :
llvm::zip(block->getTerminator()->getOperands(),
block->getParentOp()->getOpResults())) {
- if (result.getType().isa<transform::TransformHandleTypeInterface>()) {
+ if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
results.set(result, state.getPayloadOps(terminatorOperand));
- } else if (result.getType()
- .isa<transform::TransformValueHandleTypeInterface>()) {
+ } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
+ result.getType())) {
results.setValues(result, state.getPayloadValues(terminatorOperand));
} else {
- assert(result.getType().isa<transform::TransformParamTypeInterface>() &&
- "unhandled transform type interface");
+ assert(
+ llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
+ "unhandled transform type interface");
results.setParams(result, state.getParams(terminatorOperand));
}
}
@@ -1291,7 +1294,8 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
return op->emitOpError()
<< "expects the entry block to have at least one argument";
}
- if (!body->getArgument(0).getType().isa<TransformHandleTypeInterface>()) {
+ if (!llvm::isa<TransformHandleTypeInterface>(
+ body->getArgument(0).getType())) {
return op->emitOpError()
<< "expects the first entry block argument to be of type "
"implementing TransformHandleTypeInterface";
@@ -1305,9 +1309,8 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
}
}
for (BlockArgument arg : body->getArguments().drop_front()) {
- if (arg.getType()
- .isa<TransformHandleTypeInterface, TransformParamTypeInterface,
- TransformValueHandleTypeInterface>())
+ if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
+ TransformValueHandleTypeInterface>(arg.getType()))
continue;
InFlightDiagnostic diag =
@@ -1344,9 +1347,8 @@ void transform::detail::getParamProducerTransformOpTraitEffects(
bool hasPayloadOperands = false;
for (Value operand : op->getOperands()) {
onlyReadsHandle(operand, effects);
- if (operand.getType()
- .isa<TransformHandleTypeInterface,
- TransformValueHandleTypeInterface>())
+ if (llvm::isa<TransformHandleTypeInterface,
+ TransformValueHandleTypeInterface>(operand.getType()))
hasPayloadOperands = true;
}
if (hasPayloadOperands)
@@ -1364,7 +1366,7 @@ transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
op->getName().getStringRef());
}
for (Value result : op->getResults()) {
- if (result.getType().isa<TransformParamTypeInterface>())
+ if (llvm::isa<TransformParamTypeInterface>(result.getType()))
continue;
return op->emitOpError()
<< "ParamProducerTransformOpTrait attached to this op expects "
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 62ef94d54e477..7451193d51fd0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -440,8 +440,8 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return llvm::all_of(
std::initializer_list<Type>{inputs.front(), outputs.front()},
[](Type ty) {
- return ty
- .isa<pdl::OperationType, transform::TransformHandleTypeInterface>();
+ return llvm::isa<pdl::OperationType,
+ transform::TransformHandleTypeInterface>(ty);
});
}
@@ -563,7 +563,8 @@ transform::ForeachMatchOp::apply(transform::TransformResults &results,
// the payload to the result. Note that we need to consume the root handle to
// make sure any handles to operations inside, that could have been affected
// by actions, are invalidated.
- results.set(getUpdated().cast<OpResult>(), state.getPayloadOps(getRoot()));
+ results.set(llvm::cast<OpResult>(getUpdated()),
+ state.getPayloadOps(getRoot()));
return DiagnosedSilenceableFailure::success();
}
@@ -810,7 +811,7 @@ transform::ForeachOp::apply(transform::TransformResults &results,
}
for (unsigned i = 0; i < getNumResults(); ++i)
- results.set(getResult(i).cast<OpResult>(), resultOps[i]);
+ results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
return DiagnosedSilenceableFailure::success();
}
@@ -863,7 +864,7 @@ LogicalResult transform::ForeachOp::verify() {
return emitOpError() << "expects the same number of results as the "
"terminator has operands";
for (Value v : yieldOp.getOperands())
- if (!v.getType().isa<TransformHandleTypeInterface>())
+ if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
return yieldOp->emitOpError("expects operands to have types implementing "
"TransformHandleTypeInterface");
return success();
@@ -888,7 +889,7 @@ DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
}
parents.insert(parent);
}
- results.set(getResult().cast<OpResult>(), parents.getArrayRef());
+ results.set(llvm::cast<OpResult>(getResult()), parents.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
@@ -902,7 +903,7 @@ transform::GetConsumersOfResult::apply(transform::TransformResults &results,
int64_t resultNumber = getResultNumber();
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
if (payloadOps.empty()) {
- results.set(getResult().cast<OpResult>(), {});
+ results.set(llvm::cast<OpResult>(getResult()), {});
return DiagnosedSilenceableFailure::success();
}
if (payloadOps.size() != 1)
@@ -912,7 +913,7 @@ transform::GetConsumersOfResult::apply(transform::TransformResults &results,
Operation *target = payloadOps.front();
if (target->getNumResults() <= resultNumber)
return emitDefiniteFailure() << "result number overflow";
- results.set(getResult().cast<OpResult>(),
+ results.set(llvm::cast<OpResult>(getResult()),
llvm::to_vector(target->getResult(resultNumber).getUsers()));
return DiagnosedSilenceableFailure::success();
}
@@ -926,7 +927,7 @@ transform::GetDefiningOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> definingOps;
for (Value v : state.getPayloadValues(getTarget())) {
- if (v.isa<BlockArgument>()) {
+ if (llvm::isa<BlockArgument>(v)) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "cannot get defining op of block argument";
diag.attachNote(v.getLoc()) << "target value";
@@ -934,7 +935,7 @@ transform::GetDefiningOp::apply(transform::TransformResults &results,
}
definingOps.push_back(v.getDefiningOp());
}
- results.set(getResult().cast<OpResult>(), definingOps);
+ results.set(llvm::cast<OpResult>(getResult()), definingOps);
return DiagnosedSilenceableFailure::success();
}
@@ -962,7 +963,7 @@ transform::GetProducerOfOperand::apply(transform::TransformResults &results,
}
producers.push_back(producer);
}
- results.set(getResult().cast<OpResult>(), producers);
+ results.set(llvm::cast<OpResult>(getResult()), producers);
return DiagnosedSilenceableFailure::success();
}
@@ -984,7 +985,7 @@ transform::GetResultOp::apply(transform::TransformResults &results,
}
opResults.push_back(target->getOpResult(resultNumber));
}
- results.setValues(getResult().cast<OpResult>(), opResults);
+ results.setValues(llvm::cast<OpResult>(getResult()), opResults);
return DiagnosedSilenceableFailure::success();
}
@@ -1211,8 +1212,8 @@ transform::MatchParamCmpIOp::apply(transform::TransformResults &results,
}
for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
- auto intAttr = param.dyn_cast<IntegerAttr>();
- auto refAttr = reference.dyn_cast<IntegerAttr>();
+ auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
+ auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
if (!intAttr || !refAttr) {
return emitDefiniteFailure()
<< "non-integer parameter value not expected";
@@ -1295,12 +1296,12 @@ transform::MergeHandlesOp::apply(transform::TransformResults &results,
for (Value operand : getHandles())
llvm::append_range(operations, state.getPayloadOps(operand));
if (!getDeduplicate()) {
- results.set(getResult().cast<OpResult>(), operations);
+ results.set(llvm::cast<OpResult>(getResult()), operations);
return DiagnosedSilenceableFailure::success();
}
SetVector<Operation *> uniqued(operations.begin(), operations.end());
- results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
+ results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
@@ -1535,7 +1536,7 @@ transform::SplitHandleOp::apply(transform::TransformResults &results,
// Set transform op results.
for (auto &&it : llvm::enumerate(resultHandles))
- results.set(getResult(it.index()).cast<OpResult>(), it.value());
+ results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
return DiagnosedSilenceableFailure::success();
}
@@ -1573,7 +1574,7 @@ transform::PDLMatchOp::apply(transform::TransformResults &results,
<< "could not find pattern '" << getPatternName() << "'";
}
}
- results.set(getResult().cast<OpResult>(), targets);
+ results.set(llvm::cast<OpResult>(getResult()), targets);
return DiagnosedSilenceableFailure::success();
}
@@ -1594,22 +1595,23 @@ transform::ReplicateOp::apply(transform::TransformResults &results,
unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
for (const auto &en : llvm::enumerate(getHandles())) {
Value handle = en.value();
- if (handle.getType().isa<TransformHandleTypeInterface>()) {
+ if (llvm::isa<TransformHandleTypeInterface>(handle.getType())) {
ArrayRef<Operation *> current = state.getPayloadOps(handle);
SmallVector<Operation *> payload;
payload.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(payload, current);
- results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
+ results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
} else {
- assert(handle.getType().isa<TransformParamTypeInterface>() &&
+ assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
"expected param type");
ArrayRef<Attribute> current = state.getParams(handle);
SmallVector<Attribute> params;
params.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(params, current);
- results.setParams(getReplicated()[en.index()].cast<OpResult>(), params);
+ results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
+ params);
}
}
return DiagnosedSilenceableFailure::success();
diff --git a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
index 579e1065e9521..d67016efd0971 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformTypes.cpp
@@ -75,7 +75,7 @@ transform::OperationType::checkPayload(Location loc,
LogicalResult
transform::ParamType::verify(function_ref<InFlightDiagnostic()> emitError,
Type type) {
- IntegerType intType = type.dyn_cast<IntegerType>();
+ IntegerType intType = llvm::dyn_cast<IntegerType>(type);
if (!intType || intType.getWidth() > 64)
return emitError() << "only supports integer types with width <=64";
return success();
@@ -85,7 +85,7 @@ DiagnosedSilenceableFailure
transform::ParamType::checkPayload(Location loc,
ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
- auto integerAttr = attr.dyn_cast<IntegerAttr>();
+ auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
if (!integerAttr) {
return emitSilenceableError(loc)
<< "expected parameter to be an integer attribute, got " << attr;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 99eed540afb0e..2d3b27f1af58f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -65,7 +65,7 @@ static MaskFormat getMaskFormat(Value mask) {
// Inspect constant dense values. We count up for bits that
// are set, count down for bits that are cleared, and bail
// when a mix is detected.
- if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) {
+ if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
int64_t val = 0;
for (bool b : denseElts.getValues<bool>())
if (b && val >= 0)
@@ -88,7 +88,7 @@ static MaskFormat getMaskFormat(Value mask) {
bool allTrue = true;
bool allFalse = true;
for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
- int64_t i = maskIdx.cast<IntegerAttr>().getInt();
+ int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
if (i < dimSize)
allTrue = false;
if (i > 0)
@@ -125,7 +125,7 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
return elementType.isIntOrIndex();
case CombiningKind::MINF:
case CombiningKind::MAXF:
- return elementType.isa<FloatType>();
+ return llvm::isa<FloatType>(elementType);
}
return false;
}
@@ -143,7 +143,7 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
VectorType vectorType) {
int64_t elementVectorRank = 0;
VectorType elementVectorType =
- shapedType.getElementType().dyn_cast<VectorType>();
+ llvm::dyn_cast<VectorType>(shapedType.getElementType());
if (elementVectorType)
elementVectorRank += elementVectorType.getRank();
// 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
@@ -190,15 +190,15 @@ bool mlir::vector::isDisjointTransferIndices(
if (i < rankOffset) {
// For leading dimensions, if we can prove that index are
diff erent we
// know we are accessing disjoint slices.
- if (indexA.getValue().cast<IntegerAttr>().getInt() !=
- indexB.getValue().cast<IntegerAttr>().getInt())
+ if (llvm::cast<IntegerAttr>(indexA.getValue()).getInt() !=
+ llvm::cast<IntegerAttr>(indexB.getValue()).getInt())
return true;
} else {
// For this dimension, we slice a part of the memref we need to make sure
// the intervals accessed don't overlap.
int64_t distance =
- std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
- indexB.getValue().cast<IntegerAttr>().getInt());
+ std::abs(llvm::cast<IntegerAttr>(indexA.getValue()).getInt() -
+ llvm::cast<IntegerAttr>(indexB.getValue()).getInt());
if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
return true;
}
@@ -325,7 +325,7 @@ LogicalResult MultiDimReductionOp::verify() {
Type inferredReturnType;
for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
- return attr.cast<IntegerAttr>().getValue() == it.index();
+ return llvm::cast<IntegerAttr>(attr).getValue() == it.index();
}))
targetShape.push_back(it.value());
// TODO: update to also allow 0-d vectors when available.
@@ -426,8 +426,9 @@ void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
CombiningKind kind, Value vector, Value acc) {
- build(builder, result, vector.getType().cast<VectorType>().getElementType(),
- kind, vector, acc);
+ build(builder, result,
+ llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
+ acc);
}
LogicalResult ReductionOp::verify() {
@@ -659,9 +660,8 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
// because tests still use the old format when 'iterator_types' attribute is
// represented as an array of strings.
// TODO: Remove this conversion once tests are fixed.
- ArrayAttr iteratorTypes =
- result.attributes.get(getIteratorTypesAttrName(result.name))
- .cast<ArrayAttr>();
+ ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
+ result.attributes.get(getIteratorTypesAttrName(result.name)));
SmallVector<Attribute> iteratorTypeAttrs;
@@ -687,8 +687,8 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
if (masksInfo.size() != 2)
return parser.emitError(parser.getNameLoc(),
"expected zero or exactly 2 vector mask operands");
- auto lhsType = types[0].cast<VectorType>();
- auto rhsType = types[1].cast<VectorType>();
+ auto lhsType = llvm::cast<VectorType>(types[0]);
+ auto rhsType = llvm::cast<VectorType>(types[1]);
auto maskElementType = parser.getBuilder().getI1Type();
std::array<Type, 2> maskTypes = {
VectorType::Builder(lhsType).setElementType(maskElementType),
@@ -707,8 +707,7 @@ void ContractionOp::print(OpAsmPrinter &p) {
for (auto attr : (*this)->getAttrs()) {
if (attr.getName() == getIteratorTypesAttrName()) {
auto iteratorTypes =
- attr.getValue()
- .cast<ArrayAttr>()
+ llvm::cast<ArrayAttr>(attr.getValue())
.getAsValueRange<IteratorTypeAttr, IteratorType>();
// Convert IteratorType enums into the string representation. This is
// needed, because tests still use the old format when 'iterator_types'
@@ -778,12 +777,12 @@ static LogicalResult verifyOutputShape(
// Verify 'expectedResultDims'.
if (expectedResultDims.empty()) {
// No batch or free dimension implies a scalar result.
- if (resType.isa<VectorType>() || accType.isa<VectorType>())
+ if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
return op.emitOpError("invalid accumulator/result vector shape");
} else {
// At least one batch or free dimension implies a vector result.
- auto resVectorType = resType.dyn_cast<VectorType>();
- auto accVectorType = accType.dyn_cast<VectorType>();
+ auto resVectorType = llvm::dyn_cast<VectorType>(resType);
+ auto accVectorType = llvm::dyn_cast<VectorType>(accType);
if (!resVectorType || !accVectorType)
return op.emitOpError("invalid accumulator/result vector shape");
@@ -841,7 +840,7 @@ LogicalResult ContractionOp::verify() {
Type accType = getAccType();
Type resType = getResultType();
- if (lhsType.getElementType().isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(lhsType.getElementType())) {
if (!lhsType.getElementType().isSignlessInteger())
return emitOpError("only supports signless integer types");
}
@@ -860,7 +859,7 @@ LogicalResult ContractionOp::verify() {
if (map.getNumSymbols() != 0)
return emitOpError("expected indexing map ")
<< index << " to have no symbols";
- auto vectorType = getOperand(index).getType().dyn_cast<VectorType>();
+ auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
unsigned rank = vectorType ? vectorType.getShape().size() : 0;
// Verify that the map has the right number of inputs, outputs, and indices.
// This also correctly accounts for (..) -> () for rank-0 results.
@@ -896,7 +895,7 @@ LogicalResult ContractionOp::verify() {
return failure();
// Verify supported combining kind.
- auto vectorType = resType.dyn_cast<VectorType>();
+ auto vectorType = llvm::dyn_cast<VectorType>(resType);
auto elementType = vectorType ? vectorType.getElementType() : resType;
if (!isSupportedCombiningKind(getKind(), elementType))
return emitOpError("unsupported contraction type");
@@ -949,7 +948,7 @@ getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
IteratorType targetIteratorType, MLIRContext *context) {
std::vector<std::pair<int64_t, int64_t>> dimMap;
for (const auto &it : llvm::enumerate(iteratorTypes)) {
- auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
+ auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
if (iteratorType != targetIteratorType)
continue;
// Search lhs/rhs map results for 'targetExpr'.
@@ -965,13 +964,13 @@ getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
void ContractionOp::getIterationBounds(
SmallVectorImpl<int64_t> &iterationBounds) {
auto lhsShape = getLhsType().getShape();
- auto resVectorType = getResultType().dyn_cast<VectorType>();
+ auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
SmallVector<int64_t, 2> iterationShape;
for (const auto &it : llvm::enumerate(getIteratorTypes())) {
// Search lhs/rhs map results for 'targetExpr'.
auto targetExpr = getAffineDimExpr(it.index(), getContext());
- auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
+ auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
if (iteratorType == IteratorType::reduction) {
// Get reduction dim size from lhs shape (same size in rhsShape).
int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
@@ -1085,7 +1084,7 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
Value source) {
result.addOperands({source});
- result.addTypes(source.getType().cast<VectorType>().getElementType());
+ result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
}
LogicalResult vector::ExtractElementOp::verify() {
@@ -1116,15 +1115,15 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// Fold extractelement(broadcast(X)) -> X.
if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
- if (!broadcast.getSource().getType().isa<VectorType>())
+ if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
return broadcast.getSource();
if (!pos || !src)
return {};
- auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();
+ auto srcElements = llvm::cast<DenseElementsAttr>(src).getValues<Attribute>();
- auto attr = pos.dyn_cast<IntegerAttr>();
+ auto attr = llvm::dyn_cast<IntegerAttr>(pos);
uint64_t posIdx = attr.getInt();
return srcElements[posIdx];
@@ -1155,7 +1154,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractOp::Adaptor op(operands, attributes);
- auto vectorType = op.getVector().getType().cast<VectorType>();
+ auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
} else {
@@ -1170,7 +1169,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
// Allow extracting 1-element vectors instead of scalars.
auto isCompatible = [](TypeRange l, TypeRange r) {
- auto vectorType = l.front().dyn_cast<VectorType>();
+ auto vectorType = llvm::dyn_cast<VectorType>(l.front());
return vectorType && vectorType.getShape().equals({1}) &&
vectorType.getElementType() == r.front();
};
@@ -1187,7 +1186,7 @@ LogicalResult vector::ExtractOp::verify() {
return emitOpError(
"expected position attribute of rank smaller than vector rank");
for (const auto &en : llvm::enumerate(positionAttr)) {
- auto attr = en.value().dyn_cast<IntegerAttr>();
+ auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
if (!attr || attr.getInt() < 0 ||
attr.getInt() >= getSourceVectorType().getDimSize(en.index()))
return emitOpError("expected position attribute #")
@@ -1451,7 +1450,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
if (extractOp.getType() == source.getType())
return source;
auto getRank = [](Type type) {
- return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
+ return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
+ : 0;
};
// If splat or broadcast from a scalar, just return the source scalar.
unsigned broadcastSrcRank = getRank(source.getType());
@@ -1462,8 +1462,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
if (extractResultRank >= broadcastSrcRank)
return Value();
// Check that the dimension of the result haven't been broadcasted.
- auto extractVecType = extractOp.getType().dyn_cast<VectorType>();
- auto broadcastVecType = source.getType().dyn_cast<VectorType>();
+ auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
+ auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
if (extractVecType && broadcastVecType &&
extractVecType.getShape() !=
broadcastVecType.getShape().take_back(extractResultRank))
@@ -1502,13 +1502,14 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
return type.getShape().take_back(n + 1).front();
};
int64_t destinationRank =
- extractOp.getType().isa<VectorType>()
- ? extractOp.getType().cast<VectorType>().getRank()
+ llvm::isa<VectorType>(extractOp.getType())
+ ? llvm::cast<VectorType>(extractOp.getType()).getRank()
: 0;
if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
return Value();
if (destinationRank > 0) {
- auto destinationType = extractOp.getResult().getType().cast<VectorType>();
+ auto destinationType =
+ llvm::cast<VectorType>(extractOp.getResult().getType());
for (int64_t i = 0; i < destinationRank; i++) {
// The lowest dimension of of the destination must match the lowest
// dimension of the shapecast op source.
@@ -1574,7 +1575,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
sliceOffsets.pop_back();
}
unsigned destinationRank = 0;
- if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
+ if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
destinationRank = vecType.getRank();
// The dimensions of the result need to be untouched by the
// extractStridedSlice op.
@@ -1595,8 +1596,8 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
/// Fold extract_op fed from a chain of insertStridedSlice ops.
static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
- int64_t destinationRank = op.getType().isa<VectorType>()
- ? op.getType().cast<VectorType>().getRank()
+ int64_t destinationRank = llvm::isa<VectorType>(op.getType())
+ ? llvm::cast<VectorType>(op.getType()).getRank()
: 0;
auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
while (insertOp) {
@@ -1608,7 +1609,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
auto extractOffsets = extractVector<int64_t>(op.getPosition());
if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
- return attr.cast<IntegerAttr>().getInt() != 1;
+ return llvm::cast<IntegerAttr>(attr).getInt() != 1;
}))
return Value();
bool disjoint = false;
@@ -1691,7 +1692,9 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
if (extractOp.getType() == source.getType())
return failure();
auto getRank = [](Type type) {
- return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
+ return llvm::isa<VectorType>(type)
+ ? llvm::cast<VectorType>(type).getRank()
+ : 0;
};
unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
@@ -1703,7 +1706,7 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
// Special case if broadcast src is a 0D vector.
if (extractResultRank == 0) {
- assert(broadcastSrcRank == 0 && source.getType().isa<VectorType>());
+ assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
return success();
}
@@ -1726,11 +1729,11 @@ class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
Attribute vectorCst;
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
return failure();
- auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
+ auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
if (!splat)
return failure();
TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
- if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
+ if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
return success();
@@ -1752,12 +1755,12 @@ class ExtractOpNonSplatConstantFolder final
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
return failure();
- auto vecTy = sourceVector.getType().cast<VectorType>();
+ auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
if (vecTy.isScalable())
return failure();
// The splat case is handled by `ExtractOpSplatConstantFolder`.
- auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
+ auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
if (!dense || dense.isSplat())
return failure();
@@ -1770,7 +1773,7 @@ class ExtractOpNonSplatConstantFolder final
auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
TypedAttr newAttr;
- if (auto resVecTy = extractOp.getType().dyn_cast<VectorType>()) {
+ if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
SmallVector<Attribute> elementValues(
denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
newAttr = DenseElementsAttr::get(resVecTy, elementValues);
@@ -1794,7 +1797,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
SmallVectorImpl<int64_t> &results) {
for (auto attr : arrayAttr)
- results.push_back(attr.cast<IntegerAttr>().getInt());
+ results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
}
//===----------------------------------------------------------------------===//
@@ -1830,7 +1833,7 @@ computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
// Scalar broadcast is without any unit dim broadcast.
- auto srcVectorType = getSourceType().dyn_cast<VectorType>();
+ auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
if (!srcVectorType)
return {};
return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
@@ -1867,7 +1870,7 @@ Value BroadcastOp::createOrFoldBroadcastOp(
Location loc = value.getLoc();
Type elementType = getElementTypeOrSelf(value.getType());
- VectorType srcVectorType = value.getType().dyn_cast<VectorType>();
+ VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
VectorType dstVectorType = VectorType::get(dstShape, elementType);
// Step 2. If scalar -> dstShape broadcast, just do it.
@@ -1952,7 +1955,7 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
return BroadcastableToResult::Success;
// From now on, only vectors broadcast.
- VectorType srcVectorType = srcType.dyn_cast<VectorType>();
+ VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
if (!srcVectorType)
return BroadcastableToResult::SourceTypeNotAVector;
@@ -2074,7 +2077,7 @@ LogicalResult ShuffleOp::verify() {
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
for (const auto &en : llvm::enumerate(maskAttr)) {
- auto attr = en.value().dyn_cast<IntegerAttr>();
+ auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
return emitOpError("mask index #") << (en.index() + 1) << " out of range";
}
@@ -2087,7 +2090,7 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ShuffleOp::Adaptor op(operands, attributes);
- auto v1Type = op.getV1().getType().cast<VectorType>();
+ auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
auto v1Rank = v1Type.getRank();
// Construct resulting type: leading dimension matches mask
// length, all trailing dimensions match the operands.
@@ -2132,7 +2135,8 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
if (!lhs || !rhs)
return {};
- auto lhsType = lhs.cast<DenseElementsAttr>().getType().cast<VectorType>();
+ auto lhsType =
+ llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType());
// Only support 1-D for now to avoid complicated n-D DenseElementsAttr
// manipulation.
if (lhsType.getRank() != 1)
@@ -2140,8 +2144,8 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
int64_t lhsSize = lhsType.getDimSize(0);
SmallVector<Attribute> results;
- auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>();
- auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>();
+ auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
+ auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
int64_t i = index.getZExtValue();
if (i >= lhsSize) {
@@ -2170,7 +2174,7 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
if (mask.size() != 1)
return failure();
Type resType = VectorType::Builder(v1VectorType).setShape({1});
- if (mask[0].cast<IntegerAttr>().getInt() == 0)
+ if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
shuffleOp.getV1());
else
@@ -2242,11 +2246,11 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
if (!src || !dst || !pos)
return {};
- auto dstElements = dst.cast<DenseElementsAttr>().getValues<Attribute>();
+ auto dstElements = llvm::cast<DenseElementsAttr>(dst).getValues<Attribute>();
SmallVector<Attribute> results(dstElements);
- auto attr = pos.dyn_cast<IntegerAttr>();
+ auto attr = llvm::dyn_cast<IntegerAttr>(pos);
uint64_t posIdx = attr.getInt();
results[posIdx] = src;
@@ -2282,7 +2286,7 @@ LogicalResult InsertOp::verify() {
if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
return emitOpError(
"expected position attribute of rank smaller than dest vector rank");
- auto srcVectorType = getSourceType().dyn_cast<VectorType>();
+ auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
if (srcVectorType &&
(static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
static_cast<unsigned>(destVectorType.getRank())))
@@ -2293,7 +2297,7 @@ LogicalResult InsertOp::verify() {
return emitOpError(
"expected position attribute rank to match the dest vector rank");
for (const auto &en : llvm::enumerate(positionAttr)) {
- auto attr = en.value().dyn_cast<IntegerAttr>();
+ auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
if (!attr || attr.getInt() < 0 ||
attr.getInt() >= destVectorType.getDimSize(en.index()))
return emitOpError("expected position attribute #")
@@ -2314,7 +2318,7 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
LogicalResult matchAndRewrite(InsertOp insertOp,
PatternRewriter &rewriter) const override {
- auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
+ auto srcVecType = llvm::dyn_cast<VectorType>(insertOp.getSourceType());
if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
srcVecType.getNumElements())
return failure();
@@ -2372,7 +2376,7 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
!destVector.hasOneUse())
return failure();
- auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
+ auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
Value sourceValue = op.getSource();
Attribute sourceCst;
@@ -2387,7 +2391,7 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
linearize(completePositions, computeStrides(destTy.getShape()));
SmallVector<Attribute> insertedValues;
- if (auto denseSource = sourceCst.dyn_cast<DenseElementsAttr>())
+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
else
insertedValues.push_back(sourceCst);
@@ -2455,7 +2459,7 @@ isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
int64_t max, StringRef attrName,
bool halfOpen = true) {
for (auto attr : arrayAttr) {
- auto val = attr.cast<IntegerAttr>().getInt();
+ auto val = llvm::cast<IntegerAttr>(attr).getInt();
auto upper = max;
if (!halfOpen)
upper += 1;
@@ -2476,8 +2480,7 @@ isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
bool halfOpen = true, int64_t min = 0) {
for (auto [index, attrDimPair] :
llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
- int64_t val =
- std::get<0>(attrDimPair).template cast<IntegerAttr>().getInt();
+ int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
int64_t max = std::get<1>(attrDimPair);
if (!halfOpen)
max += 1;
@@ -2501,8 +2504,8 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
assert(arrayAttr2.size() <= shape.size());
for (auto [index, it] :
llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
- auto val1 = std::get<0>(it).template cast<IntegerAttr>().getInt();
- auto val2 = std::get<1>(it).template cast<IntegerAttr>().getInt();
+ auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
+ auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
int64_t max = std::get<2>(it);
if (!halfOpen)
max += 1;
@@ -2643,7 +2646,7 @@ class InsertStridedSliceConstantFolder final
!destVector.hasOneUse())
return failure();
- auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
+ auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
TypedValue<VectorType> sourceValue = op.getSource();
Attribute sourceCst;
@@ -2666,7 +2669,7 @@ class InsertStridedSliceConstantFolder final
// increasing linearized position indices.
// Because the destination may have higher dimensionality then the slice,
// we keep track of two overlapping sets of positions and offsets.
- auto denseSlice = sourceCst.cast<DenseElementsAttr>();
+ auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
auto sliceValuesIt = denseSlice.value_begin<Attribute>();
auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
@@ -2735,8 +2738,8 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
if (operandsInfo.size() < 2)
return parser.emitError(parser.getNameLoc(),
"expected at least 2 operands");
- VectorType vLHS = tLHS.dyn_cast<VectorType>();
- VectorType vRHS = tRHS.dyn_cast<VectorType>();
+ VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
+ VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
if (!vLHS)
return parser.emitError(parser.getNameLoc(),
"expected vector type for operand #1");
@@ -2771,7 +2774,7 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
LogicalResult OuterProductOp::verify() {
Type tRHS = getOperandTypeRHS();
VectorType vLHS = getOperandVectorTypeLHS(),
- vRHS = tRHS.dyn_cast<VectorType>(),
+ vRHS = llvm::dyn_cast<VectorType>(tRHS),
vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
if (vLHS.getRank() != 1)
@@ -2897,7 +2900,7 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
shape.reserve(vectorType.getRank());
unsigned idx = 0;
for (unsigned e = offsets.size(); idx < e; ++idx)
- shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
+ shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
shape.push_back(vectorType.getShape()[idx]);
@@ -2913,7 +2916,7 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
auto stridesAttr = getVectorSubscriptAttr(builder, strides);
result.addTypes(
- inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
+ inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
offsetsAttr, sizesAttr, stridesAttr));
result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
result.addAttribute(getSizesAttrStrName(), sizesAttr);
@@ -2967,7 +2970,7 @@ static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
// Helper to extract integer out of ArrayAttr.
auto getElement = [](ArrayAttr array, int idx) {
- return array[idx].cast<IntegerAttr>().getInt();
+ return llvm::cast<IntegerAttr>(array[idx]).getInt();
};
ArrayAttr extractOffsets = op.getOffsets();
ArrayAttr extractStrides = op.getStrides();
@@ -3112,7 +3115,7 @@ class StridedSliceSplatConstantFolder final
if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
return failure();
- auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
+ auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
if (!splat)
return failure();
@@ -3141,7 +3144,7 @@ class StridedSliceNonSplatConstantFolder final
return failure();
// The splat case is handled by `StridedSliceSplatConstantFolder`.
- auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
+ auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
if (!dense || dense.isSplat())
return failure();
@@ -3149,7 +3152,7 @@ class StridedSliceNonSplatConstantFolder final
if (extractStridedSliceOp.hasNonUnitStrides())
return failure();
- auto sourceVecTy = sourceVector.getType().cast<VectorType>();
+ auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
@@ -3201,9 +3204,10 @@ class StridedSliceBroadcast final
auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
if (!broadcast)
return failure();
- auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>();
+ auto srcVecType =
+ llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
- auto dstVecType = op.getType().cast<VectorType>();
+ auto dstVecType = llvm::cast<VectorType>(op.getType());
unsigned dstRank = dstVecType.getRank();
unsigned rankDiff = dstRank - srcRank;
// Check if the most inner dimensions of the source of the broadcast are the
@@ -3269,7 +3273,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, AffineMapAttr permutationMapAttr,
/*optional*/ ArrayAttr inBoundsAttr) {
- Type elemType = source.getType().cast<ShapedType>().getElementType();
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
Value padding = builder.create<arith::ConstantOp>(
result.location, elemType, builder.getZeroAttr(elemType));
build(builder, result, vectorType, source, indices, permutationMapAttr,
@@ -3295,7 +3299,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
ValueRange indices, Value padding,
std::optional<ArrayRef<bool>> inBounds) {
AffineMap permutationMap = getTransferMinorIdentityMap(
- source.getType().cast<ShapedType>(), vectorType);
+ llvm::cast<ShapedType>(source.getType()), vectorType);
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
? builder.getBoolArrayAttr(inBounds.value())
@@ -3311,7 +3315,7 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices,
std::optional<ArrayRef<bool>> inBounds) {
- Type elemType = source.getType().cast<ShapedType>().getElementType();
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
Value padding = builder.create<arith::ConstantOp>(
result.location, elemType, builder.getZeroAttr(elemType));
build(builder, result, vectorType, source, indices, padding, inBounds);
@@ -3356,13 +3360,13 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
"Use in_bounds instead.");
}
- if (!shapedType.isa<MemRefType, RankedTensorType>())
+ if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
return op->emitOpError(
"requires source to be a memref or ranked tensor type");
auto elementType = shapedType.getElementType();
DataLayout dataLayout = DataLayout::closest(op);
- if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
+ if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
// Memref or tensor has vector element type.
unsigned sourceVecSize =
dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
@@ -3425,7 +3429,7 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
<< " vs inBounds of size: " << inBounds.size();
for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
- !inBounds.getValue()[i].cast<BoolAttr>().getValue())
+ !llvm::cast<BoolAttr>(inBounds.getValue()[i]).getValue())
return op->emitOpError("requires broadcast dimensions to be in-bounds");
}
@@ -3440,7 +3444,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
bool elideInBounds = true;
if (auto inBounds = op.in_bounds()) {
for (auto attr : *inBounds) {
- if (attr.template cast<BoolAttr>().getValue()) {
+ if (llvm::cast<BoolAttr>(attr).getValue()) {
elideInBounds = false;
break;
}
@@ -3496,10 +3500,10 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
if (types.size() != 2)
return parser.emitError(typesLoc, "requires two types");
auto indexType = builder.getIndexType();
- auto shapedType = types[0].dyn_cast<ShapedType>();
- if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
+ auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
+ if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
- VectorType vectorType = types[1].dyn_cast<VectorType>();
+ VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName();
@@ -3509,7 +3513,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
- permMap = permMapAttr.cast<AffineMapAttr>().getValue();
+ permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
}
if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
@@ -3517,7 +3521,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
result.operands))
return failure();
if (hasMask.succeeded()) {
- if (shapedType.getElementType().dyn_cast<VectorType>())
+ if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
return parser.emitError(
maskInfo.location, "does not support masks with vector element type");
// Instead of adding the mask type as an op type, compute it based on the
@@ -3554,7 +3558,8 @@ LogicalResult TransferReadOp::verify() {
getInBounds() ? *getInBounds() : ArrayAttr())))
return failure();
- if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
+ if (auto sourceVectorElementType =
+ llvm::dyn_cast<VectorType>(sourceElementType)) {
// Source has vector element type.
// Check that 'sourceVectorElementType' and 'paddingType' types match.
if (sourceVectorElementType != paddingType)
@@ -3647,7 +3652,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
/// %v0
/// ```
static Value foldRAW(TransferReadOp readOp) {
- if (!readOp.getShapedType().isa<RankedTensorType>())
+ if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
return {};
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
while (defWrite) {
@@ -3682,7 +3687,7 @@ std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
void TransferReadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- if (getShapedType().isa<MemRefType>())
+ if (llvm::isa<MemRefType>(getShapedType()))
effects.emplace_back(MemoryEffects::Read::get(), getSource(),
SideEffects::DefaultResource::get());
}
@@ -3818,7 +3823,7 @@ struct TransferReadAfterWriteToBroadcast
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
if (readOp.hasOutOfBoundsDim() ||
- !readOp.getShapedType().isa<RankedTensorType>())
+ !llvm::isa<RankedTensorType>(readOp.getShapedType()))
return failure();
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
if (!defWrite)
@@ -3889,7 +3894,7 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
AffineMapAttr permutationMapAttr,
/*optional*/ Value mask,
/*optional*/ ArrayAttr inBoundsAttr) {
- Type resultType = dest.getType().dyn_cast<RankedTensorType>();
+ Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
mask, inBoundsAttr);
}
@@ -3922,9 +3927,9 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
Value vector, Value dest, ValueRange indices,
std::optional<ArrayRef<bool>> inBounds) {
- auto vectorType = vector.getType().cast<VectorType>();
+ auto vectorType = llvm::cast<VectorType>(vector.getType());
AffineMap permutationMap = getTransferMinorIdentityMap(
- dest.getType().cast<ShapedType>(), vectorType);
+ llvm::cast<ShapedType>(dest.getType()), vectorType);
build(builder, result, vector, dest, indices, permutationMap, inBounds);
}
@@ -3949,11 +3954,11 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
if (types.size() != 2)
return parser.emitError(typesLoc, "requires two types");
auto indexType = builder.getIndexType();
- VectorType vectorType = types[0].dyn_cast<VectorType>();
+ VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
if (!vectorType)
return parser.emitError(typesLoc, "requires vector type");
- ShapedType shapedType = types[1].dyn_cast<ShapedType>();
- if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
+ ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
+ if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
return parser.emitError(typesLoc, "requires memref or ranked tensor type");
auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName();
auto permMapAttr = result.attributes.get(permMapAttrName);
@@ -3962,14 +3967,14 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
permMap = getTransferMinorIdentityMap(shapedType, vectorType);
result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
} else {
- permMap = permMapAttr.cast<AffineMapAttr>().getValue();
+ permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
}
if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands))
return failure();
if (hasMask.succeeded()) {
- if (shapedType.getElementType().dyn_cast<VectorType>())
+ if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
return parser.emitError(
maskInfo.location, "does not support masks with vector element type");
auto maskType = inferTransferOpMaskType(vectorType, permMap);
@@ -3980,7 +3985,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
builder.getDenseI32ArrayAttr(
{1, 1, static_cast<int32_t>(indexInfo.size()),
static_cast<int32_t>(hasMask.succeeded())}));
- return failure(shapedType.isa<RankedTensorType>() &&
+ return failure(llvm::isa<RankedTensorType>(shapedType) &&
parser.addTypeToList(shapedType, result.types));
}
@@ -4052,7 +4057,7 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
if (write.getTransferRank() == 0)
return failure();
auto rankedTensorType =
- write.getSource().getType().dyn_cast<RankedTensorType>();
+ llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
// If not operating on tensors, bail.
if (!rankedTensorType)
return failure();
@@ -4119,7 +4124,7 @@ static bool checkSameValueWAR(vector::TransferReadOp read,
/// ```
static LogicalResult foldWAR(TransferWriteOp write,
SmallVectorImpl<OpFoldResult> &results) {
- if (!write.getSource().getType().isa<RankedTensorType>())
+ if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
return failure();
auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
if (!read)
@@ -4149,7 +4154,7 @@ std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
void TransferWriteOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- if (getShapedType().isa<MemRefType>())
+ if (llvm::isa<MemRefType>(getShapedType()))
effects.emplace_back(MemoryEffects::Write::get(), getSource(),
SideEffects::DefaultResource::get());
}
@@ -4184,7 +4189,7 @@ class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
- if (!writeOp.getShapedType().isa<RankedTensorType>())
+ if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
return failure();
vector::TransferWriteOp writeToModify = writeOp;
@@ -4439,7 +4444,7 @@ LogicalResult vector::LoadOp::verify() {
// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
- if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
+ if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
if (memVecTy != resVecTy)
return emitOpError("base memref and result vector types should match");
memElemTy = memVecTy.getElementType();
@@ -4471,7 +4476,7 @@ LogicalResult vector::StoreOp::verify() {
// Checks for vector memrefs.
Type memElemTy = memRefTy.getElementType();
- if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
+ if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
if (memVecTy != valueVecTy)
return emitOpError(
"base memref and valueToStore vector types should match");
@@ -4604,7 +4609,7 @@ LogicalResult GatherOp::verify() {
VectorType resVType = getVectorType();
ShapedType baseType = getBaseType();
- if (!baseType.isa<MemRefType, RankedTensorType>())
+ if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
return emitOpError("requires base to be a memref or ranked tensor type");
if (resVType.getElementType() != baseType.getElementType())
@@ -4864,8 +4869,10 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
}
LogicalResult ShapeCastOp::verify() {
- auto sourceVectorType = getSource().getType().dyn_cast_or_null<VectorType>();
- auto resultVectorType = getResult().getType().dyn_cast_or_null<VectorType>();
+ auto sourceVectorType =
+ llvm::dyn_cast_or_null<VectorType>(getSource().getType());
+ auto resultVectorType =
+ llvm::dyn_cast_or_null<VectorType>(getResult().getType());
// Check if source/result are of vector type.
if (sourceVectorType && resultVectorType)
@@ -4885,8 +4892,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return otherOp.getSource();
// Only allows valid transitive folding.
- VectorType srcType = otherOp.getSource().getType().cast<VectorType>();
- VectorType resultType = getResult().getType().cast<VectorType>();
+ VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
+ VectorType resultType = llvm::cast<VectorType>(getResult().getType());
if (srcType.getRank() < resultType.getRank()) {
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
@@ -4923,11 +4930,11 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
if (!constantOp)
return failure();
// Only handle splat for now.
- auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
+ auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
if (!dense)
return failure();
auto newAttr =
- DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
+ DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
dense.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
return success();
@@ -4950,7 +4957,7 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
return failure();
auto broadcastSourceVectorType =
- broadcastOp.getSourceType().dyn_cast<VectorType>();
+ llvm::dyn_cast<VectorType>(broadcastOp.getSourceType());
auto broadcastSourceShape = broadcastSourceVectorType
? broadcastSourceVectorType.getShape()
: ArrayRef<int64_t>{};
@@ -5029,7 +5036,7 @@ OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
Type srcElemType = getSourceVectorType().getElementType();
Type dstElemType = getResultVectorType().getElementType();
- if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
+ if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
if (floatPack.isSplat()) {
auto splat = floatPack.getSplatValue<FloatAttr>();
@@ -5046,11 +5053,11 @@ OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
}
}
- if (auto intPack = sourceConstant.dyn_cast<DenseIntElementsAttr>()) {
+ if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
if (intPack.isSplat()) {
auto splat = intPack.getSplatValue<IntegerAttr>();
- if (dstElemType.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(dstElemType)) {
uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
@@ -5075,7 +5082,7 @@ OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
- auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
+ auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
memRefType.getShape().end());
if (vectorType)
@@ -5088,7 +5095,7 @@ static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
void TypeCastOp::build(OpBuilder &builder, OperationState &result,
Value source) {
result.addOperands(source);
- MemRefType memRefType = source.getType().cast<MemRefType>();
+ MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
VectorType vectorType =
VectorType::get(extractShape(memRefType),
getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
@@ -5126,7 +5133,7 @@ LogicalResult TypeCastOp::verify() {
void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
Value vector, ArrayRef<int64_t> transp) {
- VectorType vt = vector.getType().cast<VectorType>();
+ VectorType vt = llvm::cast<VectorType>(vector.getType());
SmallVector<int64_t, 4> transposedShape(vt.getRank());
for (unsigned i = 0; i < transp.size(); ++i)
transposedShape[i] = vt.getShape()[transp[i]];
@@ -5170,7 +5177,7 @@ LogicalResult vector::TransposeOp::verify() {
return emitOpError("transposition length mismatch: ") << size;
SmallVector<bool, 8> seen(rank, false);
for (const auto &ta : llvm::enumerate(transpAttr)) {
- int64_t i = ta.value().cast<IntegerAttr>().getInt();
+ int64_t i = llvm::cast<IntegerAttr>(ta.value()).getInt();
if (i < 0 || i >= rank)
return emitOpError("transposition index out of range: ") << i;
if (seen[i])
@@ -5239,7 +5246,7 @@ struct FoldTransposedScalarBroadcast final
if (!bcastOp)
return failure();
- auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
+ auto srcVectorType = llvm::dyn_cast<VectorType>(bcastOp.getSourceType());
if (!srcVectorType || srcVectorType.getNumElements() == 1) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
@@ -5324,12 +5331,12 @@ void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
//===----------------------------------------------------------------------===//
LogicalResult ConstantMaskOp::verify() {
- auto resultType = getResult().getType().cast<VectorType>();
+ auto resultType = llvm::cast<VectorType>(getResult().getType());
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
if (getMaskDimSizes().size() != 1)
return emitError("array attr must have length 1 for 0-D vectors");
- auto dim = getMaskDimSizes()[0].cast<IntegerAttr>().getInt();
+ auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
if (dim != 0 && dim != 1)
return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
return success();
@@ -5344,7 +5351,7 @@ LogicalResult ConstantMaskOp::verify() {
auto resultShape = resultType.getShape();
SmallVector<int64_t, 4> maskDimSizes;
for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
- int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
+ int64_t attrValue = llvm::cast<IntegerAttr>(it.value()).getInt();
if (attrValue < 0 || attrValue > resultShape[it.index()])
return emitOpError(
"array attr of size out of bounds of vector result dimension size");
@@ -5363,7 +5370,7 @@ LogicalResult ConstantMaskOp::verify() {
// `vector.constant_mask`. In the future, a convention could be established
// to decide if a specific dimension value could be considered as "all set".
if (resultType.isScalable() &&
- getMaskDimSizes()[0].cast<IntegerAttr>().getInt() != 0)
+ llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() != 0)
return emitOpError("expected mask dim sizes for scalable masks to be 0");
return success();
}
@@ -5381,14 +5388,14 @@ void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult CreateMaskOp::verify() {
- auto vectorType = getResult().getType().cast<VectorType>();
+ auto vectorType = llvm::cast<VectorType>(getResult().getType());
// Verify that an operand was specified for each result vector each dimension.
if (vectorType.getRank() == 0) {
if (getNumOperands() != 1)
return emitOpError(
"must specify exactly one operand for 0-D create_mask");
} else if (getNumOperands() !=
- getResult().getType().cast<VectorType>().getRank()) {
+ llvm::cast<VectorType>(getResult().getType()).getRank()) {
return emitOpError(
"must specify an operand for each result vector dimension");
}
@@ -5413,7 +5420,7 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
// CreateMaskOp for scalable vectors can be folded only if all dimensions
// are negative or zero.
- if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
+ if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) {
if (vType.isScalable())
for (auto opDim : createMaskOp.getOperands()) {
APInt intVal;
@@ -5615,7 +5622,7 @@ LogicalResult MaskOp::verify() {
"expects result type to match maskable operation result type");
if (llvm::count_if(maskableOp->getResultTypes(),
- [](Type t) { return t.isa<VectorType>(); }) > 1)
+ [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
return emitOpError("multiple vector results not supported");
// Mask checks.
@@ -5759,7 +5766,7 @@ void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
- p << "[" << warpSizeAttr.cast<IntegerAttr>().getInt() << "]";
+ p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
if (!getArgs().empty())
p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
@@ -5872,8 +5879,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
// If the types matches there is no distribution.
if (expanded == distributed)
return success();
- auto expandedVecType = expanded.dyn_cast<VectorType>();
- auto distributedVecType = distributed.dyn_cast<VectorType>();
+ auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
+ auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
if (!expandedVecType || !distributedVecType)
return op->emitOpError("expected vector type for distributed operands.");
if (expandedVecType.getRank() != distributedVecType.getRank() ||
@@ -5940,7 +5947,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
case CombiningKind::ADD:
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
- else if (t1.isa<FloatType>() && tAcc.isa<FloatType>())
+ else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
else
llvm_unreachable("invalid value types for ADD reduction");
@@ -5950,12 +5957,12 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
break;
case CombiningKind::MAXF:
- assert(t1.isa<FloatType>() && tAcc.isa<FloatType>() &&
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
result = b.createOrFold<arith::MaxFOp>(loc, v1, acc);
break;
case CombiningKind::MINF:
- assert(t1.isa<FloatType>() && tAcc.isa<FloatType>() &&
+ assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
"expected float values");
result = b.createOrFold<arith::MinFOp>(loc, v1, acc);
break;
@@ -5978,7 +5985,7 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
case CombiningKind::MUL:
if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
- else if (t1.isa<FloatType>() && tAcc.isa<FloatType>())
+ else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
else
llvm_unreachable("invalid value types for MUL reduction");
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 1744c46db5886..5d2abf7c03680 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -231,7 +231,7 @@ class ReduceMultiDimReductionRank
Value vectorMask = maskableOp.getMaskingOp().getMask();
auto maskCastedType = VectorType::get(
vectorShape,
- vectorMask.getType().cast<VectorType>().getElementType());
+ llvm::cast<VectorType>(vectorMask.getType()).getElementType());
newVectorMask =
rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
}
@@ -413,7 +413,7 @@ struct OneDimMultiReductionToTwoDim
srcVectorType.getElementType());
auto accType =
VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
- assert(!multiReductionOp.getDestType().isa<VectorType>() &&
+ assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
"multi_reduction with a single dimension expects a scalar result");
// If the unique dim is reduced and we insert a parallel in front, we need a
@@ -427,7 +427,7 @@ struct OneDimMultiReductionToTwoDim
loc, accType, multiReductionOp.getAcc());
Value castMask;
if (maskableOp.isMasked()) {
- auto maskType = mask.getType().cast<ShapedType>();
+ auto maskType = llvm::cast<ShapedType>(mask.getType());
auto castMaskType =
VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
maskType.getElementType());
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 6c9034d446341..4a67010b7a3a0 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -66,14 +66,14 @@ class AffineExprConstantFolder {
case AffineExprKind::Constant:
return expr.cast<AffineConstantExpr>().getValue();
case AffineExprKind::DimId:
- if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()]
- .dyn_cast_or_null<IntegerAttr>())
+ if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
+ operandConsts[expr.cast<AffineDimExpr>().getPosition()]))
return attr.getInt();
return std::nullopt;
case AffineExprKind::SymbolId:
- if (auto attr = operandConsts[numDims +
- expr.cast<AffineSymbolExpr>().getPosition()]
- .dyn_cast_or_null<IntegerAttr>())
+ if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
+ operandConsts[numDims +
+ expr.cast<AffineSymbolExpr>().getPosition()]))
return attr.getInt();
return std::nullopt;
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3afafd6d4cdf0..206a097a02802 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -91,7 +91,7 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
// it is a function (avoiding a grammar ambiguity).
bool wrapped = op->getNumResults() != 1;
if (!wrapped && op->getResult(0).getType() &&
- op->getResult(0).getType().isa<FunctionType>())
+ llvm::isa<FunctionType>(op->getResult(0).getType()))
wrapped = true;
if (wrapped)
@@ -254,7 +254,7 @@ OpPrintingFlags &OpPrintingFlags::printValueUsers() {
bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
return elementsAttrElementLimit &&
*elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
- !attr.isa<SplatElementsAttr>();
+ !llvm::isa<SplatElementsAttr>(attr);
}
/// Return the size limit for printing large ElementsAttr.
@@ -803,8 +803,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
attr.getDialect().printAttribute(attr, *this);
// Process the builtin attributes.
- } else if (attr.isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr,
- IntegerSetAttr, UnitAttr>()) {
+ } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr,
+ IntegerSetAttr, UnitAttr>(attr)) {
return;
} else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) {
for (const NamedAttribute &nestedAttr : dictAttr.getValue()) {
@@ -833,9 +833,9 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
// Don't print the type if we must elide it, or if it is a None type.
if (!elideType) {
- if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
Type attrType = typedAttr.getType();
- if (!attrType.isa<NoneType>())
+ if (!llvm::isa<NoneType>(attrType))
printType(attrType);
}
}
@@ -845,10 +845,10 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
return type.getDialect().printType(type, *this);
// Only visit the layout of memref if it isn't the identity.
- if (auto memrefTy = type.dyn_cast<MemRefType>()) {
+ if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) {
printType(memrefTy.getElementType());
MemRefLayoutAttrInterface layout = memrefTy.getLayout();
- if (!layout.isa<AffineMapAttr>() || !layout.isIdentity())
+ if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity())
printAttribute(memrefTy.getLayout());
if (memrefTy.getMemorySpace())
printAttribute(memrefTy.getMemorySpace());
@@ -1418,7 +1418,7 @@ void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
void SSANameState::numberValuesInRegion(Region ®ion) {
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
- assert(arg.cast<BlockArgument>().getOwner()->getParent() == ®ion &&
+ assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion &&
"arg not defined in current region");
setValueName(arg, name);
};
@@ -1479,7 +1479,7 @@ void SSANameState::numberValuesInOp(Operation &op) {
setValueName(result, name);
// Record the result number for groups not anchored at 0.
- if (int resultNo = result.cast<OpResult>().getResultNumber())
+ if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
resultGroups.push_back(resultNo);
};
// Operations can customize the printing of block names in OpAsmOpInterface.
@@ -1878,7 +1878,7 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
// Print the child if it isn't unknown.
auto childLoc = loc.getChildLoc();
- if (!childLoc.isa<UnknownLoc>()) {
+ if (!llvm::isa<UnknownLoc>(childLoc)) {
os << '(';
printLocationInternal(childLoc, pretty);
os << ')';
@@ -1891,8 +1891,8 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
os << "callsite(";
printLocationInternal(callee, pretty);
if (pretty) {
- if (callee.isa<NameLoc>()) {
- if (caller.isa<FileLineColLoc>()) {
+ if (llvm::isa<NameLoc>(callee)) {
+ if (llvm::isa<FileLineColLoc>(caller)) {
os << " at ";
} else {
os << newLine << " at ";
@@ -2100,19 +2100,19 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
AttrTypeElision typeElision) {
if (!isa<BuiltinDialect>(attr.getDialect())) {
printDialectAttribute(attr);
- } else if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
+ } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) {
printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
opaqueAttr.getAttrData());
- } else if (attr.isa<UnitAttr>()) {
+ } else if (llvm::isa<UnitAttr>(attr)) {
os << "unit";
return;
- } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
+ } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
os << '{';
interleaveComma(dictAttr.getValue(),
[&](NamedAttribute attr) { printNamedAttribute(attr); });
os << '}';
- } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
+ } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
Type intType = intAttr.getType();
if (intType.isSignlessInteger(1)) {
os << (intAttr.getValue().getBoolValue() ? "true" : "false");
@@ -2132,24 +2132,24 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64))
return;
- } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
+ } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) {
printFloatValue(floatAttr.getValue(), os);
// FloatAttr elides the type if F64.
if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64())
return;
- } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+ } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) {
printEscapedString(strAttr.getValue());
- } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
+ } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
os << '[';
interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
printAttribute(attr, AttrTypeElision::May);
});
os << ']';
- } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
+ } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
os << "affine_map<";
affineMapAttr.getValue().print(os);
os << '>';
@@ -2157,7 +2157,7 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
// AffineMap always elides the type.
return;
- } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
+ } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) {
os << "affine_set<";
integerSetAttr.getValue().print(os);
os << '>';
@@ -2165,17 +2165,18 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
// IntegerSet always elides the type.
return;
- } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
+ } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) {
printType(typeAttr.getValue());
- } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
+ } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) {
printSymbolReference(refAttr.getRootReference().getValue(), os);
for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
os << "::";
printSymbolReference(nestedRef.getValue(), os);
}
- } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
+ } else if (auto intOrFpEltAttr =
+ llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
printElidedElementsAttr(os);
} else {
@@ -2184,7 +2185,7 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
os << '>';
}
- } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
+ } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
printElidedElementsAttr(os);
} else {
@@ -2193,7 +2194,7 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
os << '>';
}
- } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
+ } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
printElidedElementsAttr(os);
@@ -2207,9 +2208,9 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
}
os << '>';
}
- } else if (auto stridedLayoutAttr = attr.dyn_cast<StridedLayoutAttr>()) {
+ } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) {
stridedLayoutAttr.print(os);
- } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
+ } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) {
os << "array<";
printType(denseArrayAttr.getElementType());
if (!denseArrayAttr.empty()) {
@@ -2218,20 +2219,21 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
}
os << ">";
return;
- } else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
+ } else if (auto resourceAttr =
+ llvm::dyn_cast<DenseResourceElementsAttr>(attr)) {
os << "dense_resource<";
printResourceHandle(resourceAttr.getRawHandle());
os << ">";
- } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
+ } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(attr)) {
printLocation(locAttr);
} else {
llvm::report_fatal_error("Unknown builtin attribute");
}
// Don't print the type if we must elide it, or if it is a None type.
if (typeElision != AttrTypeElision::Must) {
- if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
Type attrType = typedAttr.getType();
- if (!attrType.isa<NoneType>()) {
+ if (!llvm::isa<NoneType>(attrType)) {
os << " : ";
printType(attrType);
}
@@ -2300,10 +2302,10 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
bool allowHex) {
- if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
+ if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr))
return printDenseStringElementsAttr(stringAttr);
- printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
+ printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr),
allowHex);
}
@@ -2333,12 +2335,12 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
return;
}
- if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
+ if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
Type complexElementType = complexTy.getElementType();
// Note: The if and else below had a common lambda function which invoked
// printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
// and hence was replaced.
- if (complexElementType.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(complexElementType)) {
auto valueIt = attr.value_begin<std::complex<APInt>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(valueIt + index);
@@ -2365,7 +2367,7 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
printDenseIntElement(*(valueIt + index), os, elementType);
});
} else {
- assert(elementType.isa<FloatType>() && "unexpected element type");
+ assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
auto valueIt = attr.value_begin<APFloat>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
printFloatValue(*(valueIt + index), os);
@@ -2397,7 +2399,7 @@ void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
if (type.isIntOrIndex()) {
printDenseIntElement(value, getStream(), type);
} else {
- APFloat fltVal(type.cast<FloatType>().getFloatSemantics(), value);
+ APFloat fltVal(llvm::cast<FloatType>(type).getFloatSemantics(), value);
printFloatValue(fltVal, getStream());
}
};
@@ -2447,7 +2449,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
os << ") -> ";
ArrayRef<Type> results = funcTy.getResults();
- if (results.size() == 1 && !results[0].isa<FunctionType>()) {
+ if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) {
printType(results[0]);
} else {
os << '(';
@@ -2506,7 +2508,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
}
printType(memrefTy.getElementType());
MemRefLayoutAttrInterface layout = memrefTy.getLayout();
- if (!layout.isa<AffineMapAttr>() || !layout.isIdentity()) {
+ if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
os << ", ";
printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
}
@@ -2580,7 +2582,7 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
::printKeywordOrString(attr.getName().strref(), os);
// Pretty printing elides the attribute value for unit attributes.
- if (attr.getValue().isa<UnitAttr>())
+ if (llvm::isa<UnitAttr>(attr.getValue()))
return;
os << " = ";
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index bc336dcc121e9..5bf7caa440277 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -33,7 +33,7 @@ namespace detail {
/// Return the bit width which DenseElementsAttr should use for this type.
inline size_t getDenseElementBitWidth(Type eltType) {
// Align the width for complex to 8 to make storage and interpretation easier.
- if (ComplexType comp = eltType.dyn_cast<ComplexType>())
+ if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
if (eltType.isIndex())
return IndexType::kInternalStorageBitWidth;
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 2798944c2df36..c572f12091647 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -46,7 +46,9 @@ NamedAttribute::NamedAttribute(StringAttr name, Attribute value)
assert(name.size() != 0 && "expected valid attribute name");
}
-StringAttr NamedAttribute::getName() const { return name.cast<StringAttr>(); }
+StringAttr NamedAttribute::getName() const {
+ return llvm::cast<StringAttr>(name);
+}
Dialect *NamedAttribute::getNameDialect() const {
return getName().getReferencedDialect();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 7943655aa1b89..6cbba068fc1a9 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -316,14 +316,15 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
}
TypedAttr Builder::getZeroAttr(Type type) {
- if (type.isa<FloatType>())
+ if (llvm::isa<FloatType>(type))
return getFloatAttr(type, 0.0);
- if (type.isa<IndexType>())
+ if (llvm::isa<IndexType>(type))
return getIndexAttr(0);
- if (auto integerType = type.dyn_cast<IntegerType>())
- return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
- if (type.isa<RankedTensorType, VectorType>()) {
- auto vtType = type.cast<ShapedType>();
+ if (auto integerType = llvm::dyn_cast<IntegerType>(type))
+ return getIntegerAttr(type,
+ APInt(llvm::cast<IntegerType>(type).getWidth(), 0));
+ if (llvm::isa<RankedTensorType, VectorType>(type)) {
+ auto vtType = llvm::cast<ShapedType>(type);
auto element = getZeroAttr(vtType.getElementType());
if (!element)
return {};
diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index ab216baede61c..9b5235a6c5ceb 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -53,7 +53,7 @@ bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
}
uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
- ShapedType shapeType = type.cast<ShapedType>();
+ ShapedType shapeType = llvm::cast<ShapedType>(type);
assert(isValidIndex(shapeType, index) &&
"expected valid multi-dimensional index");
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 536328d7cc761..de26c8e05802a 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -300,11 +300,12 @@ double FloatAttr::getValueAsDouble(APFloat value) {
LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, APFloat value) {
// Verify that the type is correct.
- if (!type.isa<FloatType>())
+ if (!llvm::isa<FloatType>(type))
return emitError() << "expected floating point type";
// Verify that the type semantics match that of the value.
- if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
+ if (&llvm::cast<FloatType>(type).getFloatSemantics() !=
+ &value.getSemantics()) {
return emitError()
<< "FloatAttr type doesn't match the type implied by its value";
}
@@ -321,11 +322,11 @@ SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
}
FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
- return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
+ return llvm::cast<FlatSymbolRefAttr>(get(ctx, value, {}));
}
FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
- return get(value, {}).cast<FlatSymbolRefAttr>();
+ return llvm::cast<FlatSymbolRefAttr>(get(value, {}));
}
FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
@@ -370,14 +371,14 @@ APSInt IntegerAttr::getAPSInt() const {
LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, APInt value) {
- if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
+ if (IntegerType integerType = llvm::dyn_cast<IntegerType>(type)) {
if (integerType.getWidth() != value.getBitWidth())
return emitError() << "integer type bit width (" << integerType.getWidth()
<< ") doesn't match value bit width ("
<< value.getBitWidth() << ")";
return success();
}
- if (type.isa<IndexType>()) {
+ if (llvm::isa<IndexType>(type)) {
if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
return emitError()
<< "value bit width (" << value.getBitWidth()
@@ -390,7 +391,7 @@ LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
- return attr.cast<BoolAttr>();
+ return llvm::cast<BoolAttr>(attr);
}
//===----------------------------------------------------------------------===//
@@ -403,7 +404,7 @@ bool BoolAttr::getValue() const {
}
bool BoolAttr::classof(Attribute attr) {
- IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
+ IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getType().isSignlessInteger(1);
}
@@ -600,21 +601,21 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
attr.getAsOpaquePointer(), index) {}
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
- auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
+ auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
Type eltTy = owner.getElementType();
- if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
+ if (auto intEltTy = llvm::dyn_cast<IntegerType>(eltTy))
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
- if (eltTy.isa<IndexType>())
+ if (llvm::isa<IndexType>(eltTy))
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
- if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
+ if (auto floatEltTy = llvm::dyn_cast<FloatType>(eltTy)) {
IntElementIterator intIt(owner, index);
FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
return FloatAttr::get(eltTy, *floatIt);
}
- if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
+ if (auto complexTy = llvm::dyn_cast<ComplexType>(eltTy)) {
auto complexEltTy = complexTy.getElementType();
ComplexIntElementIterator complexIntIt(owner, index);
- if (complexEltTy.isa<IntegerType>()) {
+ if (llvm::isa<IntegerType>(complexEltTy)) {
auto value = *complexIntIt;
auto real = IntegerAttr::get(complexEltTy, value.real());
auto imag = IntegerAttr::get(complexEltTy, value.imag());
@@ -623,14 +624,14 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
}
ComplexFloatElementIterator complexFloatIt(
- complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
+ llvm::cast<FloatType>(complexEltTy).getFloatSemantics(), complexIntIt);
auto value = *complexFloatIt;
auto real = FloatAttr::get(complexEltTy, value.real());
auto imag = FloatAttr::get(complexEltTy, value.imag());
return ArrayAttr::get(complexTy.getContext(),
ArrayRef<Attribute>{real, imag});
}
- if (owner.isa<DenseStringElementsAttr>()) {
+ if (llvm::isa<DenseStringElementsAttr>(owner)) {
ArrayRef<StringRef> vals = owner.getRawStringData();
return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
}
@@ -673,7 +674,7 @@ DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
std::complex<APInt>, std::complex<APInt>,
std::complex<APInt>>(
attr.getRawData().data(), attr.isSplat(), dataIndex) {
- auto complexType = attr.getElementType().cast<ComplexType>();
+ auto complexType = llvm::cast<ComplexType>(attr.getElementType());
bitWidth = getDenseElementBitWidth(complexType.getElementType());
}
@@ -713,7 +714,7 @@ template <size_t width,
IntegerType::SignednessSemantics signedness = IntegerType::Signless>
struct DenseArrayAttrIntUtil {
static bool checkElementType(Type eltType) {
- auto type = eltType.dyn_cast<IntegerType>();
+ auto type = llvm::dyn_cast<IntegerType>(eltType);
if (!type || type.getWidth() != width)
return false;
return type.getSignedness() == signedness;
@@ -860,7 +861,7 @@ DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
template <typename T>
bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
- if (auto denseArray = attr.dyn_cast<DenseArrayAttr>())
+ if (auto denseArray = llvm::dyn_cast<DenseArrayAttr>(attr))
return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
return false;
}
@@ -884,7 +885,7 @@ template class DenseArrayAttrImpl<double>;
/// Method for support type inquiry through isa, cast and dyn_cast.
bool DenseElementsAttr::classof(Attribute attr) {
- return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
+ return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr);
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@@ -894,20 +895,19 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
Type eltType = type.getElementType();
// Take care complex type case first.
- if (auto complexType = eltType.dyn_cast<ComplexType>()) {
+ if (auto complexType = llvm::dyn_cast<ComplexType>(eltType)) {
if (complexType.getElementType().isIntOrIndex()) {
SmallVector<std::complex<APInt>> complexValues;
complexValues.reserve(values.size());
for (Attribute attr : values) {
- assert(attr.isa<ArrayAttr>() &&
- "expected ArrayAttr for complex");
- auto arrayAttr = attr.cast<ArrayAttr>();
+ assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
+ auto arrayAttr = llvm::cast<ArrayAttr>(attr);
assert(arrayAttr.size() == 2 && "expected 2 element for complex");
auto attr0 = arrayAttr[0];
auto attr1 = arrayAttr[1];
complexValues.push_back(
- std::complex<APInt>(attr0.cast<IntegerAttr>().getValue(),
- attr1.cast<IntegerAttr>().getValue()));
+ std::complex<APInt>(llvm::cast<IntegerAttr>(attr0).getValue(),
+ llvm::cast<IntegerAttr>(attr1).getValue()));
}
return DenseElementsAttr::get(type, complexValues);
}
@@ -915,14 +915,14 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
SmallVector<std::complex<APFloat>> complexValues;
complexValues.reserve(values.size());
for (Attribute attr : values) {
- assert(attr.isa<ArrayAttr>() && "expected ArrayAttr for complex");
- auto arrayAttr = attr.cast<ArrayAttr>();
+ assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
+ auto arrayAttr = llvm::cast<ArrayAttr>(attr);
assert(arrayAttr.size() == 2 && "expected 2 element for complex");
auto attr0 = arrayAttr[0];
auto attr1 = arrayAttr[1];
complexValues.push_back(
- std::complex<APFloat>(attr0.cast<FloatAttr>().getValue(),
- attr1.cast<FloatAttr>().getValue()));
+ std::complex<APFloat>(llvm::cast<FloatAttr>(attr0).getValue(),
+ llvm::cast<FloatAttr>(attr1).getValue()));
}
return DenseElementsAttr::get(type, complexValues);
}
@@ -933,9 +933,9 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
SmallVector<StringRef, 8> stringValues;
stringValues.reserve(values.size());
for (Attribute attr : values) {
- assert(attr.isa<StringAttr>() &&
+ assert(llvm::isa<StringAttr>(attr) &&
"expected string value for non integer/index/float element");
- stringValues.push_back(attr.cast<StringAttr>().getValue());
+ stringValues.push_back(llvm::cast<StringAttr>(attr).getValue());
}
return get(type, stringValues);
}
@@ -949,12 +949,12 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT));
APInt intVal;
for (unsigned i = 0, e = values.size(); i < e; ++i) {
- if (auto floatAttr = values[i].dyn_cast<FloatAttr>()) {
+ if (auto floatAttr = llvm::dyn_cast<FloatAttr>(values[i])) {
assert(floatAttr.getType() == eltType &&
"expected float attribute type to equal element type");
intVal = floatAttr.getValue().bitcastToAPInt();
} else {
- auto intAttr = values[i].cast<IntegerAttr>();
+ auto intAttr = llvm::cast<IntegerAttr>(values[i]);
assert(intAttr.getType() == eltType &&
"expected integer attribute type to equal element type");
intVal = intAttr.getValue();
@@ -1015,8 +1015,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<std::complex<APInt>> values) {
- ComplexType complex = type.getElementType().cast<ComplexType>();
- assert(complex.getElementType().isa<IntegerType>());
+ ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
+ assert(llvm::isa<IntegerType>(complex.getElementType()));
assert(hasSameElementsOrSplat(type, values));
size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
@@ -1029,7 +1029,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
// element type of 'type'.
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<APFloat> values) {
- assert(type.getElementType().isa<FloatType>());
+ assert(llvm::isa<FloatType>(type.getElementType()));
assert(hasSameElementsOrSplat(type, values));
size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
@@ -1037,8 +1037,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
DenseElementsAttr
DenseElementsAttr::get(ShapedType type,
ArrayRef<std::complex<APFloat>> values) {
- ComplexType complex = type.getElementType().cast<ComplexType>();
- assert(complex.getElementType().isa<FloatType>());
+ ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
+ assert(llvm::isa<FloatType>(complex.getElementType()));
assert(hasSameElementsOrSplat(type, values));
ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
values.size() * 2);
@@ -1104,11 +1104,11 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
// Check that the element type is either float or integer or index.
if (!isInt)
- return type.isa<FloatType>();
+ return llvm::isa<FloatType>(type);
if (type.isIndex())
return true;
- auto intType = type.dyn_cast<IntegerType>();
+ auto intType = llvm::dyn_cast<IntegerType>(type);
if (!intType)
return false;
@@ -1142,8 +1142,8 @@ bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
bool isSigned) const {
return ::isValidIntOrFloat(
- getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
- isInt, isSigned);
+ llvm::cast<ComplexType>(getElementType()).getElementType(),
+ dataEltSize / 2, isInt, isSigned);
}
/// Returns true if this attribute corresponds to a splat, i.e. if all element
@@ -1154,7 +1154,7 @@ bool DenseElementsAttr::isSplat() const {
/// Return if the given complex type has an integer element type.
static bool isComplexOfIntType(Type type) {
- return type.cast<ComplexType>().getElementType().isa<IntegerType>();
+ return llvm::isa<IntegerType>(llvm::cast<ComplexType>(type).getElementType());
}
auto DenseElementsAttr::tryGetComplexIntValues() const
@@ -1168,7 +1168,7 @@ auto DenseElementsAttr::tryGetComplexIntValues() const
auto DenseElementsAttr::tryGetFloatValues() const
-> FailureOr<iterator_range_impl<FloatElementIterator>> {
- auto eltTy = getElementType().dyn_cast<FloatType>();
+ auto eltTy = llvm::dyn_cast<FloatType>(getElementType());
if (!eltTy)
return failure();
const auto &elementSemantics = eltTy.getFloatSemantics();
@@ -1179,10 +1179,10 @@ auto DenseElementsAttr::tryGetFloatValues() const
auto DenseElementsAttr::tryGetComplexFloatValues() const
-> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> {
- auto complexTy = getElementType().dyn_cast<ComplexType>();
+ auto complexTy = llvm::dyn_cast<ComplexType>(getElementType());
if (!complexTy)
return failure();
- auto eltTy = complexTy.getElementType().dyn_cast<FloatType>();
+ auto eltTy = llvm::dyn_cast<FloatType>(complexTy.getElementType());
if (!eltTy)
return failure();
const auto &semantics = eltTy.getFloatSemantics();
@@ -1331,7 +1331,7 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
bool isInt,
bool isSigned) {
assert(::isValidIntOrFloat(
- type.getElementType().cast<ComplexType>().getElementType(),
+ llvm::cast<ComplexType>(type.getElementType()).getElementType(),
dataEltSize / 2, isInt, isSigned));
int64_t numElements = data.size() / dataEltSize;
@@ -1404,7 +1404,7 @@ void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
ShapedType type) {
size_t numElements = type.getNumElements();
Type elementType = type.getElementType();
- if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
+ if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
elementType = complexTy.getElementType();
numElements = numElements * 2;
}
@@ -1470,8 +1470,8 @@ DenseElementsAttr DenseFPElementsAttr::mapValues(
/// Method for supporting type inquiry through isa, cast and dyn_cast.
bool DenseFPElementsAttr::classof(Attribute attr) {
- if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
- return denseAttr.getType().getElementType().isa<FloatType>();
+ if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(attr))
+ return llvm::isa<FloatType>(denseAttr.getType().getElementType());
return false;
}
@@ -1489,7 +1489,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
/// Method for supporting type inquiry through isa, cast and dyn_cast.
bool DenseIntElementsAttr::classof(Attribute attr) {
- if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
+ if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(attr))
return denseAttr.getType().getElementType().isIntOrIndex();
return false;
}
@@ -1525,7 +1525,7 @@ struct DenseResourceAttrUtil;
template <size_t width, bool isSigned>
struct DenseResourceElementsAttrIntUtil {
static bool checkElementType(Type eltType) {
- IntegerType type = eltType.dyn_cast<IntegerType>();
+ IntegerType type = llvm::dyn_cast<IntegerType>(eltType);
if (!type || type.getWidth() != width)
return false;
return isSigned ? !type.isUnsigned() : !type.isSigned();
@@ -1582,8 +1582,8 @@ DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
"size mismatch between expected element width and blob size");
assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&
"invalid shape element type for provided type `T`");
- return DenseResourceElementsAttr::get(type, blobName, std::move(blob))
- .template cast<DenseResourceElementsAttrBase<T>>();
+ return llvm::cast<DenseResourceElementsAttrBase<T>>(
+ DenseResourceElementsAttr::get(type, blobName, std::move(blob)));
}
template <typename T>
@@ -1596,7 +1596,7 @@ DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
template <typename T>
bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
- auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>();
+ auto resourceAttr = llvm::dyn_cast<DenseResourceElementsAttr>(attr);
return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
resourceAttr.getElementType());
}
@@ -1624,13 +1624,13 @@ template class DenseResourceElementsAttrBase<double>;
/// Get a zero APFloat for the given sparse attribute.
APFloat SparseElementsAttr::getZeroAPFloat() const {
- auto eltType = getElementType().cast<FloatType>();
+ auto eltType = llvm::cast<FloatType>(getElementType());
return APFloat(eltType.getFloatSemantics());
}
/// Get a zero APInt for the given sparse attribute.
APInt SparseElementsAttr::getZeroAPInt() const {
- auto eltType = getElementType().cast<IntegerType>();
+ auto eltType = llvm::cast<IntegerType>(getElementType());
return APInt::getZero(eltType.getWidth());
}
@@ -1639,14 +1639,14 @@ Attribute SparseElementsAttr::getZeroAttr() const {
auto eltType = getElementType();
// Handle floating point elements.
- if (eltType.isa<FloatType>())
+ if (llvm::isa<FloatType>(eltType))
return FloatAttr::get(eltType, 0);
// Handle complex elements.
- if (auto complexTy = eltType.dyn_cast<ComplexType>()) {
+ if (auto complexTy = llvm::dyn_cast<ComplexType>(eltType)) {
auto eltType = complexTy.getElementType();
Attribute zero;
- if (eltType.isa<FloatType>())
+ if (llvm::isa<FloatType>(eltType))
zero = FloatAttr::get(eltType, 0);
else // must be integer
zero = IntegerAttr::get(eltType, 0);
@@ -1655,7 +1655,7 @@ Attribute SparseElementsAttr::getZeroAttr() const {
}
// Handle string type.
- if (getValues().isa<DenseStringElementsAttr>())
+ if (llvm::isa<DenseStringElementsAttr>(getValues()))
return StringAttr::get("", eltType);
// Otherwise, this is an integer.
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 2dd4e2234f0cb..f73863248969d 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -48,15 +48,15 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
: OpAsmDialectInterface(dialect), blobManager(mgr) {}
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
- if (attr.isa<AffineMapAttr>()) {
+ if (llvm::isa<AffineMapAttr>(attr)) {
os << "map";
return AliasResult::OverridableAlias;
}
- if (attr.isa<IntegerSetAttr>()) {
+ if (llvm::isa<IntegerSetAttr>(attr)) {
os << "set";
return AliasResult::OverridableAlias;
}
- if (attr.isa<LocationAttr>()) {
+ if (llvm::isa<LocationAttr>(attr)) {
os << "loc";
return AliasResult::OverridableAlias;
}
@@ -64,7 +64,7 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
}
AliasResult getAlias(Type type, raw_ostream &os) const final {
- if (auto tupleType = type.dyn_cast<TupleType>()) {
+ if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
if (tupleType.size() > 16) {
os << "tuple";
return AliasResult::OverridableAlias;
@@ -145,7 +145,7 @@ DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
// interface. This needs a linear search, but is called only once per data
// layout object construction that is used for repeated queries.
for (NamedAttribute attr : getOperation()->getAttrs())
- if (auto spec = attr.getValue().dyn_cast<DataLayoutSpecInterface>())
+ if (auto spec = llvm::dyn_cast<DataLayoutSpecInterface>(attr.getValue()))
return spec;
return {};
}
@@ -168,7 +168,7 @@ LogicalResult ModuleOp::verify() {
StringRef layoutSpecAttrName;
DataLayoutSpecInterface layoutSpec;
for (const NamedAttribute &na : (*this)->getAttrs()) {
- if (auto spec = na.getValue().dyn_cast<DataLayoutSpecInterface>()) {
+ if (auto spec = llvm::dyn_cast<DataLayoutSpecInterface>(na.getValue())) {
if (layoutSpec) {
InFlightDiagnostic diag =
emitOpError() << "expects at most one data layout attribute";
diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 40af5f3b17448..75457e9ff85fb 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -32,7 +32,7 @@ namespace {
static unsigned getIntegerBitWidth(DialectBytecodeReader &reader, Type type) {
if (auto intType = dyn_cast<IntegerType>(type)) {
return intType.getWidth();
- } else if (type.isa<IndexType>()) {
+ } else if (llvm::isa<IndexType>(type)) {
return IndexType::kInternalStorageBitWidth;
}
reader.emitError()
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 0810d8965b385..b46ea8a2e6e10 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -244,10 +244,10 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
VectorType VectorType::scaleElementBitwidth(unsigned scale) {
if (!scale)
return VectorType();
- if (auto et = getElementType().dyn_cast<IntegerType>())
+ if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
return VectorType::get(getShape(), scaledEt, getNumScalableDims());
- if (auto et = getElementType().dyn_cast<FloatType>())
+ if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
return VectorType::get(getShape(), scaledEt, getNumScalableDims());
return VectorType();
@@ -305,8 +305,8 @@ bool TensorType::isValidElementType(Type type) {
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
- return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
- IndexType>() ||
+ return llvm::isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
+ IndexType>(type) ||
!llvm::isa<BuiltinDialect>(type.getDialect());
}
@@ -321,7 +321,7 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
for (int64_t s : shape)
if (s < 0 && !ShapedType::isDynamic(s))
return emitError() << "invalid tensor dimension size";
- if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
+ if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
if (failed(v.verifyEncoding(shape, elementType, emitError)))
return failure();
return checkTensorElementType(emitError, elementType);
@@ -426,9 +426,9 @@ mlir::isRankReducedType(ShapedType originalType,
if (originalType == candidateReducedType)
return SliceVerificationResult::Success;
- ShapedType originalShapedType = originalType.cast<ShapedType>();
+ ShapedType originalShapedType = llvm::cast<ShapedType>(originalType);
ShapedType candidateReducedShapedType =
- candidateReducedType.cast<ShapedType>();
+ llvm::cast<ShapedType>(candidateReducedType);
// Rank and size logic is valid for all ShapedTypes.
ArrayRef<int64_t> originalShape = originalShapedType.getShape();
@@ -459,7 +459,7 @@ bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
return true;
// Supported built-in attributes.
- if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
+ if (llvm::isa<IntegerAttr, StringAttr, DictionaryAttr>(memorySpace))
return true;
// Allow custom dialect attributes.
@@ -478,7 +478,7 @@ Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
}
Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
- IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
+ IntegerAttr intMemorySpace = llvm::dyn_cast_or_null<IntegerAttr>(memorySpace);
if (intMemorySpace && intMemorySpace.getValue() == 0)
return nullptr;
@@ -489,10 +489,10 @@ unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
if (!memorySpace)
return 0;
- assert(memorySpace.isa<IntegerAttr>() &&
+ assert(llvm::isa<IntegerAttr>(memorySpace) &&
"Using `getMemorySpaceInteger` with non-Integer attribute");
- return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
+ return static_cast<unsigned>(llvm::cast<IntegerAttr>(memorySpace).getInt());
}
unsigned MemRefType::getMemorySpaceAsInt() const {
@@ -786,7 +786,7 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
SmallVectorImpl<int64_t> &strides,
int64_t &offset) {
// Happy path: the type uses the strided layout directly.
- if (auto strided = t.getLayout().dyn_cast<StridedLayoutAttr>()) {
+ if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
llvm::append_range(strides, strided.getStrides());
offset = strided.getOffset();
return success();
@@ -834,7 +834,7 @@ ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
/// (i32, tensor<i32>, f32, i64)
void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
for (Type type : getTypes()) {
- if (auto nestedTuple = type.dyn_cast<TupleType>())
+ if (auto nestedTuple = llvm::dyn_cast<TupleType>(type))
nestedTuple.getFlattenedTypes(types);
else
types.push_back(type);
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 5926f2e6953aa..6788660d87828 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -259,7 +259,7 @@ void DiagnosticEngineImpl::emit(Diagnostic &&diag) {
return;
auto &os = llvm::errs();
- if (!diag.getLocation().isa<UnknownLoc>())
+ if (!llvm::isa<UnknownLoc>(diag.getLocation()))
os << diag.getLocation() << ": ";
os << "error: ";
@@ -448,7 +448,7 @@ void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
if (!fileLoc) {
std::string str;
llvm::raw_string_ostream strOS(str);
- if (!loc.isa<UnknownLoc>())
+ if (!llvm::isa<UnknownLoc>(loc))
strOS << loc << ": ";
strOS << message;
return mgr.PrintMessage(os, SMLoc(), getDiagKind(kind), strOS.str());
@@ -983,7 +983,7 @@ struct ParallelDiagnosticHandlerImpl : public llvm::PrettyStackTraceEntry {
// Print each diagnostic with the format:
// "<location>: <kind>: <msg>"
- if (!diag.getLocation().isa<UnknownLoc>())
+ if (!llvm::isa<UnknownLoc>(diag.getLocation()))
os << diag.getLocation() << ": ";
switch (diag.getSeverity()) {
case DiagnosticSeverity::Error:
diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
index 00849857b51d4..2225a8f2d1b91 100644
--- a/mlir/lib/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/IR/ExtensibleDialect.cpp
@@ -474,7 +474,7 @@ OptionalParseResult ExtensibleDialect::parseOptionalDynamicType(
LogicalResult ExtensibleDialect::printIfDynamicType(Type type,
AsmPrinter &printer) {
- if (auto dynType = type.dyn_cast<DynamicType>()) {
+ if (auto dynType = llvm::dyn_cast<DynamicType>(type)) {
dynType.print(printer);
return success();
}
@@ -496,7 +496,7 @@ OptionalParseResult ExtensibleDialect::parseOptionalDynamicAttr(
LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
AsmPrinter &printer) {
- if (auto dynAttr = attribute.dyn_cast<DynamicAttr>()) {
+ if (auto dynAttr = llvm::dyn_cast<DynamicAttr>(attribute)) {
dynAttr.print(printer);
return success();
}
diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index 5ca6777cea18f..af625a4f36118 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -250,14 +250,14 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
"Invalid number of attributes.");
auto &os = p.getStream();
- bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() ||
- (attrs && !attrs[0].cast<DictionaryAttr>().empty());
+ bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
+ (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
if (needsParens)
os << '(';
llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
p.printType(types[i]);
if (attrs)
- p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue());
+ p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
});
if (needsParens)
os << ')';
@@ -278,12 +278,13 @@ void function_interface_impl::printFunctionSignature(
if (!isExternal) {
ArrayRef<NamedAttribute> attrs;
if (argAttrs)
- attrs = argAttrs[i].cast<DictionaryAttr>().getValue();
+ attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
p.printRegionArgument(body.getArgument(i), attrs);
} else {
p.printType(argTypes[i]);
if (argAttrs)
- p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue());
+ p.printOptionalAttrDict(
+ llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
}
}
diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp
index 4a50a49502f57..474f7c817aad8 100644
--- a/mlir/lib/IR/FunctionInterfaces.cpp
+++ b/mlir/lib/IR/FunctionInterfaces.cpp
@@ -21,14 +21,14 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
static bool isEmptyAttrDict(Attribute attr) {
- return attr.cast<DictionaryAttr>().empty();
+ return llvm::cast<DictionaryAttr>(attr).empty();
}
DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
unsigned index) {
ArrayAttr attrs = op.getArgAttrsAttr();
DictionaryAttr argAttrs =
- attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
+ attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
return argAttrs;
}
@@ -37,7 +37,7 @@ function_interface_impl::getResultAttrDict(FunctionOpInterface op,
unsigned index) {
ArrayAttr attrs = op.getResAttrsAttr();
DictionaryAttr resAttrs =
- attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
+ attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
return resAttrs;
}
@@ -288,7 +288,7 @@ void function_interface_impl::eraseFunctionArguments(
newArgAttrs.reserve(argAttrs.size());
for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
if (!argIndices[i])
- newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
+ newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i]));
setAllArgAttrDicts(op, newArgAttrs);
}
@@ -309,7 +309,7 @@ void function_interface_impl::eraseFunctionResults(
newResultAttrs.reserve(resAttrs.size());
for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
if (!resultIndices[i])
- newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
+ newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i]));
setAllResultAttrDicts(op, newResultAttrs);
}
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
index dcbf9dcecfe29..c548bbe4b6c86 100644
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -64,8 +64,8 @@ WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
/// Methods for support type inquiry through isa, cast, and dyn_cast.
bool LocationAttr::classof(Attribute attr) {
- return attr.isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
- UnknownLoc>();
+ return llvm::isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
+ UnknownLoc>(attr);
}
//===----------------------------------------------------------------------===//
@@ -101,7 +101,7 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
}
}
// Otherwise, only add known locations to the set.
- if (!loc.isa<UnknownLoc>())
+ if (!llvm::isa<UnknownLoc>(loc))
decomposedLocs.insert(loc);
}
locs = decomposedLocs.getArrayRef();
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index bc7a9d852fd0f..42b2c79ef6f9c 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -829,11 +829,11 @@ LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op,
/// If this is a vector type, or a tensor type, return the scalar element type
/// that it is built around, otherwise return the type unmodified.
static Type getTensorOrVectorElementType(Type type) {
- if (auto vec = type.dyn_cast<VectorType>())
+ if (auto vec = llvm::dyn_cast<VectorType>(type))
return vec.getElementType();
// Look through tensor<vector<...>> to find the underlying element type.
- if (auto tensor = type.dyn_cast<TensorType>())
+ if (auto tensor = llvm::dyn_cast<TensorType>(type))
return getTensorOrVectorElementType(tensor.getElementType());
return type;
}
@@ -867,7 +867,7 @@ OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) {
LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
for (auto opType : op->getOperandTypes()) {
auto type = getTensorOrVectorElementType(opType);
- if (!type.isa<FloatType>())
+ if (!llvm::isa<FloatType>(type))
return op->emitOpError("requires a float type");
}
return success();
@@ -1102,7 +1102,7 @@ LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
for (auto resultType : op->getResultTypes())
- if (!getTensorOrVectorElementType(resultType).isa<FloatType>())
+ if (!llvm::isa<FloatType>(getTensorOrVectorElementType(resultType)))
return op->emitOpError() << "requires a floating point type";
return success();
@@ -1169,7 +1169,7 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
auto isMappableType = [](Type type) {
- return type.isa<VectorType, TensorType>();
+ return llvm::isa<VectorType, TensorType>(type);
};
auto resultMappableTypes = llvm::to_vector<1>(
llvm::make_filter_range(op->getResultTypes(), isMappableType));
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index a0a86e776ce37..b8f3601222390 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -59,7 +59,7 @@ DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const {
}
if (!dictionarySorted.getPointer())
dictionarySorted.setPointer(DictionaryAttr::getWithSorted(context, attrs));
- return dictionarySorted.getPointer().cast<DictionaryAttr>();
+ return llvm::cast<DictionaryAttr>(dictionarySorted.getPointer());
}
/// Add an attribute with the specified name.
@@ -405,18 +405,19 @@ OperandRangeRange OperandRange::split(DenseI32ArrayAttr segmentSizes) const {
OperandRangeRange::OperandRangeRange(OperandRange operands,
Attribute operandSegments)
: OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
- operandSegments.cast<DenseI32ArrayAttr>().size()) {}
+ llvm::cast<DenseI32ArrayAttr>(operandSegments).size()) {
+}
OperandRange OperandRangeRange::join() const {
const OwnerT &owner = getBase();
- ArrayRef<int32_t> sizeData = owner.second.cast<DenseI32ArrayAttr>();
+ ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(owner.second);
return OperandRange(owner.first,
std::accumulate(sizeData.begin(), sizeData.end(), 0));
}
OperandRange OperandRangeRange::dereference(const OwnerT &object,
ptr
diff _t index) {
- ArrayRef<int32_t> sizeData = object.second.cast<DenseI32ArrayAttr>();
+ ArrayRef<int32_t> sizeData = llvm::cast<DenseI32ArrayAttr>(object.second);
uint32_t startIndex =
std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
@@ -508,7 +509,7 @@ void MutableOperandRange::updateLength(unsigned newLength) {
// Update any of the provided segment attributes.
for (OperandSegment &segment : operandSegments) {
- auto attr = segment.second.getValue().cast<DenseI32ArrayAttr>();
+ auto attr = llvm::cast<DenseI32ArrayAttr>(segment.second.getValue());
SmallVector<int32_t, 8> segments(attr.asArrayRef());
segments[segment.first] +=
diff ;
segment.second.setValue(
@@ -524,7 +525,8 @@ MutableOperandRangeRange::MutableOperandRangeRange(
const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
: MutableOperandRangeRange(
OwnerT(operands, operandSegmentAttr), 0,
- operandSegmentAttr.getValue().cast<DenseI32ArrayAttr>().size()) {}
+ llvm::cast<DenseI32ArrayAttr>(operandSegmentAttr.getValue()).size()) {
+}
MutableOperandRange MutableOperandRangeRange::join() const {
return getBase().first;
@@ -537,7 +539,7 @@ MutableOperandRangeRange::operator OperandRangeRange() const {
MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
ptr
diff _t index) {
ArrayRef<int32_t> sizeData =
- object.second.getValue().cast<DenseI32ArrayAttr>();
+ llvm::cast<DenseI32ArrayAttr>(object.second.getValue());
uint32_t startIndex =
std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
return object.first.slice(
@@ -782,8 +784,8 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
auto sortValues = [](ValueRange values) {
SmallVector<Value> sortedValues = llvm::to_vector(values);
llvm::sort(sortedValues, [](Value a, Value b) {
- auto aArg = a.dyn_cast<BlockArgument>();
- auto bArg = b.dyn_cast<BlockArgument>();
+ auto aArg = llvm::dyn_cast<BlockArgument>(a);
+ auto bArg = llvm::dyn_cast<BlockArgument>(b);
// Case 1. Both `a` and `b` are `BlockArgument`s.
if (aArg && bArg) {
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 85f5cb1852fe0..c03f4dd5f352e 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -459,7 +459,7 @@ LogicalResult detail::verifySymbol(Operation *op) {
// Verify the visibility attribute.
if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
- StringAttr visStrAttr = vis.dyn_cast<StringAttr>();
+ StringAttr visStrAttr = llvm::dyn_cast<StringAttr>(vis);
if (!visStrAttr)
return op->emitOpError() << "requires visibility attribute '"
<< mlir::SymbolTable::getVisibilityAttrName()
@@ -669,7 +669,7 @@ static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
// If the references are not pointer equal, check to see if `subRef` is a
// prefix of `ref`.
- if (ref.isa<FlatSymbolRefAttr>() ||
+ if (llvm::isa<FlatSymbolRefAttr>(ref) ||
ref.getRootReference() != subRef.getRootReference())
return false;
@@ -789,7 +789,7 @@ bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
/// Generates a new symbol reference attribute with a new leaf reference.
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
FlatSymbolRefAttr newLeafAttr) {
- if (oldAttr.isa<FlatSymbolRefAttr>())
+ if (llvm::isa<FlatSymbolRefAttr>(oldAttr))
return newLeafAttr;
auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
nestedRefs.back() = newLeafAttr;
diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index bcf698400b1d0..7aa37cb015fcb 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -22,7 +22,7 @@
using namespace mlir;
Type mlir::getElementTypeOrSelf(Type type) {
- if (auto st = type.dyn_cast<ShapedType>())
+ if (auto st = llvm::dyn_cast<ShapedType>(type))
return st.getElementType();
return type;
}
@@ -32,7 +32,7 @@ Type mlir::getElementTypeOrSelf(Value val) {
}
Type mlir::getElementTypeOrSelf(Attribute attr) {
- if (auto typedAttr = attr.dyn_cast<TypedAttr>())
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr))
return getElementTypeOrSelf(typedAttr.getType());
return {};
}
@@ -47,7 +47,7 @@ SmallVector<Type, 10> mlir::getFlattenedTypes(TupleType t) {
/// dialect and typeData.
bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
StringRef typeData) {
- if (auto opaque = type.dyn_cast<mlir::OpaqueType>())
+ if (auto opaque = llvm::dyn_cast<mlir::OpaqueType>(type))
return opaque.getDialectNamespace() == dialect &&
opaque.getTypeData() == typeData;
return false;
@@ -76,8 +76,8 @@ LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
/// compatible if at least one is dynamic or both are equal. The element type
/// does not matter.
LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
- auto sType1 = type1.dyn_cast<ShapedType>();
- auto sType2 = type2.dyn_cast<ShapedType>();
+ auto sType1 = llvm::dyn_cast<ShapedType>(type1);
+ auto sType2 = llvm::dyn_cast<ShapedType>(type2);
// Either both or neither type should be shaped.
if (!sType1)
@@ -120,7 +120,7 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
/// dims are equal. The element type does not matter.
LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
- types, [](auto type) { return type.template dyn_cast<ShapedType>(); }));
+ types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); }));
// Return failure if some, but not all are not shaped. Return early if none
// are shaped also.
if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
@@ -132,7 +132,7 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
bool hasScalableVecTypes = false;
bool hasNonScalableVecTypes = false;
for (Type t : types) {
- auto vType = t.dyn_cast<VectorType>();
+ auto vType = llvm::dyn_cast<VectorType>(t);
if (vType && vType.isScalable())
hasScalableVecTypes = true;
else
@@ -167,9 +167,9 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
}
Type OperandElementTypeIterator::mapElement(Value value) const {
- return value.getType().cast<ShapedType>().getElementType();
+ return llvm::cast<ShapedType>(value.getType()).getElementType();
}
Type ResultElementTypeIterator::mapElement(Value value) const {
- return value.getType().cast<ShapedType>().getElementType();
+ return llvm::cast<ShapedType>(value.getType()).getElementType();
}
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 432d7bb28ae9b..68e498d57324e 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -302,7 +302,7 @@ static void diagnoseInvalidOperandDominance(Operation &op, unsigned operandNo) {
}
// Block argument case.
Block *block1 = op.getBlock();
- Block *block2 = operand.cast<BlockArgument>().getOwner();
+ Block *block2 = llvm::cast<BlockArgument>(operand).getOwner();
Region *region1 = block1->getParent();
Region *region2 = block2->getParent();
Location loc = UnknownLoc::get(op.getContext());
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index e0ccd500aa909..7fc2e6ab3ec0a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -75,7 +75,7 @@ Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
if (parser.parseRSquare() || parser.parseGreater())
return Attribute();
return parser.getChecked<TestI64ElementsAttr>(
- parser.getContext(), type.cast<ShapedType>(), elements);
+ parser.getContext(), llvm::cast<ShapedType>(type), elements);
}
void TestI64ElementsAttr::print(AsmPrinter &printer) const {
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 231f69f629ce2..0633752067a14 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -287,18 +287,18 @@ TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
for (DataLayoutEntryInterface entry : params) {
// This is for testing purposes only, so assert well-formedness.
assert(entry.isTypeEntry() && "unexpected identifier entry");
- assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() &&
+ assert(llvm::isa<TestTypeWithLayoutType>(entry.getKey().get<Type>()) &&
"wrong type passed in");
- auto array = entry.getValue().dyn_cast<ArrayAttr>();
+ auto array = llvm::dyn_cast<ArrayAttr>(entry.getValue());
assert(array && array.getValue().size() == 2 &&
"expected array of two elements");
- auto kind = array.getValue().front().dyn_cast<StringAttr>();
+ auto kind = llvm::dyn_cast<StringAttr>(array.getValue().front());
(void)kind;
assert(kind &&
(kind.getValue() == "size" || kind.getValue() == "alignment" ||
kind.getValue() == "preferred") &&
"unexpected kind");
- assert(array.getValue().back().isa<IntegerAttr>());
+ assert(llvm::isa<IntegerAttr>(array.getValue().back()));
}
return success();
}
@@ -306,10 +306,11 @@ TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
StringRef expectedKind) const {
for (DataLayoutEntryInterface entry : params) {
- ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue();
- StringRef kind = pair.front().cast<StringAttr>().getValue();
+ ArrayRef<Attribute> pair =
+ llvm::cast<ArrayAttr>(entry.getValue()).getValue();
+ StringRef kind = llvm::cast<StringAttr>(pair.front()).getValue();
if (kind == expectedKind)
- return pair.back().cast<IntegerAttr>().getValue().getZExtValue();
+ return llvm::cast<IntegerAttr>(pair.back()).getValue().getZExtValue();
}
return 1;
}
@@ -466,7 +467,7 @@ void TestDialect::printTestType(Type type, AsmPrinter &printer,
if (succeeded(printIfDynamicType(type, printer)))
return;
- auto rec = type.cast<TestRecursiveType>();
+ auto rec = llvm::cast<TestRecursiveType>(type);
printer << "test_rec<" << rec.getName();
if (!stack.contains(rec)) {
printer << ", ";
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 9cb26c11eb892..1d1bbc3a57084 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -109,10 +109,10 @@ DiagnosedSilenceableFailure
mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
if (getOperation()->getNumOperands() != 0) {
- results.set(getResult().cast<OpResult>(),
+ results.set(llvm::cast<OpResult>(getResult()),
getOperation()->getOperand(0).getDefiningOp());
} else {
- results.set(getResult().cast<OpResult>(), getOperation());
+ results.set(llvm::cast<OpResult>(getResult()), getOperation());
}
return DiagnosedSilenceableFailure::success();
}
@@ -127,7 +127,7 @@ void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToSelfOperand::apply(
transform::TransformResults &results, transform::TransformState &state) {
- results.setValues(getOut().cast<OpResult>(), getIn());
+ results.setValues(llvm::cast<OpResult>(getOut()), getIn());
return DiagnosedSilenceableFailure::success();
}
@@ -249,13 +249,13 @@ DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
for (Value value : values) {
std::string note;
llvm::raw_string_ostream os(note);
- if (auto arg = value.dyn_cast<BlockArgument>()) {
+ if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
os << "a block argument #" << arg.getArgNumber() << " in block #"
<< std::distance(arg.getOwner()->getParent()->begin(),
arg.getOwner()->getIterator())
<< " in region #" << arg.getOwner()->getParent()->getRegionNumber();
} else {
- os << "an op result #" << value.cast<OpResult>().getResultNumber();
+ os << "an op result #" << llvm::cast<OpResult>(value).getResultNumber();
}
InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage();
diag.attachNote() << "value handle points to " << os.str();
@@ -317,7 +317,7 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
getOperation())))
return DiagnosedSilenceableFailure::definiteFailure();
if (getNumResults() > 0)
- results.set(getResult(0).cast<OpResult>(), getOperation());
+ results.set(llvm::cast<OpResult>(getResult(0)), getOperation());
return DiagnosedSilenceableFailure::success();
}
@@ -339,7 +339,7 @@ mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
- results.set(getResult().cast<OpResult>(), reversedOps);
+ results.set(llvm::cast<OpResult>(getResult()), reversedOps);
return DiagnosedSilenceableFailure::success();
}
@@ -443,7 +443,8 @@ void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
DiagnosedSilenceableFailure
mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- results.set(getCopy().cast<OpResult>(), state.getPayloadOps(getHandle()));
+ results.set(llvm::cast<OpResult>(getCopy()),
+ state.getPayloadOps(getHandle()));
return DiagnosedSilenceableFailure::success();
}
@@ -472,7 +473,7 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
Location loc, ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
- auto integerAttr = attr.dyn_cast<IntegerAttr>();
+ auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
if (integerAttr && integerAttr.getType().isSignlessInteger(32))
continue;
return emitSilenceableError(loc)
@@ -534,7 +535,7 @@ mlir::test::TestAddToParamOp::apply(transform::TransformResults &results,
if (Value param = getParam()) {
values = llvm::to_vector(
llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
- return attr.cast<IntegerAttr>().getValue().getLimitedValue(
+ return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
UINT32_MAX);
}));
}
@@ -544,7 +545,7 @@ mlir::test::TestAddToParamOp::apply(transform::TransformResults &results,
llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
return builder.getI32IntegerAttr(value + getAddendum());
}));
- results.setParams(getResult().cast<OpResult>(), result);
+ results.setParams(llvm::cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
@@ -562,7 +563,7 @@ mlir::test::TestProduceParamWithNumberOfTestOps::apply(
});
return builder.getI32IntegerAttr(count);
}));
- results.setParams(getResult().cast<OpResult>(), result);
+ results.setParams(llvm::cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
@@ -570,12 +571,12 @@ DiagnosedSilenceableFailure
mlir::test::TestProduceIntegerParamWithTypeOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
Attribute zero = IntegerAttr::get(getType(), 0);
- results.setParams(getResult().cast<OpResult>(), zero);
+ results.setParams(llvm::cast<OpResult>(getResult()), zero);
return DiagnosedSilenceableFailure::success();
}
LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() {
- if (!getType().isa<IntegerType>()) {
+ if (!llvm::isa<IntegerType>(getType())) {
return emitOpError() << "expects an integer type";
}
return success();
@@ -618,7 +619,7 @@ void mlir::test::TestProduceNullPayloadOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
SmallVector<Operation *, 1> null({nullptr});
- results.set(getOut().cast<OpResult>(), null);
+ results.set(llvm::cast<OpResult>(getOut()), null);
return DiagnosedSilenceableFailure::success();
}
@@ -630,7 +631,7 @@ void mlir::test::TestProduceNullParamOp::getEffects(
DiagnosedSilenceableFailure
mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- results.setParams(getOut().cast<OpResult>(), Attribute());
+ results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
return DiagnosedSilenceableFailure::success();
}
@@ -642,7 +643,7 @@ void mlir::test::TestProduceNullValueOp::getEffects(
DiagnosedSilenceableFailure
mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- results.setValues(getOut().cast<OpResult>(), Value());
+ results.setValues(llvm::cast<OpResult>(getOut()), Value());
return DiagnosedSilenceableFailure::success();
}
@@ -662,7 +663,7 @@ void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
- results.set(getOut().cast<OpResult>(), state.getPayloadOps(getIn()));
+ results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 280cfa0b1738d..7b443554440bc 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -16,7 +16,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test,
python_test::PythonTestDialect)
bool mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr) {
- return unwrap(attr).isa<python_test::TestAttrAttr>();
+ return llvm::isa<python_test::TestAttrAttr>(unwrap(attr));
}
MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
@@ -24,7 +24,7 @@ MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
}
bool mlirTypeIsAPythonTestTestType(MlirType type) {
- return unwrap(type).isa<python_test::TestTypeType>();
+ return llvm::isa<python_test::TestTypeType>(unwrap(type));
}
MlirType mlirPythonTestTestTypeGet(MlirContext context) {
diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index 128696d838b98..3bc02fc2cb3ae 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -175,7 +175,7 @@ TEST(EnumsGenTest, GeneratedIntAttributeClass) {
mlir::Type intType = mlir::IntegerType::get(&ctx, 32);
mlir::Attribute intAttr = mlir::IntegerAttr::get(intType, 5);
- EXPECT_TRUE(intAttr.isa<I32EnumAttr>());
+ EXPECT_TRUE(llvm::isa<I32EnumAttr>(intAttr));
EXPECT_EQ(intAttr, enumAttr);
}
@@ -186,10 +186,10 @@ TEST(EnumsGenTest, GeneratedBitAttributeClass) {
mlir::Attribute intAttr = mlir::IntegerAttr::get(
intType,
static_cast<uint32_t>(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3));
- EXPECT_TRUE(intAttr.isa<BitEnumWithNoneAttr>());
- EXPECT_TRUE(intAttr.isa<BitEnumWithoutNoneAttr>());
+ EXPECT_TRUE(llvm::isa<BitEnumWithNoneAttr>(intAttr));
+ EXPECT_TRUE(llvm::isa<BitEnumWithoutNoneAttr>(intAttr));
intAttr = mlir::IntegerAttr::get(
intType, static_cast<uint32_t>(BitEnumWithGroup::Bits0To3) | (1u << 6));
- EXPECT_FALSE(intAttr.isa<BitEnumWithGroupAttr>());
+ EXPECT_FALSE(llvm::isa<BitEnumWithGroupAttr>(intAttr));
}
More information about the Mlir-commits
mailing list