[llvm] e67430c - [MLGO] ML Regalloc Eviction Advisor

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 19 11:00:40 PST 2022


Author: Mircea Trofin
Date: 2022-01-19T11:00:32-08:00
New Revision: e67430cca40455d31b95b088a88fa3b16a37ea34

URL: https://github.com/llvm/llvm-project/commit/e67430cca40455d31b95b088a88fa3b16a37ea34
DIFF: https://github.com/llvm/llvm-project/commit/e67430cca40455d31b95b088a88fa3b16a37ea34.diff

LOG: [MLGO] ML Regalloc Eviction Advisor

The bulk of the implementation is common between 'release' mode (==AOT-ed
model) and 'development' mode (for training), the main difference is
that in development mode, we may also log features (for training logs),
inject scoring information (currently after the Virtual Register
Rewriter) and then produce the log file.

This patch also introduces the score injection pass, 'Register
Allocation Pass Scoring', which is trivially just logging the score in
development mode.

Differential Revision: https://reviews.llvm.org/D117147

Added: 
    llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
    llvm/test/CodeGen/MLRegalloc/Inputs/input.ll
    llvm/test/CodeGen/MLRegalloc/dev-mode-log-2-fcts.ll
    llvm/test/CodeGen/MLRegalloc/dev-mode-logging.ll
    llvm/test/CodeGen/MLRegalloc/dev-rel-equivalence.ll
    llvm/test/CodeGen/MLRegalloc/rel-codepath.ll

Modified: 
    llvm/CMakeLists.txt
    llvm/include/llvm/CodeGen/Passes.h
    llvm/include/llvm/InitializePasses.h
    llvm/lib/CodeGen/CMakeLists.txt
    llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp
    llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp
    llvm/lib/CodeGen/RegAllocGreedy.h
    llvm/lib/CodeGen/TargetPassConfig.cpp
    llvm/test/CodeGen/AArch64/O3-pipeline.ll
    llvm/test/CodeGen/ARM/O3-pipeline.ll
    llvm/test/CodeGen/PowerPC/O3-pipeline.ll
    llvm/test/CodeGen/X86/opt-pipeline.ll

Removed: 
    


################################################################################
diff  --git a/llvm/CMakeLists.txt b/llvm/CMakeLists.txt
index ab5317757af76..4f248c58e1822 100644
--- a/llvm/CMakeLists.txt
+++ b/llvm/CMakeLists.txt
@@ -897,6 +897,13 @@ if (NOT TENSORFLOW_AOT_PATH STREQUAL "")
     set(LLVM_INLINER_MODEL_PATH "autogenerate")
     set(LLVM_INLINER_MODEL_AUTOGENERATED 1)
   endif()
+  if (NOT DEFINED LLVM_RAEVICT_MODEL_PATH
+      OR "${LLVM_RAEVICT_MODEL_PATH}" STREQUAL ""
+      OR "${LLVM_RAEVICT_MODEL_PATH}" STREQUAL "autogenerate")
+    set(LLVM_RAEVICT_MODEL_PATH "autogenerate")
+    set(LLVM_RAEVICT_MODEL_AUTOGENERATED 1)
+  endif()
+
 endif()
 
 # Configure the three LLVM configuration header files.

diff  --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index f4c6edba61f22..616ab10341334 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -550,6 +550,10 @@ namespace llvm {
   /// The pass transforms amx intrinsics to scalar operation if the function has
   /// optnone attribute or it is O0.
   FunctionPass *createX86LowerAMXIntrinsicsPass();
+
+  /// When learning an eviction policy, extract score(reward) information,
+  /// otherwise this does nothing
+  FunctionPass *createRegAllocScoringPass();
 } // End llvm namespace
 
 #endif

diff  --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index 0c5ebc9a2f289..489ef045796f0 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -380,6 +380,7 @@ void initializeReassociateLegacyPassPass(PassRegistry&);
 void initializeRedundantDbgInstEliminationPass(PassRegistry&);
 void initializeRegAllocEvictionAdvisorAnalysisPass(PassRegistry &);
 void initializeRegAllocFastPass(PassRegistry&);
+void initializeRegAllocScoringPass(PassRegistry &);
 void initializeRegBankSelectPass(PassRegistry&);
 void initializeRegToMemLegacyPass(PassRegistry&);
 void initializeRegUsageInfoCollectorPass(PassRegistry&);

