[Mlir-commits] [mlir] [mlir][spirv] add ExecutionModeIdOp (PR #186241)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 17 15:20:20 PDT 2026
https://github.com/Emimendoza updated https://github.com/llvm/llvm-project/pull/186241
>From a44c8b52ac1b0062f50a59fe3ef5cf741539037a Mon Sep 17 00:00:00 2001
From: Emilio M <emendoz at clemson.edu>
Date: Thu, 12 Mar 2026 17:41:40 -0400
Subject: [PATCH 1/3] [mlir][spirv] add ExecutionModeIdOp
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 2 +
.../Dialect/SPIRV/IR/SPIRVStructureOps.td | 58 ++++++++++++++++
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 68 +++++++++++++++++++
.../SPIRV/Deserialization/DeserializeOps.cpp | 36 ++++++++++
.../SPIRV/Serialization/SerializeOps.cpp | 32 +++++++++
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 50 ++++++++++++++
mlir/test/Target/SPIRV/execution-mode-id.mlir | 18 +++++
7 files changed, 264 insertions(+)
create mode 100644 mlir/test/Target/SPIRV/execution-mode-id.mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index dd5d00de4147d..1e038c5b76906 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4583,6 +4583,7 @@ def SPIRV_OC_OpGroupUMax : I32EnumAttrCase<"OpGroupUMax", 2
def SPIRV_OC_OpGroupSMax : I32EnumAttrCase<"OpGroupSMax", 271>;
def SPIRV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>;
def SPIRV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPIRV_OC_OpExecutionModeId : I32EnumAttrCase<"OpExecutionModeId", 331>;
def SPIRV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUniformAll", 334>;
def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
@@ -4725,6 +4726,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
+ SPIRV_OC_OpExecutionModeId,
SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBallot,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 9959f0bec781e..58306e1434790 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -290,6 +290,64 @@ def SPIRV_ExecutionModeOp : SPIRV_Op<"ExecutionMode", [InModuleScope]> {
// -----
+def SPIRV_ExecutionModeIdOp : SPIRV_Op<"ExecutionModeId", []> {
+ let summary = [{
+ Declare an execution mode for an entry point, using <id>s as Extra
+ Operands.
+ }];
+
+ let description = [{
+ Entry Point must be the Entry Point <id> operand of an OpEntryPoint
+ instruction.
+
+ Mode is the execution mode. See Execution Mode.
+
+ This instruction is only valid if the Mode operand is an execution mode
+ that takes Extra Operands that are <id> operands. Otherwise, use
+ OpExecutionMode.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ execution-mode ::= "Invocations" | "SpacingEqual" |
+ <and other SPIR-V execution modes...>
+ execution-mode-id-op ::= `spirv.ExecutionMode ` ssa-use execution-mode
+ symbol-reference (`, ` symbol-reference)*
+ ```
+
+ #### Example:
+
+ ```mlir
+ spirv.ExecutionModeId @foo "LocalSizeId" @var0, @var1, @var2
+ spirv.ExecutionModeId @bar "LocalSizeHintId", @x, @y, @z
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_2>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[]>
+ ];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$fn,
+ SPIRV_ExecutionModeAttr:$execution_mode,
+ SymbolRefArrayAttr:$values
+ );
+
+ let results = (outs);
+
+ let hasVerifier = 1;
+
+ let autogenSerialization = 0;
+
+ let builders = [OpBuilder<(ins "spirv::FuncOp":$function,
+ "spirv::ExecutionMode":$executionMode, "ArrayRef<Attribute>":$params)>];
+}
+
+// -----
+
def SPIRV_FuncOp : SPIRV_Op<"func", [
AutomaticAllocationScope, FunctionOpInterface,
InModuleScope, IsolatedFromAbove
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 0f039d89b8fab..cd6ef91ab59c3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -919,6 +919,74 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
}
+//===----------------------------------------------------------------------===//
+// spirv.ExecutionModeId
+//===----------------------------------------------------------------------===//
+
+void spirv::ExecutionModeIdOp::build(OpBuilder &builder, OperationState &state,
+ FuncOp function,
+ ExecutionMode executionMode,
+ ArrayRef<Attribute> params) {
+ build(builder, state, SymbolRefAttr::get(function),
+ ExecutionModeAttr::get(builder.getContext(), executionMode),
+ builder.getArrayAttr(params));
+}
+
+ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ ExecutionMode execMode;
+ if (Attribute fn;
+ parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
+ parseEnumStrAttr<ExecutionModeAttr>(execMode, parser, result)) {
+ return failure();
+ }
+
+ SmallVector<Attribute, 4> values;
+ while (!parser.parseOptionalComma()) {
+ FlatSymbolRefAttr attr;
+ if (parser.parseAttribute(attr)) {
+ return failure();
+ }
+ values.push_back(attr);
+ }
+
+ StringRef valuesAttrName = getValuesAttrName(result.name);
+ ArrayAttr valuesAttr = parser.getBuilder().getArrayAttr(values);
+ result.addAttribute(valuesAttrName, valuesAttr);
+ return success();
+}
+
+void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
+ printer << " ";
+ printer.printSymbolName(getFn());
+ printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
+ for (const auto &value : getValues()) {
+ printer << ", ";
+ printer.printSymbolName(cast<FlatSymbolRefAttr>(value).getValue());
+ }
+}
+
+LogicalResult spirv::ExecutionModeIdOp::verify() {
+ // TODO: Add check to ensure that ExecutionMode is an execution mode that
+ // takes Extra Operands that are <id> operands
+ if (getValues().empty())
+ return emitOpError("expected at least one value operand");
+
+ for (const auto &value : getValues()) {
+ auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
+ if (!valueSymbol)
+ return emitOpError("expected value operands to be symbol reference");
+ Operation *valueOp = SymbolTable::lookupNearestSymbolFrom(
+ (*this)->getParentOp(), valueSymbol);
+ if (!valueOp) {
+ return emitOpError("cannot find symbol referenced by value operand: ")
+ << valueSymbol.getValue();
+ }
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.func
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 5b04a14a78036..e38d3bc2c845a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -147,6 +147,7 @@ LogicalResult spirv::Deserializer::processInstruction(
return processMemoryModel(operands);
case spirv::Opcode::OpEntryPoint:
case spirv::Opcode::OpExecutionMode:
+ case spirv::Opcode::OpExecutionModeId:
if (deferInstructions) {
deferredInstructions.emplace_back(opcode, operands);
return success();
@@ -453,6 +454,41 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
return success();
}
+template <>
+LogicalResult
+Deserializer::processOp<spirv::ExecutionModeIdOp>(ArrayRef<uint32_t> words) {
+ unsigned wordIndex = 0;
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc,
+ "missing function result <id> in OpExecutionMode");
+ }
+ // Get the function <id> to get the name of the function
+ auto fnID = words[wordIndex++];
+ auto fn = getFunction(fnID);
+ if (!fn) {
+ return emitError(unknownLoc, "no function matching <id> ") << fnID;
+ }
+ // Get the Execution mode
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
+ }
+ auto execMode = spirv::ExecutionModeAttr::get(
+ context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
+
+ // Get the values
+ SmallVector<Attribute, 4> attrListElems;
+ while (wordIndex < words.size()) {
+ auto id = getSpecConstantSymbol(words[wordIndex++]);
+ attrListElems.push_back(FlatSymbolRefAttr::get(context, id));
+ }
+ auto values = opBuilder.getArrayAttr(attrListElems);
+ spirv::ExecutionModeIdOp::create(
+ opBuilder, unknownLoc,
+ SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode,
+ values);
+ return success();
+}
+
template <>
LogicalResult
Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index b78fac532d8c5..f6dc23b46f2b5 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -884,6 +884,38 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
return success();
}
+template <>
+LogicalResult
+Serializer::processOp<spirv::ExecutionModeIdOp>(spirv::ExecutionModeIdOp op) {
+ SmallVector<uint32_t, 4> operands;
+ // Add the function <id>.
+ auto funcID = getFunctionID(op.getFn());
+ if (!funcID) {
+ return op.emitError("missing <id> for function ")
+ << op.getFn()
+ << "; function needs to be serialized before ExecutionModeIdOp is "
+ "serialized";
+ }
+ operands.push_back(funcID);
+ // Add the ExecutionMode.
+ operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
+
+ // Serialize values if any.
+ if (const auto values = op.getValues(); values) {
+ for (auto &refVal : values.getValue()) {
+ auto id = getSpecConstID(cast<FlatSymbolRefAttr>(refVal).getValue());
+ if (!id) {
+ return op.emitError("unknown <id> for specialization constant ")
+ << cast<FlatSymbolRefAttr>(refVal).getValue();
+ }
+ operands.push_back(id);
+ }
+ }
+ encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionModeId,
+ operands);
+ return success();
+}
+
template <>
LogicalResult
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 7e37826795d83..36dfe40ed7756 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -327,6 +327,56 @@ spirv.module Logical GLSL450 {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.ExecutionModeId
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @x = 3 : i32
+ spirv.SpecConstant @y = 4 : i32
+ spirv.SpecConstant @z = 5 : i32
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // CHECK: spirv.ExecutionModeId {{@.*}} "LocalSizeHintId", @x, @y, @z
+ spirv.ExecutionModeId @do_nothing "LocalSizeHintId", @x, @y, @z
+}
+// -----
+spirv.module Logical GLSL450 {
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // expected-error @+1 {{'spirv.ExecutionModeId' op expected at least one value operand}}
+ spirv.ExecutionModeId @do_nothing "ContractionOff"
+}
+// -----
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @x = 3 : i32
+ spirv.SpecConstant @y = 4 : i32
+ spirv.SpecConstant @z = 5 : i32
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid execution_mode attribute specification: "GLCompute"}}
+ spirv.ExecutionModeId @do_nothing "GLCompute", @x, @y, @z
+}
+// -----
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @x = 3 : i32
+ spirv.SpecConstant @y = 4 : i32
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid kind of attribute specified}}
+ spirv.ExecutionModeId @do_nothing "LocalSizeId", @x, @y, 2
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.func
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/execution-mode-id.mlir b/mlir/test/Target/SPIRV/execution-mode-id.mlir
new file mode 100644
index 0000000000000..dce4cf14144d4
--- /dev/null
+++ b/mlir/test/Target/SPIRV/execution-mode-id.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.2, [Shader], []> {
+ spirv.SpecConstant @x = 3 : i32
+ spirv.SpecConstant @y = 4 : i32
+ spirv.SpecConstant @z = 5 : i32
+ spirv.func @foo() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @foo
+ // CHECK: spirv.ExecutionModeId @foo "LocalSizeId", @x, @y, @z
+ spirv.ExecutionModeId @foo "LocalSizeId", @x, @y, @z
+}
>From 5a2b59ea8e3fd7a2684c423240473b89e079721f Mon Sep 17 00:00:00 2001
From: Emilio M <emendoz at clemson.edu>
Date: Sat, 14 Mar 2026 17:09:02 -0400
Subject: [PATCH 2/3] requested changes from review
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 27 ++++++++++++-------
.../SPIRV/Deserialization/DeserializeOps.cpp | 26 +++++++++---------
.../SPIRV/Serialization/SerializeOps.cpp | 23 +++++++---------
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 17 ++++++++++++
4 files changed, 58 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index cd6ef91ab59c3..a43731f060164 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -959,16 +959,26 @@ ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
printer << " ";
printer.printSymbolName(getFn());
- printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
- for (const auto &value : getValues()) {
- printer << ", ";
- printer.printSymbolName(cast<FlatSymbolRefAttr>(value).getValue());
- }
+ printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\", ";
+
+ llvm::interleaveComma(
+ getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
+ [&](StringRef value) { printer.printSymbolName(value); });
}
LogicalResult spirv::ExecutionModeIdOp::verify() {
- // TODO: Add check to ensure that ExecutionMode is an execution mode that
- // takes Extra Operands that are <id> operands
+ // Valid as of SPIRV 1.6
+ switch (getExecutionMode()) {
+ case ExecutionMode::SubgroupsPerWorkgroupId:
+ case ExecutionMode::LocalSizeId:
+ case ExecutionMode::LocalSizeHintId:
+ break;
+ default:
+ return emitOpError("expected ExecutionMode that takes extra operands that "
+ "are <id> operands, got: ")
+ << stringifyExecutionMode(getExecutionMode());
+ }
+
if (getValues().empty())
return emitOpError("expected at least one value operand");
@@ -978,10 +988,9 @@ LogicalResult spirv::ExecutionModeIdOp::verify() {
return emitOpError("expected value operands to be symbol reference");
Operation *valueOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), valueSymbol);
- if (!valueOp) {
+ if (!valueOp)
return emitOpError("cannot find symbol referenced by value operand: ")
<< valueSymbol.getValue();
- }
}
return success();
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index e38d3bc2c845a..74ec3bcc5db00 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -458,30 +458,30 @@ template <>
LogicalResult
Deserializer::processOp<spirv::ExecutionModeIdOp>(ArrayRef<uint32_t> words) {
unsigned wordIndex = 0;
- if (wordIndex >= words.size()) {
+ if (wordIndex >= words.size())
return emitError(unknownLoc,
- "missing function result <id> in OpExecutionMode");
- }
+ "missing function result <id> in OpExecutionModeId");
+
// Get the function <id> to get the name of the function
- auto fnID = words[wordIndex++];
- auto fn = getFunction(fnID);
- if (!fn) {
+ uint32_t fnID = words[wordIndex++];
+ FuncOp fn = getFunction(fnID);
+ if (!fn)
return emitError(unknownLoc, "no function matching <id> ") << fnID;
- }
+
// Get the Execution mode
- if (wordIndex >= words.size()) {
- return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
- }
- auto execMode = spirv::ExecutionModeAttr::get(
+ if (wordIndex >= words.size())
+ return emitError(unknownLoc, "missing Execution Mode in OpExecutionModeId");
+
+ ExecutionModeAttr execMode = spirv::ExecutionModeAttr::get(
context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
// Get the values
SmallVector<Attribute, 4> attrListElems;
while (wordIndex < words.size()) {
- auto id = getSpecConstantSymbol(words[wordIndex++]);
+ std::string id = getSpecConstantSymbol(words[wordIndex++]);
attrListElems.push_back(FlatSymbolRefAttr::get(context, id));
}
- auto values = opBuilder.getArrayAttr(attrListElems);
+ ArrayAttr values = opBuilder.getArrayAttr(attrListElems);
spirv::ExecutionModeIdOp::create(
opBuilder, unknownLoc,
SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode,
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index f6dc23b46f2b5..539193b64e3d8 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -889,27 +889,24 @@ LogicalResult
Serializer::processOp<spirv::ExecutionModeIdOp>(spirv::ExecutionModeIdOp op) {
SmallVector<uint32_t, 4> operands;
// Add the function <id>.
- auto funcID = getFunctionID(op.getFn());
- if (!funcID) {
+ uint32_t funcID = getFunctionID(op.getFn());
+ if (!funcID)
return op.emitError("missing <id> for function ")
<< op.getFn()
<< "; function needs to be serialized before ExecutionModeIdOp is "
"serialized";
- }
+
operands.push_back(funcID);
// Add the ExecutionMode.
operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
- // Serialize values if any.
- if (const auto values = op.getValues(); values) {
- for (auto &refVal : values.getValue()) {
- auto id = getSpecConstID(cast<FlatSymbolRefAttr>(refVal).getValue());
- if (!id) {
- return op.emitError("unknown <id> for specialization constant ")
- << cast<FlatSymbolRefAttr>(refVal).getValue();
- }
- operands.push_back(id);
- }
+ for (Attribute refVal : op.getValues().getValue()) {
+ uint32_t id = getSpecConstID(cast<FlatSymbolRefAttr>(refVal).getValue());
+ if (!id)
+ return op.emitError("unknown <id> for specialization constant ")
+ << cast<FlatSymbolRefAttr>(refVal).getValue();
+
+ operands.push_back(id);
}
encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionModeId,
operands);
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 36dfe40ed7756..361bc6ea20f3b 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -342,16 +342,31 @@ spirv.module Logical GLSL450 {
// CHECK: spirv.ExecutionModeId {{@.*}} "LocalSizeHintId", @x, @y, @z
spirv.ExecutionModeId @do_nothing "LocalSizeHintId", @x, @y, @z
}
+
// -----
+
spirv.module Logical GLSL450 {
spirv.func @do_nothing() -> () "None" {
spirv.Return
}
spirv.EntryPoint "GLCompute" @do_nothing
// expected-error @+1 {{'spirv.ExecutionModeId' op expected at least one value operand}}
+ spirv.ExecutionModeId @do_nothing "LocalSizeId"
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // expected-error @+1 {{'spirv.ExecutionModeId' op expected ExecutionMode that takes extra operands that are <id> operands, got: ContractionOff}}
spirv.ExecutionModeId @do_nothing "ContractionOff"
}
+
// -----
+
spirv.module Logical GLSL450 {
spirv.SpecConstant @x = 3 : i32
spirv.SpecConstant @y = 4 : i32
@@ -363,7 +378,9 @@ spirv.module Logical GLSL450 {
// expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid execution_mode attribute specification: "GLCompute"}}
spirv.ExecutionModeId @do_nothing "GLCompute", @x, @y, @z
}
+
// -----
+
spirv.module Logical GLSL450 {
spirv.SpecConstant @x = 3 : i32
spirv.SpecConstant @y = 4 : i32
>From 40e5038e8069f3b28bd2a238586af8d8b40c09e3 Mon Sep 17 00:00:00 2001
From: Emimendoza <emiliomendozareyes at gmail.com>
Date: Tue, 17 Mar 2026 18:20:09 -0400
Subject: [PATCH 3/3] Update mlir/test/Target/SPIRV/execution-mode-id.mlir
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/test/Target/SPIRV/execution-mode-id.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Target/SPIRV/execution-mode-id.mlir b/mlir/test/Target/SPIRV/execution-mode-id.mlir
index dce4cf14144d4..aa5660a5c1c52 100644
--- a/mlir/test/Target/SPIRV/execution-mode-id.mlir
+++ b/mlir/test/Target/SPIRV/execution-mode-id.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
// RUN: %if spirv-tools %{ rm -rf %t %}
// RUN: %if spirv-tools %{ mkdir %t %}
More information about the Mlir-commits
mailing list