[llvm] 83080a2 - [llvm] Native size estimator for training -Oz inliner

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 13 10:14:15 PDT 2020


Author: Mircea Trofin
Date: 2020-07-13T10:13:56-07:00
New Revision: 83080a294ad7d145d758821bcf4354ad0cb7d299

URL: https://github.com/llvm/llvm-project/commit/83080a294ad7d145d758821bcf4354ad0cb7d299
DIFF: https://github.com/llvm/llvm-project/commit/83080a294ad7d145d758821bcf4354ad0cb7d299.diff

LOG: [llvm] Native size estimator for training -Oz inliner

Summary:
This is an experimental ML-based native size estimator, necessary for
computing partial rewards during -Oz inliner policy training. Data
extraction for model training will be provided in a separate patch.

RFC: http://lists.llvm.org/pipermail/llvm-dev/2020-April/140763.html

Reviewers: davidxl, jdoerfert

Subscribers: mgorny, hiraditya, mgrang, arphaman, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D82817

Added: 
    llvm/include/llvm/Analysis/InlineSizeEstimatorAnalysis.h
    llvm/include/llvm/Analysis/Utils/TFUtils.h
    llvm/lib/Analysis/InlineSizeEstimatorAnalysis.cpp
    llvm/lib/Analysis/TFUtils.cpp
    llvm/unittests/Analysis/InlineSizeEstimatorAnalysisTest.cpp
    llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/saved_model.pbtxt
    llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/variables/variables.data-00000-of-00001
    llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/variables/variables.index
    llvm/unittests/Analysis/TFUtilsTest.cpp

Modified: 
    llvm/CMakeLists.txt
    llvm/lib/Analysis/CMakeLists.txt
    llvm/lib/Passes/PassBuilder.cpp
    llvm/lib/Passes/PassRegistry.def
    llvm/unittests/Analysis/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/CMakeLists.txt b/llvm/CMakeLists.txt
index de2887b64c2a..4e14e61fcacd 100644
--- a/llvm/CMakeLists.txt
+++ b/llvm/CMakeLists.txt
@@ -981,6 +981,18 @@ if (NOT TENSORFLOW_AOT_PATH STREQUAL "")
     ${CMAKE_ARCHIVE_OUTPUT_DIRECTORY}/tf_runtime)
 endif()
 
+set(TENSORFLOW_C_LIB_PATH "" CACHE PATH "Path to TensorFlow C library install")
+find_library(tensorflow_c_api tensorflow PATHS ${TENSORFLOW_C_LIB_PATH}/lib)
+
+# Similar to the above Tensorflow dependency, please refer to the same script.
+# In this case, the latest C API library is available for download from
+# https://www.tensorflow.org/install/lang_c
+if (tensorflow_c_api)
+  set(LLVM_HAVE_TF_API "ON" CACHE BOOL "Full Tensorflow API available")
+  add_definitions("-DLLVM_HAVE_TF_API")
+  include_directories(${TENSORFLOW_C_LIB_PATH}/include)
+endif()
+
 # Put this before tblgen. Else we have a circular dependence.
 add_subdirectory(lib/Demangle)
 add_subdirectory(lib/Support)

diff  --git a/llvm/include/llvm/Analysis/InlineSizeEstimatorAnalysis.h b/llvm/include/llvm/Analysis/InlineSizeEstimatorAnalysis.h
new file mode 100644
index 000000000000..29a6f5914674
--- /dev/null
+++ b/llvm/include/llvm/Analysis/InlineSizeEstimatorAnalysis.h
@@ -0,0 +1,35 @@
+//===- InlineSizeEstimatorAnalysis.h - ML size estimator --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+
+#ifndef LLVM_ANALYSIS_INLINESIZEESTIMATORANALYSIS_H
+#define LLVM_ANALYSIS_INLINESIZEESTIMATORANALYSIS_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+class Function;
+
+class TFModelEvaluator;
+class InlineSizeEstimatorAnalysis
+    : public AnalysisInfoMixin<InlineSizeEstimatorAnalysis> {
+public:
+  InlineSizeEstimatorAnalysis();
+  InlineSizeEstimatorAnalysis(InlineSizeEstimatorAnalysis &&);
+  ~InlineSizeEstimatorAnalysis();
+
+  static AnalysisKey Key;
+  using Result = Optional<size_t>;
+  Result run(const Function &F, FunctionAnalysisManager &FAM);
+  static bool isEvaluatorRequested();
+
+private:
+  std::unique_ptr<TFModelEvaluator> Evaluator;
+};
+} // namespace llvm
+#endif // LLVM_ANALYSIS_INLINESIZEESTIMATORANALYSIS_H
\ No newline at end of file

diff  --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h
new file mode 100644
index 000000000000..a1d7108b149f
--- /dev/null
+++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h
@@ -0,0 +1,136 @@
+//===- TFUtils.h - utilities for tensorflow C API ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#ifndef LLVM_ANALYSIS_UTILS_TFUTILS_H
+#define LLVM_ANALYSIS_UTILS_TFUTILS_H
+
+#include "tensorflow/c/c_api.h"
+#include "llvm/IR/LLVMContext.h"
+
+#include <memory>
+#include <vector>
+
+namespace llvm {
+
+/// Load a SavedModel, find the given inputs and outputs, and setup storage
+/// for input tensors. The user is responsible for correctly dimensioning the
+/// input tensors and setting their values before calling evaluate().
+/// To initialize:
+/// - construct the object
+/// - initialize the input tensors using initInput. Indices must correspond to
+///   indices in the InputNames used at construction.
+/// To use:
+/// - set input values by using getInput to get each input tensor, and then
+///   setting internal scalars, for all dimensions (tensors are row-major:
+///   https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/c/c_api.h#L205)
+/// - prepare an output vector of TF_Output* type, with the correct number of
+/// outputs (i.e. same as OutputNames). Initialize the vector with nullptr
+/// values.
+/// - call evaluate. The input tensors' values are not consumed after this, and
+///   may still be read.
+/// - use the outputs in the output vector
+/// - deallocate each output tensor in the output vector, using TF_DeleteTensor.
+class TFModelEvaluator final {
+public:
+  /// The result of a model evaluation. Handles the lifetime of the output
+  /// TF_Tensor objects, which means that their values need to be used before
+  /// the EvaluationResult's dtor is called.
+  class EvaluationResult {
+  public:
+    ~EvaluationResult() {
+      for (auto *P : Output)
+        if (P)
+          TF_DeleteTensor(P);
+    }
+
+    EvaluationResult(const EvaluationResult &) = delete;
+    EvaluationResult(EvaluationResult &&Other)
+        : OutputSize(Other.OutputSize), Output(std::move(Other.Output)) {
+      Other.Output.clear();
+    };
+
+    /// Get a pointer to the first element of the tensor at Index.
+    template <typename T> T *getTensorValue(size_t Index) {
+      return static_cast<T *>(TF_TensorData(Output[Index]));
+    }
+
+  private:
+    friend class TFModelEvaluator;
+    EvaluationResult(size_t OutputSize)
+        : OutputSize(OutputSize), Output(OutputSize){};
+
+    const size_t OutputSize;
+    std::vector<TF_Tensor *> Output;
+  };
+
+  using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
+  using TFSessionOptionsPtr =
+      std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>;
+  using TFStatusPtr = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
+
+  TFModelEvaluator(StringRef SavedModelPath,
+                   const std::vector<std::string> &InputNames,
+                   const std::vector<std::string> &OutputNames,
+                   const char *Tags = "serve");
+  ~TFModelEvaluator();
+  TFModelEvaluator(const TFModelEvaluator &) = delete;
+  TFModelEvaluator(TFModelEvaluator &&) = delete;
+
+  /// Evaluate the model, assuming it is valid. Returns None if the evaluation
+  /// fails or the model is invalid, or an EvaluationResult otherwise. The
+  /// inputs are assumed to have been already provided via getInput(). When
+  /// returning None, it also marks the object invalid. Pass an Output vector
+  /// with the same size as OutputNames, but with nullptr values. evaluate()
+  /// will populate it with tensors, matching in index the corresponding
+  /// OutputNames. The caller is responsible for the deallocation of those
+  /// tensors, using TF_DeleteTensor.
+  Optional<EvaluationResult> evaluate();
+
+  /// Provides access to the input vector. It is already dimensioned correctly,
+  /// but the values need to be allocated by the user.
+  std::vector<TF_Tensor *> &getInput() { return Input; }
+
+  /// Returns true if the tensorflow model was loaded successfully, false
+  /// otherwise.
+  bool isValid() const { return !!Session; }
+
+  /// Initialize the input at Index as a tensor of the given type and dimensions
+  void initInput(int Index, TF_DataType Type,
+                 const std::vector<int64_t> &Dimensions);
+
+private:
+  /// The objects necessary for carrying out an evaluation of the SavedModel.
+  /// They are expensive to set up, and we maintain them accross all the
+  /// evaluations of the model.
+  TF_Session *Session = nullptr;
+  TFGraphPtr Graph;
+  TFSessionOptionsPtr Options;
+
+  /// The specification of the input nodes.
+  std::vector<TF_Output> InputFeed;
+
+  /// The input tensors. They must match by index of the corresponding InputFeed
+  /// value. We set up the tensors once and just mutate theirs scalars before
+  /// each evaluation. The input tensors keep their value after an evaluation.
+  std::vector<TF_Tensor *> Input;
+
+  /// The specification of the output nodes. When evaluating, the tensors in the
+  /// output tensor vector must match by index the corresponding element in the
+  /// OutputFeed.
+  std::vector<TF_Output> OutputFeed;
+
+  /// Reusable utility for deleting the session.
+  void deleteSession();
+
+  /// Reusable utility for ensuring we can bind the requested Name to a node in
+  /// the SavedModel Graph.
+  bool checkReportAndReset(const TF_Output &Output, StringRef Name);
+};
+} // namespace llvm
+
+#endif // LLVM_ANALYSIS_UTILS_TFUTILS_H
\ No newline at end of file

diff  --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index a317579ecc83..703623396d96 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -1,17 +1,35 @@
 set(CommonMLSources MLInlineAdvisor.cpp)
 set(ReleaseModeMLSources ReleaseModeModelRunner.cpp)
+set(DevelopmentModeMLSources TFUtils.cpp)
 
-if (DEFINED LLVM_HAVE_TF_AOT)
-  include(TensorFlowCompile)
-  tfcompile(models/inliner serve action InlinerSizeModel llvm::InlinerSizeModel)
-  list(APPEND ReleaseModeMLSources
-    $<TARGET_OBJECTS:tf_xla_runtime_objects>
-    ${GENERATED_OBJS}
-  )
-  set(MLPolicySources ${CommonMLSources} ${ReleaseModeMLSources})
+if (DEFINED LLVM_HAVE_TF_AOT OR DEFINED LLVM_HAVE_TF_API)
+  set(MLPolicySources ${CommonMLSources})
+  if (DEFINED LLVM_HAVE_TF_AOT)
+    include(TensorFlowCompile)
+    tfcompile(models/inliner serve action InlinerSizeModel llvm::InlinerSizeModel)
+    list(APPEND ReleaseModeMLSources
+      $<TARGET_OBJECTS:tf_xla_runtime_objects>
+      ${GENERATED_OBJS}
+    )
+    LIST(APPEND MLPolicySources ${ReleaseModeMLSources})
+  else()
+    LIST(APPEND LLVM_OPTIONAL_SOURCES ${ReleaseModeMLSources})
+  endif()
+
+  if (DEFINED LLVM_HAVE_TF_API)
+    LIST(APPEND MLPolicySources ${DevelopmentModeMLSources})
+    LIST(APPEND MLLinkDeps ${tensorflow_c_api})
+  else()
+    LIST(APPEND LLVM_OPTIONAL_SOURCES ${DevelopmentModeMLSources})
+  endif()
 else()
-  set(LLVM_OPTIONAL_SOURCES ${CommonMLSources} ${ReleaseModeMLSources})
+  LIST(APPEND LLVM_OPTIONAL_SOURCES 
+    ${CommonMLSources}
+    ${DevelopmentModeMLSources}
+    ${ReleaseModeMLSources}
+    )
 endif()
