[Mlir-commits] [mlir] [mlir][SPIR-V] Add support for SPV_INTEL_long_composites extension (PR #195685)

Jakub Kuderski llvmlistbot at llvm.org
Tue May 5 10:47:07 PDT 2026


================
@@ -229,3 +230,223 @@ TEST_F(SerializationTest, DoesNotContainSymbolName) {
   };
   EXPECT_FALSE(scanInstruction(hasVarName));
 }
+
+//===----------------------------------------------------------------------===//
+// SPV_INTEL_long_composites: composites whose binary form would exceed the
+// SPIR-V 16-bit word-count limit are split into a parent + *ContinuedINTEL ops
+// on serialization, and merged back on deserialization. These tests build the
+// large composites programmatically so that the IR doesn't have to expand
+// thousands of operands literally.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Picked to comfortably exceed kMaxWordCount = 65535 for any of the splittable
+// composite/struct opcodes -- each one packs at most kMaxWordCount - {1,2,3}
+// operands into the parent word, so 65540 always triggers a split.
+constexpr unsigned kLongCompositeSize = 65540;
+
+bool hasOpcode(SmallVectorImpl<uint32_t> &binary, spirv::Opcode target) {
+  size_t offset = spirv::kHeaderWordCount;
+  while (offset < binary.size()) {
+    uint32_t wordCount = binary[offset] >> 16;
+    if (!wordCount || offset + wordCount > binary.size())
+      return false;
+    auto op = static_cast<spirv::Opcode>(binary[offset] & 0xffff);
+    if (op == target)
+      return true;
+    offset += wordCount;
+  }
+  return false;
+}
+
+bool hasLongCompositesCapabilityAndExtension(SmallVectorImpl<uint32_t> &b) {
+  bool foundCap = false;
+  bool foundExt = false;
+  size_t offset = spirv::kHeaderWordCount;
+  while (offset < b.size()) {
+    uint32_t wordCount = b[offset] >> 16;
+    if (!wordCount || offset + wordCount > b.size())
+      break;
+    auto op = static_cast<spirv::Opcode>(b[offset] & 0xffff);
+    ArrayRef<uint32_t> operands(b.data() + offset + 1, wordCount - 1);
+    if (op == spirv::Opcode::OpCapability && !operands.empty() &&
+        operands[0] ==
+            static_cast<uint32_t>(spirv::Capability::LongCompositesINTEL))
+      foundCap = true;
+    if (op == spirv::Opcode::OpExtension) {
+      unsigned idx = 0;
+      if (spirv::decodeStringLiteral(operands, idx) ==
+          spirv::stringifyExtension(
+              spirv::Extension::SPV_INTEL_long_composites))
+        foundExt = true;
+    }
+    offset += wordCount;
+  }
+  return foundCap && foundExt;
+}
+
+// Verifies that no instruction in the binary has a word count exceeding the
+// SPIR-V 16-bit limit (which would mean the splitting logic failed).
+bool allInstructionsWithinWordLimit(SmallVectorImpl<uint32_t> &b) {
+  size_t offset = spirv::kHeaderWordCount;
+  while (offset < b.size()) {
+    uint32_t wordCount = b[offset] >> 16;
+    if (!wordCount || wordCount > spirv::kMaxWordCount)
+      return false;
+    offset += wordCount;
+  }
+  return true;
+}
+
+} // namespace
+
+TEST_F(SerializationTest, LongTypeStructIsSplit) {
+  // Build a struct with kLongCompositeSize i8 members and reference it via a
+  // global variable so the type gets serialized.
+  OpBuilder builder(module->getRegion());
+  Type i8Type = builder.getIntegerType(8);
+  SmallVector<Type> memberTypes(kLongCompositeSize, i8Type);
+  SmallVector<spirv::StructType::OffsetInfo> offsets(kLongCompositeSize, 0);
+  auto structType = spirv::StructType::get(memberTypes, offsets);
+  addGlobalVar(structType, "var0");
+
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+  EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+  EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStruct));
+  EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStructContinuedINTEL));
+  EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+  MLIRContext freshContext;
+  freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+  OwningOpRef<spirv::ModuleOp> roundTripped =
+      spirv::deserialize(binary, &freshContext);
+  ASSERT_TRUE(roundTripped);
+  bool foundStruct = false;
+  roundTripped->walk([&](spirv::GlobalVariableOp gv) {
+    auto ptrType = dyn_cast<spirv::PointerType>(gv.getType());
+    if (!ptrType)
+      return;
+    auto rtStruct = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
+    if (!rtStruct)
+      return;
+    EXPECT_EQ(rtStruct.getNumElements(), kLongCompositeSize);
----------------
kuhar wrote:

The four new tests verify that the parent + at least one continuation opcode is present and that the round-tripped op has `kLongCompositeSize` elements/constituents -- but they never check the content of those elements. With `getI8IntegerAttr(0)` and `entry->getArgument(0)` constituents, a bug that mangled member ordering, dropped a chunk, swapped two operand IDs, or zeroed-out part of the buffer would still produce a struct/array/composite of the expected size and pass.

Concrete ideas:
* `LongConstantCompositeIsSplit`: populate `elements` with a non-uniform pattern (e.g. `getI8IntegerAttr(i & 0xff)`) and assert each round-tripped element matches by index. Uniform 0 makes this essentially a size check.
* `LongCompositeConstructIsSplit`: build N distinct SSA values (e.g., N `spirv.Constant` ops) so every constituent ID is different, then check each round-tripped operand maps back to the right defining op. Reusing the same block argument N times means a swap bug in the chunking math is invisible.
* `LongTypeStructIsSplit`: use a mix of member types (e.g., alternating `i8`/`i16`) so member ordering is observable.
* `LongSpecConstantCompositeIsSplit`: the constituents are already named uniquely (`sc0`, `sc1`, ...). Walk `op.getConstituents()` and assert the first / last few match the expected `SymbolRefAttr` names.

Right now the only asserted property is "we end up with N things again", which is exactly the property that's hardest to break.


https://github.com/llvm/llvm-project/pull/195685


More information about the Mlir-commits mailing list