[llvm] 0cb9746 - [nfc][mlgo] Separate logger and training-mode model evaluator

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 3 16:20:56 PDT 2022


Author: Mircea Trofin
Date: 2022-08-03T16:20:28-07:00
New Revision: 0cb9746a7d85000fc7bbd8ac5d8557179ca4521c

URL: https://github.com/llvm/llvm-project/commit/0cb9746a7d85000fc7bbd8ac5d8557179ca4521c
DIFF: https://github.com/llvm/llvm-project/commit/0cb9746a7d85000fc7bbd8ac5d8557179ca4521c.diff

LOG: [nfc][mlgo] Separate logger and training-mode model evaluator

This just shuffles implementations and declarations around. Now the
logger and the TF C API-based model evaluator are separate.

Differential Revision: https://reviews.llvm.org/D131116

Added: 
    llvm/include/llvm/Analysis/Utils/TrainingLogger.h
    llvm/lib/Analysis/TrainingLogger.cpp
    llvm/unittests/Analysis/TrainingLoggerTest.cpp

Modified: 
    llvm/include/llvm/Analysis/Utils/TFUtils.h
    llvm/lib/Analysis/CMakeLists.txt
    llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
    llvm/lib/Analysis/TFUtils.cpp
    llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp
    llvm/unittests/Analysis/CMakeLists.txt
    llvm/unittests/Analysis/TFUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h
index 372c35863f3fb..91ae397ef7650 100644
--- a/llvm/include/llvm/Analysis/Utils/TFUtils.h
+++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h
@@ -39,81 +39,6 @@ namespace llvm {
 class TFModelEvaluatorImpl;
 class EvaluationResultImpl;
 
-/// Logging utility - given an ordered specification of features, and assuming
-/// a scalar reward, allow logging feature values and rewards, and then print
-/// as tf.train.SequenceExample text protobuf.
-/// The assumption is that, for an event to be logged (i.e. a set of feature
-/// values and a reward), the user calls the log* API for each feature exactly
-/// once, providing the index matching the position in the feature spec list
-/// provided at construction. The example assumes the first feature's element
-/// type is float, the second is int64, and the reward is float:
-///
-/// event 0:
-///   logFloatValue(0, ...)
-///   logInt64Value(1, ...)
-///   ...
-///   logFloatReward(...)
-/// event 1:
-///   logFloatValue(0, ...)
-///   logInt64Value(1, ...)
-///   ...
-///   logFloatReward(...)
-///
-/// At the end, call print to generate the protobuf.
-/// Alternatively, don't call logReward at the end of each event, just
-/// log{Float|Int32|Int64}FinalReward at the end.
-class LoggerDataImpl;
-class Logger final {
-public:
-  /// Construct a Logger. If IncludeReward is false, then logReward or
-  /// logFinalReward shouldn't be called, and the reward feature won't be
-  /// printed out.
-  /// NOTE: the FeatureSpecs are expected to be in the same order (i.e. have
-  /// corresponding indices) with any MLModelRunner implementations
-  /// corresponding to the model being trained/logged.
-  Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
-         const TensorSpec &RewardSpec, bool IncludeReward);
-
-  ~Logger();
-
-  void logFloatReward(float Value);
-  void logInt32Reward(int32_t Value);
-  void logInt64Reward(int64_t Value);
-
-  void logFloatFinalReward(float Value);
-  void logInt32FinalReward(int32_t Value);
-  void logInt64FinalReward(int64_t Value);
-
-  void logFloatValue(size_t FeatureID, const float *Value);
-  void logInt32Value(size_t FeatureID, const int32_t *Value);
-  void logInt64Value(size_t FeatureID, const int64_t *Value);
-
-  void logSpecifiedTensorValue(size_t FeatureID, const char *RawData);
-
-  // Warning! For int32_t, the return is set up for int64_t, so the caller needs
-  // to piecemeal cast their int32_t values.
-  // FIXME: let's drop int32_t support. While it's supported by evaluator, it's
-  // not supported by the tensorflow::SequenceExample proto. For small values,
-  // we can consider using bytes.
-  char *addEntryAndGetFloatOrInt64Buffer(size_t FeatureID);
-
-  // Flush the content of the log to the stream, clearing the stored data in the
-  // process.
-  void flush(std::string *Str);
-  void flush(raw_ostream &OS);
-
-  // Flush a set of logs that are produced from the same module, e.g.
-  // per-function regalloc traces, as a google::protobuf::Struct message.
-  static void flushLogs(raw_ostream &OS,
-                        const StringMap<std::unique_ptr<Logger>> &Loggers);
-
-private:
-  std::vector<LoggedFeatureSpec> FeatureSpecs;
-  TensorSpec RewardSpec;
-  const bool IncludeReward;
-  std::unique_ptr<LoggerDataImpl> LoggerData;
-};
-
 class TFModelEvaluator final {
 public:
   /// The result of a model evaluation. Handles the lifetime of the output

diff  --git a/llvm/include/llvm/Analysis/Utils/TrainingLogger.h b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h
new file mode 100644
index 0000000000000..89a02aff82fe5
--- /dev/null
+++ b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h
@@ -0,0 +1,103 @@
+//===- TrainingLogger.h - mlgo feature/reward logging  ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#ifndef LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H
+#define LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H
+
+#include "llvm/Config/llvm-config.h"
+
+#ifdef LLVM_HAVE_TF_API
+#include "llvm/ADT/StringMap.h"
+#include "llvm/Analysis/TensorSpec.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/JSON.h"
+
+#include <memory>
+#include <vector>
+
+namespace llvm {
+
+/// Logging utility - given an ordered specification of features, and assuming
+/// a scalar reward, allow logging feature values and rewards, and then print
+/// as tf.train.SequenceExample text protobuf.
+/// The assumption is that, for an event to be logged (i.e. a set of feature
+/// values and a reward), the user calls the log* API for each feature exactly
+/// once, providing the index matching the position in the feature spec list
+/// provided at construction. The example assumes the first feature's element
+/// type is float, the second is int64, and the reward is float:
+///
+/// event 0:
+///   logFloatValue(0, ...)
+///   logInt64Value(1, ...)
+///   ...
+///   logFloatReward(...)
+/// event 1:
+///   logFloatValue(0, ...)
+///   logInt64Value(1, ...)
+///   ...
+///   logFloatReward(...)
+///
+/// At the end, call print to generate the protobuf.
+/// Alternatively, don't call logReward at the end of each event, just
+/// log{Float|Int32|Int64}FinalReward at the end.
+class LoggerDataImpl;
+class Logger final {
+public:
+  /// Construct a Logger. If IncludeReward is false, then logReward or
+  /// logFinalReward shouldn't be called, and the reward feature won't be
+  /// printed out.
+  /// NOTE: the FeatureSpecs are expected to be in the same order (i.e. have
+  /// corresponding indices) with any MLModelRunner implementations
+  /// corresponding to the model being trained/logged.
+  Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
+         const TensorSpec &RewardSpec, bool IncludeReward);
+
+  ~Logger();
+
+  void logFloatReward(float Value);
+  void logInt32Reward(int32_t Value);
+  void logInt64Reward(int64_t Value);
+
+  void logFloatFinalReward(float Value);
+  void logInt32FinalReward(int32_t Value);
+  void logInt64FinalReward(int64_t Value);
+
+  void logFloatValue(size_t FeatureID, const float *Value);
+  void logInt32Value(size_t FeatureID, const int32_t *Value);
+  void logInt64Value(size_t FeatureID, const int64_t *Value);
+
+  void logSpecifiedTensorValue(size_t FeatureID, const char *RawData);
+
+  // Warning! For int32_t, the return is set up for int64_t, so the caller needs
+  // to piecemeal cast their int32_t values.
+  // FIXME: let's drop int32_t support. While it's supported by evaluator, it's
+  // not supported by the tensorflow::SequenceExample proto. For small values,
+  // we can consider using bytes.
+  char *addEntryAndGetFloatOrInt64Buffer(size_t FeatureID);
+
+  // Flush the content of the log to the stream, clearing the stored data in the
+  // process.
+  void flush(std::string *Str);
+  void flush(raw_ostream &OS);
+
+  // Flush a set of logs that are produced from the same module, e.g.
+  // per-function regalloc traces, as a google::protobuf::Struct message.
+  static void flushLogs(raw_ostream &OS,
+                        const StringMap<std::unique_ptr<Logger>> &Loggers);
+
+private:
+  std::vector<LoggedFeatureSpec> FeatureSpecs;
+  TensorSpec RewardSpec;
+  const bool IncludeReward;
+  std::unique_ptr<LoggerDataImpl> LoggerData;
+};
+
+} // namespace llvm
+
+#endif // LLVM_HAVE_TF_API
+#endif // LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H