+  
 
 add_llvm_component_library(LLVMAnalysis
   AliasAnalysis.cpp
@@ -57,6 +75,7 @@ add_llvm_component_library(LLVMAnalysis
   InlineCost.cpp
   InlineAdvisor.cpp
   InlineFeaturesAnalysis.cpp
+  InlineSizeEstimatorAnalysis.cpp
   InstCount.cpp
   InstructionPrecedenceTracking.cpp
   InstructionSimplify.cpp
@@ -124,4 +143,7 @@ add_llvm_component_library(LLVMAnalysis
 
   DEPENDS
   intrinsics_gen
+
+  LINK_LIBS
+  ${MLLinkDeps}
   )

diff  --git a/llvm/lib/Analysis/InlineSizeEstimatorAnalysis.cpp b/llvm/lib/Analysis/InlineSizeEstimatorAnalysis.cpp
new file mode 100644
index 000000000000..1d1952ae6cbb
--- /dev/null
+++ b/llvm/lib/Analysis/InlineSizeEstimatorAnalysis.cpp
@@ -0,0 +1,299 @@
+//===- InlineSizeEstimatorAnalysis.cpp - IR to native size from ML model --===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This implements feature and label extraction for offline supervised learning
+// of a IR to native size model.
+//
+//===----------------------------------------------------------------------===//
+#include "llvm/Analysis/InlineSizeEstimatorAnalysis.h"
+
+#ifdef LLVM_HAVE_TF_API
+#include "llvm/Analysis/Utils/TFUtils.h"
+#endif
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/MC/MCAsmLayout.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <algorithm>
+#include <deque>
+
+using namespace llvm;
+
+AnalysisKey InlineSizeEstimatorAnalysis::Key;
+
+#define DEBUG_TYPE "inline-size-estimator"
+
+#ifdef LLVM_HAVE_TF_API
+cl::opt<std::string> TFIR2NativeModelPath(
+    "ml-inliner-ir2native-model", cl::Hidden,
+    cl::desc("Path to saved model evaluating native size from IR."));
+
+namespace {
+unsigned getMaxInstructionID() {
+#define LAST_OTHER_INST(NR) return NR;
+#include "llvm/IR/Instruction.def"
+}
+
+class IRToNativeSizeLearning {
+public:
+  enum class NamedFeatureIndex : size_t {
+    InitialSize,
+    Blocks,
+    Calls,
+    IsLocal,
+    IsLinkOnceODR,
+    IsLinkOnce,
+    Loops,
+    MaxLoopDepth,
+    MaxDomTreeLevel,
+
+    NumNamedFeatures
+  };
+  static const size_t NumNamedFeatures =
+      static_cast<size_t>(NamedFeatureIndex::NumNamedFeatures);
+  struct FunctionFeatures {
+    static std::vector<std::pair<size_t, size_t>>
+        ImportantInstructionSuccessions;
+    static const size_t FeatureCount;
+
+    std::array<int32_t, NumNamedFeatures> NamedFeatures = {0};
+    std::vector<int32_t> InstructionHistogram;
+    std::vector<int32_t> InstructionPairHistogram;
+
+    void fillTensor(int32_t *Ptr) const;
+    int32_t &operator[](NamedFeatureIndex Pos) {
+      return NamedFeatures[static_cast<size_t>(Pos)];
+    }
+  };
+  IRToNativeSizeLearning() = default;
+
+  static FunctionFeatures getFunctionFeatures(Function &F,
+                                              FunctionAnalysisManager &FAM);
+
+private:
+  /// Sort once the feature tuples.
+  struct SortFeatureTuples {
+    bool IsSorted = false;
+    SortFeatureTuples() {
+      std::sort(FunctionFeatures::ImportantInstructionSuccessions.begin(),
+                FunctionFeatures::ImportantInstructionSuccessions.end());
+      IsSorted = true;
+    }
+  };
+
+  static llvm::ManagedStatic<SortFeatureTuples> TupleSorter;
+
+  static bool ensureSortedTuples() { return TupleSorter->IsSorted; }
+};
+llvm::ManagedStatic<IRToNativeSizeLearning::SortFeatureTuples>
+    IRToNativeSizeLearning::TupleSorter;
+
+// This is a point in time - we determined including these pairs of
+// consecutive instructions (in the IR layout available at inline time) as
+// features improves the model performance. We want to move away from manual
+// feature selection.
+// The vector is given in opcode pairs rather than labels because 1) labels
+// weren't readily available, and 2) the successions were hand - extracted
+std::vector<std::pair<size_t, size_t>>
+    IRToNativeSizeLearning::FunctionFeatures::ImportantInstructionSuccessions =
+        {{1, 34},  {15, 27}, {53, 53}, {53, 34}, {1, 11},  {32, 2},  {2, 48},
+         {28, 48}, {1, 45},  {49, 32}, {57, 56}, {55, 53}, {1, 28},  {57, 34},
+         {1, 1},   {32, 28}, {32, 15}, {49, 28}, {53, 1},  {2, 53},  {48, 34},
+         {28, 53}, {2, 32},  {1, 40},  {32, 48}, {29, 56}, {56, 32}, {55, 56},
+         {48, 56}, {1, 31},  {33, 34}, {2, 28},  {1, 12},  {55, 1},  {31, 31},
+         {65, 1},  {33, 56}, {32, 32}, {13, 13}, {1, 26},  {13, 26}, {2, 1},
+         {1, 33},  {47, 49}, {64, 1},  {2, 38},  {34, 53}, {48, 2},  {55, 34},
+         {34, 32}, {1, 5},   {56, 13}, {2, 2},   {2, 49},  {33, 2},  {49, 39},
+         {56, 49}, {33, 49}, {32, 39}, {39, 57}, {29, 33}, {31, 34}, {32, 29},
+         {47, 15}, {13, 34}, {2, 33},  {32, 49}, {49, 34}, {56, 33}, {1, 30},
+         {33, 33}, {31, 33}, {2, 29},  {56, 7},  {32, 13}, {2, 55},  {56, 56},
+         {2, 34},  {1, 42},  {34, 49}, {1, 20},  {32, 33}, {1, 25},  {53, 28},
+         {1, 14},  {31, 49}, {28, 2},  {2, 13},  {2, 56},  {1, 32},  {56, 53},
+         {65, 65}, {33, 53}, {64, 64}, {13, 2},  {34, 33}, {1, 4},   {49, 2},
+         {1, 9},   {56, 1},  {33, 1},  {53, 57}, {32, 53}, {13, 56}, {32, 56},
+         {55, 55}, {1, 18},  {49, 56}, {34, 34}, {1, 7},   {56, 64}, {32, 1},
+         {13, 33}, {55, 28}, {49, 33}, {57, 57}, {56, 34}, {34, 56}, {33, 32},
+         {32, 40}, {1, 29},  {53, 2},  {34, 1},  {32, 34}, {49, 49}, {1, 24},
+         {40, 34}, {1, 13},  {38, 34}, {29, 2},  {34, 2},  {1, 39},  {1, 22},
+         {1, 27},  {49, 1},  {1, 8},   {56, 2}};
+
+// We have: 9 calculated features (the features here); 1 feature for each
+// instruction opcode; and 1 feature for each manually-identified sequence.
+// For the latter 2, we build a histogram: we count the number of
+// occurrences of each instruction opcode or succession of instructions,
+// respectively.
+// Note that instruction opcodes start from 1. For convenience, we also have an
+// always 0 feature for the '0' opcode, hence the extra 1.
+const size_t IRToNativeSizeLearning::FunctionFeatures::FeatureCount =
+    IRToNativeSizeLearning::FunctionFeatures::ImportantInstructionSuccessions
+        .size() +
+    getMaxInstructionID() + 1 + IRToNativeSizeLearning::NumNamedFeatures;
+
+size_t getSize(Function &F, TargetTransformInfo &TTI) {
+  size_t Ret = 0;
+  for (auto &BB : F)
+    for (auto &I : BB)
+      Ret += TTI.getInstructionCost(
+          &I, TargetTransformInfo::TargetCostKind::TCK_CodeSize);
+  return Ret;
+}
+
+size_t getSize(Function &F, FunctionAnalysisManager &FAM) {
+  auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
+  return getSize(F, TTI);
+}
+
+unsigned getMaxDominatorTreeDepth(const Function &F,
+                                  const DominatorTree &Tree) {
+  unsigned Ret = 0;
+  for (auto &BB : F)
+    if (auto *TN = Tree.getNode(&BB))
+      Ret = std::max(Ret, TN->getLevel());
+  return Ret;
+}
+} // namespace
+
+IRToNativeSizeLearning::FunctionFeatures
+IRToNativeSizeLearning::getFunctionFeatures(Function &F,
+                                            FunctionAnalysisManager &FAM) {
+  assert(ensureSortedTuples() && "expected lazy initialization");
+
+  auto &DomTree = FAM.getResult<DominatorTreeAnalysis>(F);
+  FunctionFeatures FF;
+  size_t InstrCount = getMaxInstructionID() + 1;
+  FF.InstructionHistogram.resize(InstrCount);
+
+  FF.InstructionPairHistogram.resize(
+      FunctionFeatures::ImportantInstructionSuccessions.size());
+
+  auto StartID = 0;
+  auto LastID = StartID;
+  auto getPairIndex = [](size_t a, size_t b) {
+    auto I =
+        std::find(FunctionFeatures::ImportantInstructionSuccessions.begin(),
+                  FunctionFeatures::ImportantInstructionSuccessions.end(),
+                  std::make_pair(a, b));
+    if (I == FunctionFeatures::ImportantInstructionSuccessions.end())
+      return -1;
+    return static_cast<int>(std::distance(
+        FunctionFeatures::ImportantInstructionSuccessions.begin(), I));
+  };
+
+  // We don't want debug calls, because they'd just add noise.
+  for (auto &BB : F) {
+    for (auto I = BB.instructionsWithoutDebug().begin(),
+              E = BB.instructionsWithoutDebug().end();
+         I != E; ++I) {
+      auto ID = I->getOpcode();
+
+      ++FF.InstructionHistogram[ID];
+      int PairIndex = getPairIndex(LastID, ID);
+      if (PairIndex >= 0)
+        ++FF.InstructionPairHistogram[PairIndex];
+      LastID = ID;
+      if (isa<CallBase>(*I))
+        ++FF[NamedFeatureIndex::Calls];
+    }
+  }
+
+  FF[NamedFeatureIndex::InitialSize] = getSize(F, FAM);
+  FF[NamedFeatureIndex::IsLocal] = F.hasLocalLinkage();
+  FF[NamedFeatureIndex::IsLinkOnceODR] = F.hasLinkOnceODRLinkage();
+  FF[NamedFeatureIndex::IsLinkOnce] = F.hasLinkOnceLinkage();
+  FF[NamedFeatureIndex::Blocks] =
+      std::distance(F.getBasicBlockList().begin(), F.getBasicBlockList().end());
+  auto &LI = FAM.getResult<LoopAnalysis>(F);
+  FF[NamedFeatureIndex::Loops] = std::distance(LI.begin(), LI.end());
+  for (auto &L : LI)
+    FF[NamedFeatureIndex::MaxLoopDepth] =
+        std::max(FF[NamedFeatureIndex::MaxLoopDepth],
+                 static_cast<int32_t>(L->getLoopDepth()));
+  FF[NamedFeatureIndex::MaxDomTreeLevel] = getMaxDominatorTreeDepth(F, DomTree);
+  return FF;
+}
+
+void IRToNativeSizeLearning::FunctionFeatures::fillTensor(int32_t *Ptr) const {
+  std::copy(NamedFeatures.begin(), NamedFeatures.end(), Ptr);
+  Ptr += NamedFeatures.size();
+  std::copy(InstructionHistogram.begin(), InstructionHistogram.end(), Ptr);
+  Ptr += InstructionHistogram.size();
+  std::copy(InstructionPairHistogram.begin(), InstructionPairHistogram.end(),
+            Ptr);
+}
+
+bool InlineSizeEstimatorAnalysis::isEvaluatorRequested() {
+  return !TFIR2NativeModelPath.empty();
+}
+
+InlineSizeEstimatorAnalysis::InlineSizeEstimatorAnalysis() {
+  if (!isEvaluatorRequested()) {
+    return;
+  }
+  std::vector<std::string> InputNames{"serving_default_input_1"};
+  std::vector<std::string> OutputName{"StatefulPartitionedCall"};
+  Evaluator = std::make_unique<TFModelEvaluator>(
+      TFIR2NativeModelPath.getValue().c_str(), InputNames, OutputName);
+  if (!Evaluator || !Evaluator->isValid()) {
+    Evaluator.reset();
+    return;
+  }
+  static const std::vector<int64_t> Dim{
+      1, static_cast<int64_t>(
+             IRToNativeSizeLearning::FunctionFeatures::FeatureCount)};
+
+  Evaluator->initInput(0, TF_INT32, Dim);
+}
+
+InlineSizeEstimatorAnalysis::Result
+InlineSizeEstimatorAnalysis::run(const Function &F,
+                                 FunctionAnalysisManager &FAM) {
+  if (!Evaluator)
+    return None;
+  auto Features = IRToNativeSizeLearning::getFunctionFeatures(
+      const_cast<Function &>(F), FAM);
+  int32_t *V = static_cast<int32_t *>(TF_TensorData(Evaluator->getInput()[0]));
+  Features.fillTensor(V);
+  auto ER = Evaluator->evaluate();
+  if (!ER)
+    return None;
+  float Ret = *ER->getTensorValue<float>(0);
+  if (Ret < 0.0)
+    Ret = 0.0;
+  return static_cast<size_t>(Ret);
+}
+
+InlineSizeEstimatorAnalysis::~InlineSizeEstimatorAnalysis() {}
+InlineSizeEstimatorAnalysis::InlineSizeEstimatorAnalysis(
+    InlineSizeEstimatorAnalysis &&Other)
+    : Evaluator(std::move(Other.Evaluator)) {}
+
+#else
+namespace llvm {
+class TFModelEvaluator {};
+} // namespace llvm
+InlineSizeEstimatorAnalysis::InlineSizeEstimatorAnalysis() {}
+InlineSizeEstimatorAnalysis ::InlineSizeEstimatorAnalysis(
+    InlineSizeEstimatorAnalysis &&) {}
+InlineSizeEstimatorAnalysis::~InlineSizeEstimatorAnalysis() {}
+InlineSizeEstimatorAnalysis::Result
+InlineSizeEstimatorAnalysis::run(const Function &F,
+                                 FunctionAnalysisManager &FAM) {
+  return None;
+}
+bool InlineSizeEstimatorAnalysis::isEvaluatorRequested() { return false; }
+#endif
\ No newline at end of file

diff  --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp
new file mode 100644
index 000000000000..6cd5b5c9b4ea
--- /dev/null
+++ b/llvm/lib/Analysis/TFUtils.cpp
@@ -0,0 +1,143 @@
+//===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utilities for interfacing with tensorflow C APIs.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/Utils/TFUtils.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ManagedStatic.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "tensorflow/c/c_api_experimental.h"
+
+#include <cassert>
+
+using namespace llvm;
+
+namespace {
+
+struct TFInitializer {
+  TFInitializer() {
+    assert(!IsInitialized && "TFInitialized should be called only once");
+    int Argc = 1;
+    const char *Name = "";
+    const char **NamePtr = &Name;
+    TF_InitMain(Name, &Argc, const_cast<char ***>(&NamePtr));
+    IsInitialized = true;
+  }
+  bool IsInitialized = false;
+};
+
+llvm::ManagedStatic<TFInitializer> TFLibInitializer;
+
+bool ensureInitTF() { return TFLibInitializer->IsInitialized; }
+
+TFModelEvaluator::TFGraphPtr createTFGraph() {
+  return TFModelEvaluator::TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph);
+}
+
+TFModelEvaluator::TFStatusPtr createTFStatus() {
+  return TFModelEvaluator::TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus);
+}
+
+TFModelEvaluator::TFSessionOptionsPtr createTFSessionOptions() {
+  return TFModelEvaluator::TFSessionOptionsPtr(TF_NewSessionOptions(),
+                                               &TF_DeleteSessionOptions);
+}
+} // namespace
+
+TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
+                                   const std::vector<std::string> &InputNames,
+                                   const std::vector<std::string> &OutputNames,
+                                   const char *Tags)
+    : Graph(createTFGraph()), Options(createTFSessionOptions()),
+      InputFeed(InputNames.size()), Input(InputNames.size()),
+      OutputFeed(OutputNames.size()) {
+  if (!ensureInitTF()) {
+    errs() << "Tensorflow should have been initialized";
+    return;
+  }
+  auto Status = createTFStatus();
+
+  Session = TF_LoadSessionFromSavedModel(Options.get(), nullptr,
+                                         SavedModelPath.str().c_str(), &Tags, 1,
+                                         Graph.get(), nullptr, Status.get());
+  if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
+    errs() << TF_Message(Status.get());
+    deleteSession();
+  }
+  for (size_t I = 0; I < InputNames.size(); ++I) {
+    InputFeed[I] = {
+        TF_GraphOperationByName(Graph.get(), (InputNames[I]).c_str()), 0};
+    if (!checkReportAndReset(InputFeed[I], InputNames[I]))
+      return;
+  }
+  for (size_t I = 0; I < OutputNames.size(); ++I) {
+    OutputFeed[I] = {
+        TF_GraphOperationByName(Graph.get(), (OutputNames[I]).c_str()), 0};
+    if (!checkReportAndReset(OutputFeed[I], OutputNames[I]))
+      return;
+  }
+}
+
+TFModelEvaluator::~TFModelEvaluator() {
+  for (auto *T : Input) {
+    TF_DeleteTensor(T);
+  }
+  deleteSession();
+}
+
+bool TFModelEvaluator::checkReportAndReset(const TF_Output &Output,
+                                           StringRef Name) {
+  if (Output.oper)
+    return true;
+  errs() << "Could not find TF_Output named: " + Name;
+  deleteSession();
+  return false;
+}
+
+void TFModelEvaluator::deleteSession() {
+  if (Session == nullptr)
+    return;
+  auto Status = createTFStatus();
+  TF_DeleteSession(Session, Status.get());
+  Session = nullptr;
+  if (TF_GetCode(Status.get()) != TF_Code::TF_OK)
+    errs() << "Could not delete TF session";
+}
+
+Optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
+  if (!isValid())
+    return None;
+  EvaluationResult Ret(OutputFeed.size());
+  auto Status = createTFStatus();
+  TF_SessionRun(Session, nullptr, InputFeed.data(), Input.data(), Input.size(),
+                OutputFeed.data(), Ret.Output.data(), Ret.Output.size(),
+                nullptr, 0, nullptr, Status.get());
+  if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
+    errs() << TF_Message(Status.get());
+    deleteSession();
+    return None;
+  }
+  return Ret;
+}
+
+void TFModelEvaluator::initInput(int Index, TF_DataType Type,
+                                 const std::vector<int64_t> &Dimensions) {
+  int64_t TotalSize = TF_DataTypeSize(Type);
+  for (auto &D : Dimensions)
+    TotalSize *= D;
+
+  Input[Index] =
+      TF_AllocateTensor(Type, Dimensions.data(), Dimensions.size(), TotalSize);
+  std::memset(TF_TensorData(Input[Index]), 0, TotalSize);
+}
\ No newline at end of file

