[Mlir-commits] [mlir] bb25060 - [mlir-tblgen] Add DagNode StaticMatcher.

Chia-hung Duan llvmlistbot at llvm.org
Mon Sep 20 16:40:19 PDT 2021


Author: Chia-hung Duan
Date: 2021-09-20T23:37:42Z
New Revision: bb2506061b06e9786b5eb9c458f52f9ba7e52a73

URL: https://github.com/llvm/llvm-project/commit/bb2506061b06e9786b5eb9c458f52f9ba7e52a73
DIFF: https://github.com/llvm/llvm-project/commit/bb2506061b06e9786b5eb9c458f52f9ba7e52a73.diff

LOG: [mlir-tblgen] Add DagNode StaticMatcher.

Some patterns may share the common DAG structures. Generate a static
function to do the match logic to reduce the binary size.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D105797

Added: 
    mlir/test/mlir-tblgen/rewriter-static-matcher.td

Modified: 
    mlir/include/mlir/TableGen/Pattern.h
    mlir/lib/TableGen/Pattern.cpp
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index a3786cd8e0b89..fdc510447d1fb 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -18,6 +18,7 @@
 #include "mlir/TableGen/Argument.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
 
@@ -198,6 +199,7 @@ class DagNode {
 
 private:
   friend class SymbolInfoMap;
+  friend llvm::DenseMapInfo<DagNode>;
   const void *getAsOpaquePointer() const { return node; }
 
   const llvm::DagInit *node; // nullptr means null DagNode
@@ -242,10 +244,17 @@ class SymbolInfoMap {
   // Class for information regarding a symbol.
   class SymbolInfo {
   public:
+    // Returns a type string of a variable.
+    std::string getVarTypeStr(StringRef name) const;
+
     // Returns a string for defining a variable named as `name` to store the
     // value bound by this symbol.
     std::string getVarDecl(StringRef name) const;
 
+    // Returns a string for defining an argument which passes the reference of
+    // the variable.
+    std::string getArgDecl(StringRef name) const;
+
     // Returns a variable name for the symbol named as `name`.
     std::string getVarName(StringRef name) const;
 
@@ -383,6 +392,7 @@ class SymbolInfoMap {
   // with index `argIndex` for operator `op`.
   const_iterator findBoundSymbol(StringRef key, DagNode node,
                                  const Operator &op, int argIndex) const;
+  const_iterator findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const;
 
   // Returns the bounds of a range that includes all the elements which
   // bind to the `key`.
@@ -474,15 +484,15 @@ class Pattern {
   // pair).
   std::vector<IdentifierLine> getLocation() const;
 
-private:
-  // Helper function to verify variabld binding.
-  void verifyBind(bool result, StringRef symbolName);
-
   // Recursively collects all bound symbols inside the DAG tree rooted
   // at `tree` and updates the given `infoMap`.
   void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
                            bool isSrcPattern);
 
+private:
+  // Helper function to verify variable binding.
+  void verifyBind(bool result, StringRef symbolName);
+
   // The TableGen definition of this pattern.
   const llvm::Record &def;
 
@@ -495,4 +505,24 @@ class Pattern {
 } // end namespace tblgen
 } // end namespace mlir
 
+namespace llvm {
+template <>
+struct DenseMapInfo<mlir::tblgen::DagNode> {
+  static mlir::tblgen::DagNode getEmptyKey() {
+    return mlir::tblgen::DagNode(
+        llvm::DenseMapInfo<llvm::DagInit *>::getEmptyKey());
+  }
+  static mlir::tblgen::DagNode getTombstoneKey() {
+    return mlir::tblgen::DagNode(
+        llvm::DenseMapInfo<llvm::DagInit *>::getTombstoneKey());
+  }
+  static unsigned getHashValue(mlir::tblgen::DagNode node) {
+    return llvm::hash_value(node.getAsOpaquePointer());
+  }
+  static bool isEqual(mlir::tblgen::DagNode lhs, mlir::tblgen::DagNode rhs) {
+    return lhs.node == rhs.node;
+  }
+};
+} // end namespace llvm
+
 #endif // MLIR_TABLEGEN_PATTERN_H_

diff  --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index ce225ed93076c..b78abf7c9e3bc 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -230,45 +230,50 @@ std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
   return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
 }
 
-std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
-  LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
+std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
+  LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
   switch (kind) {
   case Kind::Attr: {
-    if (op) {
-      auto type = op->getArg(getArgIndex())
-                      .get<NamedAttribute *>()
-                      ->attr.getStorageType();
-      return std::string(formatv("{0} {1};\n", type, name));
-    }
+    if (op)
+      return op->getArg(getArgIndex())
+          .get<NamedAttribute *>()
+          ->attr.getStorageType()
+          .str();
     // TODO(suderman): Use a more exact type when available.
-    return std::string(formatv("Attribute {0};\n", name));
+    return "Attribute";
   }
   case Kind::Operand: {
     // Use operand range for captured operands (to support potential variadic
     // operands).
-    return std::string(
-        formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
-                getVarName(name)));
+    return "::mlir::Operation::operand_range";
   }
   case Kind::Value: {
-    return std::string(formatv("::mlir::Value {0};\n", name));
+    return "::mlir::Value";
   }
   case Kind::MultipleValues: {
-    // This is for the variable used in the source pattern. Each named value in
-    // source pattern will only be bound to a Value. The others in the result
-    // pattern may be associated with multiple Values as we will use `auto` to
-    // do the type inference.
-    return std::string(formatv(
-        "::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name));
+    return "::mlir::ValueRange";
   }
   case Kind::Result: {
     // Use the op itself for captured results.
-    return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
+    return op->getQualCppClassName();
   }
   }
   llvm_unreachable("unknown kind");
 }
 
+std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
+  LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
+  std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
+  return std::string(
+      formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
+}
+
+std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
+  LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
+  return std::string(
+      formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
+}
+
 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
     StringRef name, int index, const char *fmt, const char *separator) const {
   LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
@@ -486,11 +491,14 @@ SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
 SymbolInfoMap::const_iterator
 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
                                int argIndex) const {
+  return findBoundSymbol(key, SymbolInfo::getOperand(node, &op, argIndex));
+}
+
+SymbolInfoMap::const_iterator
+SymbolInfoMap::findBoundSymbol(StringRef key, SymbolInfo symbolInfo) const {
   std::string name = getValuePackName(key).str();
   auto range = symbolInfoMap.equal_range(name);
 
-  const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex);
-
   for (auto it = range.first; it != range.second; ++it)
     if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
       return it;

diff  --git a/mlir/test/mlir-tblgen/rewriter-static-matcher.td b/mlir/test/mlir-tblgen/rewriter-static-matcher.td
new file mode 100644
index 0000000000000..cfd80a40fb1c9
--- /dev/null
+++ b/mlir/test/mlir-tblgen/rewriter-static-matcher.td
@@ -0,0 +1,48 @@
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "test";
+}
+class NS_Op<string mnemonic, list<OpTrait> traits> :
+    Op<Test_Dialect, mnemonic, traits>;
+
+def AOp : NS_Op<"a_op", []> {
+  let arguments = (ins
+    AnyInteger:$any_integer
+  );
+
+  let results = (outs AnyInteger);
+}
+
+def BOp : NS_Op<"b_op", []> {
+  let arguments = (ins
+    AnyAttr: $any_attr,
+    AnyInteger
+  );
+
+  let results = (outs AnyInteger);
+}
+
+def COp : NS_Op<"c_op", []> {
+  let arguments = (ins
+    AnyAttr: $any_attr,
+    AnyInteger
+  );
+
+  let results = (outs AnyInteger);
+}
+
+// Test static matcher for duplicate DagNode
+// ---
+
+// CHECK: static ::mlir::LogicalResult static_dag_matcher_0
+
+// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
+def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)),
+          (AOp $int)>;
+
+// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
+def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)),
+          (COp $attr, $int)>;

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 37fe800c2d4c1..1c2ba75de6d3b 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -18,6 +18,7 @@
 #include "mlir/TableGen/Pattern.h"
 #include "mlir/TableGen/Predicate.h"
 #include "mlir/TableGen/Type.h"
