[Mlir-commits] [mlir] c9f175b - [mlir][SPIR-V] Add support for SPV_INTEL_masked_gather_scatter extension (#189099)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 13 13:59:38 PDT 2026
Author: Arseniy Obolenskiy
Date: 2026-04-13T22:59:34+02:00
New Revision: c9f175bed493ccbe6c85a5c4e4fb7f1b800123a2
URL: https://github.com/llvm/llvm-project/commit/c9f175bed493ccbe6c85a5c4e4fb7f1b800123a2
DIFF: https://github.com/llvm/llvm-project/commit/c9f175bed493ccbe6c85a5c4e4fb7f1b800123a2.diff
LOG: [mlir][SPIR-V] Add support for SPV_INTEL_masked_gather_scatter extension (#189099)
Add MaskedGather/MaskedScatter ops and VectorOfPointerType for
SPV_INTEL_masked_gather_scatter extension implemented in #185418
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
mlir/test/Target/SPIRV/intel-ext-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index c4d123e0f539c..297dde3a67b2a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -407,6 +407,7 @@ def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_sp
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
+def SPV_INTEL_masked_gather_scatter : I32EnumAttrCase<"SPV_INTEL_masked_gather_scatter", 4034>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -472,6 +473,7 @@ def SPIRV_ExtensionAttr :
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
SPV_INTEL_tensor_float32_conversion,
+ SPV_INTEL_masked_gather_scatter,
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1481,6 +1483,12 @@ def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"
];
}
+def SPIRV_C_MaskedGatherScatterINTEL : I32EnumAttrCase<"MaskedGatherScatterINTEL", 6427> {
+ list<Availability> availability = [
+ Extension<[SPV_INTEL_masked_gather_scatter]>
+ ];
+}
+
def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
list<Availability> availability = [
Extension<[SPV_INTEL_cache_controls]>
@@ -1590,7 +1598,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
- SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_Float8EXT
+ SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_MaskedGatherScatterINTEL,
+ SPIRV_C_Float8EXT
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -4656,6 +4665,8 @@ def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrie
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
+def SPIRV_OC_OpMaskedGatherINTEL : I32EnumAttrCase<"OpMaskedGatherINTEL", 6428>;
+def SPIRV_OC_OpMaskedScatterINTEL : I32EnumAttrCase<"OpMaskedScatterINTEL", 6429>;
def SPIRV_OpcodeAttr :
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4770,7 +4781,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
- SPIRV_OC_OpRoundFToTF32INTEL
+ SPIRV_OC_OpRoundFToTF32INTEL,
+ SPIRV_OC_OpMaskedGatherINTEL, SPIRV_OC_OpMaskedScatterINTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 2a7fa534cc3dc..9d4d6a0785f1a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -250,6 +250,140 @@ def SPIRV_INTELControlBarrierWaitOp
}
+// -----
+
+def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather",
+ [AllTypesMatch<["fill_empty", "result"]>,
+ TypesMatchWith<"pointee type of ptr_vector must match result element type",
+ "ptr_vector", "result",
+ "VectorType::get("
+ "::llvm::cast<VectorType>($_self).getShape(), "
+ "::llvm::cast<spirv::PointerType>("
+ "::llvm::cast<VectorType>($_self).getElementType())"
+ ".getPointeeType())">,
+ TypesMatchWith<"mask must be a vector of i1 matching result shape",
+ "result", "mask",
+ "getMatchingBoolType($_self)">]> {
+ let summary = "Gather values from memory using a vector of pointers and a mask";
+
+ let description = [{
+ Reads values from a vector of pointers gathering them into a result
+ vector. Lanes where the mask is false receive the corresponding
+ FillEmpty value.
+
+ Result Type must be a vector of numerical type.
+
+ PtrVector must be a vector of pointers to the scalar element type of
+ Result Type. It must have the same number of components as Result Type.
+
+ Alignment is the known minimum alignment in bytes of each pointer in
+ PtrVector.
+
+ Mask must be a vector of boolean type with the same number of components
+ as Result Type.
+
+ FillEmpty must have the same type as Result Type.
+
+ #### Example:
+
+ ```mlir
+ %result = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_INTEL_masked_gather_scatter]>,
+ Capability<[SPIRV_C_MaskedGatherScatterINTEL]>
+ ];
+
+ let arguments = (ins
+ SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
+ SPIRV_Int32:$alignment,
+ SPIRV_VectorOf<SPIRV_Bool>:$mask,
+ SPIRV_VectorOf<SPIRV_Numerical>:$fill_empty
+ );
+
+ let results = (outs
+ SPIRV_VectorOf<SPIRV_Numerical>:$result
+ );
+
+ let assemblyFormat = [{
+ $ptr_vector `,` $alignment `,` $mask `,` $fill_empty attr-dict `:`
+ type($ptr_vector) `,` type($alignment) `,`
+ type($mask) `,` type($fill_empty) `->` type($result)
+ }];
+
+ let hasVerifier = 0;
+}
+
+// -----
+
+def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter",
+ [TypesMatchWith<"pointee type of ptr_vector must match input element type",
+ "ptr_vector", "input_vector",
+ "VectorType::get("
+ "::llvm::cast<VectorType>($_self).getShape(), "
+ "::llvm::cast<spirv::PointerType>("
+ "::llvm::cast<VectorType>($_self).getElementType())"
+ ".getPointeeType())">,
+ TypesMatchWith<"mask must be a vector of i1 matching input shape",
+ "input_vector", "mask",
+ "getMatchingBoolType($_self)">]> {
+ let summary = "Scatter values to memory using a vector of pointers and a mask";
+
+ let description = [{
+ Writes values from a vector into memory locations pointed to by a
+ vector of pointers. Only lanes where the mask is true are written.
+
+ PtrVector must be a vector of pointers to the scalar element type of
+ InputVector. It must have the same number of components as InputVector.
+
+ Alignment is the known minimum alignment in bytes of each pointer in
+ PtrVector.
+
+ Mask must be a vector of boolean type with the same number of components
+ as InputVector.
+
+ InputVector is the vector of values to scatter into memory.
+
+ #### Example:
+
+ ```mlir
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_INTEL_masked_gather_scatter]>,
+ Capability<[SPIRV_C_MaskedGatherScatterINTEL]>
+ ];
+
+ let arguments = (ins
+ SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
+ SPIRV_Int32:$alignment,
+ SPIRV_VectorOf<SPIRV_Bool>:$mask,
+ SPIRV_VectorOf<SPIRV_Numerical>:$input_vector
+ );
+
+ let results = (outs);
+
+ let assemblyFormat = [{
+ $ptr_vector `,` $alignment `,` $mask `,` $input_vector attr-dict `:`
+ type($ptr_vector) `,` type($alignment) `,`
+ type($mask) `,` type($input_vector)
+ }];
+
+ let hasVerifier = 0;
+}
+
// -----
#endif // MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 9864f644aa93e..b7890ff101b2b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
@@ -172,8 +173,9 @@ class ImageType
};
// SPIR-V pointer type
-class PointerType : public Type::TypeBase<PointerType, SPIRVType,
- detail::PointerTypeStorage> {
+class PointerType
+ : public Type::TypeBase<PointerType, SPIRVType, detail::PointerTypeStorage,
+ VectorElementTypeInterface::Trait> {
public:
using Base::Base;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index c4dd4cea778d7..0853c5aa59f92 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -190,7 +190,8 @@ bool CompositeType::classof(Type type) {
bool CompositeType::isValid(VectorType type) {
return type.getRank() == 1 &&
llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
- isa<ScalarType>(type.getElementType());
+ (isa<ScalarType>(type.getElementType()) ||
+ isa<PointerType>(type.getElementType()));
}
Type CompositeType::getElementType(unsigned index) const {
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index d124c02231161..3579fa76edcb3 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -126,6 +126,140 @@ spirv.func @split_barrier() "None" {
// spirv.INTEL.CacheControls
//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather
+//===----------------------------------------------------------------------===//
+
+spirv.func @masked_gather(
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.MaskedGather {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_i32(
+ %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.MaskedGather {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32> -> vector<4xi32>
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xi32> -> vector<4xi32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_pointee_type_mismatch(
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xi32>) "None" {
+ // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that pointee type of ptr_vector must match result element type}}
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xi32> -> vector<4xi32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_elem_count_mismatch(
+ %ptrs : vector<2x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xf32>) "None" {
+ // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that pointee type of ptr_vector must match result element type}}
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<2x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_mask_not_bool(
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi8>,
+ %fill : vector<4xf32>) "None" {
+ // expected-error @+1 {{operand #2 must be fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'vector<4xi8>'}}
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi8>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_mask_count_mismatch(
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<2xi1>,
+ %fill : vector<4xf32>) "None" {
+ // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that mask must be a vector of i1 matching result shape}}
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<2xi1>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedScatter
+//===----------------------------------------------------------------------===//
+
+spirv.func @masked_scatter(
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %values : vector<4xf32>) "None" {
+ // CHECK: spirv.INTEL.MaskedScatter {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32>
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_scatter_pointee_mismatch(
+ %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %values : vector<4xf32>) "None" {
+ // expected-error @+1 {{'spirv.INTEL.MaskedScatter' op failed to verify that pointee type of ptr_vector must match input element type}}
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @masked_scatter_mask_count_mismatch(
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<2xi1>,
+ %values : vector<4xf32>) "None" {
+ // expected-error @+1 {{'spirv.INTEL.MaskedScatter' op failed to verify that mask must be a vector of i1 matching input shape}}
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<2xi1>, vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
spirv.func @foo() "None" {
// CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 118bed8be7095..b6c42747f5b68 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -59,6 +59,66 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, TensorF
// -----
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather / MaskedScatter
+//===----------------------------------------------------------------------===//
+
+spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage, MaskedGatherScatterINTEL], [SPV_INTEL_masked_gather_scatter]> {
+ // CHECK-LABEL: @masked_gather_f32
+ spirv.func @masked_gather_f32(
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.MaskedGather {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32> -> vector<4xf32>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @masked_gather_i32
+ spirv.func @masked_gather_i32(
+ %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %fill : vector<4xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.MaskedGather {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32> -> vector<4xi32>
+ %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+ : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xi32> -> vector<4xi32>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @masked_scatter_f32
+ spirv.func @masked_scatter_f32(
+ %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %values : vector<4xf32>) "None" {
+ // CHECK: spirv.INTEL.MaskedScatter {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32>
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xf32>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @masked_scatter_i32
+ spirv.func @masked_scatter_i32(
+ %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
+ %alignment : i32,
+ %mask : vector<4xi1>,
+ %values : vector<4xi32>) "None" {
+ // CHECK: spirv.INTEL.MaskedScatter {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32>
+ spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+ : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
+ vector<4xi1>, vector<4xi32>
+ spirv.Return
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list