[Mlir-commits] [mlir] cfd9093 - Fix MLIR bytecode loading of resources

Mehdi Amini llvmlistbot at llvm.org
Thu May 25 00:28:18 PDT 2023


Author: Mehdi Amini
Date: 2023-05-25T00:27:59-07:00
New Revision: cfd90939f70805f9824ac6b99bbe3ef4e50c7e8b

URL: https://github.com/llvm/llvm-project/commit/cfd90939f70805f9824ac6b99bbe3ef4e50c7e8b
DIFF: https://github.com/llvm/llvm-project/commit/cfd90939f70805f9824ac6b99bbe3ef4e50c7e8b.diff

LOG: Fix MLIR bytecode loading of resources

The bytecode reader didn't handle properly the case where resource names
conflicted and were renamed, leading to orphan handles in the IR as well
as overwriting the exiting resources.

Differential Revision: https://reviews.llvm.org/D151408

Added: 
    mlir/unittests/Bytecode/BytecodeTest.cpp
    mlir/unittests/Bytecode/CMakeLists.txt

Modified: 
    mlir/lib/AsmParser/Parser.cpp
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/unittests/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 69116ef39741b..75f4d4d607fc0 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -25,6 +25,7 @@
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/Endian.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/PrettyStackTrace.h"
 #include "llvm/Support/SourceMgr.h"
 #include <algorithm>
@@ -2482,6 +2483,13 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
     }
     llvm::support::ulittle32_t align;
     memcpy(&align, blobData->data(), sizeof(uint32_t));
+    if (align && !llvm::isPowerOf2_32(align)) {
+      return p.emitError(value.getLoc(),
+                         "expected hex string blob for key '" + key +
+                             "' to encode alignment in first 4 bytes, but got "
+                             "non-power-of-2 value: " +
+                             Twine(align));
+    }
 
     // Get the data portion of the blob.
     StringRef data = StringRef(*blobData).drop_front(sizeof(uint32_t));

diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 05a1d33276101..8ff48ad72d0bf 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -24,6 +24,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/MemoryBufferRef.h"
 #include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/SourceMgr.h"
@@ -516,6 +517,7 @@ class ResourceSectionReader {
 private:
   /// The table of dialect resources within the bytecode file.
   SmallVector<AsmDialectResourceHandle> dialectResources;
+  llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
 };
 
 class ParsedResourceEntry : public AsmParsedResourceEntry {
@@ -604,6 +606,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
                    EncodingReader &offsetReader, EncodingReader &resourceReader,
                    StringSectionReader &stringReader, T *handler,
                    const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
+                   function_ref<StringRef(StringRef)> remapKey = {},
                    function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
   uint64_t numResources;
   if (failed(offsetReader.parseVarInt(numResources)))
@@ -635,6 +638,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
 
     // Otherwise, parse the resource value.
     EncodingReader entryReader(data, fileLoc);
+    key = remapKey(key);
     ParsedResourceEntry entry(key, kind, entryReader, stringReader,
                               bufferOwnerRef);
     if (failed(handler->parseResource(entry)))
@@ -665,8 +669,16 @@ LogicalResult ResourceSectionReader::initialize(
   // provides most of the arguments.
   auto parseGroup = [&](auto *handler, bool allowEmpty = false,
                         function_ref<LogicalResult(StringRef)> keyFn = {}) {
+    auto resolveKey = [&](StringRef key) -> StringRef {
+      auto it = dialectResourceHandleRenamingMap.find(key);
+      if (it == dialectResourceHandleRenamingMap.end())
+        return "";
+      return it->second;
+    };
+
     return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
-                              stringReader, handler, bufferOwnerRef, keyFn);
+                              stringReader, handler, bufferOwnerRef, resolveKey,
+                              keyFn);
   };
 
   // Read the external resources from the bytecode.
@@ -714,6 +726,7 @@ LogicalResult ResourceSectionReader::initialize(
                << "unknown 'resource' key '" << key << "' for dialect '"
                << dialect->name << "'";
       }
+      dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
       dialectResources.push_back(*handle);
       return success();
     };

diff  --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
new file mode 100644
index 0000000000000..96cdbaeef205f
--- /dev/null
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -0,0 +1,75 @@
+//===- AdaptorTest.cpp - Adaptor unit tests -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeReader.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/Parser/Parser.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace mlir;
+
+using testing::ElementsAre;
+
+StringLiteral IRWithResources = R"(
+module @TestDialectResources attributes {
+  bytecode.test = dense_resource<resource> : tensor<4xi32>
+} {}
+{-#
+  dialect_resources: {
+    builtin: {
+      resource: "0x1000000001000000020000000300000004000000"
+    }
+  }
+#-}
+)";
+
+TEST(Bytecode, MultiModuleWithResource) {
+  MLIRContext context;
+  Builder builder(&context);
+  ParserConfig parseConfig(&context);
+  OwningOpRef<Operation *> module =
+      parseSourceString<Operation *>(IRWithResources, parseConfig);
+  ASSERT_TRUE(module);
+
+  // Write the module to bytecode
+  std::string buffer;
+  llvm::raw_string_ostream ostream(buffer);
+  ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
+
+  // Parse it back
+  OwningOpRef<Operation *> roundTripModule =
+      parseSourceString<Operation *>(ostream.str(), parseConfig);
+  ASSERT_TRUE(roundTripModule);
+
+  // Try to see if we have a valid resource in the parsed module.
+  auto checkResourceAttribute = [&](Operation *op) {
+    Attribute attr = roundTripModule->getAttr("bytecode.test");
+    ASSERT_TRUE(attr);
+    auto denseResourceAttr = dyn_cast<DenseI32ResourceElementsAttr>(attr);
+    ASSERT_TRUE(denseResourceAttr);
+    std::optional<ArrayRef<int32_t>> attrData =
+        denseResourceAttr.tryGetAsArrayRef();
+    ASSERT_TRUE(attrData.has_value());
+    ASSERT_EQ(attrData->size(), static_cast<size_t>(4));
+    EXPECT_EQ((*attrData)[0], 1);
+    EXPECT_EQ((*attrData)[1], 2);
+    EXPECT_EQ((*attrData)[2], 3);
+    EXPECT_EQ((*attrData)[3], 4);
+  };
+
+  checkResourceAttribute(*module);
+  checkResourceAttribute(*roundTripModule);
+}

diff  --git a/mlir/unittests/Bytecode/CMakeLists.txt b/mlir/unittests/Bytecode/CMakeLists.txt
new file mode 100644
index 0000000000000..82f7ee60e9c52
--- /dev/null
+++ b/mlir/unittests/Bytecode/CMakeLists.txt
@@ -0,0 +1,9 @@
+add_mlir_unittest(MLIRBytecodeTests
+  BytecodeTest.cpp
+)
+target_link_libraries(MLIRBytecodeTests
+  PRIVATE
+  MLIRBytecodeReader
+  MLIRBytecodeWriter
+  MLIRParser
+)

diff  --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index 137fc97a3fc02..5ca3826e529c0 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname)
 endfunction()
 
 add_subdirectory(Analysis)
+add_subdirectory(Bytecode)
 add_subdirectory(Conversion)
 add_subdirectory(Debug)
 add_subdirectory(Dialect)


        


More information about the Mlir-commits mailing list