[Mlir-commits] [mlir] fef74e1 - [mlir][spirv] add ExecutionModeIdOp (#186241)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 18 03:23:38 PDT 2026


Author: Emimendoza
Date: 2026-03-18T06:23:33-04:00
New Revision: fef74e1c005d49804ea906bff39f759373681923

URL: https://github.com/llvm/llvm-project/commit/fef74e1c005d49804ea906bff39f759373681923
DIFF: https://github.com/llvm/llvm-project/commit/fef74e1c005d49804ea906bff39f759373681923.diff

LOG: [mlir][spirv] add ExecutionModeIdOp (#186241)

Adds OpExecutionModeId from spirv 1.2

---------

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    mlir/test/Target/SPIRV/execution-mode-id.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/test/Dialect/SPIRV/IR/structure-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index fc9e2f11092e6..8badb84a879fa 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4583,6 +4583,7 @@ def SPIRV_OC_OpGroupUMax                      : I32EnumAttrCase<"OpGroupUMax", 2
 def SPIRV_OC_OpGroupSMax                      : I32EnumAttrCase<"OpGroupSMax", 271>;
 def SPIRV_OC_OpNoLine                         : I32EnumAttrCase<"OpNoLine", 317>;
 def SPIRV_OC_OpModuleProcessed                : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPIRV_OC_OpExecutionModeId                : I32EnumAttrCase<"OpExecutionModeId", 331>;
 def SPIRV_OC_OpGroupNonUniformElect           : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
 def SPIRV_OC_OpGroupNonUniformAll             : I32EnumAttrCase<"OpGroupNonUniformAll", 334>;
 def SPIRV_OC_OpGroupNonUniformAny             : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
@@ -4726,6 +4727,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
       SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
       SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
+      SPIRV_OC_OpExecutionModeId,
       SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
       SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
       SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 9959f0bec781e..6899fc2dabc70 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -290,6 +290,71 @@ def SPIRV_ExecutionModeOp : SPIRV_Op<"ExecutionMode", [InModuleScope]> {
 
 // -----
 
+def SPIRV_ExecutionModeIdOp : SPIRV_Op<"ExecutionModeId", []> {
+  let summary = [{
+    Declare an execution mode for an entry point, using <id>s as Extra
+    Operands.
+  }];
+
+  let description = [{
+    Entry Point must be the Entry Point <id> operand of an OpEntryPoint
+    instruction.
+
+    Mode is the execution mode. See Execution Mode.
+
+    This instruction is only valid if the Mode operand is an execution mode
+    that takes Extra Operands that are <id> operands. Otherwise, use
+    OpExecutionMode.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    execution-mode ::= "Invocations" | "SpacingEqual" |
+                           <and other SPIR-V execution modes...>
+    execution-mode-id-op ::= `spirv.ExecutionMode ` ssa-use execution-mode
+                              symbol-reference (`, ` symbol-reference)*
+    ```
+
+    #### Example:
+
+    ```mlir
+    spirv.ExecutionModeId @foo "LocalSizeId" @var0, @var1, @var2
+    spirv.ExecutionModeId @bar "LocalSizeHintId" @x, @y, @z
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_2>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[]>
+  ];
+
+  let arguments = (ins
+    FlatSymbolRefAttr:$fn,
+    SPIRV_ExecutionModeAttr:$execution_mode, 
+    SymbolRefArrayAttr:$values
+  );
+
+  let results = (outs);
+  
+  let hasVerifier = 1;
+  
+  let autogenSerialization = 0;
+
+  let builders = [OpBuilder<(ins "spirv::FuncOp":$function,
+                                 "spirv::ExecutionMode":$executionMode,
+                                 "ArrayRef<Attribute>":$params),
+                            [{
+      build($_builder, $_state,
+            SymbolRefAttr::get(function),
+            ExecutionModeAttr::get($_builder.getContext(), executionMode),
+            $_builder.getArrayAttr(params));
+    }]>];
+}
+
+// -----
+
 def SPIRV_FuncOp : SPIRV_Op<"func", [
     AutomaticAllocationScope, FunctionOpInterface,
     InModuleScope, IsolatedFromAbove

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 41a0e8558ed89..cecc8c2194237 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -919,6 +919,76 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
     printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.ExecutionModeId
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
+                                            OperationState &result) {
+  ExecutionMode execMode;
+  if (Attribute fn;
+      parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
+      parseEnumStrAttr<ExecutionModeAttr>(execMode, parser, result)) {
+    return failure();
+  }
+
+  SmallVector<Attribute, 4> values;
+  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+        FlatSymbolRefAttr attr;
+        if (parser.parseAttribute(attr))
+          return failure();
+        values.push_back(attr);
+        return success();
+      })) {
+    return failure();
+  }
+
+  StringRef valuesAttrName = getValuesAttrName(result.name);
+  ArrayAttr valuesAttr = parser.getBuilder().getArrayAttr(values);
+  result.addAttribute(valuesAttrName, valuesAttr);
+  return success();
+}
+
+void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getFn());
+  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\" ";
+
+  llvm::interleaveComma(
+      getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
+      [&](StringRef value) { printer.printSymbolName(value); });
+}
+
+LogicalResult spirv::ExecutionModeIdOp::verify() {
+  // Valid as of SPIRV 1.6
+  switch (getExecutionMode()) {
+  case ExecutionMode::SubgroupsPerWorkgroupId:
+  case ExecutionMode::LocalSizeId:
+  case ExecutionMode::LocalSizeHintId:
+    break;
+  default:
+    return emitOpError("expected ExecutionMode that takes extra operands that "
+                       "are <id> operands, got: ")
+           << stringifyExecutionMode(getExecutionMode());
+  }
+
+  if (getValues().empty())
+    return emitOpError("expected at least one value operand");
+
+  for (Attribute value : getValues()) {
+    auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
+    if (!valueSymbol)
+      return emitOpError("expected value operands to be symbol reference");
+    Operation *valueOp = SymbolTable::lookupNearestSymbolFrom(
+        (*this)->getParentOp(), valueSymbol);
+    if (!valueOp)
+      return emitOpError("cannot find symbol referenced by value operand: ")
+             << valueSymbol.getValue();
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.func
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 5b04a14a78036..cc6302126d64a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -147,6 +147,7 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processMemoryModel(operands);
   case spirv::Opcode::OpEntryPoint:
   case spirv::Opcode::OpExecutionMode:
+  case spirv::Opcode::OpExecutionModeId:
     if (deferInstructions) {
       deferredInstructions.emplace_back(opcode, operands);
       return success();
@@ -453,6 +454,42 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
   return success();
 }
 
+template <>
+LogicalResult
+Deserializer::processOp<spirv::ExecutionModeIdOp>(ArrayRef<uint32_t> words) {
+  unsigned wordIndex = 0;
+  unsigned const wordsSize = words.size();
+  if (wordIndex >= wordsSize)
+    return emitError(unknownLoc,
+                     "missing function result <id> in OpExecutionModeId");
+
+  // Get the function <id> to get the name of the function.
+  uint32_t fnID = words[wordIndex++];
+  FuncOp fn = getFunction(fnID);
+  if (!fn)
+    return emitError(unknownLoc, "no function matching <id> ") << fnID;
+
+  // Get the Execution mode.
+  if (wordIndex >= wordsSize)
+    return emitError(unknownLoc, "missing Execution Mode in OpExecutionModeId");
+
+  ExecutionModeAttr execMode = spirv::ExecutionModeAttr::get(
+      context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
+
+  // Get the values.
+  SmallVector<Attribute, 4> attrListElems;
+  while (wordIndex < words.size()) {
+    std::string id = getSpecConstantSymbol(words[wordIndex++]);
+    attrListElems.push_back(FlatSymbolRefAttr::get(context, id));
+  }
+  ArrayAttr values = opBuilder.getArrayAttr(attrListElems);
+  spirv::ExecutionModeIdOp::create(
+      opBuilder, unknownLoc,
+      SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode,
+      values);
+  return success();
+}
+
 template <>
 LogicalResult
 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index b78fac532d8c5..a2c942d4188e7 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -884,6 +884,34 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
   return success();
 }
 
+template <>
+LogicalResult
+Serializer::processOp<spirv::ExecutionModeIdOp>(spirv::ExecutionModeIdOp op) {
+  SmallVector<uint32_t, 4> operands;
+  // Add the function <id>.
+  uint32_t funcID = getFunctionID(op.getFn());
+  if (!funcID)
+    return op.emitError("missing <id> for function ")
+           << op.getFn()
+           << "; function needs to be serialized before ExecutionModeIdOp is "
+              "serialized";
+
+  operands.push_back(funcID);
+  operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
+
+  for (Attribute refVal : op.getValues().getValue()) {
+    uint32_t id = getSpecConstID(cast<FlatSymbolRefAttr>(refVal).getValue());
+    if (!id)
+      return op.emitError("unknown <id> for specialization constant ")
+             << cast<FlatSymbolRefAttr>(refVal).getValue();
+
+    operands.push_back(id);
+  }
+  encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionModeId,
+                        operands);
+  return success();
+}
+
 template <>
 LogicalResult
 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {

diff  --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 7e37826795d83..e12b70cc5c139 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -327,6 +327,74 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.ExecutionModeId
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 {
+   spirv.SpecConstant @x = 3 : i32
+   spirv.SpecConstant @y = 4 : i32
+   spirv.SpecConstant @z = 5 : i32
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // CHECK: spirv.ExecutionModeId {{@.*}} "LocalSizeHintId" @x, @y, @z
+   spirv.ExecutionModeId @do_nothing "LocalSizeHintId" @x, @y, @z
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // expected-error @+1 {{expected attribute value}}
+   spirv.ExecutionModeId @do_nothing "LocalSizeId"
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+    spirv.SpecConstant @x = 3 : i32
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // expected-error @+1 {{'spirv.ExecutionModeId' op expected ExecutionMode that takes extra operands that are <id> operands, got: ContractionOff}}
+   spirv.ExecutionModeId @do_nothing "ContractionOff" @x
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+   spirv.SpecConstant @x = 3 : i32
+   spirv.SpecConstant @y = 4 : i32
+   spirv.SpecConstant @z = 5 : i32
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid execution_mode attribute specification: "GLCompute"}}
+   spirv.ExecutionModeId @do_nothing "GLCompute" @x, @y, @z
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+   spirv.SpecConstant @x = 3 : i32
+   spirv.SpecConstant @y = 4 : i32
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid kind of attribute specified}}
+   spirv.ExecutionModeId @do_nothing "LocalSizeId" @x, @y, 2
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.func
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/SPIRV/execution-mode-id.mlir b/mlir/test/Target/SPIRV/execution-mode-id.mlir
new file mode 100644
index 0000000000000..f5975655fcda3
--- /dev/null
+++ b/mlir/test/Target/SPIRV/execution-mode-id.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
+
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.2, [Shader], []> {
+  spirv.SpecConstant @x = 3 : i32
+  spirv.SpecConstant @y = 4 : i32
+  spirv.SpecConstant @z = 5 : i32
+  spirv.func @foo() -> () "None" {
+    spirv.Return
+  }
+  spirv.EntryPoint "GLCompute" @foo
+  // CHECK: spirv.ExecutionModeId @foo "LocalSizeId" @x, @y, @z
+  spirv.ExecutionModeId @foo "LocalSizeId" @x, @y, @z
+}


        


More information about the Mlir-commits mailing list