[Mlir-commits] [mlir] 83dc999 - [mlir:Bytecode][NFC] Refactor string section writing and reading

River Riddle llvmlistbot at llvm.org
Tue Aug 23 16:56:36 PDT 2022


Author: River Riddle
Date: 2022-08-23T16:56:03-07:00
New Revision: 83dc9999486fb3fb7e11d684312d034128d1b050

URL: https://github.com/llvm/llvm-project/commit/83dc9999486fb3fb7e11d684312d034128d1b050
DIFF: https://github.com/llvm/llvm-project/commit/83dc9999486fb3fb7e11d684312d034128d1b050.diff

LOG: [mlir:Bytecode][NFC] Refactor string section writing and reading

This extracts the string section writer and reader into dedicated
classes, which better separates the logic and will also simplify future
patches that want to interact with the string section.

Differential Revision: https://reviews.llvm.org/D132496

Added: 
    

Modified: 
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 7df545871842b..75ef8bcc80211 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -240,6 +240,69 @@ static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
   return resolveEntry(reader, entries, entryIdx, entry, entryStr);
 }
 
+//===----------------------------------------------------------------------===//
+// StringSectionReader
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class is used to read references to the string section from the
+/// bytecode.
+class StringSectionReader {
+public:
+  /// Initialize the string section reader with the given section data.
+  LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData);
+
+  /// Parse a shared string from the string section. The shared string is
+  /// encoded using an index to a corresponding string in the string section.
+  LogicalResult parseString(EncodingReader &reader, StringRef &result) {
+    return parseEntry(reader, strings, result, "string");
+  }
+
+private:
+  /// The table of strings referenced within the bytecode file.
+  SmallVector<StringRef> strings;
+};
+} // namespace
+
+LogicalResult StringSectionReader::initialize(Location fileLoc,
+                                              ArrayRef<uint8_t> sectionData) {
+  EncodingReader stringReader(sectionData, fileLoc);
+
+  // Parse the number of strings in the section.
+  uint64_t numStrings;
+  if (failed(stringReader.parseVarInt(numStrings)))
+    return failure();
+  strings.resize(numStrings);
+
+  // Parse each of the strings. The sizes of the strings are encoded in reverse
+  // order, so that's the order we populate the table.
+  size_t stringDataEndOffset = sectionData.size();
+  for (StringRef &string : llvm::reverse(strings)) {
+    uint64_t stringSize;
+    if (failed(stringReader.parseVarInt(stringSize)))
+      return failure();
+    if (stringDataEndOffset < stringSize) {
+      return stringReader.emitError(
+          "string size exceeds the available data size");
+    }
+
+    // Extract the string from the data, dropping the null character.
+    size_t stringOffset = stringDataEndOffset - stringSize;
+    string = StringRef(
+        reinterpret_cast<const char *>(sectionData.data() + stringOffset),
+        stringSize - 1);
+    stringDataEndOffset = stringOffset;
+  }
+
+  // Check that the only remaining data was for the strings, i.e. the reader
+  // should be at the same offset as the first string.
+  if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) {
+    return stringReader.emitError("unexpected trailing data between the "
+                                  "offsets for strings and their data");
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // BytecodeDialect
 //===----------------------------------------------------------------------===//
@@ -595,17 +658,6 @@ class BytecodeReader {
   LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState);
   LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
 
-  //===--------------------------------------------------------------------===//
-  // String Section
-
-  LogicalResult parseStringSection(ArrayRef<uint8_t> sectionData);
-
-  /// Parse a shared string from the string section. The shared string is
-  /// encoded using an index to a corresponding string in the string section.
-  LogicalResult parseSharedString(EncodingReader &reader, StringRef &result) {
-    return parseEntry(reader, strings, result, "string");
-  }
-
   //===--------------------------------------------------------------------===//
   // Value Processing
 
@@ -667,7 +719,7 @@ class BytecodeReader {
   SmallVector<BytecodeOperationName> opNames;
 
   /// The table of strings referenced within the bytecode file.
-  SmallVector<StringRef> strings;
+  StringSectionReader stringReader;
 
   /// The current set of available IR value scopes.
   std::vector<ValueScope> valueScopes;
@@ -726,7 +778,8 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
   }
 
   // Process the string section first.
-  if (failed(parseStringSection(*sectionDatas[bytecode::Section::kString])))
+  if (failed(stringReader.initialize(
+          fileLoc, *sectionDatas[bytecode::Section::kString])))
     return failure();
 
   // Process the dialect section.
@@ -777,13 +830,13 @@ BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
 
   // Parse each of the dialects.
   for (uint64_t i = 0; i < numDialects; ++i)
-    if (failed(parseSharedString(sectionReader, dialects[i].name)))
+    if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
       return failure();
 
   // Parse the operation names, which are grouped by dialect.
   auto parseOpName = [&](BytecodeDialect *dialect) {
     StringRef opName;
-    if (failed(parseSharedString(sectionReader, opName)))
+    if (failed(stringReader.parseString(sectionReader, opName)))
       return failure();
     opNames.emplace_back(dialect, opName);
     return success();
@@ -1091,51 +1144,6 @@ LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
   return defineValues(reader, block->getArguments());
 }
 
-//===----------------------------------------------------------------------===//
-// String Section
-
-LogicalResult
-BytecodeReader::parseStringSection(ArrayRef<uint8_t> sectionData) {
-  EncodingReader stringReader(sectionData, fileLoc);
-
-  // Parse the number of strings in the section.
-  uint64_t numStrings;
-  if (failed(stringReader.parseVarInt(numStrings)))
-    return failure();
-  strings.resize(numStrings);
-
-  // Parse each of the strings. The sizes of the strings are encoded in reverse
-  // order, so that's the order we populate the table.
-  size_t stringDataEndOffset = sectionData.size();
-  size_t totalStringDataSize = 0;
-  for (StringRef &string : llvm::reverse(strings)) {
-    uint64_t stringSize;
-    if (failed(stringReader.parseVarInt(stringSize)))
-      return failure();
-    if (stringDataEndOffset < stringSize) {
-      return stringReader.emitError(
-          "string size exceeds the available data size");
-    }
-
-    // Extract the string from the data, dropping the null character.
-    size_t stringOffset = stringDataEndOffset - stringSize;
-    string = StringRef(
-        reinterpret_cast<const char *>(sectionData.data() + stringOffset),
-        stringSize - 1);
-    stringDataEndOffset = stringOffset;
-
-    // Update the total string data size.
-    totalStringDataSize += stringSize;
-  }
-
-  // Check that the only remaining data was for the strings
-  if (stringReader.size() != totalStringDataSize) {
-    return stringReader.emitError("unexpected trailing data between the "
-                                  "offsets for strings and their data");
-  }
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Value Processing
 

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 8bf37c18030bf..ebf827fde41b8 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -196,6 +196,41 @@ void EncodingEmitter::emitMultiByteVarInt(uint64_t value) {
   emitBytes({reinterpret_cast<uint8_t *>(&value), sizeof(value)});
 }
 
+//===----------------------------------------------------------------------===//
+// StringSectionBuilder
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class is used to simplify the process of emitting the string section.
+class StringSectionBuilder {
+public:
+  /// Add the given string to the string section, and return the index of the
+  /// string within the section.
+  size_t insert(StringRef str) {
+    auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()});
+    return it.first->second;
+  }
+
+  /// Write the current set of strings to the given emitter.
+  void write(EncodingEmitter &emitter) {
+    emitter.emitVarInt(strings.size());
+
+    // Emit the sizes in reverse order, so that we don't need to backpatch an
+    // offset to the string data or have a separate section.
+    for (const auto &it : llvm::reverse(strings))
+      emitter.emitVarInt(it.first.size() + 1);
+    // Emit the string data itself.
+    for (const auto &it : strings)
+      emitter.emitNulTerminatedString(it.first.val());
+  }
+
+private:
+  /// A set of strings referenced within the bytecode. The value of the map is
+  /// unused.
+  llvm::MapVector<llvm::CachedHashStringRef, size_t> strings;
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Bytecode Writer
 //===----------------------------------------------------------------------===//
@@ -232,19 +267,14 @@ class BytecodeWriter {
 
   void writeStringSection(EncodingEmitter &emitter);
 
-  /// Get the number for the given shared string, that is contained within the
-  /// string section.
-  size_t getSharedStringNumber(StringRef str);
-
   //===--------------------------------------------------------------------===//
   // Fields
 
+  /// The builder used for the string section.
+  StringSectionBuilder stringSection;
+
   /// The IR numbering state generated for the root operation.
   IRNumberingState numberingState;
-
-  /// A set of strings referenced within the bytecode. The value of the map is
-  /// unused.
-  llvm::MapVector<llvm::CachedHashStringRef, size_t> strings;
 };
 } // namespace
 