+#include "llvm/ADT/FunctionExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
@@ -54,13 +55,20 @@ struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
 //===----------------------------------------------------------------------===//
 
 namespace {
+
+class StaticMatcherHelper;
+
 class PatternEmitter {
 public:
-  PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
+  PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
+                 StaticMatcherHelper &helper);
 
   // Emits the mlir::RewritePattern struct named `rewriteName`.
   void emit(StringRef rewriteName);
 
+  // Emits the static function of DAG matcher.
+  void emitStaticMatcher(DagNode tree, std::string funcName);
+
 private:
   // Emits the code for matching ops.
   void emitMatchLogic(DagNode tree, StringRef opName);
@@ -75,6 +83,9 @@ class PatternEmitter {
   // Emits C++ statements for matching the DAG structure.
   void emitMatch(DagNode tree, StringRef name, int depth);
 
+  // Emit C++ function call to static DAG matcher.
+  void emitStaticMatchCall(DagNode tree, StringRef name);
+
   // Emits C++ statements for matching using a native code call.
   void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
 
@@ -216,6 +227,8 @@ class PatternEmitter {
   // Map for all bound symbols' info.
   SymbolInfoMap symbolInfoMap;
 
+  StaticMatcherHelper &staticMatcherHelper;
+
   // The next unused ID for newly created values.
   unsigned nextValueId;
 
@@ -223,16 +236,79 @@ class PatternEmitter {
 
   // Format contexts containing placeholder substitutions.
   FmtContext fmtCtx;
+};
+
+// Tracks DagNode's reference multiple times across patterns. Enables generating
+// static matcher functions for DagNode's referenced multiple times rather than
+// inlining them.
+class StaticMatcherHelper {
+public:
+  StaticMatcherHelper(RecordOperatorMap &mapper);
+
+  // Determine if we should inline the match logic or delegate to a static
+  // function.
+  bool useStaticMatcher(DagNode node) {
+    return refStats[node] > kStaticMatcherThreshold;
+  }
+
+  // Get the name of the static DAG matcher function corresponding to the node.
+  std::string getMatcherName(DagNode node) {
+    assert(useStaticMatcher(node));
+    return matcherNames[node];
+  }
+
+  // Collect the `Record`s, i.e., the DRR, so that we can get the information of
+  // the duplicated DAGs.
+  void addPattern(Record *record);
+
+  // Emit all static functions of DAG Matcher.
+  void populateStaticMatchers(raw_ostream &os);
 
-  // Number of op processed.
-  int opCounter = 0;
+private:
+  static constexpr unsigned kStaticMatcherThreshold = 1;
+
+  // Consider two patterns as down below,
+  //   DagNode_Root_A    DagNode_Root_B
+  //       \                 \
+  //     DagNode_C         DagNode_C
+  //         \                 \
+  //       DagNode_D         DagNode_D
+  //
+  // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
+  // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
+  // multiple times so we'll have static matchers for both of them. When we're
+  // emitting the match logic for DagNode_C, we will check if DagNode_D has the
+  // static matcher generated. If so, then we'll generate a call to the
+  // function, inline otherwise. In this case, inlining is not what we want. As
+  // a result, generate the static matcher in topological order to ensure all
+  // the dependent static matchers are generated and we can avoid accidentally
+  // inlining.
+  //
+  // The topological order of all the DagNodes among all patterns.
+  SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
+
+  RecordOperatorMap &opMap;
+
+  // Records of the static function name of each DagNode
+  DenseMap<DagNode, std::string> matcherNames;
+
+  // After collecting all the DagNode in each pattern, `refStats` records the
+  // number of users for each DagNode. We will generate the static matcher for a
+  // DagNode while the number of users exceeds a certain threshold.
+  DenseMap<DagNode, unsigned> refStats;
+
+  // Number of static matcher generated. This is used to generate a unique name
+  // for each DagNode.
+  int staticMatcherCounter = 0;
 };
+
 } // end anonymous namespace
 
 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
-                               raw_ostream &os)
+                               raw_ostream &os, StaticMatcherHelper &helper)
     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
