[Mlir-commits] [mlir] 9f0d5cd - [mlir][spirv] Clean up casts. NFC. (#174115)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 31 14:08:08 PST 2025
Author: Jakub Kuderski
Date: 2025-12-31T17:08:03-05:00
New Revision: 9f0d5cd0c2280a3bf9237c6ad8f7fe17be36e7e6
URL: https://github.com/llvm/llvm-project/commit/9f0d5cd0c2280a3bf9237c6ad8f7fe17be36e7e6
DIFF: https://github.com/llvm/llvm-project/commit/9f0d5cd0c2280a3bf9237c6ad8f7fe17be36e7e6.diff
LOG: [mlir][spirv] Clean up casts. NFC. (#174115)
Drop the `llvm::` namespace prefix where needlessly used. These were
introduced by clang-tidy when we migrated from cast member functions to
free functions.
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index e46b576810316..aac5ef17370b2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -407,7 +407,7 @@ class CooperativeMatrixType
/// Returns the use parameter of the cooperative matrix.
CooperativeMatrixUseKHR getUse() const;
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+ operator ShapedType() const { return cast<ShapedType>(*this); }
ArrayRef<int64_t> getShape() const;
@@ -491,7 +491,7 @@ class TensorArmType
Type getElementType() const;
ArrayRef<int64_t> getShape() const;
bool hasRank() const { return !getShape().empty(); }
- operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+ operator ShapedType() const { return cast<ShapedType>(*this); }
};
} // namespace spirv
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 56e8fee191432..c101a95685a25 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -422,10 +422,10 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
#define INT_AND_FLOAT_CASE(kind, iop, fop) \
case vector::CombiningKind::kind: \
- if (llvm::isa<IntegerType>(resultType)) { \
+ if (isa<IntegerType>(resultType)) { \
result = spirv::iop::create(rewriter, loc, resultType, result, next); \
} else { \
- assert(llvm::isa<FloatType>(resultType)); \
+ assert(isa<FloatType>(resultType)); \
result = spirv::fop::create(rewriter, loc, resultType, result, next); \
} \
break
diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index 948d48980f2e8..7029268177128 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -35,9 +35,9 @@ StringRef stringifyTypeName<FloatType>() {
// Verifies an atomic update op.
template <typename AtomicOpTy, typename ExpectedElementType>
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
- auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
+ auto ptrType = cast<spirv::PointerType>(op->getOperand(0).getType());
auto elementType = ptrType.getPointeeType();
- if (!llvm::isa<ExpectedElementType>(elementType))
+ if (!isa<ExpectedElementType>(elementType))
return op->emitOpError() << "pointer operand must point to an "
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index fcf4eb6fbcf60..a5330dc56d48f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -87,13 +87,13 @@ LogicalResult BitcastOp::verify() {
if (operandType == resultType) {
return emitError("result type must be
diff erent from operand type");
}
- if (llvm::isa<spirv::PointerType>(operandType) &&
- !llvm::isa<spirv::PointerType>(resultType)) {
+ if (isa<spirv::PointerType>(operandType) &&
+ !isa<spirv::PointerType>(resultType)) {
return emitError(
"unhandled bit cast conversion from pointer type to non-pointer type");
}
- if (!llvm::isa<spirv::PointerType>(operandType) &&
- llvm::isa<spirv::PointerType>(resultType)) {
+ if (!isa<spirv::PointerType>(operandType) &&
+ isa<spirv::PointerType>(resultType)) {
return emitError(
"unhandled bit cast conversion from non-pointer type to pointer type");
}
@@ -112,8 +112,8 @@ LogicalResult BitcastOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult ConvertPtrToUOp::verify() {
- auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
- auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
+ auto operandType = cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = cast<spirv::ScalarType>(getResult().getType());
if (!resultType || !resultType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
@@ -133,8 +133,8 @@ LogicalResult ConvertPtrToUOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult ConvertUToPtrOp::verify() {
- auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
- auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+ auto operandType = cast<spirv::ScalarType>(getOperand().getType());
+ auto resultType = cast<spirv::PointerType>(getResult().getType());
if (!operandType || !operandType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
@@ -154,8 +154,8 @@ LogicalResult ConvertUToPtrOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult PtrCastToGenericOp::verify() {
- auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
- auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+ auto operandType = cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Workgroup &&
@@ -182,8 +182,8 @@ LogicalResult PtrCastToGenericOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult GenericCastToPtrOp::verify() {
- auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
- auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+ auto operandType = cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
@@ -210,8 +210,8 @@ LogicalResult GenericCastToPtrOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult GenericCastToPtrExplicitOp::verify() {
- auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
- auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+ auto operandType = cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index a846d7e60024c..4d0aedca27d42 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -138,7 +138,7 @@ LogicalResult BranchConditionalOp::verify() {
return emitOpError("must have exactly two branch weights");
}
if (llvm::all_of(*weights, [](Attribute attr) {
- return llvm::cast<IntegerAttr>(attr).getValue().isZero();
+ return cast<IntegerAttr>(attr).getValue().isZero();
}))
return emitOpError("branch weights cannot both be zero");
}
@@ -504,8 +504,8 @@ LogicalResult ReturnValueOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult SelectOp::verify() {
- if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
- auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
+ if (auto conditionTy = dyn_cast<VectorType>(getCondition().getType())) {
+ auto resultVectorTy = dyn_cast<VectorType>(getResult().getType());
if (!resultVectorTy) {
return emitOpError("result expected to be of vector type when "
"condition is of vector type");
diff --git a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
index 01ef1bdc42515..dada8925b88e1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
@@ -74,10 +74,9 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
Type factorTy = op->getOperand(0).getType();
StringAttr packedVectorFormatAttrName =
IntegerDotProductOpTy::getFormatAttrName(op->getName());
- if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
- auto packedVectorFormat =
- llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
- op->getAttr(packedVectorFormatAttrName));
+ if (auto intTy = dyn_cast<IntegerType>(factorTy)) {
+ auto packedVectorFormat = dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
+ op->getAttr(packedVectorFormatAttrName));
if (!packedVectorFormat)
return op->emitOpError("requires Packed Vector Format attribute for "
"integer vector operands");
@@ -135,8 +134,8 @@ getIntegerDotProductCapabilities(Operation *op) {
Type factorTy = op->getOperand(0).getType();
StringAttr packedVectorFormatAttrName =
IntegerDotProductOpTy::getFormatAttrName(op->getName());
- if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
- auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
+ if (auto intTy = dyn_cast<IntegerType>(factorTy)) {
+ auto formatAttr = cast<spirv::PackedVectorFormatAttr>(
op->getAttr(packedVectorFormatAttrName));
if (formatAttr.getValue() ==
spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
@@ -145,7 +144,7 @@ getIntegerDotProductCapabilities(Operation *op) {
return capabilities;
}
- auto vecTy = llvm::cast<VectorType>(factorTy);
+ auto vecTy = cast<VectorType>(factorTy);
if (vecTy.getElementTypeBitWidth() == 8) {
capabilities.push_back(dotProductInput4x8BitCap);
return capabilities;
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 461d037134dae..a1bb7f89e9183 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -65,7 +65,7 @@ LogicalResult GroupBroadcastOp::verify() {
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
- if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
+ if (auto localIdTy = dyn_cast<VectorType>(getLocalid().getType()))
if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
return emitOpError("localid is a vector and can be with only "
" 2 or 3 components, actual number is ")
diff --git a/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
index 661f3d5d9b81d..6ea07330b70cb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
@@ -89,8 +89,8 @@ static LogicalResult verifyImageOperands(Operation *imageOp,
"floating-point type scalar");
auto samplingOp = cast<spirv::SamplingOpInterface>(imageOp);
- auto sampledImageType = llvm::cast<spirv::SampledImageType>(
- samplingOp.getSampledImage().getType());
+ auto sampledImageType =
+ cast<spirv::SampledImageType>(samplingOp.getSampledImage().getType());
imageType = cast<spirv::ImageType>(sampledImageType.getImageType());
} else {
if (!isa<mlir::IntegerType>(operands[index].getType()))
@@ -243,8 +243,7 @@ LogicalResult spirv::ImageWriteOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ImageQuerySizeOp::verify() {
- spirv::ImageType imageType =
- llvm::cast<spirv::ImageType>(getImage().getType());
+ spirv::ImageType imageType = cast<spirv::ImageType>(getImage().getType());
Type resultType = getResult().getType();
spirv::Dim dim = imageType.getDim();
@@ -292,7 +291,7 @@ LogicalResult spirv::ImageQuerySizeOp::verify() {
componentNumber += 1;
unsigned resultComponentNumber = 1;
- if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
+ if (auto resultVectorType = dyn_cast<VectorType>(resultType))
resultComponentNumber = resultVectorType.getNumElements();
if (componentNumber != resultComponentNumber)
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 5ae27e5d82bd7..e3187d3dc1901 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -166,7 +166,7 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
// TODO: Check that the value type satisfies restrictions of
// SPIR-V OpLoad/OpStore operations
if (val.getType() !=
- llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
+ cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
return op.emitOpError("mismatch in result type and pointer type");
}
return success();
@@ -190,7 +190,7 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
return success();
}
- auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
+ auto memAccess = cast<spirv::MemoryAccessAttr>(memAccessAttr);
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
@@ -234,7 +234,7 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
return success();
}
- auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
+ auto memAccess = cast<spirv::MemoryAccessAttr>(memAccessAttr);
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
@@ -261,7 +261,7 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
//===----------------------------------------------------------------------===//
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
- auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+ auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType) {
emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
"to composite type, but provided ")
@@ -274,7 +274,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
int32_t index = 0;
for (auto indexSSA : indices) {
- auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
+ auto cType = dyn_cast<spirv::CompositeType>(resultType);
if (!cType) {
emitError(
baseLoc,
@@ -283,7 +283,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
return nullptr;
}
index = 0;
- if (llvm::isa<spirv::StructType>(resultType)) {
+ if (isa<spirv::StructType>(resultType)) {
Operation *op = indexSSA.getDefiningOp();
if (!op) {
emitError(baseLoc, "'spirv.AccessChain' op index must be an "
@@ -334,7 +334,7 @@ static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
return failure();
auto providedResultType =
- llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
+ dyn_cast<spirv::PointerType>(accessChainOp.getType());
if (!providedResultType)
return accessChainOp.emitOpError(
"result type must be a pointer, but provided")
@@ -357,7 +357,7 @@ LogicalResult AccessChainOp::verify() {
void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
- auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
+ auto ptrType = cast<spirv::PointerType>(basePtr.getType());
build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
alignment);
}
@@ -386,7 +386,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
void LoadOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
- llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
+ cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
printer << " \"" << sc << "\" " << getPtr();
printMemoryAccessAttribute(*this, printer, elidedAttrs);
@@ -433,7 +433,7 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
void StoreOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
- llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
+ cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
printMemoryAccessAttribute(*this, printer, elidedAttrs);
@@ -458,11 +458,11 @@ void CopyMemoryOp::print(OpAsmPrinter &printer) {
printer << ' ';
StringRef targetStorageClass = stringifyStorageClass(
- llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
+ cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
StringRef sourceStorageClass = stringifyStorageClass(
- llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
+ cast<spirv::PointerType>(getSource().getType()).getStorageClass());
printer << " \"" << sourceStorageClass << "\" " << getSource();
SmallVector<StringRef, 4> elidedAttrs;
@@ -474,7 +474,7 @@ void CopyMemoryOp::print(OpAsmPrinter &printer) {
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
Type pointeeType =
- llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
+ cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
printer << " : " << pointeeType;
}
@@ -521,10 +521,10 @@ ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
LogicalResult CopyMemoryOp::verify() {
Type targetType =
- llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
+ cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
Type sourceType =
- llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
+ cast<spirv::PointerType>(getSource().getType()).getPointeeType();
if (targetType != sourceType)
return emitOpError("both operands must be pointers to the same type");
@@ -600,7 +600,7 @@ ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseType(type))
return failure();
- auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+ auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return parser.emitError(loc, "expected spirv.ptr type");
result.addTypes(ptrType);
@@ -640,7 +640,7 @@ LogicalResult VariableOp::verify() {
"spirv.GlobalVariable for module-level variables.");
}
- auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto pointerType = cast<spirv::PointerType>(getPointer().getType());
if (getStorageClass() != pointerType.getStorageClass())
return emitOpError(
"storage class must match result pointer's storage class");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index 2ba6106896c1f..f1940091ca238 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -146,20 +146,18 @@ StringRef spirv::InterfaceVarABIAttr::getKindName() {
}
uint32_t spirv::InterfaceVarABIAttr::getBinding() {
- return llvm::cast<IntegerAttr>(getImpl()->binding).getInt();
+ return cast<IntegerAttr>(getImpl()->binding).getInt();
}
uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
- return llvm::cast<IntegerAttr>(getImpl()->descriptorSet).getInt();
+ return cast<IntegerAttr>(getImpl()->descriptorSet).getInt();
}
std::optional<spirv::StorageClass>
spirv::InterfaceVarABIAttr::getStorageClass() {
if (getImpl()->storageClass)
return static_cast<spirv::StorageClass>(
- llvm::cast<IntegerAttr>(getImpl()->storageClass)
- .getValue()
- .getZExtValue());
+ cast<IntegerAttr>(getImpl()->storageClass).getValue().getZExtValue());
return std::nullopt;
}
@@ -173,7 +171,7 @@ LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
return emitError() << "expected 32-bit integer for binding";
if (storageClass) {
- if (auto storageClassAttr = llvm::cast<IntegerAttr>(storageClass)) {
+ if (auto storageClassAttr = cast<IntegerAttr>(storageClass)) {
auto storageClassValue =
spirv::symbolizeStorageClass(storageClassAttr.getInt());
if (!storageClassValue)
@@ -222,14 +220,14 @@ StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
spirv::Version spirv::VerCapExtAttr::getVersion() {
return static_cast<spirv::Version>(
- llvm::cast<IntegerAttr>(getImpl()->version).getValue().getZExtValue());
+ cast<IntegerAttr>(getImpl()->version).getValue().getZExtValue());
}
spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
: llvm::mapped_iterator<ArrayAttr::iterator,
spirv::Extension (*)(Attribute)>(
it, [](Attribute attr) {
- return *symbolizeExtension(llvm::cast<StringAttr>(attr).getValue());
+ return *symbolizeExtension(cast<StringAttr>(attr).getValue());
}) {}
spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
@@ -238,7 +236,7 @@ spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
}
ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() {
- return llvm::cast<ArrayAttr>(getImpl()->extensions);
+ return cast<ArrayAttr>(getImpl()->extensions);
}
spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
@@ -246,7 +244,7 @@ spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
spirv::Capability (*)(Attribute)>(
it, [](Attribute attr) {
return *symbolizeCapability(
- llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
+ cast<IntegerAttr>(attr).getValue().getZExtValue());
}) {}
spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
@@ -255,7 +253,7 @@ spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
}
ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
- return llvm::cast<ArrayAttr>(getImpl()->capabilities);
+ return cast<ArrayAttr>(getImpl()->capabilities);
}
LogicalResult spirv::VerCapExtAttr::verifyInvariants(
@@ -265,7 +263,7 @@ LogicalResult spirv::VerCapExtAttr::verifyInvariants(
return emitError() << "expected 32-bit integer for version";
if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
- if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
return true;
return false;
@@ -273,7 +271,7 @@ LogicalResult spirv::VerCapExtAttr::verifyInvariants(
return emitError() << "unknown capability in capability list";
if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
- if (auto strAttr = llvm::dyn_cast<StringAttr>(attr))
+ if (auto strAttr = dyn_cast<StringAttr>(attr))
if (spirv::symbolizeExtension(strAttr.getValue()))
return true;
return false;
@@ -299,7 +297,7 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get(
StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
- return llvm::cast<spirv::VerCapExtAttr>(getImpl()->triple);
+ return cast<spirv::VerCapExtAttr>(getImpl()->triple);
}
spirv::Version spirv::TargetEnvAttr::getVersion() const {
@@ -339,7 +337,7 @@ uint32_t spirv::TargetEnvAttr::getDeviceID() const {
}
spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
- return llvm::cast<spirv::ResourceLimitsAttr>(getImpl()->limits);
+ return cast<spirv::ResourceLimitsAttr>(getImpl()->limits);
}
//===----------------------------------------------------------------------===//
@@ -668,11 +666,11 @@ void SPIRVDialect::printAttribute(Attribute attr,
if (succeeded(generatedAttributePrinter(attr, printer)))
return;
- if (auto targetEnv = llvm::dyn_cast<TargetEnvAttr>(attr))
+ if (auto targetEnv = dyn_cast<TargetEnvAttr>(attr))
print(targetEnv, printer);
- else if (auto vceAttr = llvm::dyn_cast<VerCapExtAttr>(attr))
+ else if (auto vceAttr = dyn_cast<VerCapExtAttr>(attr))
print(vceAttr, printer);
- else if (auto interfaceVarABIAttr = llvm::dyn_cast<InterfaceVarABIAttr>(attr))
+ else if (auto interfaceVarABIAttr = dyn_cast<InterfaceVarABIAttr>(attr))
print(interfaceVarABIAttr, printer);
else
llvm_unreachable("unhandled SPIR-V attribute kind");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index ccc85368c78a4..9ab3bdc6ab102 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -35,9 +35,9 @@ static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
if (!attr)
return std::nullopt;
- if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
+ if (auto boolAttr = dyn_cast<BoolAttr>(attr))
return boolAttr.getValue();
- if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
+ if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
if (splatAttr.getElementType().isInteger(1))
return splatAttr.getSplatValue<bool>();
return std::nullopt;
@@ -54,12 +54,12 @@ static Attribute extractCompositeElement(Attribute composite,
if (indices.empty())
return composite;
- if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
+ if (auto vector = dyn_cast<ElementsAttr>(composite)) {
assert(indices.size() == 1 && "must have exactly one index for a vector");
return vector.getValues<Attribute>()[indices[0]];
}
- if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
+ if (auto array = dyn_cast<ArrayAttr>(composite)) {
assert(!indices.empty() && "must have at least one index for an array");
return extractCompositeElement(array.getValue()[indices[0]],
indices.drop_front());
@@ -370,7 +370,7 @@ struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.insert<UModSimplification>(context);
+ patterns.add<UModSimplification>(context);
}
//===----------------------------------------------------------------------===//
@@ -412,10 +412,10 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
if (auto constructOp =
compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
- auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
+ auto type = cast<spirv::CompositeType>(constructOp.getType());
if (getIndices().size() == 1 &&
constructOp.getConstituents().size() == type.getNumElements()) {
- auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
+ auto i = cast<IntegerAttr>(*getIndices().begin());
if (i.getValue().getSExtValue() <
static_cast<int64_t>(constructOp.getConstituents().size()))
return constructOp.getConstituents()[i.getValue().getSExtValue()];
@@ -423,7 +423,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
}
auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
- return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
+ return static_cast<unsigned>(cast<IntegerAttr>(attr).getInt());
});
return extractCompositeElement(adaptor.getComposite(), indexVector);
}
@@ -1379,7 +1379,7 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
// Starting with version 1.4, Result Type can additionally be a composite type
// other than a vector."
bool isScalarOrVector =
- llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
+ cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
.isScalarOrVector();
// Check that each `spirv.Store` uses the same pointer, memory access
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 24c33f9ae1b90..22b57d6c0821a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -172,16 +172,16 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
return type;
// Check other allowed types
- if (auto t = llvm::dyn_cast<FloatType>(type)) {
+ if (auto t = dyn_cast<FloatType>(type)) {
// TODO: All float types are allowed for now, but this should be fixed.
- } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
+ } else if (auto t = dyn_cast<IntegerType>(type)) {
if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
"only 1/8/16/32/64-bit integer type allowed but found ")
<< type;
return Type();
}
- } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
+ } else if (auto t = dyn_cast<VectorType>(type)) {
if (t.getRank() != 1) {
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
@@ -215,7 +215,7 @@ static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
if (parser.parseType(type))
return Type();
- if (auto t = llvm::dyn_cast<VectorType>(type)) {
+ if (auto t = dyn_cast<VectorType>(type)) {
if (t.getRank() != 1) {
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
@@ -228,7 +228,7 @@ static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
return Type();
}
- if (!llvm::isa<FloatType>(t.getElementType())) {
+ if (!isa<FloatType>(t.getElementType())) {
parser.emitError(typeLoc, "matrix columns' elements must be of "
"Float type, got ")
<< t.getElementType();
@@ -1016,12 +1016,12 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
Attribute attr = attribute.getValue();
if (symbol == spirv::getEntryPointABIAttrName()) {
- if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
+ if (!isa<spirv::EntryPointABIAttr>(attr)) {
return op->emitError("'")
<< symbol << "' attribute must be an entry point ABI attribute";
}
} else if (symbol == spirv::getTargetEnvAttrName()) {
- if (!llvm::isa<spirv::TargetEnvAttr>(attr))
+ if (!isa<spirv::TargetEnvAttr>(attr))
return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
} else {
return op->emitError("found unsupported '")
@@ -1039,7 +1039,7 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
Attribute attr = attribute.getValue();
if (symbol == spirv::getInterfaceVarABIAttrName()) {
- auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+ auto varABIAttr = dyn_cast<spirv::InterfaceVarABIAttr>(attr);
if (!varABIAttr)
return emitError(loc, "'")
<< symbol << "' must be a spirv::InterfaceVarABIAttr";
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index 8575487ff52cc..ba69fa75cf2b8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -53,7 +53,7 @@ static bool isDirectInModuleLikeOp(Operation *op) {
static Type getUnaryOpResultType(Type operandType) {
Builder builder(operandType.getContext());
Type resultType = builder.getIntegerType(1);
- if (auto vecType = llvm::dyn_cast<VectorType>(operandType))
+ if (auto vecType = dyn_cast<VectorType>(operandType))
return VectorType::get(vecType.getNumElements(), resultType);
return resultType;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 938952ed273cd..1962538d804a8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -52,7 +52,7 @@ LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
return failure();
}
auto valueAttr = constOp.getValue();
- auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
+ auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
if (!integerValueAttr) {
return failure();
}
@@ -129,7 +129,7 @@ static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColon() || parser.parseType(type))
return failure();
- auto fnType = llvm::dyn_cast<FunctionType>(type);
+ auto fnType = dyn_cast<FunctionType>(type);
if (!fnType) {
parser.emitError(loc, "expected function type");
return failure();
@@ -169,11 +169,10 @@ template <typename BlockReadWriteOpTy>
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
Value ptr, Value val) {
auto valType = val.getType();
- if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
+ if (auto valVecTy = dyn_cast<VectorType>(valType))
valType = valVecTy.getElementType();
- if (valType !=
- llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
+ if (valType != cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
return op.emitOpError("mismatch in result type and pointer type");
}
return success();
@@ -191,7 +190,7 @@ getElementType(Type type, ArrayRef<int32_t> indices,
}
for (auto index : indices) {
- if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
+ if (auto cType = dyn_cast<spirv::CompositeType>(type)) {
if (cType.hasCompileTimeKnownNumElements() &&
(index < 0 ||
static_cast<uint64_t>(index) >= cType.getNumElements())) {
@@ -211,7 +210,7 @@ getElementType(Type type, ArrayRef<int32_t> indices,
static Type
getElementType(Type type, Attribute indices,
function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
- auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
+ auto indicesArrayAttr = dyn_cast<ArrayAttr>(indices);
if (!indicesArrayAttr) {
emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
return nullptr;
@@ -223,7 +222,7 @@ getElementType(Type type, Attribute indices,
SmallVector<int32_t, 2> indexVals;
for (auto indexAttr : indicesArrayAttr) {
- auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
+ auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
if (!indexIntAttr) {
emitErrorFn("expected an 32-bit integer for index, but found '")
<< indexAttr << "'";
@@ -251,7 +250,7 @@ static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
template <typename ExtendedBinaryOp>
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
- auto resultType = llvm::cast<spirv::StructType>(op.getType());
+ auto resultType = cast<spirv::StructType>(op.getType());
if (resultType.getNumElements() != 2)
return op.emitOpError("expected result struct type containing two members");
@@ -276,7 +275,7 @@ static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
if (parser.parseType(resultType))
return failure();
- auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
+ auto structType = dyn_cast<spirv::StructType>(resultType);
if (!structType || structType.getNumElements() != 2)
return parser.emitError(loc, "expected spirv.struct type with two members");
@@ -361,7 +360,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
}
// Case 2./3./4. -- number of constituents matches the number of elements.
- auto cType = llvm::cast<spirv::CompositeType>(getType());
+ auto cType = cast<spirv::CompositeType>(getType());
if (constituents.size() == cType.getNumElements()) {
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
if (constituents[index].getType() != cType.getElementType(index)) {
@@ -374,7 +373,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
}
// Case 4. -- check that all constituents add up tp the expected vector type.
- auto resultType = llvm::dyn_cast<VectorType>(cType);
+ auto resultType = dyn_cast<VectorType>(cType);
if (!resultType)
return emitOpError(
"expected to return a vector or cooperative matrix when the number of "
@@ -382,14 +381,14 @@ LogicalResult spirv::CompositeConstructOp::verify() {
SmallVector<unsigned> sizes;
for (Value component : constituents) {
- if (!llvm::isa<VectorType>(component.getType()) &&
+ if (!isa<VectorType>(component.getType()) &&
!component.getType().isIntOrFloat())
return emitOpError("operand type mismatch: expected operand to have "
"a scalar or vector type, but provided ")
<< component.getType();
Type elementType = component.getType();
- if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
+ if (auto vectorType = dyn_cast<VectorType>(component.getType())) {
sizes.push_back(vectorType.getNumElements());
elementType = vectorType.getElementType();
} else {
@@ -455,7 +454,7 @@ void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::CompositeExtractOp::verify() {
- auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
+ auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
auto resultType =
getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
if (!resultType)
@@ -500,7 +499,7 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
}
LogicalResult spirv::CompositeInsertOp::verify() {
- auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
+ auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
auto objectType =
getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
if (!objectType)
@@ -538,14 +537,14 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
return failure();
Type type = NoneType::get(parser.getContext());
- if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
+ if (auto typedAttr = dyn_cast<TypedAttr>(value))
type = typedAttr.getType();
- if (llvm::isa<NoneType, TensorType>(type)) {
+ if (isa<NoneType, TensorType>(type)) {
if (parser.parseColonType(type))
return failure();
}
- if (llvm::isa<TensorArmType>(type)) {
+ if (isa<TensorArmType>(type)) {
if (parser.parseOptionalColon().succeeded())
if (parser.parseType(type))
return failure();
@@ -556,7 +555,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
void spirv::ConstantOp::print(OpAsmPrinter &printer) {
printer << ' ' << getValue();
- if (llvm::isa<spirv::ArrayType>(getType()))
+ if (isa<spirv::ArrayType>(getType()))
printer << " : " << getType();
}
@@ -569,19 +568,19 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
"matrix constant, but found ")
<< denseAttr;
}
- if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
- auto valueType = llvm::cast<TypedAttr>(value).getType();
+ if (isa<IntegerAttr, FloatAttr>(value)) {
+ auto valueType = cast<TypedAttr>(value).getType();
if (valueType != opType)
return op.emitOpError("result type (")
<< opType << ") does not match value type (" << valueType << ")";
return success();
}
- if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
- auto valueType = llvm::cast<TypedAttr>(value).getType();
+ if (isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
+ auto valueType = cast<TypedAttr>(value).getType();
if (valueType == opType)
return success();
- auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
- auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
+ auto arrayType = dyn_cast<spirv::ArrayType>(opType);
+ auto shapedType = dyn_cast<ShapedType>(valueType);
if (!arrayType)
return op.emitOpError("result or element type (")
<< opType << ") does not match value type (" << valueType
@@ -589,7 +588,7 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
int numElements = arrayType.getNumElements();
auto opElemType = arrayType.getElementType();
- while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
+ while (auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
numElements *= t.getNumElements();
opElemType = t.getElementType();
}
@@ -610,8 +609,8 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
}
return success();
}
- if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
- auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
+ auto arrayType = dyn_cast<spirv::ArrayType>(opType);
if (!arrayType)
return op.emitOpError(
"must have spirv.array result type for array value");
@@ -635,12 +634,12 @@ LogicalResult spirv::ConstantOp::verify() {
bool spirv::ConstantOp::isBuildableWith(Type type) {
// Must be valid SPIR-V type first.
- if (!llvm::isa<spirv::SPIRVType>(type))
+ if (!isa<spirv::SPIRVType>(type))
return false;
if (isa<SPIRVDialect>(type.getDialect())) {
// TODO: support constant struct
- return llvm::isa<spirv::ArrayType>(type);
+ return isa<spirv::ArrayType>(type);
}
return true;
@@ -648,7 +647,7 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
OpBuilder &builder) {
- if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
return spirv::ConstantOp::create(builder, loc, type,
@@ -656,19 +655,19 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
return spirv::ConstantOp::create(
builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
}
- if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+ if (auto floatType = dyn_cast<FloatType>(type)) {
return spirv::ConstantOp::create(builder, loc, type,
builder.getFloatAttr(floatType, 0.0));
}
- if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
- if (llvm::isa<IntegerType>(elemType)) {
+ if (isa<IntegerType>(elemType)) {
return spirv::ConstantOp::create(
builder, loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 0).getValue()));
}
- if (llvm::isa<FloatType>(elemType)) {
+ if (isa<FloatType>(elemType)) {
return spirv::ConstantOp::create(
builder, loc, type,
DenseFPElementsAttr::get(vectorType,
@@ -681,7 +680,7 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
OpBuilder &builder) {
- if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
unsigned width = intType.getWidth();
if (width == 1)
return spirv::ConstantOp::create(builder, loc, type,
@@ -689,19 +688,19 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
return spirv::ConstantOp::create(
builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
}
- if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+ if (auto floatType = dyn_cast<FloatType>(type)) {
return spirv::ConstantOp::create(builder, loc, type,
builder.getFloatAttr(floatType, 1.0));
}
- if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
Type elemType = vectorType.getElementType();
- if (llvm::isa<IntegerType>(elemType)) {
+ if (isa<IntegerType>(elemType)) {
return spirv::ConstantOp::create(
builder, loc, type,
DenseElementsAttr::get(vectorType,
IntegerAttr::get(elemType, 1).getValue()));
}
- if (llvm::isa<FloatType>(elemType)) {
+ if (isa<FloatType>(elemType)) {
return spirv::ConstantOp::create(
builder, loc, type,
DenseFPElementsAttr::get(vectorType,
@@ -720,9 +719,9 @@ void mlir::spirv::ConstantOp::getAsmResultNames(
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << "cst";
- IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
+ IntegerType intTy = dyn_cast<IntegerType>(type);
- if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
+ if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
assert(intTy);
if (intTy.getWidth() == 1) {
@@ -738,18 +737,17 @@ void mlir::spirv::ConstantOp::getAsmResultNames(
}
}
- if (intTy || llvm::isa<FloatType>(type)) {
+ if (intTy || isa<FloatType>(type)) {
specialName << '_' << type;
}
- if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
specialName << "_vec_";
specialName << vecType.getDimSize(0);
Type elementType = vecType.getElementType();
- if (llvm::isa<IntegerType>(elementType) ||
- llvm::isa<FloatType>(elementType)) {
+ if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
specialName << "x" << elementType;
}
}
@@ -903,7 +901,7 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
if (parser.parseAttribute(value, i32Type, "value", attr)) {
return failure();
}
- values.push_back(llvm::cast<IntegerAttr>(value).getInt());
+ values.push_back(cast<IntegerAttr>(value).getInt());
}
StringRef valuesAttrName =
spirv::ExecutionModeOp::getValuesAttrName(result.name);
@@ -1005,7 +1003,7 @@ LogicalResult spirv::FuncOp::verifyType() {
auto hasDecorationAttr = [&](spirv::Decoration decoration,
unsigned argIndex) {
- auto func = llvm::cast<FunctionOpInterface>(getOperation());
+ auto func = cast<FunctionOpInterface>(getOperation());
for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
if (argAttr.getName() != spirv::DecorationAttr::name)
continue;
@@ -1224,7 +1222,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
if (parser.parseColonType(type)) {
return failure();
}
- if (!llvm::isa<spirv::PointerType>(type)) {
+ if (!isa<spirv::PointerType>(type)) {
return parser.emitError(loc, "expected spirv.ptr type");
}
result.addAttribute(typeAttrName, TypeAttr::get(type));
@@ -1257,7 +1255,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::GlobalVariableOp::verify() {
- if (!llvm::isa<spirv::PointerType>(getType()))
+ if (!isa<spirv::PointerType>(getType()))
return emitOpError("result must be of a !spv.ptr type");
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
@@ -1325,7 +1323,7 @@ ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
}
auto ptrType = spirv::PointerType::get(elementType, storageClass);
- if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
+ if (auto valVecTy = dyn_cast<VectorType>(elementType))
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
@@ -1539,7 +1537,7 @@ LogicalResult spirv::ModuleOp::verifyRegions() {
}
if (auto interface = entryPointOp.getInterface()) {
for (Attribute varRef : interface) {
- auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
+ auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
if (!varSymRef) {
return entryPointOp.emitError(
"expected symbol reference for interface "
@@ -1660,9 +1658,9 @@ LogicalResult spirv::SpecConstantOp::verify() {
return emitOpError("SpecId cannot be negative");
auto value = getDefaultValue();
- if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
+ if (isa<IntegerAttr, FloatAttr>(value)) {
// Make sure bitwidth is allowed.
- if (!llvm::isa<spirv::SPIRVType>(value.getType()))
+ if (!isa<spirv::SPIRVType>(value.getType()))
return emitOpError("default value bitwidth disallowed");
return success();
}
@@ -1675,7 +1673,7 @@ LogicalResult spirv::SpecConstantOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::VectorShuffleOp::verify() {
- VectorType resultType = llvm::cast<VectorType>(getType());
+ VectorType resultType = cast<VectorType>(getType());
size_t numResultElements = resultType.getNumElements();
if (numResultElements != getComponents().size())
@@ -1685,8 +1683,8 @@ LogicalResult spirv::VectorShuffleOp::verify() {
<< getComponents().size() << ")";
size_t totalSrcElements =
- llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
- llvm::cast<VectorType>(getVector2().getType()).getNumElements();
+ cast<VectorType>(getVector1().getType()).getNumElements() +
+ cast<VectorType>(getVector2().getType()).getNumElements();
for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
uint32_t index = selector.getZExtValue();
@@ -1725,8 +1723,8 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::TransposeOp::verify() {
- auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
- auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
+ auto inputMatrix = cast<spirv::MatrixType>(getMatrix().getType());
+ auto resultMatrix = cast<spirv::MatrixType>(getResult().getType());
// Verify that the input and output matrices have correct shapes.
if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
@@ -1750,9 +1748,9 @@ LogicalResult spirv::TransposeOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::MatrixTimesVectorOp::verify() {
- auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
- auto vectorType = llvm::cast<VectorType>(getVector().getType());
- auto resultType = llvm::cast<VectorType>(getType());
+ auto matrixType = cast<spirv::MatrixType>(getMatrix().getType());
+ auto vectorType = cast<VectorType>(getVector().getType());
+ auto resultType = cast<VectorType>(getType());
if (matrixType.getNumColumns() != vectorType.getNumElements())
return emitOpError("matrix columns (")
@@ -1775,9 +1773,9 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::VectorTimesMatrixOp::verify() {
- auto vectorType = llvm::cast<VectorType>(getVector().getType());
- auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
- auto resultType = llvm::cast<VectorType>(getType());
+ auto vectorType = cast<VectorType>(getVector().getType());
+ auto matrixType = cast<spirv::MatrixType>(getMatrix().getType());
+ auto resultType = cast<VectorType>(getType());
if (matrixType.getNumRows() != vectorType.getNumElements())
return emitOpError("number of components in vector must equal the number "
@@ -1799,9 +1797,9 @@ LogicalResult spirv::VectorTimesMatrixOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::MatrixTimesMatrixOp::verify() {
- auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
- auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
- auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
+ auto leftMatrix = cast<spirv::MatrixType>(getLeftmatrix().getType());
+ auto rightMatrix = cast<spirv::MatrixType>(getRightmatrix().getType());
+ auto resultMatrix = cast<spirv::MatrixType>(getResult().getType());
// left matrix columns' count and right matrix rows' count must be equal
if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
@@ -1886,14 +1884,14 @@ void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::SpecConstantCompositeOp::verify() {
- auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
+ auto cType = dyn_cast<spirv::CompositeType>(getType());
auto constituents = this->getConstituents().getValue();
if (!cType)
return emitError("result type must be a composite type, but provided ")
<< getType();
- if (llvm::isa<spirv::CooperativeMatrixType>(cType))
+ if (isa<spirv::CooperativeMatrixType>(cType))
return emitError("unsupported composite type ") << cType;
if (constituents.size() != cType.getNumElements())
return emitError("has incorrect number of operands: expected ")
@@ -1901,7 +1899,7 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
<< constituents.size();
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
- auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
+ auto constituent = cast<FlatSymbolRefAttr>(constituents[index]);
auto constituentSpecConstOp =
dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
@@ -2042,19 +2040,19 @@ LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
LogicalResult spirv::GLFrexpStructOp::verify() {
spirv::StructType structTy =
- llvm::dyn_cast<spirv::StructType>(getResult().getType());
+ dyn_cast<spirv::StructType>(getResult().getType());
if (structTy.getNumElements() != 2)
return emitError("result type must be a struct type with two memebers");
Type significandTy = structTy.getElementType(0);
Type exponentTy = structTy.getElementType(1);
- VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
- IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
+ VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
+ IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
Type operandTy = getOperand().getType();
- VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
- FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
+ VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
+ FloatType operandFTy = dyn_cast<FloatType>(operandTy);
if (significandTy != operandTy)
return emitError("member zero of the resulting struct type must be the "
@@ -2062,7 +2060,7 @@ LogicalResult spirv::GLFrexpStructOp::verify() {
if (exponentVecTy) {
IntegerType componentIntTy =
- llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
+ dyn_cast<IntegerType>(exponentVecTy.getElementType());
if (!componentIntTy || componentIntTy.getWidth() != 32)
return emitError("member one of the resulting struct type must"
"be a scalar or vector of 32 bit integer type");
@@ -2091,12 +2089,11 @@ LogicalResult spirv::GLLdexpOp::verify() {
Type significandType = getX().getType();
Type exponentType = getExp().getType();
- if (llvm::isa<FloatType>(significandType) !=
- llvm::isa<IntegerType>(exponentType))
+ if (isa<FloatType>(significandType) != isa<IntegerType>(exponentType))
return emitOpError("operands must both be scalars or vectors");
auto getNumElements = [](Type type) -> unsigned {
- if (auto vectorType = llvm::dyn_cast<VectorType>(type))
+ if (auto vectorType = dyn_cast<VectorType>(type))
return vectorType.getNumElements();
return 1;
};
@@ -2138,7 +2135,7 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
LogicalResult spirv::VectorTimesScalarOp::verify() {
if (getVector().getType() != getType())
return emitOpError("vector operand and result type mismatch");
- auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
+ auto scalarType = cast<VectorType>(getType()).getElementType();
if (getScalar().getType() != scalarType)
return emitOpError("scalar operand and result element type match");
return success();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index f28d386f8874d..772219c8db654 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -75,11 +75,11 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
attrName, attr))
return failure();
- if (!llvm::isa<StringAttr>(attrVal))
+ if (!isa<StringAttr>(attrVal))
return parser.emitError(loc, "expected ")
<< attrName << " attribute specified as string";
- auto attrOptional = spirv::symbolizeEnum<EnumClass>(
- llvm::cast<StringAttr>(attrVal).getValue());
+ auto attrOptional =
+ spirv::symbolizeEnum<EnumClass>(cast<StringAttr>(attrVal).getValue());
if (!attrOptional)
return parser.emitError(loc, "invalid ")
<< attrName << " attribute specification: " << attrVal;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index d1e275d590f78..53a48abe5ad02 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -178,17 +178,17 @@ unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
//===----------------------------------------------------------------------===//
bool CompositeType::classof(Type type) {
- if (auto vectorType = llvm::dyn_cast<VectorType>(type))
+ if (auto vectorType = dyn_cast<VectorType>(type))
return isValid(vectorType);
- return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
- spirv::MatrixType, spirv::RuntimeArrayType,
- spirv::StructType, spirv::TensorArmType>(type);
+ return isa<spirv::ArrayType, spirv::CooperativeMatrixType, spirv::MatrixType,
+ spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType>(
+ type);
}
bool CompositeType::isValid(VectorType type) {
return type.getRank() == 1 &&
llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
- llvm::isa<ScalarType>(type.getElementType());
+ isa<ScalarType>(type.getElementType());
}
Type CompositeType::getElementType(unsigned index) const {
@@ -210,7 +210,7 @@ unsigned CompositeType::getNumElements() const {
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
- return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
+ return !isa<CooperativeMatrixType, RuntimeArrayType>(*this);
}
void TypeCapabilityVisitor::addConcrete(VectorType type) {
@@ -529,10 +529,10 @@ void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
//===----------------------------------------------------------------------===//
bool ScalarType::classof(Type type) {
- if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+ if (auto floatType = dyn_cast<FloatType>(type)) {
return isValid(floatType);
}
- if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
return isValid(intType);
}
return false;
@@ -676,19 +676,19 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
bool SPIRVType::classof(Type type) {
// Allow SPIR-V dialect types
- if (llvm::isa<SPIRVDialect>(type.getDialect()))
+ if (isa<SPIRVDialect>(type.getDialect()))
return true;
- if (llvm::isa<ScalarType>(type))
+ if (isa<ScalarType>(type))
return true;
- if (auto vectorType = llvm::dyn_cast<VectorType>(type))
+ if (auto vectorType = dyn_cast<VectorType>(type))
return CompositeType::isValid(vectorType);
- if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
- return llvm::isa<ScalarType>(tensorArmType.getElementType());
+ if (auto tensorArmType = dyn_cast<TensorArmType>(type))
+ return isa<ScalarType>(tensorArmType.getElementType());
return false;
}
bool SPIRVType::isScalarOrVector() {
- return isIntOrFloat() || llvm::isa<VectorType>(*this);
+ return isIntOrFloat() || isa<VectorType>(*this);
}
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1190,7 +1190,7 @@ MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "matrix columns must be vectors of floats";
/// The underlying vectors (columns) must be of size 2, 3, or 4
- ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
+ ArrayRef<int64_t> columnShape = cast<VectorType>(columnType).getShape();
if (columnShape.size() != 1)
return emitError() << "matrix columns must be 1D vectors";
@@ -1202,8 +1202,8 @@ MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
/// Returns true if the matrix elements are vectors of float elements
bool MatrixType::isValidColumnType(Type columnType) {
- if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
- if (llvm::isa<FloatType>(vectorType.getElementType()))
+ if (auto vectorType = dyn_cast<VectorType>(columnType)) {
+ if (isa<FloatType>(vectorType.getElementType()))
return true;
}
return false;
@@ -1212,13 +1212,13 @@ bool MatrixType::isValidColumnType(Type columnType) {
Type MatrixType::getColumnType() const { return getImpl()->columnType; }
Type MatrixType::getElementType() const {
- return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
+ return cast<VectorType>(getImpl()->columnType).getElementType();
}
unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
unsigned MatrixType::getNumRows() const {
- return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
+ return cast<VectorType>(getImpl()->columnType).getShape()[0];
}
unsigned MatrixType::getNumElements() const {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 50883d9ed5e75..ce7e7dc4116c8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -243,7 +243,7 @@ static LogicalResult deserializeCacheControlDecoration(
auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
SmallVector<Attribute> attrs;
if (auto attrList =
- llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
+ dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
llvm::append_range(attrs, attrList);
attrs.push_back(value);
decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
@@ -326,7 +326,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
StringAttr::get(context, linkageName), linkageTypeAttr);
- decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
+ decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
break;
}
case spirv::Decoration::Aliased:
@@ -1511,10 +1511,10 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
"constant instruction of type OpTypeArray");
- ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
+ ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
SmallVector<int64_t, 1> shape;
for (auto dimAttr : shapeArrayAttr.getValue()) {
- auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
+ auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
if (!dimIntAttr)
return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
"dimension size");
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 6397d2c005c16..b78fac532d8c5 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -876,7 +876,7 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
if (values) {
for (auto &intVal : values.getValue()) {
operands.push_back(static_cast<uint32_t>(
- llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
+ cast<IntegerAttr>(intVal).getValue().getZExtValue()));
}
}
encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c879a2b3e0207..c29d20f755332 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -312,7 +312,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::LinkageAttributes: {
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
- auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
+ auto linkageAttr = dyn_cast<spirv::LinkageAttributesAttr>(attr);
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
@@ -822,7 +822,7 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+ if (auto tensorArmType = dyn_cast<TensorArmType>(type)) {
uint32_t elementTypeID = 0;
uint32_t rank = 0;
uint32_t shapeID = 0;
More information about the Mlir-commits
mailing list