[Mlir-commits] [mlir] b289266 - [mlir][spirv] Add serialization control to emit symbol name

Lei Zhang llvmlistbot at llvm.org
Fri Dec 10 16:21:07 PST 2021


Author: Lei Zhang
Date: 2021-12-10T19:20:49-05:00
New Revision: b289266cb2398cebca0ba1ecfc6242470cc7cedd

URL: https://github.com/llvm/llvm-project/commit/b289266cb2398cebca0ba1ecfc6242470cc7cedd
DIFF: https://github.com/llvm/llvm-project/commit/b289266cb2398cebca0ba1ecfc6242470cc7cedd.diff

LOG: [mlir][spirv] Add serialization control to emit symbol name

In SPIR-V, symbol names are encoded as `OpName` instructions.
They are not semantic impacting and can be omitted, which can
reduce the binary size.

Reviewed By: scotttodd

Differential Revision: https://reviews.llvm.org/D115531

Added: 
    

Modified: 
    mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
    mlir/include/mlir/Target/SPIRV/Serialization.h
    mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
    mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.h
    mlir/lib/Target/SPIRV/TranslateRegistration.cpp
    mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
index dc2b8c7edec42..a155b1989e63b 100644
--- a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
+++ b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
@@ -13,8 +13,8 @@
 #ifndef MLIR_TARGET_SPIRV_BINARY_UTILS_H_
 #define MLIR_TARGET_SPIRV_BINARY_UTILS_H_
 
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
-#include "mlir/Support/LogicalResult.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Support/LLVM.h"
 
 #include <cstdint>
 
@@ -41,6 +41,16 @@ uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode);
 /// Encodes an SPIR-V `literal` string into the given `binary` vector.
 LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
                                       StringRef literal);
+
+/// Decodes a string literal in `words` starting at `wordIndex`. Update the
+/// latter to point to the position in words after the string literal.
+inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
+                                     unsigned &wordIndex) {
+  StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
+  wordIndex += str.size() / 4 + 1;
+  return str;
+}
+
 } // namespace spirv
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Target/SPIRV/Serialization.h b/mlir/include/mlir/Target/SPIRV/Serialization.h
index 25033e2c4f169..498f390148793 100644
--- a/mlir/include/mlir/Target/SPIRV/Serialization.h
+++ b/mlir/include/mlir/Target/SPIRV/Serialization.h
@@ -22,11 +22,18 @@ class MLIRContext;
 namespace spirv {
 class ModuleOp;
 
+struct SerializationOptions {
+  /// Whether to emit `OpName` instructions for SPIR-V symbol ops.
+  bool emitSymbolName = true;
+  /// Whether to emit `OpLine` location information for SPIR-V ops.
+  bool emitDebugInfo = false;
+};
+
 /// Serializes the given SPIR-V `module` and writes to `binary`. On failure,
 /// reports errors to the error handler registered with the MLIR context for
 /// `module`.
 LogicalResult serialize(ModuleOp module, SmallVectorImpl<uint32_t> &binary,
-                        bool emitDebugInfo = false);
+                        const SerializationOptions &options = {});
 
 } // namespace spirv
 } // namespace mlir

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index c01362d1e2d4c..1c77dfdddcaa5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Location.h"
+#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 17060dddc9198..8351f3a41fa17 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -21,19 +21,6 @@
 #include "llvm/ADT/StringRef.h"
 #include <cstdint>
 