diff  --git a/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py b/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
new file mode 100644
index 0000000000000..1cb2492f4776b
--- /dev/null
+++ b/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py
@@ -0,0 +1,103 @@
+"""Generate a mock model for LLVM tests for Register Allocation.
+The generated model is not a neural net - it is just a tf.function with the
+correct input and output parameters. By construction, the mock model will always
+output the first liverange that can be evicted.
+"""
+import os
+import sys
+import tensorflow as tf
+POLICY_DECISION_LABEL = 'index_to_evict'
+POLICY_OUTPUT_SPEC = """
+[
+    {
+        "logging_name": "index_to_evict",
+        "tensor_spec": {
+            "name": "StatefulPartitionedCall",
+            "port": 0,
+            "type": "int64_t",
+            "shape": [
+                1
+            ]
+        }
+    }
+]
+"""
+PER_REGISTER_INT64_FEATURE_LIST = [
+    'mask', 'is_hint', 'is_local', 'is_free', 'max_stage', 'min_stage'
+]
+PER_REGISTER_FLOAT32_FEATURE_LIST = ['nr_urgent',
+    'weighed_reads_by_max', 'weighed_writes_by_max',
+    'weighed_read_writes_by_max', 'weighed_indvars_by_max',
+    'hint_weights_by_max', 'start_bb_freq_by_max', 'end_bb_freq_by_max',
+    'hottest_bb_freq_by_max', 'liverange_size', 'use_def_density',
+    'nr_defs_and_uses', 'nr_broken_hints', 'nr_rematerializable'
+]
+PER_REGISTER_FEATURE_LIST = PER_REGISTER_FLOAT32_FEATURE_LIST + \
+    PER_REGISTER_INT64_FEATURE_LIST
+CONTEXT_FEATURE_LIST = ('progress', 'discount', 'reward', 'step_type')
+NUM_REGISTERS = 33
+
+
+def get_input_signature():
+  """Returns (time_step_spec, action_spec) for LLVM register allocation."""
+  inputs = dict(
+      (key, tf.TensorSpec(dtype=tf.int64, shape=(NUM_REGISTERS), name=key))
+      for key in PER_REGISTER_INT64_FEATURE_LIST)
+  inputs.update(
+      dict((key,
+            tf.TensorSpec(dtype=tf.float32, shape=(NUM_REGISTERS), name=key))
+           for key in PER_REGISTER_FLOAT32_FEATURE_LIST))
+  inputs['progress'] = tf.TensorSpec(
+      dtype=tf.float32, shape=(), name='progress')
+  inputs.update(
+      dict((key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
+           for key in ['discount', 'reward']))
+  inputs.update(
+      dict((key, tf.TensorSpec(dtype=tf.int32, shape=(), name=key))
+           for key in ['step_type']))
+  return inputs
+
+
+def get_output_spec_path(path):
+  return os.path.join(path, 'output_spec.json')
+
+
+def build_mock_model(path):
+  """Build and save the mock model with the given signature."""
+  module = tf.Module()
+  # We have to set this useless variable in order for the TF C API to correctly
+  # intake it
+  module.var = tf.Variable(0, dtype=tf.int64)
+
+  def action(*inputs):
+    s1 = tf.reduce_sum([
+        tf.cast(inputs[0][key], tf.float32) for key in PER_REGISTER_FEATURE_LIST
+    ],
+        axis=0)
+    s2 = tf.reduce_sum(
+        [tf.cast(inputs[0][key], tf.float32) for key in CONTEXT_FEATURE_LIST])
+    # Add a large number so s won't be 0.
+    s = s1 + s2 + 123456789.123456789
+    # Equals to mask feature.
+    mask_alias = tf.not_equal(s * tf.cast(inputs[0]['mask'], tf.float32), 0)
+    result = tf.math.argmax(mask_alias, axis=-1) + module.var
+    return {POLICY_DECISION_LABEL: result}
+  module.action = tf.function()(action)
+  action = {
+      'action': module.action.get_concrete_function(get_input_signature())
+  }
+  tf.saved_model.save(module, path, signatures=action)
+  output_spec_path = get_output_spec_path(path)
+  with open(output_spec_path, 'w') as f:
+    print(f'Writing output spec to {output_spec_path}.')
+    f.write(POLICY_OUTPUT_SPEC)
+
+
+def main(argv):
+  assert len(argv) == 2
+  model_path = argv[1]
+  build_mock_model(model_path)
+
+
+if __name__ == '__main__':
+  main(sys.argv)

diff  --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt
index db5a6ffd826b0..8dce5d2e68ca3 100644
--- a/llvm/lib/CodeGen/CMakeLists.txt
+++ b/llvm/lib/CodeGen/CMakeLists.txt
@@ -1,3 +1,30 @@
+if (DEFINED LLVM_HAVE_TF_AOT OR DEFINED LLVM_HAVE_TF_API)
+  include(TensorFlowCompile)
+  set(LLVM_RAEVICT_MODEL_PATH_DEFAULT "models/regalloc-eviction")
+
+  # This url points to the most recent most which is known to be compatible with
+  # LLVM. When better models are published, this url should be updated to aid
+  # discoverability.
+  set(LLVM_RAEVICT_MODEL_CURRENT_URL "TO_BE_UPDATED")
+
+  if (DEFINED LLVM_HAVE_TF_AOT)
+    tf_find_and_compile(
+      ${LLVM_RAEVICT_MODEL_PATH}
+      ${LLVM_RAEVICT_MODEL_CURRENT_URL}
+      ${LLVM_RAEVICT_MODEL_PATH_DEFAULT}
+      "../Analysis/models/gen-regalloc-eviction-test-model.py"
+      serve
+      action
+      RegallocEvictModel
+      llvm::RegallocEvictModel
+    )
+  endif()
+
+  if (DEFINED LLVM_HAVE_TF_API)
+    list(APPEND MLLinkDeps ${tensorflow_c_api} ${tensorflow_fx})
+  endif()
+endif()
+
 add_llvm_component_library(LLVMCodeGen
   AggressiveAntiDepBreaker.cpp
   AllocationOrder.cpp
@@ -199,6 +226,7 @@ add_llvm_component_library(LLVMCodeGen
   WasmEHPrepare.cpp
   WinEHPrepare.cpp
   XRayInstrumentation.cpp
+  ${GeneratedMLSources}
 
   LiveDebugValues/LiveDebugValues.cpp
   LiveDebugValues/VarLocBasedImpl.cpp
@@ -208,10 +236,11 @@ add_llvm_component_library(LLVMCodeGen
   ${LLVM_MAIN_INCLUDE_DIR}/llvm/CodeGen
   ${LLVM_MAIN_INCLUDE_DIR}/llvm/CodeGen/PBQP
 
-  LINK_LIBS ${LLVM_PTHREAD_LIB}
+  LINK_LIBS ${LLVM_PTHREAD_LIB} ${MLLinkDeps}
 
   DEPENDS
   intrinsics_gen
+  ${MLDeps}
 
   LINK_COMPONENTS
   Analysis

diff  --git a/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp b/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp
index 7a6c4fade06e4..a07c28306fd41 100644
--- a/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp
+++ b/llvm/lib/CodeGen/MLRegallocEvictAdvisor.cpp
@@ -11,14 +11,21 @@
 //===----------------------------------------------------------------------===//
 
 #include "RegAllocEvictionAdvisor.h"
+#include "RegAllocGreedy.h"
+#include "RegAllocScore.h"
+#include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/MLModelRunner.h"
 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
 #include "llvm/Analysis/NoInferenceModelRunner.h"
+#include "llvm/Analysis/ReleaseModeModelRunner.h"
 #include "llvm/Analysis/Utils/TFUtils.h"
 #include "llvm/CodeGen/CalcSpillWeights.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineLoopInfo.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/Passes.h"
 #include "llvm/CodeGen/RegisterClassInfo.h"
 #include "llvm/CodeGen/VirtRegMap.h"
 #include "llvm/Config/config.h"
@@ -29,12 +36,73 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Target/TargetMachine.h"
 
+#include <array>
 #include <memory>
 
 using namespace llvm;
 
 #define DEBUG_TYPE "ml-regalloc"
 
+// Generated header in release (AOT) mode
+#if defined LLVM_HAVE_TF_AOT
+#include "RegallocEvictModel.h"
+#endif
+
+// Options that only make sense in development mode
+#ifdef LLVM_HAVE_TF_API
+static cl::opt<std::string> TrainingLog(
+    "regalloc-training-log", cl::Hidden,
+    cl::desc("Training log for the register allocator eviction model"));
+
+static cl::opt<std::string> ModelUnderTraining(
+    "regalloc-model", cl::Hidden,
+    cl::desc("The model being trained for register allocation eviction"));
+
+#endif // #ifdef LLVM_HAVE_TF_API
+
+/// The score injection pass.
+/// This pass calculates the score for a function and inserts it in the log, but
+/// this happens only in development mode. It's a no-op otherwise.
+namespace llvm {
+class RegAllocScoring : public MachineFunctionPass {
+public:
+  static char ID;
+
+  RegAllocScoring() : MachineFunctionPass(ID) {
+    initializeRegAllocScoringPass(*PassRegistry::getPassRegistry());
+  }
+
+  ~RegAllocScoring() override = default;
+
+  StringRef getPassName() const override {
+    return "Register Allocation Pass Scoring";
+  }
+
+  /// RegAllocReward analysis usage.
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesAll();
+    AU.addRequired<RegAllocEvictionAdvisorAnalysis>();
+    AU.addRequired<MachineBlockFrequencyInfo>();
+    AU.addRequired<AAResultsWrapperPass>();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  /// Performs this pass
+  bool runOnMachineFunction(MachineFunction &) override;
+};
+
+char RegAllocScoring::ID = 0;
+FunctionPass *llvm::createRegAllocScoringPass() {
+  return new RegAllocScoring();
+}
+
+INITIALIZE_PASS(RegAllocScoring, "regallocscoringpass",
+                "Register Allocation Scoring Pass", false, false)
+} // namespace llvm
+
+// ===================================
+// Common ML Advisor declarations
+// ===================================
 #if defined(LLVM_HAVE_TF_AOT) || defined(LLVM_HAVE_TF_API)
 namespace {
 // This is the maximum number of interfererring ranges. That's the number of
@@ -152,18 +220,646 @@ void resetInputs(MLModelRunner &Runner) {
 #undef _RESET
 }
 
+using CandidateRegList =
+    std::array<std::pair<MCRegister, bool>, NumberOfInterferences>;
+using FeaturesListNormalizer = std::array<float, FeatureIDs::FeatureCount>;
+
+/// The ML evictor (commonalities between release and development mode)
+class MLEvictAdvisor : public RegAllocEvictionAdvisor {
+public:
+  MLEvictAdvisor(const MachineFunction &MF, const RAGreedy &RA,
+                 MLModelRunner *Runner, const MachineBlockFrequencyInfo &MBFI,
+                 const MachineLoopInfo &Loops);
+
+protected:
+  const RegAllocEvictionAdvisor &getDefaultAdvisor() const {
+    return static_cast<const RegAllocEvictionAdvisor &>(DefaultAdvisor);
+  }
+
+  // The assumption is that if the Runner could not be constructed, we emit-ed
+  // error, and we shouldn't be asking for it here.
+  const MLModelRunner &getRunner() const { return *Runner; }
+
+  /// This just calls Evaluate on the Runner, but in the development mode case,
+  /// if we're just capturing the log of the default advisor, it needs to call
+  /// the latter instead, so we need to pass all the necessary parameters for
+  /// it. In the development case, it will also log.
+  virtual int64_t tryFindEvictionCandidatePosition(
+      LiveInterval &VirtReg, const AllocationOrder &Order, unsigned OrderLimit,
+      uint8_t CostPerUseLimit, const SmallVirtRegSet &FixedRegisters) const;
+
+  /// Load the features of the given VirtReg (allocated or not) at column Pos,
+  /// but if  that can't be evicted, return false instead.
+  bool
+  loadInterferenceFeatures(LiveInterval &VirtReg, MCRegister PhysReg,
+                           bool IsHint, const SmallVirtRegSet &FixedRegisters,
+                           std::array<float, FeatureIDs::FeatureCount> &Largest,
+                           size_t Pos) const;
+
+private:
+  static float getInitialQueueSize(const MachineFunction &MF);
+
+  MCRegister tryFindEvictionCandidate(
+      LiveInterval &VirtReg, const AllocationOrder &Order,
+      uint8_t CostPerUseLimit,
+      const SmallVirtRegSet &FixedRegisters) const override;
+
+  void extractFeatures(const SmallVectorImpl<LiveInterval *> &Intervals,
+                       std::array<float, FeatureIDs::FeatureCount> &Largest,
+                       size_t Pos, int64_t IsHint, int64_t LocalIntfsCount,
+                       float NrUrgent) const;
+
+  // Point-in-time: we didn't learn this, so we always delegate to the default.
+  bool canEvictHintInterference(
+      LiveInterval &VirtReg, MCRegister PhysReg,
+      const SmallVirtRegSet &FixedRegisters) const override {
+    return getDefaultAdvisor().canEvictHintInterference(VirtReg, PhysReg,
+                                                        FixedRegisters);
+  }
+
+  // Hold on to a default advisor for:
+  // 1) the implementation of canEvictHintInterference, because we didn't learn
+  // that nuance yet;
+  // 2) for bootstrapping (logging) in the development mode case.
+  const DefaultEvictionAdvisor DefaultAdvisor;
+  MLModelRunner *const Runner;
+  const MachineBlockFrequencyInfo &MBFI;
+  const MachineLoopInfo &Loops;
+
+  // Indices of those features we don't want to normalize.
+  // This could be static and shared, but its initialization is non-trivial.
+  std::bitset<FeatureIDs::FeatureCount> DoNotNormalize;
+  const float InitialQSize;
+};
+
+// ===================================
+// Release (AOT) - specifics
+// ===================================
+#ifdef LLVM_HAVE_TF_AOT
+const std::array<std::string, FeatureIDs::FeatureCount> FeatureNames{
+#define _GETNAME(_, NAME, __, ___) #NAME,
+    RA_EVICT_FEATURES_LIST(_GETNAME)
+#undef _GETNAME
+};
+
+class ReleaseModeEvictionAdvisorAnalysis final
+    : public RegAllocEvictionAdvisorAnalysis {
+public:
+  ReleaseModeEvictionAdvisorAnalysis()
+      : RegAllocEvictionAdvisorAnalysis(AdvisorMode::Release) {}
+  // support for isa<> and dyn_cast.
+  static bool classof(const RegAllocEvictionAdvisorAnalysis *R) {
+    return R->getAdvisorMode() == AdvisorMode::Release;
+  }
+
+private:
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<MachineBlockFrequencyInfo>();
+    AU.addRequired<MachineLoopInfo>();
+    RegAllocEvictionAdvisorAnalysis::getAnalysisUsage(AU);
+  }
+
+  std::unique_ptr<RegAllocEvictionAdvisor>
+  getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
+    if (!Runner)
+      Runner = std::make_unique<ReleaseModeModelRunner<RegallocEvictModel>>(
+          MF.getFunction().getContext(), FeatureNames, DecisionName);
+    return std::make_unique<MLEvictAdvisor>(
+        MF, RA, Runner.get(), getAnalysis<MachineBlockFrequencyInfo>(),
+        getAnalysis<MachineLoopInfo>());
+  }
+  std::unique_ptr<ReleaseModeModelRunner<RegallocEvictModel>> Runner;
+};
+#endif // LLVM_HAVE_TF_AOT
+
+// ===================================
 // Development mode-specifics
+// ===================================
+//
+// Features we log
 #ifdef LLVM_HAVE_TF_API
 #define _DECL_FEATURES(type, name, shape, _)                                   \
   TensorSpec::createSpec<type>(#name, shape),
 
 static const std::vector<TensorSpec> InputFeatures{
-    {RA_EVICT_FEATURES_LIST(_DECL_FEATURES)}};
+    {RA_EVICT_FEATURES_LIST(_DECL_FEATURES)},
+};
 #undef _DECL_FEATURES
 static const TensorSpec Output =
     TensorSpec::createSpec<int64_t>(DecisionName, {1});
 static const TensorSpec Reward = TensorSpec::createSpec<float>("reward", {1});
 
+// Features we bind on the model. The tensor names have a prefix, and we also
+// need to include some tensors that are expected to be present by the training
+// algo.
+// TODO: can we just get rid of these?
+#define _DECL_TRAIN_FEATURES(type, name, shape, _)                             \
+  TensorSpec::createSpec<type>(std::string("action_") + #name, shape),
+
+static const std::vector<TensorSpec> TrainingInputFeatures{
+    {RA_EVICT_FEATURES_LIST(_DECL_TRAIN_FEATURES)
+         TensorSpec::createSpec<float>("action_discount", {1}),
+     TensorSpec::createSpec<int32_t>("action_step_type", {1}),
+     TensorSpec::createSpec<float>("action_reward", {1})}};
+#undef _DECL_TRAIN_FEATURES
+
+class DevelopmentModeEvictAdvisor : public MLEvictAdvisor {
+public:
+  DevelopmentModeEvictAdvisor(const MachineFunction &MF, const RAGreedy &RA,
+                              MLModelRunner *Runner,
+                              const MachineBlockFrequencyInfo &MBFI,
+                              const MachineLoopInfo &Loops, Logger *Log)
+      : MLEvictAdvisor(MF, RA, Runner, MBFI, Loops), Log(Log) {}
+
+private:
+  int64_t tryFindEvictionCandidatePosition(
+      LiveInterval &VirtReg, const AllocationOrder &Order, unsigned OrderLimit,
+      uint8_t CostPerUseLimit,
+      const SmallVirtRegSet &FixedRegisters) const override;
+
+  Logger *const Log;
+};
+
+class DevelopmentModeEvictionAdvisorAnalysis final
+    : public RegAllocEvictionAdvisorAnalysis {
+public:
+  DevelopmentModeEvictionAdvisorAnalysis()
+      : RegAllocEvictionAdvisorAnalysis(AdvisorMode::Development) {}
+  // support for isa<> and dyn_cast.
+  static bool classof(const RegAllocEvictionAdvisorAnalysis *R) {
+    return R->getAdvisorMode() == AdvisorMode::Development;
+  }
+
+  /// get the logger for the given function, or nullptr if we didn't collect
+  /// one. This is used to inject the score by the RegAllocScoring pass.
+  Logger *getLogger(const MachineFunction &MF) const {
+    auto I = LogMap.find(MF.getName());
+    if (I == LogMap.end())
+      return nullptr;
+    return I->second.get();
+  }
+
+private:
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<MachineBlockFrequencyInfo>();
+    AU.addRequired<MachineLoopInfo>();
+    RegAllocEvictionAdvisorAnalysis::getAnalysisUsage(AU);
+  }
+
+  // Save all the logs (when requested).
+  bool doFinalization(Module &M) override {
+    if (TrainingLog.empty())
+      return false;
+    std::error_code EC;
+    auto OS = std::make_unique<raw_fd_ostream>(TrainingLog, EC);
+    if (EC) {
+      M.getContext().emitError(EC.message() + ":" + TrainingLog);
+      return false;
+    }
+    Logger::flushLogs(*OS, LogMap);
+    return false;
+  }
+
+  std::unique_ptr<RegAllocEvictionAdvisor>
+  getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
+    LLVMContext &Ctx = MF.getFunction().getContext();
+    if (ModelUnderTraining.empty() && TrainingLog.empty()) {
+      Ctx.emitError("Regalloc development mode should be requested with at "
+                    "least logging enabled and/or a training model");
+      return nullptr;
+    }
+    if (!Runner) {
+      if (ModelUnderTraining.empty())
+        Runner = std::make_unique<NoInferenceModelRunner>(Ctx, InputFeatures);
+      else
+        Runner = ModelUnderTrainingRunner::createAndEnsureValid(
+            Ctx, ModelUnderTraining, DecisionName, TrainingInputFeatures);
+      if (!Runner) {
+        Ctx.emitError("Regalloc: could not set up the model runner");
+        return nullptr;
+      }
+    }
+
+    Logger *Log = nullptr;
+    if (!TrainingLog.empty()) {
+      std::vector<LoggedFeatureSpec> LFS;
+      for (const auto &FS : InputFeatures)
+        LFS.push_back({FS, None});
+      if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(Runner.get()))
+        if (MUTR->outputLoggedFeatureSpecs().size() > 1)
+          append_range(LFS, drop_begin(MUTR->outputLoggedFeatureSpecs()));
+      // We always log the output; in particular, if we're not evaluating, we
+      // don't have an output spec json file. That's why we handle the
+      // 'normal' output separately.
+      LFS.push_back({Output, None});
+      auto I = LogMap.insert(std::make_pair(
+          MF.getFunction().getName(),
+          std::make_unique<Logger>(LFS, Reward, /*IncludeReward*/ true)));
+      assert(I.second);
+      Log = I.first->second.get();
+    }
+    return std::make_unique<DevelopmentModeEvictAdvisor>(
+        MF, RA, Runner.get(), getAnalysis<MachineBlockFrequencyInfo>(),
+        getAnalysis<MachineLoopInfo>(), Log);
+  }
+
+  std::unique_ptr<MLModelRunner> Runner;
+  StringMap<std::unique_ptr<Logger>> LogMap;
+};
 #endif //#ifdef LLVM_HAVE_TF_API
 } // namespace