@@ -314,11 +344,11 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
   auto dialects = numberingState.getDialects();
   dialectEmitter.emitVarInt(llvm::size(dialects));
   for (DialectNumbering &dialect : dialects)
-    dialectEmitter.emitVarInt(getSharedStringNumber(dialect.name));
+    dialectEmitter.emitVarInt(stringSection.insert(dialect.name));
 
   // Emit the referenced operation names grouped by dialect.
   auto emitOpName = [&](OpNameNumbering &name) {
-    dialectEmitter.emitVarInt(getSharedStringNumber(name.name.stripDialect()));
+    dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect()));
   };
   writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);
 
@@ -491,24 +521,10 @@ void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) {
 
 void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
   EncodingEmitter stringEmitter;
-  stringEmitter.emitVarInt(strings.size());
-
-  // Emit the sizes in reverse order, so that we don't need to backpatch an
-  // offset to the string data or have a separate section.
-  for (const auto &it : llvm::reverse(strings))
-    stringEmitter.emitVarInt(it.first.size() + 1);
-  // Emit the string data itself.
-  for (const auto &it : strings)
-    stringEmitter.emitNulTerminatedString(it.first.val());
-
+  stringSection.write(stringEmitter);
   emitter.emitSection(bytecode::Section::kString, std::move(stringEmitter));
 }
 
-size_t BytecodeWriter::getSharedStringNumber(StringRef str) {
-  auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()});
-  return it.first->second;
-}
-
 //===----------------------------------------------------------------------===//
 // Entry Points
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list