[Mlir-commits] [mlir] b289266 - [mlir][spirv] Add serialization control to emit symbol name
Lei Zhang
llvmlistbot at llvm.org
Fri Dec 10 16:21:07 PST 2021
Author: Lei Zhang
Date: 2021-12-10T19:20:49-05:00
New Revision: b289266cb2398cebca0ba1ecfc6242470cc7cedd
URL: https://github.com/llvm/llvm-project/commit/b289266cb2398cebca0ba1ecfc6242470cc7cedd
DIFF: https://github.com/llvm/llvm-project/commit/b289266cb2398cebca0ba1ecfc6242470cc7cedd.diff
LOG: [mlir][spirv] Add serialization control to emit symbol name
In SPIR-V, symbol names are encoded as `OpName` instructions.
They are not semantic impacting and can be omitted, which can
reduce the binary size.
Reviewed By: scotttodd
Differential Revision: https://reviews.llvm.org/D115531
Added:
Modified:
mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
mlir/include/mlir/Target/SPIRV/Serialization.h
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.h
mlir/lib/Target/SPIRV/TranslateRegistration.cpp
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
index dc2b8c7edec42..a155b1989e63b 100644
--- a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
+++ b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
@@ -13,8 +13,8 @@
#ifndef MLIR_TARGET_SPIRV_BINARY_UTILS_H_
#define MLIR_TARGET_SPIRV_BINARY_UTILS_H_
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
-#include "mlir/Support/LogicalResult.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Support/LLVM.h"
#include <cstdint>
@@ -41,6 +41,16 @@ uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode);
/// Encodes an SPIR-V `literal` string into the given `binary` vector.
LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
StringRef literal);
+
+/// Decodes a string literal in `words` starting at `wordIndex`. Update the
+/// latter to point to the position in words after the string literal.
+inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
+ unsigned &wordIndex) {
+ StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
+ wordIndex += str.size() / 4 + 1;
+ return str;
+}
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/include/mlir/Target/SPIRV/Serialization.h b/mlir/include/mlir/Target/SPIRV/Serialization.h
index 25033e2c4f169..498f390148793 100644
--- a/mlir/include/mlir/Target/SPIRV/Serialization.h
+++ b/mlir/include/mlir/Target/SPIRV/Serialization.h
@@ -22,11 +22,18 @@ class MLIRContext;
namespace spirv {
class ModuleOp;
+struct SerializationOptions {
+ /// Whether to emit `OpName` instructions for SPIR-V symbol ops.
+ bool emitSymbolName = true;
+ /// Whether to emit `OpLine` location information for SPIR-V ops.
+ bool emitDebugInfo = false;
+};
+
/// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
/// reports errors to the error handler registered with the MLIR context for
/// `module`.
LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary,
- bool emitDebugInfo = false);
+ const SerializationOptions &options = {});
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index c01362d1e2d4c..1c77dfdddcaa5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
+#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 17060dddc9198..8351f3a41fa17 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -21,19 +21,6 @@
#include "llvm/ADT/StringRef.h"
#include <cstdint>
-//===----------------------------------------------------------------------===//
-// Utility Functions
-//===----------------------------------------------------------------------===//
-
-/// Decodes a string literal in `words` starting at `wordIndex`. Update the
-/// latter to point to the position in words after the string literal.
-static inline llvm::StringRef
-decodeStringLiteral(llvm::ArrayRef<uint32_t> words, unsigned &wordIndex) {
- llvm::StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
- wordIndex += str.size() / 4 + 1;
- return str;
-}
-
namespace mlir {
namespace spirv {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
index 33b886b6d369c..7d4d118a6783e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
@@ -23,12 +23,12 @@
namespace mlir {
LogicalResult spirv::serialize(spirv::ModuleOp module,
SmallVectorImpl<uint32_t> &binary,
- bool emitDebugInfo) {
+ const SerializationOptions &options) {
if (!module.vce_triple().hasValue())
return module.emitError(
"module must have 'vce_triple' attribute to be serializeable");
- Serializer serializer(module, emitDebugInfo);
+ Serializer serializer(module, options);
if (failed(serializer.serialize()))
return failure();
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 6bd5ff1629b25..bcead6e527d5e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -81,9 +81,9 @@ LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
return success();
}
-Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo)
- : module(module), mlirBuilder(module.getContext()),
- emitDebugInfo(emitDebugInfo) {}
+Serializer::Serializer(spirv::ModuleOp module,
+ const SerializationOptions &options)
+ : module(module), mlirBuilder(module.getContext()), options(options) {}
LogicalResult Serializer::serialize() {
LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
@@ -172,7 +172,7 @@ void Serializer::processCapability() {
}
void Serializer::processDebugInfo() {
- if (!emitDebugInfo)
+ if (!options.emitDebugInfo)
return;
auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
@@ -254,12 +254,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
assert(!name.empty() && "unexpected empty string for OpName");
+ if (!options.emitSymbolName)
+ return success();
SmallVector<uint32_t, 4> nameOperands;
nameOperands.push_back(resultID);
- if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
+ if (failed(spirv::encodeStringLiteralInto(nameOperands, name)))
return failure();
- }
return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
}
@@ -1170,7 +1171,7 @@ LogicalResult Serializer::emitDecoration(uint32_t target,
LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
Location loc) {
- if (!emitDebugInfo)
+ if (!options.emitDebugInfo)
return success();
if (lastProcessedWasMergeInst) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 9ee288a042b4f..5f4a4e999a673 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Builders.h"
+#include "mlir/Target/SPIRV/Serialization.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
@@ -42,7 +43,8 @@ LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
class Serializer {
public:
/// Creates a serializer for the given SPIR-V `module`.
- explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false);
+ explicit Serializer(spirv::ModuleOp module,
+ const SerializationOptions &options);
/// Serializes the remembered SPIR-V module.
LogicalResult serialize();
@@ -316,8 +318,8 @@ class Serializer {
/// An MLIR builder for getting MLIR constructs.
mlir::Builder mlirBuilder;
- /// A flag which indicates if the debuginfo should be emitted.
- bool emitDebugInfo = false;
+ /// Serialization options.
+ SerializationOptions options;
/// A flag which indicates if the last processed instruction was a merge
/// instruction.
diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
index 989de4142d761..e63a68c6ed57b 100644
--- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
+++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
@@ -40,7 +40,7 @@ static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
context->loadDialect<spirv::SPIRVDialect>();
// Make sure the input stream can be treated as a stream of SPIR-V words
- auto start = input->getBufferStart();
+ auto *start = input->getBufferStart();
auto size = input->getBufferSize();
if (size % sizeof(uint32_t) != 0) {
emitError(UnknownLoc::get(context))
@@ -94,8 +94,7 @@ static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
if (spirvModules.size() != 1)
return module.emitError("found more than one 'spv.module' op");
- if (failed(
- spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false)))
+ if (failed(spirv::serialize(spirvModules[0], binary)))
return failure();
output.write(reinterpret_cast<char *>(binary.data()),
@@ -133,7 +132,9 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
if (std::next(spirvModules.begin()) != spirvModules.end())
return srcModule.emitError("found more than one 'spv.module' op");
- if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo)))
+ spirv::SerializationOptions options;
+ options.emitDebugInfo = emitDebugInfo;
+ if (failed(spirv::serialize(*spirvModules.begin(), binary, options)))
return failure();
MLIRContext deserializationContext(context->getDialectRegistry());
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index a3ae8bc3a9e75..9222b0cd3654a 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -37,10 +37,11 @@ class SerializationTest : public ::testing::Test {
protected:
SerializationTest() {
context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
- createModuleOp();
+ initModuleOp();
}
- void createModuleOp() {
+ /// Initializes an empty SPIR-V module op.
+ void initModuleOp() {
OpBuilder builder(&context);
OperationState state(UnknownLoc::get(&context),
spirv::ModuleOp::getOperationName());
@@ -58,27 +59,29 @@ class SerializationTest : public ::testing::Test {
module = cast<spirv::ModuleOp>(Operation::create(state));
}
- Type getFloatStructType() {
- OpBuilder opBuilder(module->getRegion());
- llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
+ /// Gets the `struct { float }` type.
+ spirv::StructType getFloatStructType() {
+ OpBuilder builder(module->getRegion());
+ llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()};
llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
- auto structType = spirv::StructType::get(elementTypes, offsetInfo);
- return structType;
+ return spirv::StructType::get(elementTypes, offsetInfo);
}
- void addGlobalVar(Type type, llvm::StringRef name) {
- OpBuilder opBuilder(module->getRegion());
+ /// Inserts a global variable of the given `type` and `name`.
+ spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) {
+ OpBuilder builder(module->getRegion());
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
- opBuilder.create<spirv::GlobalVariableOp>(
+ return builder.create<spirv::GlobalVariableOp>(
UnknownLoc::get(&context), TypeAttr::get(ptrType),
- opBuilder.getStringAttr(name), nullptr);
+ builder.getStringAttr(name), nullptr);
}
+ /// Returns true if we can find a matching instruction in the SPIR-V blob.
bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
ArrayRef<uint32_t> operands)>
matchFn) {
auto binarySize = binary.size();
- auto begin = binary.begin();
+ auto *begin = binary.begin();
auto currOffset = spirv::kHeaderWordCount;
while (currOffset < binarySize) {
@@ -109,10 +112,12 @@ class SerializationTest : public ::testing::Test {
// Block decoration
//===----------------------------------------------------------------------===//
-TEST_F(SerializationTest, BlockDecorationTest) {
+TEST_F(SerializationTest, ContainsBlockDecoration) {
auto structType = getFloatStructType();
addGlobalVar(structType, "var0");
+
ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+
auto hasBlockDecoration = [](spirv::Opcode opcode,
ArrayRef<uint32_t> operands) -> bool {
if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
@@ -121,3 +126,35 @@ TEST_F(SerializationTest, BlockDecorationTest) {
};
EXPECT_TRUE(findInstruction(hasBlockDecoration));
}
+
+TEST_F(SerializationTest, ContainsSymbolName) {
+ auto structType = getFloatStructType();
+ addGlobalVar(structType, "var0");
+
+ spirv::SerializationOptions options;
+ options.emitSymbolName = true;
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
+
+ auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+ unsigned index = 1; // Skip the result <id>
+ return opcode == spirv::Opcode::OpName &&
+ spirv::decodeStringLiteral(operands, index) == "var0";
+ };
+ EXPECT_TRUE(findInstruction(hasVarName));
+}
+
+TEST_F(SerializationTest, DoesNotContainSymbolName) {
+ auto structType = getFloatStructType();
+ addGlobalVar(structType, "var0");
+
+ spirv::SerializationOptions options;
+ options.emitSymbolName = false;
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
+
+ auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+ unsigned index = 1; // Skip the result <id>
+ return opcode == spirv::Opcode::OpName &&
+ spirv::decodeStringLiteral(operands, index) == "var0";
+ };
+ EXPECT_FALSE(findInstruction(hasVarName));
+}
More information about the Mlir-commits
mailing list