+
+float MLEvictAdvisor::getInitialQueueSize(const MachineFunction &MF) {
+  auto &MRI = MF.getRegInfo();
+  float Ret = 0.0;
+  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
+    Register Reg = Register::index2VirtReg(I);
+    if (MRI.reg_nodbg_empty(Reg))
+      continue;
+    ++Ret;
+  }
+  return Ret;
+}
+
+MLEvictAdvisor::MLEvictAdvisor(const MachineFunction &MF, const RAGreedy &RA,
+                               MLModelRunner *Runner,
+                               const MachineBlockFrequencyInfo &MBFI,
+                               const MachineLoopInfo &Loops)
+    : RegAllocEvictionAdvisor(MF, RA), DefaultAdvisor(MF, RA),
+      Runner(std::move(Runner)), MBFI(MBFI), Loops(Loops),
+      InitialQSize(MLEvictAdvisor::getInitialQueueSize(MF)) {
+  assert(this->Runner);
+  DoNotNormalize.set(FeatureIDs::mask);
+  DoNotNormalize.set(FeatureIDs::is_free);
+  DoNotNormalize.set(FeatureIDs::is_hint);
+  DoNotNormalize.set(FeatureIDs::is_local);
+  DoNotNormalize.set(FeatureIDs::min_stage);
+  DoNotNormalize.set(FeatureIDs::max_stage);
+  DoNotNormalize.set(FeatureIDs::progress);
+}
+
+int64_t MLEvictAdvisor::tryFindEvictionCandidatePosition(
+    LiveInterval &, const AllocationOrder &, unsigned, uint8_t,
+    const SmallVirtRegSet &) const {
+  int64_t Ret = Runner->evaluate<int64_t>();
+  assert(Ret >= 0);
+  assert(Ret <= CandidateVirtRegPos);
+  return Ret;
+}
+
+bool MLEvictAdvisor::loadInterferenceFeatures(
+    LiveInterval &VirtReg, MCRegister PhysReg, bool IsHint,
+    const SmallVirtRegSet &FixedRegisters, FeaturesListNormalizer &Largest,
+    size_t Pos) const {
+  // It is only possible to evict virtual register interference.
+  if (Matrix->checkInterference(VirtReg, PhysReg) > LiveRegMatrix::IK_VirtReg) {
+    // leave unavailable
+    return false;
+  }
+
+  const bool IsLocal = LIS->intervalIsInOneMBB(VirtReg);
+  int64_t LocalIntfs = 0;
+  float NrUrgent = 0.0f;
+
+  // The cascade tracking is the same as in the default advisor
+  unsigned Cascade = RA.getExtraInfo().getCascadeOrCurrentNext(VirtReg.reg());
+
+  SmallVector<LiveInterval *, MaxInterferences> InterferingIntervals;
+  for (MCRegUnitIterator Units(PhysReg, TRI); Units.isValid(); ++Units) {
+    LiveIntervalUnion::Query &Q = Matrix->query(VirtReg, *Units);
+    // Different from the default heuristic, we don't make any assumptions about
+    // what having more than 10 results in the query may mean.
+    const auto &IFIntervals = Q.interferingVRegs();
+    if (IFIntervals.empty() && InterferingIntervals.empty())
+      continue;
+    InterferingIntervals.append(IFIntervals.begin(), IFIntervals.end());
+    for (LiveInterval *Intf : reverse(IFIntervals)) {
+      assert(Register::isVirtualRegister(Intf->reg()) &&
+             "Only expecting virtual register interference from query");
+      // This is the same set of legality checks as in the default case: don't
+      // try to evict fixed regs or 'done' ones. Also don't break cascades,
+      // except in the urgent case, with the same nuances used in the default
+      // heuristic.
+      // We could try sharing this between the advisors, but it may end up
+      // more complex than it is right now.
+      if (FixedRegisters.count(Intf->reg()))
+        return false;
+      if (RA.getExtraInfo().getStage(*Intf) == RS_Done)
+        return false;
+      bool Urgent =
+          !VirtReg.isSpillable() &&
+          (Intf->isSpillable() ||
+           RegClassInfo.getNumAllocatableRegs(MRI->getRegClass(VirtReg.reg())) <
+               RegClassInfo.getNumAllocatableRegs(
+                   MRI->getRegClass(Intf->reg())));
+      // Only evict older cascades or live ranges without a cascade.
+      unsigned IntfCascade = RA.getExtraInfo().getCascade(Intf->reg());
+      if (Cascade <= IntfCascade) {
+        if (!Urgent)
+          return false;
+        ++NrUrgent;
+      }
+
+      LocalIntfs += (IsLocal && LIS->intervalIsInOneMBB(*Intf) &&
+                     (!EnableLocalReassign || !canReassign(*Intf, PhysReg)));
+    }
+  }
+  // OK, so if we made it this far, this LR is an eviction candidate, load its
+  // features.
+  extractFeatures(InterferingIntervals, Largest, Pos, IsHint, LocalIntfs,
+                  NrUrgent);
+  return true;
+}
+
+MCRegister MLEvictAdvisor::tryFindEvictionCandidate(
+    LiveInterval &VirtReg, const AllocationOrder &Order,
+    uint8_t CostPerUseLimit, const SmallVirtRegSet &FixedRegisters) const {
+  auto MaybeOrderLimit = getOrderLimit(VirtReg, Order, CostPerUseLimit);
+  if (!MaybeOrderLimit)
+    return MCRegister::NoRegister;
+  unsigned OrderLimit = *MaybeOrderLimit;
+
+  // The heuristic sets initial costs such as, if CostPerUseLimit is
+  // max<uint8_t>, then any of the costs of the legally-evictable intervals
+  // would be lower. When that happens, one of those will be selected.
+  // Therefore, we allow the candidate be selected, unless the candidate is
+  // unspillable, in which case it would be incorrect to not find a register for
+  // it.
+  const bool MustFindEviction =
+      (!VirtReg.isSpillable() && CostPerUseLimit == static_cast<uint8_t>(~0u));
+  // Number of available candidates - if 0, no need to continue.
+  size_t Available = 0;
+  // Make sure we don't have leftover partial state from an attempt where we had
+  // no available candidates and bailed out early.
+  resetInputs(*Runner);
+
+  // Track the index->register mapping because AllocationOrder doesn't do that
+  // and we'd have to scan it.
+  // Also track their mask, to write asserts/debug.
+  CandidateRegList Regs;
+  Regs.fill({0, false});
+
+  // Track the largest value of features seen during this eviction session. We
+  // only normalize (some of) the float features, but it's just simpler to
+  // dimension 'Largest' to all the features, especially since we have the
+  // 'DoNotNormalize' list.
+  FeaturesListNormalizer Largest;
+  Largest.fill(0.0);
+
+  // Same overal idea as in the default eviction policy - we visit the values of
+  // AllocationOrder one at a time. If it's not legally available, we mask off
+  // the corresponding feature column (==do nothing because we already reset all
+  // the features to 0)
+  // Use Pos to capture the column we load features at - in AllocationOrder
+  // order.
+  size_t Pos = 0;
+  for (auto I = Order.begin(), E = Order.getOrderLimitEnd(OrderLimit); I != E;
+       ++I, ++Pos) {
+    MCRegister PhysReg = *I;
+    Regs[Pos] = std::make_pair(PhysReg, true);
+    assert(PhysReg);
+    if (!canAllocatePhysReg(CostPerUseLimit, PhysReg)) {
+      Regs[Pos].second = false;
+      continue;
+    }
+    if (loadInterferenceFeatures(VirtReg, PhysReg, I.isHint(), FixedRegisters,
+                                 Largest, Pos)) {
+      ++Available;
+      Regs[Pos].second = true;
+    }
+  }
+  if (Available == 0) {
+    // Nothing to decide, nothing to learn.
+    assert(!MustFindEviction);
+    return MCRegister::NoRegister;
+  }
+  // If we must find eviction, the candidate should be masked out of the
+  // decision making process.
+  Regs[CandidateVirtRegPos].second = !MustFindEviction;
+  if (!MustFindEviction)
+    extractFeatures(SmallVector<LiveInterval *, 1>(1, &VirtReg), Largest,
+                    CandidateVirtRegPos, /*IsHint*/ 0, /*LocalIntfsCount*/ 0,
+                    /*NrUrgent*/ 0.0);
+  assert(InitialQSize > 0.0 && "We couldn't have gotten here if we had "
+                               "nothing to allocate initially.");
+  // Normalize the features.
+  for (auto &V : Largest)
+    V = V ? V : 1.0;
+  for (size_t FeatureIndex = 0; FeatureIndex < FeatureIDs::FeatureCount;
+       ++FeatureIndex) {
+    if (DoNotNormalize.test(FeatureIndex))
+      continue;
+    for (size_t Pos = 0; Pos < NumberOfInterferences; ++Pos) {
+      Runner->getTensor<float>(FeatureIndex)[Pos] /= Largest[FeatureIndex];
+    }
+  }
+  *Runner->getTensor<float>(FeatureIDs::progress) =
+      static_cast<float>(RA.getQueueSize()) / InitialQSize;
+
+  // Get a decision.
+  size_t CandidatePos = tryFindEvictionCandidatePosition(
+      VirtReg, Order, OrderLimit, CostPerUseLimit, FixedRegisters);
+  // The contract with the ML side is that CandidatePos is mask == 1 (i.e.
+  // Regs[CandidatePos].second)
+  assert(Regs[CandidatePos].second);
+  if (CandidatePos == CandidateVirtRegPos) {
+    assert(!MustFindEviction);
+    return MCRegister::NoRegister;
+  }
+  return Regs[CandidatePos].first;
+}
+
+// Overall, this currently mimics what we do for weight calculation, but instead
+// of accummulating the various features, we keep them separate.
+void MLEvictAdvisor::extractFeatures(
+    const SmallVectorImpl<LiveInterval *> &Intervals,
+    std::array<float, FeatureIDs::FeatureCount> &Largest, size_t Pos,
+    int64_t IsHint, int64_t LocalIntfsCount, float NrUrgent) const {
+  int64_t NrDefsAndUses = 0;
+  int64_t NrBrokenHints = 0;
+  float R = 0;
+  float W = 0;
+  float RW = 0;
+  float IndVarUpdates = 0;
+  float HintWeights = 0.0;
+  float StartBBFreq = 0.0;
+  float EndBBFreq = 0.0;
+  float HottestBlockFreq = 0.0;
+  int32_t NrRematerializable = 0;
+  float TotalWeight = 0.0;
+
+  SlotIndex EndSI = LIS->getSlotIndexes()->getZeroIndex();
+  SlotIndex StartSI = LIS->getSlotIndexes()->getLastIndex();
+  int64_t MaxStage = 0;
+  int64_t MinStage =
+      Intervals.empty() ? 0 : std::numeric_limits<int64_t>::max();
+
+  for (const auto *L : Intervals) {
+    const LiveInterval &LI = *L;
+    MaxStage = std::max<int64_t>(
+        MaxStage, static_cast<int64_t>(RA.getExtraInfo().getStage(LI)));
+    MinStage = std::min<int64_t>(
+        MinStage, static_cast<int64_t>(RA.getExtraInfo().getStage(LI)));
+
+    TotalWeight = std::max(TotalWeight, LI.weight());
+
+    if (LI.beginIndex() < StartSI)
+      StartSI = LI.beginIndex();
+
+    if (LI.endIndex() > EndSI)
+      EndSI = LI.endIndex();
+
+    SmallPtrSet<MachineInstr *, 8> Visited;
+    const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
+    NrBrokenHints += VRM->hasPreferredPhys(LI.reg());
+
+    for (MachineRegisterInfo::reg_instr_nodbg_iterator
+             I = MRI->reg_instr_nodbg_begin(LI.reg()),
+             E = MRI->reg_instr_nodbg_end();
+         I != E;) {
+      MachineInstr *MI = &*(I++);
+
+      ++NrDefsAndUses;
+      if (!Visited.insert(MI).second)
+        continue;
+
+      if (MI->isIdentityCopy() || MI->isImplicitDef())
+        continue;
+
+      bool Reads, Writes;
+      std::tie(Reads, Writes) = MI->readsWritesVirtualRegister(LI.reg());
+
+      float Freq = MBFI.getBlockFreqRelativeToEntryBlock(MI->getParent());
+      if (Freq > HottestBlockFreq)
+        HottestBlockFreq = Freq;
+      R += (Reads && !Writes) * Freq;
+      W += (!Reads && Writes) * Freq;
+      RW += (Reads && Writes) * Freq;
+
+      auto *MBB = MI->getParent();
+      auto *Loop = Loops.getLoopFor(MBB);
+      bool IsExiting = Loop ? Loop->isLoopExiting(MBB) : false;
+
+      if (Writes && IsExiting && LIS->isLiveOutOfMBB(LI, MBB))
+        IndVarUpdates += Freq;
+
+      if (MI->isCopy() && VirtRegAuxInfo::copyHint(MI, LI.reg(), TRI, *MRI))
+        HintWeights += Freq;
+    }
+    NrRematerializable += VirtRegAuxInfo::isRematerializable(
+        LI, *LIS, *VRM, *MF.getSubtarget().getInstrInfo());
+  }
+  size_t Size = 0;
+  if (!Intervals.empty()) {
+    StartBBFreq =
+        MBFI.getBlockFreqRelativeToEntryBlock(LIS->getMBBFromIndex(StartSI));
+    if (EndSI >= LIS->getSlotIndexes()->getLastIndex())
+      EndSI = LIS->getSlotIndexes()->getLastIndex().getPrevIndex();
+    EndBBFreq =
+        MBFI.getBlockFreqRelativeToEntryBlock(LIS->getMBBFromIndex(EndSI));
+    Size = StartSI.distance(EndSI);
+  }
+  // Set the features at the column 'Pos'.
+#define SET(ID, TYPE, VAL)                                                     \
+  do {                                                                         \
+    Runner->getTensor<TYPE>(FeatureIDs::ID)[Pos] = static_cast<TYPE>(VAL);     \
+    if (!DoNotNormalize.test(FeatureIDs::ID))                                  \
+      Largest[FeatureIDs::ID] =                                                \
+          std::max(Largest[FeatureIDs::ID], static_cast<float>(VAL));          \
+  } while (false)
+  SET(mask, int64_t, 1);
+  SET(is_free, int64_t, Intervals.empty());
+  SET(nr_urgent, float, NrUrgent);
+  SET(nr_broken_hints, float, NrBrokenHints);
+  SET(is_hint, int64_t, IsHint);
+  SET(is_local, int64_t, LocalIntfsCount);
+  SET(nr_rematerializable, float, NrRematerializable);
+  SET(nr_defs_and_uses, float, NrDefsAndUses);
+  SET(weighed_reads_by_max, float, R);
+  SET(weighed_writes_by_max, float, W);
+  SET(weighed_read_writes_by_max, float, RW);
+  SET(weighed_indvars_by_max, float, IndVarUpdates);
+  SET(hint_weights_by_max, float, HintWeights);
+  SET(start_bb_freq_by_max, float, StartBBFreq);
+  SET(end_bb_freq_by_max, float, EndBBFreq);
+  SET(hottest_bb_freq_by_max, float, HottestBlockFreq);
+  SET(liverange_size, float, Size);
+  SET(use_def_density, float, TotalWeight);
+  SET(max_stage, int64_t, MaxStage);
+  SET(min_stage, int64_t, MinStage);
+#undef SET
+}
+
+// Development mode-specific implementations
+#ifdef LLVM_HAVE_TF_API
+RegAllocEvictionAdvisorAnalysis *llvm::createDevelopmentModeAdvisor() {
+  return new DevelopmentModeEvictionAdvisorAnalysis();
+}
+
+int64_t DevelopmentModeEvictAdvisor::tryFindEvictionCandidatePosition(
+    LiveInterval &VirtReg, const AllocationOrder &Order, unsigned OrderLimit,
+    uint8_t CostPerUseLimit, const SmallVirtRegSet &FixedRegisters) const {
+  int64_t Ret = 0;
+  if (isa<ModelUnderTrainingRunner>(getRunner())) {
+    Ret = MLEvictAdvisor::tryFindEvictionCandidatePosition(
+        VirtReg, Order, OrderLimit, CostPerUseLimit, FixedRegisters);
+  } else {
+    MCRegister PhysReg = getDefaultAdvisor().tryFindEvictionCandidate(
+        VirtReg, Order, CostPerUseLimit, FixedRegisters);
+    // Find the index of the selected PhysReg. We need it for logging, otherwise
+    // this is wasted cycles (but so would starting development mode without a
+    // model nor logging)
+    if (!PhysReg)
+      Ret = CandidateVirtRegPos;
+    else
+      for (auto I = Order.begin(), E = Order.getOrderLimitEnd(OrderLimit);
+           I != E; ++I, ++Ret)
+        if (*I == PhysReg)
+          break;
+  }
+  if (TrainingLog.empty())
+    return Ret;
+  size_t CurrentFeature = 0;
+  for (; CurrentFeature < FeatureIDs::FeatureCount; ++CurrentFeature) {
+    Log->logSpecifiedTensorValue(
+        CurrentFeature, reinterpret_cast<const char *>(
+                            getRunner().getTensorUntyped(CurrentFeature)));
+  }
+  if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(&getRunner()))
+    for (size_t I = 1; I < MUTR->outputLoggedFeatureSpecs().size();
+         ++I, ++CurrentFeature)
+      Log->logSpecifiedTensorValue(
+          CurrentFeature,
+          reinterpret_cast<const char *>(
+              MUTR->lastEvaluationResult()->getUntypedTensorValue(I)));
+  // The output is right after the features and the extra outputs
+  Log->logInt64Value(CurrentFeature, &Ret);
+  return Ret;
+}
+
+bool RegAllocScoring::runOnMachineFunction(MachineFunction &MF) {
+  if (auto *DevModeAnalysis = dyn_cast<DevelopmentModeEvictionAdvisorAnalysis>(
+          &getAnalysis<RegAllocEvictionAdvisorAnalysis>()))
+    if (auto *Log = DevModeAnalysis->getLogger(MF))
+      Log->logFloatFinalReward(static_cast<float>(
+          calculateRegAllocScore(
+              MF, getAnalysis<MachineBlockFrequencyInfo>(),
+              getAnalysis<AAResultsWrapperPass>().getAAResults())
+              .getScore()));
+
+  return false;
+}
+#endif // #ifdef LLVM_HAVE_TF_API
+
+// Release mode specific implementations
+#if defined LLVM_HAVE_TF_AOT
+RegAllocEvictionAdvisorAnalysis *llvm::createReleaseModeAdvisor() {
+  return new ReleaseModeEvictionAdvisorAnalysis();
+}
+#endif // defined(LLVM_HAVE_TF_AOT)
 #endif // defined(LLVM_HAVE_TF_AOT) || defined(LLVM_HAVE_TF_API)
