[PATCH] D133451: [MLGO] Make TFLiteUtils throw an error if some features haven't been passed to the model

Aiden Grossman via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 10 15:59:30 PDT 2022


This revision was automatically updated to reflect the committed changes.
Closed by commit rGec83c7e358ec: [MLGO] Make TFLiteUtils throw an error if some features haven't been passed to… (authored by aidengrossman).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D133451/new/

https://reviews.llvm.org/D133451

Files:
  llvm/lib/Analysis/TFLiteUtils.cpp
  llvm/unittests/Analysis/TFUtilsTest.cpp


Index: llvm/unittests/Analysis/TFUtilsTest.cpp
===================================================================
--- llvm/unittests/Analysis/TFUtilsTest.cpp
+++ llvm/unittests/Analysis/TFUtilsTest.cpp
@@ -121,3 +121,12 @@
   for (auto I = 0; I < 2 * 5; ++I)
     EXPECT_FLOAT_EQ(F[I], 3.14 + I);
 }
+
+TEST(TFUtilsTest, MissingFeature) {
+  std::vector<TensorSpec> InputSpecs{};
+  std::vector<TensorSpec> OutputSpecs{
+      TensorSpec::createSpec<float>("StatefulPartitionedCall", {1})};
+
+  TFModelEvaluator Evaluator(getModelPath(), InputSpecs, OutputSpecs);
+  EXPECT_FALSE(Evaluator.isValid());
+}
Index: llvm/lib/Analysis/TFLiteUtils.cpp
===================================================================
--- llvm/lib/Analysis/TFLiteUtils.cpp
+++ llvm/lib/Analysis/TFLiteUtils.cpp
@@ -134,6 +134,7 @@
   for (size_t I = 0; I < Interpreter->outputs().size(); ++I)
     OutputsMap[Interpreter->GetOutputName(I)] = I;
 
+  size_t NumberFeaturesPassed = 0;
   for (size_t I = 0; I < InputSpecs.size(); ++I) {
     auto &InputSpec = InputSpecs[I];
     auto MapI = InputsMap.find(InputSpec.name() + ":" +
@@ -147,6 +148,14 @@
       return;
     std::memset(Input[I]->data.data, 0,
                 InputSpecs[I].getTotalTensorBufferSize());
+    ++NumberFeaturesPassed;
+  }
+
+  if (NumberFeaturesPassed < Interpreter->inputs().size()) {
+    // we haven't passed all the required features to the model, throw an error.
+    errs() << "Required feature(s) have not been passed to the ML model";
+    invalidate();
+    return;
   }
 
   for (size_t I = 0; I < OutputSpecsSize; ++I) {


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D133451.459313.patch
Type: text/x-patch
Size: 1592 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20220910/5a0ac81a/attachment.bin>


More information about the llvm-commits mailing list