[Mlir-commits] [mlir] 9f0d5cd - [mlir][spirv] Clean up casts. NFC. (#174115)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 31 14:08:08 PST 2025


Author: Jakub Kuderski
Date: 2025-12-31T17:08:03-05:00
New Revision: 9f0d5cd0c2280a3bf9237c6ad8f7fe17be36e7e6

URL: https://github.com/llvm/llvm-project/commit/9f0d5cd0c2280a3bf9237c6ad8f7fe17be36e7e6
DIFF: https://github.com/llvm/llvm-project/commit/9f0d5cd0c2280a3bf9237c6ad8f7fe17be36e7e6.diff

LOG: [mlir][spirv] Clean up casts. NFC. (#174115)

Drop the `llvm::` namespace prefix where needlessly used. These were
introduced by clang-tidy when we migrated from cast member functions to
free functions.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
    mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
    mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
    mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
    mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
    mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
    mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index e46b576810316..aac5ef17370b2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -407,7 +407,7 @@ class CooperativeMatrixType
   /// Returns the use parameter of the cooperative matrix.
   CooperativeMatrixUseKHR getUse() const;
 
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+  operator ShapedType() const { return cast<ShapedType>(*this); }
 
   ArrayRef<int64_t> getShape() const;
 
@@ -491,7 +491,7 @@ class TensorArmType
   Type getElementType() const;
   ArrayRef<int64_t> getShape() const;
   bool hasRank() const { return !getShape().empty(); }
-  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+  operator ShapedType() const { return cast<ShapedType>(*this); }
 };
 
 } // namespace spirv

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 56e8fee191432..c101a95685a25 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -422,10 +422,10 @@ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
 
 #define INT_AND_FLOAT_CASE(kind, iop, fop)                                     \
   case vector::CombiningKind::kind:                                            \
