[llvm] [mlgo] Support composite AOT-ed models (PR #96276)
Aiden Grossman via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 21 11:17:33 PDT 2024
================
@@ -17,37 +17,91 @@
#include "llvm/Analysis/MLModelRunner.h"
#include "llvm/Analysis/TensorSpec.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MD5.h"
#include <memory>
-#include <vector>
namespace llvm {
/// ReleaseModeModelRunner - production mode implementation of the
/// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution.
+struct EmbeddedModelRunnerOptions {
+ /// Feed and Fetch feature prefixes - i.e. a feature named "foo" will be
+ /// looked up as {FeedPrefix}_foo; and the output named "bar" will be looked
+ /// up as {FetchPrefix}_bar
+ StringRef FeedPrefix = "feed_";
+ StringRef FetchPrefix = "fetch_";
+
+ /// ModelSelector is the name (recognized by the AOT-ed model) of a sub-model
+ /// to use. "" is allowed if the model doesn't support sub-models.
+ StringRef ModelSelector = "";
+
+ EmbeddedModelRunnerOptions &setFeedPrefix(StringRef Value) {
+ FeedPrefix = Value;
+ return *this;
+ }
+ EmbeddedModelRunnerOptions &setFetchPrefix(StringRef Value) {
+ FetchPrefix = Value;
+ return *this;
+ }
+ EmbeddedModelRunnerOptions &setModelSelector(StringRef Value) {
+ ModelSelector = Value;
+ return *this;
+ }
+};
+
template <class TGen>
class ReleaseModeModelRunner final : public MLModelRunner {
public:
/// FeatureNames' type should be an indexed collection of std::string, like
/// std::array or std::vector, that has a size() method.
template <class FType>
ReleaseModeModelRunner(LLVMContext &Ctx, const FType &InputSpec,
- StringRef DecisionName, StringRef FeedPrefix = "feed_",
- StringRef FetchPrefix = "fetch_")
- : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size()),
+ StringRef DecisionName,
+ const EmbeddedModelRunnerOptions &Options = {})
+ : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size() + 1),
CompiledModel(std::make_unique<TGen>()) {
assert(CompiledModel && "The CompiledModel should be valid");
-
- for (size_t I = 0; I < InputSpec.size(); ++I) {
- const int Index =
- CompiledModel->LookupArgIndex(FeedPrefix.str() + InputSpec[I].name());
- void *Buffer = nullptr;
- if (Index >= 0)
- Buffer = CompiledModel->arg_data(Index);
- setUpBufferForTensor(I, InputSpec[I], Buffer);
+ // Set up the model_selector past all the InputSpecs in all cases.
+ // - if the model doesn't have such a feature, but the user requested it,
+ // we report error. Same if the model supports it but the user didn't
+ // specify it
+ // - finally, we compute the MD5 hash of the user input and set the value
+ // of the model selector to {high, low}
+ bool InputIsPresent = true;
+ populateTensor(InputSpec.size(),
+ TensorSpec::createSpec<uint64_t>("_model_selector", {2}),
+ Options.FeedPrefix, InputIsPresent);
+
+ // If we hit the "report an error" cases outlined above, continue with the
+ // set up in case there's some custom diagnostics handler installed and it
+ // doesn't promptly exit.
+ if (Options.ModelSelector.empty() && InputIsPresent)
+ Ctx.emitError(
+ "A model selector was not specified but the underlying model "
+ "requires selecting one because it exposes a _model_selector input");
+ uint64_t High = 0;
+ uint64_t Low = 0;
+ if (!Options.ModelSelector.empty()) {
+ if (!InputIsPresent)
+ Ctx.emitError("A model selector was specified but the underlying model "
+ "does not expose a _model_selector input");
+ const auto Hash = MD5::hash(
+ {reinterpret_cast<const uint8_t *>(Options.ModelSelector.data()),
+ Options.ModelSelector.size()});
+
+ High = Hash.high();
+ Low = Hash.low();
}
-
- ResultIndex = CompiledModel->LookupResultIndex(FetchPrefix.str() +
+ getTensor<uint64_t>(InputSpec.size())[0] = High;
----------------
boomanaiden154 wrote:
We're separating the high and low values due to Tensorflows lack of 128-bit integer support?
Not sure it would simplify things much, but we could probably just use either the high or low bits since the total entropy doesn't really matter in this case?
https://github.com/llvm/llvm-project/pull/96276
More information about the llvm-commits
mailing list