[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_EXT_mesh_shader extension (PR #126555)

Igor Wodiany llvmlistbot at llvm.org
Mon Feb 10 09:42:22 PST 2025


https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/126555

This patch adds support for all enums and operations defined in the SPV_EXT_mesh_shader extension. Where in conflict with SPV_NV_mesh_shader definition, the EXT specification takes precedence, as duplicated enum values are not allowed. Enum values has been added manually, as define_enum.sh script, modifies files too aggressively - it adds all missing values from various extensions.

>From 0401fecae7264350972ff6ed7b93e3442563be7d Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Wed, 29 Jan 2025 14:41:17 +0000
Subject: [PATCH] [mlir][spirv] Add support for SPV_EXT_mesh_shader extension

This patch adds support for all enums and operations defined
in the SPV_EXT_mesh_shader extension. Where in conflict with
SPV_NV_mesh_shader definition, the EXT specification takes
precedence, as duplicated enum values are not allowed. Enum
values has been added manually, as define_enum.sh script,
modifies files too aggressively - it adds all missing values
from various extensions.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 112 ++++++++++----
 .../mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td     | 139 ++++++++++++++++++
 .../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td |   1 +
 mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt      |   1 +
 mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp         |  34 +++++
 mlir/test/Dialect/SPIRV/IR/availability.mlir  |  23 +++
 mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir      |  34 +++++
 mlir/test/Target/SPIRV/mesh-ops.mlir          |  33 +++++
 8 files changed, 349 insertions(+), 28 deletions(-)
 create mode 100755 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td
 create mode 100644 mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp
 create mode 100644 mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir
 create mode 100644 mlir/test/Target/SPIRV/mesh-ops.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6b2e4189aea028e..838f7cc70b0cf4a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -357,6 +357,7 @@ def SPV_EXT_shader_atomic_float_add      : I32EnumAttrCase<"SPV_EXT_shader_atomi
 def SPV_EXT_shader_atomic_float_min_max  : I32EnumAttrCase<"SPV_EXT_shader_atomic_float_min_max", 1009>;
 def SPV_EXT_shader_image_int64           : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>;
 def SPV_EXT_shader_atomic_float16_add    : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
+def SPV_EXT_mesh_shader                  : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
 
 def SPV_AMD_gpu_shader_half_float_fetch          : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
 def SPV_AMD_shader_ballot                        : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -443,6 +444,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
+      SPV_EXT_mesh_shader,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1207,6 +1209,12 @@ def SPIRV_C_MeshShadingNV                               : I32EnumAttrCase<"MeshS
     Extension<[SPV_NV_mesh_shader]>
   ];
 }