-    if (llvm::isa<IntegerType>(resultType)) {                                  \
+    if (isa<IntegerType>(resultType)) {                                        \
       result = spirv::iop::create(rewriter, loc, resultType, result, next);    \
     } else {                                                                   \
-      assert(llvm::isa<FloatType>(resultType));                                \
+      assert(isa<FloatType>(resultType));                                      \
       result = spirv::fop::create(rewriter, loc, resultType, result, next);    \
     }                                                                          \
     break

diff  --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index 948d48980f2e8..7029268177128 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -35,9 +35,9 @@ StringRef stringifyTypeName<FloatType>() {
 // Verifies an atomic update op.
 template <typename AtomicOpTy, typename ExpectedElementType>
 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
-  auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
+  auto ptrType = cast<spirv::PointerType>(op->getOperand(0).getType());
   auto elementType = ptrType.getPointeeType();
-  if (!llvm::isa<ExpectedElementType>(elementType))
+  if (!isa<ExpectedElementType>(elementType))
     return op->emitOpError() << "pointer operand must point to an "
                              << stringifyTypeName<ExpectedElementType>()
                              << " value, found " << elementType;

diff  --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index fcf4eb6fbcf60..a5330dc56d48f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -87,13 +87,13 @@ LogicalResult BitcastOp::verify() {
   if (operandType == resultType) {
     return emitError("result type must be 
diff erent from operand type");
   }
-  if (llvm::isa<spirv::PointerType>(operandType) &&
-      !llvm::isa<spirv::PointerType>(resultType)) {
+  if (isa<spirv::PointerType>(operandType) &&
+      !isa<spirv::PointerType>(resultType)) {
     return emitError(
         "unhandled bit cast conversion from pointer type to non-pointer type");
   }
-  if (!llvm::isa<spirv::PointerType>(operandType) &&
-      llvm::isa<spirv::PointerType>(resultType)) {
+  if (!isa<spirv::PointerType>(operandType) &&
+      isa<spirv::PointerType>(resultType)) {
     return emitError(
         "unhandled bit cast conversion from non-pointer type to pointer type");
   }
@@ -112,8 +112,8 @@ LogicalResult BitcastOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ConvertPtrToUOp::verify() {
-  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
-  auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
+  auto operandType = cast<spirv::PointerType>(getPointer().getType());
+  auto resultType = cast<spirv::ScalarType>(getResult().getType());
   if (!resultType || !resultType.isSignlessInteger())
     return emitError("result must be a scalar type of unsigned integer");
   auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
@@ -133,8 +133,8 @@ LogicalResult ConvertPtrToUOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ConvertUToPtrOp::verify() {
-  auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
-  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+  auto operandType = cast<spirv::ScalarType>(getOperand().getType());
+  auto resultType = cast<spirv::PointerType>(getResult().getType());
   if (!operandType || !operandType.isSignlessInteger())
     return emitError("result must be a scalar type of unsigned integer");
   auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
@@ -154,8 +154,8 @@ LogicalResult ConvertUToPtrOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult PtrCastToGenericOp::verify() {
-  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
-  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+  auto operandType = cast<spirv::PointerType>(getPointer().getType());
+  auto resultType = cast<spirv::PointerType>(getResult().getType());
 
   spirv::StorageClass operandStorage = operandType.getStorageClass();
   if (operandStorage != spirv::StorageClass::Workgroup &&
@@ -182,8 +182,8 @@ LogicalResult PtrCastToGenericOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GenericCastToPtrOp::verify() {
-  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
-  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+  auto operandType = cast<spirv::PointerType>(getPointer().getType());
+  auto resultType = cast<spirv::PointerType>(getResult().getType());
 
   spirv::StorageClass operandStorage = operandType.getStorageClass();
   if (operandStorage != spirv::StorageClass::Generic)
@@ -210,8 +210,8 @@ LogicalResult GenericCastToPtrOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GenericCastToPtrExplicitOp::verify() {
-  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
-  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+  auto operandType = cast<spirv::PointerType>(getPointer().getType());
+  auto resultType = cast<spirv::PointerType>(getResult().getType());
 
   spirv::StorageClass operandStorage = operandType.getStorageClass();
   if (operandStorage != spirv::StorageClass::Generic)

diff  --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index a846d7e60024c..4d0aedca27d42 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -138,7 +138,7 @@ LogicalResult BranchConditionalOp::verify() {
       return emitOpError("must have exactly two branch weights");
     }
     if (llvm::all_of(*weights, [](Attribute attr) {
-          return llvm::cast<IntegerAttr>(attr).getValue().isZero();
+          return cast<IntegerAttr>(attr).getValue().isZero();
         }))
       return emitOpError("branch weights cannot both be zero");
   }
@@ -504,8 +504,8 @@ LogicalResult ReturnValueOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult SelectOp::verify() {
-  if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
-    auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
+  if (auto conditionTy = dyn_cast<VectorType>(getCondition().getType())) {
+    auto resultVectorTy = dyn_cast<VectorType>(getResult().getType());
     if (!resultVectorTy) {
       return emitOpError("result expected to be of vector type when "
                          "condition is of vector type");

diff  --git a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
index 01ef1bdc42515..dada8925b88e1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp
@@ -74,10 +74,9 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   Type factorTy = op->getOperand(0).getType();
   StringAttr packedVectorFormatAttrName =
       IntegerDotProductOpTy::getFormatAttrName(op->getName());
-  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
-    auto packedVectorFormat =
-        llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
-            op->getAttr(packedVectorFormatAttrName));
+  if (auto intTy = dyn_cast<IntegerType>(factorTy)) {
+    auto packedVectorFormat = dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
+        op->getAttr(packedVectorFormatAttrName));
     if (!packedVectorFormat)
       return op->emitOpError("requires Packed Vector Format attribute for "
                              "integer vector operands");
@@ -135,8 +134,8 @@ getIntegerDotProductCapabilities(Operation *op) {
   Type factorTy = op->getOperand(0).getType();
   StringAttr packedVectorFormatAttrName =
       IntegerDotProductOpTy::getFormatAttrName(op->getName());
-  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
-    auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
+  if (auto intTy = dyn_cast<IntegerType>(factorTy)) {
+    auto formatAttr = cast<spirv::PackedVectorFormatAttr>(
         op->getAttr(packedVectorFormatAttrName));
     if (formatAttr.getValue() ==
         spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
@@ -145,7 +144,7 @@ getIntegerDotProductCapabilities(Operation *op) {
     return capabilities;
   }
 
-  auto vecTy = llvm::cast<VectorType>(factorTy);
+  auto vecTy = cast<VectorType>(factorTy);
   if (vecTy.getElementTypeBitWidth() == 8) {
     capabilities.push_back(dotProductInput4x8BitCap);
     return capabilities;

diff  --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 461d037134dae..a1bb7f89e9183 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -65,7 +65,7 @@ LogicalResult GroupBroadcastOp::verify() {
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
 
-  if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
+  if (auto localIdTy = dyn_cast<VectorType>(getLocalid().getType()))
     if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
       return emitOpError("localid is a vector and can be with only "
                          " 2 or 3 components, actual number is ")

diff  --git a/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
index 661f3d5d9b81d..6ea07330b70cb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
@@ -89,8 +89,8 @@ static LogicalResult verifyImageOperands(Operation *imageOp,
                                   "floating-point type scalar");
 
       auto samplingOp = cast<spirv::SamplingOpInterface>(imageOp);
-      auto sampledImageType = llvm::cast<spirv::SampledImageType>(
-          samplingOp.getSampledImage().getType());
+      auto sampledImageType =
+          cast<spirv::SampledImageType>(samplingOp.getSampledImage().getType());
       imageType = cast<spirv::ImageType>(sampledImageType.getImageType());
     } else {
       if (!isa<mlir::IntegerType>(operands[index].getType()))
@@ -243,8 +243,7 @@ LogicalResult spirv::ImageWriteOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::ImageQuerySizeOp::verify() {
-  spirv::ImageType imageType =
-      llvm::cast<spirv::ImageType>(getImage().getType());
+  spirv::ImageType imageType = cast<spirv::ImageType>(getImage().getType());
   Type resultType = getResult().getType();
 
   spirv::Dim dim = imageType.getDim();
@@ -292,7 +291,7 @@ LogicalResult spirv::ImageQuerySizeOp::verify() {
     componentNumber += 1;
 
   unsigned resultComponentNumber = 1;
-  if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
+  if (auto resultVectorType = dyn_cast<VectorType>(resultType))
     resultComponentNumber = resultVectorType.getNumElements();
 
   if (componentNumber != resultComponentNumber)

diff  --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 5ae27e5d82bd7..e3187d3dc1901 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -166,7 +166,7 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
   // TODO: Check that the value type satisfies restrictions of
   // SPIR-V OpLoad/OpStore operations
   if (val.getType() !=
-      llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
+      cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
     return op.emitOpError("mismatch in result type and pointer type");
   }
   return success();
@@ -190,7 +190,7 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
     return success();
   }
 
-  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
+  auto memAccess = cast<spirv::MemoryAccessAttr>(memAccessAttr);
 
   if (!memAccess) {
     return memoryOp.emitOpError("invalid memory access specifier: ")
@@ -234,7 +234,7 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
     return success();
   }
 
-  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
+  auto memAccess = cast<spirv::MemoryAccessAttr>(memAccessAttr);
 
   if (!memAccess) {
     return memoryOp.emitOpError("invalid memory access specifier: ")
@@ -261,7 +261,7 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
 //===----------------------------------------------------------------------===//
 
 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
-  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+  auto ptrType = dyn_cast<spirv::PointerType>(type);
   if (!ptrType) {
     emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
                        "to composite type, but provided ")
@@ -274,7 +274,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
   int32_t index = 0;
 
   for (auto indexSSA : indices) {
-    auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
+    auto cType = dyn_cast<spirv::CompositeType>(resultType);
     if (!cType) {
       emitError(
           baseLoc,
@@ -283,7 +283,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
       return nullptr;
     }
     index = 0;
-    if (llvm::isa<spirv::StructType>(resultType)) {
+    if (isa<spirv::StructType>(resultType)) {
       Operation *op = indexSSA.getDefiningOp();
       if (!op) {
         emitError(baseLoc, "'spirv.AccessChain' op index must be an "
@@ -334,7 +334,7 @@ static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
     return failure();
 
   auto providedResultType =
-      llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
+      dyn_cast<spirv::PointerType>(accessChainOp.getType());
   if (!providedResultType)
     return accessChainOp.emitOpError(
                "result type must be a pointer, but provided")
@@ -357,7 +357,7 @@ LogicalResult AccessChainOp::verify() {
 
 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
                    MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
-  auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
+  auto ptrType = cast<spirv::PointerType>(basePtr.getType());
   build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
         alignment);
 }
@@ -386,7 +386,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
 void LoadOp::print(OpAsmPrinter &printer) {
   SmallVector<StringRef, 4> elidedAttrs;
   StringRef sc = stringifyStorageClass(
-      llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
+      cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
   printer << " \"" << sc << "\" " << getPtr();
 
   printMemoryAccessAttribute(*this, printer, elidedAttrs);
@@ -433,7 +433,7 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
 void StoreOp::print(OpAsmPrinter &printer) {
   SmallVector<StringRef, 4> elidedAttrs;
   StringRef sc = stringifyStorageClass(
-      llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
+      cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
   printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
 
   printMemoryAccessAttribute(*this, printer, elidedAttrs);
@@ -458,11 +458,11 @@ void CopyMemoryOp::print(OpAsmPrinter &printer) {
   printer << ' ';
 
   StringRef targetStorageClass = stringifyStorageClass(
-      llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
+      cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
   printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
 
   StringRef sourceStorageClass = stringifyStorageClass(
-      llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
+      cast<spirv::PointerType>(getSource().getType()).getStorageClass());
   printer << " \"" << sourceStorageClass << "\" " << getSource();
 
   SmallVector<StringRef, 4> elidedAttrs;
@@ -474,7 +474,7 @@ void CopyMemoryOp::print(OpAsmPrinter &printer) {
   printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
 
   Type pointeeType =
-      llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
+      cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
   printer << " : " << pointeeType;
 }
 
@@ -521,10 +521,10 @@ ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
 
 LogicalResult CopyMemoryOp::verify() {
   Type targetType =
-      llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
+      cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
 
   Type sourceType =
-      llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
+      cast<spirv::PointerType>(getSource().getType()).getPointeeType();
 
   if (targetType != sourceType)
     return emitOpError("both operands must be pointers to the same type");
@@ -600,7 +600,7 @@ ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
   if (parser.parseType(type))
     return failure();
 
-  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+  auto ptrType = dyn_cast<spirv::PointerType>(type);
   if (!ptrType)
     return parser.emitError(loc, "expected spirv.ptr type");
   result.addTypes(ptrType);
@@ -640,7 +640,7 @@ LogicalResult VariableOp::verify() {
         "spirv.GlobalVariable for module-level variables.");
   }
 
-  auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
+  auto pointerType = cast<spirv::PointerType>(getPointer().getType());
   if (getStorageClass() != pointerType.getStorageClass())
     return emitOpError(
         "storage class must match result pointer's storage class");

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index 2ba6106896c1f..f1940091ca238 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -146,20 +146,18 @@ StringRef spirv::InterfaceVarABIAttr::getKindName() {
 }
 
 uint32_t spirv::InterfaceVarABIAttr::getBinding() {
-  return llvm::cast<IntegerAttr>(getImpl()->binding).getInt();
+  return cast<IntegerAttr>(getImpl()->binding).getInt();
 }
 
 uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
-  return llvm::cast<IntegerAttr>(getImpl()->descriptorSet).getInt();
+  return cast<IntegerAttr>(getImpl()->descriptorSet).getInt();
 }
 
 std::optional<spirv::StorageClass>
 spirv::InterfaceVarABIAttr::getStorageClass() {
   if (getImpl()->storageClass)
     return static_cast<spirv::StorageClass>(
-        llvm::cast<IntegerAttr>(getImpl()->storageClass)
-            .getValue()
-            .getZExtValue());
+        cast<IntegerAttr>(getImpl()->storageClass).getValue().getZExtValue());
   return std::nullopt;
 }
 
@@ -173,7 +171,7 @@ LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
     return emitError() << "expected 32-bit integer for binding";
 
   if (storageClass) {
-    if (auto storageClassAttr = llvm::cast<IntegerAttr>(storageClass)) {
+    if (auto storageClassAttr = cast<IntegerAttr>(storageClass)) {
       auto storageClassValue =
           spirv::symbolizeStorageClass(storageClassAttr.getInt());
       if (!storageClassValue)
@@ -222,14 +220,14 @@ StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
 
 spirv::Version spirv::VerCapExtAttr::getVersion() {
   return static_cast<spirv::Version>(
-      llvm::cast<IntegerAttr>(getImpl()->version).getValue().getZExtValue());
+      cast<IntegerAttr>(getImpl()->version).getValue().getZExtValue());
 }
 
 spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
     : llvm::mapped_iterator<ArrayAttr::iterator,
                             spirv::Extension (*)(Attribute)>(
           it, [](Attribute attr) {
-            return *symbolizeExtension(llvm::cast<StringAttr>(attr).getValue());
+            return *symbolizeExtension(cast<StringAttr>(attr).getValue());
           }) {}
 
 spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
@@ -238,7 +236,7 @@ spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
 }
 
 ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() {
-  return llvm::cast<ArrayAttr>(getImpl()->extensions);
+  return cast<ArrayAttr>(getImpl()->extensions);
 }
 
 spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
@@ -246,7 +244,7 @@ spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
                             spirv::Capability (*)(Attribute)>(
           it, [](Attribute attr) {
             return *symbolizeCapability(
-                llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
+                cast<IntegerAttr>(attr).getValue().getZExtValue());
           }) {}
 
 spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
@@ -255,7 +253,7 @@ spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
 }
 
 ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
-  return llvm::cast<ArrayAttr>(getImpl()->capabilities);
+  return cast<ArrayAttr>(getImpl()->capabilities);
 }
 
 LogicalResult spirv::VerCapExtAttr::verifyInvariants(
@@ -265,7 +263,7 @@ LogicalResult spirv::VerCapExtAttr::verifyInvariants(
     return emitError() << "expected 32-bit integer for version";
 
   if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
-        if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+        if (auto intAttr = dyn_cast<IntegerAttr>(attr))
           if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
             return true;
         return false;
@@ -273,7 +271,7 @@ LogicalResult spirv::VerCapExtAttr::verifyInvariants(
     return emitError() << "unknown capability in capability list";
 
   if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
-        if (auto strAttr = llvm::dyn_cast<StringAttr>(attr))
+        if (auto strAttr = dyn_cast<StringAttr>(attr))
           if (spirv::symbolizeExtension(strAttr.getValue()))
             return true;
         return false;
@@ -299,7 +297,7 @@ spirv::TargetEnvAttr spirv::TargetEnvAttr::get(
 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
 
 spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
-  return llvm::cast<spirv::VerCapExtAttr>(getImpl()->triple);
+  return cast<spirv::VerCapExtAttr>(getImpl()->triple);
 }
 
 spirv::Version spirv::TargetEnvAttr::getVersion() const {
@@ -339,7 +337,7 @@ uint32_t spirv::TargetEnvAttr::getDeviceID() const {
 }
 
 spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
-  return llvm::cast<spirv::ResourceLimitsAttr>(getImpl()->limits);
+  return cast<spirv::ResourceLimitsAttr>(getImpl()->limits);
 }
 
 //===----------------------------------------------------------------------===//
@@ -668,11 +666,11 @@ void SPIRVDialect::printAttribute(Attribute attr,
   if (succeeded(generatedAttributePrinter(attr, printer)))
     return;
 
-  if (auto targetEnv = llvm::dyn_cast<TargetEnvAttr>(attr))
+  if (auto targetEnv = dyn_cast<TargetEnvAttr>(attr))
     print(targetEnv, printer);
-  else if (auto vceAttr = llvm::dyn_cast<VerCapExtAttr>(attr))
+  else if (auto vceAttr = dyn_cast<VerCapExtAttr>(attr))
     print(vceAttr, printer);
-  else if (auto interfaceVarABIAttr = llvm::dyn_cast<InterfaceVarABIAttr>(attr))
+  else if (auto interfaceVarABIAttr = dyn_cast<InterfaceVarABIAttr>(attr))
     print(interfaceVarABIAttr, printer);
   else
     llvm_unreachable("unhandled SPIR-V attribute kind");

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index ccc85368c78a4..9ab3bdc6ab102 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -35,9 +35,9 @@ static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
   if (!attr)
     return std::nullopt;
 
-  if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
+  if (auto boolAttr = dyn_cast<BoolAttr>(attr))
     return boolAttr.getValue();
-  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
+  if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
     if (splatAttr.getElementType().isInteger(1))
       return splatAttr.getSplatValue<bool>();
   return std::nullopt;
@@ -54,12 +54,12 @@ static Attribute extractCompositeElement(Attribute composite,
   if (indices.empty())
     return composite;
 
-  if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
+  if (auto vector = dyn_cast<ElementsAttr>(composite)) {
     assert(indices.size() == 1 && "must have exactly one index for a vector");
     return vector.getValues<Attribute>()[indices[0]];
   }
 
-  if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
+  if (auto array = dyn_cast<ArrayAttr>(composite)) {
     assert(!indices.empty() && "must have at least one index for an array");
     return extractCompositeElement(array.getValue()[indices[0]],
                                    indices.drop_front());
@@ -370,7 +370,7 @@ struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
 
 void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns.insert<UModSimplification>(context);
+  patterns.add<UModSimplification>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -412,10 +412,10 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
 
   if (auto constructOp =
           compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
-    auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
+    auto type = cast<spirv::CompositeType>(constructOp.getType());
     if (getIndices().size() == 1 &&
         constructOp.getConstituents().size() == type.getNumElements()) {
-      auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
+      auto i = cast<IntegerAttr>(*getIndices().begin());
       if (i.getValue().getSExtValue() <
           static_cast<int64_t>(constructOp.getConstituents().size()))
         return constructOp.getConstituents()[i.getValue().getSExtValue()];
@@ -423,7 +423,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
   }
 
   auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
-    return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
+    return static_cast<unsigned>(cast<IntegerAttr>(attr).getInt());
   });
   return extractCompositeElement(adaptor.getComposite(), indexVector);
 }
@@ -1379,7 +1379,7 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
   // Starting with version 1.4, Result Type can additionally be a composite type
   // other than a vector."
   bool isScalarOrVector =
-      llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
+      cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
           .isScalarOrVector();
 
   // Check that each `spirv.Store` uses the same pointer, memory access

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 24c33f9ae1b90..22b57d6c0821a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -172,16 +172,16 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
     return type;
 
   // Check other allowed types
-  if (auto t = llvm::dyn_cast<FloatType>(type)) {
+  if (auto t = dyn_cast<FloatType>(type)) {
     // TODO: All float types are allowed for now, but this should be fixed.
-  } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
+  } else if (auto t = dyn_cast<IntegerType>(type)) {
     if (!ScalarType::isValid(t)) {
       parser.emitError(typeLoc,
                        "only 1/8/16/32/64-bit integer type allowed but found ")
           << type;
       return Type();
     }
-  } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
+  } else if (auto t = dyn_cast<VectorType>(type)) {
     if (t.getRank() != 1) {
       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
       return Type();
@@ -215,7 +215,7 @@ static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
   if (parser.parseType(type))
     return Type();
 
-  if (auto t = llvm::dyn_cast<VectorType>(type)) {
+  if (auto t = dyn_cast<VectorType>(type)) {
     if (t.getRank() != 1) {
       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
       return Type();
@@ -228,7 +228,7 @@ static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
       return Type();
     }
 
-    if (!llvm::isa<FloatType>(t.getElementType())) {
+    if (!isa<FloatType>(t.getElementType())) {
       parser.emitError(typeLoc, "matrix columns' elements must be of "
                                 "Float type, got ")
           << t.getElementType();
@@ -1016,12 +1016,12 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
   Attribute attr = attribute.getValue();
 
   if (symbol == spirv::getEntryPointABIAttrName()) {
-    if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
+    if (!isa<spirv::EntryPointABIAttr>(attr)) {
       return op->emitError("'")
              << symbol << "' attribute must be an entry point ABI attribute";
     }
   } else if (symbol == spirv::getTargetEnvAttrName()) {
-    if (!llvm::isa<spirv::TargetEnvAttr>(attr))
+    if (!isa<spirv::TargetEnvAttr>(attr))
       return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
   } else {
     return op->emitError("found unsupported '")
@@ -1039,7 +1039,7 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
   Attribute attr = attribute.getValue();
 
   if (symbol == spirv::getInterfaceVarABIAttrName()) {
-    auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+    auto varABIAttr = dyn_cast<spirv::InterfaceVarABIAttr>(attr);
     if (!varABIAttr)
       return emitError(loc, "'")
              << symbol << "' must be a spirv::InterfaceVarABIAttr";

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index 8575487ff52cc..ba69fa75cf2b8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -53,7 +53,7 @@ static bool isDirectInModuleLikeOp(Operation *op) {
 static Type getUnaryOpResultType(Type operandType) {
   Builder builder(operandType.getContext());
   Type resultType = builder.getIntegerType(1);
-  if (auto vecType = llvm::dyn_cast<VectorType>(operandType))
+  if (auto vecType = dyn_cast<VectorType>(operandType))
     return VectorType::get(vecType.getNumElements(), resultType);
   return resultType;
 }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 938952ed273cd..1962538d804a8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -52,7 +52,7 @@ LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
     return failure();
   }
   auto valueAttr = constOp.getValue();
-  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
+  auto integerValueAttr = dyn_cast<IntegerAttr>(valueAttr);
   if (!integerValueAttr) {
     return failure();
   }
@@ -129,7 +129,7 @@ static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
         parser.parseOptionalAttrDict(result.attributes) ||
         parser.parseColon() || parser.parseType(type))
       return failure();
-    auto fnType = llvm::dyn_cast<FunctionType>(type);
+    auto fnType = dyn_cast<FunctionType>(type);
     if (!fnType) {
       parser.emitError(loc, "expected function type");
       return failure();
@@ -169,11 +169,10 @@ template <typename BlockReadWriteOpTy>
 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
                                                         Value ptr, Value val) {
   auto valType = val.getType();
-  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
+  if (auto valVecTy = dyn_cast<VectorType>(valType))
     valType = valVecTy.getElementType();
 
-  if (valType !=
-      llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
+  if (valType != cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
     return op.emitOpError("mismatch in result type and pointer type");
   }
   return success();
@@ -191,7 +190,7 @@ getElementType(Type type, ArrayRef<int32_t> indices,
   }
 
   for (auto index : indices) {
-    if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
+    if (auto cType = dyn_cast<spirv::CompositeType>(type)) {
       if (cType.hasCompileTimeKnownNumElements() &&
           (index < 0 ||
            static_cast<uint64_t>(index) >= cType.getNumElements())) {
@@ -211,7 +210,7 @@ getElementType(Type type, ArrayRef<int32_t> indices,
 static Type
 getElementType(Type type, Attribute indices,
                function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
-  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
+  auto indicesArrayAttr = dyn_cast<ArrayAttr>(indices);
   if (!indicesArrayAttr) {
     emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
     return nullptr;
@@ -223,7 +222,7 @@ getElementType(Type type, Attribute indices,
 
   SmallVector<int32_t, 2> indexVals;
   for (auto indexAttr : indicesArrayAttr) {
-    auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
+    auto indexIntAttr = dyn_cast<IntegerAttr>(indexAttr);
     if (!indexIntAttr) {
       emitErrorFn("expected an 32-bit integer for index, but found '")
           << indexAttr << "'";
@@ -251,7 +250,7 @@ static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
 
 template <typename ExtendedBinaryOp>
 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
-  auto resultType = llvm::cast<spirv::StructType>(op.getType());
+  auto resultType = cast<spirv::StructType>(op.getType());
   if (resultType.getNumElements() != 2)
     return op.emitOpError("expected result struct type containing two members");
 
@@ -276,7 +275,7 @@ static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
   if (parser.parseType(resultType))
     return failure();
 
-  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
+  auto structType = dyn_cast<spirv::StructType>(resultType);
   if (!structType || structType.getNumElements() != 2)
     return parser.emitError(loc, "expected spirv.struct type with two members");
 
@@ -361,7 +360,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
   }
 
   // Case 2./3./4. -- number of constituents matches the number of elements.
-  auto cType = llvm::cast<spirv::CompositeType>(getType());
+  auto cType = cast<spirv::CompositeType>(getType());
   if (constituents.size() == cType.getNumElements()) {
     for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
       if (constituents[index].getType() != cType.getElementType(index)) {
@@ -374,7 +373,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
   }
 
   // Case 4. -- check that all constituents add up tp the expected vector type.
-  auto resultType = llvm::dyn_cast<VectorType>(cType);
+  auto resultType = dyn_cast<VectorType>(cType);
   if (!resultType)
     return emitOpError(
         "expected to return a vector or cooperative matrix when the number of "
@@ -382,14 +381,14 @@ LogicalResult spirv::CompositeConstructOp::verify() {
 
   SmallVector<unsigned> sizes;
   for (Value component : constituents) {
-    if (!llvm::isa<VectorType>(component.getType()) &&
+    if (!isa<VectorType>(component.getType()) &&
         !component.getType().isIntOrFloat())
       return emitOpError("operand type mismatch: expected operand to have "
                          "a scalar or vector type, but provided ")
              << component.getType();
 
     Type elementType = component.getType();
-    if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
+    if (auto vectorType = dyn_cast<VectorType>(component.getType())) {
       sizes.push_back(vectorType.getNumElements());
       elementType = vectorType.getElementType();
     } else {
@@ -455,7 +454,7 @@ void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::CompositeExtractOp::verify() {
-  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
+  auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
   auto resultType =
       getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
   if (!resultType)
@@ -500,7 +499,7 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
 }
 
 LogicalResult spirv::CompositeInsertOp::verify() {
-  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
+  auto indicesArrayAttr = dyn_cast<ArrayAttr>(getIndices());
   auto objectType =
       getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
   if (!objectType)
@@ -538,14 +537,14 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
     return failure();
 
   Type type = NoneType::get(parser.getContext());
-  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
+  if (auto typedAttr = dyn_cast<TypedAttr>(value))
     type = typedAttr.getType();
-  if (llvm::isa<NoneType, TensorType>(type)) {
+  if (isa<NoneType, TensorType>(type)) {
     if (parser.parseColonType(type))
       return failure();
   }
 
-  if (llvm::isa<TensorArmType>(type)) {
+  if (isa<TensorArmType>(type)) {
     if (parser.parseOptionalColon().succeeded())
       if (parser.parseType(type))
         return failure();
@@ -556,7 +555,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
 
 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
   printer << ' ' << getValue();
-  if (llvm::isa<spirv::ArrayType>(getType()))
+  if (isa<spirv::ArrayType>(getType()))
     printer << " : " << getType();
 }
 
@@ -569,19 +568,19 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
                             "matrix constant, but found ")
              << denseAttr;
   }
-  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
-    auto valueType = llvm::cast<TypedAttr>(value).getType();
+  if (isa<IntegerAttr, FloatAttr>(value)) {
+    auto valueType = cast<TypedAttr>(value).getType();
     if (valueType != opType)
       return op.emitOpError("result type (")
              << opType << ") does not match value type (" << valueType << ")";
     return success();
   }
-  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
-    auto valueType = llvm::cast<TypedAttr>(value).getType();
+  if (isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
+    auto valueType = cast<TypedAttr>(value).getType();
     if (valueType == opType)
       return success();
-    auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
-    auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
+    auto arrayType = dyn_cast<spirv::ArrayType>(opType);
+    auto shapedType = dyn_cast<ShapedType>(valueType);
     if (!arrayType)
       return op.emitOpError("result or element type (")
              << opType << ") does not match value type (" << valueType
@@ -589,7 +588,7 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
 
     int numElements = arrayType.getNumElements();
     auto opElemType = arrayType.getElementType();
-    while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
+    while (auto t = dyn_cast<spirv::ArrayType>(opElemType)) {
       numElements *= t.getNumElements();
       opElemType = t.getElementType();
     }
@@ -610,8 +609,8 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
     }
     return success();
   }
-  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
-    auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(value)) {
+    auto arrayType = dyn_cast<spirv::ArrayType>(opType);
     if (!arrayType)
       return op.emitOpError(
           "must have spirv.array result type for array value");
@@ -635,12 +634,12 @@ LogicalResult spirv::ConstantOp::verify() {
 
 bool spirv::ConstantOp::isBuildableWith(Type type) {
   // Must be valid SPIR-V type first.
-  if (!llvm::isa<spirv::SPIRVType>(type))
+  if (!isa<spirv::SPIRVType>(type))
     return false;
 
   if (isa<SPIRVDialect>(type.getDialect())) {
     // TODO: support constant struct
-    return llvm::isa<spirv::ArrayType>(type);
+    return isa<spirv::ArrayType>(type);
   }
 
   return true;
@@ -648,7 +647,7 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
 
 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
                                              OpBuilder &builder) {
-  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
+  if (auto intType = dyn_cast<IntegerType>(type)) {
     unsigned width = intType.getWidth();
     if (width == 1)
       return spirv::ConstantOp::create(builder, loc, type,
@@ -656,19 +655,19 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
     return spirv::ConstantOp::create(
         builder, loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
   }
-  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+  if (auto floatType = dyn_cast<FloatType>(type)) {
     return spirv::ConstantOp::create(builder, loc, type,
                                      builder.getFloatAttr(floatType, 0.0));
   }
-  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
+  if (auto vectorType = dyn_cast<VectorType>(type)) {
     Type elemType = vectorType.getElementType();
-    if (llvm::isa<IntegerType>(elemType)) {
+    if (isa<IntegerType>(elemType)) {
       return spirv::ConstantOp::create(
           builder, loc, type,
           DenseElementsAttr::get(vectorType,
                                  IntegerAttr::get(elemType, 0).getValue()));
     }
-    if (llvm::isa<FloatType>(elemType)) {
+    if (isa<FloatType>(elemType)) {
       return spirv::ConstantOp::create(
           builder, loc, type,
           DenseFPElementsAttr::get(vectorType,
@@ -681,7 +680,7 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
 
 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
                                             OpBuilder &builder) {
-  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
+  if (auto intType = dyn_cast<IntegerType>(type)) {
     unsigned width = intType.getWidth();
     if (width == 1)
       return spirv::ConstantOp::create(builder, loc, type,
@@ -689,19 +688,19 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
     return spirv::ConstantOp::create(
         builder, loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
   }
-  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+  if (auto floatType = dyn_cast<FloatType>(type)) {
     return spirv::ConstantOp::create(builder, loc, type,
                                      builder.getFloatAttr(floatType, 1.0));
   }
-  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
+  if (auto vectorType = dyn_cast<VectorType>(type)) {
     Type elemType = vectorType.getElementType();
-    if (llvm::isa<IntegerType>(elemType)) {
+    if (isa<IntegerType>(elemType)) {
       return spirv::ConstantOp::create(
           builder, loc, type,
           DenseElementsAttr::get(vectorType,
                                  IntegerAttr::get(elemType, 1).getValue()));
     }
-    if (llvm::isa<FloatType>(elemType)) {
+    if (isa<FloatType>(elemType)) {
       return spirv::ConstantOp::create(
           builder, loc, type,
           DenseFPElementsAttr::get(vectorType,
@@ -720,9 +719,9 @@ void mlir::spirv::ConstantOp::getAsmResultNames(
   llvm::raw_svector_ostream specialName(specialNameBuffer);
   specialName << "cst";
 
-  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);
+  IntegerType intTy = dyn_cast<IntegerType>(type);
 
-  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
+  if (IntegerAttr intCst = dyn_cast<IntegerAttr>(getValue())) {
     assert(intTy);
 
     if (intTy.getWidth() == 1) {
@@ -738,18 +737,17 @@ void mlir::spirv::ConstantOp::getAsmResultNames(
     }
   }
 
-  if (intTy || llvm::isa<FloatType>(type)) {
+  if (intTy || isa<FloatType>(type)) {
     specialName << '_' << type;
   }
 
-  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
+  if (auto vecType = dyn_cast<VectorType>(type)) {
     specialName << "_vec_";
     specialName << vecType.getDimSize(0);
 
     Type elementType = vecType.getElementType();
 
-    if (llvm::isa<IntegerType>(elementType) ||
-        llvm::isa<FloatType>(elementType)) {
+    if (isa<IntegerType>(elementType) || isa<FloatType>(elementType)) {
       specialName << "x" << elementType;
     }
   }
@@ -903,7 +901,7 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
     if (parser.parseAttribute(value, i32Type, "value", attr)) {
       return failure();
     }
-    values.push_back(llvm::cast<IntegerAttr>(value).getInt());
+    values.push_back(cast<IntegerAttr>(value).getInt());
   }
   StringRef valuesAttrName =
       spirv::ExecutionModeOp::getValuesAttrName(result.name);
@@ -1005,7 +1003,7 @@ LogicalResult spirv::FuncOp::verifyType() {
 
   auto hasDecorationAttr = [&](spirv::Decoration decoration,
                                unsigned argIndex) {
-    auto func = llvm::cast<FunctionOpInterface>(getOperation());
+    auto func = cast<FunctionOpInterface>(getOperation());
     for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
       if (argAttr.getName() != spirv::DecorationAttr::name)
         continue;
@@ -1224,7 +1222,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   if (parser.parseColonType(type)) {
     return failure();
   }
-  if (!llvm::isa<spirv::PointerType>(type)) {
+  if (!isa<spirv::PointerType>(type)) {
     return parser.emitError(loc, "expected spirv.ptr type");
   }
   result.addAttribute(typeAttrName, TypeAttr::get(type));
@@ -1257,7 +1255,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::GlobalVariableOp::verify() {
-  if (!llvm::isa<spirv::PointerType>(getType()))
+  if (!isa<spirv::PointerType>(getType()))
     return emitOpError("result must be of a !spv.ptr type");
 
   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
@@ -1325,7 +1323,7 @@ ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
   }
 
   auto ptrType = spirv::PointerType::get(elementType, storageClass);
-  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
+  if (auto valVecTy = dyn_cast<VectorType>(elementType))
     ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
 
   if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
@@ -1539,7 +1537,7 @@ LogicalResult spirv::ModuleOp::verifyRegions() {
       }
       if (auto interface = entryPointOp.getInterface()) {
         for (Attribute varRef : interface) {
-          auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
+          auto varSymRef = dyn_cast<FlatSymbolRefAttr>(varRef);
           if (!varSymRef) {
             return entryPointOp.emitError(
                        "expected symbol reference for interface "
@@ -1660,9 +1658,9 @@ LogicalResult spirv::SpecConstantOp::verify() {
       return emitOpError("SpecId cannot be negative");
 
   auto value = getDefaultValue();
-  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
+  if (isa<IntegerAttr, FloatAttr>(value)) {
     // Make sure bitwidth is allowed.
-    if (!llvm::isa<spirv::SPIRVType>(value.getType()))
+    if (!isa<spirv::SPIRVType>(value.getType()))
       return emitOpError("default value bitwidth disallowed");
     return success();
   }
@@ -1675,7 +1673,7 @@ LogicalResult spirv::SpecConstantOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::VectorShuffleOp::verify() {
-  VectorType resultType = llvm::cast<VectorType>(getType());
+  VectorType resultType = cast<VectorType>(getType());
 
   size_t numResultElements = resultType.getNumElements();
   if (numResultElements != getComponents().size())
@@ -1685,8 +1683,8 @@ LogicalResult spirv::VectorShuffleOp::verify() {
            << getComponents().size() << ")";
 
   size_t totalSrcElements =
-      llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
-      llvm::cast<VectorType>(getVector2().getType()).getNumElements();
+      cast<VectorType>(getVector1().getType()).getNumElements() +
+      cast<VectorType>(getVector2().getType()).getNumElements();
 
   for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
     uint32_t index = selector.getZExtValue();
@@ -1725,8 +1723,8 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::TransposeOp::verify() {
-  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
-  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
+  auto inputMatrix = cast<spirv::MatrixType>(getMatrix().getType());
+  auto resultMatrix = cast<spirv::MatrixType>(getResult().getType());
 
   // Verify that the input and output matrices have correct shapes.
   if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
@@ -1750,9 +1748,9 @@ LogicalResult spirv::TransposeOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::MatrixTimesVectorOp::verify() {
-  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
-  auto vectorType = llvm::cast<VectorType>(getVector().getType());
-  auto resultType = llvm::cast<VectorType>(getType());
+  auto matrixType = cast<spirv::MatrixType>(getMatrix().getType());
+  auto vectorType = cast<VectorType>(getVector().getType());
+  auto resultType = cast<VectorType>(getType());
 
   if (matrixType.getNumColumns() != vectorType.getNumElements())
     return emitOpError("matrix columns (")
@@ -1775,9 +1773,9 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::VectorTimesMatrixOp::verify() {
-  auto vectorType = llvm::cast<VectorType>(getVector().getType());
-  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
-  auto resultType = llvm::cast<VectorType>(getType());
+  auto vectorType = cast<VectorType>(getVector().getType());
+  auto matrixType = cast<spirv::MatrixType>(getMatrix().getType());
+  auto resultType = cast<VectorType>(getType());
 
   if (matrixType.getNumRows() != vectorType.getNumElements())
     return emitOpError("number of components in vector must equal the number "
@@ -1799,9 +1797,9 @@ LogicalResult spirv::VectorTimesMatrixOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
-  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
-  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
-  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());
+  auto leftMatrix = cast<spirv::MatrixType>(getLeftmatrix().getType());
+  auto rightMatrix = cast<spirv::MatrixType>(getRightmatrix().getType());
+  auto resultMatrix = cast<spirv::MatrixType>(getResult().getType());
 
   // left matrix columns' count and right matrix rows' count must be equal
   if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
@@ -1886,14 +1884,14 @@ void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::SpecConstantCompositeOp::verify() {
-  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
+  auto cType = dyn_cast<spirv::CompositeType>(getType());
   auto constituents = this->getConstituents().getValue();
 
   if (!cType)
     return emitError("result type must be a composite type, but provided ")
            << getType();
 
-  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
+  if (isa<spirv::CooperativeMatrixType>(cType))
     return emitError("unsupported composite type  ") << cType;
   if (constituents.size() != cType.getNumElements())
     return emitError("has incorrect number of operands: expected ")
@@ -1901,7 +1899,7 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
            << constituents.size();
 
   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
-    auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);
+    auto constituent = cast<FlatSymbolRefAttr>(constituents[index]);
 
     auto constituentSpecConstOp =
         dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
@@ -2042,19 +2040,19 @@ LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
 
 LogicalResult spirv::GLFrexpStructOp::verify() {
   spirv::StructType structTy =
-      llvm::dyn_cast<spirv::StructType>(getResult().getType());
+      dyn_cast<spirv::StructType>(getResult().getType());
 
   if (structTy.getNumElements() != 2)
     return emitError("result type must be a struct type with two memebers");
 
   Type significandTy = structTy.getElementType(0);
   Type exponentTy = structTy.getElementType(1);
-  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
-  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);
+  VectorType exponentVecTy = dyn_cast<VectorType>(exponentTy);
+  IntegerType exponentIntTy = dyn_cast<IntegerType>(exponentTy);
 
   Type operandTy = getOperand().getType();
-  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
-  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);
+  VectorType operandVecTy = dyn_cast<VectorType>(operandTy);
+  FloatType operandFTy = dyn_cast<FloatType>(operandTy);
 
   if (significandTy != operandTy)
     return emitError("member zero of the resulting struct type must be the "
@@ -2062,7 +2060,7 @@ LogicalResult spirv::GLFrexpStructOp::verify() {
 
   if (exponentVecTy) {
     IntegerType componentIntTy =
-        llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
+        dyn_cast<IntegerType>(exponentVecTy.getElementType());
     if (!componentIntTy || componentIntTy.getWidth() != 32)
       return emitError("member one of the resulting struct type must"
                        "be a scalar or vector of 32 bit integer type");
@@ -2091,12 +2089,11 @@ LogicalResult spirv::GLLdexpOp::verify() {
   Type significandType = getX().getType();
   Type exponentType = getExp().getType();
 
-  if (llvm::isa<FloatType>(significandType) !=
-      llvm::isa<IntegerType>(exponentType))
+  if (isa<FloatType>(significandType) != isa<IntegerType>(exponentType))
     return emitOpError("operands must both be scalars or vectors");
 
   auto getNumElements = [](Type type) -> unsigned {
-    if (auto vectorType = llvm::dyn_cast<VectorType>(type))
+    if (auto vectorType = dyn_cast<VectorType>(type))
       return vectorType.getNumElements();
     return 1;
   };
@@ -2138,7 +2135,7 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
 LogicalResult spirv::VectorTimesScalarOp::verify() {
   if (getVector().getType() != getType())
     return emitOpError("vector operand and result type mismatch");
-  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
+  auto scalarType = cast<VectorType>(getType()).getElementType();
   if (getScalar().getType() != scalarType)
     return emitOpError("scalar operand and result element type match");
   return success();

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index f28d386f8874d..772219c8db654 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -75,11 +75,11 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
   if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
                             attrName, attr))
     return failure();
-  if (!llvm::isa<StringAttr>(attrVal))
+  if (!isa<StringAttr>(attrVal))
     return parser.emitError(loc, "expected ")
            << attrName << " attribute specified as string";
-  auto attrOptional = spirv::symbolizeEnum<EnumClass>(
-      llvm::cast<StringAttr>(attrVal).getValue());
+  auto attrOptional =
+      spirv::symbolizeEnum<EnumClass>(cast<StringAttr>(attrVal).getValue());
   if (!attrOptional)
     return parser.emitError(loc, "invalid ")
            << attrName << " attribute specification: " << attrVal;

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index d1e275d590f78..53a48abe5ad02 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -178,17 +178,17 @@ unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
 //===----------------------------------------------------------------------===//
 
 bool CompositeType::classof(Type type) {
-  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
+  if (auto vectorType = dyn_cast<VectorType>(type))
     return isValid(vectorType);
-  return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
-                   spirv::MatrixType, spirv::RuntimeArrayType,
-                   spirv::StructType, spirv::TensorArmType>(type);
+  return isa<spirv::ArrayType, spirv::CooperativeMatrixType, spirv::MatrixType,
+             spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType>(
+      type);
 }
 
 bool CompositeType::isValid(VectorType type) {
   return type.getRank() == 1 &&
          llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
-         llvm::isa<ScalarType>(type.getElementType());
+         isa<ScalarType>(type.getElementType());
 }
 
 Type CompositeType::getElementType(unsigned index) const {
@@ -210,7 +210,7 @@ unsigned CompositeType::getNumElements() const {
 }
 
 bool CompositeType::hasCompileTimeKnownNumElements() const {
-  return !llvm::isa<CooperativeMatrixType, RuntimeArrayType>(*this);
+  return !isa<CooperativeMatrixType, RuntimeArrayType>(*this);
 }
 
 void TypeCapabilityVisitor::addConcrete(VectorType type) {
@@ -529,10 +529,10 @@ void TypeCapabilityVisitor::addConcrete(RuntimeArrayType type) {
 //===----------------------------------------------------------------------===//
 
 bool ScalarType::classof(Type type) {
-  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+  if (auto floatType = dyn_cast<FloatType>(type)) {
     return isValid(floatType);
   }
-  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
+  if (auto intType = dyn_cast<IntegerType>(type)) {
     return isValid(intType);
   }
   return false;
@@ -676,19 +676,19 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
 
 bool SPIRVType::classof(Type type) {
   // Allow SPIR-V dialect types
-  if (llvm::isa<SPIRVDialect>(type.getDialect()))
+  if (isa<SPIRVDialect>(type.getDialect()))
     return true;
-  if (llvm::isa<ScalarType>(type))
+  if (isa<ScalarType>(type))
     return true;
-  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
+  if (auto vectorType = dyn_cast<VectorType>(type))
     return CompositeType::isValid(vectorType);
-  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
-    return llvm::isa<ScalarType>(tensorArmType.getElementType());
+  if (auto tensorArmType = dyn_cast<TensorArmType>(type))
+    return isa<ScalarType>(tensorArmType.getElementType());
   return false;
 }
 
 bool SPIRVType::isScalarOrVector() {
-  return isIntOrFloat() || llvm::isa<VectorType>(*this);
+  return isIntOrFloat() || isa<VectorType>(*this);
 }
 
 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1190,7 +1190,7 @@ MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
     return emitError() << "matrix columns must be vectors of floats";
 
   /// The underlying vectors (columns) must be of size 2, 3, or 4
-  ArrayRef<int64_t> columnShape = llvm::cast<VectorType>(columnType).getShape();
+  ArrayRef<int64_t> columnShape = cast<VectorType>(columnType).getShape();
   if (columnShape.size() != 1)
     return emitError() << "matrix columns must be 1D vectors";
 
@@ -1202,8 +1202,8 @@ MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
 
 /// Returns true if the matrix elements are vectors of float elements
 bool MatrixType::isValidColumnType(Type columnType) {
-  if (auto vectorType = llvm::dyn_cast<VectorType>(columnType)) {
-    if (llvm::isa<FloatType>(vectorType.getElementType()))
+  if (auto vectorType = dyn_cast<VectorType>(columnType)) {
+    if (isa<FloatType>(vectorType.getElementType()))
       return true;
   }
   return false;
@@ -1212,13 +1212,13 @@ bool MatrixType::isValidColumnType(Type columnType) {
 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
 
 Type MatrixType::getElementType() const {
-  return llvm::cast<VectorType>(getImpl()->columnType).getElementType();
+  return cast<VectorType>(getImpl()->columnType).getElementType();
 }
 
 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
 
 unsigned MatrixType::getNumRows() const {
-  return llvm::cast<VectorType>(getImpl()->columnType).getShape()[0];
+  return cast<VectorType>(getImpl()->columnType).getShape()[0];
 }
 
 unsigned MatrixType::getNumElements() const {

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 50883d9ed5e75..ce7e7dc4116c8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -243,7 +243,7 @@ static LogicalResult deserializeCacheControlDecoration(
   auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
   SmallVector<Attribute> attrs;
   if (auto attrList =
-          llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
+          dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
     llvm::append_range(attrs, attrList);
   attrs.push_back(value);
   decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
@@ -326,7 +326,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
         static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
     auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
         StringAttr::get(context, linkageName), linkageTypeAttr);
-    decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
+    decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
     break;
   }
   case spirv::Decoration::Aliased:
@@ -1511,10 +1511,10 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
     return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
                                  "constant instruction of type OpTypeArray");
 
-  ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
+  ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
   SmallVector<int64_t, 1> shape;
   for (auto dimAttr : shapeArrayAttr.getValue()) {
-    auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
+    auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
     if (!dimIntAttr)
       return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
                                    "dimension size");

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 6397d2c005c16..b78fac532d8c5 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -876,7 +876,7 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
   if (values) {
     for (auto &intVal : values.getValue()) {
       operands.push_back(static_cast<uint32_t>(
-          llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
+          cast<IntegerAttr>(intVal).getValue().getZExtValue()));
     }
   }
   encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c879a2b3e0207..c29d20f755332 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -312,7 +312,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
   case spirv::Decoration::LinkageAttributes: {
     // Get the value of the Linkage Attributes
     // e.g., LinkageAttributes=["linkageName", linkageType].
-    auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
+    auto linkageAttr = dyn_cast<spirv::LinkageAttributesAttr>(attr);
     auto linkageName = linkageAttr.getLinkageName();
     auto linkageType = linkageAttr.getLinkageType().getValue();
     // Encode the Linkage Name (string literal to uint32_t).
@@ -822,7 +822,7 @@ LogicalResult Serializer::prepareBasicType(
     return success();
   }
 
-  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+  if (auto tensorArmType = dyn_cast<TensorArmType>(type)) {
     uint32_t elementTypeID = 0;
     uint32_t rank = 0;
     uint32_t shapeID = 0;


        


More information about the Mlir-commits mailing list