[Mlir-commits] [mlir] 1826fad - [mlir][bytecode] Avoid recording null arglocs & realloc opnames.

Jacques Pienaar llvmlistbot at llvm.org
Thu May 25 09:24:58 PDT 2023


Author: Jacques Pienaar
Date: 2023-05-25T09:24:50-07:00
New Revision: 1826fadb0d2bc5b61d6c028d8006b2a7d1249ec0

URL: https://github.com/llvm/llvm-project/commit/1826fadb0d2bc5b61d6c028d8006b2a7d1249ec0
DIFF: https://github.com/llvm/llvm-project/commit/1826fadb0d2bc5b61d6c028d8006b2a7d1249ec0.diff

LOG: [mlir][bytecode] Avoid recording null arglocs & realloc opnames.

For block arg locs a common case is no/uknown location (where the producer
signifies they don't care about blockarg location). Also avoid needing to
dynamically resize opnames during parsing.

Assumed to be post lazy loading change, so chose version 3.

Differential Revision: https://reviews.llvm.org/D151038

Added: 
    

Modified: 
    mlir/docs/BytecodeFormat.md
    mlir/include/mlir/Bytecode/Encoding.h
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/test/Bytecode/general.mlir
    mlir/test/Bytecode/invalid/invalid-structure.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index ca04d8ccbe267..589671ee6df3e 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -154,6 +154,7 @@ dialects that were also referenced.
 dialect_section {
   numDialects: varint,
   dialectNames: varint[],
+  numTotalOpNames: varint,
   opNames: op_name_group[]
 }
 
@@ -444,8 +445,8 @@ block_arguments {
 }
 
 block_argument {
-  typeIndex: varint,
-  location: varint
+  typeAndLocation: varint, // (type << 1) | (hasLocation)
+  location: varint? // Optional, else unknown location
 }
 ```
 

diff  --git a/mlir/include/mlir/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h
index 7ffbcfad1ccc0..a94bd50510e99 100644
--- a/mlir/include/mlir/Bytecode/Encoding.h
+++ b/mlir/include/mlir/Bytecode/Encoding.h
@@ -29,7 +29,7 @@ enum {
   kMinSupportedVersion = 0,
 
   /// The current bytecode version.
-  kVersion = 3,
+  kVersion = 4,
 
   /// An arbitrary value used to fill alignment padding.
   kAlignmentByte = 0xCB,

diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 8ff48ad72d0bf..ca05eac1e3e16 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1603,6 +1603,14 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
     opNames.emplace_back(dialect, opName);
     return success();
   };
+  // Avoid re-allocation in bytecode version > 3 where the number of ops are
+  // known.
+  if (version > 3) {
+    uint64_t numOps;
+    if (failed(sectionReader.parseVarInt(numOps)))
+      return failure();
+    opNames.reserve(numOps);
+  }
   while (!sectionReader.empty())
     if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
       return failure();
@@ -2175,13 +2183,25 @@ LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
   argTypes.reserve(numArgs);
   argLocs.reserve(numArgs);
 
+  Location unknownLoc = UnknownLoc::get(config.getContext());
   while (numArgs--) {
     Type argType;
-    LocationAttr argLoc;
-    if (failed(parseType(reader, argType)) ||
-        failed(parseAttribute(reader, argLoc)))
-      return failure();
-
+    LocationAttr argLoc = unknownLoc;
+    if (version > 3) {
+      // Parse the type with hasLoc flag to determine if it has type.
+      uint64_t typeIdx;
+      bool hasLoc;
+      if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
+          !(argType = attrTypeReader.resolveType(typeIdx)))
+        return failure();
+      if (hasLoc && failed(parseAttribute(reader, argLoc)))
+        return failure();
+    } else {
+      // All args has type and location.
+      if (failed(parseType(reader, argType)) ||
+          failed(parseAttribute(reader, argLoc)))
+        return failure();
+    }
     argTypes.push_back(argType);
     argLocs.push_back(argLoc);
   }

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index c67437f317396..93484913548ab 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -585,6 +585,9 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
                                  std::move(versionEmitter));
   }
 
+  if (config.bytecodeVersion > 3)
+    dialectEmitter.emitVarInt(size(numberingState.getOpNames()));
+
   // Emit the referenced operation names grouped by dialect.
   auto emitOpName = [&](OpNameNumbering &name) {
     dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect()));
@@ -670,8 +673,16 @@ void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) {
   if (hasArgs) {
     emitter.emitVarInt(args.size());
     for (BlockArgument arg : args) {
-      emitter.emitVarInt(numberingState.getNumber(arg.getType()));
-      emitter.emitVarInt(numberingState.getNumber(arg.getLoc()));
+      Location argLoc = arg.getLoc();
+      if (config.bytecodeVersion > 3) {
+        emitter.emitVarIntWithFlag(numberingState.getNumber(arg.getType()),
+                                   !isa<UnknownLoc>(argLoc));
+        if (!isa<UnknownLoc>(argLoc))
+          emitter.emitVarInt(numberingState.getNumber(argLoc));
+      } else {
+        emitter.emitVarInt(numberingState.getNumber(arg.getType()));
+        emitter.emitVarInt(numberingState.getNumber(argLoc));
+      }
     }
     if (config.bytecodeVersion > 2) {
       uint64_t maskOffset = emitter.size();
@@ -755,7 +766,7 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
 
     for (Region &region : op->getRegions()) {
       // If the region is not isolated from above, or we are emitting bytecode
-      // targetting version <2, we don't use a section.
+      // targeting version <2, we don't use a section.
       if (!isIsolatedFromAbove || config.bytecodeVersion < 2) {
         writeRegion(emitter, &region);
         continue;

diff  --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir
index 071b4c36526cd..180fb930737a0 100644
--- a/mlir/test/Bytecode/general.mlir
+++ b/mlir/test/Bytecode/general.mlir
@@ -32,7 +32,7 @@
   }
   "bytecode.branch"()[^secondBlock] : () -> ()
 
-^secondBlock(%arg1: i32, %arg2: !bytecode.int, %arg3: !pdl.operation):
+^secondBlock(%arg1: i32 loc(unknown), %arg2: !bytecode.int, %arg3: !pdl.operation loc(unknown)):
   "bytecode.regions"() ({
     "bytecode.operands"(%arg1, %arg2, %arg3) : (i32, !bytecode.int, !pdl.operation) -> ()
     "bytecode.return"() : () -> ()

diff  --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir
index ae18cfaff687c..1d2ed4833083e 100644
--- a/mlir/test/Bytecode/invalid/invalid-structure.mlir
+++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir
@@ -9,7 +9,7 @@
 //===--------------------------------------------------------------------===//
 
 // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION
-// VERSION: bytecode version 127 is newer than the current version 3
+// VERSION: bytecode version 127 is newer than the current version
 
 //===--------------------------------------------------------------------===//
 // Producer


        


More information about the Mlir-commits mailing list