[Mlir-commits] [mlir] c9f175b - [mlir][SPIR-V] Add support for SPV_INTEL_masked_gather_scatter extension (#189099)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 13 13:59:38 PDT 2026


Author: Arseniy Obolenskiy
Date: 2026-04-13T22:59:34+02:00
New Revision: c9f175bed493ccbe6c85a5c4e4fb7f1b800123a2

URL: https://github.com/llvm/llvm-project/commit/c9f175bed493ccbe6c85a5c4e4fb7f1b800123a2
DIFF: https://github.com/llvm/llvm-project/commit/c9f175bed493ccbe6c85a5c4e4fb7f1b800123a2.diff

LOG: [mlir][SPIR-V] Add support for SPV_INTEL_masked_gather_scatter extension (#189099)

Add MaskedGather/MaskedScatter ops and VectorOfPointerType for
SPV_INTEL_masked_gather_scatter extension implemented in #185418

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
    mlir/test/Target/SPIRV/intel-ext-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index c4d123e0f539c..297dde3a67b2a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -407,6 +407,7 @@ def SPV_INTEL_split_barrier                      : I32EnumAttrCase<"SPV_INTEL_sp
 def SPV_INTEL_bfloat16_conversion                : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
 def SPV_INTEL_cache_controls                     : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
 def SPV_INTEL_tensor_float32_conversion          : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
+def SPV_INTEL_masked_gather_scatter              : I32EnumAttrCase<"SPV_INTEL_masked_gather_scatter", 4034>;
 
 def SPV_NV_compute_shader_derivatives    : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
 def SPV_NV_cooperative_matrix            : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -472,6 +473,7 @@ def SPIRV_ExtensionAttr :
       SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
       SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
       SPV_INTEL_tensor_float32_conversion,
+      SPV_INTEL_masked_gather_scatter,
       SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
       SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
       SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1481,6 +1483,12 @@ def SPIRV_C_TensorFloat32RoundingINTEL                       : I32EnumAttrCase<"
   ];
 }
 
+def SPIRV_C_MaskedGatherScatterINTEL                         : I32EnumAttrCase<"MaskedGatherScatterINTEL", 6427> {
+  list<Availability> availability = [
+    Extension<[SPV_INTEL_masked_gather_scatter]>
+  ];
+}
+
 def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_cache_controls]>
@@ -1590,7 +1598,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
       SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
       SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
-      SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_Float8EXT
+      SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_MaskedGatherScatterINTEL,
+      SPIRV_C_Float8EXT
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -4656,6 +4665,8 @@ def SPIRV_OC_OpControlBarrierWaitINTEL        : I32EnumAttrCase<"OpControlBarrie
 def SPIRV_OC_OpGroupIMulKHR                   : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
 def SPIRV_OC_OpGroupFMulKHR                   : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
 def SPIRV_OC_OpRoundFToTF32INTEL              : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
+def SPIRV_OC_OpMaskedGatherINTEL              : I32EnumAttrCase<"OpMaskedGatherINTEL", 6428>;
+def SPIRV_OC_OpMaskedScatterINTEL             : I32EnumAttrCase<"OpMaskedScatterINTEL", 6429>;
 
 def SPIRV_OpcodeAttr :
     SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4770,7 +4781,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
       SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
       SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
-      SPIRV_OC_OpRoundFToTF32INTEL
+      SPIRV_OC_OpRoundFToTF32INTEL,
+      SPIRV_OC_OpMaskedGatherINTEL, SPIRV_OC_OpMaskedScatterINTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 2a7fa534cc3dc..9d4d6a0785f1a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -250,6 +250,140 @@ def SPIRV_INTELControlBarrierWaitOp
 }
 
 