-      symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
+      symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), nextValueId(0),
+      os(os) {
   fmtCtx.withBuilder("rewriter");
 }
 
@@ -246,6 +322,33 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
 }
 
+void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
+  os << formatv(
+      "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
+      "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
+      "*, 4> &tblgen_ops",
+      funcName);
+
+  // We pass the reference of the variables that need to be captured. Hence we
+  // need to collect all the symbols in the tree first.
+  pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
+  symbolInfoMap.assignUniqueAlternativeNames();
+  for (const auto &info : symbolInfoMap)
+    os << formatv(", {0}", info.second.getArgDecl(info.first));
+
+  os << ") {\n";
+  os.indent();
+  os << "(void)tblgen_ops;\n";
+
+  // Note that a static matcher is considered at least one step from the match
+  // entry.
+  emitMatch(tree, "op0", /*depth=*/1);
+
+  os << "return ::mlir::success();\n";
+  os.unindent();
+  os << "}\n\n";
+}
+
 // Helper function to match patterns.
 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
   if (tree.isNativeCodeCall()) {
@@ -261,6 +364,36 @@ void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
   PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
 }
 
+void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
+  std::string funcName = staticMatcherHelper.getMatcherName(tree);
+  os << formatv("if(failed({0}(rewriter, {1}, tblgen_ops", funcName, opName);
+
+  // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
+  // one pass.
+
+  // In general, bound symbol should have the unique name in the pattern but
+  // for the operand, binding same symbol to multiple operands imply a
+  // constraint at the same time. In this case, we will rename those operands
+  // with 
diff erent names. As a result, we need to collect all the symbolInfos
+  // from the DagNode then get the updated name of the local variables from the
+  // global symbolInfoMap.
+
+  // Collect all the bound symbols in the Dag
+  SymbolInfoMap localSymbolMap(loc);
+  pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
+
+  for (const auto &info : localSymbolMap) {
+    auto name = info.first;
+    auto symboInfo = info.second;
+    auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
+    os << formatv(", {0}", ret->second.getVarName(name));
+  }
+
+  os << "))) {\n";
+  os.scope().os << "return ::mlir::failure();\n";
+  os << "}\n";
+}
+
 // Helper function to match patterns.
 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
                                          int depth) {
@@ -268,6 +401,21 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
   LLVM_DEBUG(tree.print(llvm::dbgs()));
   LLVM_DEBUG(llvm::dbgs() << '\n');
 
+  // The order of generating static matcher follows the topological order so
+  // that for every dependent DagNode already have their static matcher
+  // generated if needed. The reason we check if `getMatcherName(tree).empty()`
+  // is when we are generating the static matcher for a DagNode itself. In this
+  // case, we need to emit the function body rather than a function call.
+  if (staticMatcherHelper.useStaticMatcher(tree) &&
+      !staticMatcherHelper.getMatcherName(tree).empty()) {
+    emitStaticMatchCall(tree, opName);
+
+    // NativeCodeCall will never be at depth 0 so that we don't need to catch
+    // the root operation as emitOpMatch();
+
+    return;
+  }
+
   // TODO(suderman): iterate through arguments, determine their types, output
   // names.
   SmallVector<std::string, 8> capture;
@@ -356,7 +504,28 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                           << op.getOperationName() << "' at depth " << depth
                           << '\n');
 
