[llvm] 8b987ca - Add support for decoding base64.
Greg Clayton via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 23 16:13:31 PDT 2022
Author: Greg Clayton
Date: 2022-06-23T16:13:19-07:00
New Revision: 8b987ca5e37ee670dface295bd147b8d6d312084
URL: https://github.com/llvm/llvm-project/commit/8b987ca5e37ee670dface295bd147b8d6d312084
DIFF: https://github.com/llvm/llvm-project/commit/8b987ca5e37ee670dface295bd147b8d6d312084.diff
LOG: Add support for decoding base64.
An upcoming patch to LLDB will require the ability to decode base64. This patch adds support for decoding base64 and adds tests.
Differential Revision: https://reviews.llvm.org/D126254
Added:
Modified:
llvm/include/llvm/Support/Base64.h
llvm/unittests/Support/Base64Test.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Support/Base64.h b/llvm/include/llvm/Support/Base64.h
index da4ae1688574f..78bf31f82dd52 100644
--- a/llvm/include/llvm/Support/Base64.h
+++ b/llvm/include/llvm/Support/Base64.h
@@ -13,6 +13,7 @@
#ifndef LLVM_SUPPORT_BASE64_H
#define LLVM_SUPPORT_BASE64_H
+#include "llvm/Support/Error.h"
#include <cstdint>
#include <string>
@@ -52,6 +53,86 @@ template <class InputBytes> std::string encodeBase64(InputBytes const &Bytes) {
return Buffer;
}
+template <class OutputBytes>
+llvm::Error decodeBase64(llvm::StringRef Input, OutputBytes &Output) {
+ // Invalid table value with short name to fit in the table init below. The
+ // invalid value is 64 since valid base64 values are 0 - 63.
+ constexpr char Inv = 64;
+ static char DecodeTable[] = {
+ Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
+ Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
+ Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
+ Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
+ Inv, Inv, Inv, Inv, Inv, Inv, Inv, Inv, // ........
+ Inv, Inv, Inv, 62, Inv, Inv, Inv, 63, // ...+.../
+ 52, 53, 54, 55, 56, 57, 58, 59, // 01234567
+ 60, 61, Inv, Inv, Inv, 0, Inv, Inv, // 89...=..
+ Inv, 0, 1, 2, 3, 4, 5, 6, // .ABCDEFG
+ 7, 8, 9, 10, 11, 12, 13, 14, // HIJKLMNO
+ 15, 16, 17, 18, 19, 20, 21, 22, // PQRSTUVW
+ 23, 24, 25, Inv, Inv, Inv, Inv, Inv, // XYZ.....
+ Inv, 26, 27, 28, 29, 30, 31, 32, // .abcdefg
+ 33, 34, 35, 36, 37, 38, 39, 40, // hijklmno
+ 41, 42, 43, 44, 45, 46, 47, 48, // pqrstuvw
+ 49, 50, 51 // xyz.....
+ };
+ auto decodeBase64Byte = [](uint8_t Ch) -> char {
+ if (Ch >= sizeof(DecodeTable))
+ return Inv;
+ return DecodeTable[Ch];
+ };
+ Output.clear();
+ const uint64_t InputLength = Input.size();
+ if (InputLength == 0)
+ return Error::success();
+ // Make sure we have a valid input string length which must be a multiple
+ // of 4.
+ if ((InputLength % 4) != 0)
+ return createStringError(std::errc::illegal_byte_sequence,
+ "Base64 encoded strings must be a multiple of 4 "
+ "bytes in length");
+ const uint64_t FirstValidEqualIdx = InputLength - 2;
+ char Hex64Bytes[4];
+ for (uint64_t Idx = 0; Idx < InputLength; Idx += 4) {
+ for (uint64_t ByteOffset = 0; ByteOffset < 4; ++ByteOffset) {
+ const uint64_t ByteIdx = Idx + ByteOffset;
+ const char Byte = Input[ByteIdx];
+ const char DecodedByte = decodeBase64Byte(Byte);
+ bool Illegal = DecodedByte == Inv;
+ if (!Illegal && Byte == '=') {
+ if (ByteIdx < FirstValidEqualIdx) {
+ // We have an '=' in the middle of the string which is invalid, only
+ // the last two characters can be '=' characters.
+ Illegal = true;
+ } else if (ByteIdx == FirstValidEqualIdx && Input[ByteIdx + 1] != '=') {
+ // We have an equal second to last from the end and the last character
+ // is not also an equal, so the '=' character is invalid
+ Illegal = true;
+ }
+ }
+ if (Illegal)
+ return createStringError(
+ std::errc::illegal_byte_sequence,
+ "Invalid Base64 character %#2.2x at index %" PRIu64, Byte, ByteIdx);
+ Hex64Bytes[ByteOffset] = DecodedByte;
+ }
+ // Now we have 6 bits of 3 bytes in value in each of the Hex64Bytes bytes.
+ // Extract the right bytes into the Output buffer.
+ Output.push_back((Hex64Bytes[0] << 2) + ((Hex64Bytes[1] >> 4) & 0x03));
+ Output.push_back((Hex64Bytes[1] << 4) + ((Hex64Bytes[2] >> 2) & 0x0f));
+ Output.push_back((Hex64Bytes[2] << 6) + (Hex64Bytes[3] & 0x3f));
+ }
+ // If we had valid trailing '=' characters strip the right number of bytes
+ // from the end of the output buffer. We already know that the Input length
+ // it a multiple of 4 and is not zero, so direct character access is safe.
+ if (Input.back() == '=') {
+ Output.pop_back();
+ if (Input[InputLength - 2] == '=')
+ Output.pop_back();
+ }
+ return Error::success();
+}
+
} // end namespace llvm
#endif
diff --git a/llvm/unittests/Support/Base64Test.cpp b/llvm/unittests/Support/Base64Test.cpp
index 9622c866b38aa..e6e54bb53098d 100644
--- a/llvm/unittests/Support/Base64Test.cpp
+++ b/llvm/unittests/Support/Base64Test.cpp
@@ -13,6 +13,7 @@
#include "llvm/Support/Base64.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Testing/Support/Error.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -24,6 +25,28 @@ void TestBase64(StringRef Input, StringRef Final) {
EXPECT_EQ(Res, Final);
}
+void TestBase64Decode(StringRef Input, StringRef Expected,
+ StringRef ExpectedErrorMessage = {}) {
+ std::vector<char> DecodedBytes;
+ if (ExpectedErrorMessage.empty()) {
+ ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes), Succeeded());
+ EXPECT_EQ(llvm::ArrayRef<char>(DecodedBytes),
+ llvm::ArrayRef<char>(Expected.data(), Expected.size()));
+ } else {
+ ASSERT_THAT_ERROR(decodeBase64(Input, DecodedBytes),
+ FailedWithMessage(ExpectedErrorMessage));
+ }
+}
+
+char NonPrintableVector[] = {0x00, 0x00, 0x00, 0x46,
+ 0x00, 0x08, (char)0xff, (char)0xee};
+
+char LargeVector[] = {0x54, 0x68, 0x65, 0x20, 0x71, 0x75, 0x69, 0x63, 0x6b,
+ 0x20, 0x62, 0x72, 0x6f, 0x77, 0x6e, 0x20, 0x66, 0x6f,
+ 0x78, 0x20, 0x6a, 0x75, 0x6d, 0x70, 0x73, 0x20, 0x6f,
+ 0x76, 0x65, 0x72, 0x20, 0x31, 0x33, 0x20, 0x6c, 0x61,
+ 0x7a, 0x79, 0x20, 0x64, 0x6f, 0x67, 0x73, 0x2e};
+
} // namespace
TEST(Base64Test, Base64) {
@@ -37,16 +60,45 @@ TEST(Base64Test, Base64) {
TestBase64("foobar", "Zm9vYmFy");
// With non-printable values.
- char NonPrintableVector[] = {0x00, 0x00, 0x00, 0x46,
- 0x00, 0x08, (char)0xff, (char)0xee};
TestBase64({NonPrintableVector, sizeof(NonPrintableVector)}, "AAAARgAI/+4=");
// Large test case
- char LargeVector[] = {0x54, 0x68, 0x65, 0x20, 0x71, 0x75, 0x69, 0x63, 0x6b,
- 0x20, 0x62, 0x72, 0x6f, 0x77, 0x6e, 0x20, 0x66, 0x6f,
- 0x78, 0x20, 0x6a, 0x75, 0x6d, 0x70, 0x73, 0x20, 0x6f,
- 0x76, 0x65, 0x72, 0x20, 0x31, 0x33, 0x20, 0x6c, 0x61,
- 0x7a, 0x79, 0x20, 0x64, 0x6f, 0x67, 0x73, 0x2e};
TestBase64({LargeVector, sizeof(LargeVector)},
"VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIDEzIGxhenkgZG9ncy4=");
}
+
+TEST(Base64Test, DecodeBase64) {
+ std::vector<llvm::StringRef> Outputs = {"", "f", "fo", "foo",
+ "foob", "fooba", "foobar"};
+ Outputs.push_back(
+ llvm::StringRef(NonPrintableVector, sizeof(NonPrintableVector)));
+
+ Outputs.push_back(llvm::StringRef(LargeVector, sizeof(LargeVector)));
+ // Make sure we can encode and decode any byte.
+ std::vector<char> AllChars;
+ for (int Ch = INT8_MIN; Ch <= INT8_MAX; ++Ch)
+ AllChars.push_back(Ch);
+ Outputs.push_back(llvm::StringRef(AllChars.data(), AllChars.size()));
+
+ for (const auto &Output : Outputs) {
+ // We trust that encoding is working after running the Base64Test::Base64()
+ // test function above, so we can use it to encode the string and verify we
+ // can decode it correctly.
+ auto Input = encodeBase64(Output);
+ TestBase64Decode(Input, Output);
+ }
+ struct ErrorInfo {
+ llvm::StringRef Input;
+ llvm::StringRef ErrorMessage;
+ };
+ std::vector<ErrorInfo> ErrorInfos = {
+ {"f", "Base64 encoded strings must be a multiple of 4 bytes in length"},
+ {"=abc", "Invalid Base64 character 0x3d at index 0"},
+ {"a=bc", "Invalid Base64 character 0x3d at index 1"},
+ {"ab=c", "Invalid Base64 character 0x3d at index 2"},
+ {"fun!", "Invalid Base64 character 0x21 at index 3"},
+ };
+
+ for (const auto &EI : ErrorInfos)
+ TestBase64Decode(EI.Input, "", EI.ErrorMessage);
+}
More information about the llvm-commits
mailing list