[llvm] ec83c7e - [MLGO] Make TFLiteUtils throw an error if some features haven't been passed to the model
Aiden Grossman via llvm-commits
llvm-commits at lists.llvm.org
Sat Sep 10 15:59:23 PDT 2022
Author: Aiden Grossman
Date: 2022-09-10T22:59:03Z
New Revision: ec83c7e358ecd7db9af2d980b6d528f5ea6865a4
URL: https://github.com/llvm/llvm-project/commit/ec83c7e358ecd7db9af2d980b6d528f5ea6865a4
DIFF: https://github.com/llvm/llvm-project/commit/ec83c7e358ecd7db9af2d980b6d528f5ea6865a4.diff
LOG: [MLGO] Make TFLiteUtils throw an error if some features haven't been passed to the model
In the Tensorflow C lib utilities, an error gets thrown if some features
haven't gotten passed into the model (due to differences in ordering
which now don't exist with the transition to TFLite). However, this is
not currently the case when using TFLiteUtils. This patch makes some
minor changes to throw an error when not all inputs of the model have
been passed, which when not handled will result in a seg fault within
TFLite.
Reviewed By: mtrofin
Differential Revision: https://reviews.llvm.org/D133451
Added:
Modified:
llvm/lib/Analysis/TFLiteUtils.cpp
llvm/unittests/Analysis/TFUtilsTest.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/TFLiteUtils.cpp b/llvm/lib/Analysis/TFLiteUtils.cpp
index 9c43193476f0c..41c9847ad64af 100644
--- a/llvm/lib/Analysis/TFLiteUtils.cpp
+++ b/llvm/lib/Analysis/TFLiteUtils.cpp
@@ -134,6 +134,7 @@ TFModelEvaluatorImpl::TFModelEvaluatorImpl(
for (size_t I = 0; I < Interpreter->outputs().size(); ++I)
OutputsMap[Interpreter->GetOutputName(I)] = I;
+ size_t NumberFeaturesPassed = 0;
for (size_t I = 0; I < InputSpecs.size(); ++I) {
auto &InputSpec = InputSpecs[I];
auto MapI = InputsMap.find(InputSpec.name() + ":" +
@@ -147,6 +148,14 @@ TFModelEvaluatorImpl::TFModelEvaluatorImpl(
return;
std::memset(Input[I]->data.data, 0,
InputSpecs[I].getTotalTensorBufferSize());
+ ++NumberFeaturesPassed;
+ }
+
+ if (NumberFeaturesPassed < Interpreter->inputs().size()) {
+ // we haven't passed all the required features to the model, throw an error.
+ errs() << "Required feature(s) have not been passed to the ML model";
+ invalidate();
+ return;
}
for (size_t I = 0; I < OutputSpecsSize; ++I) {
diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp
index fe3b115822bee..c604afd86d904 100644
--- a/llvm/unittests/Analysis/TFUtilsTest.cpp
+++ b/llvm/unittests/Analysis/TFUtilsTest.cpp
@@ -121,3 +121,12 @@ TEST(TFUtilsTest, UnsupportedFeature) {
for (auto I = 0; I < 2 * 5; ++I)
EXPECT_FLOAT_EQ(F[I], 3.14 + I);
}
+
+TEST(TFUtilsTest, MissingFeature) {
+ std::vector<TensorSpec> InputSpecs{};
+ std::vector<TensorSpec> OutputSpecs{
+ TensorSpec::createSpec<float>("StatefulPartitionedCall", {1})};
+
+ TFModelEvaluator Evaluator(getModelPath(), InputSpecs, OutputSpecs);
+ EXPECT_FALSE(Evaluator.isValid());
+}
More information about the llvm-commits
mailing list