-  std::string castedName = formatv("castedOp{0}", depth);
+  auto getCastedName = [depth]() -> std::string {
+    return formatv("castedOp{0}", depth);
+  };
+
+  // The order of generating static matcher follows the topological order so
+  // that for every dependent DagNode already have their static matcher
+  // generated if needed. The reason we check if `getMatcherName(tree).empty()`
+  // is when we are generating the static matcher for a DagNode itself. In this
+  // case, we need to emit the function body rather than a function call.
+  if (staticMatcherHelper.useStaticMatcher(tree) &&
+      !staticMatcherHelper.getMatcherName(tree).empty()) {
+    emitStaticMatchCall(tree, opName);
+    // In the codegen of rewriter, we suppose that castedOp0 will capture the
+    // root operation. Manually add it if the root DagNode is a static matcher.
+    if (depth == 0)
+      os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
+                    "(void){2};\n",
+                    opName, op.getQualCppClassName(), getCastedName());
+    return;
+  }
+
+  std::string castedName = getCastedName();
   os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
                 "(void){0};\n",
                 castedName, opName, op.getQualCppClassName());
@@ -405,7 +574,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                      formatv("\"Operand {0} of {1} has null definingOp\"",
                              nextOperand++, castedName));
       emitMatch(argTree, argName, depth + 1);