+def SPIRV_C_MeshShadingEXT                              : I32EnumAttrCase<"MeshShadingEXT", 5283> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>
+  ];
+}
 def SPIRV_C_FragmentDensityEXT                          : I32EnumAttrCase<"FragmentDensityEXT", 5291> {
   list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
   list<Availability> availability = [
@@ -1436,7 +1444,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_StorageBuffer8BitAccess, SPIRV_C_StoragePushConstant8,
       SPIRV_C_DenormPreserve, SPIRV_C_DenormFlushToZero, SPIRV_C_SignedZeroInfNanPreserve,
       SPIRV_C_RoundingModeRTE, SPIRV_C_RoundingModeRTZ, SPIRV_C_ImageFootprintNV,
-      SPIRV_C_FragmentBarycentricKHR, SPIRV_C_ComputeDerivativeGroupQuadsNV,
+      SPIRV_C_FragmentBarycentricKHR, SPIRV_C_MeshShadingEXT, SPIRV_C_ComputeDerivativeGroupQuadsNV,
       SPIRV_C_GroupNonUniformPartitionedNV, SPIRV_C_VulkanMemoryModel,
       SPIRV_C_VulkanMemoryModelDeviceScope, SPIRV_C_ComputeDerivativeGroupLinearNV,
       SPIRV_C_BindlessTextureNV, SPIRV_C_SubgroupShuffleINTEL,
@@ -1576,7 +1584,7 @@ def SPIRV_BI_InstanceId                  : I32EnumAttrCase<"InstanceId", 6> {
 }
 def SPIRV_BI_PrimitiveId                 : I32EnumAttrCase<"PrimitiveId", 7> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_Tessellation]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]>
   ];
 }
 def SPIRV_BI_InvocationId                : I32EnumAttrCase<"InvocationId", 8> {
@@ -1586,12 +1594,12 @@ def SPIRV_BI_InvocationId                : I32EnumAttrCase<"InvocationId", 8> {
 }
 def SPIRV_BI_Layer                       : I32EnumAttrCase<"Layer", 9> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]>
   ];
 }
 def SPIRV_BI_ViewportIndex               : I32EnumAttrCase<"ViewportIndex", 10> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]>
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]>
   ];
 }
 def SPIRV_BI_TessLevelOuter              : I32EnumAttrCase<"TessLevelOuter", 11> {
@@ -1769,8 +1777,8 @@ def SPIRV_BI_BaseInstance                : I32EnumAttrCase<"BaseInstance", 4425>
 }
 def SPIRV_BI_DrawIndex                   : I32EnumAttrCase<"DrawIndex", 4426> {
   list<Availability> availability = [
-    Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_BI_PrimitiveShadingRateKHR     : I32EnumAttrCase<"PrimitiveShadingRateKHR", 4432> {
@@ -1946,6 +1954,30 @@ def SPIRV_BI_FragInvocationCountEXT      : I32EnumAttrCase<"FragInvocationCountE
     Capability<[SPIRV_C_FragmentDensityEXT]>
   ];
 }
+def SPIRV_BI_PrimitivePointIndicesEXT     : I32EnumAttrCase<"PrimitivePointIndicesEXT", 5294> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_BI_PrimitiveLineIndicesEXT      : I32EnumAttrCase<"PrimitiveLineIndicesEXT", 5295> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_BI_PrimitiveTriangleIndicesEXT  : I32EnumAttrCase<"PrimitiveTriangleIndicesEXT", 5296> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_BI_CullPrimitiveEXT             : I32EnumAttrCase<"CullPrimitiveEXT", 5299> {
+  list<Availability> availability = [
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
 def SPIRV_BI_LaunchIdKHR                 : I32EnumAttrCase<"LaunchIdKHR", 5319> {
   list<Availability> availability = [
     Extension<[SPV_KHR_ray_tracing, SPV_NV_ray_tracing]>,
@@ -2102,7 +2134,9 @@ def SPIRV_BuiltInAttr :
       SPIRV_BI_ClipDistancePerViewNV, SPIRV_BI_CullDistancePerViewNV,
       SPIRV_BI_LayerPerViewNV, SPIRV_BI_MeshViewCountNV, SPIRV_BI_MeshViewIndicesNV,
       SPIRV_BI_BaryCoordKHR, SPIRV_BI_BaryCoordNoPerspKHR, SPIRV_BI_FragSizeEXT,
-      SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR,
+      SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_PrimitivePointIndicesEXT,
+      SPIRV_BI_PrimitiveLineIndicesEXT, SPIRV_BI_PrimitiveTriangleIndicesEXT,
+      SPIRV_BI_CullPrimitiveEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR,
       SPIRV_BI_WorldRayOriginKHR, SPIRV_BI_WorldRayDirectionKHR,
       SPIRV_BI_ObjectRayOriginKHR, SPIRV_BI_ObjectRayDirectionKHR, SPIRV_BI_RayTminKHR,
       SPIRV_BI_RayTmaxKHR, SPIRV_BI_InstanceCustomIndexKHR, SPIRV_BI_ObjectToWorldKHR,
@@ -2358,10 +2392,10 @@ def SPIRV_D_SecondaryViewportRelativeNV        : I32EnumAttrCase<"SecondaryViewp
     Capability<[SPIRV_C_ShaderStereoViewNV]>
   ];
 }
-def SPIRV_D_PerPrimitiveNV                     : I32EnumAttrCase<"PerPrimitiveNV", 5271> {
+def SPIRV_D_PerPrimitiveEXT                    : I32EnumAttrCase<"PerPrimitiveEXT", 5271> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_D_PerViewNV                          : I32EnumAttrCase<"PerViewNV", 5272> {
@@ -2660,7 +2694,7 @@ def SPIRV_DecorationAttr :
       SPIRV_D_AlignmentId, SPIRV_D_MaxByteOffsetId, SPIRV_D_NoSignedWrap,
       SPIRV_D_NoUnsignedWrap, SPIRV_D_ExplicitInterpAMD, SPIRV_D_OverrideCoverageNV,
       SPIRV_D_PassthroughNV, SPIRV_D_ViewportRelativeNV,
-      SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveNV, SPIRV_D_PerViewNV,
+      SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveEXT, SPIRV_D_PerViewNV,
       SPIRV_D_PerTaskNV, SPIRV_D_PerVertexKHR, SPIRV_D_NonUniform, SPIRV_D_RestrictPointer,
       SPIRV_D_AliasedPointer, SPIRV_D_BindlessSamplerNV, SPIRV_D_BindlessImageNV,
       SPIRV_D_BoundSamplerNV, SPIRV_D_BoundImageNV, SPIRV_D_SIMTCallINTEL,
@@ -2843,12 +2877,12 @@ def SPIRV_EM_Isolines                         : I32EnumAttrCase<"Isolines", 25>
 }
 def SPIRV_EM_OutputVertices                   : I32EnumAttrCase<"OutputVertices", 26> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_Tessellation]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]>
   ];
 }
 def SPIRV_EM_OutputPoints                     : I32EnumAttrCase<"OutputPoints", 27> {
   list<Availability> availability = [
-    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV]>
+    Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_EM_OutputLineStrip                  : I32EnumAttrCase<"OutputLineStrip", 28> {
@@ -3002,16 +3036,16 @@ def SPIRV_EM_StencilRefLessBackAMD            : I32EnumAttrCase<"StencilRefLessB
     Capability<[SPIRV_C_StencilExportEXT]>
   ];
 }
