[Mlir-commits] [mlir] [mlir][SPIR-V] Add support for SPV_INTEL_long_composites extension (PR #195685)
Jakub Kuderski
llvmlistbot at llvm.org
Tue May 5 10:47:07 PDT 2026
================
@@ -229,3 +230,223 @@ TEST_F(SerializationTest, DoesNotContainSymbolName) {
};
EXPECT_FALSE(scanInstruction(hasVarName));
}
+
+//===----------------------------------------------------------------------===//
+// SPV_INTEL_long_composites: composites whose binary form would exceed the
+// SPIR-V 16-bit word-count limit are split into a parent + *ContinuedINTEL ops
+// on serialization, and merged back on deserialization. These tests build the
+// large composites programmatically so that the IR doesn't have to expand
+// thousands of operands literally.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Picked to comfortably exceed kMaxWordCount = 65535 for any of the splittable
+// composite/struct opcodes -- each one packs at most kMaxWordCount - {1,2,3}
+// operands into the parent word, so 65540 always triggers a split.
+constexpr unsigned kLongCompositeSize = 65540;
+
+bool hasOpcode(SmallVectorImpl<uint32_t> &binary, spirv::Opcode target) {
+ size_t offset = spirv::kHeaderWordCount;
+ while (offset < binary.size()) {
+ uint32_t wordCount = binary[offset] >> 16;
+ if (!wordCount || offset + wordCount > binary.size())
+ return false;
+ auto op = static_cast<spirv::Opcode>(binary[offset] & 0xffff);
+ if (op == target)
+ return true;
+ offset += wordCount;
+ }
+ return false;
+}
+
+bool hasLongCompositesCapabilityAndExtension(SmallVectorImpl<uint32_t> &b) {
+ bool foundCap = false;
+ bool foundExt = false;
+ size_t offset = spirv::kHeaderWordCount;
+ while (offset < b.size()) {
+ uint32_t wordCount = b[offset] >> 16;
+ if (!wordCount || offset + wordCount > b.size())
+ break;
+ auto op = static_cast<spirv::Opcode>(b[offset] & 0xffff);
+ ArrayRef<uint32_t> operands(b.data() + offset + 1, wordCount - 1);
+ if (op == spirv::Opcode::OpCapability && !operands.empty() &&
+ operands[0] ==
+ static_cast<uint32_t>(spirv::Capability::LongCompositesINTEL))
+ foundCap = true;
+ if (op == spirv::Opcode::OpExtension) {
+ unsigned idx = 0;
+ if (spirv::decodeStringLiteral(operands, idx) ==
+ spirv::stringifyExtension(
+ spirv::Extension::SPV_INTEL_long_composites))
+ foundExt = true;
+ }
+ offset += wordCount;
+ }
+ return foundCap && foundExt;
+}
+
+// Verifies that no instruction in the binary has a word count exceeding the
+// SPIR-V 16-bit limit (which would mean the splitting logic failed).
+bool allInstructionsWithinWordLimit(SmallVectorImpl<uint32_t> &b) {
+ size_t offset = spirv::kHeaderWordCount;
+ while (offset < b.size()) {
+ uint32_t wordCount = b[offset] >> 16;
+ if (!wordCount || wordCount > spirv::kMaxWordCount)
+ return false;
+ offset += wordCount;
+ }
+ return true;
+}
+
+} // namespace
+
+TEST_F(SerializationTest, LongTypeStructIsSplit) {
+ // Build a struct with kLongCompositeSize i8 members and reference it via a
+ // global variable so the type gets serialized.
+ OpBuilder builder(module->getRegion());
+ Type i8Type = builder.getIntegerType(8);
+ SmallVector<Type> memberTypes(kLongCompositeSize, i8Type);
+ SmallVector<spirv::StructType::OffsetInfo> offsets(kLongCompositeSize, 0);
+ auto structType = spirv::StructType::get(memberTypes, offsets);
+ addGlobalVar(structType, "var0");
+
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+ EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+ EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStruct));
+ EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStructContinuedINTEL));
+ EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+ MLIRContext freshContext;
+ freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+ OwningOpRef<spirv::ModuleOp> roundTripped =
+ spirv::deserialize(binary, &freshContext);
+ ASSERT_TRUE(roundTripped);
+ bool foundStruct = false;
+ roundTripped->walk([&](spirv::GlobalVariableOp gv) {
+ auto ptrType = dyn_cast<spirv::PointerType>(gv.getType());
+ if (!ptrType)
+ return;
+ auto rtStruct = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
+ if (!rtStruct)
+ return;
+ EXPECT_EQ(rtStruct.getNumElements(), kLongCompositeSize);
----------------
kuhar wrote:
The four new tests verify that the parent + at least one continuation opcode is present and that the round-tripped op has `kLongCompositeSize` elements/constituents -- but they never check the content of those elements. With `getI8IntegerAttr(0)` and `entry->getArgument(0)` constituents, a bug that mangled member ordering, dropped a chunk, swapped two operand IDs, or zeroed-out part of the buffer would still produce a struct/array/composite of the expected size and pass.
Concrete ideas:
* `LongConstantCompositeIsSplit`: populate `elements` with a non-uniform pattern (e.g. `getI8IntegerAttr(i & 0xff)`) and assert each round-tripped element matches by index. Uniform 0 makes this essentially a size check.
* `LongCompositeConstructIsSplit`: build N distinct SSA values (e.g., N `spirv.Constant` ops) so every constituent ID is different, then check each round-tripped operand maps back to the right defining op. Reusing the same block argument N times means a swap bug in the chunking math is invisible.
* `LongTypeStructIsSplit`: use a mix of member types (e.g., alternating `i8`/`i16`) so member ordering is observable.
* `LongSpecConstantCompositeIsSplit`: the constituents are already named uniquely (`sc0`, `sc1`, ...). Walk `op.getConstituents()` and assert the first / last few match the expected `SymbolRefAttr` names.
Right now the only asserted property is "we end up with N things again", which is exactly the property that's hardest to break.
https://github.com/llvm/llvm-project/pull/195685
More information about the Mlir-commits
mailing list