+
+// In all cases except development mode, we don't need scoring.
+#if !defined(LLVM_HAVE_TF_API)
+bool RegAllocScoring::runOnMachineFunction(MachineFunction &) { return false; }
+#endif

diff  --git a/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp b/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp
index a41da98977c7d..d64e8cd06492b 100644
--- a/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp
+++ b/llvm/lib/CodeGen/RegAllocEvictionAdvisor.cpp
@@ -83,10 +83,14 @@ template <> Pass *llvm::callDefaultCtor<RegAllocEvictionAdvisorAnalysis>() {
     Ret = new DefaultEvictionAdvisorAnalysis(/*NotAsRequested*/ false);
     break;
   case RegAllocEvictionAdvisorAnalysis::AdvisorMode::Development:
-    // TODO(mtrofin): add implementation
+#if defined(LLVM_HAVE_TF_API)
+    Ret = createDevelopmentModeAdvisor();
+#endif
     break;
   case RegAllocEvictionAdvisorAnalysis::AdvisorMode::Release:
-    // TODO(mtrofin): add implementation
+#if defined(LLVM_HAVE_TF_AOT)
+    Ret = createReleaseModeAdvisor();
+#endif
     break;
   }
   if (Ret)

diff  --git a/llvm/lib/CodeGen/RegAllocGreedy.h b/llvm/lib/CodeGen/RegAllocGreedy.h
index bb8c3e7a5b461..e9a5fe635f260 100644
--- a/llvm/lib/CodeGen/RegAllocGreedy.h
+++ b/llvm/lib/CodeGen/RegAllocGreedy.h
@@ -156,6 +156,7 @@ class LLVM_LIBRARY_VISIBILITY RAGreedy : public MachineFunctionPass,
   VirtRegMap *getVirtRegMap() const { return VRM; }
   const RegisterClassInfo &getRegClassInfo() const { return RegClassInfo; }
   const ExtraRegInfo &getExtraInfo() const { return *ExtraInfo; }
+  size_t getQueueSize() const { return Queue.size(); }
   // end (interface to eviction advisers)
 
 private:

diff  --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 6d9d226992732..05004fb935df8 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -1399,6 +1399,9 @@ bool TargetPassConfig::addRegAssignAndRewriteOptimized() {
   // Finally rewrite virtual registers.
   addPass(&VirtRegRewriterID);
 
+  // Regalloc scoring for ML-driven eviction - noop except when learning a new
+  // eviction policy.
+  addPass(createRegAllocScoringPass());
   return true;
 }
 

diff  --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
index 07587749bc296..d91e3dd4e5dc5 100644
--- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -158,6 +158,7 @@
 ; CHECK-NEXT:       Machine Optimization Remark Emitter
 ; CHECK-NEXT:       Greedy Register Allocator
 ; CHECK-NEXT:       Virtual Register Rewriter
+; CHECK-NEXT:       Register Allocation Pass Scoring
 ; CHECK-NEXT:       Stack Slot Coloring
 ; CHECK-NEXT:       Machine Copy Propagation Pass
 ; CHECK-NEXT:       Machine Loop Invariant Code Motion

diff  --git a/llvm/test/CodeGen/ARM/O3-pipeline.ll b/llvm/test/CodeGen/ARM/O3-pipeline.ll
index b05429ff969ca..d25649dd13e5e 100644
--- a/llvm/test/CodeGen/ARM/O3-pipeline.ll
+++ b/llvm/test/CodeGen/ARM/O3-pipeline.ll
@@ -123,6 +123,7 @@
 ; CHECK-NEXT:      Machine Optimization Remark Emitter
 ; CHECK-NEXT:      Greedy Register Allocator
 ; CHECK-NEXT:      Virtual Register Rewriter
+; CHECK-NEXT:      Register Allocation Pass Scoring
 ; CHECK-NEXT:      Stack Slot Coloring
 ; CHECK-NEXT:      Machine Copy Propagation Pass
 ; CHECK-NEXT:      Machine Loop Invariant Code Motion