-def SPIRV_EM_OutputLinesNV                    : I32EnumAttrCase<"OutputLinesNV", 5269> {
+def SPIRV_EM_OutputLinesEXT                    : I32EnumAttrCase<"OutputLinesEXT", 5269> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
-def SPIRV_EM_OutputPrimitivesNV               : I32EnumAttrCase<"OutputPrimitivesNV", 5270> {
+def SPIRV_EM_OutputPrimitivesEXT              : I32EnumAttrCase<"OutputPrimitivesEXT", 5270> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_EM_DerivativeGroupQuadsNV           : I32EnumAttrCase<"DerivativeGroupQuadsNV", 5289> {
@@ -3026,10 +3060,10 @@ def SPIRV_EM_DerivativeGroupLinearNV          : I32EnumAttrCase<"DerivativeGroup
     Capability<[SPIRV_C_ComputeDerivativeGroupLinearNV]>
   ];
 }
-def SPIRV_EM_OutputTrianglesNV                : I32EnumAttrCase<"OutputTrianglesNV", 5298> {
+def SPIRV_EM_OutputTrianglesEXT               : I32EnumAttrCase<"OutputTrianglesEXT", 5298> {
   list<Availability> availability = [
-    Extension<[SPV_NV_mesh_shader]>,
-    Capability<[SPIRV_C_MeshShadingNV]>
+    Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]>
   ];
 }
 def SPIRV_EM_PixelInterlockOrderedEXT         : I32EnumAttrCase<"PixelInterlockOrderedEXT", 5366> {
@@ -3154,9 +3188,9 @@ def SPIRV_ExecutionModeAttr :
       SPIRV_EM_StencilRefReplacingEXT, SPIRV_EM_StencilRefUnchangedFrontAMD,
       SPIRV_EM_StencilRefGreaterFrontAMD, SPIRV_EM_StencilRefLessFrontAMD,
       SPIRV_EM_StencilRefUnchangedBackAMD, SPIRV_EM_StencilRefGreaterBackAMD,
-      SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesNV, SPIRV_EM_OutputPrimitivesNV,
-      SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV,
-      SPIRV_EM_OutputTrianglesNV, SPIRV_EM_PixelInterlockOrderedEXT,
+      SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesEXT,
+      SPIRV_EM_OutputPrimitivesEXT, SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV,
+      SPIRV_EM_OutputTrianglesEXT, SPIRV_EM_PixelInterlockOrderedEXT,
       SPIRV_EM_PixelInterlockUnorderedEXT, SPIRV_EM_SampleInterlockOrderedEXT,
       SPIRV_EM_SampleInterlockUnorderedEXT, SPIRV_EM_ShadingRateInterlockOrderedEXT,
       SPIRV_EM_ShadingRateInterlockUnorderedEXT, SPIRV_EM_SharedLocalMemorySizeINTEL,
@@ -3243,13 +3277,24 @@ def SPIRV_EM_CallableKHR            : I32EnumAttrCase<"CallableKHR", 5318> {
     Capability<[SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV]>
   ];
 }
+def SPIRV_EM_TaskEXT                : I32EnumAttrCase<"TaskEXT", 5364> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
+def SPIRV_EM_MeshEXT                : I32EnumAttrCase<"MeshEXT", 5365> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
 
 def SPIRV_ExecutionModelAttr :
     SPIRV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", "execution_model", [
       SPIRV_EM_Vertex, SPIRV_EM_TessellationControl, SPIRV_EM_TessellationEvaluation,
       SPIRV_EM_Geometry, SPIRV_EM_Fragment, SPIRV_EM_GLCompute, SPIRV_EM_Kernel,
       SPIRV_EM_TaskNV, SPIRV_EM_MeshNV, SPIRV_EM_RayGenerationKHR, SPIRV_EM_IntersectionKHR,
-      SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR
+      SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR,
+      SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
     ]>;
 
 def SPIRV_FC_None         : I32BitEnumAttrCaseNone<"None">;