diff  --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index 529b522ac0ad6..a122851e5495f 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -130,10 +130,11 @@ add_llvm_component_library(LLVMAnalysis
   SyncDependenceAnalysis.cpp
   SyntheticCountsUtils.cpp
   TFUtils.cpp
-  TensorSpec.cpp
   TargetLibraryInfo.cpp
   TargetTransformInfo.cpp
+  TensorSpec.cpp
   Trace.cpp
+  TrainingLogger.cpp
   TypeBasedAliasAnalysis.cpp
   TypeMetadataUtils.cpp
   ScopedNoAliasAA.cpp

diff  --git a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
index 79ea160afc224..e44b022925114 100644
--- a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
 #include "llvm/Analysis/NoInferenceModelRunner.h"
 #include "llvm/Analysis/Utils/TFUtils.h"
+#include "llvm/Analysis/Utils/TrainingLogger.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ManagedStatic.h"

diff  --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp
index 682fc095b0e91..de5dde8d2ea4f 100644
--- a/llvm/lib/Analysis/TFUtils.cpp
+++ b/llvm/lib/Analysis/TFUtils.cpp
@@ -22,23 +22,13 @@
 #include "llvm/Support/Path.h"
 #include "llvm/Support/raw_ostream.h"
 
-#include "google/protobuf/struct.pb.h"
-#include "google/protobuf/text_format.h"
 #include "tensorflow/c/c_api.h"
 #include "tensorflow/c/c_api_experimental.h"
-#include "tensorflow/core/example/example.pb.h"
 #include <cassert>
 #include <numeric>
 
 using namespace llvm;
 
-using google::protobuf::Message;
-using google::protobuf::TextFormat;
-
-static cl::opt<bool>
-    ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden,
-                     cl::desc("Output textual (human-readable) protobuf."));
-
 namespace {
 
 using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
@@ -72,14 +62,6 @@ TFSessionOptionsPtr createTFSessionOptions() {
   return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
 }
 
-void serialize(const Message &SE, std::string *OutStr) {
-  if (ProtobufTextMode) {
-    TextFormat::PrintToString(SE, OutStr);
-  } else {
-    *OutStr = SE.SerializeAsString();
-  }
-}
-
 int getTFTypeIndex(TensorType TType) {
   switch (TType) {
   case TensorType::Double:
@@ -182,99 +164,6 @@ class TFModelEvaluatorImpl {
                                 const TensorSpec &OutputSpec);
 };
 
-class LoggerDataImpl {
-  const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs;
-  const TensorSpec RewardSpec;
-  const bool IncludeReward;
-
-  std::vector<tensorflow::FeatureList> FeatureLists;
-  tensorflow::FeatureList Reward;
-
-  bool isSelfConsistent(const tensorflow::SequenceExample &SE,
-                        size_t NrRecords) const {
-    bool Ret = true;
-    for (const auto &TSpecs : LoggedFeatureSpecs) {
-      const auto &Name = TSpecs.getLoggingName();
-      const auto &FL = SE.feature_lists().feature_list().at(Name).feature();
-      if (NrRecords != static_cast<size_t>(FL.size())) {
-        dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected "
-               << NrRecords << " got " << FL.size() << "\n";
-        Ret = false;
-      }
-    }
-    if (IncludeReward && static_cast<size_t>(SE.feature_lists()
-                                                 .feature_list()
-                                                 .at(RewardSpec.name())
-                                                 .feature()
-                                                 .size()) != NrRecords) {
-      dbgs() << "[TF-UTILS]: reward is missing records.\n";
-      Ret = false;
-    }
-    return Ret;
-  }
-
-  void transferLog(tensorflow::SequenceExample &SE) {
-    auto *FL = SE.mutable_feature_lists()->mutable_feature_list();
-    if (IncludeReward)
-      (*FL)[RewardSpec.name()] = std::move(Reward);
-    assert(FeatureLists.size() == LoggedFeatureSpecs.size());
-    for (size_t I = 0; I < FeatureLists.size(); ++I) {
-      const auto &LFS = LoggedFeatureSpecs[I];
-      (*FL)[LFS.getLoggingName()] = std::move(FeatureLists[I]);
-    }
-  }
-
-public:
-  LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
-                 const TensorSpec &RewardSpec, bool IncludeReward)
-      : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
-        IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
-
-  // flush the logged info to a stream and clear the log contents.
-  void flush(std::string *Str) {
-    size_t NrRecords = getNrRecords();
-    (void)NrRecords;
-    tensorflow::SequenceExample SE;
-    transferLog(SE);
-    assert(isSelfConsistent(SE, NrRecords));
-    serialize(SE, Str);
-  }
-
-  char *addNewTensor(size_t FeatureID) {
-    const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec;
-    if (Spec.isElementType<float>()) {
-      auto *RF = FeatureLists[FeatureID]
-                     .add_feature()
-                     ->mutable_float_list()
-                     ->mutable_value();
-      RF->Resize(Spec.getElementCount(), 0.0);
-      return reinterpret_cast<char *>(RF->mutable_data());
-    } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) {
-      auto *RF = FeatureLists[FeatureID]
-                     .add_feature()
-                     ->mutable_int64_list()
-                     ->mutable_value();
-      RF->Resize(Spec.getElementCount(), 0);
-      return reinterpret_cast<char *>(RF->mutable_data());
-    }
-    llvm_unreachable("Unsupported tensor type.");
-  }
-
-  template <typename T> void logReward(T Value) {
-    assert(IncludeReward);
-    if (RewardSpec.isElementType<float>())
-      Reward.add_feature()->mutable_float_list()->add_value(Value);
-    else if (RewardSpec.isElementType<int32_t>() ||
-             RewardSpec.isElementType<int64_t>())
-      Reward.add_feature()->mutable_int64_list()->add_value(Value);
-    else
-      llvm_unreachable("Unsupported tensor type.");
-  }
-
-  size_t getNrRecords() const {
-    return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
-  }
-};
 } // namespace llvm
 
 TFModelEvaluatorImpl::TFModelEvaluatorImpl(
@@ -427,97 +316,4 @@ TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
 TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
 TFModelEvaluator::~TFModelEvaluator() {}
 
-Logger::Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
-               const TensorSpec &RewardSpec, bool IncludeReward)
-    : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
-      IncludeReward(IncludeReward),
-      LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
-                                                  IncludeReward)) {}
-
-Logger::~Logger() {}
-
-#define LOG_REWARD(NAME, TYPE)                                                 \
-  void Logger::log##NAME##Reward(TYPE Value) {                                 \
-    assert(IncludeReward);                                                     \
-    LoggerData->logReward(Value);                                              \
-  }
-
-LOG_REWARD(Float, float)
-LOG_REWARD(Int32, int32_t)
-LOG_REWARD(Int64, int64_t)
-#undef LOG_REWARD
-
-#define LOG_FINAL_REWARD(NAME, TYPE)                                           \
-  void Logger::log##NAME##FinalReward(TYPE Value) {                            \
-    assert(RewardSpec.isElementType<TYPE>());                                  \
-    for (size_t I = 1; I < LoggerData->getNrRecords(); ++I)                    \
-      log##NAME##Reward(0);                                                    \
-    log##NAME##Reward(Value);                                                  \
-  }
-
-LOG_FINAL_REWARD(Float, float)
-LOG_FINAL_REWARD(Int32, int32_t)
-LOG_FINAL_REWARD(Int64, int64_t)
-#undef LOG_FINAL_REWARD
-
-void Logger::logFloatValue(size_t FeatureID, const float *Value) {
-  assert(FeatureSpecs[FeatureID].Spec.isElementType<float>());
-  logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
-}
-
-void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) {
-  assert(FeatureSpecs[FeatureID].Spec.isElementType<int64_t>());
-  logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
-}
-
-void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) {
-  assert(FeatureSpecs[FeatureID].Spec.isElementType<int32_t>());
-  logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
-}
-
-void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) {
-  const auto &Spec = FeatureSpecs[FeatureID].Spec;
-  char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID);
-  if (Spec.isElementType<int32_t>())
-    for (size_t I = 0; I < Spec.getElementCount(); ++I)
-      (reinterpret_cast<int64_t *>(Buff))[I] =
-          static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]);
-  else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>())
-    std::memcpy(Buff, RawData,
-                Spec.getElementCount() * Spec.getElementByteSize());
-  else
-    llvm_unreachable("Unsupported tensor type");
-}
-
-char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) {
-  return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
-}
-
-void Logger::flush(std::string *Str) { LoggerData->flush(Str); }
-
-void Logger::flush(raw_ostream &OS) {
-  std::string Buff;
-  LoggerData->flush(&Buff);
-  OS << Buff;
-}
-
-void Logger::flushLogs(raw_ostream &OS,
-                       const StringMap<std::unique_ptr<Logger>> &Loggers) {
-  google::protobuf::Struct Msg;
-  for (const auto &NamedLogger : Loggers) {
-    tensorflow::SequenceExample SE;
-    const auto &Logger = NamedLogger.second;
-    std::string Unencoded;
-    if (Logger->LoggerData->getNrRecords() > 0)
-      Logger->flush(&Unencoded);
-
-    (*Msg.mutable_fields())[NamedLogger.first().str()]
-        .mutable_string_value()
-        ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded));
-  }
-
-  std::string OutStr;
-  serialize(Msg, &OutStr);
-  OS << OutStr;
-}
 #endif // defined(LLVM_HAVE_TF_API)

