[Mlir-commits] [mlir] 222d7fc - [mlir][spirv] Avoid duplicated Block decoration during serialization
Lei Zhang
llvmlistbot at llvm.org
Fri Dec 10 16:21:13 PST 2021
Author: Lei Zhang
Date: 2021-12-10T19:20:49-05:00
New Revision: 222d7fc7f81936edf488bc85acf3ea5d5300bca3
URL: https://github.com/llvm/llvm-project/commit/222d7fc7f81936edf488bc85acf3ea5d5300bca3
DIFF: https://github.com/llvm/llvm-project/commit/222d7fc7f81936edf488bc85acf3ea5d5300bca3.diff
LOG: [mlir][spirv] Avoid duplicated Block decoration during serialization
It's legal per the Vulkan / SPIR-V spec; still it's better to avoid
such duplication to have cleaner blob and reduce the binary size.
Reviewed By: scotttodd
Differential Revision: https://reviews.llvm.org/D115532
Added:
Modified:
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index d877c1b70311e..ef558aa6c3929 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -316,18 +316,6 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
return failure();
}
- if (isInterfaceStructPtrType(varOp.type())) {
- auto structType = varOp.type()
- .cast<spirv::PointerType>()
- .getPointeeType()
- .cast<spirv::StructType>();
- if (failed(
- emitDecoration(getTypeID(structType), spirv::Decoration::Block))) {
- return varOp.emitError("cannot decorate ")
- << structType << " with Block decoration";
- }
- }
-
elidedAttrs.push_back("type");
SmallVector<uint32_t, 4> operands;
operands.push_back(resultTypeID);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index bcead6e527d5e..bd618ec4884b8 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -331,9 +331,9 @@ LogicalResult
Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
SetVector<StringRef> &serializationCtx) {
typeID = getTypeID(type);
- if (typeID) {
+ if (typeID)
return success();
- }
+
typeID = getNextID();
SmallVector<uint32_t, 4> operands;
@@ -499,6 +499,14 @@ LogicalResult Serializer::prepareBasicType(
typeEnum = spirv::Opcode::OpTypePointer;
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
+
+ if (isInterfaceStructPtrType(ptrType)) {
+ if (failed(emitDecoration(getTypeID(pointeeStruct),
+ spirv::Decoration::Block)))
+ return emitError(loc, "cannot decorate ")
+ << pointeeStruct << " with Block decoration";
+ }
+
return success();
}
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 9222b0cd3654a..f17bc53c5bd8b 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -76,27 +76,29 @@ class SerializationTest : public ::testing::Test {
builder.getStringAttr(name), nullptr);
}
+ /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
+ /// Returns true to interrupt.
+ using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
+ ArrayRef<uint32_t> operands)>;
+
/// Returns true if we can find a matching instruction in the SPIR-V blob.
- bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
- ArrayRef<uint32_t> operands)>
- matchFn) {
+ bool scanInstruction(HandleFn handleFn) {
auto binarySize = binary.size();
auto *begin = binary.begin();
auto currOffset = spirv::kHeaderWordCount;
while (currOffset < binarySize) {
auto wordCount = binary[currOffset] >> 16;
- if (!wordCount || (currOffset + wordCount > binarySize)) {
+ if (!wordCount || (currOffset + wordCount > binarySize))
return false;
- }
+
spirv::Opcode opcode =
static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
-
- if (matchFn(opcode,
- llvm::ArrayRef<uint32_t>(begin + currOffset + 1,
- begin + currOffset + wordCount))) {
+ llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1,
+ begin + currOffset + wordCount);
+ if (handleFn(opcode, operands))
return true;
- }
+
currOffset += wordCount;
}
return false;
@@ -119,12 +121,32 @@ TEST_F(SerializationTest, ContainsBlockDecoration) {
ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
auto hasBlockDecoration = [](spirv::Opcode opcode,
- ArrayRef<uint32_t> operands) -> bool {
- if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
- return false;
- return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
+ ArrayRef<uint32_t> operands) {
+ return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
+ operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
+ };
+ EXPECT_TRUE(scanInstruction(hasBlockDecoration));
+}
+
+TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) {
+ auto structType = getFloatStructType();
+ // Two global variables using the same type should not decorate the type with
+ // duplicated `Block` decorations.
+ addGlobalVar(structType, "var0");
+ addGlobalVar(structType, "var1");
+
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+
+ unsigned count = 0;
+ auto countBlockDecoration = [&count](spirv::Opcode opcode,
+ ArrayRef<uint32_t> operands) {
+ if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
+ operands[1] == static_cast<uint32_t>(spirv::Decoration::Block))
+ ++count;
+ return false;
};
- EXPECT_TRUE(findInstruction(hasBlockDecoration));
+ ASSERT_FALSE(scanInstruction(countBlockDecoration));
+ EXPECT_EQ(count, 1u);
}
TEST_F(SerializationTest, ContainsSymbolName) {
@@ -140,7 +162,7 @@ TEST_F(SerializationTest, ContainsSymbolName) {
return opcode == spirv::Opcode::OpName &&
spirv::decodeStringLiteral(operands, index) == "var0";
};
- EXPECT_TRUE(findInstruction(hasVarName));
+ EXPECT_TRUE(scanInstruction(hasVarName));
}
TEST_F(SerializationTest, DoesNotContainSymbolName) {
@@ -156,5 +178,5 @@ TEST_F(SerializationTest, DoesNotContainSymbolName) {
return opcode == spirv::Opcode::OpName &&
spirv::decodeStringLiteral(operands, index) == "var0";
};
- EXPECT_FALSE(findInstruction(hasVarName));
+ EXPECT_FALSE(scanInstruction(hasVarName));
}
More information about the Mlir-commits
mailing list