[llvm] 55c5740 - [gicombiner] Add support for arbitrary match data being passed from match to apply

Daniel Sanders via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 18 04:27:58 PST 2019


Author: Daniel Sanders
Date: 2019-12-18T12:27:29Z
New Revision: 55c57408b0e70b188b0505e011172f13ec3b15fc

URL: https://github.com/llvm/llvm-project/commit/55c57408b0e70b188b0505e011172f13ec3b15fc
DIFF: https://github.com/llvm/llvm-project/commit/55c57408b0e70b188b0505e011172f13ec3b15fc.diff

LOG: [gicombiner] Add support for arbitrary match data being passed from match to apply

Summary:
This is used by the extending_loads combine to tell the apply step which
use is the preferred one to fold and the other uses should be re-written
to consume.

Depends on D69117

Reviewers: volkan, bogner

Reviewed By: volkan

Subscribers: hiraditya, Petar.Avramovic, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D69147

Added: 
    

Modified: 
    llvm/include/llvm/Target/GlobalISel/Combine.td
    llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp
    llvm/utils/TableGen/GICombinerEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index dcac399fd693..577a251da241 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -66,11 +66,20 @@ class GIDefKindWithArgs;
 ///       is incorrect.
 def root : GIDefKind;
 
+/// Declares data that is passed from the match stage to the apply stage.
+class GIDefMatchData<string type> : GIDefKind {
+  /// A C++ type name indicating the storage type.
+  string Type = type;
+}
+
+def extending_load_matchdata : GIDefMatchData<"PreferredTuple">;
+
 /// The operator at the root of a GICombineRule.Match dag.
 def match;
 /// All arguments of the match operator must be either:
 /// * A subclass of GIMatchKind
 /// * A subclass of GIMatchKindWithArgs
+/// * A subclass of Instruction
 /// * A MIR code block (deprecated)
 /// The GIMatchKind and GIMatchKindWithArgs cases are described in more detail
 /// in their definitions below.
@@ -93,6 +102,11 @@ def copy_prop : GICombineRule<
   (apply [{ Helper.applyCombineCopy(${d}); }])>;
 def trivial_combines : GICombineGroup<[copy_prop]>;
 
+def extending_loads : GICombineRule<
+  (defs root:$root, extending_load_matchdata:$matchinfo),
+  (match [{ return Helper.matchCombineExtendingLoads(${root}, ${matchinfo}); }]),
+  (apply [{ Helper.applyCombineExtendingLoads(${root}, ${matchinfo}); }])>;
+
 // FIXME: Is there a reason this wasn't in tryCombine? I've left it out of
 //        all_combines because it wasn't there.
 def elide_br_by_inverting_cond : GICombineRule<
@@ -100,4 +114,6 @@ def elide_br_by_inverting_cond : GICombineRule<
   (match [{ return Helper.matchElideBrByInvertingCond(${d}); }]),
   (apply [{ Helper.applyElideBrByInvertingCond(${d}); }])>;
 
-def all_combines : GICombineGroup<[trivial_combines]>;
+def combines_for_extload: GICombineGroup<[extending_loads]>;
+
+def all_combines : GICombineGroup<[trivial_combines, combines_for_extload]>;

