[llvm] [mlgo] Support composite AOT-ed models (PR #96276)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 24 09:23:01 PDT 2024


================
@@ -17,37 +17,91 @@
 #include "llvm/Analysis/MLModelRunner.h"
 #include "llvm/Analysis/TensorSpec.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MD5.h"
 
 #include <memory>
-#include <vector>
 
 namespace llvm {
 
 /// ReleaseModeModelRunner - production mode implementation of the
 /// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution.
+struct EmbeddedModelRunnerOptions {
+  /// Feed and Fetch feature prefixes - i.e. a feature named "foo" will be
+  /// looked up as {FeedPrefix}_foo; and the output named "bar" will be looked
+  /// up as {FetchPrefix}_bar
+  StringRef FeedPrefix = "feed_";
+  StringRef FetchPrefix = "fetch_";
+
+  /// ModelSelector is the name (recognized by the AOT-ed model) of a sub-model
+  /// to use. "" is allowed if the model doesn't support sub-models.
+  StringRef ModelSelector = "";
+
+  EmbeddedModelRunnerOptions &setFeedPrefix(StringRef Value) {
+    FeedPrefix = Value;
+    return *this;
+  }
+  EmbeddedModelRunnerOptions &setFetchPrefix(StringRef Value) {
+    FetchPrefix = Value;
+    return *this;
+  }
+  EmbeddedModelRunnerOptions &setModelSelector(StringRef Value) {
+    ModelSelector = Value;
+    return *this;
+  }
+};
+
 template <class TGen>
 class ReleaseModeModelRunner final : public MLModelRunner {
 public:
   /// FeatureNames' type should be an indexed collection of std::string, like
   /// std::array or std::vector, that has a size() method.
   template <class FType>
   ReleaseModeModelRunner(LLVMContext &Ctx, const FType &InputSpec,
-                         StringRef DecisionName, StringRef FeedPrefix = "feed_",
-                         StringRef FetchPrefix = "fetch_")
-      : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size()),
+                         StringRef DecisionName,
+                         const EmbeddedModelRunnerOptions &Options = {})
+      : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size() + 1),
         CompiledModel(std::make_unique<TGen>()) {
     assert(CompiledModel && "The CompiledModel should be valid");
-
-    for (size_t I = 0; I < InputSpec.size(); ++I) {
-      const int Index =
-          CompiledModel->LookupArgIndex(FeedPrefix.str() + InputSpec[I].name());
-      void *Buffer = nullptr;
-      if (Index >= 0)
-        Buffer = CompiledModel->arg_data(Index);
-      setUpBufferForTensor(I, InputSpec[I], Buffer);
+    // Set up the model_selector past all the InputSpecs in all cases.
+    //   - if the model doesn't have such a feature, but the user requested it,
+    //   we report error. Same if the model supports it but the user didn't
+    //   specify it
+    //   - finally, we compute the MD5 hash of the user input and set the value
+    //   of the model selector to {high, low}
+    bool InputIsPresent = true;
+    populateTensor(InputSpec.size(),
+                   TensorSpec::createSpec<uint64_t>("_model_selector", {2}),
+                   Options.FeedPrefix, InputIsPresent);
+
+    // If we hit the "report an error" cases outlined above, continue with the
+    // set up in case there's some custom diagnostics handler installed and it
+    // doesn't promptly exit.
+    if (Options.ModelSelector.empty() && InputIsPresent)
+      Ctx.emitError(
+          "A model selector was not specified but the underlying model "
+          "requires selecting one because it exposes a _model_selector input");
+    uint64_t High = 0;
+    uint64_t Low = 0;
+    if (!Options.ModelSelector.empty()) {
+      if (!InputIsPresent)
+        Ctx.emitError("A model selector was specified but the underlying model "
+                      "does not expose a _model_selector input");
+      const auto Hash = MD5::hash(
+          {reinterpret_cast<const uint8_t *>(Options.ModelSelector.data()),
+           Options.ModelSelector.size()});
+
+      High = Hash.high();
+      Low = Hash.low();
     }
-
-    ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() +
+    getTensor<uint64_t>(InputSpec.size())[0] = High;
----------------
mtrofin wrote:

but why raise the risk of collisions when the cost right now is basically neglijible?

https://github.com/llvm/llvm-project/pull/96276


More information about the llvm-commits mailing list