[llvm] [mlgo] Support composite AOT-ed models (PR #96276)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 24 09:20:08 PDT 2024
================
@@ -17,37 +17,91 @@
#include "llvm/Analysis/MLModelRunner.h"
#include "llvm/Analysis/TensorSpec.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MD5.h"
#include <memory>
-#include <vector>
namespace llvm {
/// ReleaseModeModelRunner - production mode implementation of the
/// MLModelRunner. It uses an AOT-compiled SavedModel for efficient execution.
+struct EmbeddedModelRunnerOptions {
+ /// Feed and Fetch feature prefixes - i.e. a feature named "foo" will be
+ /// looked up as {FeedPrefix}_foo; and the output named "bar" will be looked
+ /// up as {FetchPrefix}_bar
+ StringRef FeedPrefix = "feed_";
+ StringRef FetchPrefix = "fetch_";
+
+ /// ModelSelector is the name (recognized by the AOT-ed model) of a sub-model
+ /// to use. "" is allowed if the model doesn't support sub-models.
+ StringRef ModelSelector = "";
+
+ EmbeddedModelRunnerOptions &setFeedPrefix(StringRef Value) {
+ FeedPrefix = Value;
+ return *this;
+ }
+ EmbeddedModelRunnerOptions &setFetchPrefix(StringRef Value) {
+ FetchPrefix = Value;
+ return *this;
+ }
+ EmbeddedModelRunnerOptions &setModelSelector(StringRef Value) {
+ ModelSelector = Value;
+ return *this;
+ }
+};
+
template <class TGen>
class ReleaseModeModelRunner final : public MLModelRunner {
public:
/// FeatureNames' type should be an indexed collection of std::string, like
/// std::array or std::vector, that has a size() method.
template <class FType>
ReleaseModeModelRunner(LLVMContext &Ctx, const FType &InputSpec,
- StringRef DecisionName, StringRef FeedPrefix = "feed_",
- StringRef FetchPrefix = "fetch_")
- : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size()),
+ StringRef DecisionName,
+ const EmbeddedModelRunnerOptions &Options = {})
+ : MLModelRunner(Ctx, MLModelRunner::Kind::Release, InputSpec.size() + 1),
CompiledModel(std::make_unique<TGen>()) {
assert(CompiledModel && "The CompiledModel should be valid");
-
- for (size_t I = 0; I < InputSpec.size(); ++I) {
- const int Index =
- CompiledModel->LookupArgIndex(FeedPrefix.str() + InputSpec[I].name());
- void *Buffer = nullptr;
- if (Index >= 0)
- Buffer = CompiledModel->arg_data(Index);
- setUpBufferForTensor(I, InputSpec[I], Buffer);
+ // Set up the model_selector past all the InputSpecs in all cases.
+ // - if the model doesn't have such a feature, but the user requested it,
+ // we report error. Same if the model supports it but the user didn't
+ // specify it
+ // - finally, we compute the MD5 hash of the user input and set the value
----------------
mtrofin wrote:
Yes, plus it's deterministic.
https://github.com/llvm/llvm-project/pull/96276
More information about the llvm-commits
mailing list