[llvm] 35aa737 - [mlgo] Allow logging the spec for the "advice", if needed
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 1 10:24:52 PST 2023
Author: Mircea Trofin
Date: 2023-02-01T10:24:38-08:00
New Revision: 35aa73746c85563912765567850346b48c6610e6
URL: https://github.com/llvm/llvm-project/commit/35aa73746c85563912765567850346b48c6610e6
DIFF: https://github.com/llvm/llvm-project/commit/35aa73746c85563912765567850346b48c6610e6.diff
LOG: [mlgo] Allow logging the spec for the "advice", if needed
This is for the interactive model runner, so it can confirm the tensor
spec of the advice with its host.
Added:
Modified:
llvm/include/llvm/Analysis/Utils/TrainingLogger.h
llvm/lib/Analysis/InteractiveModelRunner.cpp
llvm/lib/Analysis/TrainingLogger.cpp
llvm/unittests/Analysis/MLModelRunnerTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/Utils/TrainingLogger.h b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h
index 75b79e622a56e..ef6018914d34a 100644
--- a/llvm/include/llvm/Analysis/Utils/TrainingLogger.h
+++ b/llvm/include/llvm/Analysis/Utils/TrainingLogger.h
@@ -96,7 +96,7 @@ class Logger final {
StringMap<size_t> ObservationIDs;
std::string CurrentContext;
- void writeHeader();
+ void writeHeader(std::optional<TensorSpec> AdviceSpec);
void writeTensor(const TensorSpec &Spec, const char *RawData) {
OS->write(RawData, Spec.getTotalTensorBufferSize());
}
@@ -111,7 +111,8 @@ class Logger final {
/// corresponding to the model being trained/logged.
Logger(std::unique_ptr<raw_ostream> OS,
const std::vector<TensorSpec> &FeatureSpecs,
- const TensorSpec &RewardSpec, bool IncludeReward);
+ const TensorSpec &RewardSpec, bool IncludeReward,
+ std::optional<TensorSpec> AdviceSpec = std::nullopt);
void switchContext(StringRef Name);
void startObservation();
diff --git a/llvm/lib/Analysis/InteractiveModelRunner.cpp b/llvm/lib/Analysis/InteractiveModelRunner.cpp
index 40a4b5089085b..a347b49eb0729 100644
--- a/llvm/lib/Analysis/InteractiveModelRunner.cpp
+++ b/llvm/lib/Analysis/InteractiveModelRunner.cpp
@@ -37,7 +37,7 @@ InteractiveModelRunner::InteractiveModelRunner(
InputSpecs(Inputs), OutputSpec(Advice), Inbound(InboundName, InEC),
OutputBuffer(OutputSpec.getTotalTensorBufferSize()),
Log(std::make_unique<raw_fd_ostream>(OutboundName, OutEC), InputSpecs,
- Advice, /*IncludeReward=*/false) {
+ Advice, /*IncludeReward=*/false, Advice) {
if (InEC) {
Ctx.emitError("Cannot open inbound file: " + InEC.message());
return;
diff --git a/llvm/lib/Analysis/TrainingLogger.cpp b/llvm/lib/Analysis/TrainingLogger.cpp
index dcee8d40c53d7..e236890aa2bcc 100644
--- a/llvm/lib/Analysis/TrainingLogger.cpp
+++ b/llvm/lib/Analysis/TrainingLogger.cpp
@@ -32,7 +32,7 @@ static cl::opt<bool>
UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden,
cl::desc("Output simple (non-protobuf) log."));
-void Logger::writeHeader() {
+void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
json::OStream JOS(*OS);
JOS.object([&]() {
JOS.attributeArray("features", [&]() {
@@ -44,6 +44,11 @@ void Logger::writeHeader() {
RewardSpec.toJSON(JOS);
JOS.attributeEnd();
}
+ if (AdviceSpec.has_value()) {
+ JOS.attributeBegin("advice");
+ AdviceSpec->toJSON(JOS);
+ JOS.attributeEnd();
+ }
});
*OS << "\n";
}
@@ -81,8 +86,9 @@ void Logger::logRewardImpl(const char *RawData) {
Logger::Logger(std::unique_ptr<raw_ostream> OS,
const std::vector<TensorSpec> &FeatureSpecs,
- const TensorSpec &RewardSpec, bool IncludeReward)
+ const TensorSpec &RewardSpec, bool IncludeReward,
+ std::optional<TensorSpec> AdviceSpec)
: OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
IncludeReward(IncludeReward) {
- writeHeader();
+ writeHeader(AdviceSpec);
}
diff --git a/llvm/unittests/Analysis/MLModelRunnerTest.cpp b/llvm/unittests/Analysis/MLModelRunnerTest.cpp
index 3750516ef7d58..1f80eb7820983 100644
--- a/llvm/unittests/Analysis/MLModelRunnerTest.cpp
+++ b/llvm/unittests/Analysis/MLModelRunnerTest.cpp
@@ -185,6 +185,7 @@ TEST(InteractiveModelRunner, Evaluation) {
auto Header = json::parse(ReadLn());
EXPECT_FALSE(Header.takeError());
EXPECT_NE(Header->getAsObject()->getArray("features"), nullptr);
+ EXPECT_NE(Header->getAsObject()->getObject("advice"), nullptr);
// Then comes the context
EXPECT_FALSE(json::parse(ReadLn()).takeError());
More information about the llvm-commits
mailing list