[llvm] 9e3919d - [Object][DX] Parse DXContainer Parts
Chris Bieneman via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 1 12:55:46 PDT 2022
Author: Chris Bieneman
Date: 2022-06-01T14:55:36-05:00
New Revision: 9e3919dac449d6e018ffbc77b5511b8ab858ede3
URL: https://github.com/llvm/llvm-project/commit/9e3919dac449d6e018ffbc77b5511b8ab858ede3
DIFF: https://github.com/llvm/llvm-project/commit/9e3919dac449d6e018ffbc77b5511b8ab858ede3.diff
LOG: [Object][DX] Parse DXContainer Parts
DXContainer files are structured as parts. This patch adds support for
parsing out the file part offsets and file part headers.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D124804
Added:
Modified:
llvm/include/llvm/BinaryFormat/DXContainer.h
llvm/include/llvm/Object/DXContainer.h
llvm/lib/Object/DXContainer.cpp
llvm/unittests/Object/DXContainerTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h
index 34660c8fe0e26..85706d55bd8e2 100644
--- a/llvm/include/llvm/BinaryFormat/DXContainer.h
+++ b/llvm/include/llvm/BinaryFormat/DXContainer.h
@@ -49,14 +49,14 @@ struct ShaderHash {
uint32_t Flags; // DxilShaderHashFlags
uint8_t Digest[16];
- void byteSwap() { sys::swapByteOrder(Flags); }
+ void swapBytes() { sys::swapByteOrder(Flags); }
};
struct ContainerVersion {
uint16_t Major;
uint16_t Minor;
- void byteSwap() {
+ void swapBytes() {
sys::swapByteOrder(Major);
sys::swapByteOrder(Minor);
}
@@ -69,8 +69,8 @@ struct Header {
uint32_t FileSize;
uint32_t PartCount;
- void byteSwap() {
- Version.byteSwap();
+ void swapBytes() {
+ Version.swapBytes();
sys::swapByteOrder(FileSize);
sys::swapByteOrder(PartCount);
}
@@ -82,6 +82,8 @@ struct Header {
struct PartHeader {
uint8_t Name[4];
uint32_t Size;
+
+ void swapBytes() { sys::swapByteOrder(Size); }
// Structure is followed directly by part data: uint8_t PartData[PartSize].
};
diff --git a/llvm/include/llvm/Object/DXContainer.h b/llvm/include/llvm/Object/DXContainer.h
index 69a5e28539d6c..b3b153a125687 100644
--- a/llvm/include/llvm/Object/DXContainer.h
+++ b/llvm/include/llvm/Object/DXContainer.h
@@ -15,6 +15,7 @@
#ifndef LLVM_OBJECT_DXCONTAINER_H
#define LLVM_OBJECT_DXCONTAINER_H
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/Support/Error.h"
@@ -28,10 +29,80 @@ class DXContainer {
MemoryBufferRef Data;
dxbc::Header Header;
+ SmallVector<uint32_t, 4> PartOffsets;
Error parseHeader();
+ Error parsePartOffsets();
+ friend class PartIterator;
public:
+ // The PartIterator is a wrapper around the iterator for the PartOffsets
+ // member of the DXContainer. It contains a refernce to the container, and the
+ // current iterator value, as well as storage for a parsed part header.
+ class PartIterator {
+ const DXContainer &Container;
+ SmallVectorImpl<uint32_t>::const_iterator OffsetIt;
+ struct PartData {
+ dxbc::PartHeader Part;
+ StringRef Data;
+ } IteratorState;
+
+ friend class DXContainer;
+
+ PartIterator(const DXContainer &C,
+ SmallVectorImpl<uint32_t>::const_iterator It)
+ : Container(C), OffsetIt(It) {
+ if (OffsetIt == Container.PartOffsets.end())
+ updateIteratorImpl(Container.PartOffsets.back());
+ else
+ updateIterator();
+ }
+
+ // Updates the iterator's state data. This results in copying the part
+ // header into the iterator and handling any required byte swapping. This is
+ // called when incrementing or decrementing the iterator.
+ void updateIterator() {
+ if (OffsetIt != Container.PartOffsets.end())
+ updateIteratorImpl(*OffsetIt);
+ }
+
+ // Implementation for updating the iterator state based on a specified
+ // offest.
+ void updateIteratorImpl(const uint32_t Offset);
+
+ public:
+ PartIterator &operator++() {
+ if (OffsetIt == Container.PartOffsets.end())
+ return *this;
+ ++OffsetIt;
+ updateIterator();
+ return *this;
+ }
+
+ PartIterator operator++(int) {
+ PartIterator Tmp = *this;
+ ++(*this);
+ return Tmp;
+ }
+
+ bool operator==(const PartIterator &RHS) const {
+ return OffsetIt == RHS.OffsetIt;
+ }
+
+ bool operator!=(const PartIterator &RHS) const {
+ return OffsetIt != RHS.OffsetIt;
+ }
+
+ const PartData &operator*() { return IteratorState; }
+ const PartData *operator->() { return &IteratorState; }
+ };
+
+ PartIterator begin() const {
+ return PartIterator(*this, PartOffsets.begin());
+ }
+
+ PartIterator end() const { return PartIterator(*this, PartOffsets.end()); }
+
StringRef getData() const { return Data.getBuffer(); }
static Expected<DXContainer> create(MemoryBufferRef Object);
diff --git a/llvm/lib/Object/DXContainer.cpp b/llvm/lib/Object/DXContainer.cpp
index 2c43f519dd947..ae3d2f00c2e4d 100644
--- a/llvm/lib/Object/DXContainer.cpp
+++ b/llvm/lib/Object/DXContainer.cpp
@@ -18,15 +18,32 @@ static Error parseFailed(const Twine &Msg) {
}
template <typename T>
-static Error readStruct(StringRef Buffer, const char *P, T &Struct) {
+static Error readStruct(StringRef Buffer, const char *Src, T &Struct) {
// Don't read before the beginning or past the end of the file
- if (P < Buffer.begin() || P + sizeof(T) > Buffer.end())
+ if (Src < Buffer.begin() || Src + sizeof(T) > Buffer.end())
return parseFailed("Reading structure out of file bounds");
- memcpy(&Struct, P, sizeof(T));
+ memcpy(&Struct, Src, sizeof(T));
// DXContainer is always little endian
if (sys::IsBigEndianHost)
- Struct.byteSwap();
+ Struct.swapBytes();
+ return Error::success();
+}
+
+template <typename T>
+static Error readInteger(StringRef Buffer, const char *Src, T &Val) {
+ static_assert(std::is_integral<T>::value,
+ "Cannot call readInteger on non-integral type.");
+ assert(reinterpret_cast<uintptr_t>(Src) % alignof(T) == 0 &&
+ "Unaligned read of value from buffer!");
+ // Don't read before the beginning or past the end of the file
+ if (Src < Buffer.begin() || Src + sizeof(T) > Buffer.end())
+ return parseFailed("Reading structure out of file bounds");
+
+ Val = *reinterpret_cast<const T *>(Src);
+ // DXContainer is always little endian
+ if (sys::IsBigEndianHost)
+ sys::swapByteOrder(Val);
return Error::success();
}
@@ -36,9 +53,35 @@ Error DXContainer::parseHeader() {
return readStruct(Data.getBuffer(), Data.getBuffer().data(), Header);
}
+Error DXContainer::parsePartOffsets() {
+ const char *Current = Data.getBuffer().data() + sizeof(dxbc::Header);
+ for (uint32_t Part = 0; Part < Header.PartCount; ++Part) {
+ uint32_t PartOffset;
+ if (Error Err = readInteger(Data.getBuffer(), Current, PartOffset))
+ return Err;
+ Current += sizeof(uint32_t);
+ if (PartOffset + sizeof(dxbc::PartHeader) > Data.getBufferSize())
+ return parseFailed("Part offset points beyond boundary of the file");
+ PartOffsets.push_back(PartOffset);
+ }
+ return Error::success();
+}
+
Expected<DXContainer> DXContainer::create(MemoryBufferRef Object) {
DXContainer Container(Object);
if (Error Err = Container.parseHeader())
return std::move(Err);
+ if (Error Err = Container.parsePartOffsets())
+ return std::move(Err);
return Container;
}
+
+void DXContainer::PartIterator::updateIteratorImpl(const uint32_t Offset) {
+ StringRef Buffer = Container.Data.getBuffer();
+ const char *Current = Buffer.data() + Offset;
+ // Offsets are validated during parsing, so all offsets in the container are
+ // valid and contain enough readable data to read a header.
+ cantFail(readStruct(Buffer, Current, IteratorState.Part));
+ IteratorState.Data =
+ StringRef(Current + sizeof(dxbc::PartHeader), IteratorState.Part.Size);
+}
diff --git a/llvm/unittests/Object/DXContainerTest.cpp b/llvm/unittests/Object/DXContainerTest.cpp
index 1c8ef4fe2722d..14fb4b82d7c58 100644
--- a/llvm/unittests/Object/DXContainerTest.cpp
+++ b/llvm/unittests/Object/DXContainerTest.cpp
@@ -39,11 +39,17 @@ TEST(DXCFile, ParseHeaderErrors) {
FailedWithMessage("Reading structure out of file bounds"));
}
+TEST(DXCFile, EmptyFile) {
+ EXPECT_THAT_EXPECTED(
+ DXContainer::create(MemoryBufferRef(StringRef("", 0), "")),
+ FailedWithMessage("Reading structure out of file bounds"));
+}
+
TEST(DXCFile, ParseHeader) {
uint8_t Buffer[] = {0x44, 0x58, 0x42, 0x43, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
- 0x70, 0x0D, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00};
+ 0x70, 0x0D, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
DXContainer C =
llvm::cantFail(DXContainer::create(getMemoryBuffer<32>(Buffer)));
EXPECT_TRUE(memcmp(C.getHeader().Magic, "DXBC", 4) == 0);
@@ -52,3 +58,71 @@ TEST(DXCFile, ParseHeader) {
EXPECT_EQ(C.getHeader().Version.Major, 1u);
EXPECT_EQ(C.getHeader().Version.Minor, 0u);
}
+
+TEST(DXCFile, ParsePartMissingOffsets) {
+ uint8_t Buffer[] = {
+ 0x44, 0x58, 0x42, 0x43, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00,
+ 0x00, 0x00, 0x70, 0x0D, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ };
+ EXPECT_THAT_EXPECTED(
+ DXContainer::create(getMemoryBuffer<32>(Buffer)),
+ FailedWithMessage("Reading structure out of file bounds"));
+}
+
+TEST(DXCFile, ParsePartInvalidOffsets) {
+ uint8_t Buffer[] = {
+ 0x44, 0x58, 0x42, 0x43, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x70, 0x0D, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF,
+ };
+ EXPECT_THAT_EXPECTED(
+ DXContainer::create(getMemoryBuffer<36>(Buffer)),
+ FailedWithMessage("Part offset points beyond boundary of the file"));
+}
+
+TEST(DXCFile, ParseEmptyParts) {
+ uint8_t Buffer[] = {
+ 0x44, 0x58, 0x42, 0x43, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x70, 0x0D, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x3C, 0x00, 0x00, 0x00,
+ 0x44, 0x00, 0x00, 0x00, 0x4C, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00,
+ 0x5C, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x6C, 0x00, 0x00, 0x00,
+ 0x53, 0x46, 0x49, 0x30, 0x00, 0x00, 0x00, 0x00, 0x49, 0x53, 0x47, 0x31,
+ 0x00, 0x00, 0x00, 0x00, 0x4F, 0x53, 0x47, 0x31, 0x00, 0x00, 0x00, 0x00,
+ 0x50, 0x53, 0x56, 0x30, 0x00, 0x00, 0x00, 0x00, 0x53, 0x54, 0x41, 0x54,
+ 0x00, 0x00, 0x00, 0x00, 0x43, 0x58, 0x49, 0x4C, 0x00, 0x00, 0x00, 0x00,
+ 0x44, 0x45, 0x41, 0x44, 0x00, 0x00, 0x00, 0x00,
+ };
+ DXContainer C =
+ llvm::cantFail(DXContainer::create(getMemoryBuffer<116>(Buffer)));
+ EXPECT_EQ(C.getHeader().PartCount, 7u);
+
+ // All the part sizes are 0, which makes a nice test of the range based for
+ int ElementsVisited = 0;
+ for (auto Part : C) {
+ EXPECT_EQ(Part.Part.Size, 0u);
+ EXPECT_EQ(Part.Data.size(), 0u);
+ ++ElementsVisited;
+ }
+ EXPECT_EQ(ElementsVisited, 7);
+
+ {
+ auto It = C.begin();
+ EXPECT_TRUE(memcmp(It->Part.Name, "SFI0", 4) == 0);
+ ++It;
+ EXPECT_TRUE(memcmp(It->Part.Name, "ISG1", 4) == 0);
+ ++It;
+ EXPECT_TRUE(memcmp(It->Part.Name, "OSG1", 4) == 0);
+ ++It;
+ EXPECT_TRUE(memcmp(It->Part.Name, "PSV0", 4) == 0);
+ ++It;
+ EXPECT_TRUE(memcmp(It->Part.Name, "STAT", 4) == 0);
+ ++It;
+ EXPECT_TRUE(memcmp(It->Part.Name, "CXIL", 4) == 0);
+ ++It;
+ EXPECT_TRUE(memcmp(It->Part.Name, "DEAD", 4) == 0);
+ ++It; // Don't increment past the end
+ EXPECT_TRUE(memcmp(It->Part.Name, "DEAD", 4) == 0);
+ }
+}
More information about the llvm-commits
mailing list