[Mlir-commits] [mlir] [mlir][SPIR-V] Add support for SPV_INTEL_masked_gather_scatter extension (PR #189099)

Arseniy Obolenskiy llvmlistbot at llvm.org
Mon Apr 13 01:33:24 PDT 2026


https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/189099

>From bdc658c95274ad076dd14593492abae33113caec Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Fri, 27 Mar 2026 17:29:54 +0100
Subject: [PATCH 1/3] [mlir][SPIR-V] Add MaskedGather/MaskedScatter ops and
 VectorOfPointerType for SPV_INTEL_masked_gather_scatter extension

Add support for the SPV_INTEL_masked_gather_scatter extension implemented in #185418
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  25 +++-
 .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 113 +++++++++++++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        |  20 +++
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  49 ++++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |  65 ++++++++++
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      |  56 +++++++-
 .../SPIRV/Deserialization/Deserializer.cpp    |   6 +-
 .../Target/SPIRV/Serialization/Serializer.cpp |  12 ++
 mlir/test/Dialect/SPIRV/IR/composite-ops.mlir |   2 +-
 mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 120 ++++++++++++++++++
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir |   2 +-
 mlir/test/Target/SPIRV/intel-ext-ops.mlir     |  60 +++++++++
 12 files changed, 517 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 9f9e2f5f9a677..d9a573b59c368 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -407,6 +407,7 @@ def SPV_INTEL_split_barrier                      : I32EnumAttrCase<"SPV_INTEL_sp
 def SPV_INTEL_bfloat16_conversion                : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
 def SPV_INTEL_cache_controls                     : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
 def SPV_INTEL_tensor_float32_conversion          : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
+def SPV_INTEL_masked_gather_scatter              : I32EnumAttrCase<"SPV_INTEL_masked_gather_scatter", 4034>;
 
 def SPV_NV_compute_shader_derivatives    : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
 def SPV_NV_cooperative_matrix            : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -472,6 +473,7 @@ def SPIRV_ExtensionAttr :
       SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
       SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
       SPV_INTEL_tensor_float32_conversion,
+      SPV_INTEL_masked_gather_scatter,
       SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
       SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
       SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1481,6 +1483,12 @@ def SPIRV_C_TensorFloat32RoundingINTEL                       : I32EnumAttrCase<"
   ];
 }
 
+def SPIRV_C_MaskedGatherScatterINTEL                         : I32EnumAttrCase<"MaskedGatherScatterINTEL", 6427> {
+  list<Availability> availability = [
+    Extension<[SPV_INTEL_masked_gather_scatter]>
+  ];
+}
+
 def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_cache_controls]>
@@ -1590,7 +1598,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
       SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
       SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
-      SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_Float8EXT
+      SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_MaskedGatherScatterINTEL,
+      SPIRV_C_Float8EXT
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -4257,6 +4266,7 @@ def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_
 def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
 def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
 def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
+def SPIRV_IsVectorOfPointerType : CPred<"::llvm::isa<::mlir::spirv::VectorOfPointerType>($_self)">;
 
 // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
 // for the definition of the following types and type categories.
@@ -4300,18 +4310,22 @@ def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
                                 "any SPIR-V sampled image type">;
 def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
                                  "any SPIR-V tensorArm type">;
+def SPIRV_AnyVectorOfPointer : DialectType<SPIRV_Dialect,
+                                 SPIRV_IsVectorOfPointerType,
+                                 "any SPIR-V vector of pointer type">;
 
 def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-               SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm]>;
+               SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm,
+               SPIRV_AnyVectorOfPointer]>;
 def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
-    SPIRV_AnyImage, SPIRV_AnyTensorArm
+    SPIRV_AnyImage, SPIRV_AnyTensorArm, SPIRV_AnyVectorOfPointer
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4651,6 +4665,8 @@ def SPIRV_OC_OpControlBarrierWaitINTEL        : I32EnumAttrCase<"OpControlBarrie
 def SPIRV_OC_OpGroupIMulKHR                   : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
 def SPIRV_OC_OpGroupFMulKHR                   : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
 def SPIRV_OC_OpRoundFToTF32INTEL              : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
+def SPIRV_OC_OpMaskedGatherINTEL              : I32EnumAttrCase<"OpMaskedGatherINTEL", 6428>;
+def SPIRV_OC_OpMaskedScatterINTEL             : I32EnumAttrCase<"OpMaskedScatterINTEL", 6429>;
 
 def SPIRV_OpcodeAttr :
     SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4763,7 +4779,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
       SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
       SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
-      SPIRV_OC_OpRoundFToTF32INTEL
+      SPIRV_OC_OpRoundFToTF32INTEL,
+      SPIRV_OC_OpMaskedGatherINTEL, SPIRV_OC_OpMaskedScatterINTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 2a7fa534cc3dc..0ea57960fec74 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -250,6 +250,119 @@ def SPIRV_INTELControlBarrierWaitOp
 }
 
 