diff  --git a/llvm/test/CodeGen/MLRegalloc/Inputs/input.ll b/llvm/test/CodeGen/MLRegalloc/Inputs/input.ll
new file mode 100644
index 0000000000000..4d874da7aef26
--- /dev/null
+++ b/llvm/test/CodeGen/MLRegalloc/Inputs/input.ll
@@ -0,0 +1,687 @@
+; This is a copy of test/CodeGen/X86/ragreedy-hoist-spill.ll. It generates
+; sufficiently interesting 
diff erences between the default eviction heuristic
+; and the test ML policy: 
diff erent eviction choices, and 
diff erent reward.
+;
+;
+%struct.TMP.1 = type { %struct.TMP.2*, %struct.TMP.2*, [1024 x i8] }
+%struct.TMP.2 = type { i8*, i32, i32, i16, i16, %struct.TMP.3, i32, i8*, i32 (i8*)*, i32 (i8*, i8*, i32)*, i64 (i8*, i64, i32)*, i32 (i8*, i8*, i32)*, %struct.TMP.3, %struct.TMP.4*, i32, [3 x i8], [1 x i8], %struct.TMP.3, i32, i64 }
+%struct.TMP.4 = type opaque
+%struct.TMP.3 = type { i8*, i32 }
+
+ at syBuf = external global [16 x %struct.TMP.1], align 16
+ at syHistory = external global [8192 x i8], align 16
+ at SyFgets.yank = external global [512 x i8], align 16
+ at syCTRO = external global i32, align 4
+
+define i8* @SyFgets(i8* %line, i64 %length, i64 %fid) {
+; CHECK-LABEL: SyFgets:
+; CHECK:       ## %bb.0: ## %entry
+; CHECK-NEXT:    pushq %rbp
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    pushq %r15
+; CHECK-NEXT:    .cfi_def_cfa_offset 24
+; CHECK-NEXT:    pushq %r14
+; CHECK-NEXT:    .cfi_def_cfa_offset 32
+; CHECK-NEXT:    pushq %r13
+; CHECK-NEXT:    .cfi_def_cfa_offset 40
+; CHECK-NEXT:    pushq %r12
+; CHECK-NEXT:    .cfi_def_cfa_offset 48
+; CHECK-NEXT:    pushq %rbx
+; CHECK-NEXT:    .cfi_def_cfa_offset 56
+; CHECK-NEXT:    subq $552, %rsp ## imm = 0x228
+; CHECK-NEXT:    .cfi_def_cfa_offset 608
+; CHECK-NEXT:    .cfi_offset %rbx, -56
+; CHECK-NEXT:    .cfi_offset %r12, -48
+; CHECK-NEXT:    .cfi_offset %r13, -40
+; CHECK-NEXT:    .cfi_offset %r14, -32
+; CHECK-NEXT:    .cfi_offset %r15, -24
+; CHECK-NEXT:    .cfi_offset %rbp, -16
+; CHECK-NEXT:    testq $-3, %rdx
+; CHECK-NEXT:    jne LBB0_4
+; CHECK-NEXT:  ## %bb.1: ## %if.end
+; CHECK-NEXT:    xorl %eax, %eax
+; CHECK-NEXT:    testb %al, %al
+; CHECK-NEXT:    jne LBB0_5
+; CHECK-NEXT:  ## %bb.2: ## %if.then4
+; CHECK-NEXT:    xorl %eax, %eax
+; CHECK-NEXT:    testb %al, %al
+; CHECK-NEXT:    je LBB0_55
+; CHECK-NEXT:  ## %bb.3: ## %SyTime.exit
+; CHECK-NEXT:    xorl %eax, %eax
+; CHECK-NEXT:    testb %al, %al
+; CHECK-NEXT:    je LBB0_55
+; CHECK-NEXT:  LBB0_4: ## %cleanup
+; CHECK-NEXT:    addq $552, %rsp ## imm = 0x228
+; CHECK-NEXT:    popq %rbx
+; CHECK-NEXT:    popq %r12
+; CHECK-NEXT:    popq %r13
+; CHECK-NEXT:    popq %r14
+; CHECK-NEXT:    popq %r15
+; CHECK-NEXT:    popq %rbp
+; CHECK-NEXT:    retq
+; CHECK-NEXT:  LBB0_5: ## %if.end25
+; CHECK-NEXT:    xorl %eax, %eax
+; CHECK-NEXT:    testb %al, %al
+; CHECK-NEXT:    je LBB0_55
+; CHECK-NEXT:  ## %bb.6: ## %SyTime.exit2720
+; CHECK-NEXT:    movq %rdx, %rbx
+; CHECK-NEXT:    movq %rdi, %rbp
+; CHECK-NEXT:    leaq {{[0-9]+}}(%rsp), %rax
+; CHECK-NEXT:    leaq {{[0-9]+}}(%rsp), %rcx
+; CHECK-NEXT:    cmpq %rax, %rcx
+; CHECK-NEXT:    jae LBB0_8
+; CHECK-NEXT:  ## %bb.7: ## %for.body.lr.ph
+; CHECK-NEXT:    movl $512, %edx ## imm = 0x200
+; CHECK-NEXT:    movl $32, %esi
+; CHECK-NEXT:    callq _memset
+; CHECK-NEXT:  LBB0_8: ## %while.body.preheader
+; CHECK-NEXT:    imulq $1040, %rbx, %rax ## imm = 0x410
+; CHECK-NEXT:    movq _syBuf at GOTPCREL(%rip), %rcx
+; CHECK-NEXT:    leaq 8(%rcx,%rax), %rdx
+; CHECK-NEXT:    movl $1, %r15d
+; CHECK-NEXT:    movq _syCTRO at GOTPCREL(%rip), %rax
+; CHECK-NEXT:    movb $1, %cl
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_9: ## %do.body
+; CHECK-NEXT:    ## =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    movl $0, (%rax)
+; CHECK-NEXT:    testb %cl, %cl
+; CHECK-NEXT:    jne LBB0_9
+; CHECK-NEXT:  ## %bb.10: ## %do.end
+; CHECK-NEXT:    movq %rdx, {{[-0-9]+}}(%r{{[sb]}}p) ## 8-byte Spill
+; CHECK-NEXT:    movq %rbp, {{[-0-9]+}}(%r{{[sb]}}p) ## 8-byte Spill
+; CHECK-NEXT:    xorl %r13d, %r13d
+; CHECK-NEXT:    testb %r13b, %r13b
+; CHECK-NEXT:    jne LBB0_11
+; CHECK-NEXT:  ## %bb.12: ## %while.body200.preheader
+; CHECK-NEXT:    xorl %r12d, %r12d
+; CHECK-NEXT:    leaq LJTI0_0(%rip), %rdx
+; CHECK-NEXT:    leaq LJTI0_1(%rip), %rbx
+; CHECK-NEXT:    movl $0, {{[-0-9]+}}(%r{{[sb]}}p) ## 4-byte Folded Spill
+; CHECK-NEXT:    xorl %r14d, %r14d
+; CHECK-NEXT:    jmp LBB0_13
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_20: ## %sw.bb256
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    movl %r13d, %r14d
+; CHECK-NEXT:  LBB0_21: ## %while.cond197.backedge
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    decl %r15d
+; CHECK-NEXT:    testl %r15d, %r15d
+; CHECK-NEXT:    movl %r14d, %r13d
+; CHECK-NEXT:    jle LBB0_22
+; CHECK-NEXT:  LBB0_13: ## %while.body200
+; CHECK-NEXT:    ## =>This Loop Header: Depth=1
+; CHECK-NEXT:    ## Child Loop BB0_29 Depth 2
+; CHECK-NEXT:    ## Child Loop BB0_38 Depth 2
+; CHECK-NEXT:    leal -268(%r13), %eax
+; CHECK-NEXT:    cmpl $105, %eax
+; CHECK-NEXT:    ja LBB0_14
+; CHECK-NEXT:  ## %bb.56: ## %while.body200
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    movslq (%rbx,%rax,4), %rax
+; CHECK-NEXT:    addq %rbx, %rax
+; CHECK-NEXT:    jmpq *%rax
+; CHECK-NEXT:  LBB0_44: ## %while.cond1037.preheader
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    testb %r12b, %r12b
+; CHECK-NEXT:    movl %r13d, %r14d
+; CHECK-NEXT:    jne LBB0_21
+; CHECK-NEXT:    jmp LBB0_55
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_14: ## %while.body200
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    leal 1(%r13), %eax
+; CHECK-NEXT:    cmpl $21, %eax
+; CHECK-NEXT:    ja LBB0_20
+; CHECK-NEXT:  ## %bb.15: ## %while.body200
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    movl $-1, %r14d
+; CHECK-NEXT:    movslq (%rdx,%rax,4), %rax
+; CHECK-NEXT:    addq %rdx, %rax
+; CHECK-NEXT:    jmpq *%rax
+; CHECK-NEXT:  LBB0_18: ## %while.cond201.preheader
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    movl $1, %r14d
+; CHECK-NEXT:    jmp LBB0_21
+; CHECK-NEXT:  LBB0_26: ## %sw.bb474
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    testb %r12b, %r12b
+; CHECK-NEXT:    ## implicit-def: $rbp
+; CHECK-NEXT:    jne LBB0_34
+; CHECK-NEXT:  ## %bb.27: ## %do.body479.preheader
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    testb %r12b, %r12b
+; CHECK-NEXT:    ## implicit-def: $rbp
+; CHECK-NEXT:    jne LBB0_34
+; CHECK-NEXT:  ## %bb.28: ## %land.rhs485.preheader
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    ## implicit-def: $rax
+; CHECK-NEXT:    jmp LBB0_29
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_32: ## %do.body479.backedge
+; CHECK-NEXT:    ## in Loop: Header=BB0_29 Depth=2
+; CHECK-NEXT:    leaq 1(%rbp), %rax
+; CHECK-NEXT:    testb %r12b, %r12b
+; CHECK-NEXT:    je LBB0_33
+; CHECK-NEXT:  LBB0_29: ## %land.rhs485
+; CHECK-NEXT:    ## Parent Loop BB0_13 Depth=1
+; CHECK-NEXT:    ## => This Inner Loop Header: Depth=2
+; CHECK-NEXT:    testb %al, %al
+; CHECK-NEXT:    js LBB0_55
+; CHECK-NEXT:  ## %bb.30: ## %cond.true.i.i2780
+; CHECK-NEXT:    ## in Loop: Header=BB0_29 Depth=2
+; CHECK-NEXT:    movq %rax, %rbp
+; CHECK-NEXT:    testb %r12b, %r12b
+; CHECK-NEXT:    jne LBB0_32
+; CHECK-NEXT:  ## %bb.31: ## %lor.rhs500
+; CHECK-NEXT:    ## in Loop: Header=BB0_29 Depth=2
+; CHECK-NEXT:    movl $256, %esi ## imm = 0x100
+; CHECK-NEXT:    callq ___maskrune
+; CHECK-NEXT:    testb %r12b, %r12b
+; CHECK-NEXT:    jne LBB0_32
+; CHECK-NEXT:    jmp LBB0_34
+; CHECK-NEXT:  LBB0_45: ## %sw.bb1134
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    leaq {{[0-9]+}}(%rsp), %rax
+; CHECK-NEXT:    leaq {{[0-9]+}}(%rsp), %rcx
+; CHECK-NEXT:    cmpq %rax, %rcx
+; CHECK-NEXT:    jb LBB0_55
+; CHECK-NEXT:  ## %bb.46: ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    movl $0, {{[-0-9]+}}(%r{{[sb]}}p) ## 4-byte Folded Spill
+; CHECK-NEXT:    movl $268, %r14d ## imm = 0x10C
+; CHECK-NEXT:    jmp LBB0_21
+; CHECK-NEXT:  LBB0_40: ## %sw.bb566
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    movl $20, %r14d
+; CHECK-NEXT:    jmp LBB0_21
+; CHECK-NEXT:  LBB0_19: ## %sw.bb243
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    movl $2, %r14d
+; CHECK-NEXT:    jmp LBB0_21
+; CHECK-NEXT:  LBB0_33: ## %if.end517.loopexitsplit
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    incq %rbp
+; CHECK-NEXT:  LBB0_34: ## %if.end517
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    leal -324(%r14), %eax
+; CHECK-NEXT:    cmpl $59, %eax
+; CHECK-NEXT:    ja LBB0_35
+; CHECK-NEXT:  ## %bb.57: ## %if.end517
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    movabsq $576460756598390785, %rcx ## imm = 0x800000100000001
+; CHECK-NEXT:    btq %rax, %rcx
+; CHECK-NEXT:    jb LBB0_38
+; CHECK-NEXT:  LBB0_35: ## %if.end517
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    cmpl $11, %r14d
+; CHECK-NEXT:    je LBB0_38
+; CHECK-NEXT:  ## %bb.36: ## %if.end517
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    cmpl $24, %r14d
+; CHECK-NEXT:    je LBB0_38
+; CHECK-NEXT:  ## %bb.37: ## %if.then532
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    movq _SyFgets.yank at GOTPCREL(%rip), %rax
+; CHECK-NEXT:    movb $0, (%rax)
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_38: ## %for.cond534
+; CHECK-NEXT:    ## Parent Loop BB0_13 Depth=1
+; CHECK-NEXT:    ## => This Inner Loop Header: Depth=2
+; CHECK-NEXT:    testb %r12b, %r12b
+; CHECK-NEXT:    jne LBB0_38
+; CHECK-NEXT:  ## %bb.39: ## %for.cond542.preheader
+; CHECK-NEXT:    ## in Loop: Header=BB0_13 Depth=1
+; CHECK-NEXT:    testb %r12b, %r12b
+; CHECK-NEXT:    movb $0, (%rbp)
+; CHECK-NEXT:    movl %r13d, %r14d
+; CHECK-NEXT:    leaq LJTI0_0(%rip), %rdx
+; CHECK-NEXT:    jmp LBB0_21
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_42: ## %while.cond864
+; CHECK-NEXT:    ## =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    jmp LBB0_42
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_43: ## %while.cond962
+; CHECK-NEXT:    ## =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    jmp LBB0_43
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_25: ## %for.cond357
+; CHECK-NEXT:    ## =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    jmp LBB0_25
+; CHECK-NEXT:  LBB0_11:
+; CHECK-NEXT:    movl $0, {{[-0-9]+}}(%r{{[sb]}}p) ## 4-byte Folded Spill
+; CHECK-NEXT:    xorl %r14d, %r14d
+; CHECK-NEXT:  LBB0_22: ## %while.end1465
+; CHECK-NEXT:    incl %r14d
+; CHECK-NEXT:    cmpl $16, %r14d
+; CHECK-NEXT:    ja LBB0_50
+; CHECK-NEXT:  ## %bb.23: ## %while.end1465
+; CHECK-NEXT:    movl $83969, %eax ## imm = 0x14801
+; CHECK-NEXT:    btl %r14d, %eax
+; CHECK-NEXT:    jae LBB0_50
+; CHECK-NEXT:  ## %bb.24:
+; CHECK-NEXT:    xorl %ebp, %ebp
+; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rbx ## 8-byte Reload
+; CHECK-NEXT:  LBB0_48: ## %if.then1477
+; CHECK-NEXT:    movl $1, %edx
+; CHECK-NEXT:    callq _write
+; CHECK-NEXT:    subq %rbp, %rbx
+; CHECK-NEXT:    movq _syHistory at GOTPCREL(%rip), %rax
+; CHECK-NEXT:    leaq 8189(%rbx,%rax), %rax
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_49: ## %for.body1723
+; CHECK-NEXT:    ## =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    decq %rax
+; CHECK-NEXT:    jmp LBB0_49
+; CHECK-NEXT:  LBB0_47: ## %if.then1477.loopexit
+; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rbx ## 8-byte Reload
+; CHECK-NEXT:    movq %rbx, %rbp
+; CHECK-NEXT:    jmp LBB0_48
+; CHECK-NEXT:  LBB0_16: ## %while.cond635.preheader
+; CHECK-NEXT:    xorl %eax, %eax
+; CHECK-NEXT:    testb %al, %al
+; CHECK-NEXT:    je LBB0_41
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_17: ## %for.body643.us
+; CHECK-NEXT:    ## =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    jmp LBB0_17
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_41: ## %while.cond661
+; CHECK-NEXT:    ## =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    jmp LBB0_41
+; CHECK-NEXT:  LBB0_50: ## %for.cond1480.preheader
+; CHECK-NEXT:    movl $512, %eax ## imm = 0x200
+; CHECK-NEXT:    cmpq %rax, %rax
+; CHECK-NEXT:    jae LBB0_55
+; CHECK-NEXT:  ## %bb.51: ## %for.body1664.lr.ph
+; CHECK-NEXT:    xorl %eax, %eax
+; CHECK-NEXT:    testb %al, %al
+; CHECK-NEXT:    movq {{[-0-9]+}}(%r{{[sb]}}p), %rbx ## 8-byte Reload
+; CHECK-NEXT:    movl {{[-0-9]+}}(%r{{[sb]}}p), %ebp ## 4-byte Reload
+; CHECK-NEXT:    jne LBB0_54
+; CHECK-NEXT:  ## %bb.52: ## %while.body1679.preheader
+; CHECK-NEXT:    incl %ebp
+; CHECK-NEXT:    .p2align 4, 0x90
+; CHECK-NEXT:  LBB0_53: ## %while.body1679
+; CHECK-NEXT:    ## =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    movq (%rbx), %rdi
+; CHECK-NEXT:    callq _fileno
+; CHECK-NEXT:    movslq %ebp, %rax
+; CHECK-NEXT:    leal 1(%rax), %ebp
+; CHECK-NEXT:    cmpq %rax, %rax
+; CHECK-NEXT:    jl LBB0_53
+; CHECK-NEXT:  LBB0_54: ## %while.cond1683.preheader
+; CHECK-NEXT:    xorl %eax, %eax
+; CHECK-NEXT:    testb %al, %al
+; CHECK-NEXT:  LBB0_55: ## %if.then.i
+; CHECK-NEXT:    ud2
+entry:
+  %sub.ptr.rhs.cast646 = ptrtoint i8* %line to i64
+  %old = alloca [512 x i8], align 16
+  %0 = getelementptr inbounds [512 x i8], [512 x i8]* %old, i64 0, i64 0
+  switch i64 %fid, label %if.then [
+    i64 2, label %if.end
+    i64 0, label %if.end
+  ]
+
+if.then:
+  br label %cleanup
+
+if.end:
+  switch i64 undef, label %if.end25 [
+    i64 0, label %if.then4
+    i64 1, label %if.end25
+  ]
+
+if.then4:
+  br i1 undef, label %SyTime.exit, label %if.then.i
+
+if.then.i:
+  unreachable
+
+SyTime.exit:
+  br i1 undef, label %SyTime.exit2681, label %if.then.i2673
+
+if.then.i2673:
+  unreachable
+
+SyTime.exit2681:
+  br label %cleanup
+
+land.lhs.true14:
+  unreachable
+
+if.end25:
+  br i1 undef, label %SyTime.exit2720, label %if.then.i2712
+
+if.then.i2712:
+  unreachable
+
+SyTime.exit2720:
+  %add.ptr = getelementptr [512 x i8], [512 x i8]* %old, i64 0, i64 512
+  %cmp293427 = icmp ult i8* %0, %add.ptr
+  br i1 %cmp293427, label %for.body.lr.ph, label %while.body.preheader
+
+for.body.lr.ph:
+  call void @llvm.memset.p0i8.i64(i8* align 16 undef, i8 32, i64 512, i1 false)
+  br label %while.body.preheader
+
+while.body.preheader:
+  %add.ptr1603 = getelementptr [512 x i8], [512 x i8]* null, i64 0, i64 512
+  %echo.i3101 = getelementptr [16 x %struct.TMP.1], [16 x %struct.TMP.1]* @syBuf, i64 0, i64 %fid, i32 1
+  %1 = xor i64 %sub.ptr.rhs.cast646, -1
+  br label %do.body
+
+do.body:
+  %ch2.0 = phi i32 [ 0, %while.body.preheader ], [ %ch.12.ch2.12, %do.body ]
+  %rep.0 = phi i32 [ 1, %while.body.preheader ], [ %rep.6, %do.body ]
+  store i32 0, i32* @syCTRO, align 4, !tbaa !1
+  %ch.0.ch2.0 = select i1 undef, i32 14, i32 %ch2.0
+  %ch2.2 = select i1 undef, i32 0, i32 %ch.0.ch2.0
+  %ch.2.ch2.2 = select i1 undef, i32 0, i32 %ch2.2
+  %ch2.4 = select i1 undef, i32 278, i32 %ch.2.ch2.2
+  %ch2.5 = select i1 undef, i32 0, i32 %ch2.4
+  %rep.2 = select i1 undef, i32 undef, i32 %rep.0
+  %ch.5.ch2.5 = select i1 undef, i32 undef, i32 %ch2.5
+  %ch2.7 = select i1 undef, i32 0, i32 %ch.5.ch2.5
+  %rep.3 = select i1 undef, i32 undef, i32 %rep.2
+  %ch.7.ch2.7 = select i1 false, i32 0, i32 %ch2.7
+  %mul98.rep.3 = select i1 false, i32 0, i32 %rep.3
+  %ch2.9 = select i1 undef, i32 undef, i32 %ch.7.ch2.7
+  %rep.5 = select i1 undef, i32 undef, i32 %mul98.rep.3
+  %ch2.10 = select i1 false, i32 undef, i32 %ch2.9
+  %rep.6 = select i1 false, i32 undef, i32 %rep.5
+  %isdigittmp = add i32 %ch2.10, -48
+  %isdigit = icmp ult i32 %isdigittmp, 10
+  %cmp119 = icmp eq i32 undef, 22
+  %or.cond1875 = and i1 %isdigit, %cmp119
+  %ch.10.ch2.10 = select i1 %or.cond1875, i32 undef, i32 %ch2.10
+  %.ch.10 = select i1 %or.cond1875, i32 0, i32 undef
+  %ch2.12 = select i1 undef, i32 %.ch.10, i32 %ch.10.ch2.10
+  %ch.12 = select i1 undef, i32 0, i32 %.ch.10
+  %ch.12.ch2.12 = select i1 false, i32 %ch.12, i32 %ch2.12
+  %.ch.12 = select i1 false, i32 0, i32 %ch.12
+  %cmp147 = icmp eq i32 %.ch.12, 0
+  br i1 %cmp147, label %do.body, label %do.end
+
+do.end:
+  %cmp164 = icmp eq i32 %ch.12.ch2.12, 21
+  %mul167 = shl i32 %rep.6, 2
+  %rep.8 = select i1 %cmp164, i32 %mul167, i32 %rep.6
+  %..ch.19 = select i1 false, i32 2, i32 0
+  br i1 undef, label %while.body200, label %while.end1465
+
+while.body200:
+  %dec3386.in = phi i32 [ %dec3386, %while.cond197.backedge ], [ %rep.8, %do.end ]
+  %oldc.13384 = phi i32 [ %oldc.1.be, %while.cond197.backedge ], [ 0, %do.end ]
+  %ch.213379 = phi i32 [ %last.1.be, %while.cond197.backedge ], [ %..ch.19, %do.end ]
+  %last.13371 = phi i32 [ %last.1.be, %while.cond197.backedge ], [ 0, %do.end ]
+  %dec3386 = add i32 %dec3386.in, -1
+  switch i32 %ch.213379, label %sw.default [
+    i32 1, label %while.cond201.preheader
+    i32 322, label %sw.bb206
+    i32 354, label %sw.bb206
+    i32 2, label %sw.bb243
+    i32 364, label %sw.bb1077
+    i32 326, label %sw.bb256
+    i32 358, label %sw.bb256
+    i32 341, label %sw.bb979
+    i32 323, label %while.cond1037.preheader
+    i32 373, label %sw.bb979
+    i32 4, label %if.then1477
+    i32 332, label %sw.bb1077
+    i32 11, label %for.cond357
+    i32 355, label %while.cond1037.preheader
+    i32 324, label %sw.bb474
+    i32 356, label %sw.bb474
+    i32 20, label %sw.bb566
+    i32 -1, label %while.cond197.backedge
+    i32 268, label %sw.bb1134
+    i32 16, label %while.cond635.preheader
+    i32 18, label %sw.bb956
+    i32 316, label %while.cond864
+  ]
+
+while.cond1037.preheader:
+  %cmp10393273 = icmp eq i8 undef, 0
+  br i1 %cmp10393273, label %if.end1070, label %land.rhs1041
+
+while.cond635.preheader:
+  br i1 undef, label %for.body643.us, label %while.cond661
+
+for.body643.us:
+  br label %for.body643.us
+
+while.cond201.preheader:
+  %umax = select i1 false, i64 undef, i64 %1
+  %2 = xor i64 %umax, -1
+  %3 = inttoptr i64 %2 to i8*
+  br label %while.cond197.backedge
+
+sw.bb206:
+  br label %while.cond197.backedge
+
+sw.bb243:
+  br label %while.cond197.backedge
+
+sw.bb256:
+  br label %while.cond197.backedge
+
+while.cond197.backedge:
+  %last.1.be = phi i32 [ %ch.213379, %sw.default ], [ -1, %while.body200 ], [ %ch.213379, %sw.bb1077 ], [ %ch.213379, %sw.bb979 ], [ 18, %sw.bb956 ], [ 20, %sw.bb566 ], [ %ch.213379, %for.end552 ], [ %ch.213379, %sw.bb256 ], [ 2, %sw.bb243 ], [ 1, %while.cond201.preheader ], [ 268, %for.cond1145.preheader ], [ %ch.213379, %sw.bb206 ]
+  %oldc.1.be = phi i32 [ %oldc.13384, %sw.default ], [ %oldc.13384, %while.body200 ], [ %oldc.13384, %sw.bb1077 ], [ %oldc.13384, %sw.bb979 ], [ %oldc.13384, %sw.bb956 ], [ %oldc.13384, %sw.bb566 ], [ %oldc.13384, %for.end552 ], [ %oldc.13384, %sw.bb256 ], [ %oldc.13384, %sw.bb243 ], [ %oldc.13384, %while.cond201.preheader ], [ 0, %for.cond1145.preheader ], [ %oldc.13384, %sw.bb206 ]
+  %cmp198 = icmp sgt i32 %dec3386, 0
+  br i1 %cmp198, label %while.body200, label %while.end1465
+
+for.cond357:
+  br label %for.cond357
+
+sw.bb474:
+  ; spill is hoisted here. Although loop depth1 is even hotter than loop depth2, sw.bb474 is still cold.
+  %cmp476 = icmp eq i8 undef, 0
+  br i1 %cmp476, label %if.end517, label %do.body479.preheader
+
+do.body479.preheader:
+  %cmp4833314 = icmp eq i8 undef, 0
+  br i1 %cmp4833314, label %if.end517, label %land.rhs485
+
+land.rhs485:
+  %incdec.ptr4803316 = phi i8* [ %incdec.ptr480, %do.body479.backedge.land.rhs485_crit_edge ], [ undef, %do.body479.preheader ]
+  %isascii.i.i27763151 = icmp sgt i8 undef, -1
+  br i1 %isascii.i.i27763151, label %cond.true.i.i2780, label %cond.false.i.i2782
+
+cond.true.i.i2780:
+  br i1 undef, label %land.lhs.true490, label %lor.rhs500
+
+cond.false.i.i2782:
+  unreachable
+
+land.lhs.true490:
+  br i1 false, label %lor.rhs500, label %do.body479.backedge
+
+lor.rhs500:
+  ; Make sure spill is hoisted to a cold preheader in outside loop.
+  %call3.i.i2792 = call i32 @__maskrune(i32 undef, i64 256)
+  br i1 undef, label %land.lhs.true504, label %do.body479.backedge
+
+land.lhs.true504:
+  br i1 undef, label %do.body479.backedge, label %if.end517
+
+do.body479.backedge:
+  %incdec.ptr480 = getelementptr i8, i8* %incdec.ptr4803316, i64 1
+  %cmp483 = icmp eq i8 undef, 0
+  br i1 %cmp483, label %if.end517, label %do.body479.backedge.land.rhs485_crit_edge
+
+do.body479.backedge.land.rhs485_crit_edge:
+  br label %land.rhs485
+
+if.end517:
+  %q.4 = phi i8* [ undef, %sw.bb474 ], [ undef, %do.body479.preheader ], [ %incdec.ptr480, %do.body479.backedge ], [ %incdec.ptr4803316, %land.lhs.true504 ]
+  switch i32 %last.13371, label %if.then532 [
+    i32 383, label %for.cond534
+    i32 356, label %for.cond534
+    i32 324, label %for.cond534
+    i32 24, label %for.cond534
+    i32 11, label %for.cond534
+  ]
+
+if.then532:
+  store i8 0, i8* getelementptr inbounds ([512 x i8], [512 x i8]* @SyFgets.yank, i64 0, i64 0), align 16, !tbaa !5
+  br label %for.cond534
+
+for.cond534:
+  %cmp536 = icmp eq i8 undef, 0
+  br i1 %cmp536, label %for.cond542.preheader, label %for.cond534
+
+for.cond542.preheader:
+  br i1 undef, label %for.body545, label %for.end552
+
+for.body545:
+  br i1 undef, label %for.end552, label %for.body545
+
+for.end552:
+  %s.2.lcssa = phi i8* [ undef, %for.cond542.preheader ], [ %q.4, %for.body545 ]
+  %sub.ptr.lhs.cast553 = ptrtoint i8* %s.2.lcssa to i64
+  %sub.ptr.sub555 = sub i64 %sub.ptr.lhs.cast553, 0
+  %arrayidx556 = getelementptr i8, i8* null, i64 %sub.ptr.sub555
+  store i8 0, i8* %arrayidx556, align 1, !tbaa !5
+  br label %while.cond197.backedge
+
+sw.bb566:
+  br label %while.cond197.backedge
+
+while.cond661:
+  br label %while.cond661
+
+while.cond864:
+  br label %while.cond864
+
+sw.bb956:
+  br i1 undef, label %if.then959, label %while.cond197.backedge
+
+if.then959:
+  br label %while.cond962
+
+while.cond962:
+  br label %while.cond962
+
+sw.bb979:
+  br label %while.cond197.backedge
+
+land.rhs1041:
+  unreachable
+
+if.end1070:
+  br label %sw.bb1077
+
+sw.bb1077:
+  br label %while.cond197.backedge
+
+sw.bb1134:
+  br i1 false, label %for.body1139, label %for.cond1145.preheader
+
+for.cond1145.preheader:
+  br i1 %cmp293427, label %for.body1150.lr.ph, label %while.cond197.backedge
+
+for.body1150.lr.ph:
+  unreachable
+
+for.body1139:
+  unreachable
+
+sw.default:
+  br label %while.cond197.backedge
+
+while.end1465:
+  %oldc.1.lcssa = phi i32 [ 0, %do.end ], [ %oldc.1.be, %while.cond197.backedge ]
+  %ch.21.lcssa = phi i32 [ %..ch.19, %do.end ], [ %last.1.be, %while.cond197.backedge ]
+  switch i32 %ch.21.lcssa, label %for.cond1480.preheader [
+    i32 -1, label %if.then1477
+    i32 15, label %if.then1477
+    i32 13, label %if.then1477
+    i32 10, label %if.then1477
+  ]
+
+for.cond1480.preheader:
+  br i1 undef, label %for.body1606.lr.ph, label %for.end1609
+
+if.then1477:
+  %p.1.lcssa3539 = phi i8* [ null, %while.end1465 ], [ null, %while.end1465 ], [ null, %while.end1465 ], [ null, %while.end1465 ], [ %line, %while.body200 ]
+  %call1.i3057 = call i64 @"\01_write"(i32 undef, i8* undef, i64 1)
+  %sub.ptr.lhs.cast1717 = ptrtoint i8* %p.1.lcssa3539 to i64
+  %sub.ptr.sub1719 = sub i64 %sub.ptr.lhs.cast1717, %sub.ptr.rhs.cast646
+  %idx.neg1727 = sub i64 0, %sub.ptr.sub1719
+  br label %for.body1723
+
+for.body1606.lr.ph:
+  br label %for.end1609
+
+for.end1609:
+  br i1 undef, label %for.cond1659.preheader, label %land.lhs.true1614
+
+land.lhs.true1614:
+  br label %for.cond1659.preheader
+
+for.cond1659.preheader:
+  %cmp16623414 = icmp ult i8* undef, %add.ptr1603
+  br i1 %cmp16623414, label %for.body1664.lr.ph, label %while.body1703.lr.ph
+
+for.body1664.lr.ph:
+  %cmp16773405 = icmp slt i64 undef, undef
+  br i1 %cmp16773405, label %while.body1679, label %while.cond1683.preheader
+
+while.body1703.lr.ph:
+  unreachable
+
+while.cond1683.preheader:
+  br i1 undef, label %while.body1691, label %while.end1693
+
+while.body1679:
+  %oldc.43406 = phi i32 [ %inc, %syEchoch.exit3070 ], [ %oldc.1.lcssa, %for.body1664.lr.ph ]
+  %4 = load %struct.TMP.2*, %struct.TMP.2** %echo.i3101, align 8, !tbaa !6
+  %call.i3062 = call i32 @fileno(%struct.TMP.2* %4)
+  br i1 undef, label %if.then.i3069, label %syEchoch.exit3070
+
+if.then.i3069:
+  br label %syEchoch.exit3070
+
+syEchoch.exit3070:
+  %inc = add i32 %oldc.43406, 1
+  %conv1672 = sext i32 %inc to i64
+  %cmp1677 = icmp slt i64 %conv1672, undef
+  br i1 %cmp1677, label %while.body1679, label %while.cond1683.preheader
+
+while.body1691:
+  unreachable
+
+while.end1693:
+  unreachable
+
+for.body1723:
+  %q.303203 = phi i8* [ getelementptr inbounds ([8192 x i8], [8192 x i8]* @syHistory, i64 0, i64 8189), %if.then1477 ], [ %incdec.ptr1730, %for.body1723 ]
+  %add.ptr1728 = getelementptr i8, i8* %q.303203, i64 %idx.neg1727
+  %5 = load i8, i8* %add.ptr1728, align 1, !tbaa !5
+  %incdec.ptr1730 = getelementptr i8, i8* %q.303203, i64 -1
+  br label %for.body1723
+
+cleanup:
+  ret i8* undef
+}
+
+declare i32 @fileno(%struct.TMP.2* nocapture)
+declare i64 @"\01_write"(i32, i8*, i64)
+declare i32 @__maskrune(i32, i64)
+declare void @llvm.memset.p0i8.i64(i8* nocapture, i8, i64, i1)
+
+!llvm.ident = !{!0}
+
+!0 = !{!"clang version 3.5.0 (trunk 204257)"}
+!1 = !{!2, !2, i64 0}
+!2 = !{!"int", !3, i64 0}
+!3 = !{!"omnipotent char", !4, i64 0}
+!4 = !{!"Simple C/C++ TBAA"}
+!5 = !{!3, !3, i64 0}
+!6 = !{!7, !8, i64 8}
+!7 = !{!"", !8, i64 0, !8, i64 8, !3, i64 16}
+!8 = !{!"any pointer", !3, i64 0}