+// -----
+
+def SPIRV_INTELMaskedGatherOp : SPIRV_IntelVendorOp<"MaskedGather",
+    [AllTypesMatch<["fill_empty", "result"]>,
+     TypesMatchWith<"pointee type of ptr_vector must match result element type",
+                    "ptr_vector", "result",
+                    "VectorType::get("
+                      "::llvm::cast<VectorType>($_self).getShape(), "
+                      "::llvm::cast<spirv::PointerType>("
+                        "::llvm::cast<VectorType>($_self).getElementType())"
+                      ".getPointeeType())">,
+     TypesMatchWith<"mask must be a vector of i1 matching result shape",
+                    "result", "mask",
+                    "getMatchingBoolType($_self)">]> {
+  let summary = "Gather values from memory using a vector of pointers and a mask";
+
+  let description = [{
+    Reads values from a vector of pointers gathering them into a result
+    vector. Lanes where the mask is false receive the corresponding
+    FillEmpty value.
+
+    Result Type must be a vector of numerical type.
+
+    PtrVector must be a vector of pointers to the scalar element type of
+    Result Type. It must have the same number of components as Result Type.
+
+    Alignment is the known minimum alignment in bytes of each pointer in
+    PtrVector.
+
+    Mask must be a vector of boolean type with the same number of components
+    as Result Type.
+
+    FillEmpty must have the same type as Result Type.
+
+    #### Example:
+
+    ```mlir
+    %result = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+              : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+                vector<4xi1>, vector<4xf32> -> vector<4xf32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_masked_gather_scatter]>,
+    Capability<[SPIRV_C_MaskedGatherScatterINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
+    SPIRV_Int32:$alignment,
+    SPIRV_VectorOf<SPIRV_Bool>:$mask,
+    SPIRV_VectorOf<SPIRV_Numerical>:$fill_empty
+  );
+
+  let results = (outs
+    SPIRV_VectorOf<SPIRV_Numerical>:$result
+  );
+
+  let assemblyFormat = [{
+    $ptr_vector `,` $alignment `,` $mask `,` $fill_empty attr-dict `:`
+      type($ptr_vector) `,` type($alignment) `,`
+      type($mask) `,` type($fill_empty) `->` type($result)
+  }];
+
+  let hasVerifier = 0;
+}
+
+// -----
+
+def SPIRV_INTELMaskedScatterOp : SPIRV_IntelVendorOp<"MaskedScatter",
+    [TypesMatchWith<"pointee type of ptr_vector must match input element type",
+                    "ptr_vector", "input_vector",
+                    "VectorType::get("
+                      "::llvm::cast<VectorType>($_self).getShape(), "
+                      "::llvm::cast<spirv::PointerType>("
+                        "::llvm::cast<VectorType>($_self).getElementType())"
+                      ".getPointeeType())">,
+     TypesMatchWith<"mask must be a vector of i1 matching input shape",
+                    "input_vector", "mask",
+                    "getMatchingBoolType($_self)">]> {
+  let summary = "Scatter values to memory using a vector of pointers and a mask";
+
+  let description = [{
+    Writes values from a vector into memory locations pointed to by a
+    vector of pointers. Only lanes where the mask is true are written.
+
+    PtrVector must be a vector of pointers to the scalar element type of
+    InputVector. It must have the same number of components as InputVector.
+
+    Alignment is the known minimum alignment in bytes of each pointer in
+    PtrVector.
+
+    Mask must be a vector of boolean type with the same number of components
+    as InputVector.
+
+    InputVector is the vector of values to scatter into memory.
+
+    #### Example:
+
+    ```mlir
+    spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+              : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+                vector<4xi1>, vector<4xf32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_masked_gather_scatter]>,
+    Capability<[SPIRV_C_MaskedGatherScatterINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_VectorOf<SPIRV_AnyPtr>:$ptr_vector,
+    SPIRV_Int32:$alignment,
+    SPIRV_VectorOf<SPIRV_Bool>:$mask,
+    SPIRV_VectorOf<SPIRV_Numerical>:$input_vector
+  );
+
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $ptr_vector `,` $alignment `,` $mask `,` $input_vector attr-dict `:`
+      type($ptr_vector) `,` type($alignment) `,`
+      type($mask) `,` type($input_vector)
+  }];
+
+  let hasVerifier = 0;
+}
+
 // -----
 
 #endif // MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 9864f644aa93e..b7890ff101b2b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Location.h"
@@ -172,8 +173,9 @@ class ImageType
 };
 
 // SPIR-V pointer type