@@ -3982,6 +4027,13 @@ def SPIRV_SC_PhysicalStorageBuffer   : I32EnumAttrCase<"PhysicalStorageBuffer",
     Capability<[SPIRV_C_PhysicalStorageBufferAddresses]>
   ];
 }
+def SPIRV_SC_TaskPayloadWorkgroupEXT : I32EnumAttrCase<"TaskPayloadWorkgroupEXT", 5402> {
+  list<Availability> availability = [
+    MinVersion<SPIRV_V_1_4>,
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+}
 def SPIRV_SC_CodeSectionINTEL        : I32EnumAttrCase<"CodeSectionINTEL", 5605> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_function_pointers]>,
@@ -4009,7 +4061,8 @@ def SPIRV_StorageClassAttr :
       SPIRV_SC_StorageBuffer, SPIRV_SC_CallableDataKHR, SPIRV_SC_IncomingCallableDataKHR,
       SPIRV_SC_RayPayloadKHR, SPIRV_SC_HitAttributeKHR, SPIRV_SC_IncomingRayPayloadKHR,
       SPIRV_SC_ShaderRecordBufferKHR, SPIRV_SC_PhysicalStorageBuffer,
-      SPIRV_SC_CodeSectionINTEL, SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
+      SPIRV_SC_TaskPayloadWorkgroupEXT, SPIRV_SC_CodeSectionINTEL,
+      SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
     ]>;
 
 def SPIRV_PVF_PackedVectorFormat4x8Bit : I32EnumAttrCase<"PackedVectorFormat4x8Bit", 0> {
@@ -4524,6 +4577,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR      : I32EnumAttrCase<"OpCooperativeMat
 def SPIRV_OC_OpCooperativeMatrixStoreKHR     : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
 def SPIRV_OC_OpCooperativeMatrixMulAddKHR    : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
 def SPIRV_OC_OpCooperativeMatrixLengthKHR    : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
+def SPIRV_OC_OpEmitMeshTasksEXT              : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>;
+def SPIRV_OC_OpSetMeshOutputsEXT             : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>;
 def SPIRV_OC_OpSubgroupBlockReadINTEL        : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
 def SPIRV_OC_OpSubgroupBlockWriteINTEL       : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
 def SPIRV_OC_OpAssumeTrueKHR                 : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
@@ -4622,7 +4677,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
       SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
       SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
-      SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpSubgroupBlockReadINTEL,
+      SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpEmitMeshTasksEXT,
+      SPIRV_OC_OpSetMeshOutputsEXT, SPIRV_OC_OpSubgroupBlockReadINTEL,
       SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR,
       SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL,
       SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td
new file mode 100755
index 000000000000000..a2e3d0509525fad
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td
@@ -0,0 +1,139 @@
+//===-- SPIRVMeshOps.td - MLIR SPIR-V Mesh Ops ------*- tablegen -*----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===------------------------------------------------------------------------------===//
+//
+// This file contains mesh ops for the SPIR-V dialect. It corresponds
+// to the part of "3.52.25. Reserved Instructions" of the SPIR-V specification, and
+// to the SPV_EXT_mesh_shader specification.
+//
+//===------------------------------------------------------------------------ -----===//
+
+#ifndef MLIR_DIALECT_SPIRV_MESH_OPS
+#define MLIR_DIALECT_SPIRV_MESH_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+
+// -----
+
+def SPIRV_EXTEmitMeshTasksOp : SPIRV_ExtVendorOp<"EmitMeshTasks", [Terminator]> {
+  let summary = [{
+    Defines the grid size of subsequent mesh shader workgroups to generate upon
+    completion of the task shader workgroup.
+  }];
+
+  let description = [{
+    Defines the grid size of subsequent mesh shader workgroups to generate upon
+    completion of the task shader workgroup.
+
+    Group Count X Y Z must each be a 32-bit unsigned integer value. They
+    configure the number of local workgroups in each respective dimensions for the
+    launch of child mesh tasks. See Vulkan API specification for more detail.
+
+    Payload is an optional pointer to the payload structure to pass to the
+    generated mesh shader invocations. Payload must be the result of an OpVariable
+    with a storage class of TaskPayloadWorkgroupEXT.
+
+    The arguments are taken from the first invocation in each workgroup.
+    Behaviour is undefined if any invocation terminates without executing this
+    instruction, or if any invocation executes this instruction in non-uniform
+    control flow.
+
+    This instruction also serves as an OpControlBarrier instruction, and also
+    performs and adheres to the description and semantics of an OpControlBarrier
+    instruction with the Execution and Memory operands set to Workgroup and the
+    Semantics operand set to a combination of WorkgroupMemory and AcquireRelease.
+
+    Ceases all further processing: Only instructions executed before
+    OpEmitMeshTasksEXT have observable side effects.
+
+    This instruction must be the last instruction in a block.
+
+    This instruction is only valid in the TaskEXT Execution Model.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    spirv.EmitMeshTasksEXT %x, %y, %z : i32, i32, i32
+    spirv.EmitMeshTasksEXT %x, %x, %z, %payload : i32, i32, i32, !spirv.ptr<i32, TaskPayloadWorkgroupEXT>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_4>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+
+  let arguments = (ins
+    SignlessOrUnsignedIntOfWidths<[32]>:$group_count_x,
+    SignlessOrUnsignedIntOfWidths<[32]>:$group_count_y,
+    SignlessOrUnsignedIntOfWidths<[32]>:$group_count_z,
+    Optional<SPIRV_AnyPtr>:$payload
+  );
+
+  let results = (outs);
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type(operands)
+  }];
+}
+
+// -----
+
+def SPIRV_EXTSetMeshOutputsOp : SPIRV_ExtVendorOp<"SetMeshOutputs", []> {
+  let summary = [{
+    Sets the actual output size of the primitives and vertices that the mesh
+    shader workgroup will emit upon completion.
+  }];
+
+  let description = [{
+    Vertex Count must be a 32-bit unsigned integer value. It defines the array size
+    of per-vertex outputs.
+
+    Primitive Count must a 32-bit unsigned integer value. It defines the array size
+    of per-primitive outputs.
+
+    The arguments are taken from the first invocation in each workgroup. Behavior
+    is undefined if any invocation executes this instruction more than once or
+    under non-uniform control flow. Behavior is undefined if there is any control
+    flow path to an output write that is not preceded by this instruction.
+
+    This instruction is only valid in the MeshEXT Execution Model.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    spirv.SetMeshOutputsEXT %vcount, %pcount : i32, i32
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_4>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_EXT_mesh_shader]>,
+    Capability<[SPIRV_C_MeshShadingEXT]>
+  ];
+
+  let arguments = (ins
+    SignlessOrUnsignedIntOfWidths<[32]>:$vertex_count,
+    SignlessOrUnsignedIntOfWidths<[32]>:$primitive_count
+  );
+
+  let results = (outs);
+  let hasVerifier = 0;
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type(operands)
+  }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_MESH_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index ff1ca89f93b5acc..0fa1bb9d5bd0184 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -38,6 +38,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td"
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 7d760e0dd802222..ae8ad5a491ff2b1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
   GroupOps.cpp
   IntegerDotProductOps.cpp
   MemoryOps.cpp
+  MeshOps.cpp
   SPIRVAttributes.cpp
   SPIRVCanonicalization.cpp
   SPIRVGLCanonicalization.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp
new file mode 100644
index 000000000000000..d18f7644dfb64b9
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp
@@ -0,0 +1,34 @@
+//===- MeshOps.cpp - MLIR SPIR-V Mesh Ops  --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the mesh operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+
+#include <optional>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// spirv.EXT.EmitMeshTasks
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::EXTEmitMeshTasksOp::verify() {
+  if (auto payloadOp = getPayload()) {
+    auto payloadType = cast<spirv::PointerType>(payloadOp.getType());
+    if (payloadType.getStorageClass() !=
+        spirv::StorageClass::TaskPayloadWorkgroupEXT)
+      return emitOpError("payload must be a variable with a storage class of "
+                         "TaskPayloadWorkgroupEXT");
+  }
+  return success();
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 31a90ad0329d809..64ba8e3fc249ee5 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -255,3 +255,26 @@ func.func @end_primitive() -> () {
   spirv.EndPrimitive
   return
 }
+
+//===----------------------------------------------------------------------===//
+// Mesh ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: emit_mesh_tasks
+func.func @emit_mesh_tasks(%0 : i32) -> () {
+  // CHECK: min version: v1.4
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_EXT_mesh_shader] ]
+  // CHECK: capabilities: [ [MeshShadingEXT] ]
+  spirv.EXT.EmitMeshTasks %0, %0, %0 : i32, i32, i32
+}
+
+// CHECK-LABEL: set_mesh_outputs
+func.func @set_mesh_outputs(%0 : i32, %1 : i32) -> () {
+  // CHECK: min version: v1.4
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_EXT_mesh_shader] ]
+  // CHECK: capabilities: [ [MeshShadingEXT] ]
+  spirv.EXT.SetMeshOutputs %0, %1 : i32, i32
+  spirv.Return
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir b/mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir
new file mode 100644
index 000000000000000..436f7d1c9fb1571
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.EmitMeshTasksEXT
+//===----------------------------------------------------------------------===//
+
+func.func @emit_mesh_tasks(%0 : i32) {
+  // CHECK: spirv.EXT.EmitMeshTasks {{%.*}}, {{%.*}}, {{%.*}} : i32, i32, i32
+  spirv.EXT.EmitMeshTasks %0, %0, %0 : i32, i32, i32
+}
+
+func.func @emit_mesh_tasks_payload(%0 : i32, %1 : !spirv.ptr<i32, TaskPayloadWorkgroupEXT>) {
+  // CHECK: spirv.EXT.EmitMeshTasks {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : i32, i32, i32, !spirv.ptr<i32, TaskPayloadWorkgroupEXT>
+  spirv.EXT.EmitMeshTasks %0, %0, %0, %1 : i32, i32, i32, !spirv.ptr<i32, TaskPayloadWorkgroupEXT>
+}
+
+// -----
+
+func.func @emit_mesh_tasks_wrong_payload(%0 : i32, %1 : !spirv.ptr<i32, Image>) {
+  // expected-error @+1 {{payload must be a variable with a storage class of TaskPayloadWorkgroupEXT}}
+  spirv.EXT.EmitMeshTasks %0, %0, %0, %1 : i32, i32, i32, !spirv.ptr<i32, Image>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SetMeshOutputsEXT
+//===----------------------------------------------------------------------===//
+
+func.func @set_mesh_outputs(%0 : i32, %1 : i32) {
+  // CHECK: spirv.EXT.SetMeshOutputs {{%.*}}, {{%.*}} : i32, i32
+  spirv.EXT.SetMeshOutputs %0, %1 : i32, i32
+  spirv.Return
+}
diff --git a/mlir/test/Target/SPIRV/mesh-ops.mlir b/mlir/test/Target/SPIRV/mesh-ops.mlir
new file mode 100644
index 000000000000000..3b937072de04e48
--- /dev/null
+++ b/mlir/test/Target/SPIRV/mesh-ops.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.4, [MeshShadingEXT], [SPV_EXT_mesh_shader]> {
+  // CHECK-LABEL: @emit_mesh_tasks
+  spirv.func @emit_mesh_tasks() "None" {
+    %0 = spirv.Constant 1 : i32
+    // CHECK: spirv.EXT.EmitMeshTasks {{%.*}}, {{%.*}}, {{%.*}} : i32, i32, i32
+    spirv.EXT.EmitMeshTasks %0, %0, %0 : i32, i32, i32
+  }
+  // CHECK-LABEL: @set_mesh_outputs
+  spirv.func @set_mesh_outputs(%0 : i32, %1 : i32) "None" {
+    // CHECK: spirv.EXT.SetMeshOutputs {{%.*}}, {{%.*}} : i32, i32
+    spirv.EXT.SetMeshOutputs %0, %1 : i32, i32
+    spirv.Return
+  }
+  // CHECK: spirv.EntryPoint "TaskEXT" {{@.*}}
+  spirv.EntryPoint "TaskEXT" @emit_mesh_tasks
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.4, [MeshShadingEXT], [SPV_EXT_mesh_shader]> {
+  spirv.GlobalVariable @payload : !spirv.ptr<i32, TaskPayloadWorkgroupEXT>
+  // CHECK-LABEL: @emit_mesh_tasks_payload
+  spirv.func @emit_mesh_tasks_payload() "None" {
+    %0 = spirv.Constant 1 : i32
+    %1 = spirv.mlir.addressof @payload : !spirv.ptr<i32, TaskPayloadWorkgroupEXT>
+    // CHECK: spirv.EXT.EmitMeshTasks {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : i32, i32, i32, !spirv.ptr<i32, TaskPayloadWorkgroupEXT>
+    spirv.EXT.EmitMeshTasks %0, %0, %0, %1 : i32, i32, i32, !spirv.ptr<i32, TaskPayloadWorkgroupEXT>
+  }
+  // CHECK: spirv.EntryPoint "TaskEXT" {{@.*}}, {{@.*}}
+  spirv.EntryPoint "TaskEXT" @emit_mesh_tasks_payload, @payload
+}



More information about the Mlir-commits mailing list