[llvm] [mlgo] Support composite AOT-ed models (PR #96276)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 24 09:24:15 PDT 2024
================
@@ -17,37 +17,91 @@
#include "llvm/Analysis/MLModelRunner.h"
#include "llvm/Analysis/TensorSpec.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MD5.h"
#include <memory>
-#include <vector>
namespace llvm {
/// ReleaseModeModelRunner - production mode implementation of the
/// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution.
+struct EmbeddedModelRunnerOptions {
+ /// Feed and Fetch feature prefixes - i.e. a feature named "foo" will be
+ /// looked up as {FeedPrefix}_foo; and the output named "bar" will be looked
+ /// up as {FetchPrefix}_bar
+ StringRef FeedPrefix = "feed_";
+ StringRef FetchPrefix = "fetch_";
+
+ /// ModelSelector is the name (recognized by the AOT-ed model) of a sub-model
+ /// to use. "" is allowed if the model doesn't support sub-models.
+ StringRef ModelSelector = "";
+
+ EmbeddedModelRunnerOptions &setFeedPrefix(StringRef Value) {
+ FeedPrefix = Value;
+ return *this;
+ }
+ EmbeddedModelRunnerOptions &setFetchPrefix(StringRef Value) {
+ FetchPrefix = Value;
+ return *this;
+ }
+ EmbeddedModelRunnerOptions &setModelSelector(StringRef Value) {
+ ModelSelector = Value;
+ return *this;
+ }
+};
+
template <class TGen>
class ReleaseModeModelRunner final : public MLModelRunner {
public:
/// FeatureNames' type should be an indexed collection of std::string, like
/// std::array or std::vector, that has a size() method.
template <class FType>
ReleaseModeModelRunner(LLVMContext &Ctx, const FType &InputSpec,
- StringRef DecisionName, StringRef FeedPrefix = "feed_",
- StringRef FetchPrefix = "fetch_")
- : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size()),
+ StringRef DecisionName,
+ const EmbeddedModelRunnerOptions &Options = {})
+ : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size() + 1),
CompiledModel(std::make_unique<TGen>()) {
assert(CompiledModel && "The CompiledModel should be valid");
-
- for (size_t I = 0; I < InputSpec.size(); ++I) {
- const int Index =
- CompiledModel->LookupArgIndex(FeedPrefix.str() + InputSpec[I].name());
- void *Buffer = nullptr;
- if (Index >= 0)
- Buffer = CompiledModel->arg_data(Index);
- setUpBufferForTensor(I, InputSpec[I], Buffer);
+ // Set up the model_selector past all the InputSpecs in all cases.
+ // - if the model doesn't have such a feature, but the user requested it,
+ // we report error. Same if the model supports it but the user didn't
+ // specify it
+ // - finally, we compute the MD5 hash of the user input and set the value
+ // of the model selector to {high, low}
+ bool InputIsPresent = true;
+ populateTensor(InputSpec.size(),
+ TensorSpec::createSpec<uint64_t>("_model_selector", {2}),
+ Options.FeedPrefix, InputIsPresent);
+
+ // If we hit the "report an error" cases outlined above, continue with the
+ // set up in case there's some custom diagnostics handler installed and it
+ // doesn't promptly exit.
+ if (Options.ModelSelector.empty() && InputIsPresent)
+ Ctx.emitError(
+ "A model selector was not specified but the underlying model "
+ "requires selecting one because it exposes a _model_selector input");
+ uint64_t High = 0;
+ uint64_t Low = 0;
+ if (!Options.ModelSelector.empty()) {
+ if (!InputIsPresent)
+ Ctx.emitError("A model selector was specified but the underlying model "
+ "does not expose a _model_selector input");
+ const auto Hash = MD5::hash(
+ {reinterpret_cast<const uint8_t *>(Options.ModelSelector.data()),
+ Options.ModelSelector.size()});
+
+ High = Hash.high();
+ Low = Hash.low();
}
-
- ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() +
+ getTensor<uint64_t>(InputSpec.size())[0] = High;
+ getTensor<uint64_t>(InputSpec.size())[1] = Low;
+ // At this point, the model selector is set up. If the user didn't provide
+ // one, but the model has a _model_selector, it'll be set to (0, 0) which
----------------
mtrofin wrote:
right. added more blurbing.
https://github.com/llvm/llvm-project/pull/96276
More information about the llvm-commits
mailing list