[Mlir-commits] [mlir] fef74e1 - [mlir][spirv] add ExecutionModeIdOp (#186241)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 18 03:23:38 PDT 2026
Author: Emimendoza
Date: 2026-03-18T06:23:33-04:00
New Revision: fef74e1c005d49804ea906bff39f759373681923
URL: https://github.com/llvm/llvm-project/commit/fef74e1c005d49804ea906bff39f759373681923
DIFF: https://github.com/llvm/llvm-project/commit/fef74e1c005d49804ea906bff39f759373681923.diff
LOG: [mlir][spirv] add ExecutionModeIdOp (#186241)
Adds OpExecutionModeId from spirv 1.2
---------
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Added:
mlir/test/Target/SPIRV/execution-mode-id.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index fc9e2f11092e6..8badb84a879fa 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4583,6 +4583,7 @@ def SPIRV_OC_OpGroupUMax : I32EnumAttrCase<"OpGroupUMax", 2
def SPIRV_OC_OpGroupSMax : I32EnumAttrCase<"OpGroupSMax", 271>;
def SPIRV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>;
def SPIRV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPIRV_OC_OpExecutionModeId : I32EnumAttrCase<"OpExecutionModeId", 331>;
def SPIRV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUniformAll", 334>;
def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
@@ -4726,6 +4727,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
+ SPIRV_OC_OpExecutionModeId,
SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 9959f0bec781e..6899fc2dabc70 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -290,6 +290,71 @@ def SPIRV_ExecutionModeOp : SPIRV_Op<"ExecutionMode", [InModuleScope]> {
// -----
+def SPIRV_ExecutionModeIdOp : SPIRV_Op<"ExecutionModeId", []> {
+ let summary = [{
+ Declare an execution mode for an entry point, using <id>s as Extra
+ Operands.
+ }];
+
+ let description = [{
+ Entry Point must be the Entry Point <id> operand of an OpEntryPoint
+ instruction.
+
+ Mode is the execution mode. See Execution Mode.
+
+ This instruction is only valid if the Mode operand is an execution mode
+ that takes Extra Operands that are <id> operands. Otherwise, use
+ OpExecutionMode.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ execution-mode ::= "Invocations" | "SpacingEqual" |
+ <and other SPIR-V execution modes...>
+ execution-mode-id-op ::= `spirv.ExecutionMode ` ssa-use execution-mode
+ symbol-reference (`, ` symbol-reference)*
+ ```
+
+ #### Example:
+
+ ```mlir
+ spirv.ExecutionModeId @foo "LocalSizeId" @var0, @var1, @var2
+ spirv.ExecutionModeId @bar "LocalSizeHintId" @x, @y, @z
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_2>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[]>
+ ];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$fn,
+ SPIRV_ExecutionModeAttr:$execution_mode,
+ SymbolRefArrayAttr:$values
+ );
+
+ let results = (outs);
+
+ let hasVerifier = 1;
+
+ let autogenSerialization = 0;
+
+ let builders = [OpBuilder<(ins "spirv::FuncOp":$function,
+ "spirv::ExecutionMode":$executionMode,
+ "ArrayRef<Attribute>":$params),
+ [{
+ build($_builder, $_state,
+ SymbolRefAttr::get(function),
+ ExecutionModeAttr::get($_builder.getContext(), executionMode),
+ $_builder.getArrayAttr(params));
+ }]>];
+}
+
+// -----
+
def SPIRV_FuncOp : SPIRV_Op<"func", [
AutomaticAllocationScope, FunctionOpInterface,
InModuleScope, IsolatedFromAbove
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 41a0e8558ed89..cecc8c2194237 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -919,6 +919,76 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
}
+//===----------------------------------------------------------------------===//
+// spirv.ExecutionModeId
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::ExecutionModeIdOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ ExecutionMode execMode;
+ if (Attribute fn;
+ parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
+ parseEnumStrAttr<ExecutionModeAttr>(execMode, parser, result)) {
+ return failure();
+ }
+
+ SmallVector<Attribute, 4> values;
+ if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+ FlatSymbolRefAttr attr;
+ if (parser.parseAttribute(attr))
+ return failure();
+ values.push_back(attr);
+ return success();
+ })) {
+ return failure();
+ }
+
+ StringRef valuesAttrName = getValuesAttrName(result.name);
+ ArrayAttr valuesAttr = parser.getBuilder().getArrayAttr(values);
+ result.addAttribute(valuesAttrName, valuesAttr);
+ return success();
+}
+
+void spirv::ExecutionModeIdOp::print(OpAsmPrinter &printer) {
+ printer << " ";
+ printer.printSymbolName(getFn());
+ printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\" ";
+
+ llvm::interleaveComma(
+ getValues().getAsValueRange<FlatSymbolRefAttr>(), printer,
+ [&](StringRef value) { printer.printSymbolName(value); });
+}
+
+LogicalResult spirv::ExecutionModeIdOp::verify() {
+ // Valid as of SPIRV 1.6
+ switch (getExecutionMode()) {
+ case ExecutionMode::SubgroupsPerWorkgroupId:
+ case ExecutionMode::LocalSizeId:
+ case ExecutionMode::LocalSizeHintId:
+ break;
+ default:
+ return emitOpError("expected ExecutionMode that takes extra operands that "
+ "are <id> operands, got: ")
+ << stringifyExecutionMode(getExecutionMode());
+ }
+
+ if (getValues().empty())
+ return emitOpError("expected at least one value operand");
+
+ for (Attribute value : getValues()) {
+ auto valueSymbol = dyn_cast<FlatSymbolRefAttr>(value);
+ if (!valueSymbol)
+ return emitOpError("expected value operands to be symbol reference");
+ Operation *valueOp = SymbolTable::lookupNearestSymbolFrom(
+ (*this)->getParentOp(), valueSymbol);
+ if (!valueOp)
+ return emitOpError("cannot find symbol referenced by value operand: ")
+ << valueSymbol.getValue();
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.func
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 5b04a14a78036..cc6302126d64a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -147,6 +147,7 @@ LogicalResult spirv::Deserializer::processInstruction(
return processMemoryModel(operands);
case spirv::Opcode::OpEntryPoint:
case spirv::Opcode::OpExecutionMode:
+ case spirv::Opcode::OpExecutionModeId:
if (deferInstructions) {
deferredInstructions.emplace_back(opcode, operands);
return success();
@@ -453,6 +454,42 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
return success();
}
+template <>
+LogicalResult
+Deserializer::processOp<spirv::ExecutionModeIdOp>(ArrayRef<uint32_t> words) {
+ unsigned wordIndex = 0;
+ unsigned const wordsSize = words.size();
+ if (wordIndex >= wordsSize)
+ return emitError(unknownLoc,
+ "missing function result <id> in OpExecutionModeId");
+
+ // Get the function <id> to get the name of the function.
+ uint32_t fnID = words[wordIndex++];
+ FuncOp fn = getFunction(fnID);
+ if (!fn)
+ return emitError(unknownLoc, "no function matching <id> ") << fnID;
+
+ // Get the Execution mode.
+ if (wordIndex >= wordsSize)
+ return emitError(unknownLoc, "missing Execution Mode in OpExecutionModeId");
+
+ ExecutionModeAttr execMode = spirv::ExecutionModeAttr::get(
+ context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));
+
+ // Get the values.
+ SmallVector<Attribute, 4> attrListElems;
+ while (wordIndex < words.size()) {
+ std::string id = getSpecConstantSymbol(words[wordIndex++]);
+ attrListElems.push_back(FlatSymbolRefAttr::get(context, id));
+ }
+ ArrayAttr values = opBuilder.getArrayAttr(attrListElems);
+ spirv::ExecutionModeIdOp::create(
+ opBuilder, unknownLoc,
+ SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), execMode,
+ values);
+ return success();
+}
+
template <>
LogicalResult
Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index b78fac532d8c5..a2c942d4188e7 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -884,6 +884,34 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
return success();
}
+template <>
+LogicalResult
+Serializer::processOp<spirv::ExecutionModeIdOp>(spirv::ExecutionModeIdOp op) {
+ SmallVector<uint32_t, 4> operands;
+ // Add the function <id>.
+ uint32_t funcID = getFunctionID(op.getFn());
+ if (!funcID)
+ return op.emitError("missing <id> for function ")
+ << op.getFn()
+ << "; function needs to be serialized before ExecutionModeIdOp is "
+ "serialized";
+
+ operands.push_back(funcID);
+ operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
+
+ for (Attribute refVal : op.getValues().getValue()) {
+ uint32_t id = getSpecConstID(cast<FlatSymbolRefAttr>(refVal).getValue());
+ if (!id)
+ return op.emitError("unknown <id> for specialization constant ")
+ << cast<FlatSymbolRefAttr>(refVal).getValue();
+
+ operands.push_back(id);
+ }
+ encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionModeId,
+ operands);
+ return success();
+}
+
template <>
LogicalResult
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 7e37826795d83..e12b70cc5c139 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -327,6 +327,74 @@ spirv.module Logical GLSL450 {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.ExecutionModeId
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @x = 3 : i32
+ spirv.SpecConstant @y = 4 : i32
+ spirv.SpecConstant @z = 5 : i32
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // CHECK: spirv.ExecutionModeId {{@.*}} "LocalSizeHintId" @x, @y, @z
+ spirv.ExecutionModeId @do_nothing "LocalSizeHintId" @x, @y, @z
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // expected-error @+1 {{expected attribute value}}
+ spirv.ExecutionModeId @do_nothing "LocalSizeId"
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @x = 3 : i32
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // expected-error @+1 {{'spirv.ExecutionModeId' op expected ExecutionMode that takes extra operands that are <id> operands, got: ContractionOff}}
+ spirv.ExecutionModeId @do_nothing "ContractionOff" @x
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @x = 3 : i32
+ spirv.SpecConstant @y = 4 : i32
+ spirv.SpecConstant @z = 5 : i32
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid execution_mode attribute specification: "GLCompute"}}
+ spirv.ExecutionModeId @do_nothing "GLCompute" @x, @y, @z
+}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.SpecConstant @x = 3 : i32
+ spirv.SpecConstant @y = 4 : i32
+ spirv.func @do_nothing() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @do_nothing
+ // expected-error @+1 {{custom op 'spirv.ExecutionModeId' invalid kind of attribute specified}}
+ spirv.ExecutionModeId @do_nothing "LocalSizeId" @x, @y, 2
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.func
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/execution-mode-id.mlir b/mlir/test/Target/SPIRV/execution-mode-id.mlir
new file mode 100644
index 0000000000000..f5975655fcda3
--- /dev/null
+++ b/mlir/test/Target/SPIRV/execution-mode-id.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
+
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.2, [Shader], []> {
+ spirv.SpecConstant @x = 3 : i32
+ spirv.SpecConstant @y = 4 : i32
+ spirv.SpecConstant @z = 5 : i32
+ spirv.func @foo() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @foo
+ // CHECK: spirv.ExecutionModeId @foo "LocalSizeId" @x, @y, @z
+ spirv.ExecutionModeId @foo "LocalSizeId" @x, @y, @z
+}
More information about the Mlir-commits
mailing list