[llvm] Adding Matching and Inference Functionality to Propeller-PR4: Implement matching and inference and create clusters (PR #165868)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 5 07:06:20 PST 2025
================
@@ -0,0 +1,187 @@
+//===- llvm/CodeGen/BasicBlockMatchingAndInference.cpp ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Infer weights for all basic blocks using matching and inference.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/BasicBlockMatchingAndInference.h"
+#include "llvm/CodeGen/BasicBlockSectionsProfileReader.h"
+#include "llvm/CodeGen/MachineBlockHashInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/InitializePasses.h"
+#include <llvm/Support/CommandLine.h>
+
+using namespace llvm;
+
+static cl::opt<float>
+ PropellerInferThreshold("propeller-infer-threshold",
+ cl::desc("Threshold for infer stale profile"),
+ cl::init(0.6), cl::Optional);
+
+/// The object is used to identify and match basic blocks given their hashes.
+class StaleMatcher {
+public:
+ /// Initialize stale matcher.
+ void init(const std::vector<MachineBasicBlock *> &Blocks,
+ const std::vector<BlendedBlockHash> &Hashes) {
+ assert(Blocks.size() == Hashes.size() &&
+ "incorrect matcher initialization");
+ for (size_t I = 0; I < Blocks.size(); I++) {
+ MachineBasicBlock *Block = Blocks[I];
+ uint16_t OpHash = Hashes[I].getOpcodeHash();
+ OpHashToBlocks[OpHash].push_back(std::make_pair(Hashes[I], Block));
+ }
+ }
+
+ /// Find the most similar block for a given hash.
+ MachineBasicBlock *matchBlock(BlendedBlockHash BlendedHash) const {
+ auto BlockIt = OpHashToBlocks.find(BlendedHash.getOpcodeHash());
+ if (BlockIt == OpHashToBlocks.end()) {
+ return nullptr;
+ }
+ MachineBasicBlock *BestBlock = nullptr;
+ uint64_t BestDist = std::numeric_limits<uint64_t>::max();
+ for (auto It : BlockIt->second) {
+ MachineBasicBlock *Block = It.second;
+ BlendedBlockHash Hash = It.first;
+ uint64_t Dist = Hash.distance(BlendedHash);
+ if (BestBlock == nullptr || Dist < BestDist) {
+ BestDist = Dist;
+ BestBlock = Block;
+ }
+ }
+ return BestBlock;
+ }
+
+private:
+ using HashBlockPairType = std::pair<BlendedBlockHash, MachineBasicBlock *>;
+ std::unordered_map<uint16_t, std::vector<HashBlockPairType>> OpHashToBlocks;
+};
+
+INITIALIZE_PASS_BEGIN(BasicBlockMatchingAndInference,
+ "machine-block-match-infer",
+ "Machine Block Matching and Inference Analysis", true,
+ true)
+INITIALIZE_PASS_DEPENDENCY(MachineBlockHashInfo)
+INITIALIZE_PASS_DEPENDENCY(BasicBlockSectionsProfileReaderWrapperPass)
+INITIALIZE_PASS_END(BasicBlockMatchingAndInference, "machine-block-match-infer",
+ "Machine Block Matching and Inference Analysis", true, true)
+
+char BasicBlockMatchingAndInference::ID = 0;
+
+BasicBlockMatchingAndInference::BasicBlockMatchingAndInference()
+ : MachineFunctionPass(ID) {
+ initializeBasicBlockMatchingAndInferencePass(
+ *PassRegistry::getPassRegistry());
+}
+
+void BasicBlockMatchingAndInference::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<MachineBlockHashInfo>();
+ AU.addRequired<BasicBlockSectionsProfileReaderWrapperPass>();
+ AU.setPreservesAll();
+ MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+std::optional<BasicBlockMatchingAndInference::WeightInfo>
+BasicBlockMatchingAndInference::getWeightInfo(StringRef FuncName) const {
+ auto It = ProgramWeightInfo.find(FuncName);
+ if (It == ProgramWeightInfo.end()) {
+ return std::nullopt;
+ }
+ return It->second;
+}
+
+BasicBlockMatchingAndInference::WeightInfo
+BasicBlockMatchingAndInference::initWeightInfoByMatching(MachineFunction &MF) {
+ std::vector<MachineBasicBlock *> Blocks;
+ std::vector<BlendedBlockHash> Hashes;
+ auto BSPR = &getAnalysis<BasicBlockSectionsProfileReaderWrapperPass>();
+ auto MBHI = &getAnalysis<MachineBlockHashInfo>();
+ for (auto &Block : MF) {
+ Blocks.push_back(&Block);
+ Hashes.push_back(BlendedBlockHash(MBHI->getMBBHash(Block)));
+ }
+ StaleMatcher Matcher;
+ Matcher.init(Blocks, Hashes);
+ BasicBlockMatchingAndInference::WeightInfo MatchWeight;
+ auto [Flag, PathAndClusterInfo] =
----------------
spupyrev wrote:
It's unclear what this `Flag` is. Can you find a more descriptive name?
https://github.com/llvm/llvm-project/pull/165868
More information about the llvm-commits
mailing list