[Mlir-commits] [mlir] [mlir][spirv] add ExecutionModeIdOp (PR #186241)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 17 15:20:20 PDT 2026


https://github.com/Emimendoza updated https://github.com/llvm/llvm-project/pull/186241

>From a44c8b52ac1b0062f50a59fe3ef5cf741539037a Mon Sep 17 00:00:00 2001
From: Emilio M <emendoz at clemson.edu>
Date: Thu, 12 Mar 2026 17:41:40 -0400
Subject: [PATCH 1/3] [mlir][spirv] add ExecutionModeIdOp

---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  2 +
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     | 58 ++++++++++++++++
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 68 +++++++++++++++++++
 .../SPIRV/Deserialization/DeserializeOps.cpp  | 36 ++++++++++
 .../SPIRV/Serialization/SerializeOps.cpp      | 32 +++++++++
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 50 ++++++++++++++
 mlir/test/Target/SPIRV/execution-mode-id.mlir | 18 +++++
 7 files changed, 264 insertions(+)
 create mode 100644 mlir/test/Target/SPIRV/execution-mode-id.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index dd5d00de4147d..1e038c5b76906 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4583,6 +4583,7 @@ def SPIRV_OC_OpGroupUMax                      : I32EnumAttrCase<"OpGroupUMax", 2
 def SPIRV_OC_OpGroupSMax                      : I32EnumAttrCase<"OpGroupSMax", 271>;
 def SPIRV_OC_OpNoLine                         : I32EnumAttrCase<"OpNoLine", 317>;
 def SPIRV_OC_OpModuleProcessed                : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPIRV_OC_OpExecutionModeId                : I32EnumAttrCase<"OpExecutionModeId", 331>;
 def SPIRV_OC_OpGroupNonUniformElect           : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
 def SPIRV_OC_OpGroupNonUniformAll             : I32EnumAttrCase<"OpGroupNonUniformAll", 334>;
 def SPIRV_OC_OpGroupNonUniformAny             : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
@@ -4725,6 +4726,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
       SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
       SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
+      SPIRV_OC_OpExecutionModeId,
       SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
       SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
       SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 9959f0bec781e..58306e1434790 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -290,6 +290,64 @@ def SPIRV_ExecutionModeOp : SPIRV_Op<"ExecutionMode", [InModuleScope]> {
 
 // -----
 
+def SPIRV_ExecutionModeIdOp : SPIRV_Op<"ExecutionModeId", []> {
+  let summary = [{
+    Declare an execution mode for an entry point, using <id>s as Extra
+    Operands.
+  }];
+
+  let description = [{
+    Entry Point must be the Entry Point <id> operand of an OpEntryPoint
+    instruction.
+
+    Mode is the execution mode. See Execution Mode.
+
+    This instruction is only valid if the Mode operand is an execution mode
+    that takes Extra Operands that are <id> operands. Otherwise, use
+    OpExecutionMode.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    execution-mode ::= "Invocations" | "SpacingEqual" |
+                           <and other SPIR-V execution modes...>
+    execution-mode-id-op ::= `spirv.ExecutionMode ` ssa-use execution-mode
+                              symbol-reference (`, ` symbol-reference)*
+    ```
+
+    #### Example:
+
+    ```mlir
+    spirv.ExecutionModeId @foo "LocalSizeId" @var0, @var1, @var2
+    spirv.ExecutionModeId @bar "LocalSizeHintId", @x, @y, @z
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_2>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[]>
+  ];
+
+  let arguments = (ins
+    FlatSymbolRefAttr:$fn,
+    SPIRV_ExecutionModeAttr:$execution_mode, 
+    SymbolRefArrayAttr:$values
+  );
+
+  let results = (outs);
+  
+  let hasVerifier = 1;
+  
+  let autogenSerialization = 0;
+  
+  let builders = [OpBuilder<(ins "spirv::FuncOp":$function,
+      "spirv::ExecutionMode":$executionMode, "ArrayRef<Attribute>":$params)>];
+}
+
+// -----
+
 def SPIRV_FuncOp : SPIRV_Op<"func", [
     AutomaticAllocationScope, FunctionOpInterface,
     InModuleScope, IsolatedFromAbove
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 0f039d89b8fab..cd6ef91ab59c3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -919,6 +919,74 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
     printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.ExecutionModeId
+//===----------------------------------------------------------------------===//
+
+void spirv::ExecutionModeIdOp::build(OpBuilder &builder, OperationState &state,
+                                     FuncOp function,
+                                     ExecutionMode executionMode,
+                                     ArrayRef<Attribute> params) {
+  build(builder, state, SymbolRefAttr::get(function),
+        ExecutionModeAttr::get(builder.getContext(), executionMode),
+        builder.getArrayAttr(params));
+}
+
+ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
+                                            OperationState &result) {
+  ExecutionMode execMode;
+  if (Attribute fn;
+      parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
+      parseEnumStrAttr<ExecutionModeAttr>(execMode, parser, result)) {
+    return failure();
+  }
+
+  SmallVector<Attribute, 4> values;
+  while (!parser.parseOptionalComma()) {
+    FlatSymbolRefAttr attr;
+    if (parser.parseAttribute(attr)) {
+      return failure();
+    }
+    values.push_back(attr);
+  }
+
+  StringRef valuesAttrName = getValuesAttrName(result.name);
+  ArrayAttr valuesAttr = parser.getBuilder().getArrayAttr(values);
+  result.addAttribute(valuesAttrName, valuesAttr);
+  return success();
+}
+
+void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getFn());
+  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
+  for (const auto &value : getValues()) {
+    printer << ", ";
+    printer.printSymbolName(cast<FlatSymbolRefAttr>(value).getValue());
+  }
+}
+
+LogicalResult spirv::ExecutionModeIdOp::verify() {
+  // TODO: Add check to ensure that ExecutionMode is an execution mode that
+  //       takes Extra Operands that are <id> operands
+  if (getValues().empty())
+    return emitOpError("expected at least one value operand");
+
+  for (const auto &value : getValues()) {
+    auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
+    if (!valueSymbol)
+      return emitOpError("expected value operands to be symbol reference");
+    Operation *valueOp = SymbolTable::lookupNearestSymbolFrom(
+        (*this)->getParentOp(), valueSymbol);
+    if (!valueOp) {
+      return emitOpError("cannot find symbol referenced by value operand: ")
+             << valueSymbol.getValue();
+    }
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.func
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 5b04a14a78036..e38d3bc2c845a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -147,6 +147,7 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processMemoryModel(operands);
   case spirv::Opcode::OpEntryPoint:
   case spirv::Opcode::OpExecutionMode:
+  case spirv::Opcode::OpExecutionModeId:
     if (deferInstructions) {
       deferredInstructions.emplace_back(opcode, operands);
       return success();
@@ -453,6 +454,41 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
   return success();
 }
 
+template <>
+LogicalResult
+Deserializer::processOp<spirv::ExecutionModeIdOp>(ArrayRef<uint32_t> words) {
+  unsigned wordIndex = 0;
+  if (wordIndex >= words.size()) {
+    return emitError(unknownLoc,
+                     "missing function result <id> in OpExecutionMode");
+  }
+  // Get the function <id> to get the name of the function
+  auto fnID = words[wordIndex++];
+  auto fn = getFunction(fnID);
+  if (!fn) {
+    return emitError(unknownLoc, "no function matching <id> ") << fnID;
+  }
+  // Get the Execution mode
+  if (wordIndex >= words.size()) {
+    return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
+  }
+  auto execMode = spirv::ExecutionModeAttr::get(
+      context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
+
+  // Get the values
+  SmallVector<Attribute, 4> attrListElems;
+  while (wordIndex < words.size()) {
+    auto id = getSpecConstantSymbol(words[wordIndex++]);
+    attrListElems.push_back(FlatSymbolRefAttr::get(context, id));
+  }
+  auto values = opBuilder.getArrayAttr(attrListElems);
+  spirv::ExecutionModeIdOp::create(
+      opBuilder, unknownLoc,
+      SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode,
+      values);
+  return success();
+}
+
 template <>
 LogicalResult
 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index b78fac532d8c5..f6dc23b46f2b5 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -884,6 +884,38 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
   return success();
 }
 
+template <>
+LogicalResult
+Serializer::processOp<spirv::ExecutionModeIdOp>(spirv::ExecutionModeIdOp op) {
+  SmallVector<uint32_t, 4> operands;
+  // Add the function <id>.
+  auto funcID = getFunctionID(op.getFn());
+  if (!funcID) {
+    return op.emitError("missing <id> for function ")
+           << op.getFn()
+           << "; function needs to be serialized before ExecutionModeIdOp is "
+              "serialized";
+  }
+  operands.push_back(funcID);
+  // Add the ExecutionMode.
+  operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
+
+  // Serialize values if any.
+  if (const auto values = op.getValues(); values) {
+    for (auto &refVal : values.getValue()) {
+      auto id = getSpecConstID(cast<FlatSymbolRefAttr>(refVal).getValue());
+      if (!id) {
+        return op.emitError("unknown <id> for specialization constant ")
+               << cast<FlatSymbolRefAttr>(refVal).getValue();
+      }
+      operands.push_back(id);
+    }
+  }
+  encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionModeId,
+                        operands);
+  return success();
+}
+
 template <>
 LogicalResult
 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 7e37826795d83..36dfe40ed7756 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -327,6 +327,56 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.ExecutionModeId
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 {
+   spirv.SpecConstant @x = 3 : i32
+   spirv.SpecConstant @y = 4 : i32
+   spirv.SpecConstant @z = 5 : i32
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // CHECK: spirv.ExecutionModeId {{@.*}} "LocalSizeHintId", @x, @y, @z
+   spirv.ExecutionModeId @do_nothing "LocalSizeHintId", @x, @y, @z
+}
+// -----
+spirv.module Logical GLSL450 {
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // expected-error @+1 {{'spirv.ExecutionModeId' op expected at least one value operand}}
+   spirv.ExecutionModeId @do_nothing "ContractionOff"
+}
+// -----
+spirv.module Logical GLSL450 {
+   spirv.SpecConstant @x = 3 : i32
+   spirv.SpecConstant @y = 4 : i32
+   spirv.SpecConstant @z = 5 : i32
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid execution_mode attribute specification: "GLCompute"}}
+   spirv.ExecutionModeId @do_nothing "GLCompute", @x, @y, @z
+}
+// -----
+spirv.module Logical GLSL450 {
+   spirv.SpecConstant @x = 3 : i32
+   spirv.SpecConstant @y = 4 : i32
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid kind of attribute specified}}
+   spirv.ExecutionModeId @do_nothing "LocalSizeId", @x, @y, 2
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.func
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/execution-mode-id.mlir b/mlir/test/Target/SPIRV/execution-mode-id.mlir
new file mode 100644
index 0000000000000..dce4cf14144d4
--- /dev/null
+++ b/mlir/test/Target/SPIRV/execution-mode-id.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.2, [Shader], []> {
+  spirv.SpecConstant @x = 3 : i32
+  spirv.SpecConstant @y = 4 : i32
+  spirv.SpecConstant @z = 5 : i32
+  spirv.func @foo() -> () "None" {
+    spirv.Return
+  }
+  spirv.EntryPoint "GLCompute" @foo
+  // CHECK: spirv.ExecutionModeId @foo "LocalSizeId", @x, @y, @z
+  spirv.ExecutionModeId @foo "LocalSizeId", @x, @y, @z
+}

