[Mlir-commits] [mlir] b344939 - [mlir:Bytecode][NFC] Cleanup Attribute/Type reading

River Riddle llvmlistbot at llvm.org
Tue Aug 23 16:56:38 PDT 2022


Author: River Riddle
Date: 2022-08-23T16:56:03-07:00
New Revision: b3449392f55739ec950b038f6a81e5c0a5ef0cae

URL: https://github.com/llvm/llvm-project/commit/b3449392f55739ec950b038f6a81e5c0a5ef0cae
DIFF: https://github.com/llvm/llvm-project/commit/b3449392f55739ec950b038f6a81e5c0a5ef0cae.diff

LOG: [mlir:Bytecode][NFC] Cleanup Attribute/Type reading

This moves some parsing functionality from BytecodeReader to
AttrTypeReader, and removes some duplication between the attribute/type
code paths.

Differential Revision: https://reviews.llvm.org/D132497

Added: 
    

Modified: 
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 75ef8bcc8021..010e4e492fa6 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -411,18 +411,50 @@ class AttrTypeReader {
   }
   Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
 
+  /// Parse a reference to an attribute or type using the given reader.
+  LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
+    uint64_t attrIdx;
+    if (failed(reader.parseVarInt(attrIdx)))
+      return failure();
+    result = resolveAttribute(attrIdx);
+    return success(!!result);
+  }
+  LogicalResult parseType(EncodingReader &reader, Type &result) {
+    uint64_t typeIdx;
+    if (failed(reader.parseVarInt(typeIdx)))
+      return failure();
+    result = resolveType(typeIdx);
+    return success(!!result);
+  }
+
+  template <typename T>
+  LogicalResult parseAttribute(EncodingReader &reader, T &result) {
+    Attribute baseResult;
+    if (failed(parseAttribute(reader, baseResult)))
+      return failure();
+    if ((result = baseResult.dyn_cast<T>()))
+      return success();
+    return reader.emitError("expected attribute of type: ",
+                            llvm::getTypeName<T>(), ", but got: ", baseResult);
+  }
+
 private:
   /// Resolve the given entry at `index`.
   template <typename T>
   T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
                  StringRef entryType);
 
-  /// Parse the value defined within the given reader. `code` indicates how the
-  /// entry was encoded.
-  LogicalResult parseEntry(EncodingReader &reader, bool hasCustomEncoding,
-                           Attribute &result);
-  LogicalResult parseEntry(EncodingReader &reader, bool hasCustomEncoding,
-                           Type &result);
+  /// Parse an entry using the given reader that was encoded using the textual
+  /// assembly format.
+  template <typename T>
+  LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
+                              StringRef entryType);
+
+  /// Parse an entry using the given reader that was encoded using a custom
+  /// bytecode format.
+  template <typename T>
+  LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
+                                 StringRef entryType);
 
   /// The set of attribute and type entries.
   SmallVector<AttrEntry> attributes;
@@ -506,8 +538,15 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
 
   // Parse the entry.
   EncodingReader reader(entry.data, fileLoc);
-  if (failed(parseEntry(reader, entry.hasCustomEncoding, entry.entry)))
+
+  // Parse based on how the entry was encoded.
+  if (entry.hasCustomEncoding) {
+    if (failed(parseCustomEntry(entry, reader, entryType)))
+      return T();
+  } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
     return T();
+  }
+
   if (!reader.empty()) {
     (void)reader.emitError("unexpected trailing bytes after " + entryType +
                            " entry");
@@ -516,51 +555,37 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
   return entry.entry;
 }
 
-LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader,
-                                         bool hasCustomEncoding,
-                                         Attribute &result) {
-  // Handle the fallback case, where the attribute was encoded using its
-  // assembly format.
-  if (!hasCustomEncoding) {
-    StringRef attrStr;
-    if (failed(reader.parseNullTerminatedString(attrStr)))
-      return failure();
-
-    size_t numRead = 0;
-    if (!(result = parseAttribute(attrStr, fileLoc->getContext(), numRead)))
-      return failure();
-    if (numRead != attrStr.size()) {
-      return reader.emitError(
-          "trailing characters found after Attribute assembly format: ",
-          attrStr.drop_front(numRead));
-    }
-    return success();
-  }
-
-  return reader.emitError("unexpected Attribute encoding");
-}
+template <typename T>
+LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
+                                            StringRef entryType) {
+  StringRef asmStr;
+  if (failed(reader.parseNullTerminatedString(asmStr)))
+    return failure();
 
-LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader,
-                                         bool hasCustomEncoding, Type &result) {
-  // Handle the fallback case, where the type was encoded using its
-  // assembly format.
-  if (!hasCustomEncoding) {
-    StringRef typeStr;
-    if (failed(reader.parseNullTerminatedString(typeStr)))
-      return failure();
+  // Invoke the MLIR assembly parser to parse the entry text.
+  size_t numRead = 0;
+  MLIRContext *context = fileLoc->getContext();
+  if constexpr (std::is_same_v<T, Type>)
+    result = ::parseType(asmStr, context, numRead);
+  else
+    result = ::parseAttribute(asmStr, context, numRead);
+  if (!result)
+    return failure();
 
-    size_t numRead = 0;
-    if (!(result = parseType(typeStr, fileLoc->getContext(), numRead)))
-      return failure();
-    if (numRead != typeStr.size()) {
-      return reader.emitError(
-          "trailing characters found after Type assembly format: " +
-          typeStr.drop_front(numRead));
-    }
-    return success();
+  // Ensure there weren't dangling characters after the entry.
+  if (numRead != asmStr.size()) {
+    return reader.emitError("trailing characters found after ", entryType,
+                            " assembly format: ", asmStr.drop_front(numRead));
   }
+  return success();
+}
 
-  return reader.emitError("unexpected Type encoding");
+template <typename T>
+LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
+                                               EncodingReader &reader,
+                                               StringRef entryType) {
+  // FIXME: Add support for reading custom attribute/type encodings.
+  return reader.emitError("unexpected Attribute encoding");
 }
 
 //===----------------------------------------------------------------------===//
@@ -600,20 +625,13 @@ class BytecodeReader {
   //===--------------------------------------------------------------------===//
   // Attribute/Type Section
 
-  /// Parse an attribute or type using the given reader. Returns nullptr in the
-  /// case of failure.
-  Attribute parseAttribute(EncodingReader &reader);
-  Type parseType(EncodingReader &reader);
-
+  /// Parse an attribute or type using the given reader.
   template <typename T>
-  T parseAttribute(EncodingReader &reader) {
-    if (Attribute attr = parseAttribute(reader)) {
-      if (auto derivedAttr = attr.dyn_cast<T>())
-        return derivedAttr;
-      (void)reader.emitError("expected attribute of type: ",
-                             llvm::getTypeName<T>(), ", but got: ", attr);
-    }
-    return T();
+  LogicalResult parseAttribute(EncodingReader &reader, T &result) {
+    return attrTypeReader.parseAttribute(reader, result);
+  }
+  LogicalResult parseType(EncodingReader &reader, Type &result) {
+    return attrTypeReader.parseType(reader, result);
   }
 
   //===--------------------------------------------------------------------===//
@@ -863,23 +881,6 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
   return *opName->opName;
 }
 
-//===----------------------------------------------------------------------===//
-// Attribute/Type Section
-
-Attribute BytecodeReader::parseAttribute(EncodingReader &reader) {
-  uint64_t attrIdx;
-  if (failed(reader.parseVarInt(attrIdx)))
-    return Attribute();
-  return attrTypeReader.resolveAttribute(attrIdx);
-}
-
-Type BytecodeReader::parseType(EncodingReader &reader) {
-  uint64_t typeIdx;
-  if (failed(reader.parseVarInt(typeIdx)))
-    return Type();
-  return attrTypeReader.resolveType(typeIdx);
-}
-
 //===----------------------------------------------------------------------===//
 // IR Section
 
@@ -996,8 +997,8 @@ BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
     return failure();
 
   /// Parse the location.
-  LocationAttr opLoc = parseAttribute<LocationAttr>(reader);
-  if (!opLoc)
+  LocationAttr opLoc;
+  if (failed(parseAttribute(reader, opLoc)))
     return failure();
 
   // With the location and name resolved, we can start building the operation
@@ -1006,8 +1007,8 @@ BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
 
   // Parse the attributes of the operation.
   if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
-    DictionaryAttr dictAttr = parseAttribute<DictionaryAttr>(reader);
-    if (!dictAttr)
+    DictionaryAttr dictAttr;
+    if (failed(parseAttribute(reader, dictAttr)))
       return failure();
     opState.attributes = dictAttr;
   }
@@ -1019,7 +1020,7 @@ BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
       return failure();
     opState.types.resize(numResults);
     for (int i = 0, e = numResults; i < e; ++i)
-      if (!(opState.types[i] = parseType(reader)))
+      if (failed(parseType(reader, opState.types[i])))
         return failure();
   }
 
@@ -1130,11 +1131,10 @@ LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
   argLocs.reserve(numArgs);
 
   while (numArgs--) {
-    Type argType = parseType(reader);
-    if (!argType)
-      return failure();
-    LocationAttr argLoc = parseAttribute<LocationAttr>(reader);
-    if (!argLoc)
+    Type argType;
+    LocationAttr argLoc;
+    if (failed(parseType(reader, argType)) ||
+        failed(parseAttribute(reader, argLoc)))
       return failure();
 
     argTypes.push_back(argType);


        


More information about the Mlir-commits mailing list