[llvm] f291667 - [mlgo][nfc] Virtualize Logger implementation
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 1 16:03:17 PST 2022
Author: Mircea Trofin
Date: 2022-12-01T16:03:08-08:00
New Revision: f291667d61b66c82f123f6dd0d7a406633db1000
URL: https://github.com/llvm/llvm-project/commit/f291667d61b66c82f123f6dd0d7a406633db1000
DIFF: https://github.com/llvm/llvm-project/commit/f291667d61b66c82f123f6dd0d7a406633db1000.diff
LOG: [mlgo][nfc] Virtualize Logger implementation
This is in preparation for dropping the dependency on protobuf. This
first step allows us to subsequently introduce the non-protobuf
implementation behind a flag. After that we can update the training side
to ingest the new format, after which we can drop the protobuf
implementation and de-virtualize everything.
Differential Revision: https://reviews.llvm.org/D139062
Added:
Modified:
llvm/lib/Analysis/TrainingLogger.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/TrainingLogger.cpp b/llvm/lib/Analysis/TrainingLogger.cpp
index 81e8100b2d12..2aff026d15a6 100644
--- a/llvm/lib/Analysis/TrainingLogger.cpp
+++ b/llvm/lib/Analysis/TrainingLogger.cpp
@@ -52,10 +52,29 @@ void serialize(const Message &SE, std::string *OutStr) {
namespace llvm {
class LoggerDataImpl {
+protected:
const std::vector<TensorSpec> LoggedFeatureSpecs;
const TensorSpec RewardSpec;
const bool IncludeReward;
+ LoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs,
+ const TensorSpec &RewardSpec, bool IncludeReward)
+ : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
+ IncludeReward(IncludeReward) {}
+ virtual void logRewardImpl(const char *Value, size_t Size) = 0;
+
+public:
+ // flush the logged info to a stream and clear the log contents.
+ virtual void flush(std::string *Str) = 0;
+ virtual char *addNewTensor(size_t FeatureID) = 0;
+ virtual size_t getNrRecords() const = 0;
+ virtual ~LoggerDataImpl() = default;
+
+ template <typename T> void logReward(T Value) {
+ logRewardImpl(reinterpret_cast<const char *>(&Value), sizeof(T));
+ }
+};
+class TFSequenceExampleLoggerDataImpl : public LoggerDataImpl {
std::vector<tensorflow::FeatureList> FeatureLists;
tensorflow::FeatureList Reward;
@@ -94,13 +113,14 @@ class LoggerDataImpl {
}
public:
- LoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs,
- const TensorSpec &RewardSpec, bool IncludeReward)
- : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
- IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
+ TFSequenceExampleLoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs,
+ const TensorSpec &RewardSpec,
+ bool IncludeReward)
+ : LoggerDataImpl(LoggedSpecs, RewardSpec, IncludeReward),
+ FeatureLists(LoggedFeatureSpecs.size()) {}
// flush the logged info to a stream and clear the log contents.
- void flush(std::string *Str) {
+ void flush(std::string *Str) override {
size_t NrRecords = getNrRecords();
(void)NrRecords;
tensorflow::SequenceExample SE;
@@ -109,7 +129,7 @@ class LoggerDataImpl {
serialize(SE, Str);
}
- char *addNewTensor(size_t FeatureID) {
+ char *addNewTensor(size_t FeatureID) override {
const auto &Spec = LoggedFeatureSpecs[FeatureID];
if (Spec.isElementType<float>()) {
auto *RF = FeatureLists[FeatureID]
@@ -129,18 +149,22 @@ class LoggerDataImpl {
llvm_unreachable("Unsupported tensor type.");
}
- template <typename T> void logReward(T Value) {
+ void logRewardImpl(const char *Value, size_t Size) override {
assert(IncludeReward);
if (RewardSpec.isElementType<float>())
- Reward.add_feature()->mutable_float_list()->add_value(Value);
- else if (RewardSpec.isElementType<int32_t>() ||
- RewardSpec.isElementType<int64_t>())
- Reward.add_feature()->mutable_int64_list()->add_value(Value);
+ Reward.add_feature()->mutable_float_list()->add_value(
+ *reinterpret_cast<const float *>(Value));
+ else if (RewardSpec.isElementType<int32_t>())
+ Reward.add_feature()->mutable_int64_list()->add_value(
+ *reinterpret_cast<const int32_t *>(Value));
+ else if (RewardSpec.isElementType<int64_t>())
+ Reward.add_feature()->mutable_int64_list()->add_value(
+ *reinterpret_cast<const int64_t *>(Value));
else
llvm_unreachable("Unsupported tensor type.");
}
- size_t getNrRecords() const {
+ size_t getNrRecords() const override {
return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
}
};
@@ -150,8 +174,8 @@ Logger::Logger(const std::vector<TensorSpec> &FeatureSpecs,
const TensorSpec &RewardSpec, bool IncludeReward)
: FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
IncludeReward(IncludeReward),
- LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
- IncludeReward)) {}
+ LoggerData(std::make_unique<TFSequenceExampleLoggerDataImpl>(
+ FeatureSpecs, RewardSpec, IncludeReward)) {}
Logger::~Logger() {}
More information about the llvm-commits
mailing list