diff  --git a/llvm/test/CodeGen/MLRegalloc/dev-mode-log-2-fcts.ll b/llvm/test/CodeGen/MLRegalloc/dev-mode-log-2-fcts.ll
new file mode 100644
index 0000000000000..c22d95d698a42
--- /dev/null
+++ b/llvm/test/CodeGen/MLRegalloc/dev-mode-log-2-fcts.ll
@@ -0,0 +1,58 @@
+; REQUIRES: have_tf_api
+; REQUIRES: x86_64-linux
+;
+; Check that we can log more than 1 function.
+;
+; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=development \
+; RUN:   -regalloc-training-log=%t1 -tfutils-text-log < %s
+; RUN: sed -i 's/ \+/ /g' %t1
+; RUN: sed -i 's/\\n key:/\n key:/g' %t1
+; RUN: sed -i 's/\\n feature/\n feature/g' %t1
+; RUN: sed -i 's/\\n/ /g' %t1
+; RUN: FileCheck --input-file %t1 %s
+
+; RUN: rm -rf %t && mkdir %t
+; RUN: %python %S/../../../lib/Analysis/models/gen-regalloc-eviction-test-model.py %t
+; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=development \
+; RUN:   -regalloc-training-log=%t2 -tfutils-text-log -regalloc-model=%t < %s
+; RUN: sed -i 's/ \+/ /g' %t2
+; RUN: sed -i 's/\\n key:/\n key:/g' %t2
+; RUN: sed -i 's/\\n feature/\n feature/g' %t2
+; RUN: sed -i 's/\\n/ /g' %t2
+; RUN: FileCheck --input-file %t2 %s
+
+declare void @f();
+
+define void @f1(i64 %lhs, i64 %rhs, i64* %addr) !prof !15 {
+  %sum = add i64 %lhs, %rhs
+  call void @f();
+  store i64 %sum, i64* %addr
+  ret void
+}
+
+define void @f2(i64 %lhs, i64 %rhs, i64* %addr) !prof !16 {
+  %sum = add i64 %lhs, %rhs
+  store i64 %sum, i64* %addr
+  ret void
+}
+
+; CHECK:  key: "f1"
+; CHECK:  key: "f2"
+
+!llvm.module.flags = !{!1}
+!1 = !{i32 1, !"ProfileSummary", !2}
+!2 = !{!3, !4, !5, !6, !7, !8, !9, !10}
+!3 = !{!"ProfileFormat", !"InstrProf"}
+!4 = !{!"TotalCount", i64 10000}
+!5 = !{!"MaxCount", i64 10}
+!6 = !{!"MaxInternalCount", i64 1}
+!7 = !{!"MaxFunctionCount", i64 1000}
+!8 = !{!"NumCounts", i64 3}
+!9 = !{!"NumFunctions", i64 3}
+!10 = !{!"DetailedSummary", !11}
+!11 = !{!12, !13, !14}
+!12 = !{i32 10000, i64 100, i32 1}
+!13 = !{i32 999000, i64 100, i32 1}
+!14 = !{i32 999999, i64 1, i32 2}
+!15 = !{!"function_entry_count", i64 1}
+!16 = !{!"function_entry_count", i64 1000}

