[Mlir-commits] [mlir] [mlir][SPIR-V] Add OpDecorateId support for Id-form decorations (PR #194611)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 28 06:04:30 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Arseniy Obolenskiy (aobolensk)

<details>
<summary>Changes</summary>

Support serialization/deserialization of standard decorations that can be used via OpDecorateId

---
Full diff: https://github.com/llvm/llvm-project/pull/194611.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+2-1) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (+1) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+60) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+15) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+29) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+5) 
- (added) mlir/test/Target/SPIRV/decorations-id.mlir (+32) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 733150d11c8b1..a2e1274cac3a5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4607,6 +4607,7 @@ def SPIRV_OC_OpGroupSMax                      : I32EnumAttrCase<"OpGroupSMax", 2
 def SPIRV_OC_OpNoLine                         : I32EnumAttrCase<"OpNoLine", 317>;
 def SPIRV_OC_OpModuleProcessed                : I32EnumAttrCase<"OpModuleProcessed", 330>;
 def SPIRV_OC_OpExecutionModeId                : I32EnumAttrCase<"OpExecutionModeId", 331>;
+def SPIRV_OC_OpDecorateId                     : I32EnumAttrCase<"OpDecorateId", 332>;
 def SPIRV_OC_OpGroupNonUniformElect           : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
 def SPIRV_OC_OpGroupNonUniformAll             : I32EnumAttrCase<"OpGroupNonUniformAll", 334>;
 def SPIRV_OC_OpGroupNonUniformAny             : I32EnumAttrCase<"OpGroupNonUniformAny", 335>;
@@ -4755,7 +4756,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
       SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
       SPIRV_OC_OpGroupSMax, SPIRV_OC_OpNoLine, SPIRV_OC_OpModuleProcessed,
-      SPIRV_OC_OpExecutionModeId,
+      SPIRV_OC_OpExecutionModeId, SPIRV_OC_OpDecorateId,
       SPIRV_OC_OpGroupNonUniformElect, SPIRV_OC_OpGroupNonUniformAll,
       SPIRV_OC_OpGroupNonUniformAny, SPIRV_OC_OpGroupNonUniformAllEqual,
       SPIRV_OC_OpGroupNonUniformBroadcast, SPIRV_OC_OpGroupNonUniformBroadcastFirst,
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 0faa5f0f29d7d..c12647d4255aa 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -221,6 +221,7 @@ LogicalResult spirv::Deserializer::processInstruction(
   case spirv::Opcode::OpGraphConstantARM:
     return processGraphConstantARM(operands);
   case spirv::Opcode::OpDecorate:
+  case spirv::Opcode::OpDecorateId:
     return processDecoration(operands);
   case spirv::Opcode::OpMemberDecorate:
     return processMemberDecoration(operands);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index f4481305fd552..3a98c27727e31 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -92,6 +92,9 @@ LogicalResult spirv::Deserializer::deserialize() {
     }
   }
 
+  if (failed(resolveDeferredIdDecorations()))
+    return failure();
+
   attachVCETriple();
 
   LLVM_DEBUG(logger.startLine()
@@ -377,12 +380,69 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
       return res;
     break;
   }
+  case spirv::Decoration::AlignmentId:
+  case spirv::Decoration::MaxByteOffsetId:
+  case spirv::Decoration::CounterBuffer:
+    if (words.size() != 3) {
+      return emitError(unknownLoc, "OpDecorateId with ")
+             << decorationName << " needs a single <id> operand";
+    }
+    pendingIdDecorations.push_back(
+        {words[0], static_cast<spirv::Decoration>(words[1]), words[2],
+         unknownLoc});
+    break;
   default:
     return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
   }
   return success();
 }
 