+// -----
+
+def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
+  let summary = "See extension SPV_INTEL_masked_gather_scatter";
+
+  let description = [{
+    Reads values from a vector of pointers gathering them into a result
+    vector. Lanes where the mask is false receive the corresponding
+    FillEmpty value.
+
+    Result Type must be a vector of numerical type.
+
+    PtrVector must be a vector of pointers to the scalar element type of
+    Result Type. It must have the same number of components as Result Type.
+
+    Alignment is the known minimum alignment in bytes of each pointer in
+    PtrVector.
+
+    Mask must be a vector of boolean type with the same number of components
+    as Result Type.
+
+    FillEmpty must have the same type as Result Type.
+
+    #### Example:
+
+    ```mlir
+    %result = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+              : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+                vector<4xi1>, vector<4xf32> -> vector<4xf32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_masked_gather_scatter]>,
+    Capability<[SPIRV_C_MaskedGatherScatterINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_AnyVectorOfPointer:$ptr_vector,
+    SPIRV_Int32:$alignment,
+    SPIRV_Vector:$mask,
+    SPIRV_Vector:$fill_empty
+  );
+
+  let results = (outs
+    SPIRV_Vector:$result
+  );
+
+  let assemblyFormat = [{
+    $ptr_vector `,` $alignment `,` $mask `,` $fill_empty attr-dict `:`
+      type($ptr_vector) `,` type($alignment) `,`
+      type($mask) `,` type($fill_empty) `->` type($result)
+  }];
+
+  let hasVerifier = 1;
+}
+
+// -----
+
+def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
+  let summary = "See extension SPV_INTEL_masked_gather_scatter";
+
+  let description = [{
+    Writes values from a vector into memory locations pointed to by a
+    vector of pointers. Only lanes where the mask is true are written.
+
+    PtrVector must be a vector of pointers to the scalar element type of
+    InputVector. It must have the same number of components as InputVector.
+
+    Alignment is the known minimum alignment in bytes of each pointer in
+    PtrVector.
+
+    Mask must be a vector of boolean type with the same number of components
+    as InputVector.
+
+    InputVector is the vector of values to scatter into memory.
+
+    #### Example:
+
+    ```mlir
+    spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+              : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+                vector<4xi1>, vector<4xf32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_masked_gather_scatter]>,
+    Capability<[SPIRV_C_MaskedGatherScatterINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_AnyVectorOfPointer:$ptr_vector,
+    SPIRV_Int32:$alignment,
+    SPIRV_Vector:$mask,
+    SPIRV_Vector:$input_vector
+  );
+
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $ptr_vector `,` $alignment `,` $mask `,` $input_vector attr-dict `:`
+      type($ptr_vector) `,` type($alignment) `,`
+      type($mask) `,` type($input_vector)
+  }];
+
+  let hasVerifier = 1;
+}
+
 // -----
 
 #endif // MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 4a0c29d4b5d90..085482b6099d7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -36,6 +36,7 @@ struct PointerTypeStorage;
 struct RuntimeArrayTypeStorage;
 struct SampledImageTypeStorage;
 struct StructTypeStorage;
+struct VectorOfPointerTypeStorage;
 
 } // namespace detail
 
@@ -515,6 +516,25 @@ class TensorArmType
   operator ShapedType() const { return cast<ShapedType>(*this); }
 };
 
+/// SPIR-V vector of pointers type. Represents an OpTypeVector whose element
+/// type is an OpTypePointer. This is needed because MLIR's built-in VectorType
+/// does not support pointer element types. Used by the
+/// SPV_INTEL_masked_gather_scatter extension.
+class VectorOfPointerType
+    : public Type::TypeBase<VectorOfPointerType, CompositeType,
+                            detail::VectorOfPointerTypeStorage> {
+public:
+  using Base::Base;
+
+  static constexpr StringLiteral name = "spirv.vecptr";
+
+  static VectorOfPointerType get(PointerType elementType, unsigned numElements);
+
+  PointerType getElementType() const;
+
+  unsigned getNumElements() const;
+};
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index c9b22fe145d88..160d460f2cf6f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -837,6 +837,43 @@ static Type parseStructType(SPIRVDialect const &dialect,
                          structDecorationInfo);
 }
 
+// vecptr-type ::= `vecptr` `<` integer-literal `x` pointer-type `>`
+static Type parseVectorOfPointerType(SPIRVDialect const &dialect,
+                                     DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return Type();
+
+  int64_t count = 0;
+  SMLoc countLoc = parser.getCurrentLocation();
+  if (parser.parseInteger(count))
+    return Type();
+
+  if (parser.parseComma())
+    return Type();
+  if (!llvm::is_contained({2, 3, 4, 8, 16}, count)) {
+    parser.emitError(countLoc,
+                     "vector length must be 2, 3, 4, 8, or 16, but got ")
+        << count;
+    return Type();
+  }
+
+  Type elementType = parseAndVerifyType(dialect, parser);
+  if (!elementType)
+    return Type();
+
+  auto ptrType = dyn_cast<spirv::PointerType>(elementType);
+  if (!ptrType) {
+    parser.emitError(parser.getNameLoc(),
+                     "vecptr element type must be a spirv.ptr type");
+    return Type();
+  }
+
+  if (parser.parseGreater())
+    return Type();
+
+  return VectorOfPointerType::get(ptrType, count);
+}
+
 // spirv-type ::= array-type
 //              | element-type
 //              | image-type
@@ -844,6 +881,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
 //              | runtime-array-type
 //              | sampled-image-type
 //              | struct-type
+//              | vecptr-type
 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
   StringRef keyword;
   if (parser.parseKeyword(&keyword))
@@ -867,6 +905,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseMatrixType(*this, parser);
   if (keyword == "arm.tensor")
     return parseTensorArmType(*this, parser);
+  if (keyword == "vecptr")
+    return parseVectorOfPointerType(*this, parser);
   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
   return Type();
 }
@@ -998,11 +1038,16 @@ static void print(TensorArmType type, DialectAsmPrinter &os) {
   os << "x" << type.getElementType() << ">";
 }
 
