[llvm] [SelectionDAG] Add space-optimized forms of OPC_CheckPredicate (PR #73488)

Wang Pengcheng via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 10 23:43:13 PST 2024


https://github.com/wangpc-pp updated https://github.com/llvm/llvm-project/pull/73488

>From e8c153340d1168e69701f5c9d404bda25f8b2f8e Mon Sep 17 00:00:00 2001
From: wangpc <wangpengcheng.pp at bytedance.com>
Date: Mon, 27 Nov 2023 16:28:36 +0800
Subject: [PATCH] [SelectionDAG] Add space-optimized forms of
 OPC_CheckPredicate

We record the usage of each `Predicate` and sort them by usage.

For the top 8 `Predicate`s, we will emit a `PC_CheckPredicateN` to
save one byte.

Overall this reduces the llc binary size with all in-tree targets by
about 61K.

This PR is stacked on #73310.
---
 llvm/include/llvm/CodeGen/SelectionDAGISel.h  |  8 ++
 .../CodeGen/SelectionDAG/SelectionDAGISel.cpp | 30 +++++-
 llvm/test/TableGen/address-space-patfrags.td  |  4 +-
 llvm/test/TableGen/predicate-patfags.td       |  4 +-
 llvm/utils/TableGen/DAGISelMatcherEmitter.cpp | 93 ++++++++++++-------
 5 files changed, 94 insertions(+), 45 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
index e4d90f6e898fe8..dbd9b391f4a431 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h
@@ -169,6 +169,14 @@ class SelectionDAGISel : public MachineFunctionPass {
     OPC_CheckPatternPredicate7,
     OPC_CheckPatternPredicateTwoByte,
     OPC_CheckPredicate,
+    OPC_CheckPredicate0,
+    OPC_CheckPredicate1,
+    OPC_CheckPredicate2,
+    OPC_CheckPredicate3,
+    OPC_CheckPredicate4,
+    OPC_CheckPredicate5,
+    OPC_CheckPredicate6,
+    OPC_CheckPredicate7,
     OPC_CheckPredicateWithOperands,
     OPC_CheckOpcode,
     OPC_SwitchOpcode,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 678d273e4bd605..359d738d2ca09f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -2712,9 +2712,13 @@ CheckPatternPredicate(unsigned Opcode, const unsigned char *MatcherTable,
 
 /// CheckNodePredicate - Implements OP_CheckNodePredicate.
 LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
-CheckNodePredicate(const unsigned char *MatcherTable, unsigned &MatcherIndex,
-                   const SelectionDAGISel &SDISel, SDNode *N) {
-  return SDISel.CheckNodePredicate(N, MatcherTable[MatcherIndex++]);
+CheckNodePredicate(unsigned Opcode, const unsigned char *MatcherTable,
+                   unsigned &MatcherIndex, const SelectionDAGISel &SDISel,
+                   SDNode *N) {
+  unsigned PredNo = Opcode == SelectionDAGISel::OPC_CheckPredicate
+                        ? MatcherTable[MatcherIndex++]
+                        : Opcode - SelectionDAGISel::OPC_CheckPredicate0;
+  return SDISel.CheckNodePredicate(N, PredNo);
 }
 
 LLVM_ATTRIBUTE_ALWAYS_INLINE static bool
@@ -2868,7 +2872,15 @@ static unsigned IsPredicateKnownToFail(const unsigned char *Table,
     Result = !::CheckPatternPredicate(Opcode, Table, Index, SDISel);
     return Index;
   case SelectionDAGISel::OPC_CheckPredicate:
-    Result = !::CheckNodePredicate(Table, Index, SDISel, N.getNode());
+  case SelectionDAGISel::OPC_CheckPredicate0:
+  case SelectionDAGISel::OPC_CheckPredicate1:
+  case SelectionDAGISel::OPC_CheckPredicate2:
+  case SelectionDAGISel::OPC_CheckPredicate3:
+  case SelectionDAGISel::OPC_CheckPredicate4:
+  case SelectionDAGISel::OPC_CheckPredicate5:
+  case SelectionDAGISel::OPC_CheckPredicate6:
+  case SelectionDAGISel::OPC_CheckPredicate7:
+    Result = !::CheckNodePredicate(Opcode, Table, Index, SDISel, N.getNode());
     return Index;
   case SelectionDAGISel::OPC_CheckOpcode:
     Result = !::CheckOpcode(Table, Index, N.getNode());
@@ -3359,8 +3371,16 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
       if (!::CheckPatternPredicate(Opcode, MatcherTable, MatcherIndex, *this))
         break;
       continue;
+    case SelectionDAGISel::OPC_CheckPredicate0:
+    case SelectionDAGISel::OPC_CheckPredicate1:
+    case SelectionDAGISel::OPC_CheckPredicate2:
+    case SelectionDAGISel::OPC_CheckPredicate3:
+    case SelectionDAGISel::OPC_CheckPredicate4:
+    case SelectionDAGISel::OPC_CheckPredicate5:
+    case SelectionDAGISel::OPC_CheckPredicate6:
+    case SelectionDAGISel::OPC_CheckPredicate7:
     case OPC_CheckPredicate:
-      if (!::CheckNodePredicate(MatcherTable, MatcherIndex, *this,
+      if (!::CheckNodePredicate(Opcode, MatcherTable, MatcherIndex, *this,
                                 N.getNode()))
         break;
       continue;
diff --git a/llvm/test/TableGen/address-space-patfrags.td b/llvm/test/TableGen/address-space-patfrags.td
index 27b174b4633cd8..4aec6ea7e0eae8 100644
--- a/llvm/test/TableGen/address-space-patfrags.td
+++ b/llvm/test/TableGen/address-space-patfrags.td
@@ -46,7 +46,7 @@ def inst_d : Instruction {
   let InOperandList = (ins GPR32:$src0, GPR32:$src1);
 }
 
-// SDAG: case 2: {
+// SDAG: case 1: {
 // SDAG-NEXT: // Predicate_pat_frag_b
 // SDAG-NEXT: // Predicate_truncstorei16_addrspace
 // SDAG-NEXT: SDNode *N = Node;
@@ -69,7 +69,7 @@ def : Pat <
 >;
 
 
-// SDAG: case 3: {
+// SDAG: case 6: {
 // SDAG: // Predicate_pat_frag_a
 // SDAG-NEXT: SDNode *N = Node;
 // SDAG-NEXT: (void)N;
diff --git a/llvm/test/TableGen/predicate-patfags.td b/llvm/test/TableGen/predicate-patfags.td
index 0912b05127ef81..2cf29769dc13a7 100644
--- a/llvm/test/TableGen/predicate-patfags.td
+++ b/llvm/test/TableGen/predicate-patfags.td
@@ -39,10 +39,10 @@ def TGTmul24_oneuse : PatFrag<
 }
 
 // SDAG: OPC_CheckOpcode, TARGET_VAL(ISD::INTRINSIC_W_CHAIN),
-// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
+// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse
 
 // SDAG: OPC_CheckOpcode, TARGET_VAL(TargetISD::MUL24),
-// SDAG: OPC_CheckPredicate, 0, // Predicate_TGTmul24_oneuse
+// SDAG: OPC_CheckPredicate0, // Predicate_TGTmul24_oneuse
 
 // GISEL: GIM_CheckOpcode, /*MI*/1, GIMT_Encode2(TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS),
 // GISEL: GIM_CheckIntrinsicID, /*MI*/1, /*Op*/1, GIMT_Encode2(Intrinsic::tgt_mul24),
diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
index a3e2facf948e89..69d040f9b85c49 100644
--- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
+++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp
@@ -52,9 +52,8 @@ class MatcherTableEmitter {
 
   SmallVector<unsigned, Matcher::HighestKind+1> OpcodeCounts;
 
-  DenseMap<TreePattern *, unsigned> NodePredicateMap;
-  std::vector<TreePredicateFn> NodePredicates;
-  std::vector<TreePredicateFn> NodePredicatesWithOperands;
+  std::vector<TreePattern *> NodePredicates;
+  std::vector<TreePattern *> NodePredicatesWithOperands;
 
   // We de-duplicate the predicates by code string, and use this map to track
   // all the patterns with "identical" predicates.
@@ -88,6 +87,8 @@ class MatcherTableEmitter {
     DenseMap<const ComplexPattern *, unsigned> ComplexPatternUsage;
     // Record the usage of PatternPredicate.
     std::map<StringRef, unsigned> PatternPredicateUsage;
+    // Record the usage of Predicate.
+    DenseMap<TreePattern *, unsigned> PredicateUsage;
 
     // Iterate the whole MatcherTable once and do some statistics.
     std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
@@ -105,6 +106,8 @@ class MatcherTableEmitter {
           ++ComplexPatternUsage[&CPM->getPattern()];
         else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
           ++PatternPredicateUsage[CPPM->getPredicate()];
+        else if (auto *PM = dyn_cast<CheckPredicateMatcher>(N))
+          ++PredicateUsage[PM->getPredicate().getOrigPatFragRecord()];
         N = N->getNext();
       }
     };
@@ -125,6 +128,39 @@ class MatcherTableEmitter {
          [](const auto &A, const auto &B) { return A.second > B.second; });
     for (const auto &PatternPredicate : PatternPredicateList)
       PatternPredicates.push_back(PatternPredicate.first);
+
+    // Sort Predicates by usage.
+    // Merge predicates with same code.
+    for (const auto &Usage : PredicateUsage) {
+      TreePattern *TP = Usage.first;
+      TreePredicateFn Pred(TP);
+      NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()].push_back(TP);
+    }
+
+    std::vector<std::pair<TreePattern *, unsigned>> PredicateList;
+    // Sum the usage.
+    for (auto &Predicate : NodePredicatesByCodeToRun) {
+      TinyPtrVector<TreePattern *> &TPs = Predicate.second;
+      sort(TPs, [](const auto *A, const auto *B) {
+        return A->getRecord()->getName() < B->getRecord()->getName();
+      });
+      unsigned Uses = 0;
+      for (TreePattern *TP : TPs)
+        Uses += PredicateUsage.at(TP);
+
+      // We only add the first predicate here since they are with the same code.
+      PredicateList.push_back({TPs[0], Uses});
+    }
+
+    sort(PredicateList,
+         [](const auto &A, const auto &B) { return A.second > B.second; });
+    for (const auto &Predicate : PredicateList) {
+      TreePattern *TP = Predicate.first;
+      if (TreePredicateFn(TP).usesOperands())
+        NodePredicatesWithOperands.push_back(TP);
+      else
+        NodePredicates.push_back(TP);
+    }
   }
 
   unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
@@ -139,7 +175,7 @@ class MatcherTableEmitter {
   void EmitPatternMatchTable(raw_ostream &OS);
 
 private:
-  void EmitNodePredicatesFunction(const std::vector<TreePredicateFn> &Preds,
+  void EmitNodePredicatesFunction(const std::vector<TreePattern *> &Preds,
                                   StringRef Decl, raw_ostream &OS);
 
   unsigned SizeMatcher(Matcher *N, raw_ostream &OS);
@@ -148,33 +184,13 @@ class MatcherTableEmitter {
                        raw_ostream &OS);
 
   unsigned getNodePredicate(TreePredicateFn Pred) {
-    TreePattern *TP = Pred.getOrigPatFragRecord();
-    unsigned &Entry = NodePredicateMap[TP];
-    if (Entry == 0) {
-      TinyPtrVector<TreePattern *> &SameCodePreds =
-          NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()];
-      if (SameCodePreds.empty()) {
-        // We've never seen a predicate with the same code: allocate an entry.
-        if (Pred.usesOperands()) {
-          NodePredicatesWithOperands.push_back(Pred);
-          Entry = NodePredicatesWithOperands.size();
-        } else {
-          NodePredicates.push_back(Pred);
-          Entry = NodePredicates.size();
-        }
-      } else {
-        // We did see an identical predicate: re-use it.
-        Entry = NodePredicateMap[SameCodePreds.front()];
-        assert(Entry != 0);
-        assert(TreePredicateFn(SameCodePreds.front()).usesOperands() ==
-               Pred.usesOperands() &&
-               "PatFrags with some code must have same usesOperands setting");
-      }
-      // In both cases, we've never seen this particular predicate before, so
-      // mark it in the list of predicates sharing the same code.
-      SameCodePreds.push_back(TP);
-    }
-    return Entry-1;
+    // We use the first predicate.
+    TreePattern *PredPat =
+        NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()][0];
+    return Pred.usesOperands()
+               ? llvm::find(NodePredicatesWithOperands, PredPat) -
+                     NodePredicatesWithOperands.begin()
+               : llvm::find(NodePredicates, PredPat) - NodePredicates.begin();
   }
 
   unsigned getPatternPredicate(StringRef PredName) {
@@ -529,6 +545,7 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
   case Matcher::CheckPredicate: {
     TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
     unsigned OperandBytes = 0;
+    unsigned PredNo = getNodePredicate(Pred);
 
     if (Pred.usesOperands()) {
       unsigned NumOps = cast<CheckPredicateMatcher>(N)->getNumOperands();
@@ -537,10 +554,15 @@ EmitMatcher(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
         OS << cast<CheckPredicateMatcher>(N)->getOperandNo(i) << ", ";
       OperandBytes = 1 + NumOps;
     } else {
-      OS << "OPC_CheckPredicate, ";
+      if (PredNo < 8) {
+        OperandBytes = -1;
+        OS << "OPC_CheckPredicate" << PredNo << ", ";
+      } else
+        OS << "OPC_CheckPredicate, ";
     }
 
-    OS << getNodePredicate(Pred) << ',';
+    if (PredNo >= 8 || Pred.usesOperands())
+      OS << PredNo << ',';
     if (!OmitComments)
       OS << " // " << Pred.getFnName();
     OS << '\n';
@@ -1029,8 +1051,7 @@ EmitMatcherList(const Matcher *N, const unsigned Indent, unsigned CurrentIdx,
 }
 
 void MatcherTableEmitter::EmitNodePredicatesFunction(
-    const std::vector<TreePredicateFn> &Preds, StringRef Decl,
-    raw_ostream &OS) {
+    const std::vector<TreePattern *> &Preds, StringRef Decl, raw_ostream &OS) {
   if (Preds.empty())
     return;
 
@@ -1040,7 +1061,7 @@ void MatcherTableEmitter::EmitNodePredicatesFunction(
   OS << "  default: llvm_unreachable(\"Invalid predicate in table?\");\n";
   for (unsigned i = 0, e = Preds.size(); i != e; ++i) {
     // Emit the predicate code corresponding to this pattern.
-    const TreePredicateFn PredFn = Preds[i];
+    TreePredicateFn PredFn(Preds[i]);
     assert(!PredFn.isAlwaysTrue() && "No code in this predicate");
     std::string PredFnCodeStr = PredFn.getCodeToRunOnSDNode();
 



More information about the llvm-commits mailing list