[llvm] [mlir][bytecode] Check that bytecode source buffer is sufficiently aligned. (PR #66380)

Christian Sigg via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 16 04:58:21 PDT 2023


https://github.com/chsigg updated https://github.com/llvm/llvm-project/pull/66380

>From f8b97f3cc59c6ec1f0e5c7eee9285e0d891b0ada Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Thu, 14 Sep 2023 16:16:59 +0200
Subject: [PATCH 1/4] Check that bytecode source buffer is sufficiently
 aligned.

Adjust test to make a copy of the source buffer that is sufficiently aligned.
---
 mlir/lib/Bytecode/Reader/BytecodeReader.cpp   | 33 ++++++++++++++-----
 mlir/unittests/Bytecode/BytecodeTest.cpp      | 18 ++++++++--
 .../mlir/unittests/BUILD.bazel                | 21 ++++++++++++
 3 files changed, 60 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 483cbfda8d0e565..d6a7ab7f366d3bd 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -11,29 +11,33 @@
 #include "mlir/Bytecode/BytecodeImplementation.h"
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Bytecode/Encoding.h"
-#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/ScopeExit.h"
-#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Endian.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MemoryBufferRef.h"
-#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/SourceMgr.h"
+
+#include <cassert>
 #include <cstddef>
+#include <cstdint>
+#include <cstring>
 #include <list>
 #include <memory>
 #include <numeric>
 #include <optional>
+#include <string>
 
 #define DEBUG_TYPE "mlir-bytecode-reader"
 
@@ -93,23 +97,31 @@ namespace {
 class EncodingReader {
 public:
   explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
-      : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {}
+      : buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
   explicit EncodingReader(StringRef contents, Location fileLoc)
       : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
                         contents.size()},
                        fileLoc) {}
 
   /// Returns true if the entire section has been read.
-  bool empty() const { return dataIt == dataEnd; }
+  bool empty() const { return dataIt == buffer.end(); }
 
   /// Returns the remaining size of the bytecode.
-  size_t size() const { return dataEnd - dataIt; }
+  size_t size() const { return buffer.end() - dataIt; }
 
   /// Align the current reader position to the specified alignment.
   LogicalResult alignTo(unsigned alignment) {
     if (!llvm::isPowerOf2_32(alignment))
       return emitError("expected alignment to be a power-of-two");
 
+    // Ensure the data buffer was sufficiently aligned in the first place.
+    if (LLVM_UNLIKELY(
+            !llvm::isAddrAligned(llvm::Align(alignment), buffer.begin()))) {
+      return emitError("expected bytecode buffer to be aligned to ", alignment,
+                       ", but got pointer: '0x" +
+                           llvm::utohexstr((uintptr_t)buffer.begin()) + "'");
+    }
+
     // Shift the reader position to the next alignment boundary.
     while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) {
       uint8_t padding;
@@ -320,8 +332,11 @@ class EncodingReader {
     return success();
   }
 
-  /// The current data iterator, and an iterator to the end of the buffer.
-  const uint8_t *dataIt, *dataEnd;
+  /// The bytecode buffer.
+  ArrayRef<uint8_t> buffer;
+
+  /// The current iterator within the 'buffer'.
+  const uint8_t *dataIt;
 
   /// A location for the bytecode used to report errors.
   Location fileLoc;
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index fc86f132dd60b4d..26e63a3bf9ed34b 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Bytecode/BytecodeReader.h"
 #include "mlir/Bytecode/BytecodeWriter.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -19,6 +18,10 @@
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
+#include <algorithm>
+#include <cstdlib>
+#include <memory>
+
 using namespace llvm;
 using namespace mlir;
 
@@ -50,9 +53,18 @@ TEST(Bytecode, MultiModuleWithResource) {
   llvm::raw_string_ostream ostream(buffer);
   ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
 
+  // Make a sufficiently aligned copy of the buffer for reading it back.
+  ostream.flush();
+  constexpr std::size_t kAlignment = 16; // AsmResourceBlob alignment.
+  auto deleter = [](char *ptr) { std::free(ptr); };
+  std::unique_ptr<char, decltype(deleter)> aligned_buffer(
+      static_cast<char *>(std::aligned_alloc(kAlignment, buffer.size())),
+      deleter);
+  std::copy(buffer.begin(), buffer.end(), aligned_buffer.get());
+
   // Parse it back
-  OwningOpRef<Operation *> roundTripModule =
-      parseSourceString<Operation *>(ostream.str(), parseConfig);
+  OwningOpRef<Operation *> roundTripModule = parseSourceString<Operation *>(
+      {aligned_buffer.get(), buffer.size()}, parseConfig);
   ASSERT_TRUE(roundTripModule);
 
   // FIXME: Parsing external resources does not work on big-endian
diff --git a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
index c9773437f5a8435..ec6619bedd6f53e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
@@ -359,6 +359,27 @@ cc_test(
     ],
 )
 
+cc_test(
+    name = "bytecode_tests",
+    size = "small",
+    srcs = glob([
+        "Bytecode/*.cpp",
+        "Bytecode/*.h",
+        "Bytecode/*/*.cpp",
+        "Bytecode/*/*.h",
+    ]),
+    deps = [
+        "//llvm:Support",
+        "//mlir:BytecodeReader",
+        "//mlir:BytecodeWriter",
+        "//mlir:IR",
+        "//mlir:Parser",
+        "//third-party/unittest:gmock",
+        "//third-party/unittest:gtest",
+        "//third-party/unittest:gtest_main",
+    ],
+)
+
 cc_test(
     name = "conversion_tests",
     size = "small",

>From b10ca7aa56152e7ccd1f70fd5c7ff4f9f225d02a Mon Sep 17 00:00:00 2001
From: Christian Sigg <chsigg at users.noreply.github.com>
Date: Thu, 14 Sep 2023 22:41:27 +0200
Subject: [PATCH 2/4] Micro-optimize alignment check.

---
 mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index d6a7ab7f366d3bd..95ba6ed80946d28 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -114,16 +114,19 @@ class EncodingReader {
     if (!llvm::isPowerOf2_32(alignment))
       return emitError("expected alignment to be a power-of-two");
 
+    auto isUnaligned = [&](const uint8_t *ptr) {
+      return ((uintptr_t)ptr & (alignment - 1)) != 0;
+    };
+
     // Ensure the data buffer was sufficiently aligned in the first place.
-    if (LLVM_UNLIKELY(
-            !llvm::isAddrAligned(llvm::Align(alignment), buffer.begin()))) {
+    if (LLVM_UNLIKELY(isUnaligned(buffer.begin()))) {
       return emitError("expected bytecode buffer to be aligned to ", alignment,
                        ", but got pointer: '0x" +
                            llvm::utohexstr((uintptr_t)buffer.begin()) + "'");
     }
 
     // Shift the reader position to the next alignment boundary.
-    while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) {
+    while (isUnaligned(dataIt)) {
       uint8_t padding;
       if (failed(parseByte(padding)))
         return failure();
@@ -135,7 +138,7 @@ class EncodingReader {
 
     // Ensure the data iterator is now aligned. This case is unlikely because we
     // *just* went through the effort to align the data iterator.
-    if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) {
+    if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
       return emitError("expected data iterator aligned to ", alignment,
                        ", but got pointer: '0x" +
                            llvm::utohexstr((uintptr_t)dataIt) + "'");

>From 2fac90d1f69b70f3179636d8045b67251e090324 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Fri, 15 Sep 2023 11:29:28 +0200
Subject: [PATCH 3/4] Add test for insufficient alignment.

---
 mlir/unittests/Bytecode/BytecodeTest.cpp | 64 ++++++++++++++++++------
 1 file changed, 48 insertions(+), 16 deletions(-)

diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index 26e63a3bf9ed34b..b5f8c09b617bdd2 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -18,14 +18,10 @@
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
-#include <algorithm>
-#include <cstdlib>
-#include <memory>
-
 using namespace llvm;
 using namespace mlir;
 
-using testing::ElementsAre;
+using ::testing::StartsWith;
 
 StringLiteral IRWithResources = R"(
 module @TestDialectResources attributes {
@@ -34,7 +30,7 @@ module @TestDialectResources attributes {
 {-#
   dialect_resources: {
     builtin: {
-      resource: "0x1000000001000000020000000300000004000000"
+      resource: "0x2000000001000000020000000300000004000000"
     }
   }
 #-}
@@ -52,19 +48,19 @@ TEST(Bytecode, MultiModuleWithResource) {
   std::string buffer;
   llvm::raw_string_ostream ostream(buffer);
   ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
-
-  // Make a sufficiently aligned copy of the buffer for reading it back.
   ostream.flush();
-  constexpr std::size_t kAlignment = 16; // AsmResourceBlob alignment.
-  auto deleter = [](char *ptr) { std::free(ptr); };
-  std::unique_ptr<char, decltype(deleter)> aligned_buffer(
-      static_cast<char *>(std::aligned_alloc(kAlignment, buffer.size())),
-      deleter);
-  std::copy(buffer.begin(), buffer.end(), aligned_buffer.get());
+
+  // Create copy of buffer which is aligned to requested resource alignment.
+  constexpr size_t kAlignment = 0x20;
+  size_t buffer_size = buffer.size();
+  buffer.reserve(buffer_size + kAlignment - 1);
+  size_t pad = ~(uintptr_t)buffer.data() + 1 & kAlignment - 1;
+  buffer.insert(0, pad, ' ');
+  StringRef aligned_buffer(buffer.data() + pad, buffer_size);
 
   // Parse it back
-  OwningOpRef<Operation *> roundTripModule = parseSourceString<Operation *>(
-      {aligned_buffer.get(), buffer.size()}, parseConfig);
+  OwningOpRef<Operation *> roundTripModule =
+      parseSourceString<Operation *>(aligned_buffer, parseConfig);
   ASSERT_TRUE(roundTripModule);
 
   // FIXME: Parsing external resources does not work on big-endian
@@ -92,3 +88,39 @@ TEST(Bytecode, MultiModuleWithResource) {
   checkResourceAttribute(*module);
   checkResourceAttribute(*roundTripModule);
 }
+
+TEST(Bytecode, InsufficientAlignmentFailure) {
+  MLIRContext context;
+  Builder builder(&context);
+  ParserConfig parseConfig(&context);
+  OwningOpRef<Operation *> module =
+      parseSourceString<Operation *>(IRWithResources, parseConfig);
+  ASSERT_TRUE(module);
+
+  // Write the module to bytecode
+  std::string buffer;
+  llvm::raw_string_ostream ostream(buffer);
+  ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
+  ostream.flush();
+
+  // Create copy of buffer which is insufficiently aligned.
+  constexpr size_t kAlignment = 0x20;
+  size_t buffer_size = buffer.size();
+  buffer.reserve(buffer_size + kAlignment - 1);
+  size_t pad = ~(uintptr_t)buffer.data() + kAlignment / 2 + 1 & kAlignment - 1;
+  buffer.insert(0, pad, ' ');
+  StringRef misaligned_buffer(buffer.data() + pad, buffer_size);
+
+  std::unique_ptr<Diagnostic> diagnostic;
+  context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
+    diagnostic = std::make_unique<Diagnostic>(std::move(diag));
+  });
+
+  // Try to parse it back and check for alignment error.
+  OwningOpRef<Operation *> roundTripModule =
+      parseSourceString<Operation *>(misaligned_buffer, parseConfig);
+  EXPECT_FALSE(roundTripModule);
+  ASSERT_TRUE(diagnostic);
+  EXPECT_THAT(diagnostic->str(),
+              StartsWith("expected bytecode buffer to be aligned to 32"));
+}

>From ec81a8ea34117045d0df7eb07b774d49241e6816 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Fri, 15 Sep 2023 11:37:29 +0200
Subject: [PATCH 4/4] Undo IWYU.

---
 mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 95ba6ed80946d28..98a080d23a1dce5 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -11,10 +11,8 @@
 #include "mlir/Bytecode/BytecodeImplementation.h"
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Bytecode/Encoding.h"
-#include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
@@ -25,19 +23,14 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Endian.h"
-#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MemoryBufferRef.h"
 #include "llvm/Support/SourceMgr.h"
 
-#include <cassert>
 #include <cstddef>
-#include <cstdint>
-#include <cstring>
 #include <list>
 #include <memory>
 #include <numeric>
 #include <optional>
-#include <string>
 
 #define DEBUG_TYPE "mlir-bytecode-reader"
 



More information about the llvm-commits mailing list