[llvm] [mlgo] Support composite AOT-ed models (PR #96276)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 24 09:20:08 PDT 2024


================
@@ -17,37 +17,91 @@
 #include "llvm/Analysis/MLModelRunner.h"
 #include "llvm/Analysis/TensorSpec.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MD5.h"
 
 #include <memory>
-#include <vector>
 
 namespace llvm {
 
 /// ReleaseModeModelRunner - production mode implementation of the
 /// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution.
+struct EmbeddedModelRunnerOptions {
+  /// Feed and Fetch feature prefixes - i.e. a feature named "foo" will be
+  /// looked up as {FeedPrefix}_foo; and the output named "bar" will be looked
+  /// up as {FetchPrefix}_bar
+  StringRef FeedPrefix = "feed_";
+  StringRef FetchPrefix = "fetch_";
+
+  /// ModelSelector is the name (recognized by the AOT-ed model) of a sub-model
+  /// to use. "" is allowed if the model doesn't support sub-models.
+  StringRef ModelSelector = "";
+
+  EmbeddedModelRunnerOptions &setFeedPrefix(StringRef Value) {
+    FeedPrefix = Value;
+    return *this;
+  }
+  EmbeddedModelRunnerOptions &setFetchPrefix(StringRef Value) {
+    FetchPrefix = Value;
+    return *this;
+  }
+  EmbeddedModelRunnerOptions &setModelSelector(StringRef Value) {
+    ModelSelector = Value;
+    return *this;
+  }
+};
+
 template <class TGen>
 class ReleaseModeModelRunner final : public MLModelRunner {
 public:
   /// FeatureNames' type should be an indexed collection of std::string, like
   /// std::array or std::vector, that has a size() method.
   template <class FType>
   ReleaseModeModelRunner(LLVMContext &Ctx, const FType &InputSpec,
-                         StringRef DecisionName, StringRef FeedPrefix = "feed_",
-                         StringRef FetchPrefix = "fetch_")
-      : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size()),
+                         StringRef DecisionName,
+                         const EmbeddedModelRunnerOptions &Options = {})
+      : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size() + 1),
         CompiledModel(std::make_unique<TGen>()) {
     assert(CompiledModel && "The CompiledModel should be valid");
-
-    for (size_t I = 0; I < InputSpec.size(); ++I) {
-      const int Index =
-          CompiledModel->LookupArgIndex(FeedPrefix.str() + InputSpec[I].name());
-      void *Buffer = nullptr;
-      if (Index >= 0)
-        Buffer = CompiledModel->arg_data(Index);
-      setUpBufferForTensor(I, InputSpec[I], Buffer);
+    // Set up the model_selector past all the InputSpecs in all cases.
+    //   - if the model doesn't have such a feature, but the user requested it,
+    //   we report error. Same if the model supports it but the user didn't
+    //   specify it
+    //   - finally, we compute the MD5 hash of the user input and set the value
----------------
mtrofin wrote:

Yes, plus it's deterministic.

https://github.com/llvm/llvm-project/pull/96276


More information about the llvm-commits mailing list