diff  --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 771cdfd17aa5..7f5763467695 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -35,6 +35,7 @@
 #include "llvm/Analysis/IVUsers.h"
 #include "llvm/Analysis/InlineAdvisor.h"
 #include "llvm/Analysis/InlineFeaturesAnalysis.h"
+#include "llvm/Analysis/InlineSizeEstimatorAnalysis.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/Analysis/LazyValueInfo.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"

diff  --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index eb2b740db561..dfdfc3d05976 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -133,6 +133,7 @@ FUNCTION_ANALYSIS("loops", LoopAnalysis())
 FUNCTION_ANALYSIS("lazy-value-info", LazyValueAnalysis())
 FUNCTION_ANALYSIS("da", DependenceAnalysis())
 FUNCTION_ANALYSIS("inliner-features", InlineFeaturesAnalysis())
+FUNCTION_ANALYSIS("inliner-size-estimator", InlineSizeEstimatorAnalysis())
 FUNCTION_ANALYSIS("memdep", MemoryDependenceAnalysis())
 FUNCTION_ANALYSIS("memoryssa", MemorySSAAnalysis())
 FUNCTION_ANALYSIS("phi-values", PhiValuesAnalysis())

diff  --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 42f7dd3c0610..59ad444d32fb 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -6,7 +6,13 @@ set(LLVM_LINK_COMPONENTS
   TransformUtils
   )
 
