[Mlir-commits] [mlir] [mlir][SPIR-V] Add support for SPV_INTEL_masked_gather_scatter extension (PR #189099)
Arseniy Obolenskiy
llvmlistbot at llvm.org
Mon Apr 13 01:16:25 PDT 2026
https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/189099
>From bdc658c95274ad076dd14593492abae33113caec Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Fri, 27 Mar 2026 17:29:54 +0100
Subject: [PATCH 1/3] [mlir][SPIR-V] Add MaskedGather/MaskedScatter ops and
VectorOfPointerType for SPV_INTEL_masked_gather_scatter extension
Add support for the SPV_INTEL_masked_gather_scatter extension implemented in #185418
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 25 +++-
.../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 113 +++++++++++++++++
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 20 +++
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 49 ++++++-
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 65 ++++++++++
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 56 +++++++-
.../SPIRV/Deserialization/Deserializer.cpp | 6 +-
.../Target/SPIRV/Serialization/Serializer.cpp | 12 ++
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 120 ++++++++++++++++++
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 2 +-
mlir/test/Target/SPIRV/intel-ext-ops.mlir | 60 +++++++++
12 files changed, 517 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 9f9e2f5f9a677..d9a573b59c368 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -407,6 +407,7 @@ def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_sp
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
+def SPV_INTEL_masked_gather_scatter : I32EnumAttrCase<"SPV_INTEL_masked_gather_scatter", 4034>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -472,6 +473,7 @@ def SPIRV_ExtensionAttr :
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
SPV_INTEL_tensor_float32_conversion,
+ SPV_INTEL_masked_gather_scatter,
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1481,6 +1483,12 @@ def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"
];
}
+def SPIRV_C_MaskedGatherScatterINTEL : I32EnumAttrCase<"MaskedGatherScatterINTEL", 6427> {
+ list<Availability> availability = [
+ Extension<[SPV_INTEL_masked_gather_scatter]>
+ ];
+}
+
def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
list<Availability> availability = [
Extension<[SPV_INTEL_cache_controls]>
@@ -1590,7 +1598,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
- SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_Float8EXT
+ SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_MaskedGatherScatterINTEL,
+ SPIRV_C_Float8EXT
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -4257,6 +4266,7 @@ def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
+def SPIRV_IsVectorOfPointerType : CPred<"::llvm::isa<::mlir::spirv::VectorOfPointerType>($_self)">;
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
@@ -4300,18 +4310,22 @@ def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
"any SPIR-V tensorArm type">;
+def SPIRV_AnyVectorOfPointer : DialectType<SPIRV_Dialect,
+ SPIRV_IsVectorOfPointerType,
+ "any SPIR-V vector of pointer type">;
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm]>;
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm,
+ SPIRV_AnyVectorOfPointer]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
- SPIRV_AnyImage, SPIRV_AnyTensorArm
+ SPIRV_AnyImage, SPIRV_AnyTensorArm, SPIRV_AnyVectorOfPointer
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4651,6 +4665,8 @@ def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrie
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
+def SPIRV_OC_OpMaskedGatherINTEL : I32EnumAttrCase<"OpMaskedGatherINTEL", 6428>;
+def SPIRV_OC_OpMaskedScatterINTEL : I32EnumAttrCase<"OpMaskedScatterINTEL", 6429>;
def SPIRV_OpcodeAttr :
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4763,7 +4779,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
- SPIRV_OC_OpRoundFToTF32INTEL
+ SPIRV_OC_OpRoundFToTF32INTEL,
+ SPIRV_OC_OpMaskedGatherINTEL, SPIRV_OC_OpMaskedScatterINTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 2a7fa534cc3dc..0ea57960fec74 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -250,6 +250,119 @@ def SPIRV_INTELControlBarrierWaitOp
}
+// -----
+
+def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
+ let summary = "See extension SPV_INTEL_masked_gather_scatter";
+
+ let description = [{
+ Reads values from a vector of pointers gathering them into a result
+ vector. Lanes where the mask is false receive the corresponding
+ FillEmpty value.
+
+ Result Type must be a vector of numerical type.
+
+ PtrVector must be a vector of pointers to the scalar element type of
+ Result Type. It must have the same number of components as Result Type.
+
+ Alignment is the known minimum alignment in bytes of each pointer in
+ PtrVector.
+
+ Mask must be a vector of boolean type with the same number of components
+ as Result Type.
+
+ FillEmpty must have the same type as Result Type.
+
+ #### Example:
+
+ ```mlir
+ %result = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_INTEL_masked_gather_scatter]>,
+ Capability<[SPIRV_C_MaskedGatherScatterINTEL]>
+ ];
+
+ let arguments = (ins
+ SPIRV_AnyVectorOfPointer:$ptr_vector,
+ SPIRV_Int32:$alignment,
+ SPIRV_Vector:$mask,
+ SPIRV_Vector:$fill_empty
+ );
+
+ let results = (outs
+ SPIRV_Vector:$result
+ );
+
+ let assemblyFormat = [{
+ $ptr_vector `,` $alignment `,` $mask `,` $fill_empty attr-dict `:`
+ type($ptr_vector) `,` type($alignment) `,`
+ type($mask) `,` type($fill_empty) `->` type($result)
+ }];
+
+ let hasVerifier = 1;
+}
+
+// -----
+
+def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
+ let summary = "See extension SPV_INTEL_masked_gather_scatter";
+
+ let description = [{
+ Writes values from a vector into memory locations pointed to by a
+ vector of pointers. Only lanes where the mask is true are written.
+
+ PtrVector must be a vector of pointers to the scalar element type of
+ InputVector. It must have the same number of components as InputVector.
+
+ Alignment is the known minimum alignment in bytes of each pointer in
+ PtrVector.
+
+ Mask must be a vector of boolean type with the same number of components
+ as InputVector.
+
+ InputVector is the vector of values to scatter into memory.
+
+ #### Example:
+
+ ```mlir
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_INTEL_masked_gather_scatter]>,
+ Capability<[SPIRV_C_MaskedGatherScatterINTEL]>
+ ];
+
+ let arguments = (ins
+ SPIRV_AnyVectorOfPointer:$ptr_vector,
+ SPIRV_Int32:$alignment,
+ SPIRV_Vector:$mask,
+ SPIRV_Vector:$input_vector
+ );
+
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $ptr_vector `,` $alignment `,` $mask `,` $input_vector attr-dict `:`
+ type($ptr_vector) `,` type($alignment) `,`
+ type($mask) `,` type($input_vector)
+ }];
+
+ let hasVerifier = 1;
+}
+
// -----
#endif // MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 4a0c29d4b5d90..085482b6099d7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -36,6 +36,7 @@ struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
struct SampledImageTypeStorage;
struct StructTypeStorage;
+struct VectorOfPointerTypeStorage;
} // namespace detail
@@ -515,6 +516,25 @@ class TensorArmType
operator ShapedType() const { return cast<ShapedType>(*this); }
};
+/// SPIR-V vector of pointers type. Represents an OpTypeVector whose element
+/// type is an OpTypePointer. This is needed because MLIR's built-in VectorType
+/// does not support pointer element types. Used by the
+/// SPV_INTEL_masked_gather_scatter extension.
+class VectorOfPointerType
+ : public Type::TypeBase<VectorOfPointerType, CompositeType,
+ detail::VectorOfPointerTypeStorage> {
+public:
+ using Base::Base;
+
+ static constexpr StringLiteral name = "spirv.vecptr";
+
+ static VectorOfPointerType get(PointerType elementType, unsigned numElements);
+
+ PointerType getElementType() const;
+
+ unsigned getNumElements() const;
+};
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index c9b22fe145d88..160d460f2cf6f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -837,6 +837,43 @@ static Type parseStructType(SPIRVDialect const &dialect,
structDecorationInfo);
}
+// vecptr-type ::= `vecptr` `<` integer-literal `x` pointer-type `>`
+static Type parseVectorOfPointerType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return Type();
+
+ int64_t count = 0;
+ SMLoc countLoc = parser.getCurrentLocation();
+ if (parser.parseInteger(count))
+ return Type();
+
+ if (parser.parseComma())
+ return Type();
+ if (!llvm::is_contained({2, 3, 4, 8, 16}, count)) {
+ parser.emitError(countLoc,
+ "vector length must be 2, 3, 4, 8, or 16, but got ")
+ << count;
+ return Type();
+ }
+
+ Type elementType = parseAndVerifyType(dialect, parser);
+ if (!elementType)
+ return Type();
+
+ auto ptrType = dyn_cast<spirv::PointerType>(elementType);
+ if (!ptrType) {
+ parser.emitError(parser.getNameLoc(),
+ "vecptr element type must be a spirv.ptr type");
+ return Type();
+ }
+
+ if (parser.parseGreater())
+ return Type();
+
+ return VectorOfPointerType::get(ptrType, count);
+}
+
// spirv-type ::= array-type
// | element-type
// | image-type
@@ -844,6 +881,7 @@ static Type parseStructType(SPIRVDialect const &dialect,
// | runtime-array-type
// | sampled-image-type
// | struct-type
+// | vecptr-type
Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
@@ -867,6 +905,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseMatrixType(*this, parser);
if (keyword == "arm.tensor")
return parseTensorArmType(*this, parser);
+ if (keyword == "vecptr")
+ return parseVectorOfPointerType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
@@ -998,11 +1038,16 @@ static void print(TensorArmType type, DialectAsmPrinter &os) {
os << "x" << type.getElementType() << ">";
}
+static void print(VectorOfPointerType type, DialectAsmPrinter &os) {
+ os << "vecptr<" << type.getNumElements() << ", " << type.getElementType()
+ << ">";
+}
+
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
- ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
- [&](auto type) { print(type, os); })
+ ImageType, SampledImageType, StructType, MatrixType, TensorArmType,
+ VectorOfPointerType>([&](auto type) { print(type, os); })
.DefaultUnreachable("Unhandled SPIR-V type");
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index cecc8c2194237..9be4d07eeca31 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1415,6 +1415,71 @@ LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELMaskedGatherOp::verify() {
+ auto ptrVecType = cast<spirv::VectorOfPointerType>(getPtrVector().getType());
+ auto resultType = cast<VectorType>(getResult().getType());
+ unsigned numElems = resultType.getNumElements();
+
+ // Verify pointee type matches result element type.
+ if (ptrVecType.getElementType().getPointeeType() !=
+ resultType.getElementType())
+ return emitOpError(
+ "pointer pointee type must match result vector element type");
+
+ // Verify element counts match.
+ if (ptrVecType.getNumElements() != numElems)
+ return emitOpError(
+ "ptr_vector must have the same number of elements as result");
+
+ // Verify mask is a vector of i1.
+ auto maskType = cast<VectorType>(getMask().getType());
+ if (!maskType.getElementType().isInteger(1))
+ return emitOpError("mask must be a vector of i1");
+ if (maskType.getNumElements() != numElems)
+ return emitOpError("mask must have the same number of elements as result");
+
+ // Verify fill_empty matches result type.
+ if (getFillEmpty().getType() != resultType)
+ return emitOpError("fill_empty must have the same type as result");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedScatter
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELMaskedScatterOp::verify() {
+ auto ptrVecType = cast<spirv::VectorOfPointerType>(getPtrVector().getType());
+ auto inputType = cast<VectorType>(getInputVector().getType());
+ unsigned numElems = inputType.getNumElements();
+
+ // Verify pointee type matches input element type.
+ if (ptrVecType.getElementType().getPointeeType() !=
+ inputType.getElementType())
+ return emitOpError(
+ "pointer pointee type must match input vector element type");
+
+ // Verify element counts match.
+ if (ptrVecType.getNumElements() != numElems)
+ return emitOpError(
+ "ptr_vector must have the same number of elements as input_vector");
+
+ // Verify mask is a vector of i1.
+ auto maskType = cast<VectorType>(getMask().getType());
+ if (!maskType.getElementType().isInteger(1))
+ return emitOpError("mask must be a vector of i1");
+ if (maskType.getNumElements() != numElems)
+ return emitOpError(
+ "mask must have the same number of elements as input_vector");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.IAddCarryOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 331d98c1d9313..c03ca72c0a908 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -50,6 +50,9 @@ class TypeExtensionVisitor {
[this](auto concreteType) { addConcrete(concreteType); })
.Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
[this](auto concreteType) { add(concreteType.getElementType()); })
+ .Case([this](VectorOfPointerType concreteType) {
+ add(concreteType.getElementType());
+ })
.Case([this](SampledImageType concreteType) {
add(concreteType.getImageType());
})
@@ -100,6 +103,9 @@ class TypeCapabilityVisitor {
.Case([this](ArrayType concreteType) {
add(concreteType.getElementType());
})
+ .Case([this](VectorOfPointerType concreteType) {
+ add(concreteType.getElementType());
+ })
.Case([this](SampledImageType concreteType) {
add(concreteType.getImageType());
})
@@ -181,8 +187,8 @@ bool CompositeType::classof(Type type) {
if (auto vectorType = dyn_cast<VectorType>(type))
return isValid(vectorType);
return isa<spirv::ArrayType, spirv::CooperativeMatrixType, spirv::MatrixType,
- spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType>(
- type);
+ spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType,
+ spirv::VectorOfPointerType>(type);
}
bool CompositeType::isValid(VectorType type) {
@@ -195,6 +201,9 @@ Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
.Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
TensorArmType>([](auto type) { return type.getElementType(); })
+ .Case([](VectorOfPointerType type) -> Type {
+ return type.getElementType();
+ })
.Case([](MatrixType type) { return type.getColumnType(); })
.Case([index](StructType type) { return type.getElementType(index); })
.DefaultUnreachable("Invalid composite type");
@@ -202,7 +211,8 @@ Type CompositeType::getElementType(unsigned index) const {
unsigned CompositeType::getNumElements() const {
return TypeSwitch<SPIRVType, unsigned>(*this)
- .Case<ArrayType, StructType, TensorArmType, VectorType>(
+ .Case<ArrayType, StructType, TensorArmType, VectorType,
+ VectorOfPointerType>(
[](auto type) { return type.getNumElements(); })
.Case([](MatrixType type) { return type.getNumColumns(); })
.DefaultUnreachable("Invalid type for number of elements query");
@@ -1325,11 +1335,49 @@ TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// VectorOfPointerType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::VectorOfPointerTypeStorage final : TypeStorage {
+ using KeyTy = std::pair<PointerType, unsigned>;
+
+ static VectorOfPointerTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<VectorOfPointerTypeStorage>())
+ VectorOfPointerTypeStorage(key);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(elementType, numElements);
+ }
+
+ VectorOfPointerTypeStorage(const KeyTy &key)
+ : elementType(key.first), numElements(key.second) {}
+
+ PointerType elementType;
+ unsigned numElements;
+};
+
+VectorOfPointerType VectorOfPointerType::get(PointerType elementType,
+ unsigned numElements) {
+ return Base::get(elementType.getContext(), elementType, numElements);
+}
+
+PointerType VectorOfPointerType::getElementType() const {
+ return getImpl()->elementType;
+}
+
+unsigned VectorOfPointerType::getNumElements() const {
+ return getImpl()->numElements;
+}
+
//===----------------------------------------------------------------------===//
// SPIR-V Dialect
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
- RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
+ RuntimeArrayType, SampledImageType, StructType, TensorArmType,
+ VectorOfPointerType>();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e218bee4b4fe3..9ea8f0f06337c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1137,7 +1137,11 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
<< operands[1];
}
- typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
+ if (auto ptrType = dyn_cast<spirv::PointerType>(elementTy))
+ typeMap[operands[0]] =
+ spirv::VectorOfPointerType::get(ptrType, operands[2]);
+ else
+ typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
} break;
case spirv::Opcode::OpTypePointer: {
return processOpTypePointer(operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c21cb27b072f1..c3ee16486b751 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -626,6 +626,18 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
+ if (auto vecPtrType = dyn_cast<spirv::VectorOfPointerType>(type)) {
+ uint32_t elementTypeID = 0;
+ if (failed(processTypeImpl(loc, vecPtrType.getElementType(), elementTypeID,
+ serializationCtx))) {
+ return failure();
+ }
+ typeEnum = spirv::Opcode::OpTypeVector;
+ operands.push_back(elementTypeID);
+ operands.push_back(vecPtrType.getNumElements());
+ return success();
+ }
+
if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
typeEnum = spirv::Opcode::OpTypeImage;
uint32_t sampledTypeID = 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 6e4126172f670..826fc0f964361 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -100,7 +100,7 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
// -----
func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
- // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
+ // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type or any SPIR-V vector of pointer type, but got 'vector<2x2xi1>'}}
%0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
return %0: vector<4x2xi1>
}
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index d124c02231161..fa3c90fef82c9 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -126,6 +126,126 @@ spirv.func @split_barrier() "None" {
// spirv.INTEL.CacheControls
//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather
+//===----------------------------------------------------------------------===//
+
+spirv.func @masked_gather(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.MaskedGather
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_i32(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.MaskedGather
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xi32> -> vector<4xi32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_pointee_type_mismatch(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xi32>) "None" {
+ // expected-error @+1 {{pointer pointee type must match result vector element type}}
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xi32> -> vector<4xi32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_elem_count_mismatch(
+ %ptrs : !spirv.vecptr<2, !spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xf32>) "None" {
+ // expected-error @+1 {{ptr_vector must have the same number of elements as result}}
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : !spirv.vecptr<2, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_mask_not_bool(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi8>,
+ %fill : vector<4xf32>) "None" {
+ // expected-error @+1 {{mask must be a vector of i1}}
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi8>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedScatter
+//===----------------------------------------------------------------------===//
+
+spirv.func @masked_scatter(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %values : vector<4xf32>) "None" {
+ // CHECK: spirv.INTEL.MaskedScatter
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_scatter_pointee_mismatch(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %values : vector<4xf32>) "None" {
+ // expected-error @+1 {{pointer pointee type must match input vector element type}}
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_scatter_mask_count_mismatch(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<2xi1>,
+ %values : vector<4xf32>) "None" {
+ // expected-error @+1 {{mask must have the same number of elements as input_vector}}
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<2xi1>, vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
spirv.func @foo() "None" {
// CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index e12b70cc5c139..5c219e0c86191 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () {
//===----------------------------------------------------------------------===//
func.func @ccr_result_not_composite() -> () {
- // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
+ // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type or any SPIR-V vector of pointer type, but got 'i32'}}
%0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
return
}
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 118bed8be7095..682fa2cd8320d 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -59,6 +59,66 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, TensorF
// -----
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather / MaskedScatter
+//===----------------------------------------------------------------------===//
+
+spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage, MaskedGatherScatterINTEL], [SPV_INTEL_masked_gather_scatter]> {
+ // CHECK-LABEL: @masked_gather_f32
+ spirv.func @masked_gather_f32(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xf32>) "None" {
+ // CHECK: spirv.INTEL.MaskedGather
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @masked_gather_i32
+ spirv.func @masked_gather_i32(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xi32>) "None" {
+ // CHECK: spirv.INTEL.MaskedGather
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xi32> -> vector<4xi32>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @masked_scatter_f32
+ spirv.func @masked_scatter_f32(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %values : vector<4xf32>) "None" {
+ // CHECK: spirv.INTEL.MaskedScatter
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @masked_scatter_i32
+ spirv.func @masked_scatter_i32(
+ %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %values : vector<4xi32>) "None" {
+ // CHECK: spirv.INTEL.MaskedScatter
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xi32>
+ spirv.Return
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
>From c35ea72587c8be892af69ff4930db8bb2472a84c Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Tue, 7 Apr 2026 06:40:46 +0200
Subject: [PATCH 2/3] Address review comments
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 9 +--
.../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 45 +++++++++----
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 26 ++------
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 49 +-------------
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 65 -------------------
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 59 ++---------------
.../SPIRV/Deserialization/Deserializer.cpp | 6 +-
.../Target/SPIRV/Serialization/Serializer.cpp | 12 ----
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 42 ++++++------
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 2 +-
mlir/test/Target/SPIRV/intel-ext-ops.mlir | 16 ++---
12 files changed, 79 insertions(+), 254 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d9a573b59c368..cd6ea54555f4b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4266,7 +4266,6 @@ def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
-def SPIRV_IsVectorOfPointerType : CPred<"::llvm::isa<::mlir::spirv::VectorOfPointerType>($_self)">;
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
@@ -4310,22 +4309,18 @@ def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
"any SPIR-V tensorArm type">;
-def SPIRV_AnyVectorOfPointer : DialectType<SPIRV_Dialect,
- SPIRV_IsVectorOfPointerType,
- "any SPIR-V vector of pointer type">;
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm,
- SPIRV_AnyVectorOfPointer]>;
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnyTensorArm]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
- SPIRV_AnyImage, SPIRV_AnyTensorArm, SPIRV_AnyVectorOfPointer
+ SPIRV_AnyImage, SPIRV_AnyTensorArm
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 0ea57960fec74..1502187683c1f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -252,8 +252,19 @@ def SPIRV_INTELControlBarrierWaitOp
// -----
-def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
- let summary = "See extension SPV_INTEL_masked_gather_scatter";
+def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather",
+ [AllTypesMatch<["fill_empty", "result"]>,
+ TypesMatchWith<"pointee type of ptr_vector must match result element type",
+ "ptr_vector", "result",
+ "VectorType::get("
+ "::llvm::cast<VectorType>($_self).getShape(), "
+ "::llvm::cast<spirv::PointerType>("
+ "::llvm::cast<VectorType>($_self).getElementType())"
+ ".getPointeeType())">,
+ TypesMatchWith<"mask must be a vector of i1 matching result shape",
+ "result", "mask",
+ "getUnaryOpResultType($_self)">]> {
+ let summary = "Gather values from memory using a vector of pointers and a mask";
let description = [{
Reads values from a vector of pointers gathering them into a result
@@ -277,7 +288,7 @@ def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
```mlir
%result = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
- : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xf32> -> vector<4xf32>
```
}];
@@ -290,9 +301,9 @@ def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
];
let arguments = (ins
- SPIRV_AnyVectorOfPointer:$ptr_vector,
+ SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
SPIRV_Int32:$alignment,
- SPIRV_Vector:$mask,
+ SPIRV_VectorOf<I1>:$mask,
SPIRV_Vector:$fill_empty
);
@@ -306,13 +317,23 @@ def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather", []> {
type($mask) `,` type($fill_empty) `->` type($result)
}];
- let hasVerifier = 1;
+ let hasVerifier = 0;
}
// -----
-def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
- let summary = "See extension SPV_INTEL_masked_gather_scatter";
+def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter",
+ [TypesMatchWith<"pointee type of ptr_vector must match input element type",
+ "ptr_vector", "input_vector",
+ "VectorType::get("
+ "::llvm::cast<VectorType>($_self).getShape(), "
+ "::llvm::cast<spirv::PointerType>("
+ "::llvm::cast<VectorType>($_self).getElementType())"
+ ".getPointeeType())">,
+ TypesMatchWith<"mask must be a vector of i1 matching input shape",
+ "input_vector", "mask",
+ "getUnaryOpResultType($_self)">]> {
+ let summary = "Scatter values to memory using a vector of pointers and a mask";
let description = [{
Writes values from a vector into memory locations pointed to by a
@@ -333,7 +354,7 @@ def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
```mlir
spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
- : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xf32>
```
}];
@@ -346,9 +367,9 @@ def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
];
let arguments = (ins
- SPIRV_AnyVectorOfPointer:$ptr_vector,
+ SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
SPIRV_Int32:$alignment,
- SPIRV_Vector:$mask,
+ SPIRV_VectorOf<I1>:$mask,
SPIRV_Vector:$input_vector
);
@@ -360,7 +381,7 @@ def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter", []> {
type($mask) `,` type($input_vector)
}];
- let hasVerifier = 1;
+ let hasVerifier = 0;
}
// -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 085482b6099d7..90b695a7ba2fd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
@@ -36,7 +37,6 @@ struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
struct SampledImageTypeStorage;
struct StructTypeStorage;
-struct VectorOfPointerTypeStorage;
} // namespace detail
@@ -173,8 +173,9 @@ class ImageType
};
// SPIR-V pointer type
-class PointerType : public Type::TypeBase<PointerType, SPIRVType,
- detail::PointerTypeStorage> {
+class PointerType
+ : public Type::TypeBase<PointerType, SPIRVType, detail::PointerTypeStorage,
+ VectorElementTypeInterface::Trait> {
public:
using Base::Base;
@@ -516,25 +517,6 @@ class TensorArmType
operator ShapedType() const { return cast<ShapedType>(*this); }
};
-/// SPIR-V vector of pointers type. Represents an OpTypeVector whose element
-/// type is an OpTypePointer. This is needed because MLIR's built-in VectorType
-/// does not support pointer element types. Used by the
-/// SPV_INTEL_masked_gather_scatter extension.
-class VectorOfPointerType
- : public Type::TypeBase<VectorOfPointerType, CompositeType,
- detail::VectorOfPointerTypeStorage> {
-public:
- using Base::Base;
-
- static constexpr StringLiteral name = "spirv.vecptr";
-
- static VectorOfPointerType get(PointerType elementType, unsigned numElements);
-
- PointerType getElementType() const;
-
- unsigned getNumElements() const;
-};
-
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 160d460f2cf6f..c9b22fe145d88 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -837,43 +837,6 @@ static Type parseStructType(SPIRVDialect const &dialect,
structDecorationInfo);
}
-// vecptr-type ::= `vecptr` `<` integer-literal `x` pointer-type `>`
-static Type parseVectorOfPointerType(SPIRVDialect const &dialect,
- DialectAsmParser &parser) {
- if (parser.parseLess())
- return Type();
-
- int64_t count = 0;
- SMLoc countLoc = parser.getCurrentLocation();
- if (parser.parseInteger(count))
- return Type();
-
- if (parser.parseComma())
- return Type();
- if (!llvm::is_contained({2, 3, 4, 8, 16}, count)) {
- parser.emitError(countLoc,
- "vector length must be 2, 3, 4, 8, or 16, but got ")
- << count;
- return Type();
- }
-
- Type elementType = parseAndVerifyType(dialect, parser);
- if (!elementType)
- return Type();
-
- auto ptrType = dyn_cast<spirv::PointerType>(elementType);
- if (!ptrType) {
- parser.emitError(parser.getNameLoc(),
- "vecptr element type must be a spirv.ptr type");
- return Type();
- }
-
- if (parser.parseGreater())
- return Type();
-
- return VectorOfPointerType::get(ptrType, count);
-}
-
// spirv-type ::= array-type
// | element-type
// | image-type
@@ -881,7 +844,6 @@ static Type parseVectorOfPointerType(SPIRVDialect const &dialect,
// | runtime-array-type
// | sampled-image-type
// | struct-type
-// | vecptr-type
Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
@@ -905,8 +867,6 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseMatrixType(*this, parser);
if (keyword == "arm.tensor")
return parseTensorArmType(*this, parser);
- if (keyword == "vecptr")
- return parseVectorOfPointerType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
@@ -1038,16 +998,11 @@ static void print(TensorArmType type, DialectAsmPrinter &os) {
os << "x" << type.getElementType() << ">";
}
-static void print(VectorOfPointerType type, DialectAsmPrinter &os) {
- os << "vecptr<" << type.getNumElements() << ", " << type.getElementType()
- << ">";
-}
-
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
- ImageType, SampledImageType, StructType, MatrixType, TensorArmType,
- VectorOfPointerType>([&](auto type) { print(type, os); })
+ ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
+ [&](auto type) { print(type, os); })
.DefaultUnreachable("Unhandled SPIR-V type");
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 9be4d07eeca31..cecc8c2194237 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1415,71 +1415,6 @@ LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.MaskedGather
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELMaskedGatherOp::verify() {
- auto ptrVecType = cast<spirv::VectorOfPointerType>(getPtrVector().getType());
- auto resultType = cast<VectorType>(getResult().getType());
- unsigned numElems = resultType.getNumElements();
-
- // Verify pointee type matches result element type.
- if (ptrVecType.getElementType().getPointeeType() !=
- resultType.getElementType())
- return emitOpError(
- "pointer pointee type must match result vector element type");
-
- // Verify element counts match.
- if (ptrVecType.getNumElements() != numElems)
- return emitOpError(
- "ptr_vector must have the same number of elements as result");
-
- // Verify mask is a vector of i1.
- auto maskType = cast<VectorType>(getMask().getType());
- if (!maskType.getElementType().isInteger(1))
- return emitOpError("mask must be a vector of i1");
- if (maskType.getNumElements() != numElems)
- return emitOpError("mask must have the same number of elements as result");
-
- // Verify fill_empty matches result type.
- if (getFillEmpty().getType() != resultType)
- return emitOpError("fill_empty must have the same type as result");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.MaskedScatter
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELMaskedScatterOp::verify() {
- auto ptrVecType = cast<spirv::VectorOfPointerType>(getPtrVector().getType());
- auto inputType = cast<VectorType>(getInputVector().getType());
- unsigned numElems = inputType.getNumElements();
-
- // Verify pointee type matches input element type.
- if (ptrVecType.getElementType().getPointeeType() !=
- inputType.getElementType())
- return emitOpError(
- "pointer pointee type must match input vector element type");
-
- // Verify element counts match.
- if (ptrVecType.getNumElements() != numElems)
- return emitOpError(
- "ptr_vector must have the same number of elements as input_vector");
-
- // Verify mask is a vector of i1.
- auto maskType = cast<VectorType>(getMask().getType());
- if (!maskType.getElementType().isInteger(1))
- return emitOpError("mask must be a vector of i1");
- if (maskType.getNumElements() != numElems)
- return emitOpError(
- "mask must have the same number of elements as input_vector");
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.IAddCarryOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index c03ca72c0a908..55f7dafcab42d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -50,9 +50,6 @@ class TypeExtensionVisitor {
[this](auto concreteType) { addConcrete(concreteType); })
.Case<ArrayType, ImageType, MatrixType, RuntimeArrayType, VectorType>(
[this](auto concreteType) { add(concreteType.getElementType()); })
- .Case([this](VectorOfPointerType concreteType) {
- add(concreteType.getElementType());
- })
.Case([this](SampledImageType concreteType) {
add(concreteType.getImageType());
})
@@ -103,9 +100,6 @@ class TypeCapabilityVisitor {
.Case([this](ArrayType concreteType) {
add(concreteType.getElementType());
})
- .Case([this](VectorOfPointerType concreteType) {
- add(concreteType.getElementType());
- })
.Case([this](SampledImageType concreteType) {
add(concreteType.getImageType());
})
@@ -187,23 +181,21 @@ bool CompositeType::classof(Type type) {
if (auto vectorType = dyn_cast<VectorType>(type))
return isValid(vectorType);
return isa<spirv::ArrayType, spirv::CooperativeMatrixType, spirv::MatrixType,
- spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType,
- spirv::VectorOfPointerType>(type);
+ spirv::RuntimeArrayType, spirv::StructType, spirv::TensorArmType>(
+ type);
}
bool CompositeType::isValid(VectorType type) {
return type.getRank() == 1 &&
llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
- isa<ScalarType>(type.getElementType());
+ (isa<ScalarType>(type.getElementType()) ||
+ isa<PointerType>(type.getElementType()));
}
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
.Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
TensorArmType>([](auto type) { return type.getElementType(); })
- .Case([](VectorOfPointerType type) -> Type {
- return type.getElementType();
- })
.Case([](MatrixType type) { return type.getColumnType(); })
.Case([index](StructType type) { return type.getElementType(index); })
.DefaultUnreachable("Invalid composite type");
@@ -211,8 +203,7 @@ Type CompositeType::getElementType(unsigned index) const {
unsigned CompositeType::getNumElements() const {
return TypeSwitch<SPIRVType, unsigned>(*this)
- .Case<ArrayType, StructType, TensorArmType, VectorType,
- VectorOfPointerType>(
+ .Case<ArrayType, StructType, TensorArmType, VectorType>(
[](auto type) { return type.getNumElements(); })
.Case([](MatrixType type) { return type.getNumColumns(); })
.DefaultUnreachable("Invalid type for number of elements query");
@@ -1335,49 +1326,11 @@ TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-//===----------------------------------------------------------------------===//
-// VectorOfPointerType
-//===----------------------------------------------------------------------===//
-
-struct spirv::detail::VectorOfPointerTypeStorage final : TypeStorage {
- using KeyTy = std::pair<PointerType, unsigned>;
-
- static VectorOfPointerTypeStorage *construct(TypeStorageAllocator &allocator,
- const KeyTy &key) {
- return new (allocator.allocate<VectorOfPointerTypeStorage>())
- VectorOfPointerTypeStorage(key);
- }
-
- bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, numElements);
- }
-
- VectorOfPointerTypeStorage(const KeyTy &key)
- : elementType(key.first), numElements(key.second) {}
-
- PointerType elementType;
- unsigned numElements;
-};
-
-VectorOfPointerType VectorOfPointerType::get(PointerType elementType,
- unsigned numElements) {
- return Base::get(elementType.getContext(), elementType, numElements);
-}
-
-PointerType VectorOfPointerType::getElementType() const {
- return getImpl()->elementType;
-}
-
-unsigned VectorOfPointerType::getNumElements() const {
- return getImpl()->numElements;
-}
-
//===----------------------------------------------------------------------===//
// SPIR-V Dialect
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
- RuntimeArrayType, SampledImageType, StructType, TensorArmType,
- VectorOfPointerType>();
+ RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 9ea8f0f06337c..e218bee4b4fe3 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1137,11 +1137,7 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
<< operands[1];
}
- if (auto ptrType = dyn_cast<spirv::PointerType>(elementTy))
- typeMap[operands[0]] =
- spirv::VectorOfPointerType::get(ptrType, operands[2]);
- else
- typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
+ typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
} break;
case spirv::Opcode::OpTypePointer: {
return processOpTypePointer(operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c3ee16486b751..c21cb27b072f1 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -626,18 +626,6 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
- if (auto vecPtrType = dyn_cast<spirv::VectorOfPointerType>(type)) {
- uint32_t elementTypeID = 0;
- if (failed(processTypeImpl(loc, vecPtrType.getElementType(), elementTypeID,
- serializationCtx))) {
- return failure();
- }
- typeEnum = spirv::Opcode::OpTypeVector;
- operands.push_back(elementTypeID);
- operands.push_back(vecPtrType.getNumElements());
- return success();
- }
-
if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
typeEnum = spirv::Opcode::OpTypeImage;
uint32_t sampledTypeID = 0;
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 826fc0f964361..6e4126172f670 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -100,7 +100,7 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
// -----
func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
- // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type or any SPIR-V vector of pointer type, but got 'vector<2x2xi1>'}}
+ // expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
%0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
return %0: vector<4x2xi1>
}
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index fa3c90fef82c9..06f33a235c660 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -131,13 +131,13 @@ spirv.func @split_barrier() "None" {
//===----------------------------------------------------------------------===//
spirv.func @masked_gather(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%fill : vector<4xf32>) "None" {
// CHECK: {{%.*}} = spirv.INTEL.MaskedGather
%0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
- : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xf32> -> vector<4xf32>
spirv.Return
}
@@ -145,13 +145,13 @@ spirv.func @masked_gather(
// -----
spirv.func @masked_gather_i32(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%fill : vector<4xi32>) "None" {
// CHECK: {{%.*}} = spirv.INTEL.MaskedGather
%0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
- : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xi32> -> vector<4xi32>
spirv.Return
}
@@ -159,13 +159,13 @@ spirv.func @masked_gather_i32(
// -----
spirv.func @masked_gather_pointee_type_mismatch(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%fill : vector<4xi32>) "None" {
- // expected-error @+1 {{pointer pointee type must match result vector element type}}
+ // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that pointee type of ptr_vector must match result element type}}
%0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
- : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xi32> -> vector<4xi32>
spirv.Return
}
@@ -173,13 +173,13 @@ spirv.func @masked_gather_pointee_type_mismatch(
// -----
spirv.func @masked_gather_elem_count_mismatch(
- %ptrs : !spirv.vecptr<2, !spirv.ptr<f32, CrossWorkgroup>>,
+ %ptrs : vector<2x!spirv.ptr<f32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%fill : vector<4xf32>) "None" {
- // expected-error @+1 {{ptr_vector must have the same number of elements as result}}
+ // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that pointee type of ptr_vector must match result element type}}
%0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
- : !spirv.vecptr<2, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<2x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xf32> -> vector<4xf32>
spirv.Return
}
@@ -187,13 +187,13 @@ spirv.func @masked_gather_elem_count_mismatch(
// -----
spirv.func @masked_gather_mask_not_bool(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi8>,
%fill : vector<4xf32>) "None" {
- // expected-error @+1 {{mask must be a vector of i1}}
+ // expected-error @+1 {{operand #2 must be fixed-length vector of 1-bit signless integer values of length 2/3/4/8/16 of ranks 1, but got 'vector<4xi8>'}}
%0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
- : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<4xi8>, vector<4xf32> -> vector<4xf32>
spirv.Return
}
@@ -205,13 +205,13 @@ spirv.func @masked_gather_mask_not_bool(
//===----------------------------------------------------------------------===//
spirv.func @masked_scatter(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%values : vector<4xf32>) "None" {
// CHECK: spirv.INTEL.MaskedScatter
spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
- : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xf32>
spirv.Return
}
@@ -219,13 +219,13 @@ spirv.func @masked_scatter(
// -----
spirv.func @masked_scatter_pointee_mismatch(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%values : vector<4xf32>) "None" {
- // expected-error @+1 {{pointer pointee type must match input vector element type}}
+ // expected-error @+1 {{'spirv.INTEL.MaskedScatter' op failed to verify that pointee type of ptr_vector must match input element type}}
spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
- : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xf32>
spirv.Return
}
@@ -233,13 +233,13 @@ spirv.func @masked_scatter_pointee_mismatch(
// -----
spirv.func @masked_scatter_mask_count_mismatch(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<2xi1>,
%values : vector<4xf32>) "None" {
- // expected-error @+1 {{mask must have the same number of elements as input_vector}}
+ // expected-error @+1 {{'spirv.INTEL.MaskedScatter' op failed to verify that mask must be a vector of i1 matching input shape}}
spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
- : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<2xi1>, vector<4xf32>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 5c219e0c86191..e12b70cc5c139 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -170,7 +170,7 @@ func.func @coop_matrix_const_wrong_type() -> () {
//===----------------------------------------------------------------------===//
func.func @ccr_result_not_composite() -> () {
- // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type or any SPIR-V vector of pointer type, but got 'i32'}}
+ // expected-error @+1 {{op result #0 must be vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V tensorArm type, but got 'i32'}}
%0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : i32
return
}
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 682fa2cd8320d..4296c3f1a9682 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -66,52 +66,52 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, TensorF
spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage, MaskedGatherScatterINTEL], [SPV_INTEL_masked_gather_scatter]> {
// CHECK-LABEL: @masked_gather_f32
spirv.func @masked_gather_f32(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%fill : vector<4xf32>) "None" {
// CHECK: spirv.INTEL.MaskedGather
%0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
- : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xf32> -> vector<4xf32>
spirv.Return
}
// CHECK-LABEL: @masked_gather_i32
spirv.func @masked_gather_i32(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%fill : vector<4xi32>) "None" {
// CHECK: spirv.INTEL.MaskedGather
%0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
- : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xi32> -> vector<4xi32>
spirv.Return
}
// CHECK-LABEL: @masked_scatter_f32
spirv.func @masked_scatter_f32(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%values : vector<4xf32>) "None" {
// CHECK: spirv.INTEL.MaskedScatter
spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
- : !spirv.vecptr<4, !spirv.ptr<f32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xf32>
spirv.Return
}
// CHECK-LABEL: @masked_scatter_i32
spirv.func @masked_scatter_i32(
- %ptrs : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>,
+ %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
%alignment : i32,
%mask : vector<4xi1>,
%values : vector<4xi32>) "None" {
// CHECK: spirv.INTEL.MaskedScatter
spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
- : !spirv.vecptr<4, !spirv.ptr<i32, CrossWorkgroup>>, i32,
+ : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
vector<4xi1>, vector<4xi32>
spirv.Return
}
>From 7739f8bd45803d17f541d10a9ce0c346c0b7725b Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Mon, 13 Apr 2026 10:16:11 +0200
Subject: [PATCH 3/3] Address review comments
---
.../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 10 +++++-----
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 14 ++++++++++++++
2 files changed, 19 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 1502187683c1f..57a57a18cca33 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -303,12 +303,12 @@ def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather",
let arguments = (ins
SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
SPIRV_Int32:$alignment,
- SPIRV_VectorOf<I1>:$mask,
- SPIRV_Vector:$fill_empty
+ SPIRV_VectorOf<SPIRV_Bool>:$mask,
+ SPIRV_VectorOf<SPIRV_Numerical>:$fill_empty
);
let results = (outs
- SPIRV_Vector:$result
+ SPIRV_VectorOf<SPIRV_Numerical>:$result
);
let assemblyFormat = [{
@@ -369,8 +369,8 @@ def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter",
let arguments = (ins
SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
SPIRV_Int32:$alignment,
- SPIRV_VectorOf<I1>:$mask,
- SPIRV_Vector:$input_vector
+ SPIRV_VectorOf<SPIRV_Bool>:$mask,
+ SPIRV_VectorOf<SPIRV_Numerical>:$input_vector
);
let results = (outs);
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 06f33a235c660..20d5b112b1652 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -200,6 +200,20 @@ spirv.func @masked_gather_mask_not_bool(
// -----
+spirv.func @masked_gather_mask_count_mismatch(
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<2xi1>,
+ %fill : vector<4xf32>) "None" {
+ // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that mask must be a vector of i1 matching result shape}}
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<2xi1>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.INTEL.MaskedScatter
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list