diff  --git a/llvm/lib/Analysis/TrainingLogger.cpp b/llvm/lib/Analysis/TrainingLogger.cpp
new file mode 100644
index 0000000000000..bdde216e48cb6
--- /dev/null
+++ b/llvm/lib/Analysis/TrainingLogger.cpp
@@ -0,0 +1,242 @@
+//===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logging infrastructure for extracting features and
+// rewards for mlgo policy training.
+//
+//===----------------------------------------------------------------------===//
+#include "llvm/Config/config.h"
+#if defined(LLVM_HAVE_TF_API)
+
+#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/Utils/TrainingLogger.h"
+#include "llvm/Support/Base64.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/JSON.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "google/protobuf/struct.pb.h"
+#include "google/protobuf/text_format.h"
+#include "tensorflow/core/example/example.pb.h"
+#include <cassert>
+#include <numeric>
+
+using namespace llvm;
+
+using google::protobuf::Message;
+using google::protobuf::TextFormat;
+
+static cl::opt<bool>
+    ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden,
+                     cl::desc("Output textual (human-readable) protobuf."));
+
+namespace {
+
+void serialize(const Message &SE, std::string *OutStr) {
+  if (ProtobufTextMode) {
+    TextFormat::PrintToString(SE, OutStr);
+  } else {
+    *OutStr = SE.SerializeAsString();
+  }
+}
+} // namespace
+
+namespace llvm {
+
+class LoggerDataImpl {
+  const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs;
+  const TensorSpec RewardSpec;
+  const bool IncludeReward;
+
+  std::vector<tensorflow::FeatureList> FeatureLists;
+  tensorflow::FeatureList Reward;
+
+  bool isSelfConsistent(const tensorflow::SequenceExample &SE,
+                        size_t NrRecords) const {
+    bool Ret = true;
+    for (const auto &TSpecs : LoggedFeatureSpecs) {
+      const auto &Name = TSpecs.getLoggingName();
+      const auto &FL = SE.feature_lists().feature_list().at(Name).feature();
+      if (NrRecords != static_cast<size_t>(FL.size())) {
+        dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected "
+               << NrRecords << " got " << FL.size() << "\n";
+        Ret = false;
+      }
+    }
+    if (IncludeReward && static_cast<size_t>(SE.feature_lists()
+                                                 .feature_list()
+                                                 .at(RewardSpec.name())
+                                                 .feature()
+                                                 .size()) != NrRecords) {
+      dbgs() << "[TF-UTILS]: reward is missing records.\n";
+      Ret = false;
+    }
+    return Ret;
+  }
+
+  void transferLog(tensorflow::SequenceExample &SE) {
+    auto *FL = SE.mutable_feature_lists()->mutable_feature_list();
+    if (IncludeReward)
+      (*FL)[RewardSpec.name()] = std::move(Reward);
+    assert(FeatureLists.size() == LoggedFeatureSpecs.size());
+    for (size_t I = 0; I < FeatureLists.size(); ++I) {
+      const auto &LFS = LoggedFeatureSpecs[I];
+      (*FL)[LFS.getLoggingName()] = std::move(FeatureLists[I]);
+    }
+  }
+
+public:
+  LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
+                 const TensorSpec &RewardSpec, bool IncludeReward)
+      : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
+        IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
+
+  // flush the logged info to a stream and clear the log contents.
+  void flush(std::string *Str) {
+    size_t NrRecords = getNrRecords();
+    (void)NrRecords;
+    tensorflow::SequenceExample SE;
+    transferLog(SE);
+    assert(isSelfConsistent(SE, NrRecords));
+    serialize(SE, Str);
+  }
+
+  char *addNewTensor(size_t FeatureID) {
+    const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec;
+    if (Spec.isElementType<float>()) {
+      auto *RF = FeatureLists[FeatureID]
+                     .add_feature()
+                     ->mutable_float_list()
+                     ->mutable_value();
+      RF->Resize(Spec.getElementCount(), 0.0);
+      return reinterpret_cast<char *>(RF->mutable_data());
+    } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) {
+      auto *RF = FeatureLists[FeatureID]
+                     .add_feature()
+                     ->mutable_int64_list()
+                     ->mutable_value();
+      RF->Resize(Spec.getElementCount(), 0);
+      return reinterpret_cast<char *>(RF->mutable_data());
+    }
+    llvm_unreachable("Unsupported tensor type.");
+  }
+
+  template <typename T> void logReward(T Value) {
+    assert(IncludeReward);
+    if (RewardSpec.isElementType<float>())
+      Reward.add_feature()->mutable_float_list()->add_value(Value);
+    else if (RewardSpec.isElementType<int32_t>() ||
+             RewardSpec.isElementType<int64_t>())
+      Reward.add_feature()->mutable_int64_list()->add_value(Value);
+    else
+      llvm_unreachable("Unsupported tensor type.");
+  }
+
+  size_t getNrRecords() const {
+    return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
+  }
+};
+} // namespace llvm
+
+Logger::Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
+               const TensorSpec &RewardSpec, bool IncludeReward)
+    : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
+      IncludeReward(IncludeReward),
+      LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
+                                                  IncludeReward)) {}
+
+Logger::~Logger() {}
+
+#define LOG_REWARD(NAME, TYPE)                                                 \
+  void Logger::log##NAME##Reward(TYPE Value) {                                 \
+    assert(IncludeReward);                                                     \
+    LoggerData->logReward(Value);                                              \
+  }
+
+LOG_REWARD(Float, float)
+LOG_REWARD(Int32, int32_t)
+LOG_REWARD(Int64, int64_t)
+#undef LOG_REWARD
+
+#define LOG_FINAL_REWARD(NAME, TYPE)                                           \
+  void Logger::log##NAME##FinalReward(TYPE Value) {                            \
+    assert(RewardSpec.isElementType<TYPE>());                                  \
+    for (size_t I = 1; I < LoggerData->getNrRecords(); ++I)                    \
+      log##NAME##Reward(0);                                                    \
+    log##NAME##Reward(Value);                                                  \
+  }
+
+LOG_FINAL_REWARD(Float, float)
+LOG_FINAL_REWARD(Int32, int32_t)
+LOG_FINAL_REWARD(Int64, int64_t)
+#undef LOG_FINAL_REWARD
+
+void Logger::logFloatValue(size_t FeatureID, const float *Value) {
+  assert(FeatureSpecs[FeatureID].Spec.isElementType<float>());
+  logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
+}
+
+void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) {
+  assert(FeatureSpecs[FeatureID].Spec.isElementType<int64_t>());
+  logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
+}
+
+void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) {
+  assert(FeatureSpecs[FeatureID].Spec.isElementType<int32_t>());
+  logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
+}
+
+void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) {
+  const auto &Spec = FeatureSpecs[FeatureID].Spec;
+  char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID);
+  if (Spec.isElementType<int32_t>())
+    for (size_t I = 0; I < Spec.getElementCount(); ++I)
+      (reinterpret_cast<int64_t *>(Buff))[I] =
+          static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]);
+  else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>())
+    std::memcpy(Buff, RawData,
+                Spec.getElementCount() * Spec.getElementByteSize());
+  else
+    llvm_unreachable("Unsupported tensor type");
+}
+
+char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) {
+  return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
+}
+
+void Logger::flush(std::string *Str) { LoggerData->flush(Str); }
+
+void Logger::flush(raw_ostream &OS) {
+  std::string Buff;
+  LoggerData->flush(&Buff);
+  OS << Buff;
+}
+
+void Logger::flushLogs(raw_ostream &OS,
+                       const StringMap<std::unique_ptr<Logger>> &Loggers) {
+  google::protobuf::Struct Msg;
+  for (const auto &NamedLogger : Loggers) {
+    tensorflow::SequenceExample SE;
+    const auto &Logger = NamedLogger.second;
+    std::string Unencoded;
+    if (Logger->LoggerData->getNrRecords() > 0)
+      Logger->flush(&Unencoded);
+
+    (*Msg.mutable_fields())[NamedLogger.first().str()]
+        .mutable_string_value()
+        ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded));
+  }
+
+  std::string OutStr;
+  serialize(Msg, &OutStr);
+  OS << OutStr;
+}
+#endif // defined(LLVM_HAVE_TF_API)