diff  --git a/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp
index d30ea120bae4..0956930e2e8c 100644
--- a/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp
@@ -62,20 +62,6 @@ bool AArch64PreLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
   CombinerHelper Helper(Observer, B, KB, MDT);
 
   switch (MI.getOpcode()) {
-  case TargetOpcode::G_CONCAT_VECTORS:
-    return Helper.tryCombineConcatVectors(MI);
-  case TargetOpcode::G_SHUFFLE_VECTOR:
-    return Helper.tryCombineShuffleVector(MI);
-  case TargetOpcode::G_LOAD:
-  case TargetOpcode::G_SEXTLOAD:
-  case TargetOpcode::G_ZEXTLOAD: {
-    bool Changed = false;
-    Changed |= Helper.tryCombineExtendingLoads(MI);
-    Changed |= Helper.tryCombineIndexedLoadStore(MI);
-    return Changed;
-  }
-  case TargetOpcode::G_STORE:
-    return Helper.tryCombineIndexedLoadStore(MI);
   case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
     switch (MI.getIntrinsicID()) {
     case Intrinsic::memcpy:
@@ -96,6 +82,18 @@ bool AArch64PreLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
   if (Generated.tryCombineAll(Observer, MI, B))
     return true;
 
+  switch (MI.getOpcode()) {
+  case TargetOpcode::G_CONCAT_VECTORS:
+    return Helper.tryCombineConcatVectors(MI);
+  case TargetOpcode::G_SHUFFLE_VECTOR:
+    return Helper.tryCombineShuffleVector(MI);
+  case TargetOpcode::G_LOAD:
+  case TargetOpcode::G_SEXTLOAD:
+  case TargetOpcode::G_ZEXTLOAD:
+  case TargetOpcode::G_STORE:
+    return Helper.tryCombineIndexedLoadStore(MI);
+  }
+
   return false;
 }
 

diff  --git a/llvm/utils/TableGen/GICombinerEmitter.cpp b/llvm/utils/TableGen/GICombinerEmitter.cpp
index 223a929a6889..c83936f2ba45 100644
--- a/llvm/utils/TableGen/GICombinerEmitter.cpp
+++ b/llvm/utils/TableGen/GICombinerEmitter.cpp
@@ -61,6 +61,24 @@ StringRef insertStrTab(StringRef S) {
   return StrTab.insert(S).first->first();
 }
 
+/// Declares data that is passed from the match stage to the apply stage.
+class MatchDataInfo {
+  /// The symbol used in the tablegen patterns
+  StringRef PatternSymbol;
+  /// The data type for the variable
+  StringRef Type;
+  /// The name of the variable as declared in the generated matcher.
+  std::string VariableName;
+
+public:
+  MatchDataInfo(StringRef PatternSymbol, StringRef Type, StringRef VariableName)
+      : PatternSymbol(PatternSymbol), Type(Type), VariableName(VariableName) {}
+
+  StringRef getPatternSymbol() const { return PatternSymbol; };
+  StringRef getType() const { return Type; };
+  StringRef getVariableName() const { return VariableName; };
+};
+
 class RootInfo {
   StringRef PatternSymbol;
 
@@ -71,6 +89,10 @@ class RootInfo {
 };
 
 class CombineRule {
+public:
+
+  using const_matchdata_iterator = std::vector<MatchDataInfo>::const_iterator;
+
   struct VarInfo {
     const GIMatchDagInstr *N;
     const GIMatchDagOperand *Op;
@@ -108,6 +130,33 @@ class CombineRule {
   /// FIXME: This is a temporary measure until we have actual pattern matching
   const CodeInit *MatchingFixupCode = nullptr;
 
+  /// The MatchData defined by the match stage and required by the apply stage.
+  /// This allows the plumbing of arbitrary data from C++ predicates between the
+  /// stages.
+  ///
+  /// For example, suppose you have:
+  ///   %A = <some-constant-expr>
+  ///   %0 = G_ADD %1, %A
+  /// you could define a GIMatchPredicate that walks %A, constant folds as much
+  /// as possible and returns an APInt containing the discovered constant. You
+  /// could then declare:
+  ///   def apint : GIDefMatchData<"APInt">;
+  /// add it to the rule with:
+  ///   (defs root:$root, apint:$constant)
+  /// evaluate it in the pattern with a C++ function that takes a
+  /// MachineOperand& and an APInt& with:
+  ///   (match [{MIR %root = G_ADD %0, %A }],
+  ///             (constantfold operand:$A, apint:$constant))
+  /// and finally use it in the apply stage with:
+  ///   (apply (create_operand
+  ///                [{ MachineOperand::CreateImm(${constant}.getZExtValue());
+  ///                ]}, apint:$constant),
+  ///             [{MIR %root = FOO %0, %constant }])
+  std::vector<MatchDataInfo> MatchDataDecls;
+
+  void declareMatchData(StringRef PatternSymbol, StringRef Type,
+                        StringRef VarName);
+
   bool parseInstructionMatcher(const CodeGenTarget &Target, StringInit *ArgName,
                                const Init &Arg,
                                StringMap<std::vector<VarInfo>> &NamedEdgeDefs,
@@ -139,6 +188,16 @@ class CombineRule {
     return llvm::make_range(Roots.begin(), Roots.end());
   }
 
+  iterator_range<const_matchdata_iterator> matchdata_decls() const {
+    return make_range(MatchDataDecls.begin(), MatchDataDecls.end());
+  }
+
+  /// Export expansions for this rule
+  void declareExpansions(CodeExpansions &Expansions) const {
+    for (const auto &I : matchdata_decls())
+      Expansions.declare(I.getPatternSymbol(), I.getVariableName());
+  }
+
   /// The matcher will begin from the roots and will perform the match by
   /// traversing the edges to cover the whole DAG. This function reverses DAG
   /// edges such that everything is reachable from a root. This is part of the
@@ -243,6 +302,11 @@ StringRef makeNameForAnonPredicate(CombineRule &Rule) {
       to_string(format("__anonpred%d_%d", Rule.getID(), Rule.allocUID())));
 }
 
+void CombineRule::declareMatchData(StringRef PatternSymbol, StringRef Type,
+                                   StringRef VarName) {
+  MatchDataDecls.emplace_back(PatternSymbol, Type, VarName);
+}
+
 bool CombineRule::parseDefs() {
   NamedRegionTimer T("parseDefs", "Time spent parsing the defs", "Rule Parsing",
                      "Time spent on rule parsing", TimeRegions);
@@ -260,6 +324,17 @@ bool CombineRule::parseDefs() {
       continue;
     }
 
+    // Subclasses of GIDefMatchData should declare that this rule needs to pass
+    // data from the match stage to the apply stage, and ensure that the
+    // generated matcher has a suitable variable for it to do so.
+    if (Record *MatchDataRec =
+            getDefOfSubClass(*Defs->getArg(I), "GIDefMatchData")) {
+      declareMatchData(Defs->getArgNameStr(I),
+                       MatchDataRec->getValueAsString("Type"),
+                       llvm::to_string(llvm::format("MatchData%d", ID)));
+      continue;
+    }
+
     // Otherwise emit an appropriate error message.
     if (getDefOfSubClass(*Defs->getArg(I), "GIDefKind"))
       PrintError(TheDef.getLoc(),
@@ -556,6 +631,8 @@ void GICombinerEmitter::generateCodeForRule(raw_ostream &OS,
     for (const RootInfo &Root : Rule->roots()) {
       Expansions.declare(Root.getPatternSymbol(), "MI");
     }
+    Rule->declareExpansions(Expansions);
+
     DagInit *Applyer = RuleDef.getValueAsDag("Apply");
     if (Applyer->getOperatorAsDef(RuleDef.getLoc())->getName() !=
         "apply") {
@@ -695,6 +772,12 @@ void GICombinerEmitter::run(raw_ostream &OS) {
      << "  MachineRegisterInfo &MRI = MF->getRegInfo();\n"
      << "  (void)MBB; (void)MF; (void)MRI;\n\n";
 
+  OS << "  // Match data\n";
+  for (const auto &Rule : Rules)
+    for (const auto &I : Rule->matchdata_decls())
+      OS << "  " << I.getType() << " " << I.getVariableName() << ";\n";
+  OS << "\n";
+
   for (const auto &Rule : Rules)
     generateCodeForRule(OS, Rule.get(), "  ");
   OS << "\n  return false;\n"


        


More information about the llvm-commits mailing list