[llvm] [SelectionDAG] Add space-optimized forms of OPC_CheckPatternPredicate (PR #73319)
Wang Pengcheng via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 8 00:20:55 PST 2024
https://github.com/wangpc-pp updated https://github.com/llvm/llvm-project/pull/73319
>From 352fd1fb9d90ca0551c2083f9ee8cb31ae8756cf Mon Sep 17 00:00:00 2001
From: wangpc <wangpengcheng.pp at bytedance.com>
Date: Fri, 24 Nov 2023 13:24:12 +0800
Subject: [PATCH 1/2] [SelectionDAG] Add space-optimized forms of
OPC_CheckComplexPat
We record the usage of each `ComplexPat` and sort the `ComplexPat`s
by usage.
For the top 8 `ComplexPat`s, we will emit a `OPC_CheckComplexPatN`
to save one byte.
Overall this reduces the llc binary size with all in-tree targets by
about 89K.
---
llvm/include/llvm/CodeGen/SelectionDAGISel.h | 8 +++
.../CodeGen/SelectionDAG/SelectionDAGISel.cpp | 14 ++++-
llvm/test/TableGen/dag-isel-complexpattern.td | 2 +-
llvm/utils/TableGen/DAGISelMatcherEmitter.cpp | 55 ++++++++++++++-----
4 files changed, 63 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
index 40046e0a8dec9a..99ce658e7eb711 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
@@ -207,6 +207,14 @@ class SelectionDAGISel : public MachineFunctionPass {
OPC_CheckChild2CondCode,
OPC_CheckValueType,
OPC_CheckComplexPat,
+ OPC_CheckComplexPat0,
+ OPC_CheckComplexPat1,
+ OPC_CheckComplexPat2,
+ OPC_CheckComplexPat3,
+ OPC_CheckComplexPat4,
+ OPC_CheckComplexPat5,
+ OPC_CheckComplexPat6,
+ OPC_CheckComplexPat7,
OPC_CheckAndImm,
OPC_CheckOrImm,
OPC_CheckImmAllOnesV,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 99bb3d875d4fa5..0c708b3da58898 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -3361,8 +3361,18 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
break;
continue;
}
- case OPC_CheckComplexPat: {
- unsigned CPNum = MatcherTable[MatcherIndex++];
+ case OPC_CheckComplexPat:
+ case OPC_CheckComplexPat0:
+ case OPC_CheckComplexPat1:
+ case OPC_CheckComplexPat2:
+ case OPC_CheckComplexPat3:
+ case OPC_CheckComplexPat4:
+ case OPC_CheckComplexPat5:
+ case OPC_CheckComplexPat6:
+ case OPC_CheckComplexPat7: {
+ unsigned CPNum = Opcode == OPC_CheckComplexPat
+ ? MatcherTable[MatcherIndex++]
+ : Opcode - OPC_CheckComplexPat0;
unsigned RecNo = MatcherTable[MatcherIndex++];
assert(RecNo < RecordedNodes.size() && "Invalid CheckComplexPat");
diff --git a/llvm/test/TableGen/dag-isel-complexpattern.td b/llvm/test/TableGen/dag-isel-complexpattern.td
index 3d74e4e46dc41c..b8f517a1fc2890 100644
--- a/llvm/test/TableGen/dag-isel-complexpattern.td
+++ b/llvm/test/TableGen/dag-isel-complexpattern.td
@@ -22,7 +22,7 @@ def CP32 : ComplexPattern<i32, 0, "SelectCP32">;
def INSTR : Instruction {
// CHECK-LABEL: OPC_CheckOpcode, TARGET_VAL(ISD::STORE)
// CHECK: OPC_CheckTypeI32
-// CHECK: OPC_CheckComplexPat, /*CP*/0, /*#*/1, // SelectCP32:$
+// CHECK: OPC_CheckComplexPat0, /*#*/1, // SelectCP32:$
// CHECK: Src: (st (add:{ *:[i32] } (CP32:{ *:[i32] }), (CP32:{ *:[i32] })), i64:{ *:[i64] }:$addr)
let OutOperandList = (outs);
let InOperandList = (ins GPR64:$addr);
diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
index 6fd5698e7372e4..e460a2804c6649 100644
--- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
@@ -63,7 +63,6 @@ class MatcherTableEmitter {
StringMap<unsigned> PatternPredicateMap;
std::vector<std::string> PatternPredicates;
- DenseMap<const ComplexPattern*, unsigned> ComplexPatternMap;
std::vector<const ComplexPattern*> ComplexPatterns;
@@ -84,8 +83,38 @@ class MatcherTableEmitter {
}
public:
- MatcherTableEmitter(const CodeGenDAGPatterns &cgp)
- : CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) {}
+ MatcherTableEmitter(const Matcher *TheMatcher, const CodeGenDAGPatterns &cgp)
+ : CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) {
+ // Record the usage of ComplexPattern.
+ DenseMap<const ComplexPattern *, unsigned> ComplexPatternUsage;
+
+ // Iterate the whole MatcherTable once and do some statistics.
+ std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
+ while (N) {
+ if (auto *SM = dyn_cast<ScopeMatcher>(N))
+ for (unsigned I = 0; I < SM->getNumChildren(); I++)
+ Statistic(SM->getChild(I));
+ else if (auto *SOM = dyn_cast<SwitchOpcodeMatcher>(N))
+ for (unsigned I = 0; I < SOM->getNumCases(); I++)
+ Statistic(SOM->getCaseMatcher(I));
+ else if (auto *STM = dyn_cast<SwitchTypeMatcher>(N))
+ for (unsigned I = 0; I < STM->getNumCases(); I++)
+ Statistic(STM->getCaseMatcher(I));
+ else if (auto *CPM = dyn_cast<CheckComplexPatMatcher>(N))
+ ++ComplexPatternUsage[&CPM->getPattern()];
+ N = N->getNext();
+ }
+ };
+ Statistic(TheMatcher);
+
+ // Sort ComplexPatterns by usage.
+ std::vector<std::pair<const ComplexPattern *, unsigned>> ComplexPatternList(
+ ComplexPatternUsage.begin(), ComplexPatternUsage.end());
+ sort(ComplexPatternList,
+ [](const auto &A, const auto &B) { return A.second > B.second; });
+ for (const auto &ComplexPattern : ComplexPatternList)
+ ComplexPatterns.push_back(ComplexPattern.first);
+ }
unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
unsigned StartIdx, raw_ostream &OS);
@@ -146,12 +175,7 @@ class MatcherTableEmitter {
return Entry-1;
}
unsigned getComplexPat(const ComplexPattern &P) {
- unsigned &Entry = ComplexPatternMap[&P];
- if (Entry == 0) {
- ComplexPatterns.push_back(&P);
- Entry = ComplexPatterns.size();
- }
- return Entry-1;
+ return llvm::find(ComplexPatterns, &P) - ComplexPatterns.begin();
}
unsigned getNodeXFormID(Record *Rec) {
@@ -652,8 +676,13 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
case Matcher::CheckComplexPat: {
const CheckComplexPatMatcher *CCPM = cast<CheckComplexPatMatcher>(N);
const ComplexPattern &Pattern = CCPM->getPattern();
- OS << "OPC_CheckComplexPat, /*CP*/" << getComplexPat(Pattern) << ", /*#*/"
- << CCPM->getMatchNumber() << ',';
+ unsigned PatternNo = getComplexPat(Pattern);
+ if (PatternNo < 8)
+ OS << "OPC_CheckComplexPat" << PatternNo << ", /*#*/"
+ << CCPM->getMatchNumber() << ',';
+ else
+ OS << "OPC_CheckComplexPat, /*CP*/" << PatternNo << ", /*#*/"
+ << CCPM->getMatchNumber() << ',';
if (!OmitComments) {
OS << " // " << Pattern.getSelectFunc();
@@ -665,7 +694,7 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
OS << " + chain result";
}
OS << '\n';
- return 3;
+ return PatternNo < 8 ? 2 : 3;
}
case Matcher::CheckAndImm: {
@@ -1267,7 +1296,7 @@ void llvm::EmitMatcherTable(Matcher *TheMatcher,
OS << "#endif\n\n";
BeginEmitFunction(OS, "void", "SelectCode(SDNode *N)", false/*AddOverride*/);
- MatcherTableEmitter MatcherEmitter(CGP);
+ MatcherTableEmitter MatcherEmitter(TheMatcher, CGP);
// First we size all the children of the three kinds of matchers that have
// them. This is done by sharing the code in EmitMatcher(). but we don't
>From 97cc38fb4ec9308dfc341eea3f0b053be2346197 Mon Sep 17 00:00:00 2001
From: wangpc <wangpengcheng.pp at bytedance.com>
Date: Fri, 24 Nov 2023 19:45:06 +0800
Subject: [PATCH 2/2] [SelectionDAG] Add space-optimized forms of
OPC_CheckPatternPredicate
We record the usage of each `PatternPredicate` and sort them by
usage.
For the top 8 `PatternPredicate`s, we will emit a
`OPC_CheckPatternPredicateN` to save one byte.
The old `OPC_CheckPatternPredicate2` is renamed to
`OPC_CheckPatternPredicateTwoByte`.
Overall this reduces the llc binary size with all in-tree targets by
about 93K.
---
llvm/include/llvm/CodeGen/SelectionDAGISel.h | 8 +++++
.../CodeGen/SelectionDAG/SelectionDAGISel.cpp | 34 ++++++++++++++-----
llvm/utils/TableGen/DAGISelMatcherEmitter.cpp | 27 ++++++++++-----
3 files changed, 52 insertions(+), 17 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
index 99ce658e7eb711..e4d90f6e898fe8 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
@@ -159,7 +159,15 @@ class SelectionDAGISel : public MachineFunctionPass {
OPC_CheckChild2Same,
OPC_CheckChild3Same,
OPC_CheckPatternPredicate,
+ OPC_CheckPatternPredicate0,
+ OPC_CheckPatternPredicate1,
OPC_CheckPatternPredicate2,
+ OPC_CheckPatternPredicate3,
+ OPC_CheckPatternPredicate4,
+ OPC_CheckPatternPredicate5,
+ OPC_CheckPatternPredicate6,
+ OPC_CheckPatternPredicate7,
+ OPC_CheckPatternPredicateTwoByte,
OPC_CheckPredicate,
OPC_CheckPredicateWithOperands,
OPC_CheckOpcode,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 0c708b3da58898..50d85442f64005 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -2700,9 +2700,14 @@ LLVM_ATTRIBUTE_ALWAYS_INLINE static bool CheckChildSame(
/// CheckPatternPredicate - Implements OP_CheckPatternPredicate.
LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
-CheckPatternPredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex,
- const SelectionDAGISel &SDISel, bool TwoBytePredNo) {
- unsigned PredNo = MatcherTable[MatcherIndex++];
+CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,
+ unsigned &MatcherIndex, const SelectionDAGISel &SDISel) {
+ bool TwoBytePredNo =
+ Opcode == SelectionDAGISel::OPC_CheckPatternPredicateTwoByte;
+ unsigned PredNo =
+ TwoBytePredNo || Opcode == SelectionDAGISel::OPC_CheckPatternPredicate
+ ? MatcherTable[MatcherIndex++]
+ : Opcode - SelectionDAGISel::OPC_CheckPatternPredicate0;
if (TwoBytePredNo)
PredNo |= MatcherTable[MatcherIndex++] << 8;
return SDISel.CheckPatternPredicate(PredNo);
@@ -2854,10 +2859,16 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
Table[Index-1] - SelectionDAGISel::OPC_CheckChild0Same);
return Index;
case SelectionDAGISel::OPC_CheckPatternPredicate:
+ case SelectionDAGISel::OPC_CheckPatternPredicate0:
+ case SelectionDAGISel::OPC_CheckPatternPredicate1:
case SelectionDAGISel::OPC_CheckPatternPredicate2:
- Result = !::CheckPatternPredicate(
- Table, Index, SDISel,
- Table[Index - 1] == SelectionDAGISel::OPC_CheckPatternPredicate2);
+ case SelectionDAGISel::OPC_CheckPatternPredicate3:
+ case SelectionDAGISel::OPC_CheckPatternPredicate4:
+ case SelectionDAGISel::OPC_CheckPatternPredicate5:
+ case SelectionDAGISel::OPC_CheckPatternPredicate6:
+ case SelectionDAGISel::OPC_CheckPatternPredicate7:
+ case SelectionDAGISel::OPC_CheckPatternPredicateTwoByte:
+ Result = !::CheckPatternPredicate(Opcode, Table, Index, SDISel);
return Index;
case SelectionDAGISel::OPC_CheckPredicate:
Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode());
@@ -3339,9 +3350,16 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
continue;
case OPC_CheckPatternPredicate:
+ case OPC_CheckPatternPredicate0:
+ case OPC_CheckPatternPredicate1:
case OPC_CheckPatternPredicate2:
- if (!::CheckPatternPredicate(MatcherTable, MatcherIndex, *this,
- Opcode == OPC_CheckPatternPredicate2))
+ case OPC_CheckPatternPredicate3:
+ case OPC_CheckPatternPredicate4:
+ case OPC_CheckPatternPredicate5:
+ case OPC_CheckPatternPredicate6:
+ case OPC_CheckPatternPredicate7:
+ case OPC_CheckPatternPredicateTwoByte:
+ if (!::CheckPatternPredicate(Opcode, MatcherTable, MatcherIndex, *this))
break;
continue;
case OPC_CheckPredicate:
diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
index e460a2804c6649..6732b58661dc92 100644
--- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
@@ -60,7 +60,6 @@ class MatcherTableEmitter {
// all the patterns with "identical" predicates.
StringMap<TinyPtrVector<TreePattern *>> NodePredicatesByCodeToRun;
- StringMap<unsigned> PatternPredicateMap;
std::vector<std::string> PatternPredicates;
std::vector<const ComplexPattern*> ComplexPatterns;
@@ -87,6 +86,8 @@ class MatcherTableEmitter {
: CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) {
// Record the usage of ComplexPattern.
DenseMap<const ComplexPattern *, unsigned> ComplexPatternUsage;
+ // Record the usage of PatternPredicate.
+ std::map<StringRef, unsigned> PatternPredicateUsage;
// Iterate the whole MatcherTable once and do some statistics.
std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
@@ -102,6 +103,9 @@ class MatcherTableEmitter {
Statistic(STM->getCaseMatcher(I));
else if (auto *CPM = dyn_cast<CheckComplexPatMatcher>(N))
++ComplexPatternUsage[&CPM->getPattern()];
+ else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
+ ++PatternPredicateUsage[CPPM->getPredicate()];
+
N = N->getNext();
}
};
@@ -114,6 +118,14 @@ class MatcherTableEmitter {
[](const auto &A, const auto &B) { return A.second > B.second; });
for (const auto &ComplexPattern : ComplexPatternList)
ComplexPatterns.push_back(ComplexPattern.first);
+
+ // Sort PatternPredicates by usage.
+ std::vector<std::pair<std::string, unsigned>> PatternPredicateList(
+ PatternPredicateUsage.begin(), PatternPredicateUsage.end());
+ sort(PatternPredicateList,
+ [](const auto &A, const auto &B) { return A.second > B.second; });
+ for (const auto &PatternPredicate : PatternPredicateList)
+ PatternPredicates.push_back(PatternPredicate.first);
}
unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
@@ -167,12 +179,7 @@ class MatcherTableEmitter {
}
unsigned getPatternPredicate(StringRef PredName) {
- unsigned &Entry = PatternPredicateMap[PredName];
- if (Entry == 0) {
- PatternPredicates.push_back(PredName.str());
- Entry = PatternPredicates.size();
- }
- return Entry-1;
+ return llvm::find(PatternPredicates, PredName) - PatternPredicates.begin();
}
unsigned getComplexPat(const ComplexPattern &P) {
return llvm::find(ComplexPatterns, &P) - ComplexPatterns.begin();
@@ -510,13 +517,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
StringRef Pred = cast<CheckPatternPredicateMatcher>(N)->getPredicate();
unsigned PredNo = getPatternPredicate(Pred);
if (PredNo > 255)
- OS << "OPC_CheckPatternPredicate2, TARGET_VAL(" << PredNo << "),";
+ OS << "OPC_CheckPatternPredicateTwoByte, TARGET_VAL(" << PredNo << "),";
+ else if (PredNo < 8)
+ OS << "OPC_CheckPatternPredicate" << PredNo << ',';
else
OS << "OPC_CheckPatternPredicate, " << PredNo << ',';
if (!OmitComments)
OS << " // " << Pred;
OS << '\n';
- return 2 + (PredNo > 255);
+ return 2 + (PredNo > 255) - (PredNo < 8);
}
case Matcher::CheckPredicate: {
TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
More information about the llvm-commits
mailing list