diff  --git a/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp b/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp
index d21d552227cfc..cab483903e1dc 100644
--- a/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp
+++ b/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp
@@ -18,6 +18,7 @@
 #if defined(LLVM_HAVE_TF_AOT_REGALLOCEVICTMODEL) || defined(LLVM_HAVE_TF_API)
 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
 #include "llvm/Analysis/NoInferenceModelRunner.h"
+#include "llvm/Analysis/Utils/TrainingLogger.h"
 #endif
 #include "llvm/Analysis/ReleaseModeModelRunner.h"
 #include "llvm/CodeGen/CalcSpillWeights.h"

diff  --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index af45304fdaf10..a748949ad4c2f 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -6,7 +6,7 @@ set(LLVM_LINK_COMPONENTS
   TransformUtils
   )
 
-set(MLGO_TESTS TFUtilsTest.cpp)
+set(MLGO_TESTS TFUtilsTest.cpp TrainingLoggerTest.cpp)
 if (DEFINED LLVM_HAVE_TF_API)
   LIST(APPEND EXTRA_TESTS ${MLGO_TESTS})
 else()

diff  --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp
index 6ec129cf413d7..6d70ab08636ed 100644
--- a/llvm/unittests/Analysis/TFUtilsTest.cpp
+++ b/llvm/unittests/Analysis/TFUtilsTest.cpp
@@ -7,9 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Analysis/Utils/TFUtils.h"
-#include "google/protobuf/struct.pb.h"
-#include "tensorflow/core/example/example.pb.h"
-#include "tensorflow/core/example/feature.pb.h"
 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
 #include "llvm/Analysis/TensorSpec.h"
 #include "llvm/AsmParser/Parser.h"
