[llvm] ec83c7e - [MLGO] Make TFLiteUtils throw an error if some features haven't been passed to the model

Aiden Grossman via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 10 15:59:23 PDT 2022


Author: Aiden Grossman
Date: 2022-09-10T22:59:03Z
New Revision: ec83c7e358ecd7db9af2d980b6d528f5ea6865a4

URL: https://github.com/llvm/llvm-project/commit/ec83c7e358ecd7db9af2d980b6d528f5ea6865a4
DIFF: https://github.com/llvm/llvm-project/commit/ec83c7e358ecd7db9af2d980b6d528f5ea6865a4.diff

LOG: [MLGO] Make TFLiteUtils throw an error if some features haven't been passed to the model

In the Tensorflow C lib utilities, an error gets thrown if some features
haven't gotten passed into the model (due to differences in ordering
which now don't exist with the transition to TFLite). However, this is
not currently the case when using TFLiteUtils. This patch makes some
minor changes to throw an error when not all inputs of the model have
been passed, which when not handled will result in a seg fault within
TFLite.

Reviewed By: mtrofin

Differential Revision: https://reviews.llvm.org/D133451

Added: 
    

Modified: 
    llvm/lib/Analysis/TFLiteUtils.cpp
    llvm/unittests/Analysis/TFUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/TFLiteUtils.cpp b/llvm/lib/Analysis/TFLiteUtils.cpp
index 9c43193476f0c..41c9847ad64af 100644
--- a/llvm/lib/Analysis/TFLiteUtils.cpp
+++ b/llvm/lib/Analysis/TFLiteUtils.cpp
@@ -134,6 +134,7 @@ TFModelEvaluatorImpl::TFModelEvaluatorImpl(
   for (size_t I = 0; I < Interpreter->outputs().size(); ++I)
     OutputsMap[Interpreter->GetOutputName(I)] = I;
 
+  size_t NumberFeaturesPassed = 0;
   for (size_t I = 0; I < InputSpecs.size(); ++I) {
     auto &InputSpec = InputSpecs[I];
     auto MapI = InputsMap.find(InputSpec.name() + ":" +
@@ -147,6 +148,14 @@ TFModelEvaluatorImpl::TFModelEvaluatorImpl(
       return;
     std::memset(Input[I]->data.data, 0,
                 InputSpecs[I].getTotalTensorBufferSize());
+    ++NumberFeaturesPassed;
+  }
+
+  if (NumberFeaturesPassed < Interpreter->inputs().size()) {
+    // we haven't passed all the required features to the model, throw an error.
+    errs() << "Required feature(s) have not been passed to the ML model";
+    invalidate();
+    return;
   }
 
   for (size_t I = 0; I < OutputSpecsSize; ++I) {

diff  --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp
index fe3b115822bee..c604afd86d904 100644
--- a/llvm/unittests/Analysis/TFUtilsTest.cpp
+++ b/llvm/unittests/Analysis/TFUtilsTest.cpp
@@ -121,3 +121,12 @@ TEST(TFUtilsTest, UnsupportedFeature) {
   for (auto I = 0; I < 2 * 5; ++I)
     EXPECT_FLOAT_EQ(F[I], 3.14 + I);
 }
+
+TEST(TFUtilsTest, MissingFeature) {
+  std::vector<TensorSpec> InputSpecs{};
+  std::vector<TensorSpec> OutputSpecs{
+      TensorSpec::createSpec<float>("StatefulPartitionedCall", {1})};
+
+  TFModelEvaluator Evaluator(getModelPath(), InputSpecs, OutputSpecs);
+  EXPECT_FALSE(Evaluator.isValid());
+}


        


More information about the llvm-commits mailing list