[llvm] [SelectionDAG] Add space-optimized forms of OPC_CheckPredicate (PR #77763)
Wang Pengcheng via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 11 04:59:44 PST 2024
https://github.com/wangpc-pp created https://github.com/llvm/llvm-project/pull/77763
We record the usage of each `Predicate` and sort them by usage.
For the top 8 `Predicate`s, we will emit a `PC_CheckPredicateN` to
save one byte.
Overall this reduces the llc binary size with all in-tree targets by
about 61K.
This is a recommit of 1a57927, which was reverted in bc98c31.
The CI failures occurred when doing expensive checks (with option
`LLVM_ENABLE_EXPENSIVE_CHECKS` being ON).
The key point here is that we need stable sorting result in the
test, but doing expensive checks uncovered the non-determinism of
`llvm::sort`. So `llvm::sort` is changed to `illvm::stable_sort`
in this revised patch.
>From 669bc5e1a4a16b2c6d126cac23390935bb7b386e Mon Sep 17 00:00:00 2001
From: wangpc <wangpengcheng.pp at bytedance.com>
Date: Thu, 11 Jan 2024 19:47:49 +0800
Subject: [PATCH] [SelectionDAG] Add space-optimized forms of
OPC_CheckPredicate
We record the usage of each `Predicate` and sort them by usage.
For the top 8 `Predicate`s, we will emit a `PC_CheckPredicateN` to
save one byte.
Overall this reduces the llc binary size with all in-tree targets by
about 61K.
This is a recommit of 1a57927, which was reverted in bc98c31.
The CI failures occurred when doing expensive checks (with option
`LLVM_ENABLE_EXPENSIVE_CHECKS` being ON).
The key point here is that we need stable sorting result in the
test, but doing expensive checks uncovered the non-determinism of
`llvm::sort`. So `llvm::sort` is changed to `illvm::stable_sort`
in this revised patch.
---
llvm/include/llvm/CodeGen/SelectionDAGISel.h | 8 ++
.../CodeGen/SelectionDAG/SelectionDAGISel.cpp | 30 ++++-
llvm/test/TableGen/address-space-patfrags.td | 4 +-
llvm/test/TableGen/predicate-patfags.td | 4 +-
llvm/utils/TableGen/DAGISelMatcherEmitter.cpp | 104 +++++++++++-------
5 files changed, 101 insertions(+), 49 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
index e4d90f6e898fe8..dbd9b391f4a431 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
@@ -169,6 +169,14 @@ class SelectionDAGISel : public MachineFunctionPass {
OPC_CheckPatternPredicate7,
OPC_CheckPatternPredicateTwoByte,
OPC_CheckPredicate,
+ OPC_CheckPredicate0,
+ OPC_CheckPredicate1,
+ OPC_CheckPredicate2,
+ OPC_CheckPredicate3,
+ OPC_CheckPredicate4,
+ OPC_CheckPredicate5,
+ OPC_CheckPredicate6,
+ OPC_CheckPredicate7,
OPC_CheckPredicateWithOperands,
OPC_CheckOpcode,
OPC_SwitchOpcode,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 678d273e4bd605..359d738d2ca09f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -2712,9 +2712,13 @@ CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,
/// CheckNodePredicate - Implements OP_CheckNodePredicate.
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
-CheckNodePredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex,
- const SelectionDAGISel &SDISel, SDNode *N) {
- return SDISel.CheckNodePredicate(N, MatcherTable[MatcherIndex++]);
+CheckNodePredicate(unsigned Opcode, const unsigned char *MatcherTable,
+ unsigned &MatcherIndex, const SelectionDAGISel &SDISel,
+ SDNode *N) {
+ unsigned PredNo = Opcode == SelectionDAGISel::OPC_CheckPredicate
+ ? MatcherTable[MatcherIndex++]
+ : Opcode - SelectionDAGISel::OPC_CheckPredicate0;
+ return SDISel.CheckNodePredicate(N, PredNo);
}
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
@@ -2868,7 +2872,15 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
Result = !::CheckPatternPredicate(Opcode, Table, Index, SDISel);
return Index;
case SelectionDAGISel::OPC_CheckPredicate:
- Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode());
+ case SelectionDAGISel::OPC_CheckPredicate0:
+ case SelectionDAGISel::OPC_CheckPredicate1:
+ case SelectionDAGISel::OPC_CheckPredicate2:
+ case SelectionDAGISel::OPC_CheckPredicate3:
+ case SelectionDAGISel::OPC_CheckPredicate4:
+ case SelectionDAGISel::OPC_CheckPredicate5:
+ case SelectionDAGISel::OPC_CheckPredicate6:
+ case SelectionDAGISel::OPC_CheckPredicate7:
+ Result = !::CheckNodePredicate(Opcode, Table, Index, SDISel, N.getNode());
return Index;
case SelectionDAGISel::OPC_CheckOpcode:
Result = !::CheckOpcode(Table, Index, N.getNode());
@@ -3359,8 +3371,16 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
if (!::CheckPatternPredicate(Opcode, MatcherTable, MatcherIndex, *this))
break;
continue;
+ case SelectionDAGISel::OPC_CheckPredicate0:
+ case SelectionDAGISel::OPC_CheckPredicate1:
+ case SelectionDAGISel::OPC_CheckPredicate2:
+ case SelectionDAGISel::OPC_CheckPredicate3:
+ case SelectionDAGISel::OPC_CheckPredicate4:
+ case SelectionDAGISel::OPC_CheckPredicate5:
+ case SelectionDAGISel::OPC_CheckPredicate6:
+ case SelectionDAGISel::OPC_CheckPredicate7:
case OPC_CheckPredicate:
- if (!::CheckNodePredicate(MatcherTable, MatcherIndex, *this,
+ if (!::CheckNodePredicate(Opcode, MatcherTable, MatcherIndex, *this,
N.getNode()))
break;
continue;
diff --git a/llvm/test/TableGen/address-space-patfrags.td b/llvm/test/TableGen/address-space-patfrags.td
index 27b174b4633cd8..4aec6ea7e0eae8 100644
--- a/llvm/test/TableGen/address-space-patfrags.td
+++ b/llvm/test/TableGen/address-space-patfrags.td
@@ -46,7 +46,7 @@ def inst_d : Instruction {
let InOperandList = (ins GPR32:$src0, GPR32:$src1);
}
-// SDAG: case 2: {
+// SDAG: case 1: {
// SDAG-NEXT: // Predicate_pat_frag_b
// SDAG-NEXT: // Predicate_truncstorei16_addrspace
// SDAG-NEXT: SDNode *N = Node;
@@ -69,7 +69,7 @@ def : Pat <
>;
-// SDAG: case 3: {
+// SDAG: case 6: {
// SDAG: // Predicate_pat_frag_a
// SDAG-NEXT: SDNode *N = Node;
// SDAG-NEXT: (void)N;
diff --git a/llvm/test/TableGen/predicate-patfags.td b/llvm/test/TableGen/predicate-patfags.td
index 0912b05127ef81..2cf29769dc13a7 100644
--- a/llvm/test/TableGen/predicate-patfags.td
+++ b/llvm/test/TableGen/predicate-patfags.td
@@ -39,10 +39,10 @@ def TGTmul24_oneuse : PatFrag<
}
// SDAG: OPC_CheckOpcode, TARGET_VAL(ISD::INTRINSIC_W_CHAIN),
-// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
+// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse
// SDAG: OPC_CheckOpcode, TARGET_VAL(TargetISD::MUL24),
-// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
+// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse
// GISEL: GIM_CheckOpcode, /*MI*/1, GIMT_Encode2(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS),
// GISEL: GIM_CheckIntrinsicID, /*MI*/1, /*Op*/1, GIMT_Encode2(Intrinsic::tgt_mul24),
diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
index a3e2facf948e89..f917b5689398b3 100644
--- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
@@ -52,9 +52,8 @@ class MatcherTableEmitter {
SmallVector<unsigned, Matcher::HighestKind+1> OpcodeCounts;
- DenseMap<TreePattern *, unsigned> NodePredicateMap;
- std::vector<TreePredicateFn> NodePredicates;
- std::vector<TreePredicateFn> NodePredicatesWithOperands;
+ std::vector<TreePattern *> NodePredicates;
+ std::vector<TreePattern *> NodePredicatesWithOperands;
// We de-duplicate the predicates by code string, and use this map to track
// all the patterns with "identical" predicates.
@@ -88,6 +87,8 @@ class MatcherTableEmitter {
DenseMap<const ComplexPattern *, unsigned> ComplexPatternUsage;
// Record the usage of PatternPredicate.
std::map<StringRef, unsigned> PatternPredicateUsage;
+ // Record the usage of Predicate.
+ DenseMap<TreePattern *, unsigned> PredicateUsage;
// Iterate the whole MatcherTable once and do some statistics.
std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
@@ -105,6 +106,8 @@ class MatcherTableEmitter {
++ComplexPatternUsage[&CPM->getPattern()];
else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
++PatternPredicateUsage[CPPM->getPredicate()];
+ else if (auto *PM = dyn_cast<CheckPredicateMatcher>(N))
+ ++PredicateUsage[PM->getPredicate().getOrigPatFragRecord()];
N = N->getNext();
}
};
@@ -113,18 +116,54 @@ class MatcherTableEmitter {
// Sort ComplexPatterns by usage.
std::vector<std::pair<const ComplexPattern *, unsigned>> ComplexPatternList(
ComplexPatternUsage.begin(), ComplexPatternUsage.end());
- sort(ComplexPatternList,
- [](const auto &A, const auto &B) { return A.second > B.second; });
+ stable_sort(ComplexPatternList, [](const auto &A, const auto &B) {
+ return A.second > B.second;
+ });
for (const auto &ComplexPattern : ComplexPatternList)
ComplexPatterns.push_back(ComplexPattern.first);
// Sort PatternPredicates by usage.
std::vector<std::pair<std::string, unsigned>> PatternPredicateList(
PatternPredicateUsage.begin(), PatternPredicateUsage.end());
- sort(PatternPredicateList,
- [](const auto &A, const auto &B) { return A.second > B.second; });
+ stable_sort(PatternPredicateList, [](const auto &A, const auto &B) {
+ return A.second > B.second;
+ });
for (const auto &PatternPredicate : PatternPredicateList)
PatternPredicates.push_back(PatternPredicate.first);
+
+ // Sort Predicates by usage.
+ // Merge predicates with same code.
+ for (const auto &Usage : PredicateUsage) {
+ TreePattern *TP = Usage.first;
+ TreePredicateFn Pred(TP);
+ NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()].push_back(TP);
+ }
+
+ std::vector<std::pair<TreePattern *, unsigned>> PredicateList;
+ // Sum the usage.
+ for (auto &Predicate : NodePredicatesByCodeToRun) {
+ TinyPtrVector<TreePattern *> &TPs = Predicate.second;
+ stable_sort(TPs, [](const auto *A, const auto *B) {
+ return A->getRecord()->getName() < B->getRecord()->getName();
+ });
+ unsigned Uses = 0;
+ for (TreePattern *TP : TPs)
+ Uses += PredicateUsage.at(TP);
+
+ // We only add the first predicate here since they are with the same code.
+ PredicateList.push_back({TPs[0], Uses});
+ }
+
+ stable_sort(PredicateList, [](const auto &A, const auto &B) {
+ return A.second > B.second;
+ });
+ for (const auto &Predicate : PredicateList) {
+ TreePattern *TP = Predicate.first;
+ if (TreePredicateFn(TP).usesOperands())
+ NodePredicatesWithOperands.push_back(TP);
+ else
+ NodePredicates.push_back(TP);
+ }
}
unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
@@ -139,7 +178,7 @@ class MatcherTableEmitter {
void EmitPatternMatchTable(raw_ostream &OS);
private:
- void EmitNodePredicatesFunction(const std::vector<TreePredicateFn> &Preds,
+ void EmitNodePredicatesFunction(const std::vector<TreePattern *> &Preds,
StringRef Decl, raw_ostream &OS);
unsigned SizeMatcher(Matcher *N, raw_ostream &OS);
@@ -148,33 +187,13 @@ class MatcherTableEmitter {
raw_ostream &OS);
unsigned getNodePredicate(TreePredicateFn Pred) {
- TreePattern *TP = Pred.getOrigPatFragRecord();
- unsigned &Entry = NodePredicateMap[TP];
- if (Entry == 0) {
- TinyPtrVector<TreePattern *> &SameCodePreds =
- NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()];
- if (SameCodePreds.empty()) {
- // We've never seen a predicate with the same code: allocate an entry.
- if (Pred.usesOperands()) {
- NodePredicatesWithOperands.push_back(Pred);
- Entry = NodePredicatesWithOperands.size();
- } else {
- NodePredicates.push_back(Pred);
- Entry = NodePredicates.size();
- }
- } else {
- // We did see an identical predicate: re-use it.
- Entry = NodePredicateMap[SameCodePreds.front()];
- assert(Entry != 0);
- assert(TreePredicateFn(SameCodePreds.front()).usesOperands() ==
- Pred.usesOperands() &&
- "PatFrags with some code must have same usesOperands setting");
- }
- // In both cases, we've never seen this particular predicate before, so
- // mark it in the list of predicates sharing the same code.
- SameCodePreds.push_back(TP);
- }
- return Entry-1;
+ // We use the first predicate.
+ TreePattern *PredPat =
+ NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()][0];
+ return Pred.usesOperands()
+ ? llvm::find(NodePredicatesWithOperands, PredPat) -
+ NodePredicatesWithOperands.begin()
+ : llvm::find(NodePredicates, PredPat) - NodePredicates.begin();
}
unsigned getPatternPredicate(StringRef PredName) {
@@ -529,6 +548,7 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
case Matcher::CheckPredicate: {
TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
unsigned OperandBytes = 0;
+ unsigned PredNo = getNodePredicate(Pred);
if (Pred.usesOperands()) {
unsigned NumOps = cast<CheckPredicateMatcher>(N)->getNumOperands();
@@ -537,10 +557,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
OS << cast<CheckPredicateMatcher>(N)->getOperandNo(i) << ", ";
OperandBytes = 1 + NumOps;
} else {
- OS << "OPC_CheckPredicate, ";
+ if (PredNo < 8) {
+ OperandBytes = -1;
+ OS << "OPC_CheckPredicate" << PredNo << ", ";
+ } else
+ OS << "OPC_CheckPredicate, ";
}
- OS << getNodePredicate(Pred) << ',';
+ if (PredNo >= 8 || Pred.usesOperands())
+ OS << PredNo << ',';
if (!OmitComments)
OS << " // " << Pred.getFnName();
OS << '\n';
@@ -1029,8 +1054,7 @@ EmitMatcherList(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
}
void MatcherTableEmitter::EmitNodePredicatesFunction(
- const std::vector<TreePredicateFn> &Preds, StringRef Decl,
- raw_ostream &OS) {
+ const std::vector<TreePattern *> &Preds, StringRef Decl, raw_ostream &OS) {
if (Preds.empty())
return;
@@ -1040,7 +1064,7 @@ void MatcherTableEmitter::EmitNodePredicatesFunction(
OS << " default: llvm_unreachable(\"Invalid predicate in table?\");\n";
for (unsigned i = 0, e = Preds.size(); i != e; ++i) {
// Emit the predicate code corresponding to this pattern.
- const TreePredicateFn PredFn = Preds[i];
+ TreePredicateFn PredFn(Preds[i]);
assert(!PredFn.isAlwaysTrue() && "No code in this predicate");
std::string PredFnCodeStr = PredFn.getCodeToRunOnSDNode();
More information about the llvm-commits
mailing list