-add_llvm_unittest(AnalysisTests
+if (DEFINED LLVM_HAVE_TF_API)
+  LIST(APPEND EXTRA_TESTS TFUtilsTest.cpp)
+else()
+  LIST(APPEND LLVM_OPTIONAL_SOURCES TFUtilsTest.cpp)
+endif()
+
+add_llvm_unittest_with_input_files(AnalysisTests
   AliasAnalysisTest.cpp
   AliasSetTrackerTest.cpp
   AssumeBundleQueriesTest.cpp
@@ -22,6 +28,7 @@ add_llvm_unittest(AnalysisTests
   DomTreeUpdaterTest.cpp
   GlobalsModRefTest.cpp
   InlineFeaturesAnalysisTest.cpp
+  InlineSizeEstimatorAnalysisTest.cpp
   IVDescriptorsTest.cpp
   LazyCallGraphTest.cpp
   LoadsTest.cpp
@@ -40,4 +47,7 @@ add_llvm_unittest(AnalysisTests
   ValueLatticeTest.cpp
   ValueTrackingTest.cpp
   VectorUtilsTest.cpp
+  ${EXTRA_TESTS}
   )
+
+ target_link_libraries(AnalysisTests PRIVATE LLVMTestingSupport)

diff  --git a/llvm/unittests/Analysis/InlineSizeEstimatorAnalysisTest.cpp b/llvm/unittests/Analysis/InlineSizeEstimatorAnalysisTest.cpp
new file mode 100644
index 000000000000..377590be016a
--- /dev/null
+++ b/llvm/unittests/Analysis/InlineSizeEstimatorAnalysisTest.cpp
@@ -0,0 +1,101 @@
+//===- InlineSizeEstimatorAnalysisTest.cpp - test for ir2native -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/InlineSizeEstimatorAnalysis.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+extern const char *TestMainArgv0;
+extern cl::opt<std::string> TFIR2NativeModelPath;
+
+#if LLVM_HAVE_TF_API
+static std::string getModelPath() {
+  SmallString<128> InputsDir = unittest::getInputFileDirectory(TestMainArgv0);
+  llvm::sys::path::append(InputsDir, "ir2native_x86_64_model");
+  return std::string(InputsDir);
+}
+#endif
+
+static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
+  SMDiagnostic Err;
+  std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+  if (!Mod)
+    Err.print("MLAnalysisTests", errs());
+  return Mod;
+}
+
+static FunctionAnalysisManager buildFAM() {
+  FunctionAnalysisManager FAM;
+  FAM.registerPass([&] { return DominatorTreeAnalysis(); });
+  FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+  FAM.registerPass([&] { return TargetIRAnalysis(); });
+  FAM.registerPass([&] { return LoopAnalysis(); });
+  return FAM;
+}
+
+// Test model loading and evaluation.
+TEST(InlineSizeEstimatorAnalysis, SizeIsValidTest) {
+  LLVMContext C;
+  std::unique_ptr<Module> M = parseIR(C,
+                                      R"IR(
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-pc-linux-gnu"
+
+declare i32 @f1(i32)
+declare i32 @f2(i32)
+
+define i32 @branches(i32) {
+  %cond = icmp slt i32 %0, 3
+  br i1 %cond, label %then, label %else
+
+then:
+  %ret.1 = call i32 @f1(i32 %0)
+  br label %last.block
+
+else:
+  %ret.2 = call i32 @f2(i32 %0)
+  br label %last.block
+
+last.block:
+  %ret = phi i32 [%ret.1, %then], [%ret.2, %else]
+  ret i32 %ret
+}
+
+define internal i32 @top() {
+  %1 = call i32 @branches(i32 2)
+  %2 = call i32 @f1(i32 %1)
+  ret i32 %2
+}
+)IR");
+
+  FunctionAnalysisManager FAM = buildFAM();
+#if LLVM_HAVE_TF_API
+  TFIR2NativeModelPath = getModelPath();
+#endif
+
+  InlineSizeEstimatorAnalysis FA;
+  auto SizeEstimate = FA.run(*M->getFunction("branches"), FAM);
+#if LLVM_HAVE_TF_API
+  EXPECT_GT(*SizeEstimate, 0);
+#else
+  EXPECT_FALSE(SizeEstimate.hasValue());
+#endif
+}

diff  --git a/llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/saved_model.pbtxt b/llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/saved_model.pbtxt
new file mode 100644
index 000000000000..6efdad51083d
--- /dev/null
+++ b/llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/saved_model.pbtxt
@@ -0,0 +1,10596 @@
+saved_model_schema_version: 1
+meta_graphs {
+  meta_info_def {
+    stripped_op_list {
+      op {
+        name: "Const"
+        output_arg {
+          name: "output"
+          type_attr: "dtype"
+        }
+        attr {
+          name: "value"
+          type: "tensor"
+        }
+        attr {
+          name: "dtype"
+          type: "type"
+        }
+      }
+      op {
+        name: "NoOp"
+      }
+      op {
+        name: "Placeholder"
+        output_arg {
+          name: "output"
+          type_attr: "dtype"
+        }
+        attr {
+          name: "dtype"
+          type: "type"
+        }
+        attr {
+          name: "shape"
+          type: "shape"
+          default_value {
+            shape {
+              unknown_rank: true
+            }
+          }
+        }
+      }
+      op {
+        name: "ReadVariableOp"
+        input_arg {
+          name: "resource"
+          type: DT_RESOURCE
+        }
+        output_arg {
+          name: "value"
+          type_attr: "dtype"
+        }
+        attr {
+          name: "dtype"
+          type: "type"
+        }
+        is_stateful: true
+      }
+      op {
+        name: "StatefulPartitionedCall"
+        input_arg {
+          name: "args"
+          type_list_attr: "Tin"
+        }
+        output_arg {
+          name: "output"
+          type_list_attr: "Tout"
+        }
+        attr {
+          name: "Tin"
+          type: "list(type)"
+          has_minimum: true
+        }
+        attr {
+          name: "Tout"
+          type: "list(type)"
+          has_minimum: true
+        }
+        attr {
+          name: "f"
+          type: "func"
+        }
+        attr {
+          name: "config"
+          type: "string"
+          default_value {
+            s: ""
+          }
+        }
+        attr {
+          name: "config_proto"
+          type: "string"
+          default_value {
+            s: ""
+          }
+        }
+        attr {
+          name: "executor_type"
+          type: "string"
+          default_value {
+            s: ""
+          }
+        }
+        is_stateful: true
+      }
+      op {
+        name: "VarHandleOp"
+        output_arg {
+          name: "resource"
+          type: DT_RESOURCE
+        }
+        attr {
+          name: "container"
+          type: "string"
+          default_value {
+            s: ""
+          }
+        }
+        attr {
+          name: "shared_name"
+          type: "string"
+          default_value {
+            s: ""
+          }
+        }
+        attr {
+          name: "dtype"
+          type: "type"
+        }
+        attr {
+          name: "shape"
+          type: "shape"
+        }
+        is_stateful: true
+      }
+    }
+    tags: "serve"
+    tensorflow_version: "1.15.0"
+    tensorflow_git_version: "unknown"
+    stripped_default_attrs: true
+  }
+  graph_def {
+    node {
+      name: "dense/kernel"
+      op: "VarHandleOp"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+            dim {
+              size: 214
+            }
+            dim {
+              size: 100
+            }
+          }
+        }
+      }
+      attr {
+        key: "shared_name"
+        value {
+          s: "dense/kernel"
+        }
+      }
+    }
+    node {
+      name: "dense/kernel/Read/ReadVariableOp"
+      op: "ReadVariableOp"
+      input: "dense/kernel"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+              dim {
+                size: 214
+              }
+              dim {
+                size: 100
+              }
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+    }
+    node {
+      name: "dense/bias"
+      op: "VarHandleOp"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+            dim {
+              size: 100
+            }
+          }
+        }
+      }
+      attr {
+        key: "shared_name"
+        value {
+          s: "dense/bias"
+        }
+      }
+    }
+    node {
+      name: "dense/bias/Read/ReadVariableOp"
+      op: "ReadVariableOp"
+      input: "dense/bias"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+              dim {
+                size: 100
+              }
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+    }
+    node {
+      name: "dense_1/kernel"
+      op: "VarHandleOp"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+            dim {
+              size: 100
+            }
+            dim {
+              size: 1
+            }
+          }
+        }
+      }
+      attr {
+        key: "shared_name"
+        value {
+          s: "dense_1/kernel"
+        }
+      }
+    }
+    node {
+      name: "dense_1/kernel/Read/ReadVariableOp"
+      op: "ReadVariableOp"
+      input: "dense_1/kernel"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+              dim {
+                size: 100
+              }
+              dim {
+                size: 1
+              }
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+    }
+    node {
+      name: "dense_1/bias"
+      op: "VarHandleOp"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+            dim {
+              size: 1
+            }
+          }
+        }
+      }
+      attr {
+        key: "shared_name"
+        value {
+          s: "dense_1/bias"
+        }
+      }
+    }
+    node {
+      name: "dense_1/bias/Read/ReadVariableOp"
+      op: "ReadVariableOp"
+      input: "dense_1/bias"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+              dim {
+                size: 1
+              }
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+    }
+    node {
+      name: "total"
+      op: "VarHandleOp"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+          }
+        }
+      }
+      attr {
+        key: "shared_name"
+        value {
+          s: "total"
+        }
+      }
+    }
+    node {
+      name: "total/Read/ReadVariableOp"
+      op: "ReadVariableOp"
+      input: "total"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+    }
+    node {
+      name: "count"
+      op: "VarHandleOp"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+          }
+        }
+      }
+      attr {
+        key: "shared_name"
+        value {
+          s: "count"
+        }
+      }
+    }
+    node {
+      name: "count/Read/ReadVariableOp"
+      op: "ReadVariableOp"
+      input: "count"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+    }
+    node {
+      name: "total_1"
+      op: "VarHandleOp"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+          }
+        }
+      }
+      attr {
+        key: "shared_name"
+        value {
+          s: "total_1"
+        }
+      }
+    }
+    node {
+      name: "total_1/Read/ReadVariableOp"
+      op: "ReadVariableOp"
+      input: "total_1"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+    }
+    node {
+      name: "count_1"
+      op: "VarHandleOp"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+          }
+        }
+      }
+      attr {
+        key: "shared_name"
+        value {
+          s: "count_1"
+        }
+      }
+    }
+    node {
+      name: "count_1/Read/ReadVariableOp"
+      op: "ReadVariableOp"
+      input: "count_1"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_FLOAT
+        }
+      }
+    }
+    node {
+      name: "NoOp"
+      op: "NoOp"
+    }
+    node {
+      name: "Const"
+      op: "Const"
+      device: "/device:CPU:0"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_STRING
+        }
+      }
+      attr {
+        key: "value"
+        value {
+          tensor {
+            dtype: DT_STRING
+            tensor_shape {
+            }
+            string_val: "\n\277\001\n\030\010\001\022\024layer_with_weights-0\n\013\010\001\022\007layer-0\n\030\010\002\022\024layer_with_weights-1\n\013\010\002\022\007layer-1\n\r\010\003\022\toptimizer\n\031\010\004\022\025regularization_losses\n\r\010\005\022\tvariables\n\027\010\006\022\023trainable_variables\n\r\010\007\022\tkeras_api\n\016\010\010\022\nsignatures\nh\n\n\010\t\022\006kernel\n\010\010\n\022\004bias\n\031\010\013\022\025regularization_losses\n\r\010\014\022\tvariables\n\027\010\r\022\023trainable_variables\n\r\010\016\022\tkeras_api\nh\n\n\010\017\022\006kernel\n\010\010\020\022\004bias\n\031\010\021\022\025regularization_losses\n\r\010\022\022\tvariables\n\027\010\023\022\023trainable_variables\n\r\010\024\022\tkeras_api\n\000\n\000\n\034\n\005\010\t\022\0010\n\005\010\n\022\0011\n\005\010\017\022\0012\n\005\010\020\022\0013\n\034\n\005\010\t\022\0010\n\005\010\n\022\0011\n\005\010\017\022\0012\n\005\010\020\022\0013\n\255\001\n\n\010\025\022\006layers\n\037\010\026\022\033layer_regularization_losses\n\033\010\027\022\027non_trainable_variables\n\021\010\030\022\rlayer_metrics\n\031\010\004\022\025regularization_losses\n\013\010\031\022\007metrics\n\r\010\005\022\tvariables\n\027\010\006\022\023trainable_variables\n\000\nX\022V\n\016VARIABLE_VALUE\022\014dense/kernel\0326layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE\nT\022R\n\016VARIABLE_VALUE\022\ndense/bias\0324layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE\n\000\n\016\n\005\010\t\022\0010\n\005\010\n\022\0011\n\016\n\005\010\t\022\0010\n\005\010\n\022\0011\n\255\001\n\n\010\032\022\006layers\n\037\010\033\022\033layer_regularization_losses\n\033\010\034\022\027non_trainable_variables\n\021\010\035\022\rlayer_metrics\n\031\010\013\022\025regularization_losses\n\013\010\036\022\007metrics\n\r\010\014\022\tvariables\n\027\010\r\022\023trainable_variables\nZ\022X\n\016VARIABLE_VALUE\022\016dense_1/kernel\0326layer_with_weights-1/kernel/.ATTRIBUTES/VARIABLE_VALUE\nV\022T\n\016VARIABLE_VALUE\022\014dense_1/bias\0324layer_with_weights-1/bias/.ATTRIBUTES/VARIABLE_VALUE\n\000\n\016\n\005\010\017\022\0010\n\005\010\020\022\0011\n\016\n\005\010\017\022\0010\n\005\010\020\022\0011\n\255\001\n\n\010\037\022\006layers\n\037\010 \022\033layer_regularization_losses\n\033\010!\022\027non_trainable_variables\n\021\010\"\022\rlayer_metrics\n\031\010\021\022\025regularization_losses\n\013\010#\022\007metrics\n\r\010\022\022\tvariables\n\027\010\023\022\023trainable_variables\n\016\n\005\010\001\022\0010\n\005\010\002\022\0011\n\000\n\000\n\000\n\016\n\005\010$\022\0010\n\005\010%\022\0011\n\000\n\000\n\000\n\000\n\000\n\000\n\000\n\000\n\000\n\000\n4\n\t\010&\022\005total\n\t\010\'\022\005count\n\r\010(\022\tvariables\n\r\010)\022\tkeras_api\nD\n\t\010*\022\005total\n\t\010+\022\005count\n\016\010,\022\n_fn_kwargs\n\r\010-\022\tvariables\n\r\010.\022\tkeras_api\nO\022M\n\016VARIABLE_VALUE\022\005total\0324keras_api/metrics/0/total/.ATTRIBUTES/VARIABLE_VALUE\nO\022M\n\016VARIABLE_VALUE\022\005count\0324keras_api/metrics/0/count/.ATTRIBUTES/VARIABLE_VALUE\n\016\n\005\010&\022\0010\n\005\010\'\022\0011\n\017\n\r\010(\022\tvariables\nQ\022O\n\016VARIABLE_VALUE\022\007total_1\0324keras_api/metrics/1/total/.ATTRIBUTES/VARIABLE_VALUE\nQ\022O\n\016VARIABLE_VALUE\022\007count_1\0324keras_api/metrics/1/count/.ATTRIBUTES/VARIABLE_VALUE\n\000\n\016\n\005\010*\022\0010\n\005\010+\022\0011\n\017\n\r\010-\022\tvariables"
+          }
+        }
+      }
+    }
+    node {
+      name: "serving_default_input_1"
+      op: "Placeholder"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+              dim {
+                size: -1
+              }
+              dim {
+                size: 214
+              }
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_INT32
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+            dim {
+              size: -1
+            }
+            dim {
+              size: 214
+            }
+          }
+        }
+      }
+    }
+    node {
+      name: "StatefulPartitionedCall"
+      op: "StatefulPartitionedCall"
+      input: "serving_default_input_1"
+      input: "dense/kernel"
+      input: "dense/bias"
+      input: "dense_1/kernel"
+      input: "dense_1/bias"
+      attr {
+        key: "Tin"
+        value {
+          list {
+            type: DT_INT32
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+          }
+        }
+      }
+      attr {
+        key: "Tout"
+        value {
+          list {
+            type: DT_FLOAT
+          }
+        }
+      }
+      attr {
+        key: "_collective_manager_ids"
+        value {
+          list {
+          }
+        }
+      }
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+              dim {
+                size: -1
+              }
+              dim {
+                size: 1
+              }
+            }
+          }
+        }
+      }
+      attr {
+        key: "_read_only_resource_inputs"
+        value {
+          list {
+            i: 1
+            i: 2
+            i: 3
+            i: 4
+          }
+        }
+      }
+      attr {
+        key: "config_proto"
+        value {
+          s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+        }
+      }
+      attr {
+        key: "f"
+        value {
+          func {
+            name: "__inference_signature_wrapper_6671"
+          }
+        }
+      }
+    }
+    node {
+      name: "saver_filename"
+      op: "Placeholder"
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "dtype"
+        value {
+          type: DT_STRING
+        }
+      }
+      attr {
+        key: "shape"
+        value {
+          shape {
+          }
+        }
+      }
+    }
+    node {
+      name: "StatefulPartitionedCall_1"
+      op: "StatefulPartitionedCall"
+      input: "saver_filename"
+      input: "dense/kernel/Read/ReadVariableOp"
+      input: "dense/bias/Read/ReadVariableOp"
+      input: "dense_1/kernel/Read/ReadVariableOp"
+      input: "dense_1/bias/Read/ReadVariableOp"
+      input: "total/Read/ReadVariableOp"
+      input: "count/Read/ReadVariableOp"
+      input: "total_1/Read/ReadVariableOp"
+      input: "count_1/Read/ReadVariableOp"
+      input: "Const"
+      attr {
+        key: "Tin"
+        value {
+          list {
+            type: DT_STRING
+            type: DT_FLOAT
+            type: DT_FLOAT
+            type: DT_FLOAT
+            type: DT_FLOAT
+            type: DT_FLOAT
+            type: DT_FLOAT
+            type: DT_FLOAT
+            type: DT_FLOAT
+            type: DT_STRING
+          }
+        }
+      }
+      attr {
+        key: "Tout"
+        value {
+          list {
+            type: DT_STRING
+          }
+        }
+      }
+      attr {
+        key: "_collective_manager_ids"
+        value {
+          list {
+          }
+        }
+      }
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "_read_only_resource_inputs"
+        value {
+          list {
+          }
+        }
+      }
+      attr {
+        key: "config_proto"
+        value {
+          s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+        }
+      }
+      attr {
+        key: "f"
+        value {
+          func {
+            name: "__inference__traced_save_6824"
+          }
+        }
+      }
+    }
+    node {
+      name: "StatefulPartitionedCall_2"
+      op: "StatefulPartitionedCall"
+      input: "saver_filename"
+      input: "dense/kernel"
+      input: "dense/bias"
+      input: "dense_1/kernel"
+      input: "dense_1/bias"
+      input: "total"
+      input: "count"
+      input: "total_1"
+      input: "count_1"
+      attr {
+        key: "Tin"
+        value {
+          list {
+            type: DT_STRING
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+            type: DT_RESOURCE
+          }
+        }
+      }
+      attr {
+        key: "Tout"
+        value {
+          list {
+            type: DT_STRING
+          }
+        }
+      }
+      attr {
+        key: "_collective_manager_ids"
+        value {
+          list {
+          }
+        }
+      }
+      attr {
+        key: "_output_shapes"
+        value {
+          list {
+            shape {
+            }
+          }
+        }
+      }
+      attr {
+        key: "_read_only_resource_inputs"
+        value {
+          list {
+          }
+        }
+      }
+      attr {
+        key: "config_proto"
+        value {
+          s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+        }
+      }
+      attr {
+        key: "f"
+        value {
+          func {
+            name: "__inference__traced_restore_6860"
+          }
+        }
+      }
+    }
+    library {
+      function {
+        signature {
+          name: "__inference__traced_restore_6860"
+          input_arg {
+            name: "file_prefix"
+            type: DT_STRING
+          }
+          input_arg {
+            name: "assignvariableop_dense_kernel"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "assignvariableop_1_dense_bias"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "assignvariableop_2_dense_1_kernel"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "assignvariableop_3_dense_1_bias"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "assignvariableop_4_total"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "assignvariableop_5_count"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "assignvariableop_6_total_1"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "assignvariableop_7_count_1"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity_9"
+            type: DT_STRING
+          }
+          is_stateful: true
+          control_output: "AssignVariableOp"
+          control_output: "AssignVariableOp_1"
+          control_output: "AssignVariableOp_2"
+          control_output: "AssignVariableOp_3"
+          control_output: "AssignVariableOp_4"
+          control_output: "AssignVariableOp_5"
+          control_output: "AssignVariableOp_6"
+          control_output: "AssignVariableOp_7"
+          control_output: "RestoreV2"
+          control_output: "RestoreV2_1"
+        }
+        node_def {
+          name: "RestoreV2/tensor_names"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 8
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                  dim {
+                    size: 8
+                  }
+                }
+                string_val: "layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "layer_with_weights-1/kernel/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "layer_with_weights-1/bias/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "keras_api/metrics/0/total/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "keras_api/metrics/0/count/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "keras_api/metrics/1/total/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "keras_api/metrics/1/count/.ATTRIBUTES/VARIABLE_VALUE"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "RestoreV2/tensor_names"
+          }
+        }
+        node_def {
+          name: "RestoreV2/shape_and_slices"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 8
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                  dim {
+                    size: 8
+                  }
+                }
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "RestoreV2/shape_and_slices"
+          }
+        }
+        node_def {
+          name: "RestoreV2"
+          op: "RestoreV2"
+          input: "file_prefix"
+          input: "RestoreV2/tensor_names:output:0"
+          input: "RestoreV2/shape_and_slices:output:0"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+                shape {
+                  unknown_rank: true
+                }
+                shape {
+                  unknown_rank: true
+                }
+                shape {
+                  unknown_rank: true
+                }
+                shape {
+                  unknown_rank: true
+                }
+                shape {
+                  unknown_rank: true
+                }
+                shape {
+                  unknown_rank: true
+                }
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtypes"
+            value {
+              list {
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "RestoreV2"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "RestoreV2:tensors:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        node_def {
+          name: "AssignVariableOp"
+          op: "AssignVariableOp"
+          input: "assignvariableop_dense_kernel"
+          input: "Identity:output:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "AssignVariableOp"
+          }
+        }
+        node_def {
+          name: "Identity_1"
+          op: "Identity"
+          input: "RestoreV2:tensors:1"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_1"
+          }
+        }
+        node_def {
+          name: "AssignVariableOp_1"
+          op: "AssignVariableOp"
+          input: "assignvariableop_1_dense_bias"
+          input: "Identity_1:output:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "AssignVariableOp_1"
+          }
+        }
+        node_def {
+          name: "Identity_2"
+          op: "Identity"
+          input: "RestoreV2:tensors:2"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_2"
+          }
+        }
+        node_def {
+          name: "AssignVariableOp_2"
+          op: "AssignVariableOp"
+          input: "assignvariableop_2_dense_1_kernel"
+          input: "Identity_2:output:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "AssignVariableOp_2"
+          }
+        }
+        node_def {
+          name: "Identity_3"
+          op: "Identity"
+          input: "RestoreV2:tensors:3"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_3"
+          }
+        }
+        node_def {
+          name: "AssignVariableOp_3"
+          op: "AssignVariableOp"
+          input: "assignvariableop_3_dense_1_bias"
+          input: "Identity_3:output:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "AssignVariableOp_3"
+          }
+        }
+        node_def {
+          name: "Identity_4"
+          op: "Identity"
+          input: "RestoreV2:tensors:4"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_4"
+          }
+        }
+        node_def {
+          name: "AssignVariableOp_4"
+          op: "AssignVariableOp"
+          input: "assignvariableop_4_total"
+          input: "Identity_4:output:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "AssignVariableOp_4"
+          }
+        }
+        node_def {
+          name: "Identity_5"
+          op: "Identity"
+          input: "RestoreV2:tensors:5"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_5"
+          }
+        }
+        node_def {
+          name: "AssignVariableOp_5"
+          op: "AssignVariableOp"
+          input: "assignvariableop_5_count"
+          input: "Identity_5:output:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "AssignVariableOp_5"
+          }
+        }
+        node_def {
+          name: "Identity_6"
+          op: "Identity"
+          input: "RestoreV2:tensors:6"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_6"
+          }
+        }
+        node_def {
+          name: "AssignVariableOp_6"
+          op: "AssignVariableOp"
+          input: "assignvariableop_6_total_1"
+          input: "Identity_6:output:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "AssignVariableOp_6"
+          }
+        }
+        node_def {
+          name: "Identity_7"
+          op: "Identity"
+          input: "RestoreV2:tensors:7"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_7"
+          }
+        }
+        node_def {
+          name: "AssignVariableOp_7"
+          op: "AssignVariableOp"
+          input: "assignvariableop_7_count_1"
+          input: "Identity_7:output:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "AssignVariableOp_7"
+          }
+        }
+        node_def {
+          name: "RestoreV2_1/tensor_names"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                  dim {
+                    size: 1
+                  }
+                }
+                string_val: "_CHECKPOINTABLE_OBJECT_GRAPH"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "RestoreV2_1/tensor_names"
+          }
+        }
+        node_def {
+          name: "RestoreV2_1/shape_and_slices"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                  dim {
+                    size: 1
+                  }
+                }
+                string_val: ""
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "RestoreV2_1/shape_and_slices"
+          }
+        }
+        node_def {
+          name: "RestoreV2_1"
+          op: "RestoreV2"
+          input: "file_prefix"
+          input: "RestoreV2_1/tensor_names:output:0"
+          input: "RestoreV2_1/shape_and_slices:output:0"
+          input: "^RestoreV2"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  unknown_rank: true
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtypes"
+            value {
+              list {
+                type: DT_STRING
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "RestoreV2_1"
+          }
+        }
+        node_def {
+          name: "NoOp"
+          op: "NoOp"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "NoOp"
+          }
+        }
+        node_def {
+          name: "Identity_8"
+          op: "Identity"
+          input: "file_prefix"
+          input: "^AssignVariableOp"
+          input: "^AssignVariableOp_1"
+          input: "^AssignVariableOp_2"
+          input: "^AssignVariableOp_3"
+          input: "^AssignVariableOp_4"
+          input: "^AssignVariableOp_5"
+          input: "^AssignVariableOp_6"
+          input: "^AssignVariableOp_7"
+          input: "^NoOp"
+          device: "/device:CPU:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_8"
+          }
+        }
+        node_def {
+          name: "Identity_9"
+          op: "Identity"
+          input: "Identity_8:output:0"
+          input: "^AssignVariableOp"
+          input: "^AssignVariableOp_1"
+          input: "^AssignVariableOp_2"
+          input: "^AssignVariableOp_3"
+          input: "^AssignVariableOp_4"
+          input: "^AssignVariableOp_5"
+          input: "^AssignVariableOp_6"
+          input: "^AssignVariableOp_7"
+          input: "^RestoreV2"
+          input: "^RestoreV2_1"
+          attr {
+            key: "T"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_9"
+          }
+        }
+        ret {
+          key: "identity_9"
+          value: "Identity_9:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "AssignVariableOp"
+          value: "AssignVariableOp"
+        }
+        control_ret {
+          key: "AssignVariableOp_1"
+          value: "AssignVariableOp_1"
+        }
+        control_ret {
+          key: "AssignVariableOp_2"
+          value: "AssignVariableOp_2"
+        }
+        control_ret {
+          key: "AssignVariableOp_3"
+          value: "AssignVariableOp_3"
+        }
+        control_ret {
+          key: "AssignVariableOp_4"
+          value: "AssignVariableOp_4"
+        }
+        control_ret {
+          key: "AssignVariableOp_5"
+          value: "AssignVariableOp_5"
+        }
+        control_ret {
+          key: "AssignVariableOp_6"
+          value: "AssignVariableOp_6"
+        }
+        control_ret {
+          key: "AssignVariableOp_7"
+          value: "AssignVariableOp_7"
+        }
+        control_ret {
+          key: "RestoreV2"
+          value: "RestoreV2"
+        }
+        control_ret {
+          key: "RestoreV2_1"
+          value: "RestoreV2_1"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "file_prefix"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 5
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 6
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 7
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 8
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_fn_6629"
+          input_arg {
+            name: "input_1"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "unknown"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_0"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_1"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_2"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "StatefulPartitionedCall"
+        }
+        node_def {
+          name: "StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "input_1"
+          input: "unknown"
+          input: "unknown_0"
+          input: "unknown_1"
+          input: "unknown_2"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+                i: 3
+                i: 4
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_sequential_layer_call_and_return_conditional_losses_6618"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "StatefulPartitionedCall:output:0"
+          input: "^StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "StatefulPartitionedCall"
+          value: "StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "input_1"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_and_return_conditional_losses_6587"
+          input_arg {
+            name: "input_1"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "dense_6555"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_6557"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_6581"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_6583"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "dense/StatefulPartitionedCall"
+          control_output: "dense_1/StatefulPartitionedCall"
+        }
+        node_def {
+          name: "dense/StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "input_1"
+          input: "dense_6555"
+          input: "dense_6557"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_layer_call_and_return_conditional_losses_6544"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "dense_1/StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "dense/StatefulPartitionedCall:output:0"
+          input: "dense_1_6581"
+          input: "dense_1_6583"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_FLOAT
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_1_layer_call_and_return_conditional_losses_6570"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "dense_1/StatefulPartitionedCall:output:0"
+          input: "^dense/StatefulPartitionedCall"
+          input: "^dense_1/StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "dense/StatefulPartitionedCall"
+          value: "dense/StatefulPartitionedCall"
+        }
+        control_ret {
+          key: "dense_1/StatefulPartitionedCall"
+          value: "dense_1/StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "input_1"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_and_return_conditional_losses_6618"
+          input_arg {
+            name: "inputs"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "dense_6607"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_6609"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_6612"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_6614"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "dense/StatefulPartitionedCall"
+          control_output: "dense_1/StatefulPartitionedCall"
+        }
+        node_def {
+          name: "dense/StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "inputs"
+          input: "dense_6607"
+          input: "dense_6609"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_layer_call_and_return_conditional_losses_6544"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "dense_1/StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "dense/StatefulPartitionedCall:output:0"
+          input: "dense_1_6612"
+          input: "dense_1_6614"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_FLOAT
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_1_layer_call_and_return_conditional_losses_6570"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "dense_1/StatefulPartitionedCall:output:0"
+          input: "^dense/StatefulPartitionedCall"
+          input: "^dense_1/StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "dense/StatefulPartitionedCall"
+          value: "dense/StatefulPartitionedCall"
+        }
+        control_ret {
+          key: "dense_1/StatefulPartitionedCall"
+          value: "dense_1/StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_fn_6656"
+          input_arg {
+            name: "input_1"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "unknown"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_0"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_1"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_2"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "StatefulPartitionedCall"
+        }
+        node_def {
+          name: "StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "input_1"
+          input: "unknown"
+          input: "unknown_0"
+          input: "unknown_1"
+          input: "unknown_2"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+                i: 3
+                i: 4
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_sequential_layer_call_and_return_conditional_losses_6645"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "StatefulPartitionedCall:output:0"
+          input: "^StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "StatefulPartitionedCall"
+          value: "StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "input_1"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_dense_1_layer_call_and_return_conditional_losses_6764"
+          input_arg {
+            name: "inputs"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+        }
+        node_def {
+          name: "MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "MatMul"
+          op: "MatMul"
+          input: "inputs"
+          input: "MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MatMul"
+          }
+        }
+        node_def {
+          name: "BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "BiasAdd"
+          op: "BiasAdd"
+          input: "MatMul:product:0"
+          input: "BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "BiasAdd"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 100
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 100
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_dense_layer_call_fn_6754"
+          input_arg {
+            name: "inputs"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "unknown"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_0"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "StatefulPartitionedCall"
+        }
+        node_def {
+          name: "StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "inputs"
+          input: "unknown"
+          input: "unknown_0"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_layer_call_and_return_conditional_losses_6544"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "StatefulPartitionedCall:output:0"
+          input: "^StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "StatefulPartitionedCall"
+          value: "StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference__traced_save_6824"
+          input_arg {
+            name: "file_prefix"
+            type: DT_STRING
+          }
+          input_arg {
+            name: "savev2_dense_kernel_read_readvariableop"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "savev2_dense_bias_read_readvariableop"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "savev2_dense_1_kernel_read_readvariableop"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "savev2_dense_1_bias_read_readvariableop"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "savev2_total_read_readvariableop"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "savev2_count_read_readvariableop"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "savev2_total_1_read_readvariableop"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "savev2_count_1_read_readvariableop"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "savev2_1_const"
+            type: DT_STRING
+          }
+          output_arg {
+            name: "identity_1"
+            type: DT_STRING
+          }
+          is_stateful: true
+          control_output: "MergeV2Checkpoints"
+          control_output: "SaveV2"
+          control_output: "SaveV2_1"
+        }
+        node_def {
+          name: "StaticRegexFullMatch"
+          op: "StaticRegexFullMatch"
+          input: "file_prefix"
+          device: "/device:CPU:*"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          attr {
+            key: "pattern"
+            value {
+              s: "^s3://.*"
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "StaticRegexFullMatch"
+          }
+        }
+        node_def {
+          name: "Const"
+          op: "Const"
+          device: "/device:CPU:*"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                }
+                string_val: ".part"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Const"
+          }
+        }
+        node_def {
+          name: "Const_1"
+          op: "Const"
+          device: "/device:CPU:*"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                }
+                string_val: "_temp_6f1e5fef49bb4c06ace07a8a95dfbb1b/part"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Const_1"
+          }
+        }
+        node_def {
+          name: "Select"
+          op: "Select"
+          input: "StaticRegexFullMatch:output:0"
+          input: "Const:output:0"
+          input: "Const_1:output:0"
+          device: "/device:CPU:*"
+          attr {
+            key: "T"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Select"
+          }
+        }
+        node_def {
+          name: "StringJoin"
+          op: "StringJoin"
+          input: "file_prefix"
+          input: "Select:output:0"
+          device: "/device:CPU:*"
+          attr {
+            key: "N"
+            value {
+              i: 2
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "StringJoin"
+          }
+        }
+        node_def {
+          name: "num_shards"
+          op: "Const"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_INT32
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_INT32
+                tensor_shape {
+                }
+                int_val: 2
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "num_shards"
+          }
+        }
+        node_def {
+          name: "ShardedFilename/shard"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_INT32
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_INT32
+                tensor_shape {
+                }
+                int_val: 0
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "ShardedFilename/shard"
+          }
+        }
+        node_def {
+          name: "ShardedFilename"
+          op: "ShardedFilename"
+          input: "StringJoin:output:0"
+          input: "ShardedFilename/shard:output:0"
+          input: "num_shards:output:0"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "ShardedFilename"
+          }
+        }
+        node_def {
+          name: "SaveV2/tensor_names"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 8
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                  dim {
+                    size: 8
+                  }
+                }
+                string_val: "layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "layer_with_weights-1/kernel/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "layer_with_weights-1/bias/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "keras_api/metrics/0/total/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "keras_api/metrics/0/count/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "keras_api/metrics/1/total/.ATTRIBUTES/VARIABLE_VALUE"
+                string_val: "keras_api/metrics/1/count/.ATTRIBUTES/VARIABLE_VALUE"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "SaveV2/tensor_names"
+          }
+        }
+        node_def {
+          name: "SaveV2/shape_and_slices"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 8
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                  dim {
+                    size: 8
+                  }
+                }
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+                string_val: ""
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "SaveV2/shape_and_slices"
+          }
+        }
+        node_def {
+          name: "SaveV2"
+          op: "SaveV2"
+          input: "ShardedFilename:filename:0"
+          input: "SaveV2/tensor_names:output:0"
+          input: "SaveV2/shape_and_slices:output:0"
+          input: "savev2_dense_kernel_read_readvariableop"
+          input: "savev2_dense_bias_read_readvariableop"
+          input: "savev2_dense_1_kernel_read_readvariableop"
+          input: "savev2_dense_1_bias_read_readvariableop"
+          input: "savev2_total_read_readvariableop"
+          input: "savev2_count_read_readvariableop"
+          input: "savev2_total_1_read_readvariableop"
+          input: "savev2_count_1_read_readvariableop"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtypes"
+            value {
+              list {
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+                type: DT_FLOAT
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "SaveV2"
+          }
+        }
+        node_def {
+          name: "ShardedFilename_1/shard"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_INT32
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_INT32
+                tensor_shape {
+                }
+                int_val: 1
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "ShardedFilename_1/shard"
+          }
+        }
+        node_def {
+          name: "ShardedFilename_1"
+          op: "ShardedFilename"
+          input: "StringJoin:output:0"
+          input: "ShardedFilename_1/shard:output:0"
+          input: "num_shards:output:0"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "ShardedFilename_1"
+          }
+        }
+        node_def {
+          name: "SaveV2_1/tensor_names"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                  dim {
+                    size: 1
+                  }
+                }
+                string_val: "_CHECKPOINTABLE_OBJECT_GRAPH"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "SaveV2_1/tensor_names"
+          }
+        }
+        node_def {
+          name: "SaveV2_1/shape_and_slices"
+          op: "Const"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "value"
+            value {
+              tensor {
+                dtype: DT_STRING
+                tensor_shape {
+                  dim {
+                    size: 1
+                  }
+                }
+                string_val: ""
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "SaveV2_1/shape_and_slices"
+          }
+        }
+        node_def {
+          name: "SaveV2_1"
+          op: "SaveV2"
+          input: "ShardedFilename_1:filename:0"
+          input: "SaveV2_1/tensor_names:output:0"
+          input: "SaveV2_1/shape_and_slices:output:0"
+          input: "savev2_1_const"
+          input: "^SaveV2"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "dtypes"
+            value {
+              list {
+                type: DT_STRING
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "SaveV2_1"
+          }
+        }
+        node_def {
+          name: "MergeV2Checkpoints/checkpoint_prefixes"
+          op: "Pack"
+          input: "ShardedFilename:filename:0"
+          input: "ShardedFilename_1:filename:0"
+          input: "^SaveV2"
+          input: "^SaveV2_1"
+          device: "/device:CPU:0"
+          attr {
+            key: "N"
+            value {
+              i: 2
+            }
+          }
+          attr {
+            key: "T"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 2
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MergeV2Checkpoints/checkpoint_prefixes"
+          }
+        }
+        node_def {
+          name: "MergeV2Checkpoints"
+          op: "MergeV2Checkpoints"
+          input: "MergeV2Checkpoints/checkpoint_prefixes:output:0"
+          input: "file_prefix"
+          input: "^SaveV2_1"
+          device: "/device:CPU:0"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MergeV2Checkpoints"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "file_prefix"
+          input: "^MergeV2Checkpoints"
+          device: "/device:CPU:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        node_def {
+          name: "Identity_1"
+          op: "Identity"
+          input: "Identity:output:0"
+          input: "^MergeV2Checkpoints"
+          input: "^SaveV2"
+          input: "^SaveV2_1"
+          attr {
+            key: "T"
+            value {
+              type: DT_STRING
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity_1"
+          }
+        }
+        ret {
+          key: "identity_1"
+          value: "Identity_1:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+              }
+              shape {
+                dim {
+                  size: 214
+                }
+                dim {
+                  size: 100
+                }
+              }
+              shape {
+                dim {
+                  size: 100
+                }
+              }
+              shape {
+                dim {
+                  size: 100
+                }
+                dim {
+                  size: 1
+                }
+              }
+              shape {
+                dim {
+                  size: 1
+                }
+              }
+              shape {
+              }
+              shape {
+              }
+              shape {
+              }
+              shape {
+              }
+              shape {
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "MergeV2Checkpoints"
+          value: "MergeV2Checkpoints"
+        }
+        control_ret {
+          key: "SaveV2"
+          value: "SaveV2"
+        }
+        control_ret {
+          key: "SaveV2_1"
+          value: "SaveV2_1"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "file_prefix"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: 214
+                    }
+                    dim {
+                      size: 100
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: 100
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: 100
+                    }
+                    dim {
+                      size: 1
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: 1
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 5
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 6
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 7
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 8
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 9
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_and_return_conditional_losses_6689"
+          input_arg {
+            name: "inputs"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "dense_matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+        }
+        node_def {
+          name: "dense/Cast"
+          op: "Cast"
+          input: "inputs"
+          attr {
+            key: "DstT"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "SrcT"
+            value {
+              type: DT_INT32
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 214
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/Cast"
+          }
+        }
+        node_def {
+          name: "dense/MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "dense_matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 214
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "dense/MatMul"
+          op: "MatMul"
+          input: "dense/Cast:y:0"
+          input: "dense/MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/MatMul"
+          }
+        }
+        node_def {
+          name: "dense/BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "dense_biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "dense/BiasAdd"
+          op: "BiasAdd"
+          input: "dense/MatMul:product:0"
+          input: "dense/BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/BiasAdd"
+          }
+        }
+        node_def {
+          name: "dense/Relu"
+          op: "Relu"
+          input: "dense/BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/Relu"
+          }
+        }
+        node_def {
+          name: "dense_1/MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "dense_1_matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "dense_1/MatMul"
+          op: "MatMul"
+          input: "dense/Relu:activations:0"
+          input: "dense_1/MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/MatMul"
+          }
+        }
+        node_def {
+          name: "dense_1/BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "dense_1_biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "dense_1/BiasAdd"
+          op: "BiasAdd"
+          input: "dense_1/MatMul:product:0"
+          input: "dense_1/BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/BiasAdd"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "dense_1/BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_dense_layer_call_and_return_conditional_losses_6745"
+          input_arg {
+            name: "inputs"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+        }
+        node_def {
+          name: "Cast"
+          op: "Cast"
+          input: "inputs"
+          attr {
+            key: "DstT"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "SrcT"
+            value {
+              type: DT_INT32
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 214
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Cast"
+          }
+        }
+        node_def {
+          name: "MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 214
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "MatMul"
+          op: "MatMul"
+          input: "Cast:y:0"
+          input: "MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MatMul"
+          }
+        }
+        node_def {
+          name: "BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "BiasAdd"
+          op: "BiasAdd"
+          input: "MatMul:product:0"
+          input: "BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "BiasAdd"
+          }
+        }
+        node_def {
+          name: "Relu"
+          op: "Relu"
+          input: "BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Relu"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "Relu:activations:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_dense_1_layer_call_fn_6773"
+          input_arg {
+            name: "inputs"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "unknown"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_0"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "StatefulPartitionedCall"
+        }
+        node_def {
+          name: "StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "inputs"
+          input: "unknown"
+          input: "unknown_0"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_FLOAT
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_1_layer_call_and_return_conditional_losses_6570"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "StatefulPartitionedCall:output:0"
+          input: "^StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 100
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "StatefulPartitionedCall"
+          value: "StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 100
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference__wrapped_model_6528"
+          input_arg {
+            name: "input_1"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "sequential_dense_matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "sequential_dense_biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "sequential_dense_1_matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "sequential_dense_1_biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+        }
+        node_def {
+          name: "sequential/dense/Cast"
+          op: "Cast"
+          input: "input_1"
+          attr {
+            key: "DstT"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "SrcT"
+            value {
+              type: DT_INT32
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 214
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense/Cast"
+          }
+        }
+        node_def {
+          name: "sequential/dense/MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "sequential_dense_matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 214
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense/MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "sequential/dense/MatMul"
+          op: "MatMul"
+          input: "sequential/dense/Cast:y:0"
+          input: "sequential/dense/MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense/MatMul"
+          }
+        }
+        node_def {
+          name: "sequential/dense/BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "sequential_dense_biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense/BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "sequential/dense/BiasAdd"
+          op: "BiasAdd"
+          input: "sequential/dense/MatMul:product:0"
+          input: "sequential/dense/BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense/BiasAdd"
+          }
+        }
+        node_def {
+          name: "sequential/dense/Relu"
+          op: "Relu"
+          input: "sequential/dense/BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense/Relu"
+          }
+        }
+        node_def {
+          name: "sequential/dense_1/MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "sequential_dense_1_matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense_1/MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "sequential/dense_1/MatMul"
+          op: "MatMul"
+          input: "sequential/dense/Relu:activations:0"
+          input: "sequential/dense_1/MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense_1/MatMul"
+          }
+        }
+        node_def {
+          name: "sequential/dense_1/BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "sequential_dense_1_biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense_1/BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "sequential/dense_1/BiasAdd"
+          op: "BiasAdd"
+          input: "sequential/dense_1/MatMul:product:0"
+          input: "sequential/dense_1/BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "sequential/dense_1/BiasAdd"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "sequential/dense_1/BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "input_1"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_dense_layer_call_and_return_conditional_losses_6544"
+          input_arg {
+            name: "inputs"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+        }
+        node_def {
+          name: "Cast"
+          op: "Cast"
+          input: "inputs"
+          attr {
+            key: "DstT"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "SrcT"
+            value {
+              type: DT_INT32
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 214
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Cast"
+          }
+        }
+        node_def {
+          name: "MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 214
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "MatMul"
+          op: "MatMul"
+          input: "Cast:y:0"
+          input: "MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MatMul"
+          }
+        }
+        node_def {
+          name: "BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "BiasAdd"
+          op: "BiasAdd"
+          input: "MatMul:product:0"
+          input: "BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "BiasAdd"
+          }
+        }
+        node_def {
+          name: "Relu"
+          op: "Relu"
+          input: "BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Relu"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "Relu:activations:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_and_return_conditional_losses_6601"
+          input_arg {
+            name: "input_1"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "dense_6590"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_6592"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_6595"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_6597"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "dense/StatefulPartitionedCall"
+          control_output: "dense_1/StatefulPartitionedCall"
+        }
+        node_def {
+          name: "dense/StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "input_1"
+          input: "dense_6590"
+          input: "dense_6592"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_layer_call_and_return_conditional_losses_6544"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "dense_1/StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "dense/StatefulPartitionedCall:output:0"
+          input: "dense_1_6595"
+          input: "dense_1_6597"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_FLOAT
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_1_layer_call_and_return_conditional_losses_6570"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "dense_1/StatefulPartitionedCall:output:0"
+          input: "^dense/StatefulPartitionedCall"
+          input: "^dense_1/StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "dense/StatefulPartitionedCall"
+          value: "dense/StatefulPartitionedCall"
+        }
+        control_ret {
+          key: "dense_1/StatefulPartitionedCall"
+          value: "dense_1/StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "input_1"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_fn_6733"
+          input_arg {
+            name: "inputs"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "unknown"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_0"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_1"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_2"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "StatefulPartitionedCall"
+        }
+        node_def {
+          name: "StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "inputs"
+          input: "unknown"
+          input: "unknown_0"
+          input: "unknown_1"
+          input: "unknown_2"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+                i: 3
+                i: 4
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_sequential_layer_call_and_return_conditional_losses_6645"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "StatefulPartitionedCall:output:0"
+          input: "^StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "StatefulPartitionedCall"
+          value: "StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_and_return_conditional_losses_6645"
+          input_arg {
+            name: "inputs"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "dense_6634"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_6636"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_6639"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_6641"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "dense/StatefulPartitionedCall"
+          control_output: "dense_1/StatefulPartitionedCall"
+        }
+        node_def {
+          name: "dense/StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "inputs"
+          input: "dense_6634"
+          input: "dense_6636"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_layer_call_and_return_conditional_losses_6544"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "dense_1/StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "dense/StatefulPartitionedCall:output:0"
+          input: "dense_1_6639"
+          input: "dense_1_6641"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_FLOAT
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_dense_1_layer_call_and_return_conditional_losses_6570"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "dense_1/StatefulPartitionedCall:output:0"
+          input: "^dense/StatefulPartitionedCall"
+          input: "^dense_1/StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "dense/StatefulPartitionedCall"
+          value: "dense/StatefulPartitionedCall"
+        }
+        control_ret {
+          key: "dense_1/StatefulPartitionedCall"
+          value: "dense_1/StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_dense_1_layer_call_and_return_conditional_losses_6570"
+          input_arg {
+            name: "inputs"
+            type: DT_FLOAT
+          }
+          input_arg {
+            name: "matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+        }
+        node_def {
+          name: "MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "MatMul"
+          op: "MatMul"
+          input: "inputs"
+          input: "MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "MatMul"
+          }
+        }
+        node_def {
+          name: "BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "BiasAdd"
+          op: "BiasAdd"
+          input: "MatMul:product:0"
+          input: "BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "BiasAdd"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 100
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 100
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_signature_wrapper_6671"
+          input_arg {
+            name: "input_1"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "unknown"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_0"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_1"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_2"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "StatefulPartitionedCall"
+        }
+        node_def {
+          name: "StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "input_1"
+          input: "unknown"
+          input: "unknown_0"
+          input: "unknown_1"
+          input: "unknown_2"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+                i: 3
+                i: 4
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference__wrapped_model_6528"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "StatefulPartitionedCall:output:0"
+          input: "^StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "StatefulPartitionedCall"
+          value: "StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "input_1"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_fn_6720"
+          input_arg {
+            name: "inputs"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "unknown"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_0"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_1"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "unknown_2"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+          control_output: "StatefulPartitionedCall"
+        }
+        node_def {
+          name: "StatefulPartitionedCall"
+          op: "StatefulPartitionedCall"
+          input: "inputs"
+          input: "unknown"
+          input: "unknown_0"
+          input: "unknown_1"
+          input: "unknown_2"
+          attr {
+            key: "Tin"
+            value {
+              list {
+                type: DT_INT32
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+                type: DT_RESOURCE
+              }
+            }
+          }
+          attr {
+            key: "Tout"
+            value {
+              list {
+                type: DT_FLOAT
+              }
+            }
+          }
+          attr {
+            key: "_collective_manager_ids"
+            value {
+              list {
+              }
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "_read_only_resource_inputs"
+            value {
+              list {
+                i: 1
+                i: 2
+                i: 3
+                i: 4
+              }
+            }
+          }
+          attr {
+            key: "config_proto"
+            value {
+              s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001"
+            }
+          }
+          attr {
+            key: "f"
+            value {
+              func {
+                name: "__inference_sequential_layer_call_and_return_conditional_losses_6618"
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "StatefulPartitionedCall"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "StatefulPartitionedCall:output:0"
+          input: "^StatefulPartitionedCall"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        control_ret {
+          key: "StatefulPartitionedCall"
+          value: "StatefulPartitionedCall"
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+      function {
+        signature {
+          name: "__inference_sequential_layer_call_and_return_conditional_losses_6707"
+          input_arg {
+            name: "inputs"
+            type: DT_INT32
+          }
+          input_arg {
+            name: "dense_matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_matmul_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          input_arg {
+            name: "dense_1_biasadd_readvariableop_resource"
+            type: DT_RESOURCE
+          }
+          output_arg {
+            name: "identity"
+            type: DT_FLOAT
+          }
+          is_stateful: true
+        }
+        node_def {
+          name: "dense/Cast"
+          op: "Cast"
+          input: "inputs"
+          attr {
+            key: "DstT"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "SrcT"
+            value {
+              type: DT_INT32
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 214
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/Cast"
+          }
+        }
+        node_def {
+          name: "dense/MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "dense_matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 214
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "dense/MatMul"
+          op: "MatMul"
+          input: "dense/Cast:y:0"
+          input: "dense/MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/MatMul"
+          }
+        }
+        node_def {
+          name: "dense/BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "dense_biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "dense/BiasAdd"
+          op: "BiasAdd"
+          input: "dense/MatMul:product:0"
+          input: "dense/BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/BiasAdd"
+          }
+        }
+        node_def {
+          name: "dense/Relu"
+          op: "Relu"
+          input: "dense/BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense/Relu"
+          }
+        }
+        node_def {
+          name: "dense_1/MatMul/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "dense_1_matmul_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 100
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/MatMul/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "dense_1/MatMul"
+          op: "MatMul"
+          input: "dense/Relu:activations:0"
+          input: "dense_1/MatMul/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/MatMul"
+          }
+        }
+        node_def {
+          name: "dense_1/BiasAdd/ReadVariableOp"
+          op: "ReadVariableOp"
+          input: "dense_1_biasadd_readvariableop_resource"
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          attr {
+            key: "dtype"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/BiasAdd/ReadVariableOp"
+          }
+        }
+        node_def {
+          name: "dense_1/BiasAdd"
+          op: "BiasAdd"
+          input: "dense_1/MatMul:product:0"
+          input: "dense_1/BiasAdd/ReadVariableOp:value:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "dense_1/BiasAdd"
+          }
+        }
+        node_def {
+          name: "Identity"
+          op: "Identity"
+          input: "dense_1/BiasAdd:output:0"
+          attr {
+            key: "T"
+            value {
+              type: DT_FLOAT
+            }
+          }
+          attr {
+            key: "_output_shapes"
+            value {
+              list {
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+              }
+            }
+          }
+          experimental_debug_info {
+            original_node_names: "Identity"
+          }
+        }
+        ret {
+          key: "identity"
+          value: "Identity:output:0"
+        }
+        attr {
+          key: "_input_shapes"
+          value {
+            list {
+              shape {
+                dim {
+                  size: -1
+                }
+                dim {
+                  size: 214
+                }
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+              shape {
+                unknown_rank: true
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 0
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                }
+              }
+            }
+            attr {
+              key: "_user_specified_name"
+              value {
+                s: "inputs"
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 1
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 2
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 3
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+        arg_attr {
+          key: 4
+          value {
+            attr {
+              key: "_output_shapes"
+              value {
+                list {
+                  shape {
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+    versions {
+      producer: 331
+      min_consumer: 12
+    }
+  }
+  saver_def {
+    filename_tensor_name: "saver_filename:0"
+    save_tensor_name: "StatefulPartitionedCall_1:0"
+    restore_op_name: "StatefulPartitionedCall_2"
+    version: V2
+  }
+  collection_def {
+    key: "saved_model_main_op"
+    value {
+      node_list {
+        value: "NoOp"
+      }
+    }
+  }
+  signature_def {
+    key: "__saved_model_init_op"
+    value {
+      outputs {
+        key: "__saved_model_init_op"
+        value {
+          name: "NoOp"
+          tensor_shape {
+            unknown_rank: true
+          }
+        }
+      }
+    }
+  }
+  signature_def {
+    key: "serving_default"
+    value {
+      inputs {
+        key: "input_1"
+        value {
+          name: "serving_default_input_1:0"
+          dtype: DT_INT32
+          tensor_shape {
+            dim {
+              size: -1
+            }
+            dim {
+              size: 214
+            }
+          }
+        }
+      }
+      outputs {
+        key: "output_1"
+        value {
+          name: "StatefulPartitionedCall:0"
+          dtype: DT_FLOAT
+          tensor_shape {
+            dim {
+              size: -1
+            }
+            dim {
+              size: 1
+            }
+          }
+        }
+      }
+      method_name: "tensorflow/serving/predict"
+    }
+  }
+  object_graph_def {
+    nodes {
+      children {
+        node_id: 1
+        local_name: "layer_with_weights-0"
+      }
+      children {
+        node_id: 1
+        local_name: "layer-0"
+      }
+      children {
+        node_id: 2
+        local_name: "layer_with_weights-1"
+      }
+      children {
+        node_id: 2
+        local_name: "layer-1"
+      }
+      children {
+        node_id: 3
+        local_name: "optimizer"
+      }
+      children {
+        node_id: 4
+        local_name: "regularization_losses"
+      }
+      children {
+        node_id: 5
+        local_name: "variables"
+      }
+      children {
+        node_id: 6
+        local_name: "trainable_variables"
+      }
+      children {
+        node_id: 7
+        local_name: "keras_api"
+      }
+      children {
+        node_id: 8
+        local_name: "signatures"
+      }
+      children {
+        node_id: 47
+        local_name: "__call__"
+      }
+      children {
+        node_id: 48
+        local_name: "_default_save_signature"
+      }
+      children {
+        node_id: 49
+        local_name: "call_and_return_all_conditional_losses"
+      }
+      user_object {
+        identifier: "_tf_keras_sequential"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+        metadata: "{\"class_name\": \"Sequential\", \"name\": \"sequential\", \"trainable\": true, \"expects_training_arg\": true, \"dtype\": \"float32\", \"batch_input_shape\": null, \"config\": {\"name\": \"sequential\", \"layers\": [{\"class_name\": \"Dense\", \"config\": {\"name\": \"dense\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 100, \"activation\": \"relu\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}, {\"class_name\": \"Dense\", \"config\": {\"name\": \"dense_1\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 1, \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}], \"build_input_shape\": {\"class_name\": \"__tuple__\", \"items\": [null, 214]}}, \"input_spec\": {\"class_name\": \"InputSpec\", \"config\": {\"dtype\": null, \"shape\": null, \"ndim\": null, \"max_ndim\": null, \"min_ndim\": 2, \"axes\": {\"-1\": 214}}}, \"build_input_shape\": {\"class_name\": \"__tuple__\", \"items\": [null, 214]}, \"is_graph_network\": false, \"keras_version\": \"2.2.4-tf\", \"backend\": \"tensorflow\", \"model_config\": {\"class_name\": \"Sequential\", \"config\": {\"name\": \"sequential\", \"layers\": [{\"class_name\": \"Dense\", \"config\": {\"name\": \"dense\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 100, \"activation\": \"relu\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}, {\"class_name\": \"Dense\", \"config\": {\"name\": \"dense_1\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 1, \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}], \"build_input_shape\": {\"class_name\": \"__tuple__\", \"items\": [null, 214]}}}, \"training_config\": {\"loss\": \"mean_absolute_error\", \"metrics\": [\"mean_squared_error\"], \"weighted_metrics\": null, \"loss_weights\": null, \"sample_weight_mode\": null, \"optimizer_config\": {\"class_name\": \"Adam\", \"config\": {\"name\": \"Adam\", \"learning_rate\": 0.0003000000142492354, \"decay\": 0.0, \"beta_1\": 0.8999999761581421, \"beta_2\": 0.9990000128746033, \"epsilon\": 1e-07, \"amsgrad\": false}}}}"
+      }
+    }
+    nodes {
+      children {
+        node_id: 9
+        local_name: "kernel"
+      }
+      children {
+        node_id: 10
+        local_name: "bias"
+      }
+      children {
+        node_id: 11
+        local_name: "regularization_losses"
+      }
+      children {
+        node_id: 12
+        local_name: "variables"
+      }
+      children {
+        node_id: 13
+        local_name: "trainable_variables"
+      }
+      children {
+        node_id: 14
+        local_name: "keras_api"
+      }
+      children {
+        node_id: 50
+        local_name: "__call__"
+      }
+      children {
+        node_id: 51
+        local_name: "call_and_return_all_conditional_losses"
+      }
+      user_object {
+        identifier: "_tf_keras_layer"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+        metadata: "{\"class_name\": \"Dense\", \"name\": \"dense\", \"trainable\": true, \"expects_training_arg\": false, \"dtype\": \"float32\", \"batch_input_shape\": null, \"stateful\": false, \"config\": {\"name\": \"dense\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 100, \"activation\": \"relu\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}, \"input_spec\": {\"class_name\": \"InputSpec\", \"config\": {\"dtype\": null, \"shape\": null, \"ndim\": null, \"max_ndim\": null, \"min_ndim\": 2, \"axes\": {\"-1\": 214}}}, \"build_input_shape\": {\"class_name\": \"TensorShape\", \"items\": [null, 214]}}"
+      }
+    }
+    nodes {
+      children {
+        node_id: 15
+        local_name: "kernel"
+      }
+      children {
+        node_id: 16
+        local_name: "bias"
+      }
+      children {
+        node_id: 17
+        local_name: "regularization_losses"
+      }
+      children {
+        node_id: 18
+        local_name: "variables"
+      }
+      children {
+        node_id: 19
+        local_name: "trainable_variables"
+      }
+      children {
+        node_id: 20
+        local_name: "keras_api"
+      }
+      children {
+        node_id: 52
+        local_name: "__call__"
+      }
+      children {
+        node_id: 53
+        local_name: "call_and_return_all_conditional_losses"
+      }
+      user_object {
+        identifier: "_tf_keras_layer"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+        metadata: "{\"class_name\": \"Dense\", \"name\": \"dense_1\", \"trainable\": true, \"expects_training_arg\": false, \"dtype\": \"float32\", \"batch_input_shape\": null, \"stateful\": false, \"config\": {\"name\": \"dense_1\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 1, \"activation\": \"linear\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}, \"input_spec\": {\"class_name\": \"InputSpec\", \"config\": {\"dtype\": null, \"shape\": null, \"ndim\": null, \"max_ndim\": null, \"min_ndim\": 2, \"axes\": {\"-1\": 100}}}, \"build_input_shape\": {\"class_name\": \"TensorShape\", \"items\": [null, 100]}}"
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "optimizer"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 9
+        local_name: "0"
+      }
+      children {
+        node_id: 10
+        local_name: "1"
+      }
+      children {
+        node_id: 15
+        local_name: "2"
+      }
+      children {
+        node_id: 16
+        local_name: "3"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 9
+        local_name: "0"
+      }
+      children {
+        node_id: 10
+        local_name: "1"
+      }
+      children {
+        node_id: 15
+        local_name: "2"
+      }
+      children {
+        node_id: 16
+        local_name: "3"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 21
+        local_name: "layers"
+      }
+      children {
+        node_id: 22
+        local_name: "layer_regularization_losses"
+      }
+      children {
+        node_id: 23
+        local_name: "non_trainable_variables"
+      }
+      children {
+        node_id: 24
+        local_name: "layer_metrics"
+      }
+      children {
+        node_id: 4
+        local_name: "regularization_losses"
+      }
+      children {
+        node_id: 25
+        local_name: "metrics"
+      }
+      children {
+        node_id: 5
+        local_name: "variables"
+      }
+      children {
+        node_id: 6
+        local_name: "trainable_variables"
+      }
+      children {
+        node_id: 47
+        local_name: "__call__"
+      }
+      children {
+        node_id: 48
+        local_name: "_default_save_signature"
+      }
+      children {
+        node_id: 49
+        local_name: "call_and_return_all_conditional_losses"
+      }
+      children {
+        node_id: 49
+        local_name: "call_and_return_conditional_losses"
+      }
+      user_object {
+        identifier: "_generic_user_object"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 54
+        local_name: "serving_default"
+      }
+      user_object {
+        identifier: "signature_map"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      variable {
+        dtype: DT_FLOAT
+        shape {
+          dim {
+            size: 214
+          }
+          dim {
+            size: 100
+          }
+        }
+        trainable: true
+        name: "dense/kernel"
+      }
+    }
+    nodes {
+      variable {
+        dtype: DT_FLOAT
+        shape {
+          dim {
+            size: 100
+          }
+        }
+        trainable: true
+        name: "dense/bias"
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 9
+        local_name: "0"
+      }
+      children {
+        node_id: 10
+        local_name: "1"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 9
+        local_name: "0"
+      }
+      children {
+        node_id: 10
+        local_name: "1"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 26
+        local_name: "layers"
+      }
+      children {
+        node_id: 27
+        local_name: "layer_regularization_losses"
+      }
+      children {
+        node_id: 28
+        local_name: "non_trainable_variables"
+      }
+      children {
+        node_id: 29
+        local_name: "layer_metrics"
+      }
+      children {
+        node_id: 11
+        local_name: "regularization_losses"
+      }
+      children {
+        node_id: 30
+        local_name: "metrics"
+      }
+      children {
+        node_id: 12
+        local_name: "variables"
+      }
+      children {
+        node_id: 13
+        local_name: "trainable_variables"
+      }
+      children {
+        node_id: 50
+        local_name: "__call__"
+      }
+      children {
+        node_id: 51
+        local_name: "call_and_return_all_conditional_losses"
+      }
+      children {
+        node_id: 51
+        local_name: "call_and_return_conditional_losses"
+      }
+      user_object {
+        identifier: "_generic_user_object"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      variable {
+        dtype: DT_FLOAT
+        shape {
+          dim {
+            size: 100
+          }
+          dim {
+            size: 1
+          }
+        }
+        trainable: true
+        name: "dense_1/kernel"
+      }
+    }
+    nodes {
+      variable {
+        dtype: DT_FLOAT
+        shape {
+          dim {
+            size: 1
+          }
+        }
+        trainable: true
+        name: "dense_1/bias"
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 15
+        local_name: "0"
+      }
+      children {
+        node_id: 16
+        local_name: "1"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 15
+        local_name: "0"
+      }
+      children {
+        node_id: 16
+        local_name: "1"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 31
+        local_name: "layers"
+      }
+      children {
+        node_id: 32
+        local_name: "layer_regularization_losses"
+      }
+      children {
+        node_id: 33
+        local_name: "non_trainable_variables"
+      }
+      children {
+        node_id: 34
+        local_name: "layer_metrics"
+      }
+      children {
+        node_id: 17
+        local_name: "regularization_losses"
+      }
+      children {
+        node_id: 35
+        local_name: "metrics"
+      }
+      children {
+        node_id: 18
+        local_name: "variables"
+      }
+      children {
+        node_id: 19
+        local_name: "trainable_variables"
+      }
+      children {
+        node_id: 52
+        local_name: "__call__"
+      }
+      children {
+        node_id: 53
+        local_name: "call_and_return_all_conditional_losses"
+      }
+      children {
+        node_id: 53
+        local_name: "call_and_return_conditional_losses"
+      }
+      user_object {
+        identifier: "_generic_user_object"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 1
+        local_name: "0"
+      }
+      children {
+        node_id: 2
+        local_name: "1"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_dict_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 36
+        local_name: "0"
+      }
+      children {
+        node_id: 37
+        local_name: "1"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_dict_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_dict_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 38
+        local_name: "total"
+      }
+      children {
+        node_id: 39
+        local_name: "count"
+      }
+      children {
+        node_id: 40
+        local_name: "variables"
+      }
+      children {
+        node_id: 41
+        local_name: "keras_api"
+      }
+      user_object {
+        identifier: "_tf_keras_metric"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+        metadata: "{\"class_name\": \"Mean\", \"name\": \"loss\", \"dtype\": \"float32\", \"config\": {\"name\": \"loss\", \"dtype\": \"float32\"}}"
+      }
+    }
+    nodes {
+      children {
+        node_id: 42
+        local_name: "total"
+      }
+      children {
+        node_id: 43
+        local_name: "count"
+      }
+      children {
+        node_id: 44
+        local_name: "_fn_kwargs"
+      }
+      children {
+        node_id: 45
+        local_name: "variables"
+      }
+      children {
+        node_id: 46
+        local_name: "keras_api"
+      }
+      user_object {
+        identifier: "_tf_keras_metric"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+        metadata: "{\"class_name\": \"MeanMetricWrapper\", \"name\": \"mean_squared_error\", \"dtype\": \"float32\", \"config\": {\"name\": \"mean_squared_error\", \"dtype\": \"float32\", \"fn\": \"mean_squared_error\"}}"
+      }
+    }
+    nodes {
+      variable {
+        dtype: DT_FLOAT
+        shape {
+        }
+        synchronization: VARIABLE_SYNCHRONIZATION_ON_READ
+        aggregation: VARIABLE_AGGREGATION_SUM
+        name: "total"
+      }
+    }
+    nodes {
+      variable {
+        dtype: DT_FLOAT
+        shape {
+        }
+        synchronization: VARIABLE_SYNCHRONIZATION_ON_READ
+        aggregation: VARIABLE_AGGREGATION_SUM
+        name: "count"
+      }
+    }
+    nodes {
+      children {
+        node_id: 38
+        local_name: "0"
+      }
+      children {
+        node_id: 39
+        local_name: "1"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 40
+        local_name: "variables"
+      }
+      user_object {
+        identifier: "_generic_user_object"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      variable {
+        dtype: DT_FLOAT
+        shape {
+        }
+        synchronization: VARIABLE_SYNCHRONIZATION_ON_READ
+        aggregation: VARIABLE_AGGREGATION_SUM
+        name: "total"
+      }
+    }
+    nodes {
+      variable {
+        dtype: DT_FLOAT
+        shape {
+        }
+        synchronization: VARIABLE_SYNCHRONIZATION_ON_READ
+        aggregation: VARIABLE_AGGREGATION_SUM
+        name: "count"
+      }
+    }
+    nodes {
+      user_object {
+        identifier: "trackable_dict_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 42
+        local_name: "0"
+      }
+      children {
+        node_id: 43
+        local_name: "1"
+      }
+      user_object {
+        identifier: "trackable_list_wrapper"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      children {
+        node_id: 45
+        local_name: "variables"
+      }
+      user_object {
+        identifier: "_generic_user_object"
+        version {
+          producer: 1
+          min_consumer: 1
+        }
+      }
+    }
+    nodes {
+      function {
+        concrete_functions: "__inference_sequential_layer_call_fn_6629"
+        concrete_functions: "__inference_sequential_layer_call_fn_6733"
+        concrete_functions: "__inference_sequential_layer_call_fn_6720"
+        concrete_functions: "__inference_sequential_layer_call_fn_6656"
+        function_spec {
+          fullargspec {
+            named_tuple_value {
+              name: "FullArgSpec"
+              values {
+                key: "args"
+                value {
+                  list_value {
+                    values {
+                      string_value: "self"
+                    }
+                    values {
+                      string_value: "inputs"
+                    }
+                    values {
+                      string_value: "training"
+                    }
+                    values {
+                      string_value: "mask"
+                    }
+                  }
+                }
+              }
+              values {
+                key: "varargs"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "varkw"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "defaults"
+                value {
+                  list_value {
+                    values {
+                      bool_value: false
+                    }
+                    values {
+                      none_value {
+                      }
+                    }
+                  }
+                }
+              }
+              values {
+                key: "kwonlyargs"
+                value {
+                  list_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlydefaults"
+                value {
+                  dict_value {
+                  }
+                }
+              }
+              values {
+                key: "annotations"
+                value {
+                  dict_value {
+                  }
+                }
+              }
+            }
+          }
+          is_method: true
+          input_signature {
+            none_value {
+            }
+          }
+        }
+      }
+    }
+    nodes {
+      function {
+        concrete_functions: "__inference__wrapped_model_6528"
+        function_spec {
+          fullargspec {
+            named_tuple_value {
+              name: "FullArgSpec"
+              values {
+                key: "args"
+                value {
+                  list_value {
+                  }
+                }
+              }
+              values {
+                key: "varargs"
+                value {
+                  string_value: "args"
+                }
+              }
+              values {
+                key: "varkw"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "defaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlyargs"
+                value {
+                  list_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlydefaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "annotations"
+                value {
+                  dict_value {
+                  }
+                }
+              }
+            }
+          }
+          input_signature {
+            tuple_value {
+              values {
+                tensor_spec_value {
+                  name: "input_1"
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 214
+                    }
+                  }
+                  dtype: DT_INT32
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+    nodes {
+      function {
+        concrete_functions: "__inference_sequential_layer_call_and_return_conditional_losses_6689"
+        concrete_functions: "__inference_sequential_layer_call_and_return_conditional_losses_6587"
+        concrete_functions: "__inference_sequential_layer_call_and_return_conditional_losses_6707"
+        concrete_functions: "__inference_sequential_layer_call_and_return_conditional_losses_6601"
+        function_spec {
+          fullargspec {
+            named_tuple_value {
+              name: "FullArgSpec"
+              values {
+                key: "args"
+                value {
+                  list_value {
+                    values {
+                      string_value: "self"
+                    }
+                    values {
+                      string_value: "inputs"
+                    }
+                    values {
+                      string_value: "training"
+                    }
+                    values {
+                      string_value: "mask"
+                    }
+                  }
+                }
+              }
+              values {
+                key: "varargs"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "varkw"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "defaults"
+                value {
+                  list_value {
+                    values {
+                      bool_value: false
+                    }
+                    values {
+                      none_value {
+                      }
+                    }
+                  }
+                }
+              }
+              values {
+                key: "kwonlyargs"
+                value {
+                  list_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlydefaults"
+                value {
+                  dict_value {
+                  }
+                }
+              }
+              values {
+                key: "annotations"
+                value {
+                  dict_value {
+                  }
+                }
+              }
+            }
+          }
+          is_method: true
+          input_signature {
+            none_value {
+            }
+          }
+        }
+      }
+    }
+    nodes {
+      function {
+        concrete_functions: "__inference_dense_layer_call_fn_6754"
+        function_spec {
+          fullargspec {
+            named_tuple_value {
+              name: "FullArgSpec"
+              values {
+                key: "args"
+                value {
+                  list_value {
+                    values {
+                      string_value: "self"
+                    }
+                    values {
+                      string_value: "inputs"
+                    }
+                  }
+                }
+              }
+              values {
+                key: "varargs"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "varkw"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "defaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlyargs"
+                value {
+                  list_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlydefaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "annotations"
+                value {
+                  dict_value {
+                  }
+                }
+              }
+            }
+          }
+          is_method: true
+          input_signature {
+            none_value {
+            }
+          }
+        }
+      }
+    }
+    nodes {
+      function {
+        concrete_functions: "__inference_dense_layer_call_and_return_conditional_losses_6745"
+        function_spec {
+          fullargspec {
+            named_tuple_value {
+              name: "FullArgSpec"
+              values {
+                key: "args"
+                value {
+                  list_value {
+                    values {
+                      string_value: "self"
+                    }
+                    values {
+                      string_value: "inputs"
+                    }
+                  }
+                }
+              }
+              values {
+                key: "varargs"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "varkw"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "defaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlyargs"
+                value {
+                  list_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlydefaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "annotations"
+                value {
+                  dict_value {
+                  }
+                }
+              }
+            }
+          }
+          is_method: true
+          input_signature {
+            none_value {
+            }
+          }
+        }
+      }
+    }
+    nodes {
+      function {
+        concrete_functions: "__inference_dense_1_layer_call_fn_6773"
+        function_spec {
+          fullargspec {
+            named_tuple_value {
+              name: "FullArgSpec"
+              values {
+                key: "args"
+                value {
+                  list_value {
+                    values {
+                      string_value: "self"
+                    }
+                    values {
+                      string_value: "inputs"
+                    }
+                  }
+                }
+              }
+              values {
+                key: "varargs"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "varkw"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "defaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlyargs"
+                value {
+                  list_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlydefaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "annotations"
+                value {
+                  dict_value {
+                  }
+                }
+              }
+            }
+          }
+          is_method: true
+          input_signature {
+            none_value {
+            }
+          }
+        }
+      }
+    }
+    nodes {
+      function {
+        concrete_functions: "__inference_dense_1_layer_call_and_return_conditional_losses_6764"
+        function_spec {
+          fullargspec {
+            named_tuple_value {
+              name: "FullArgSpec"
+              values {
+                key: "args"
+                value {
+                  list_value {
+                    values {
+                      string_value: "self"
+                    }
+                    values {
+                      string_value: "inputs"
+                    }
+                  }
+                }
+              }
+              values {
+                key: "varargs"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "varkw"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "defaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlyargs"
+                value {
+                  list_value {
+                  }
+                }
+              }
+              values {
+                key: "kwonlydefaults"
+                value {
+                  none_value {
+                  }
+                }
+              }
+              values {
+                key: "annotations"
+                value {
+                  dict_value {
+                  }
+                }
+              }
+            }
+          }
+          is_method: true
+          input_signature {
+            none_value {
+            }
+          }
+        }
+      }
+    }
+    nodes {
+      bare_concrete_function {
+        concrete_function_name: "__inference_signature_wrapper_6671"
+        argument_keywords: "input_1"
+        allowed_positional_arguments: 1
+      }
+    }
+    concrete_functions {
+      key: "__inference__wrapped_model_6528"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "input_1"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          dict_value {
+            fields {
+              key: "output_1"
+              value {
+                tensor_spec_value {
+                  name: "output_1"
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 1
+                    }
+                  }
+                  dtype: DT_FLOAT
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_dense_1_layer_call_and_return_conditional_losses_6764"
+      value {
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "inputs"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 100
+                      }
+                    }
+                    dtype: DT_FLOAT
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tuple_value {
+            values {
+              tensor_spec_value {
+                name: "0"
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+                dtype: DT_FLOAT
+              }
+            }
+            values {
+              list_value {
+              }
+            }
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_dense_1_layer_call_fn_6773"
+      value {
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "inputs"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 100
+                      }
+                    }
+                    dtype: DT_FLOAT
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tensor_spec_value {
+            shape {
+              dim {
+                size: -1
+              }
+              dim {
+                size: 1
+              }
+            }
+            dtype: DT_FLOAT
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_dense_layer_call_and_return_conditional_losses_6745"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "inputs"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tuple_value {
+            values {
+              tensor_spec_value {
+                name: "0"
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 100
+                  }
+                }
+                dtype: DT_FLOAT
+              }
+            }
+            values {
+              list_value {
+              }
+            }
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_dense_layer_call_fn_6754"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "inputs"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tensor_spec_value {
+            shape {
+              dim {
+                size: -1
+              }
+              dim {
+                size: 100
+              }
+            }
+            dtype: DT_FLOAT
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_sequential_layer_call_and_return_conditional_losses_6587"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "input_1"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+                values {
+                  bool_value: true
+                }
+                values {
+                  none_value {
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tuple_value {
+            values {
+              tensor_spec_value {
+                name: "0"
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+                dtype: DT_FLOAT
+              }
+            }
+            values {
+              list_value {
+              }
+            }
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_sequential_layer_call_and_return_conditional_losses_6601"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "input_1"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+                values {
+                  bool_value: false
+                }
+                values {
+                  none_value {
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tuple_value {
+            values {
+              tensor_spec_value {
+                name: "0"
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+                dtype: DT_FLOAT
+              }
+            }
+            values {
+              list_value {
+              }
+            }
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_sequential_layer_call_and_return_conditional_losses_6689"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "inputs"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+                values {
+                  bool_value: true
+                }
+                values {
+                  none_value {
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tuple_value {
+            values {
+              tensor_spec_value {
+                name: "0"
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+                dtype: DT_FLOAT
+              }
+            }
+            values {
+              list_value {
+              }
+            }
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_sequential_layer_call_and_return_conditional_losses_6707"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "inputs"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+                values {
+                  bool_value: false
+                }
+                values {
+                  none_value {
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tuple_value {
+            values {
+              tensor_spec_value {
+                name: "0"
+                shape {
+                  dim {
+                    size: -1
+                  }
+                  dim {
+                    size: 1
+                  }
+                }
+                dtype: DT_FLOAT
+              }
+            }
+            values {
+              list_value {
+              }
+            }
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_sequential_layer_call_fn_6629"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "input_1"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+                values {
+                  bool_value: true
+                }
+                values {
+                  none_value {
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tensor_spec_value {
+            shape {
+              dim {
+                size: -1
+              }
+              dim {
+                size: 1
+              }
+            }
+            dtype: DT_FLOAT
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_sequential_layer_call_fn_6656"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "input_1"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+                values {
+                  bool_value: false
+                }
+                values {
+                  none_value {
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tensor_spec_value {
+            shape {
+              dim {
+                size: -1
+              }
+              dim {
+                size: 1
+              }
+            }
+            dtype: DT_FLOAT
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_sequential_layer_call_fn_6720"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "inputs"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+                values {
+                  bool_value: true
+                }
+                values {
+                  none_value {
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tensor_spec_value {
+            shape {
+              dim {
+                size: -1
+              }
+              dim {
+                size: 1
+              }
+            }
+            dtype: DT_FLOAT
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_sequential_layer_call_fn_6733"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+                values {
+                  tensor_spec_value {
+                    name: "inputs"
+                    shape {
+                      dim {
+                        size: -1
+                      }
+                      dim {
+                        size: 214
+                      }
+                    }
+                    dtype: DT_INT32
+                  }
+                }
+                values {
+                  bool_value: false
+                }
+                values {
+                  none_value {
+                  }
+                }
+              }
+            }
+            values {
+              dict_value {
+              }
+            }
+          }
+        }
+        output_signature {
+          tensor_spec_value {
+            shape {
+              dim {
+                size: -1
+              }
+              dim {
+                size: 1
+              }
+            }
+            dtype: DT_FLOAT
+          }
+        }
+      }
+    }
+    concrete_functions {
+      key: "__inference_signature_wrapper_6671"
+      value {
+        bound_inputs: 9
+        bound_inputs: 10
+        bound_inputs: 15
+        bound_inputs: 16
+        canonicalized_input_signature {
+          tuple_value {
+            values {
+              tuple_value {
+              }
+            }
+            values {
+              dict_value {
+                fields {
+                  key: "input_1"
+                  value {
+                    tensor_spec_value {
+                      name: "input_1"
+                      shape {
+                        dim {
+                          size: -1
+                        }
+                        dim {
+                          size: 214
+                        }
+                      }
+                      dtype: DT_INT32
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
+        output_signature {
+          dict_value {
+            fields {
+              key: "output_1"
+              value {
+                tensor_spec_value {
+                  name: "output_1"
+                  shape {
+                    dim {
+                      size: -1
+                    }
+                    dim {
+                      size: 1
+                    }
+                  }
+                  dtype: DT_FLOAT
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+}
+

diff  --git a/llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/variables/variables.data-00000-of-00001 b/llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/variables/variables.data-00000-of-00001
new file mode 100644
index 000000000000..98807d26ee9f
Binary files /dev/null and b/llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/variables/variables.data-00000-of-00001 
diff er

diff  --git a/llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/variables/variables.index b/llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/variables/variables.index
new file mode 100644
index 000000000000..c20d8afabf38
Binary files /dev/null and b/llvm/unittests/Analysis/Inputs/ir2native_x86_64_model/variables/variables.index 
diff er

diff  --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp
new file mode 100644
index 000000000000..4c775c4c0b93
--- /dev/null
+++ b/llvm/unittests/Analysis/TFUtilsTest.cpp
@@ -0,0 +1,98 @@
+//===- TFUtilsTest.cpp - test for TFUtils ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/Utils/TFUtils.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Testing/Support/SupportHelpers.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+extern const char *TestMainArgv0;
+
+static std::string getModelPath() {
+  SmallString<128> InputsDir = unittest::getInputFileDirectory(TestMainArgv0);
+  llvm::sys::path::append(InputsDir, "ir2native_x86_64_model");
+  return std::string(InputsDir);
+}
+
+// Test observable behavior when no model is provided.
+TEST(TFUtilsTest, NoModel) {
+  TFModelEvaluator Evaluator("", {}, {});
+  EXPECT_FALSE(Evaluator.isValid());
+}
+
+// Test we can correctly load a savedmodel and evaluate it.
+TEST(TFUtilsTest, LoadAndExecuteTest) {
+  // We use the ir2native model for test. We know it has one feature of
+  // dimension (1, 214)
+  std::vector<std::string> InputNames{"serving_default_input_1"};
+  std::vector<std::string> OutputName{"StatefulPartitionedCall"};
+  const static int64_t KnownSize = 214;
+
+  TFModelEvaluator Evaluator(getModelPath(), InputNames, OutputName);
+  static const std::vector<int64_t> Dim{1, KnownSize};
+
+  EXPECT_TRUE(Evaluator.isValid());
+  Evaluator.initInput(0, TF_INT32, Dim);
+
+  int32_t *V = static_cast<int32_t *>(TF_TensorData(Evaluator.getInput()[0]));
+  // Fill it up with 1's, we know the output.
+  for (auto I = 0; I < KnownSize; ++I) {
+    V[I] = 1;
+  }
+  {
+    auto ER = Evaluator.evaluate();
+    EXPECT_TRUE(ER.hasValue());
+    float Ret = *ER->getTensorValue<float>(0);
+    EXPECT_EQ(static_cast<size_t>(Ret), 80);
+  }
+  // The input vector should be unchanged
+  for (auto I = 0; I < KnownSize; ++I) {
+    EXPECT_EQ(V[I], 1);
+  }
+  // Zero-out the unused position '0' of the instruction histogram, which is
+  // after the first 9 calculated values. Should the the same result.
+  V[9] = 0;
+  {
+    auto ER = Evaluator.evaluate();
+    EXPECT_TRUE(ER.hasValue());
+    float Ret = *ER->getTensorValue<float>(0);
+    EXPECT_EQ(static_cast<size_t>(Ret), 80);
+  }
+}
+
+// Test incorrect input setup
+TEST(TFUtilsTest, EvalError) {
+  // We use the ir2native model for test. We know it has one feature of
+  // dimension (1, 214)
+  std::vector<std::string> InputNames{"serving_default_input_1"};
+  std::vector<std::string> OutputName{"StatefulPartitionedCall"};
+  const static int64_t KnownSize = 213;
+
+  TFModelEvaluator Evaluator(getModelPath(), InputNames, OutputName);
+  static const std::vector<int64_t> Dim{1, KnownSize};
+
+  EXPECT_TRUE(Evaluator.isValid());
+  Evaluator.initInput(0, TF_INT32, Dim);
+
+  int32_t *V = static_cast<int32_t *>(TF_TensorData(Evaluator.getInput()[0]));
+  // Fill it up with 1's, we know the output.
+  for (auto I = 0; I < KnownSize; ++I) {
+    V[I] = 1;
+  }
+  auto ER = Evaluator.evaluate();
+  EXPECT_FALSE(ER.hasValue());
+  EXPECT_FALSE(Evaluator.isValid());
+}


        


More information about the llvm-commits mailing list