-//===----------------------------------------------------------------------===//
-// Utility Functions
-//===----------------------------------------------------------------------===//
-
-/// Decodes a string literal in `words` starting at `wordIndex`. Update the
-/// latter to point to the position in words after the string literal.
-static inline llvm::StringRef
-decodeStringLiteral(llvm::ArrayRef<uint32_t> words, unsigned &wordIndex) {
-  llvm::StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
-  wordIndex += str.size() / 4 + 1;
-  return str;
-}
-
 namespace mlir {
 namespace spirv {
 

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
index 33b886b6d369c..7d4d118a6783e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
@@ -23,12 +23,12 @@
 namespace mlir {
 LogicalResult spirv::serialize(spirv::ModuleOp module,
                                SmallVectorImpl<uint32_t> &binary,
-                               bool emitDebugInfo) {
+                               const SerializationOptions &options) {
   if (!module.vce_triple().hasValue())
     return module.emitError(
         "module must have 'vce_triple' attribute to be serializeable");
 
-  Serializer serializer(module, emitDebugInfo);
+  Serializer serializer(module, options);
 
   if (failed(serializer.serialize()))
     return failure();

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 6bd5ff1629b25..bcead6e527d5e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -81,9 +81,9 @@ LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
   return success();
 }
 
-Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo)
-    : module(module), mlirBuilder(module.getContext()),
-      emitDebugInfo(emitDebugInfo) {}
+Serializer::Serializer(spirv::ModuleOp module,
+                       const SerializationOptions &options)
+    : module(module), mlirBuilder(module.getContext()), options(options) {}
 
 LogicalResult Serializer::serialize() {
   LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
@@ -172,7 +172,7 @@ void Serializer::processCapability() {
 }
 
 void Serializer::processDebugInfo() {
-  if (!emitDebugInfo)
+  if (!options.emitDebugInfo)
     return;
   auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
   auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
@@ -254,12 +254,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
 
 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
   assert(!name.empty() && "unexpected empty string for OpName");
+  if (!options.emitSymbolName)
+    return success();
 
   SmallVector<uint32_t, 4> nameOperands;
   nameOperands.push_back(resultID);
-  if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
+  if (failed(spirv::encodeStringLiteralInto(nameOperands, name)))
     return failure();
-  }
   return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
 }
 
@@ -1170,7 +1171,7 @@ LogicalResult Serializer::emitDecoration(uint32_t target,
 
 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
                                         Location loc) {
-  if (!emitDebugInfo)
+  if (!options.emitDebugInfo)
     return success();
 
   if (lastProcessedWasMergeInst) {

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 9ee288a042b4f..5f4a4e999a673 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/Target/SPIRV/Serialization.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/raw_ostream.h"
@@ -42,7 +43,8 @@ LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
 class Serializer {
 public:
   /// Creates a serializer for the given SPIR-V `module`.
-  explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false);
+  explicit Serializer(spirv::ModuleOp module,
+                      const SerializationOptions &options);
 
   /// Serializes the remembered SPIR-V module.
   LogicalResult serialize();
@@ -316,8 +318,8 @@ class Serializer {
   /// An MLIR builder for getting MLIR constructs.
   mlir::Builder mlirBuilder;
 
-  /// A flag which indicates if the debuginfo should be emitted.
-  bool emitDebugInfo = false;
+  /// Serialization options.
+  SerializationOptions options;
 
   /// A flag which indicates if the last processed instruction was a merge
   /// instruction.

diff  --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
index 989de4142d761..e63a68c6ed57b 100644
--- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
+++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
@@ -40,7 +40,7 @@ static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
   context->loadDialect<spirv::SPIRVDialect>();
 
   // Make sure the input stream can be treated as a stream of SPIR-V words
-  auto start = input->getBufferStart();
+  auto *start = input->getBufferStart();
   auto size = input->getBufferSize();
   if (size % sizeof(uint32_t) != 0) {
     emitError(UnknownLoc::get(context))
@@ -94,8 +94,7 @@ static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
   if (spirvModules.size() != 1)
     return module.emitError("found more than one 'spv.module' op");
 
-  if (failed(
-          spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false)))
+  if (failed(spirv::serialize(spirvModules[0], binary)))
     return failure();
 
   output.write(reinterpret_cast<char *>(binary.data()),
@@ -133,7 +132,9 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
   if (std::next(spirvModules.begin()) != spirvModules.end())
     return srcModule.emitError("found more than one 'spv.module' op");
 
-  if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo)))
+  spirv::SerializationOptions options;
+  options.emitDebugInfo = emitDebugInfo;
+  if (failed(spirv::serialize(*spirvModules.begin(), binary, options)))
     return failure();
 
   MLIRContext deserializationContext(context->getDialectRegistry());

diff  --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index a3ae8bc3a9e75..9222b0cd3654a 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -37,10 +37,11 @@ class SerializationTest : public ::testing::Test {
 protected:
   SerializationTest() {
     context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
-    createModuleOp();
+    initModuleOp();
   }
 
-  void createModuleOp() {
+  /// Initializes an empty SPIR-V module op.
+  void initModuleOp() {
     OpBuilder builder(&context);
     OperationState state(UnknownLoc::get(&context),
                          spirv::ModuleOp::getOperationName());
@@ -58,27 +59,29 @@ class SerializationTest : public ::testing::Test {
     module = cast<spirv::ModuleOp>(Operation::create(state));
   }
 
-  Type getFloatStructType() {
-    OpBuilder opBuilder(module->getRegion());
-    llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
+  /// Gets the `struct { float }` type.
+  spirv::StructType getFloatStructType() {
+    OpBuilder builder(module->getRegion());
+    llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()};
     llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
-    auto structType = spirv::StructType::get(elementTypes, offsetInfo);
-    return structType;
+    return spirv::StructType::get(elementTypes, offsetInfo);
   }
 
-  void addGlobalVar(Type type, llvm::StringRef name) {
-    OpBuilder opBuilder(module->getRegion());
+  /// Inserts a global variable of the given `type` and `name`.
+  spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) {
+    OpBuilder builder(module->getRegion());
     auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
-    opBuilder.create<spirv::GlobalVariableOp>(
+    return builder.create<spirv::GlobalVariableOp>(
         UnknownLoc::get(&context), TypeAttr::get(ptrType),
-        opBuilder.getStringAttr(name), nullptr);
+        builder.getStringAttr(name), nullptr);
   }
 
+  /// Returns true if we can find a matching instruction in the SPIR-V blob.
   bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
                                                ArrayRef<uint32_t> operands)>
                            matchFn) {
     auto binarySize = binary.size();
-    auto begin = binary.begin();
+    auto *begin = binary.begin();
     auto currOffset = spirv::kHeaderWordCount;
 
     while (currOffset < binarySize) {
@@ -109,10 +112,12 @@ class SerializationTest : public ::testing::Test {
 // Block decoration
 //===----------------------------------------------------------------------===//
 
-TEST_F(SerializationTest, BlockDecorationTest) {
+TEST_F(SerializationTest, ContainsBlockDecoration) {
   auto structType = getFloatStructType();
   addGlobalVar(structType, "var0");
+
   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+
   auto hasBlockDecoration = [](spirv::Opcode opcode,
                                ArrayRef<uint32_t> operands) -> bool {
     if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
@@ -121,3 +126,35 @@ TEST_F(SerializationTest, BlockDecorationTest) {
   };
   EXPECT_TRUE(findInstruction(hasBlockDecoration));
 }
+
+TEST_F(SerializationTest, ContainsSymbolName) {
+  auto structType = getFloatStructType();
+  addGlobalVar(structType, "var0");
+
+  spirv::SerializationOptions options;
+  options.emitSymbolName = true;
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
+
+  auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+    unsigned index = 1; // Skip the result <id>
+    return opcode == spirv::Opcode::OpName &&
+           spirv::decodeStringLiteral(operands, index) == "var0";
+  };
+  EXPECT_TRUE(findInstruction(hasVarName));
+}
+
+TEST_F(SerializationTest, DoesNotContainSymbolName) {
+  auto structType = getFloatStructType();
+  addGlobalVar(structType, "var0");
+
+  spirv::SerializationOptions options;
+  options.emitSymbolName = false;
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
+
+  auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
+    unsigned index = 1; // Skip the result <id>
+    return opcode == spirv::Opcode::OpName &&
+           spirv::decodeStringLiteral(operands, index) == "var0";
+  };
+  EXPECT_FALSE(findInstruction(hasVarName));
+}


        


More information about the Mlir-commits mailing list