-      os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
+      os << formatv("tblgen_ops.push_back({0});\n", argName);
       os.unindent() << "}\n";
       continue;
     }
@@ -704,13 +873,12 @@ void PatternEmitter::emit(StringRef rewriteName) {
       }
       // TODO: capture ops with consistent numbering so that it can be
       // reused for fused loc.
-      os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
-                    pattern.getSourcePattern().getNumOps());
+      os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
       LLVM_DEBUG(llvm::dbgs()
                  << "done creating local variables for capturing matches\n");
 
       os << "// Match\n";
-      os << "tblgen_ops[0] = op0;\n";
+      os << "tblgen_ops.push_back(op0);\n";
       emitMatchLogic(sourceTree, "op0");
 
       os << "\n// Rewrite\n";
@@ -1399,17 +1567,67 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
   }
 }
 
+StaticMatcherHelper::StaticMatcherHelper(RecordOperatorMap &mapper)
+    : opMap(mapper) {}
+
+void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
+  // PatternEmitter will use the static matcher if there's one generated. To
+  // ensure that all the dependent static matchers are generated before emitting
+  // the matching logic of the DagNode, we use topological order to achieve it.
+  for (auto &dagInfo : topologicalOrder) {
+    DagNode node = dagInfo.first;
+    if (!useStaticMatcher(node))
+      continue;
+
+    std::string funcName =
+        formatv("static_dag_matcher_{0}", staticMatcherCounter++);
+    assert(matcherNames.find(node) == matcherNames.end());
+    PatternEmitter(dagInfo.second, &opMap, os, *this)
+        .emitStaticMatcher(node, funcName);
+    matcherNames[node] = funcName;
+  }
+}
+
+void StaticMatcherHelper::addPattern(Record *record) {
+  Pattern pat(record, &opMap);
+
+  // While generating the function body of the DAG matcher, it may depends on
+  // other DAG matchers. To ensure the dependent matchers are ready, we compute
+  // the topological order for all the DAGs and emit the DAG matchers in this
+  // order.
+  llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
+    ++refStats[node];
+
+    if (refStats[node] != 1)
+      return;
+
+    for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
+      if (DagNode sibling = node.getArgAsNestedDag(i))
+        dfs(sibling);
+
+    topologicalOrder.push_back(std::make_pair(node, record));
+  };
+
+  dfs(pat.getSourcePattern());
+}
+
 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Rewriters", os);
 
   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
-  auto numPatterns = patterns.size();
 
   // We put the map here because it can be shared among multiple patterns.
   RecordOperatorMap recordOpMap;
 
+  // Exam all the patterns and generate static matcher for the duplicated
+  // DagNode.
+  StaticMatcherHelper staticMatcher(recordOpMap);
+  for (Record *p : patterns)
+    staticMatcher.addPattern(p);
+  staticMatcher.populateStaticMatchers(os);
+
   std::vector<std::string> rewriterNames;
-  rewriterNames.reserve(numPatterns);
+  rewriterNames.reserve(patterns.size());
 
   std::string baseRewriterName = "GeneratedConvert";
   int rewriterIndex = 0;
@@ -1425,7 +1643,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
     }
     LLVM_DEBUG(llvm::dbgs()
                << "=== start generating pattern '" << name << "' ===\n");
-    PatternEmitter(p, &recordOpMap, os).emit(name);
+    PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
     LLVM_DEBUG(llvm::dbgs()
                << "=== done generating pattern '" << name << "' ===\n");
     rewriterNames.push_back(std::move(name));


        


More information about the Mlir-commits mailing list