[llvm] 5c8d123 - [SelectionDAG] Add space-optimized forms of OPC_CheckPatternPredicate (#73319)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 10 23:36:25 PST 2024


Author: Wang Pengcheng
Date: 2024-01-11T15:36:21+08:00
New Revision: 5c8d1238382ce3ef6004d9cbe3fe67b8342d868c

URL: https://github.com/llvm/llvm-project/commit/5c8d1238382ce3ef6004d9cbe3fe67b8342d868c
DIFF: https://github.com/llvm/llvm-project/commit/5c8d1238382ce3ef6004d9cbe3fe67b8342d868c.diff

LOG: [SelectionDAG] Add space-optimized forms of OPC_CheckPatternPredicate (#73319)


We record the usage of each `PatternPredicate` and sort them by
usage.

For the top 8 `PatternPredicate`s, we will emit a
`OPC_CheckPatternPredicateN` to save one byte.

The old `OPC_CheckPatternPredicate2` is renamed to
`OPC_CheckPatternPredicateTwoByte`.

Overall this reduces the llc binary size with all in-tree targets by
about 93K.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAGISel.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
    llvm/utils/TableGen/DAGISelMatcherEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
index 99ce658e7eb711..e4d90f6e898fe8 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
@@ -159,7 +159,15 @@ class SelectionDAGISel : public MachineFunctionPass {
     OPC_CheckChild2Same,
     OPC_CheckChild3Same,
     OPC_CheckPatternPredicate,
+    OPC_CheckPatternPredicate0,
+    OPC_CheckPatternPredicate1,
     OPC_CheckPatternPredicate2,
+    OPC_CheckPatternPredicate3,
+    OPC_CheckPatternPredicate4,
+    OPC_CheckPatternPredicate5,
+    OPC_CheckPatternPredicate6,
+    OPC_CheckPatternPredicate7,
+    OPC_CheckPatternPredicateTwoByte,
     OPC_CheckPredicate,
     OPC_CheckPredicateWithOperands,
     OPC_CheckOpcode,

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 344dc8d8a9b677..678d273e4bd605 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -2697,9 +2697,14 @@ LLVM_ATTRIBUTE_ALWAYS_INLINE static bool CheckChildSame(
 
 /// CheckPatternPredicate - Implements OP_CheckPatternPredicate.
 LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
-CheckPatternPredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex,
-                      const SelectionDAGISel &SDISel, bool TwoBytePredNo) {
-  unsigned PredNo = MatcherTable[MatcherIndex++];
+CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,
+                      unsigned &MatcherIndex, const SelectionDAGISel &SDISel) {
+  bool TwoBytePredNo =
+      Opcode == SelectionDAGISel::OPC_CheckPatternPredicateTwoByte;
+  unsigned PredNo =
+      TwoBytePredNo || Opcode == SelectionDAGISel::OPC_CheckPatternPredicate
+          ? MatcherTable[MatcherIndex++]
+          : Opcode - SelectionDAGISel::OPC_CheckPatternPredicate0;
   if (TwoBytePredNo)
     PredNo |= MatcherTable[MatcherIndex++] << 8;
   return SDISel.CheckPatternPredicate(PredNo);
@@ -2851,10 +2856,16 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
                         Table[Index-1] - SelectionDAGISel::OPC_CheckChild0Same);
     return Index;
   case SelectionDAGISel::OPC_CheckPatternPredicate:
+  case SelectionDAGISel::OPC_CheckPatternPredicate0:
+  case SelectionDAGISel::OPC_CheckPatternPredicate1:
   case SelectionDAGISel::OPC_CheckPatternPredicate2:
-    Result = !::CheckPatternPredicate(
-        Table, Index, SDISel,
-        Table[Index - 1] == SelectionDAGISel::OPC_CheckPatternPredicate2);
+  case SelectionDAGISel::OPC_CheckPatternPredicate3:
+  case SelectionDAGISel::OPC_CheckPatternPredicate4:
+  case SelectionDAGISel::OPC_CheckPatternPredicate5:
+  case SelectionDAGISel::OPC_CheckPatternPredicate6:
+  case SelectionDAGISel::OPC_CheckPatternPredicate7:
+  case SelectionDAGISel::OPC_CheckPatternPredicateTwoByte:
+    Result = !::CheckPatternPredicate(Opcode, Table, Index, SDISel);
     return Index;
   case SelectionDAGISel::OPC_CheckPredicate:
     Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode());
@@ -3336,9 +3347,16 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
       continue;
 
     case OPC_CheckPatternPredicate:
+    case OPC_CheckPatternPredicate0:
+    case OPC_CheckPatternPredicate1:
     case OPC_CheckPatternPredicate2:
-      if (!::CheckPatternPredicate(MatcherTable, MatcherIndex, *this,
-                                   Opcode == OPC_CheckPatternPredicate2))
+    case OPC_CheckPatternPredicate3:
+    case OPC_CheckPatternPredicate4:
+    case OPC_CheckPatternPredicate5:
+    case OPC_CheckPatternPredicate6:
+    case OPC_CheckPatternPredicate7:
+    case OPC_CheckPatternPredicateTwoByte:
+      if (!::CheckPatternPredicate(Opcode, MatcherTable, MatcherIndex, *this))
         break;
       continue;
     case OPC_CheckPredicate:

diff  --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
index e460a2804c6649..a3e2facf948e89 100644
--- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
@@ -60,7 +60,6 @@ class MatcherTableEmitter {
   // all the patterns with "identical" predicates.
   StringMap<TinyPtrVector<TreePattern *>> NodePredicatesByCodeToRun;
 
-  StringMap<unsigned> PatternPredicateMap;
   std::vector<std::string> PatternPredicates;
 
   std::vector<const ComplexPattern*> ComplexPatterns;
@@ -87,6 +86,8 @@ class MatcherTableEmitter {
       : CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) {
     // Record the usage of ComplexPattern.
     DenseMap<const ComplexPattern *, unsigned> ComplexPatternUsage;
+    // Record the usage of PatternPredicate.
+    std::map<StringRef, unsigned> PatternPredicateUsage;
 
     // Iterate the whole MatcherTable once and do some statistics.
     std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
@@ -102,6 +103,8 @@ class MatcherTableEmitter {
             Statistic(STM->getCaseMatcher(I));
         else if (auto *CPM = dyn_cast<CheckComplexPatMatcher>(N))
           ++ComplexPatternUsage[&CPM->getPattern()];
+        else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
+          ++PatternPredicateUsage[CPPM->getPredicate()];
         N = N->getNext();
       }
     };
@@ -114,6 +117,14 @@ class MatcherTableEmitter {
          [](const auto &A, const auto &B) { return A.second > B.second; });
     for (const auto &ComplexPattern : ComplexPatternList)
       ComplexPatterns.push_back(ComplexPattern.first);
+
+    // Sort PatternPredicates by usage.
+    std::vector<std::pair<std::string, unsigned>> PatternPredicateList(
+        PatternPredicateUsage.begin(), PatternPredicateUsage.end());
+    sort(PatternPredicateList,
+         [](const auto &A, const auto &B) { return A.second > B.second; });
+    for (const auto &PatternPredicate : PatternPredicateList)
+      PatternPredicates.push_back(PatternPredicate.first);
   }
 
   unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
@@ -167,12 +178,7 @@ class MatcherTableEmitter {
   }
 
   unsigned getPatternPredicate(StringRef PredName) {
-    unsigned &Entry = PatternPredicateMap[PredName];
-    if (Entry == 0) {
-      PatternPredicates.push_back(PredName.str());
-      Entry = PatternPredicates.size();
-    }
-    return Entry-1;
+    return llvm::find(PatternPredicates, PredName) - PatternPredicates.begin();
   }
   unsigned getComplexPat(const ComplexPattern &P) {
     return llvm::find(ComplexPatterns, &P) - ComplexPatterns.begin();
@@ -510,13 +516,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
     StringRef Pred = cast<CheckPatternPredicateMatcher>(N)->getPredicate();
     unsigned PredNo = getPatternPredicate(Pred);
     if (PredNo > 255)
-      OS << "OPC_CheckPatternPredicate2, TARGET_VAL(" << PredNo << "),";
+      OS << "OPC_CheckPatternPredicateTwoByte, TARGET_VAL(" << PredNo << "),";
+    else if (PredNo < 8)
+      OS << "OPC_CheckPatternPredicate" << PredNo << ',';
     else
       OS << "OPC_CheckPatternPredicate, " << PredNo << ',';
     if (!OmitComments)
       OS << " // " << Pred;
     OS << '\n';
-    return 2 + (PredNo > 255);
+    return 2 + (PredNo > 255) - (PredNo < 8);
   }
   case Matcher::CheckPredicate: {
     TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();


        


More information about the llvm-commits mailing list