>From 5a2b59ea8e3fd7a2684c423240473b89e079721f Mon Sep 17 00:00:00 2001
From: Emilio M <emendoz at clemson.edu>
Date: Sat, 14 Mar 2026 17:09:02 -0400
Subject: [PATCH 2/3] requested changes from review

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 27 ++++++++++++-------
 .../SPIRV/Deserialization/DeserializeOps.cpp  | 26 +++++++++---------
 .../SPIRV/Serialization/SerializeOps.cpp      | 23 +++++++---------
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 17 ++++++++++++
 4 files changed, 58 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index cd6ef91ab59c3..a43731f060164 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -959,16 +959,26 @@ ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
 void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
   printer << " ";
   printer.printSymbolName(getFn());
-  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
-  for (const auto &value : getValues()) {
-    printer << ", ";
-    printer.printSymbolName(cast<FlatSymbolRefAttr>(value).getValue());
-  }
+  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\", ";
+
+  llvm::interleaveComma(
+      getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
+      [&](StringRef value) { printer.printSymbolName(value); });
 }
 
 LogicalResult spirv::ExecutionModeIdOp::verify() {
-  // TODO: Add check to ensure that ExecutionMode is an execution mode that
-  //       takes Extra Operands that are <id> operands
+  // Valid as of SPIRV 1.6
+  switch (getExecutionMode()) {
+  case ExecutionMode::SubgroupsPerWorkgroupId:
+  case ExecutionMode::LocalSizeId:
+  case ExecutionMode::LocalSizeHintId:
+    break;
+  default:
+    return emitOpError("expected ExecutionMode that takes extra operands that "
+                       "are <id> operands, got: ")
+           << stringifyExecutionMode(getExecutionMode());
+  }
+
   if (getValues().empty())
     return emitOpError("expected at least one value operand");
 
@@ -978,10 +988,9 @@ LogicalResult spirv::ExecutionModeIdOp::verify() {
       return emitOpError("expected value operands to be symbol reference");
     Operation *valueOp = SymbolTable::lookupNearestSymbolFrom(
         (*this)->getParentOp(), valueSymbol);
-    if (!valueOp) {
+    if (!valueOp)
       return emitOpError("cannot find symbol referenced by value operand: ")
              << valueSymbol.getValue();
-    }
   }
 
   return success();
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index e38d3bc2c845a..74ec3bcc5db00 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -458,30 +458,30 @@ template <>
 LogicalResult
 Deserializer::processOp<spirv::ExecutionModeIdOp>(ArrayRef<uint32_t> words) {
   unsigned wordIndex = 0;
-  if (wordIndex >= words.size()) {
+  if (wordIndex >= words.size())
     return emitError(unknownLoc,
-                     "missing function result <id> in OpExecutionMode");
-  }
+                     "missing function result <id> in OpExecutionModeId");
+
   // Get the function <id> to get the name of the function
-  auto fnID = words[wordIndex++];
-  auto fn = getFunction(fnID);
-  if (!fn) {
+  uint32_t fnID = words[wordIndex++];
+  FuncOp fn = getFunction(fnID);
+  if (!fn)
     return emitError(unknownLoc, "no function matching <id> ") << fnID;
-  }
+
   // Get the Execution mode
-  if (wordIndex >= words.size()) {
-    return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
-  }
-  auto execMode = spirv::ExecutionModeAttr::get(
+  if (wordIndex >= words.size())
+    return emitError(unknownLoc, "missing Execution Mode in OpExecutionModeId");
+
+  ExecutionModeAttr execMode = spirv::ExecutionModeAttr::get(
       context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
 
   // Get the values
   SmallVector<Attribute, 4> attrListElems;
   while (wordIndex < words.size()) {
-    auto id = getSpecConstantSymbol(words[wordIndex++]);
+    std::string id = getSpecConstantSymbol(words[wordIndex++]);
     attrListElems.push_back(FlatSymbolRefAttr::get(context, id));
   }
-  auto values = opBuilder.getArrayAttr(attrListElems);
+  ArrayAttr values = opBuilder.getArrayAttr(attrListElems);
   spirv::ExecutionModeIdOp::create(
       opBuilder, unknownLoc,
       SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode,
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index f6dc23b46f2b5..539193b64e3d8 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -889,27 +889,24 @@ LogicalResult
 Serializer::processOp<spirv::ExecutionModeIdOp>(spirv::ExecutionModeIdOp op) {
   SmallVector<uint32_t, 4> operands;
   // Add the function <id>.
-  auto funcID = getFunctionID(op.getFn());
-  if (!funcID) {
+  uint32_t funcID = getFunctionID(op.getFn());
+  if (!funcID)
     return op.emitError("missing <id> for function ")
            << op.getFn()
            << "; function needs to be serialized before ExecutionModeIdOp is "
               "serialized";
-  }
+
   operands.push_back(funcID);
   // Add the ExecutionMode.
   operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
 
-  // Serialize values if any.
-  if (const auto values = op.getValues(); values) {
-    for (auto &refVal : values.getValue()) {
-      auto id = getSpecConstID(cast<FlatSymbolRefAttr>(refVal).getValue());
-      if (!id) {
-        return op.emitError("unknown <id> for specialization constant ")
-               << cast<FlatSymbolRefAttr>(refVal).getValue();
-      }
-      operands.push_back(id);
-    }
+  for (Attribute refVal : op.getValues().getValue()) {
+    uint32_t id = getSpecConstID(cast<FlatSymbolRefAttr>(refVal).getValue());
+    if (!id)
+      return op.emitError("unknown <id> for specialization constant ")
+             << cast<FlatSymbolRefAttr>(refVal).getValue();
+
+    operands.push_back(id);
   }
   encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionModeId,
                         operands);
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 36dfe40ed7756..361bc6ea20f3b 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -342,16 +342,31 @@ spirv.module Logical GLSL450 {
    // CHECK: spirv.ExecutionModeId {{@.*}} "LocalSizeHintId", @x, @y, @z
    spirv.ExecutionModeId @do_nothing "LocalSizeHintId", @x, @y, @z
 }
+
 // -----
+
 spirv.module Logical GLSL450 {
    spirv.func @do_nothing() -> () "None" {
      spirv.Return
    }
    spirv.EntryPoint "GLCompute" @do_nothing
    // expected-error @+1 {{'spirv.ExecutionModeId' op expected at least one value operand}}
+   spirv.ExecutionModeId @do_nothing "LocalSizeId"
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+   spirv.func @do_nothing() -> () "None" {
+     spirv.Return
+   }
+   spirv.EntryPoint "GLCompute" @do_nothing
+   // expected-error @+1 {{'spirv.ExecutionModeId' op expected ExecutionMode that takes extra operands that are <id> operands, got: ContractionOff}}
    spirv.ExecutionModeId @do_nothing "ContractionOff"
 }
+
 // -----
+
 spirv.module Logical GLSL450 {
    spirv.SpecConstant @x = 3 : i32
    spirv.SpecConstant @y = 4 : i32
@@ -363,7 +378,9 @@ spirv.module Logical GLSL450 {
    // expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid execution_mode attribute specification: "GLCompute"}}
    spirv.ExecutionModeId @do_nothing "GLCompute", @x, @y, @z
 }
+
 // -----
+
 spirv.module Logical GLSL450 {
    spirv.SpecConstant @x = 3 : i32
    spirv.SpecConstant @y = 4 : i32

>From 40e5038e8069f3b28bd2a238586af8d8b40c09e3 Mon Sep 17 00:00:00 2001
From: Emimendoza <emiliomendozareyes at gmail.com>
Date: Tue, 17 Mar 2026 18:20:09 -0400
Subject: [PATCH 3/3] Update mlir/test/Target/SPIRV/execution-mode-id.mlir

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/test/Target/SPIRV/execution-mode-id.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Target/SPIRV/execution-mode-id.mlir b/mlir/test/Target/SPIRV/execution-mode-id.mlir
index dce4cf14144d4..aa5660a5c1c52 100644
--- a/mlir/test/Target/SPIRV/execution-mode-id.mlir
+++ b/mlir/test/Target/SPIRV/execution-mode-id.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
 
 // RUN: %if spirv-tools %{ rm -rf %t %}
 // RUN: %if spirv-tools %{ mkdir %t %}



More information about the Mlir-commits mailing list