[Mlir-commits] [mlir] [mlir][SPIR-V] Add OpDecorateId support for Id-form decorations (PR #194611)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 28 06:04:30 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Arseniy Obolenskiy (aobolensk)
<details>
<summary>Changes</summary>
Support serialization/deserialization of standard decorations that can be used via OpDecorateId
---
Full diff: https://github.com/llvm/llvm-project/pull/194611.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+2-1)
- (modified) mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (+1)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+60)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+15)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+29)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+5)
- (added) mlir/test/Target/SPIRV/decorations-id.mlir (+32)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 733150d11c8b1..a2e1274cac3a5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4607,6 +4607,7 @@ def SPIRV_OC_OpGroupSMax : I32EnumAttrCase<"OpGroupSMax", 2
def SPIRV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>;
def SPIRV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
def SPIRV_OC_OpExecutionModeId : I32EnumAttrCase<"OpExecutionModeId", 331>;
+def SPIRV_OC_OpDecorateId : I32EnumAttrCase<"OpDecorateId", 332>;
def SPIRV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
def SPIRV_OC_OpGroupNonUniformAll : I32EnumAttrCase<"OpGroupNonUniformAll", 334>;
def SPIRV_OC_OpGroupNonUniformAny : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
@@ -4755,7 +4756,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
- SPIRV_OC_OpExecutionModeId,
+ SPIRV_OC_OpExecutionModeId, SPIRV_OC_OpDecorateId,
SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst,
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 0faa5f0f29d7d..c12647d4255aa 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -221,6 +221,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpGraphConstantARM:
return processGraphConstantARM(operands);
case spirv::Opcode::OpDecorate:
+ case spirv::Opcode::OpDecorateId:
return processDecoration(operands);
case spirv::Opcode::OpMemberDecorate:
return processMemberDecoration(operands);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index f4481305fd552..3a98c27727e31 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -92,6 +92,9 @@ LogicalResult spirv::Deserializer::deserialize() {
}
}
+ if (failed(resolveDeferredIdDecorations()))
+ return failure();
+
attachVCETriple();
LLVM_DEBUG(logger.startLine()
@@ -377,12 +380,69 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return res;
break;
}
+ case spirv::Decoration::AlignmentId:
+ case spirv::Decoration::MaxByteOffsetId:
+ case spirv::Decoration::CounterBuffer:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorateId with ")
+ << decorationName << " needs a single <id> operand";
+ }
+ pendingIdDecorations.push_back(
+ {words[0], static_cast<spirv::Decoration>(words[1]), words[2],
+ unknownLoc});
+ break;
default:
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
}
return success();
}
+LogicalResult spirv::Deserializer::resolveDeferredIdDecorations() {
+ for (const DeferredIdDecoration &entry : pendingIdDecorations) {
+ auto decorationName = stringifyDecoration(entry.decoration);
+ StringAttr symbol = getSymbolDecoration(decorationName);
+
+ // Resolve the operand <id> to a symbol name. The operand must reference a
+ // module-scope symbol op (global variable or specialization constant).
+ StringRef operandSymName;
+ if (auto varOp = globalVariableMap.lookup(entry.operandID))
+ operandSymName = varOp.getSymName();
+ else if (auto specOp = specConstMap.lookup(entry.operandID))
+ operandSymName = specOp.getSymName();
+
+ if (operandSymName.empty()) {
+ return emitError(entry.loc, "OpDecorateId with ")
+ << decorationName << " references <id> " << entry.operandID
+ << " which is not a global variable or specialization constant";
+ }
+
+ auto symRef = FlatSymbolRefAttr::get(context, operandSymName);
+
+ // The decoration target may already be a constructed module-scope op
+ // (its decorations dict was applied at construction time, before this
+ // resolution pass runs). In that case, set the attribute directly on the
+ // op. Otherwise, fall back to the deferred `decorations` map for ops that
+ // consume it later.
+ Operation *targetOp = nullptr;
+ if (auto varOp = globalVariableMap.lookup(entry.targetID))
+ targetOp = varOp;
+ else if (auto specOp = specConstMap.lookup(entry.targetID))
+ targetOp = specOp;
+ else if (auto fnOp = funcMap.lookup(entry.targetID))
+ targetOp = fnOp;
+ else if (Value v = valueMap.lookup(entry.targetID))
+ targetOp = v.getDefiningOp();
+
+ if (targetOp) {
+ targetOp->setAttr(symbol, symRef);
+ } else {
+ decorations[entry.targetID].set(symbol, symRef);
+ }
+ }
+ pendingIdDecorations.clear();
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
// The binary layout of OpMemberDecorate is different comparing to OpDecorate
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index e0743503acc5b..9725c63deb8c2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -177,6 +177,11 @@ class Deserializer {
/// Processes an OpDecorate instruction.
LogicalResult processDecoration(ArrayRef<uint32_t> words);
+ /// Resolves all OpDecorateId entries previously queued during
+ /// processDecoration. Called after all module ops have been deserialized so
+ /// the operand <id>s can be looked up as MLIR symbols.
+ LogicalResult resolveDeferredIdDecorations();
+
// Processes an OpMemberDecorate instruction.
LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
@@ -682,6 +687,16 @@ class Deserializer {
// Result <id> to decorations mapping.
DenseMap<uint32_t, NamedAttrList> decorations;
+ // Decoration entries from OpDecorateId whose operand <id>s must be resolved
+ // to MLIR symbols after all module ops have been deserialized.
+ struct DeferredIdDecoration {
+ uint32_t targetID;
+ spirv::Decoration decoration;
+ uint32_t operandID;
+ Location loc;
+ };
+ SmallVector<DeferredIdDecoration> pendingIdDecorations;
+
// Result <id> to type decorations.
DenseMap<uint32_t, uint32_t> typeDecorations;
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index aaa80470f40e1..e29a437cca87f 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -405,6 +405,23 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
resultID, decoration,
{cacheLevel, static_cast<uint32_t>(storeCacheControl)});
});
+ case spirv::Decoration::AlignmentId:
+ case spirv::Decoration::MaxByteOffsetId:
+ case spirv::Decoration::CounterBuffer: {
+ auto symRef = dyn_cast<FlatSymbolRefAttr>(attr);
+ if (!symRef)
+ return emitError(loc, "expected symbol reference for ")
+ << stringifyDecoration(decoration);
+ StringRef symName = symRef.getValue();
+ uint32_t operandID = getVariableID(symName);
+ if (!operandID)
+ operandID = getSpecConstID(symName);
+ if (!operandID)
+ return emitError(loc, "could not find <id> for symbol '")
+ << symName << "' referenced by "
+ << stringifyDecoration(decoration);
+ return emitDecorationId(resultID, decoration, {operandID});
+ }
default:
return emitError(loc, "unhandled decoration ")
<< stringifyDecoration(decoration);
@@ -1677,6 +1694,18 @@ LogicalResult Serializer::emitDecoration(uint32_t target,
return success();
}
+LogicalResult Serializer::emitDecorationId(uint32_t target,
+ spirv::Decoration decoration,
+ ArrayRef<uint32_t> operandIds) {
+ uint32_t wordCount = 3 + operandIds.size();
+ llvm::append_values(
+ decorations,
+ spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorateId), target,
+ static_cast<uint32_t>(decoration));
+ llvm::append_range(decorations, operandIds);
+ return success();
+}
+
LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
Location loc) {
if (!options.emitDebugInfo)
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 6e79c133eb6af..eb5ac0d60038e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -355,6 +355,11 @@ class Serializer {
LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration,
ArrayRef<uint32_t> params = {});
+ /// Emits an OpDecorateId instruction to decorate the given `target` with the
+ /// given `decoration` whose extra operands are SPIR-V <id>s.
+ LogicalResult emitDecorationId(uint32_t target, spirv::Decoration decoration,
+ ArrayRef<uint32_t> operandIds);
+
/// Emits an OpLine instruction with the given `loc` location information into
/// the given `binary` vector.
LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc);
diff --git a/mlir/test/Target/SPIRV/decorations-id.mlir b/mlir/test/Target/SPIRV/decorations-id.mlir
new file mode 100644
index 0000000000000..dfd83319032d9
--- /dev/null
+++ b/mlir/test/Target/SPIRV/decorations-id.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
+
+// Round-trip tests for decorations whose operand is a SPIR-V <id>
+// (serialized as OpDecorateId, opcode 332).
+
+// AlignmentId references a specialization constant that supplies the alignment.
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage], []> {
+ // CHECK: spirv.SpecConstant @sc_align = 16
+ // CHECK: alignment_id = @sc_align
+ spirv.SpecConstant @sc_align = 16 : i32
+ spirv.GlobalVariable @var {alignment_id = @sc_align} : !spirv.ptr<f32, CrossWorkgroup>
+}
+
+// -----
+
+// MaxByteOffsetId references a specialization constant.
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage], []> {
+ // CHECK: spirv.SpecConstant @sc_offset = 1024
+ // CHECK: max_byte_offset_id = @sc_offset
+ spirv.SpecConstant @sc_offset = 1024 : i32
+ spirv.GlobalVariable @var {max_byte_offset_id = @sc_offset} : !spirv.ptr<f32, CrossWorkgroup>
+}
+
+// -----
+
+// CounterBuffer references another global variable.
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+ // CHECK: spirv.GlobalVariable @counter
+ // CHECK: counter_buffer = @counter
+ spirv.GlobalVariable @counter bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4>[0])>, StorageBuffer>
+ spirv.GlobalVariable @var bind(0, 0) {counter_buffer = @counter} : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4>[0])>, StorageBuffer>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/194611
More information about the Mlir-commits
mailing list