[llvm] d5c81be - [NFC][MLInliner] Set up the logger outside the development mode advisor

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 10 09:22:30 PDT 2020


Author: Mircea Trofin
Date: 2020-08-10T09:22:17-07:00
New Revision: d5c81be3ca2504e32a99b57711ae101e02d810fa

URL: https://github.com/llvm/llvm-project/commit/d5c81be3ca2504e32a99b57711ae101e02d810fa
DIFF: https://github.com/llvm/llvm-project/commit/d5c81be3ca2504e32a99b57711ae101e02d810fa.diff

LOG: [NFC][MLInliner] Set up the logger outside the development mode advisor

This allows us to subsequently configure the logger for the case when we
use a model evaluator and want to log additional outputs.

Differential Revision: https://reviews.llvm.org/D85577

Added: 
    

Modified: 
    llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
index 76bb3d39d606..6c0fa7a624e6 100644
--- a/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/DevelopmentModeInlineAdvisor.cpp
@@ -71,14 +71,14 @@ struct InlineEvent {
 /// lines up with how TF SequenceExample represents it.
 class TrainingLogger final {
 public:
-  TrainingLogger();
+  TrainingLogger(StringRef LogFileName);
 
   /// Log one inlining event.
   void logInlineEvent(const InlineEvent &Event,
                       const MLModelRunner &ModelRunner);
 
   /// Print the stored tensors.
-  void print(raw_fd_ostream &OutFile);
+  void print();
 
 private:
   /// Write the values of one tensor as a list.
@@ -156,6 +156,7 @@ class TrainingLogger final {
     OutFile << "  }\n";
   }
 
+  StringRef LogFileName;
   std::vector<InlineFeatures> Features;
   std::vector<int64_t> DefaultDecisions;
   std::vector<int64_t> Decisions;
@@ -193,7 +194,8 @@ class DevelopmentModeMLInlineAdvisor : public MLInlineAdvisor {
   DevelopmentModeMLInlineAdvisor(
       Module &M, ModuleAnalysisManager &MAM,
       std::unique_ptr<MLModelRunner> ModelRunner,
-      std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference);
+      std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference,
+      std::unique_ptr<TrainingLogger> Logger);
 
   size_t getTotalSizeEstimate();
 
@@ -211,11 +213,11 @@ class DevelopmentModeMLInlineAdvisor : public MLInlineAdvisor {
   size_t getNativeSizeEstimate(const Function &F) const;
 
 private:
-  bool isLogging() const { return !TrainingLog.empty(); }
+  bool isLogging() const { return !!Logger; }
 
   std::function<bool(CallBase &)> GetDefaultAdvice;
-  TrainingLogger Logger;
   const bool IsDoingInference;
+  std::unique_ptr<TrainingLogger> Logger;
 
   const int32_t InitialNativeSize;
   int32_t CurrentNativeSize = 0;
@@ -346,7 +348,8 @@ class ModelUnderTrainingRunner final : public MLModelRunner {
 };
 } // namespace
 
-TrainingLogger::TrainingLogger() {
+TrainingLogger::TrainingLogger(StringRef LogFileName)
+    : LogFileName(LogFileName) {
   for (size_t I = 0; I < NumberOfFeatures; ++I) {
     Features.push_back(InlineFeatures());
   }
@@ -364,7 +367,9 @@ void TrainingLogger::logInlineEvent(const InlineEvent &Event,
   DefaultDecisions.push_back(Event.DefaultDecision);
 }
 
-void TrainingLogger::print(raw_fd_ostream &OutFile) {
+void TrainingLogger::print() {
+  std::error_code EC;
+  raw_fd_ostream OutFile(LogFileName, EC);
   size_t NumberOfRecords = Decisions.size();
   if (NumberOfRecords == 0)
     return;
@@ -392,9 +397,11 @@ void TrainingLogger::print(raw_fd_ostream &OutFile) {
 DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor(
     Module &M, ModuleAnalysisManager &MAM,
     std::unique_ptr<MLModelRunner> ModelRunner,
-    std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference)
+    std::function<bool(CallBase &)> GetDefaultAdvice, bool IsDoingInference,
+    std::unique_ptr<TrainingLogger> Logger)
     : MLInlineAdvisor(M, MAM, std::move(ModelRunner)),
       GetDefaultAdvice(GetDefaultAdvice), IsDoingInference(IsDoingInference),
+      Logger(std::move(Logger)),
       InitialNativeSize(isLogging() ? getTotalSizeEstimate() : 0),
       CurrentNativeSize(InitialNativeSize) {
   // We cannot have the case of neither inference nor logging.
@@ -402,11 +409,8 @@ DevelopmentModeMLInlineAdvisor::DevelopmentModeMLInlineAdvisor(
 }
 
 DevelopmentModeMLInlineAdvisor::~DevelopmentModeMLInlineAdvisor() {
-  if (TrainingLog.empty())
-    return;
-  std::error_code ErrorCode;
-  raw_fd_ostream OutFile(TrainingLog, ErrorCode);
-  Logger.print(OutFile);
+  if (isLogging())
+    Logger->print();
 }
 
 size_t
@@ -428,7 +432,7 @@ DevelopmentModeMLInlineAdvisor::getMandatoryAdvice(
     return MLInlineAdvisor::getMandatoryAdvice(CB, ORE);
   return std::make_unique<LoggingMLInlineAdvice>(
       /*Advisor=*/this,
-      /*CB=*/CB, /*ORE=*/ORE, /*Recommendation=*/true, /*Logger=*/Logger,
+      /*CB=*/CB, /*ORE=*/ORE, /*Recommendation=*/true, /*Logger=*/*Logger,
       /*CallerSizeEstimateBefore=*/getNativeSizeEstimate(*CB.getCaller()),
       /*CalleeSizeEstimateBefore=*/
       getNativeSizeEstimate(*CB.getCalledFunction()),
@@ -446,7 +450,7 @@ DevelopmentModeMLInlineAdvisor::getAdviceFromModel(
   return std::make_unique<LoggingMLInlineAdvice>(
       /*Advisor=*/this,
       /*CB=*/CB, /*ORE=*/ORE, /*Recommendation=*/Recommendation,
-      /*Logger=*/Logger,
+      /*Logger=*/*Logger,
       /*CallerSizeEstimateBefore=*/getNativeSizeEstimate(*CB.getCaller()),
       /*CalleeSizeEstimateBefore=*/
       getNativeSizeEstimate(*CB.getCalledFunction()),
@@ -531,7 +535,12 @@ std::unique_ptr<InlineAdvisor> llvm::getDevelopmentModeAdvisor(
     }
     IsDoingInference = true;
   }
+  std::unique_ptr<TrainingLogger> Logger;
+  if (!TrainingLog.empty())
+    Logger = std::make_unique<TrainingLogger>(TrainingLog);
+
   return std::make_unique<DevelopmentModeMLInlineAdvisor>(
-      M, MAM, std::move(Runner), GetDefaultAdvice, IsDoingInference);
+      M, MAM, std::move(Runner), GetDefaultAdvice, IsDoingInference,
+      std::move(Logger));
 }
 #endif // defined(LLVM_HAVE_TF_API)


        


More information about the llvm-commits mailing list