[llvm] 35aa737 - [mlgo] Allow logging the spec for the "advice", if needed

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 1 10:24:52 PST 2023


Author: Mircea Trofin
Date: 2023-02-01T10:24:38-08:00
New Revision: 35aa73746c85563912765567850346b48c6610e6

URL: https://github.com/llvm/llvm-project/commit/35aa73746c85563912765567850346b48c6610e6
DIFF: https://github.com/llvm/llvm-project/commit/35aa73746c85563912765567850346b48c6610e6.diff

LOG: [mlgo] Allow logging the spec for the "advice", if needed

This is for the interactive model runner, so it can confirm the tensor
spec of the advice with its host.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/Utils/TrainingLogger.h
    llvm/lib/Analysis/InteractiveModelRunner.cpp
    llvm/lib/Analysis/TrainingLogger.cpp
    llvm/unittests/Analysis/MLModelRunnerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/Utils/TrainingLogger.h b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h
index 75b79e622a56e..ef6018914d34a 100644
--- a/llvm/include/llvm/Analysis/Utils/TrainingLogger.h
+++ b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h
@@ -96,7 +96,7 @@ class Logger final {
   StringMap<size_t> ObservationIDs;
   std::string CurrentContext;
 
-  void writeHeader();
+  void writeHeader(std::optional<TensorSpec> AdviceSpec);
   void writeTensor(const TensorSpec &Spec, const char *RawData) {
     OS->write(RawData, Spec.getTotalTensorBufferSize());
   }
@@ -111,7 +111,8 @@ class Logger final {
   /// corresponding to the model being trained/logged.
   Logger(std::unique_ptr<raw_ostream> OS,
          const std::vector<TensorSpec> &FeatureSpecs,
-         const TensorSpec &RewardSpec, bool IncludeReward);
+         const TensorSpec &RewardSpec, bool IncludeReward,
+         std::optional<TensorSpec> AdviceSpec = std::nullopt);
 
   void switchContext(StringRef Name);
   void startObservation();

diff  --git a/llvm/lib/Analysis/InteractiveModelRunner.cpp b/llvm/lib/Analysis/InteractiveModelRunner.cpp
index 40a4b5089085b..a347b49eb0729 100644
--- a/llvm/lib/Analysis/InteractiveModelRunner.cpp
+++ b/llvm/lib/Analysis/InteractiveModelRunner.cpp
@@ -37,7 +37,7 @@ InteractiveModelRunner::InteractiveModelRunner(
       InputSpecs(Inputs), OutputSpec(Advice), Inbound(InboundName, InEC),
       OutputBuffer(OutputSpec.getTotalTensorBufferSize()),
       Log(std::make_unique<raw_fd_ostream>(OutboundName, OutEC), InputSpecs,
-          Advice, /*IncludeReward=*/false) {
+          Advice, /*IncludeReward=*/false, Advice) {
   if (InEC) {
     Ctx.emitError("Cannot open inbound file: " + InEC.message());
     return;

diff  --git a/llvm/lib/Analysis/TrainingLogger.cpp b/llvm/lib/Analysis/TrainingLogger.cpp
index dcee8d40c53d7..e236890aa2bcc 100644
--- a/llvm/lib/Analysis/TrainingLogger.cpp
+++ b/llvm/lib/Analysis/TrainingLogger.cpp
@@ -32,7 +32,7 @@ static cl::opt<bool>
     UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden,
                     cl::desc("Output simple (non-protobuf) log."));
 
-void Logger::writeHeader() {
+void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
   json::OStream JOS(*OS);
   JOS.object([&]() {
     JOS.attributeArray("features", [&]() {
@@ -44,6 +44,11 @@ void Logger::writeHeader() {
       RewardSpec.toJSON(JOS);
       JOS.attributeEnd();
     }
+    if (AdviceSpec.has_value()) {
+      JOS.attributeBegin("advice");
+      AdviceSpec->toJSON(JOS);
+      JOS.attributeEnd();
+    }
   });
   *OS << "\n";
 }
@@ -81,8 +86,9 @@ void Logger::logRewardImpl(const char *RawData) {
 
 Logger::Logger(std::unique_ptr<raw_ostream> OS,
                const std::vector<TensorSpec> &FeatureSpecs,
-               const TensorSpec &RewardSpec, bool IncludeReward)
+               const TensorSpec &RewardSpec, bool IncludeReward,
+               std::optional<TensorSpec> AdviceSpec)
     : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
       IncludeReward(IncludeReward) {
-  writeHeader();
+  writeHeader(AdviceSpec);
 }

diff  --git a/llvm/unittests/Analysis/MLModelRunnerTest.cpp b/llvm/unittests/Analysis/MLModelRunnerTest.cpp
index 3750516ef7d58..1f80eb7820983 100644
--- a/llvm/unittests/Analysis/MLModelRunnerTest.cpp
+++ b/llvm/unittests/Analysis/MLModelRunnerTest.cpp
@@ -185,6 +185,7 @@ TEST(InteractiveModelRunner, Evaluation) {
   auto Header = json::parse(ReadLn());
   EXPECT_FALSE(Header.takeError());
   EXPECT_NE(Header->getAsObject()->getArray("features"), nullptr);
+  EXPECT_NE(Header->getAsObject()->getObject("advice"), nullptr);
   // Then comes the context
   EXPECT_FALSE(json::parse(ReadLn()).takeError());
 


        


More information about the llvm-commits mailing list