-class PointerType : public Type::TypeBase<PointerType, SPIRVType,
-                                          detail::PointerTypeStorage> {
+class PointerType
+    : public Type::TypeBase<PointerType, SPIRVType, detail::PointerTypeStorage,
+                            VectorElementTypeInterface::Trait> {
 public:
   using Base::Base;
 

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index c4dd4cea778d7..0853c5aa59f92 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -190,7 +190,8 @@ bool CompositeType::classof(Type type) {
 bool CompositeType::isValid(VectorType type) {
   return type.getRank() == 1 &&
          llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
-         isa<ScalarType>(type.getElementType());
+         (isa<ScalarType>(type.getElementType()) ||
+          isa<PointerType>(type.getElementType()));
 }
 
 Type CompositeType::getElementType(unsigned index) const {

diff  --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index d124c02231161..3579fa76edcb3 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -126,6 +126,140 @@ spirv.func @split_barrier() "None" {
 // spirv.INTEL.CacheControls
 //===----------------------------------------------------------------------===//
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather
+//===----------------------------------------------------------------------===//
+
+spirv.func @masked_gather(
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %fill : vector<4xf32>) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.MaskedGather {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32> -> vector<4xf32>
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xf32> -> vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_i32(
+    %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %fill : vector<4xi32>) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.MaskedGather {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32> -> vector<4xi32>
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xi32> -> vector<4xi32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_pointee_type_mismatch(
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %fill : vector<4xi32>) "None" {
+  // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that pointee type of ptr_vector must match result element type}}
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xi32> -> vector<4xi32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_elem_count_mismatch(
+    %ptrs : vector<2x!spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %fill : vector<4xf32>) "None" {
+  // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that pointee type of ptr_vector must match result element type}}
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : vector<2x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xf32> -> vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_mask_not_bool(
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi8>,
+    %fill : vector<4xf32>) "None" {
+  // expected-error @+1 {{operand #2 must be fixed-length vector of bool values of length 2/3/4/8/16 of ranks 1, but got 'vector<4xi8>'}}
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi8>, vector<4xf32> -> vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_gather_mask_count_mismatch(
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<2xi1>,
+    %fill : vector<4xf32>) "None" {
+  // expected-error @+1 {{'spirv.INTEL.MaskedGather' op failed to verify that mask must be a vector of i1 matching result shape}}
+  %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<2xi1>, vector<4xf32> -> vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedScatter
+//===----------------------------------------------------------------------===//
+
+spirv.func @masked_scatter(
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %values : vector<4xf32>) "None" {
+  // CHECK: spirv.INTEL.MaskedScatter {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32>
+  spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_scatter_pointee_mismatch(
+    %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<4xi1>,
+    %values : vector<4xf32>) "None" {
+  // expected-error @+1 {{'spirv.INTEL.MaskedScatter' op failed to verify that pointee type of ptr_vector must match input element type}}
+  spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+       : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
+         vector<4xi1>, vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @masked_scatter_mask_count_mismatch(
+    %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+    %alignment : i32,
+    %mask : vector<2xi1>,
+    %values : vector<4xf32>) "None" {
+  // expected-error @+1 {{'spirv.INTEL.MaskedScatter' op failed to verify that mask must be a vector of i1 matching input shape}}
+  spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+       : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+         vector<2xi1>, vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
   spirv.func @foo() "None" {
     // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>

diff  --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 118bed8be7095..b6c42747f5b68 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -59,6 +59,66 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, TensorF
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.MaskedGather / MaskedScatter
+//===----------------------------------------------------------------------===//
+
+spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage, MaskedGatherScatterINTEL], [SPV_INTEL_masked_gather_scatter]> {
+  // CHECK-LABEL: @masked_gather_f32
+  spirv.func @masked_gather_f32(
+      %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+      %alignment : i32,
+      %mask : vector<4xi1>,
+      %fill : vector<4xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.MaskedGather {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32> -> vector<4xf32>
+    %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+         : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+           vector<4xi1>, vector<4xf32> -> vector<4xf32>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @masked_gather_i32
+  spirv.func @masked_gather_i32(
+      %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
+      %alignment : i32,
+      %mask : vector<4xi1>,
+      %fill : vector<4xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.MaskedGather {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32> -> vector<4xi32>
+    %0 = spirv.INTEL.MaskedGather %ptrs, %alignment, %mask, %fill
+         : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
+           vector<4xi1>, vector<4xi32> -> vector<4xi32>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @masked_scatter_f32
+  spirv.func @masked_scatter_f32(
+      %ptrs : vector<4x!spirv.ptr<f32, CrossWorkgroup>>,
+      %alignment : i32,
+      %mask : vector<4xi1>,
+      %values : vector<4xf32>) "None" {
+    // CHECK: spirv.INTEL.MaskedScatter {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xf32>
+    spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+         : vector<4x!spirv.ptr<f32, CrossWorkgroup>>, i32,
+           vector<4xi1>, vector<4xf32>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @masked_scatter_i32
+  spirv.func @masked_scatter_i32(
+      %ptrs : vector<4x!spirv.ptr<i32, CrossWorkgroup>>,
+      %alignment : i32,
+      %mask : vector<4xi1>,
+      %values : vector<4xi32>) "None" {
+    // CHECK: spirv.INTEL.MaskedScatter {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32, vector<4xi1>, vector<4xi32>
+    spirv.INTEL.MaskedScatter %ptrs, %alignment, %mask, %values
+         : vector<4x!spirv.ptr<i32, CrossWorkgroup>>, i32,
+           vector<4xi1>, vector<4xi32>
+    spirv.Return
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list