@@ -133,171 +130,3 @@ TEST(TFUtilsTest, UnsupportedFeature) {
   for (auto I = 0; I < 2 * 5; ++I)
     EXPECT_FLOAT_EQ(F[I], 3.14 + I);
 }
-
-#define PROTO_CHECKER(FNAME, TYPE, INDEX, EXP)                                 \
-  do {                                                                         \
-    const auto &V = Expected.feature_lists()                                   \
-                        .feature_list()                                        \
-                        .at(FNAME)                                             \
-                        .feature(INDEX)                                        \
-                        .TYPE()                                                \
-                        .value();                                              \
-    for (auto I = 0; I < V.size(); ++I)                                        \
-      EXPECT_EQ(V.at(I), EXP[I]);                                              \
-  } while (false)
-
-TEST(TFUtilsTest, Logger) {
-  std::vector<LoggedFeatureSpec> Features;
-  Features.push_back(
-      {TensorSpec::createSpec<float>("the_float", {2, 3}), None});
-  Features.push_back({TensorSpec::createSpec<int64_t>("the_int", {2}),
-                      std::string("alternate_name")});
-
-  auto Rewards = TensorSpec::createSpec<float>("reward", {1});
-  Logger L(Features, Rewards, true);
-  const float F00[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
-  const int64_t F01[]{2, 3};
-
-  L.logFloatValue(0, F00);
-  L.logInt64Value(1, F01);
-  L.logFloatReward(3.4);
-  const float F10[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0};
-  const int64_t F11[]{-2, -3};
-  L.logFloatValue(0, F10);
-  L.logInt64Value(1, F11);
-  L.logFloatReward(-3.0);
-  std::string Result;
-  raw_string_ostream OS(Result);
-  L.flush(OS);
-
-  tensorflow::SequenceExample Expected;
-  ASSERT_TRUE(Expected.ParseFromString(Result));
-  PROTO_CHECKER("the_float", float_list, 0, F00);
-  PROTO_CHECKER("the_float", float_list, 1, F10);
-  PROTO_CHECKER("alternate_name", int64_list, 0, F01);
-  PROTO_CHECKER("alternate_name", int64_list, 1, F11);
-  float R0[]{3.4};
-  float R1[]{-3.0};
-  PROTO_CHECKER("reward", float_list, 0, R0);
-  PROTO_CHECKER("reward", float_list, 1, R1);
-}
-
-TEST(TFUtilsTest, LoggerInt32FeaturesAndReward) {
-  std::vector<LoggedFeatureSpec> Features;
-  Features.push_back(
-      {TensorSpec::createSpec<float>("the_float", {2, 3}), None});
-  Features.push_back({TensorSpec::createSpec<int32_t>("the_int", {2}),
-                      std::string("alternate_name")});
-
-  auto Rewards = TensorSpec::createSpec<int32_t>("reward", {1});
-  Logger L(Features, Rewards, true);
-  const float F00[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
-  const int32_t F01[]{2, 3};
-
-  L.logFloatValue(0, F00);
-  L.logInt32Value(1, F01);
-  L.logInt32Reward(3);
-  const float F10[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0};
-  const int32_t F11[]{-2, -3};
-  L.logFloatValue(0, F10);
-  L.logInt32Value(1, F11);
-  L.logInt32Reward(-3);
-  std::string Result;
-  raw_string_ostream OS(Result);
-  L.flush(OS);
-
-  tensorflow::SequenceExample Expected;
-  ASSERT_TRUE(Expected.ParseFromString(Result));
-  PROTO_CHECKER("the_float", float_list, 0, F00);
-  PROTO_CHECKER("the_float", float_list, 1, F10);
-  PROTO_CHECKER("alternate_name", int64_list, 0, F01);
-  PROTO_CHECKER("alternate_name", int64_list, 1, F11);
-  int32_t R0[]{3};
-  int32_t R1[]{-3};
-  PROTO_CHECKER("reward", int64_list, 0, R0);
-  PROTO_CHECKER("reward", int64_list, 1, R1);
-}
-
-TEST(TFUtilsTest, LoggerNoReward) {
-  std::vector<LoggedFeatureSpec> Features;
-  Features.push_back(
-      {TensorSpec::createSpec<float>("the_float", {2, 3}), None});
-  Features.push_back({TensorSpec::createSpec<int64_t>("the_int", {2}),
-                      std::string("alternate_name")});
-
-  auto Rewards = TensorSpec::createSpec<float>("reward", {1});
-  Logger L(Features, Rewards, false);
-  const float F00[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
-  const int64_t F01[]{2, 3};
-
-  L.logFloatValue(0, F00);
-  L.logInt64Value(1, F01);
-  const float F10[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0};
-  const int64_t F11[]{-2, -3};
-  L.logFloatValue(0, F10);
-  L.logInt64Value(1, F11);
-
-  std::string Result;
-  raw_string_ostream OS(Result);
-  L.flush(OS);
-  tensorflow::SequenceExample Expected;
-  ASSERT_TRUE(Expected.ParseFromString(Result));
-  PROTO_CHECKER("the_float", float_list, 0, F00);
-  PROTO_CHECKER("the_float", float_list, 1, F10);
-  PROTO_CHECKER("alternate_name", int64_list, 0, F01);
-  PROTO_CHECKER("alternate_name", int64_list, 1, F11);
-}
-
-TEST(TFUtilsTest, LoggerFinalReward) {
-  std::vector<LoggedFeatureSpec> Features;
-  Features.push_back({TensorSpec::createSpec<float>("the_float", {1}), None});
-  Features.push_back({TensorSpec::createSpec<int64_t>("the_int", {1}), None});
-
-  auto Rewards = TensorSpec::createSpec<float>("reward", {1});
-  Logger L(Features, Rewards, true);
-  for (int64_t I = 0; I < 3; ++I) {
-    float F = static_cast<float>(I);
-    L.logFloatValue(0, &F);
-    L.logInt64Value(1, &I);
-  }
-  L.logFloatFinalReward(3.14);
-  std::string Result;
-  raw_string_ostream OS(Result);
-  L.flush(OS);
-  const float Zero[]{0.0};
-  const float R[]{3.14};
-  tensorflow::SequenceExample Expected;
-  ASSERT_TRUE(Expected.ParseFromString(Result));
-  PROTO_CHECKER("reward", float_list, 0, Zero);
-  PROTO_CHECKER("reward", float_list, 1, Zero);
-  PROTO_CHECKER("reward", float_list, 2, R);
-}
-
-TEST(TFUtilsTest, LoggerGroup) {
-  std::vector<LoggedFeatureSpec> Features;
-  Features.push_back({TensorSpec::createSpec<float>("the_float", {1}), None});
-  Features.push_back({TensorSpec::createSpec<int64_t>("the_int", {1}), None});
-
-  auto Rewards = TensorSpec::createSpec<float>("reward", {1});
-  StringMap<std::unique_ptr<Logger>> Loggers;
-  std::vector<std::string> Names{"a", "b"};
-  size_t Bump = 0;
-  for (auto Name : Names) {
-    auto L = std::make_unique<Logger>(Features, Rewards, true);
-    for (int64_t I = 0; I < 3; ++I) {
-      float F = static_cast<float>(I) + Bump;
-      L->logFloatValue(0, &F);
-      L->logInt64Value(1, &I);
-    }
-    L->logFloatFinalReward(3.14 + Bump);
-    Loggers.insert(std::make_pair(Name, std::move(L)));
-  }
-  std::string Result;
-  raw_string_ostream OS(Result);
-  Logger::flushLogs(OS, Loggers);
-  google::protobuf::Struct Expected;
-  ASSERT_TRUE(Expected.ParseFromString(Result));
-  EXPECT_EQ(Expected.fields_size(), 2);
-  EXPECT_TRUE(Expected.fields().contains("a"));
-  EXPECT_TRUE(Expected.fields().contains("b"));
-}

diff  --git a/llvm/unittests/Analysis/TrainingLoggerTest.cpp b/llvm/unittests/Analysis/TrainingLoggerTest.cpp
new file mode 100644
index 0000000000000..f076190572d83
--- /dev/null
+++ b/llvm/unittests/Analysis/TrainingLoggerTest.cpp
@@ -0,0 +1,198 @@
+//===- TrainingLoggerTest.cpp - test for TrainingLogger -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/Utils/TrainingLogger.h"
+#include "google/protobuf/struct.pb.h"
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
+#include "llvm/Analysis/TensorSpec.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+extern const char *TestMainArgv0;
+
+// NOTE! This test model is currently also used by test/Transforms/Inline/ML
+// tests
+//- relevant if updating this model.
+
+#define PROTO_CHECKER(FNAME, TYPE, INDEX, EXP)                                 \
+  do {                                                                         \
+    const auto &V = Expected.feature_lists()                                   \
+                        .feature_list()                                        \
+                        .at(FNAME)                                             \
+                        .feature(INDEX)                                        \
+                        .TYPE()                                                \
+                        .value();                                              \
+    for (auto I = 0; I < V.size(); ++I)                                        \
+      EXPECT_EQ(V.at(I), EXP[I]);                                              \
+  } while (false)
+
+TEST(TrainingLoggerTest, Logger) {
+  std::vector<LoggedFeatureSpec> Features;
+  Features.push_back(
+      {TensorSpec::createSpec<float>("the_float", {2, 3}), None});
+  Features.push_back({TensorSpec::createSpec<int64_t>("the_int", {2}),
+                      std::string("alternate_name")});
+
+  auto Rewards = TensorSpec::createSpec<float>("reward", {1});
+  Logger L(Features, Rewards, true);
+  const float F00[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
+  const int64_t F01[]{2, 3};
+
+  L.logFloatValue(0, F00);
+  L.logInt64Value(1, F01);
+  L.logFloatReward(3.4);
+  const float F10[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0};
+  const int64_t F11[]{-2, -3};
+  L.logFloatValue(0, F10);
+  L.logInt64Value(1, F11);
+  L.logFloatReward(-3.0);
+  std::string Result;
+  raw_string_ostream OS(Result);
+  L.flush(OS);
+
+  tensorflow::SequenceExample Expected;
+  ASSERT_TRUE(Expected.ParseFromString(Result));
+  PROTO_CHECKER("the_float", float_list, 0, F00);
+  PROTO_CHECKER("the_float", float_list, 1, F10);
+  PROTO_CHECKER("alternate_name", int64_list, 0, F01);
+  PROTO_CHECKER("alternate_name", int64_list, 1, F11);
+  float R0[]{3.4};
+  float R1[]{-3.0};
+  PROTO_CHECKER("reward", float_list, 0, R0);
+  PROTO_CHECKER("reward", float_list, 1, R1);
+}
+
+TEST(TrainingLoggerTest, LoggerInt32FeaturesAndReward) {
+  std::vector<LoggedFeatureSpec> Features;
+  Features.push_back(
+      {TensorSpec::createSpec<float>("the_float", {2, 3}), None});
+  Features.push_back({TensorSpec::createSpec<int32_t>("the_int", {2}),
+                      std::string("alternate_name")});
+
+  auto Rewards = TensorSpec::createSpec<int32_t>("reward", {1});
+  Logger L(Features, Rewards, true);
+  const float F00[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
+  const int32_t F01[]{2, 3};
+
+  L.logFloatValue(0, F00);
+  L.logInt32Value(1, F01);
+  L.logInt32Reward(3);
+  const float F10[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0};
+  const int32_t F11[]{-2, -3};
+  L.logFloatValue(0, F10);
+  L.logInt32Value(1, F11);
+  L.logInt32Reward(-3);
+  std::string Result;
+  raw_string_ostream OS(Result);
+  L.flush(OS);
+
+  tensorflow::SequenceExample Expected;
+  ASSERT_TRUE(Expected.ParseFromString(Result));
+  PROTO_CHECKER("the_float", float_list, 0, F00);
+  PROTO_CHECKER("the_float", float_list, 1, F10);
+  PROTO_CHECKER("alternate_name", int64_list, 0, F01);
+  PROTO_CHECKER("alternate_name", int64_list, 1, F11);
+  int32_t R0[]{3};
+  int32_t R1[]{-3};
+  PROTO_CHECKER("reward", int64_list, 0, R0);
+  PROTO_CHECKER("reward", int64_list, 1, R1);
+}
+
+TEST(TrainingLoggerTest, LoggerNoReward) {
+  std::vector<LoggedFeatureSpec> Features;
+  Features.push_back(
+      {TensorSpec::createSpec<float>("the_float", {2, 3}), None});
+  Features.push_back({TensorSpec::createSpec<int64_t>("the_int", {2}),
+                      std::string("alternate_name")});
+
+  auto Rewards = TensorSpec::createSpec<float>("reward", {1});
+  Logger L(Features, Rewards, false);
+  const float F00[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5};
+  const int64_t F01[]{2, 3};
+
+  L.logFloatValue(0, F00);
+  L.logInt64Value(1, F01);
+  const float F10[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0};
+  const int64_t F11[]{-2, -3};
+  L.logFloatValue(0, F10);
+  L.logInt64Value(1, F11);
+
+  std::string Result;
+  raw_string_ostream OS(Result);
+  L.flush(OS);
+  tensorflow::SequenceExample Expected;
+  ASSERT_TRUE(Expected.ParseFromString(Result));
+  PROTO_CHECKER("the_float", float_list, 0, F00);
+  PROTO_CHECKER("the_float", float_list, 1, F10);
+  PROTO_CHECKER("alternate_name", int64_list, 0, F01);
+  PROTO_CHECKER("alternate_name", int64_list, 1, F11);
+}
+
+TEST(TrainingLoggerTest, LoggerFinalReward) {
+  std::vector<LoggedFeatureSpec> Features;
+  Features.push_back({TensorSpec::createSpec<float>("the_float", {1}), None});
+  Features.push_back({TensorSpec::createSpec<int64_t>("the_int", {1}), None});
+
+  auto Rewards = TensorSpec::createSpec<float>("reward", {1});
+  Logger L(Features, Rewards, true);
+  for (int64_t I = 0; I < 3; ++I) {
+    float F = static_cast<float>(I);
+    L.logFloatValue(0, &F);
+    L.logInt64Value(1, &I);
+  }
+  L.logFloatFinalReward(3.14);
+  std::string Result;
+  raw_string_ostream OS(Result);
+  L.flush(OS);
+  const float Zero[]{0.0};
+  const float R[]{3.14};
+  tensorflow::SequenceExample Expected;
+  ASSERT_TRUE(Expected.ParseFromString(Result));
+  PROTO_CHECKER("reward", float_list, 0, Zero);
+  PROTO_CHECKER("reward", float_list, 1, Zero);
+  PROTO_CHECKER("reward", float_list, 2, R);
+}
+
+TEST(TrainingLoggerTest, LoggerGroup) {
+  std::vector<LoggedFeatureSpec> Features;
+  Features.push_back({TensorSpec::createSpec<float>("the_float", {1}), None});
+  Features.push_back({TensorSpec::createSpec<int64_t>("the_int", {1}), None});
+
+  auto Rewards = TensorSpec::createSpec<float>("reward", {1});
+  StringMap<std::unique_ptr<Logger>> Loggers;
+  std::vector<std::string> Names{"a", "b"};
+  size_t Bump = 0;
+  for (auto Name : Names) {
+    auto L = std::make_unique<Logger>(Features, Rewards, true);
+    for (int64_t I = 0; I < 3; ++I) {
+      float F = static_cast<float>(I) + Bump;
+      L->logFloatValue(0, &F);
+      L->logInt64Value(1, &I);
+    }
+    L->logFloatFinalReward(3.14 + Bump);
+    Loggers.insert(std::make_pair(Name, std::move(L)));
+  }
+  std::string Result;
+  raw_string_ostream OS(Result);
+  Logger::flushLogs(OS, Loggers);
+  google::protobuf::Struct Expected;
+  ASSERT_TRUE(Expected.ParseFromString(Result));
+  EXPECT_EQ(Expected.fields_size(), 2);
+  EXPECT_TRUE(Expected.fields().contains("a"));
+  EXPECT_TRUE(Expected.fields().contains("b"));
+}


        


More information about the llvm-commits mailing list