+static void print(VectorOfPointerType type, DialectAsmPrinter &os) {
+  os << "vecptr<" << type.getNumElements() << ", " << type.getElementType()
+     << ">";
+}
+
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
-            ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
-          [&](auto type) { print(type, os); })
+            ImageType, SampledImageType, StructType, MatrixType, TensorArmType,
+            VectorOfPointerType>([&](auto type) { print(type, os); })
       .DefaultUnreachable("Unhandled SPIR-V type");
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index cecc8c2194237..9be4d07eeca31 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1415,6 +1415,71 @@ LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELMaskedGatherOp::verify() {
+  auto ptrVecType = cast<spirv::VectorOfPointerType>(getPtrVector().getType());
+  auto resultType = cast<VectorType>(getResult().getType());
+  unsigned numElems = resultType.getNumElements();
+
+  // Verify pointee type matches result element type.
+  if (ptrVecType.getElementType().getPointeeType() !=
+      resultType.getElementType())
+    return emitOpError(
+        "pointer pointee type must match result vector element type");
+
+  // Verify element counts match.
+  if (ptrVecType.getNumElements() != numElems)
+    return emitOpError(
+        "ptr_vector must have the same number of elements as result");
+
+  // Verify mask is a vector of i1.
+  auto maskType = cast<VectorType>(getMask().getType());
+  if (!maskType.getElementType().isInteger(1))
+    return emitOpError("mask must be a vector of i1");
+  if (maskType.getNumElements() != numElems)
+    return emitOpError("mask must have the same number of elements as result");
+
+  // Verify fill_empty matches result type.
+  if (getFillEmpty().getType() != resultType)
+    return emitOpError("fill_empty must have the same type as result");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedScatter
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELMaskedScatterOp::verify() {
+  auto ptrVecType = cast<spirv::VectorOfPointerType>(getPtrVector().getType());
+  auto inputType = cast<VectorType>(getInputVector().getType());
+  unsigned numElems = inputType.getNumElements();
+
+  // Verify pointee type matches input element type.
+  if (ptrVecType.getElementType().getPointeeType() !=
+      inputType.getElementType())
+    return emitOpError(
+        "pointer pointee type must match input vector element type");
+
+  // Verify element counts match.
+  if (ptrVecType.getNumElements() != numElems)
+    return emitOpError(
+        "ptr_vector must have the same number of elements as input_vector");
+
+  // Verify mask is a vector of i1.
+  auto maskType = cast<VectorType>(getMask().getType());
+  if (!maskType.getElementType().isInteger(1))
+    return emitOpError("mask must be a vector of i1");
+  if (maskType.getNumElements() != numElems)
+    return emitOpError(
+        "mask must have the same number of elements as input_vector");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.IAddCarryOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 331d98c1d9313..c03ca72c0a908 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -50,6 +50,9 @@ class TypeExtensionVisitor {
             [this](auto concreteType) { addConcrete(concreteType); })
         .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
             [this](auto concreteType) { add(concreteType.getElementType()); })
+        .Case([this](VectorOfPointerType concreteType) {
+          add(concreteType.getElementType());
+        })
         .Case([this](SampledImageType concreteType) {
           add(concreteType.getImageType());
         })
@@ -100,6 +103,9 @@ class TypeCapabilityVisitor {
         .Case([this](ArrayType concreteType) {
           add(concreteType.getElementType());
         })
+        .Case([this](VectorOfPointerType concreteType) {
+          add(concreteType.getElementType());
+        })
         .Case([this](SampledImageType concreteType) {
           add(concreteType.getImageType());
         })
@@ -181,8 +187,8 @@ bool CompositeType::classof(Type type) {
   if (auto vectorType = dyn_cast<VectorType>(type))
     return isValid(vectorType);
   return isa<spirv::ArrayType, spirv::CooperativeMatrixType, spirv::MatrixType,
-             spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType>(
-      type);
+             spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType,
+             spirv::VectorOfPointerType>(type);
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -195,6 +201,9 @@ Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
       .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
             TensorArmType>([](auto type) { return type.getElementType(); })
+      .Case([](VectorOfPointerType type) -> Type {
+        return type.getElementType();
+      })
       .Case([](MatrixType type) { return type.getColumnType(); })
       .Case([index](StructType type) { return type.getElementType(index); })
       .DefaultUnreachable("Invalid composite type");
@@ -202,7 +211,8 @@ Type CompositeType::getElementType(unsigned index) const {
 
 unsigned CompositeType::getNumElements() const {
   return TypeSwitch<SPIRVType, unsigned>(*this)
-      .Case<ArrayType, StructType, TensorArmType, VectorType>(
+      .Case<ArrayType, StructType, TensorArmType, VectorType,
+            VectorOfPointerType>(
           [](auto type) { return type.getNumElements(); })
       .Case([](MatrixType type) { return type.getNumColumns(); })
       .DefaultUnreachable("Invalid type for number of elements query");
@@ -1325,11 +1335,49 @@ TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// VectorOfPointerType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::VectorOfPointerTypeStorage final : TypeStorage {
+  using KeyTy = std::pair<PointerType, unsigned>;
+
+  static VectorOfPointerTypeStorage *construct(TypeStorageAllocator &allocator,
+                                               const KeyTy &key) {
+    return new (allocator.allocate<VectorOfPointerTypeStorage>())
+        VectorOfPointerTypeStorage(key);
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(elementType, numElements);
+  }
+
+  VectorOfPointerTypeStorage(const KeyTy &key)
+      : elementType(key.first), numElements(key.second) {}
+
+  PointerType elementType;
+  unsigned numElements;
+};
+
+VectorOfPointerType VectorOfPointerType::get(PointerType elementType,
+                                             unsigned numElements) {
+  return Base::get(elementType.getContext(), elementType, numElements);
+}
+
+PointerType VectorOfPointerType::getElementType() const {
+  return getImpl()->elementType;
+}
+
+unsigned VectorOfPointerType::getNumElements() const {
+  return getImpl()->numElements;
+}
+
 //===----------------------------------------------------------------------===//
 // SPIR-V Dialect
 //===----------------------------------------------------------------------===//
 
 void SPIRVDialect::registerTypes() {
   addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
-           RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
+           RuntimeArrayType, SampledImageType, StructType, TensorArmType,
+           VectorOfPointerType>();
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e218bee4b4fe3..9ea8f0f06337c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1137,7 +1137,11 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
       return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
              << operands[1];
     }
-    typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
+    if (auto ptrType = dyn_cast<spirv::PointerType>(elementTy))
+      typeMap[operands[0]] =
+          spirv::VectorOfPointerType::get(ptrType, operands[2]);
+    else
+      typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
   } break;
   case spirv::Opcode::OpTypePointer: {
     return processOpTypePointer(operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c21cb27b072f1..c3ee16486b751 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -626,6 +626,18 @@ LogicalResult Serializer::prepareBasicType(
     return success();
   }
 
+  if (auto vecPtrType = dyn_cast<spirv::VectorOfPointerType>(type)) {
+    uint32_t elementTypeID = 0;
+    if (failed(processTypeImpl(loc, vecPtrType.getElementType(), elementTypeID,
+                               serializationCtx))) {
+      return failure();
+    }
+    typeEnum = spirv::Opcode::OpTypeVector;
+    operands.push_back(elementTypeID);
+    operands.push_back(vecPtrType.getNumElements());
+    return success();
+  }
+
   if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
     typeEnum = spirv::Opcode::OpTypeImage;
     uint32_t sampledTypeID = 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 6e4126172f670..826fc0f964361 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -100,7 +100,7 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
 // -----
 
 func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
-  // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
+  // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type or any SPIR-V vector of pointer type, but got 'vector<2x2xi1>'}}
   %0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
   return %0: vector<4x2xi1>
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index d124c02231161..fa3c90fef82c9 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -126,6 +126,126 @@ spirv.func @split_barrier() "None" {
 // spirv.INTEL.CacheControls
 //===----------------------------------------------------------------------===//
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather
+//===----------------------------------------------------------------------===//
+
+spirv.func @masked_gather(
+    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %fill : vector<4xf32>) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.MaskedGather
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xf32> -> vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_i32(
+    %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %fill : vector<4xi32>) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.MaskedGather
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xi32> -> vector<4xi32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_pointee_type_mismatch(
+    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %fill : vector<4xi32>) "None" {
+  // expected-error @+1 {{pointer pointee type must match result vector element type}}
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xi32> -> vector<4xi32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_elem_count_mismatch(
+    %ptrs : !spirv.vecptr<2, !spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %fill : vector<4xf32>) "None" {
+  // expected-error @+1 {{ptr_vector must have the same number of elements as result}}
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : !spirv.vecptr<2, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xf32> -> vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_mask_not_bool(
+    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi8>,
+    %fill : vector<4xf32>) "None" {
+  // expected-error @+1 {{mask must be a vector of i1}}
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi8>, vector<4xf32> -> vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedScatter
+//===----------------------------------------------------------------------===//
+
+spirv.func @masked_scatter(
+    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %values : vector<4xf32>) "None" {
+  // CHECK: spirv.INTEL.MaskedScatter
+  spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_scatter_pointee_mismatch(
+    %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %values : vector<4xf32>) "None" {
+  // expected-error @+1 {{pointer pointee type must match input vector element type}}
+  spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+       : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_scatter_mask_count_mismatch(
+    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<2xi1>,
+    %values : vector<4xf32>) "None" {
+  // expected-error @+1 {{mask must have the same number of elements as input_vector}}
+  spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<2xi1>, vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
   spirv.func @foo() "None" {
     // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index e12b70cc5c139..5c219e0c86191 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () {
 //===----------------------------------------------------------------------===//
 
 func.func @ccr_result_not_composite() -> () {
-  // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
+  // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type or any SPIR-V vector of pointer type, but got 'i32'}}
   %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
   return
 }
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 118bed8be7095..682fa2cd8320d 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -59,6 +59,66 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, TensorF
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather / MaskedScatter
+//===----------------------------------------------------------------------===//
+
+spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage, MaskedGatherScatterINTEL], [SPV_INTEL_masked_gather_scatter]> {
+  // CHECK-LABEL: @masked_gather_f32
+  spirv.func @masked_gather_f32(
+      %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+      %alignment : i32,
+      %mask : vector<4xi1>,
+      %fill : vector<4xf32>) "None" {
+    // CHECK: spirv.INTEL.MaskedGather
+    %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+         : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+           vector<4xi1>, vector<4xf32> -> vector<4xf32>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @masked_gather_i32
+  spirv.func @masked_gather_i32(
+      %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+      %alignment : i32,
+      %mask : vector<4xi1>,
+      %fill : vector<4xi32>) "None" {
+    // CHECK: spirv.INTEL.MaskedGather
+    %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+         : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+           vector<4xi1>, vector<4xi32> -> vector<4xi32>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @masked_scatter_f32
+  spirv.func @masked_scatter_f32(
+      %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+      %alignment : i32,
+      %mask : vector<4xi1>,
+      %values : vector<4xf32>) "None" {
+    // CHECK: spirv.INTEL.MaskedScatter
+    spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+         : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+           vector<4xi1>, vector<4xf32>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @masked_scatter_i32
+  spirv.func @masked_scatter_i32(
+      %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+      %alignment : i32,
+      %mask : vector<4xi1>,
+      %values : vector<4xi32>) "None" {
+    // CHECK: spirv.INTEL.MaskedScatter
+    spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+         : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+           vector<4xi1>, vector<4xi32>
+    spirv.Return
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//

>From c35ea72587c8be892af69ff4930db8bb2472a84c Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 7 Apr 2026 06:40:46 +0200
Subject: [PATCH 2/3] Address review comments

---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  9 +--
 .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 45 +++++++++----
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        | 26 ++------
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    | 49 +-------------
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 65 -------------------
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 59 ++---------------
 .../SPIRV/Deserialization/Deserializer.cpp    |  6 +-
 .../Target/SPIRV/Serialization/Serializer.cpp | 12 ----
 mlir/test/Dialect/SPIRV/IR/composite-ops.mlir |  2 +-
 mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 42 ++++++------
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir |  2 +-
 mlir/test/Target/SPIRV/intel-ext-ops.mlir     | 16 ++---
 12 files changed, 79 insertions(+), 254 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d9a573b59c368..cd6ea54555f4b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4266,7 +4266,6 @@ def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_
 def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
 def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
 def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
-def SPIRV_IsVectorOfPointerType : CPred<"::llvm::isa<::mlir::spirv::VectorOfPointerType>($_self)">;
 
 // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
 // for the definition of the following types and type categories.
@@ -4310,22 +4309,18 @@ def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
                                 "any SPIR-V sampled image type">;
 def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
                                  "any SPIR-V tensorArm type">;
-def SPIRV_AnyVectorOfPointer : DialectType<SPIRV_Dialect,
-                                 SPIRV_IsVectorOfPointerType,
-                                 "any SPIR-V vector of pointer type">;
 
 def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
-               SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm,
-               SPIRV_AnyVectorOfPointer]>;
+               SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm]>;
 def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
-    SPIRV_AnyImage, SPIRV_AnyTensorArm, SPIRV_AnyVectorOfPointer
+    SPIRV_AnyImage, SPIRV_AnyTensorArm
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 0ea57960fec74..1502187683c1f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -252,8 +252,19 @@ def SPIRV_INTELControlBarrierWaitOp
 
 // -----
 
-def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
-  let summary = "See extension SPV_INTEL_masked_gather_scatter";
+def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather",
+    [AllTypesMatch<["fill_empty", "result"]>,
+     TypesMatchWith<"pointee type of ptr_vector must match result element type",
+                    "ptr_vector", "result",
+                    "VectorType::get("
+                      "::llvm::cast<VectorType>($_self).getShape(), "
+                      "::llvm::cast<spirv::PointerType>("
+                        "::llvm::cast<VectorType>($_self).getElementType())"
+                      ".getPointeeType())">,
+     TypesMatchWith<"mask must be a vector of i1 matching result shape",
+                    "result", "mask",
+                    "getUnaryOpResultType($_self)">]> {
+  let summary = "Gather values from memory using a vector of pointers and a mask";
 
   let description = [{
     Reads values from a vector of pointers gathering them into a result
@@ -277,7 +288,7 @@ def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
 
     ```mlir
     %result = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
-              : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+              : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
                 vector<4xi1>, vector<4xf32> -> vector<4xf32>
     ```
   }];
@@ -290,9 +301,9 @@ def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
   ];
 
   let arguments = (ins
-    SPIRV_AnyVectorOfPointer:$ptr_vector,
+    SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
     SPIRV_Int32:$alignment,
-    SPIRV_Vector:$mask,
+    SPIRV_VectorOf<I1>:$mask,
     SPIRV_Vector:$fill_empty
   );
 
@@ -306,13 +317,23 @@ def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
       type($mask) `,` type($fill_empty) `->` type($result)
   }];
 
-  let hasVerifier = 1;
+  let hasVerifier = 0;
 }
 
 // -----
 
-def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
-  let summary = "See extension SPV_INTEL_masked_gather_scatter";
+def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter",
+    [TypesMatchWith<"pointee type of ptr_vector must match input element type",
+                    "ptr_vector", "input_vector",
+                    "VectorType::get("
+                      "::llvm::cast<VectorType>($_self).getShape(), "
+                      "::llvm::cast<spirv::PointerType>("
+                        "::llvm::cast<VectorType>($_self).getElementType())"
+                      ".getPointeeType())">,
+     TypesMatchWith<"mask must be a vector of i1 matching input shape",
+                    "input_vector", "mask",
+                    "getUnaryOpResultType($_self)">]> {
+  let summary = "Scatter values to memory using a vector of pointers and a mask";
 
   let description = [{
     Writes values from a vector into memory locations pointed to by a
@@ -333,7 +354,7 @@ def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
 
     ```mlir
     spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
-              : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+              : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
                 vector<4xi1>, vector<4xf32>
     ```
   }];
@@ -346,9 +367,9 @@ def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
   ];
 
   let arguments = (ins
-    SPIRV_AnyVectorOfPointer:$ptr_vector,
+    SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
     SPIRV_Int32:$alignment,
-    SPIRV_Vector:$mask,
+    SPIRV_VectorOf<I1>:$mask,
     SPIRV_Vector:$input_vector
   );
 
@@ -360,7 +381,7 @@ def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
       type($mask) `,` type($input_vector)
   }];
 
-  let hasVerifier = 1;
+  let hasVerifier = 0;
 }
 
 // -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 085482b6099d7..90b695a7ba2fd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Location.h"
@@ -36,7 +37,6 @@ struct PointerTypeStorage;
 struct RuntimeArrayTypeStorage;
 struct SampledImageTypeStorage;
 struct StructTypeStorage;
-struct VectorOfPointerTypeStorage;
 
 } // namespace detail
 
@@ -173,8 +173,9 @@ class ImageType
 };
 
 // SPIR-V pointer type
-class PointerType : public Type::TypeBase<PointerType, SPIRVType,
-                                          detail::PointerTypeStorage> {
+class PointerType
+    : public Type::TypeBase<PointerType, SPIRVType, detail::PointerTypeStorage,
+                            VectorElementTypeInterface::Trait> {
 public:
   using Base::Base;
 
@@ -516,25 +517,6 @@ class TensorArmType
   operator ShapedType() const { return cast<ShapedType>(*this); }
 };
 
-/// SPIR-V vector of pointers type. Represents an OpTypeVector whose element
-/// type is an OpTypePointer. This is needed because MLIR's built-in VectorType
-/// does not support pointer element types. Used by the
-/// SPV_INTEL_masked_gather_scatter extension.
-class VectorOfPointerType
-    : public Type::TypeBase<VectorOfPointerType, CompositeType,
-                            detail::VectorOfPointerTypeStorage> {
-public:
-  using Base::Base;
-
-  static constexpr StringLiteral name = "spirv.vecptr";
-
-  static VectorOfPointerType get(PointerType elementType, unsigned numElements);
-
-  PointerType getElementType() const;
-
-  unsigned getNumElements() const;
-};
-
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 160d460f2cf6f..c9b22fe145d88 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -837,43 +837,6 @@ static Type parseStructType(SPIRVDialect const &dialect,
                          structDecorationInfo);
 }
 
-// vecptr-type ::= `vecptr` `<` integer-literal `x` pointer-type `>`
-static Type parseVectorOfPointerType(SPIRVDialect const &dialect,
-                                     DialectAsmParser &parser) {
-  if (parser.parseLess())
-    return Type();
-
-  int64_t count = 0;
-  SMLoc countLoc = parser.getCurrentLocation();
-  if (parser.parseInteger(count))
-    return Type();
-
-  if (parser.parseComma())
-    return Type();
-  if (!llvm::is_contained({2, 3, 4, 8, 16}, count)) {
-    parser.emitError(countLoc,
-                     "vector length must be 2, 3, 4, 8, or 16, but got ")
-        << count;
-    return Type();
-  }
-
-  Type elementType = parseAndVerifyType(dialect, parser);
-  if (!elementType)
-    return Type();
-
-  auto ptrType = dyn_cast<spirv::PointerType>(elementType);
-  if (!ptrType) {
-    parser.emitError(parser.getNameLoc(),
-                     "vecptr element type must be a spirv.ptr type");
-    return Type();
-  }
-
-  if (parser.parseGreater())
-    return Type();
-
-  return VectorOfPointerType::get(ptrType, count);
-}
-
 // spirv-type ::= array-type
 //              | element-type
 //              | image-type
@@ -881,7 +844,6 @@ static Type parseVectorOfPointerType(SPIRVDialect const &dialect,
 //              | runtime-array-type
 //              | sampled-image-type
 //              | struct-type
-//              | vecptr-type
 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
   StringRef keyword;
   if (parser.parseKeyword(&keyword))
@@ -905,8 +867,6 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseMatrixType(*this, parser);
   if (keyword == "arm.tensor")
     return parseTensorArmType(*this, parser);
-  if (keyword == "vecptr")
-    return parseVectorOfPointerType(*this, parser);
   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
   return Type();
 }
@@ -1038,16 +998,11 @@ static void print(TensorArmType type, DialectAsmPrinter &os) {
   os << "x" << type.getElementType() << ">";
 }
 
-static void print(VectorOfPointerType type, DialectAsmPrinter &os) {
-  os << "vecptr<" << type.getNumElements() << ", " << type.getElementType()
-     << ">";
-}
-
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
-            ImageType, SampledImageType, StructType, MatrixType, TensorArmType,
-            VectorOfPointerType>([&](auto type) { print(type, os); })
+            ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
+          [&](auto type) { print(type, os); })
       .DefaultUnreachable("Unhandled SPIR-V type");
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 9be4d07eeca31..cecc8c2194237 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1415,71 +1415,6 @@ LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.MaskedGather
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELMaskedGatherOp::verify() {
-  auto ptrVecType = cast<spirv::VectorOfPointerType>(getPtrVector().getType());
-  auto resultType = cast<VectorType>(getResult().getType());
-  unsigned numElems = resultType.getNumElements();
-
-  // Verify pointee type matches result element type.
-  if (ptrVecType.getElementType().getPointeeType() !=
-      resultType.getElementType())
-    return emitOpError(
-        "pointer pointee type must match result vector element type");
-
-  // Verify element counts match.
-  if (ptrVecType.getNumElements() != numElems)
-    return emitOpError(
-        "ptr_vector must have the same number of elements as result");
-
-  // Verify mask is a vector of i1.
-  auto maskType = cast<VectorType>(getMask().getType());
-  if (!maskType.getElementType().isInteger(1))
-    return emitOpError("mask must be a vector of i1");
-  if (maskType.getNumElements() != numElems)
-    return emitOpError("mask must have the same number of elements as result");
-
-  // Verify fill_empty matches result type.
-  if (getFillEmpty().getType() != resultType)
-    return emitOpError("fill_empty must have the same type as result");
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.MaskedScatter
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELMaskedScatterOp::verify() {
-  auto ptrVecType = cast<spirv::VectorOfPointerType>(getPtrVector().getType());
-  auto inputType = cast<VectorType>(getInputVector().getType());
-  unsigned numElems = inputType.getNumElements();
-
-  // Verify pointee type matches input element type.
-  if (ptrVecType.getElementType().getPointeeType() !=
-      inputType.getElementType())
-    return emitOpError(
-        "pointer pointee type must match input vector element type");
-
-  // Verify element counts match.
-  if (ptrVecType.getNumElements() != numElems)
-    return emitOpError(
-        "ptr_vector must have the same number of elements as input_vector");
-
-  // Verify mask is a vector of i1.
-  auto maskType = cast<VectorType>(getMask().getType());
-  if (!maskType.getElementType().isInteger(1))
-    return emitOpError("mask must be a vector of i1");
-  if (maskType.getNumElements() != numElems)
-    return emitOpError(
-        "mask must have the same number of elements as input_vector");
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.IAddCarryOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index c03ca72c0a908..55f7dafcab42d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -50,9 +50,6 @@ class TypeExtensionVisitor {
             [this](auto concreteType) { addConcrete(concreteType); })
         .Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
             [this](auto concreteType) { add(concreteType.getElementType()); })
-        .Case([this](VectorOfPointerType concreteType) {
-          add(concreteType.getElementType());
-        })
         .Case([this](SampledImageType concreteType) {
           add(concreteType.getImageType());
         })
@@ -103,9 +100,6 @@ class TypeCapabilityVisitor {
         .Case([this](ArrayType concreteType) {
           add(concreteType.getElementType());
         })
-        .Case([this](VectorOfPointerType concreteType) {
-          add(concreteType.getElementType());
-        })
         .Case([this](SampledImageType concreteType) {
           add(concreteType.getImageType());
         })
@@ -187,23 +181,21 @@ bool CompositeType::classof(Type type) {
   if (auto vectorType = dyn_cast<VectorType>(type))
     return isValid(vectorType);
   return isa<spirv::ArrayType, spirv::CooperativeMatrixType, spirv::MatrixType,
-             spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType,
-             spirv::VectorOfPointerType>(type);
+             spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType>(
+      type);
 }
 
 bool CompositeType::isValid(VectorType type) {
   return type.getRank() == 1 &&
          llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
-         isa<ScalarType>(type.getElementType());
+         (isa<ScalarType>(type.getElementType()) ||
+          isa<PointerType>(type.getElementType()));
 }
 
 Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
       .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
             TensorArmType>([](auto type) { return type.getElementType(); })
-      .Case([](VectorOfPointerType type) -> Type {
-        return type.getElementType();
-      })
       .Case([](MatrixType type) { return type.getColumnType(); })
       .Case([index](StructType type) { return type.getElementType(index); })
       .DefaultUnreachable("Invalid composite type");
@@ -211,8 +203,7 @@ Type CompositeType::getElementType(unsigned index) const {
 
 unsigned CompositeType::getNumElements() const {
   return TypeSwitch<SPIRVType, unsigned>(*this)
-      .Case<ArrayType, StructType, TensorArmType, VectorType,
-            VectorOfPointerType>(
+      .Case<ArrayType, StructType, TensorArmType, VectorType>(
           [](auto type) { return type.getNumElements(); })
       .Case([](MatrixType type) { return type.getNumColumns(); })
       .DefaultUnreachable("Invalid type for number of elements query");
@@ -1335,49 +1326,11 @@ TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// VectorOfPointerType
-//===----------------------------------------------------------------------===//
-
-struct spirv::detail::VectorOfPointerTypeStorage final : TypeStorage {
-  using KeyTy = std::pair<PointerType, unsigned>;
-
-  static VectorOfPointerTypeStorage *construct(TypeStorageAllocator &allocator,
-                                               const KeyTy &key) {
-    return new (allocator.allocate<VectorOfPointerTypeStorage>())
-        VectorOfPointerTypeStorage(key);
-  }
-
-  bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, numElements);
-  }
-
-  VectorOfPointerTypeStorage(const KeyTy &key)
-      : elementType(key.first), numElements(key.second) {}
-
-  PointerType elementType;
-  unsigned numElements;
-};
-
-VectorOfPointerType VectorOfPointerType::get(PointerType elementType,
-                                             unsigned numElements) {
-  return Base::get(elementType.getContext(), elementType, numElements);
-}
-
-PointerType VectorOfPointerType::getElementType() const {
-  return getImpl()->elementType;
-}
-
-unsigned VectorOfPointerType::getNumElements() const {
-  return getImpl()->numElements;
-}
-
 //===----------------------------------------------------------------------===//
 // SPIR-V Dialect
 //===----------------------------------------------------------------------===//
 
 void SPIRVDialect::registerTypes() {
   addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
-           RuntimeArrayType, SampledImageType, StructType, TensorArmType,
-           VectorOfPointerType>();
+           RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 9ea8f0f06337c..e218bee4b4fe3 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1137,11 +1137,7 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
       return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
              << operands[1];
     }
-    if (auto ptrType = dyn_cast<spirv::PointerType>(elementTy))
-      typeMap[operands[0]] =
-          spirv::VectorOfPointerType::get(ptrType, operands[2]);
-    else
-      typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
+    typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
   } break;
   case spirv::Opcode::OpTypePointer: {
     return processOpTypePointer(operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c3ee16486b751..c21cb27b072f1 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -626,18 +626,6 @@ LogicalResult Serializer::prepareBasicType(
     return success();
   }
 
-  if (auto vecPtrType = dyn_cast<spirv::VectorOfPointerType>(type)) {
-    uint32_t elementTypeID = 0;
-    if (failed(processTypeImpl(loc, vecPtrType.getElementType(), elementTypeID,
-                               serializationCtx))) {
-      return failure();
-    }
-    typeEnum = spirv::Opcode::OpTypeVector;
-    operands.push_back(elementTypeID);
-    operands.push_back(vecPtrType.getNumElements());
-    return success();
-  }
-
   if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
     typeEnum = spirv::Opcode::OpTypeImage;
     uint32_t sampledTypeID = 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 826fc0f964361..6e4126172f670 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -100,7 +100,7 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
 // -----
 
 func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
-  // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type or any SPIR-V vector of pointer type, but got 'vector<2x2xi1>'}}
+  // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
   %0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
   return %0: vector<4x2xi1>
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index fa3c90fef82c9..06f33a235c660 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -131,13 +131,13 @@ spirv.func @split_barrier() "None" {
 //===----------------------------------------------------------------------===//
 
 spirv.func @masked_gather(
-    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
     %alignment : i32,
     %mask : vector<4xi1>,
     %fill : vector<4xf32>) "None" {
   // CHECK: {{%.*}} = spirv.INTEL.MaskedGather
   %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
-       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
          vector<4xi1>, vector<4xf32> -> vector<4xf32>
   spirv.Return
 }
@@ -145,13 +145,13 @@ spirv.func @masked_gather(
 // -----
 
 spirv.func @masked_gather_i32(
-    %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+    %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
     %alignment : i32,
     %mask : vector<4xi1>,
     %fill : vector<4xi32>) "None" {
   // CHECK: {{%.*}} = spirv.INTEL.MaskedGather
   %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
-       : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+       : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
          vector<4xi1>, vector<4xi32> -> vector<4xi32>
   spirv.Return
 }
@@ -159,13 +159,13 @@ spirv.func @masked_gather_i32(
 // -----
 
 spirv.func @masked_gather_pointee_type_mismatch(
-    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
     %alignment : i32,
     %mask : vector<4xi1>,
     %fill : vector<4xi32>) "None" {
-  // expected-error @+1 {{pointer pointee type must match result vector element type}}
+  // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that pointee type of ptr_vector must match result element type}}
   %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
-       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
          vector<4xi1>, vector<4xi32> -> vector<4xi32>
   spirv.Return
 }
@@ -173,13 +173,13 @@ spirv.func @masked_gather_pointee_type_mismatch(
 // -----
 
 spirv.func @masked_gather_elem_count_mismatch(
-    %ptrs : !spirv.vecptr<2, !spirv.ptr<f32, CrossWorkgroup>>,
+    %ptrs : vector<2x!spirv.ptr<f32, CrossWorkgroup>>,
     %alignment : i32,
     %mask : vector<4xi1>,
     %fill : vector<4xf32>) "None" {
-  // expected-error @+1 {{ptr_vector must have the same number of elements as result}}
+  // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that pointee type of ptr_vector must match result element type}}
   %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
-       : !spirv.vecptr<2, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+       : vector<2x!spirv.ptr<f32, CrossWorkgroup>>, i32,
          vector<4xi1>, vector<4xf32> -> vector<4xf32>
   spirv.Return
 }
@@ -187,13 +187,13 @@ spirv.func @masked_gather_elem_count_mismatch(
 // -----
 
 spirv.func @masked_gather_mask_not_bool(
-    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
     %alignment : i32,
     %mask : vector<4xi8>,
     %fill : vector<4xf32>) "None" {
-  // expected-error @+1 {{mask must be a vector of i1}}
+  // expected-error @+1 {{operand #2 must be fixed-length vector of 1-bit signless integer values of length 2/3/4/8/16 of ranks 1, but got 'vector<4xi8>'}}
   %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
-       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
          vector<4xi8>, vector<4xf32> -> vector<4xf32>
   spirv.Return
 }
@@ -205,13 +205,13 @@ spirv.func @masked_gather_mask_not_bool(
 //===----------------------------------------------------------------------===//
 
 spirv.func @masked_scatter(
-    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
     %alignment : i32,
     %mask : vector<4xi1>,
     %values : vector<4xf32>) "None" {
   // CHECK: spirv.INTEL.MaskedScatter
   spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
-       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
          vector<4xi1>, vector<4xf32>
   spirv.Return
 }
@@ -219,13 +219,13 @@ spirv.func @masked_scatter(
 // -----
 
 spirv.func @masked_scatter_pointee_mismatch(
-    %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+    %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
     %alignment : i32,
     %mask : vector<4xi1>,
     %values : vector<4xf32>) "None" {
-  // expected-error @+1 {{pointer pointee type must match input vector element type}}
+  // expected-error @+1 {{'spirv.INTEL.MaskedScatter' op failed to verify that pointee type of ptr_vector must match input element type}}
   spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
-       : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+       : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
          vector<4xi1>, vector<4xf32>
   spirv.Return
 }
@@ -233,13 +233,13 @@ spirv.func @masked_scatter_pointee_mismatch(
 // -----
 
 spirv.func @masked_scatter_mask_count_mismatch(
-    %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
     %alignment : i32,
     %mask : vector<2xi1>,
     %values : vector<4xf32>) "None" {
-  // expected-error @+1 {{mask must have the same number of elements as input_vector}}
+  // expected-error @+1 {{'spirv.INTEL.MaskedScatter' op failed to verify that mask must be a vector of i1 matching input shape}}
   spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
-       : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
          vector<2xi1>, vector<4xf32>
   spirv.Return
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 5c219e0c86191..e12b70cc5c139 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () {
 //===----------------------------------------------------------------------===//
 
 func.func @ccr_result_not_composite() -> () {
-  // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type or any SPIR-V vector of pointer type, but got 'i32'}}
+  // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
   %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
   return
 }
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 682fa2cd8320d..4296c3f1a9682 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -66,52 +66,52 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, TensorF
 spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage, MaskedGatherScatterINTEL], [SPV_INTEL_masked_gather_scatter]> {
   // CHECK-LABEL: @masked_gather_f32
   spirv.func @masked_gather_f32(
-      %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+      %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
       %alignment : i32,
       %mask : vector<4xi1>,
       %fill : vector<4xf32>) "None" {
     // CHECK: spirv.INTEL.MaskedGather
     %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
-         : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+         : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
            vector<4xi1>, vector<4xf32> -> vector<4xf32>
     spirv.Return
   }
 
   // CHECK-LABEL: @masked_gather_i32
   spirv.func @masked_gather_i32(
-      %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+      %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
       %alignment : i32,
       %mask : vector<4xi1>,
       %fill : vector<4xi32>) "None" {
     // CHECK: spirv.INTEL.MaskedGather
     %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
-         : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+         : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
            vector<4xi1>, vector<4xi32> -> vector<4xi32>
     spirv.Return
   }
 
   // CHECK-LABEL: @masked_scatter_f32
   spirv.func @masked_scatter_f32(
-      %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+      %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
       %alignment : i32,
       %mask : vector<4xi1>,
       %values : vector<4xf32>) "None" {
     // CHECK: spirv.INTEL.MaskedScatter
     spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
-         : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+         : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
            vector<4xi1>, vector<4xf32>
     spirv.Return
   }
 
   // CHECK-LABEL: @masked_scatter_i32
   spirv.func @masked_scatter_i32(
-      %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+      %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
       %alignment : i32,
       %mask : vector<4xi1>,
       %values : vector<4xi32>) "None" {
     // CHECK: spirv.INTEL.MaskedScatter
     spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
-         : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+         : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
            vector<4xi1>, vector<4xi32>
     spirv.Return
   }

>From 639da5e6cb4233866a55619383cd4b5af7263ac3 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Mon, 13 Apr 2026 10:16:11 +0200
Subject: [PATCH 3/3] Address review comments

---
 .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td    | 10 +++++-----
 mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir    | 16 +++++++++++++++-
 2 files changed, 20 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 1502187683c1f..57a57a18cca33 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -303,12 +303,12 @@ def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather",
   let arguments = (ins
     SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
     SPIRV_Int32:$alignment,
-    SPIRV_VectorOf<I1>:$mask,
-    SPIRV_Vector:$fill_empty
+    SPIRV_VectorOf<SPIRV_Bool>:$mask,
+    SPIRV_VectorOf<SPIRV_Numerical>:$fill_empty
   );
 
   let results = (outs
-    SPIRV_Vector:$result
+    SPIRV_VectorOf<SPIRV_Numerical>:$result
   );
 
   let assemblyFormat = [{
@@ -369,8 +369,8 @@ def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter",
   let arguments = (ins
     SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
     SPIRV_Int32:$alignment,
-    SPIRV_VectorOf<I1>:$mask,
-    SPIRV_Vector:$input_vector
+    SPIRV_VectorOf<SPIRV_Bool>:$mask,
+    SPIRV_VectorOf<SPIRV_Numerical>:$input_vector
   );
 
   let results = (outs);
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 06f33a235c660..ca9f501d23938 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -191,7 +191,7 @@ spirv.func @masked_gather_mask_not_bool(
     %alignment : i32,
     %mask : vector<4xi8>,
     %fill : vector<4xf32>) "None" {
-  // expected-error @+1 {{operand #2 must be fixed-length vector of 1-bit signless integer values of length 2/3/4/8/16 of ranks 1, but got 'vector<4xi8>'}}
+  // expected-error @+1 {{operand #2 must be fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'vector<4xi8>'}}
   %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
        : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
          vector<4xi8>, vector<4xf32> -> vector<4xf32>
@@ -200,6 +200,20 @@ spirv.func @masked_gather_mask_not_bool(
 
 // -----
 
+spirv.func @masked_gather_mask_count_mismatch(
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<2xi1>,
+    %fill : vector<4xf32>) "None" {
+  // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that mask must be a vector of i1 matching result shape}}
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<2xi1>, vector<4xf32> -> vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.MaskedScatter
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list