diff  --git a/llvm/test/CodeGen/MLRegalloc/dev-mode-logging.ll b/llvm/test/CodeGen/MLRegalloc/dev-mode-logging.ll
new file mode 100644
index 0000000000000..995684959c6c8
--- /dev/null
+++ b/llvm/test/CodeGen/MLRegalloc/dev-mode-logging.ll
@@ -0,0 +1,33 @@
+; REQUIRES: have_tf_api
+; REQUIRES: x86_64-linux
+;
+; Check that we log correctly, both with a learned policy, and the default policy
+;
+; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=development \
+; RUN:   -regalloc-training-log=%t1 -tfutils-text-log < %S/Inputs/input.ll
+; RUN: sed -i 's/ \+/ /g' %t1
+; RUN: sed -i 's/\\n key:/\n key:/g' %t1
+; RUN: sed -i 's/\\n feature/\n feature/g' %t1
+; RUN: sed -i 's/\\n/ /g' %t1
+; RUN: FileCheck --input-file %t1 %s --check-prefixes=CHECK,NOML
+
+; RUN: rm -rf %t && mkdir %t
+; RUN: %python %S/../../../lib/Analysis/models/gen-regalloc-eviction-test-model.py %t
+; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=development \
+; RUN:   -regalloc-training-log=%t2 -tfutils-text-log -regalloc-model=%t < %S/Inputs/input.ll
+; RUN: sed -i 's/ \+/ /g' %t2
+; RUN: sed -i 's/\\n key:/\n key:/g' %t2
+; RUN: sed -i 's/\\n feature/\n feature/g' %t2
+; RUN: sed -i 's/\\n/ /g' %t2
+; RUN: FileCheck --input-file %t2 %s --check-prefixes=CHECK,ML
+
+; CHECK-NOT: nan
+; CHECK-LABEL: key: \"index_to_evict\"
+; CHECK-NEXT: value: 9
+; ML-NEXT:    value: 9
+; NOML-NEXT:  value: 32
+; CHECK-LABEL: key: \"reward\"
+; ML:   value: 37.73
+; NOML: value: 37.47
+; CHECK-NEXT: feature_list
+; CHECK-NEXT: key: \"start_bb_freq_by_max\"

diff  --git a/llvm/test/CodeGen/MLRegalloc/dev-rel-equivalence.ll b/llvm/test/CodeGen/MLRegalloc/dev-rel-equivalence.ll
new file mode 100644
index 0000000000000..0dddc899f1db5
--- /dev/null
+++ b/llvm/test/CodeGen/MLRegalloc/dev-rel-equivalence.ll
@@ -0,0 +1,19 @@
+; REQUIRES: have_tf_api
+; REQUIRES: have_tf_aot
+; REQUIRES: x86_64-linux
+;
+; Check that the same model (==the autogenerated one) produces the same output
+; regardless of how it's evaluated, which is 
diff erent from the default
+;
+; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=default \
+; RUN:   %S/Inputs/input.ll -o %t.default
+
+; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=release \
+; RUN:   %S/Inputs/input.ll -o %t.release
+
+; RUN: rm -rf %t && mkdir %t
+; RUN: %python %S/../../../lib/Analysis/models/gen-regalloc-eviction-test-model.py %t
+; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=development \
+; RUN:   -regalloc-model=%t %S/Inputs/input.ll -o %t.development
+; RUN: 
diff  %t.release %t.development
+; RUN: not 
diff  %t.release %t.default

diff  --git a/llvm/test/CodeGen/MLRegalloc/rel-codepath.ll b/llvm/test/CodeGen/MLRegalloc/rel-codepath.ll
new file mode 100644
index 0000000000000..6547e91d42a68
--- /dev/null
+++ b/llvm/test/CodeGen/MLRegalloc/rel-codepath.ll
@@ -0,0 +1,15 @@
+; REQUIRES: have_tf_aot
+; REQUIRES: x86_64-linux
+;
+; Check the code path for release mode is correctly taken. It is shared with
+; development mode, and we separately test the internals of that (logged
+; features, etc), so all we care about here is that the output is produced and
+; is 
diff erent from default policy.
+;
+; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=default \
+; RUN:   %S/Inputs/input.ll -o %t.default
+
+; RUN: llc -mtriple=x86_64-linux-unknown -regalloc=greedy -regalloc-enable-advisor=release \
+; RUN:   %S/Inputs/input.ll -o %t.release
+
+; RUN: not 
diff  %t.release %t.default

diff  --git a/llvm/test/CodeGen/PowerPC/O3-pipeline.ll b/llvm/test/CodeGen/PowerPC/O3-pipeline.ll
index dc0f7c95d8273..d64947f5f97a2 100644
--- a/llvm/test/CodeGen/PowerPC/O3-pipeline.ll
+++ b/llvm/test/CodeGen/PowerPC/O3-pipeline.ll
@@ -160,6 +160,7 @@
 ; CHECK-NEXT:       Machine Optimization Remark Emitter
 ; CHECK-NEXT:       Greedy Register Allocator
 ; CHECK-NEXT:       Virtual Register Rewriter
+; CHECK-NEXT:       Register Allocation Pass Scoring
 ; CHECK-NEXT:       Stack Slot Coloring
 ; CHECK-NEXT:       Machine Copy Propagation Pass
 ; CHECK-NEXT:       Machine Loop Invariant Code Motion

diff  --git a/llvm/test/CodeGen/X86/opt-pipeline.ll b/llvm/test/CodeGen/X86/opt-pipeline.ll
index 3fbdc8d1a4e7a..49a2829833995 100644
--- a/llvm/test/CodeGen/X86/opt-pipeline.ll
+++ b/llvm/test/CodeGen/X86/opt-pipeline.ll
@@ -144,6 +144,7 @@
 ; CHECK-NEXT:       Greedy Register Allocator
 ; CHECK-NEXT:       Tile Register Configure
 ; CHECK-NEXT:       Virtual Register Rewriter
+; CHECK-NEXT:       Register Allocation Pass Scoring
 ; CHECK-NEXT:       Stack Slot Coloring
 ; CHECK-NEXT:       Machine Copy Propagation Pass
 ; CHECK-NEXT:       Machine Loop Invariant Code Motion


        


More information about the llvm-commits mailing list