+LogicalResult spirv::Deserializer::resolveDeferredIdDecorations() {
+  for (const DeferredIdDecoration &entry : pendingIdDecorations) {
+    auto decorationName = stringifyDecoration(entry.decoration);
+    StringAttr symbol = getSymbolDecoration(decorationName);
+
+    // Resolve the operand <id> to a symbol name. The operand must reference a
+    // module-scope symbol op (global variable or specialization constant).
+    StringRef operandSymName;
+    if (auto varOp = globalVariableMap.lookup(entry.operandID))
+      operandSymName = varOp.getSymName();
+    else if (auto specOp = specConstMap.lookup(entry.operandID))
+      operandSymName = specOp.getSymName();
+
+    if (operandSymName.empty()) {
+      return emitError(entry.loc, "OpDecorateId with ")
+             << decorationName << " references <id> " << entry.operandID
+             << " which is not a global variable or specialization constant";
+    }
+
+    auto symRef = FlatSymbolRefAttr::get(context, operandSymName);
+
+    // The decoration target may already be a constructed module-scope op
+    // (its decorations dict was applied at construction time, before this
+    // resolution pass runs). In that case, set the attribute directly on the
+    // op. Otherwise, fall back to the deferred `decorations` map for ops that
+    // consume it later.
+    Operation *targetOp = nullptr;
+    if (auto varOp = globalVariableMap.lookup(entry.targetID))
+      targetOp = varOp;
+    else if (auto specOp = specConstMap.lookup(entry.targetID))
+      targetOp = specOp;
+    else if (auto fnOp = funcMap.lookup(entry.targetID))
+      targetOp = fnOp;
+    else if (Value v = valueMap.lookup(entry.targetID))
+      targetOp = v.getDefiningOp();
+
+    if (targetOp) {
+      targetOp->setAttr(symbol, symRef);
+    } else {
+      decorations[entry.targetID].set(symbol, symRef);
+    }
+  }
+  pendingIdDecorations.clear();
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
   // The binary layout of OpMemberDecorate is different comparing to OpDecorate
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index e0743503acc5b..9725c63deb8c2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -177,6 +177,11 @@ class Deserializer {
   /// Processes an OpDecorate instruction.
   LogicalResult processDecoration(ArrayRef<uint32_t> words);
 
+  /// Resolves all OpDecorateId entries previously queued during
+  /// processDecoration. Called after all module ops have been deserialized so
+  /// the operand <id>s can be looked up as MLIR symbols.
+  LogicalResult resolveDeferredIdDecorations();
+
   // Processes an OpMemberDecorate instruction.
   LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
 
@@ -682,6 +687,16 @@ class Deserializer {
   // Result <id> to decorations mapping.
   DenseMap<uint32_t, NamedAttrList> decorations;
 
+  // Decoration entries from OpDecorateId whose operand <id>s must be resolved
+  // to MLIR symbols after all module ops have been deserialized.
+  struct DeferredIdDecoration {
+    uint32_t targetID;
+    spirv::Decoration decoration;
+    uint32_t operandID;
+    Location loc;
+  };
+  SmallVector<DeferredIdDecoration> pendingIdDecorations;
+
   // Result <id> to type decorations.
   DenseMap<uint32_t, uint32_t> typeDecorations;
 
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index aaa80470f40e1..e29a437cca87f 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -405,6 +405,23 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
               resultID, decoration,
               {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
         });
+  case spirv::Decoration::AlignmentId:
+  case spirv::Decoration::MaxByteOffsetId:
+  case spirv::Decoration::CounterBuffer: {
+    auto symRef = dyn_cast<FlatSymbolRefAttr>(attr);
+    if (!symRef)
+      return emitError(loc, "expected symbol reference for ")
+             << stringifyDecoration(decoration);
+    StringRef symName = symRef.getValue();
+    uint32_t operandID = getVariableID(symName);
+    if (!operandID)
+      operandID = getSpecConstID(symName);
+    if (!operandID)
+      return emitError(loc, "could not find <id> for symbol '")
+             << symName << "' referenced by "
+             << stringifyDecoration(decoration);
+    return emitDecorationId(resultID, decoration, {operandID});
+  }
   default:
     return emitError(loc, "unhandled decoration ")
            << stringifyDecoration(decoration);
@@ -1677,6 +1694,18 @@ LogicalResult Serializer::emitDecoration(uint32_t target,
   return success();
 }
 
+LogicalResult Serializer::emitDecorationId(uint32_t target,
+                                           spirv::Decoration decoration,
+                                           ArrayRef<uint32_t> operandIds) {
+  uint32_t wordCount = 3 + operandIds.size();
+  llvm::append_values(
+      decorations,
+      spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorateId), target,
+      static_cast<uint32_t>(decoration));
+  llvm::append_range(decorations, operandIds);
+  return success();
+}
+
 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
                                         Location loc) {
   if (!options.emitDebugInfo)
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 6e79c133eb6af..eb5ac0d60038e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -355,6 +355,11 @@ class Serializer {
   LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration,
                                ArrayRef<uint32_t> params = {});
 
+  /// Emits an OpDecorateId instruction to decorate the given `target` with the
+  /// given `decoration` whose extra operands are SPIR-V <id>s.
+  LogicalResult emitDecorationId(uint32_t target, spirv::Decoration decoration,
+                                 ArrayRef<uint32_t> operandIds);
+
   /// Emits an OpLine instruction with the given `loc` location information into
   /// the given `binary` vector.
   LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc);
diff --git a/mlir/test/Target/SPIRV/decorations-id.mlir b/mlir/test/Target/SPIRV/decorations-id.mlir
new file mode 100644
index 0000000000000..dfd83319032d9
--- /dev/null
+++ b/mlir/test/Target/SPIRV/decorations-id.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
+
+// Round-trip tests for decorations whose operand is a SPIR-V <id>
+// (serialized as OpDecorateId, opcode 332).
+
+// AlignmentId references a specialization constant that supplies the alignment.
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage], []> {
+  // CHECK: spirv.SpecConstant @sc_align = 16
+  // CHECK: alignment_id = @sc_align
+  spirv.SpecConstant @sc_align = 16 : i32
+  spirv.GlobalVariable @var {alignment_id = @sc_align} : !spirv.ptr<f32, CrossWorkgroup>
+}
+
+// -----
+
+// MaxByteOffsetId references a specialization constant.
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Linkage], []> {
+  // CHECK: spirv.SpecConstant @sc_offset = 1024
+  // CHECK: max_byte_offset_id = @sc_offset
+  spirv.SpecConstant @sc_offset = 1024 : i32
+  spirv.GlobalVariable @var {max_byte_offset_id = @sc_offset} : !spirv.ptr<f32, CrossWorkgroup>
+}
+
+// -----
+
+// CounterBuffer references another global variable.
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+  // CHECK: spirv.GlobalVariable @counter
+  // CHECK: counter_buffer = @counter
+  spirv.GlobalVariable @counter bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4>[0])>, StorageBuffer>
+  spirv.GlobalVariable @var bind(0, 0) {counter_buffer = @counter} : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4>[0])>, StorageBuffer>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/194611


More information about the Mlir-commits mailing list