[Mlir-commits] [mlir] MLIR-QUERY DefinitionsMatcher implementation	& DAG (PR #115670)
    Denzel-Brian Budii 
    llvmlistbot at llvm.org
       
    Sun Nov 10 14:15:35 PST 2024
    
    
  
https://github.com/chios202 created https://github.com/llvm/llvm-project/pull/115670
This Pull Request aims to enhance the MLIR-QUERY tool by introducing a new matcher that identifies defining operations for a given operation, tracing back through dependencies up to n hops, including operations within regions. To support this functionality, a new data structure, BoundOperationsGraphBuilder, has been introduced. This structure constructs a Directed Acyclic Graph (DAG) to represent the relationships between operations, allowing for efficient dependency analysis. This addition will improve the tool's capability to query and understand how operations in MLIR modules are interrelated, which is crucial for debugging, optimization, and transformation tasks.
>From 187026822ad0ff0278cc84fc5a614e8ab4f18e36 Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Wed, 23 Oct 2024 17:42:19 +0000
Subject: [PATCH] MLIR-QUERY DefinitionsMatcher implementation & DAG 	-
 included printing logic for DAG 	- sfinae for match methods
---
 .../mlir/Query/Matcher/ExtraMatchers.h        | 109 ++++++++++++++
 mlir/include/mlir/Query/Matcher/Marshallers.h |  17 ++-
 mlir/include/mlir/Query/Matcher/MatchFinder.h |  18 ++-
 .../mlir/Query/Matcher/MatchersInternal.h     | 134 +++++++++++++----
 .../include/mlir/Query/Matcher/VariantValue.h |  10 +-
 mlir/lib/Query/Matcher/Parser.cpp             |  59 ++++++++
 mlir/lib/Query/Matcher/RegistryManager.cpp    |   2 +
 mlir/lib/Query/Matcher/VariantValue.cpp       |  23 +++
 mlir/lib/Query/Query.cpp                      | 141 +++++++++++++++---
 mlir/tools/mlir-query/mlir-query.cpp          |   7 +-
 10 files changed, 463 insertions(+), 57 deletions(-)
 create mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 00000000000000..1764ad35cc9c30
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,109 @@
+//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides extra matchers that are very useful for mlir-query
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTRAMATCHERS_H
+#define MLIR_IR_EXTRAMATCHERS_H
+
+#include "MatchFinder.h"
+#include "MatchersInternal.h"
+
+namespace mlir {
+
+namespace query {
+
+namespace extramatcher {
+
+namespace detail {
+
+class DefinitionsMatcher {
+public:
+  DefinitionsMatcher(matcher::DynMatcher &&InnerMatcher, unsigned Hops)
+      : InnerMatcher(std::move(InnerMatcher)), Hops(Hops) {}
+
+private:
+  bool matches(Operation *op, matcher::BoundOperationsGraphBuilder &Bound,
+               unsigned TempHops) {
+
+    llvm::DenseSet<mlir::Value> Ccache;
+    llvm::SmallVector<std::pair<Operation *, size_t>, 4> TempStorage;
+    TempStorage.push_back({op, TempHops});
+    while (!TempStorage.empty()) {
+      auto [CurrentOp, RemainingHops] = TempStorage.pop_back_val();
+
+      matcher::BoundOperationNode *CurrentNode =
+          Bound.addNode(CurrentOp, true, true);
+      if (RemainingHops == 0) {
+        continue;
+      }
+
+      for (auto Operand : CurrentOp->getOperands()) {
+        if (auto DefiningOp = Operand.getDefiningOp()) {
+          Bound.addEdge(CurrentOp, DefiningOp);
+          if (!Ccache.contains(Operand)) {
+            Ccache.insert(Operand);
+            TempStorage.emplace_back(DefiningOp, RemainingHops - 1);
+          }
+        } else if (auto BlockArg = Operand.dyn_cast<BlockArgument>()) {
+          auto *Block = BlockArg.getOwner();
+
+          if (Block->isEntryBlock() &&
+              isa<FunctionOpInterface>(Block->getParentOp())) {
+            continue;
+          }
+
+          Operation *ParentOp = BlockArg.getOwner()->getParentOp();
+          if (ParentOp) {
+            Bound.addEdge(CurrentOp, ParentOp);
+            if (!!Ccache.contains(BlockArg)) {
+              Ccache.insert(BlockArg);
+              TempStorage.emplace_back(ParentOp, RemainingHops - 1);
+            }
+          }
+        }
+      }
+    }
+    // We need at least 1 defining op
+    return Ccache.size() >= 2;
+  }
+
+public:
+  bool match(Operation *op, matcher::BoundOperationsGraphBuilder &Bound) {
+    if (InnerMatcher.match(op) && matches(op, Bound, Hops)) {
+      return true;
+    }
+    return false;
+  }
+
+private:
+  matcher::DynMatcher InnerMatcher;
+  unsigned Hops;
+};
+} // namespace detail
+
+inline detail::DefinitionsMatcher
+definedBy(mlir::query::matcher::DynMatcher InnerMatcher) {
+  return detail::DefinitionsMatcher(std::move(InnerMatcher), 1);
+}
+
+inline detail::DefinitionsMatcher
+getDefinitions(mlir::query::matcher::DynMatcher InnerMatcher, unsigned Hops) {
+  assert(Hops > 0 && "hops must be >= 1");
+  return detail::DefinitionsMatcher(std::move(InnerMatcher), Hops);
+}
+
+} // namespace extramatcher
+
+} // namespace query
+
+} // namespace mlir
+
+#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc7..4a08b9af82c26c 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
   }
 };
 
+template <>
+struct ArgTypeTraits<unsigned> {
+  static bool hasCorrectType(const VariantValue &value) {
+    return value.isUnsigned();
+  }
+
+  static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+
+  static ArgKind getKind() { return ArgKind::Unsigned; }
+
+  static std::optional<std::string> getBestGuess(const VariantValue &) {
+    return std::nullopt;
+  }
+};
+
 template <>
 struct ArgTypeTraits<DynMatcher> {
 
@@ -166,7 +181,7 @@ matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
     ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
         ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
     return VariantMatcher::SingleMatcher(
-        *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
+        *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer, matcherName));
   }
 
   return VariantMatcher();
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2a..4664e48b51b94a 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -15,6 +15,7 @@
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
 
 #include "MatchersInternal.h"
+#include "mlir/IR/Operation.h"
 
 namespace mlir::query::matcher {
 
@@ -22,17 +23,18 @@ namespace mlir::query::matcher {
 class MatchFinder {
 public:
   // Returns all operations that match the given matcher.
-  static std::vector<Operation *> getMatches(Operation *root,
-                                             DynMatcher matcher) {
-    std::vector<Operation *> matches;
+  static BoundOperationsGraphBuilder getMatches(Operation *root,
+                                                DynMatcher matcher) {
 
-    // Simple match finding with walk.
+    BoundOperationsGraphBuilder Bound;
     root->walk([&](Operation *subOp) {
-      if (matcher.match(subOp))
-        matches.push_back(subOp);
+      if (matcher.match(subOp)) {
+        matcher::BoundOperationNode *currentNode = Bound.addNode(subOp);
+      } else if (matcher.match(subOp, Bound)) {
+        ////
+      }
     });
-
-    return matches;
+    return Bound;
   }
 };
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e3..cb4063dc284526 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,19 +1,8 @@
 //===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
 //
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Part of the LLVM Project, under the Apache License v2.0 wIth LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Implements the base layer of the matcher framework.
-//
-// Matchers are methods that return a Matcher which provides a method
-// match(Operation *op)
-//
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
-// This file contains the wrapper classes needed to construct matchers for
-// mlir-query.
+// SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
 
@@ -22,16 +11,91 @@
 
 #include "mlir/IR/Matchers.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/ADT/MapVector.h"
+#include <memory>
+#include <stack>
+#include <unordered_set>
+#include <vector>
 
 namespace mlir::query::matcher {
 
+struct BoundOperationNode {
+  Operation *op;
+  std::vector<BoundOperationNode *> Parents;
+  std::vector<BoundOperationNode *> Children;
+
+  bool IsRootNode;
+  bool DetailedPrinting;
+
+  BoundOperationNode(Operation *op, bool IsRootNode = false,
+                     bool DetailedPrinting = false)
+      : op(op), IsRootNode(IsRootNode), DetailedPrinting(DetailedPrinting) {}
+};
+
+class BoundOperationsGraphBuilder {
+public:
+  BoundOperationNode *addNode(Operation *op, bool IsRootNode = false,
+                              bool DetailedPrinting = false) {
+    auto It = Nodes.find(op);
+    if (It != Nodes.end()) {
+      return It->second.get();
+    }
+    auto Node =
+        std::make_unique<BoundOperationNode>(op, IsRootNode, DetailedPrinting);
+    BoundOperationNode *NodePtr = Node.get();
+    Nodes[op] = std::move(Node);
+    return NodePtr;
+  }
+
+  void addEdge(Operation *parentOp, Operation *childOp) {
+    BoundOperationNode *ParentNode = addNode(parentOp, false, false);
+    BoundOperationNode *ChildNode = addNode(childOp, false, false);
+
+    ParentNode->Children.push_back(ChildNode);
+    ChildNode->Parents.push_back(ParentNode);
+  }
+
+  BoundOperationNode *getNode(Operation *op) const {
+    auto It = Nodes.find(op);
+    return It != Nodes.end() ? It->second.get() : nullptr;
+  }
+
+  const llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> &
+  getNodes() const {
+    return Nodes;
+  }
+
+private:
+  llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> Nodes;
+};
+
+// Type traIt to detect if a matcher has a match(Operation*) method
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+                               std::declval<Operation *>()))>>
+    : std::true_type {};
+
+// Type traIt to detect if a matcher has a match(Operation*,
+// BoundOperationsGraphBuilder&) method
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+                              std::declval<Operation *>(),
+                              std::declval<BoundOperationsGraphBuilder &>()))>>
+    : std::true_type {};
+
 // Generic interface for matchers on an MLIR operation.
 class MatcherInterface
     : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
 public:
   virtual ~MatcherInterface() = default;
-
   virtual bool match(Operation *op) = 0;
+  virtual bool match(Operation *op, BoundOperationsGraphBuilder &bound) = 0;
 };
 
 // MatcherFnImpl takes a matcher function object and implements
@@ -40,40 +104,56 @@ template <typename MatcherFn>
 class MatcherFnImpl : public MatcherInterface {
 public:
   MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
-  bool match(Operation *op) override { return matcherFn.match(op); }
+
+  bool match(Operation *op) override {
+    if constexpr (has_simple_match<MatcherFn>::value)
+      return matcherFn.match(op);
+    return false;
+  }
+
+  bool match(Operation *op, BoundOperationsGraphBuilder &bound) override {
+    if constexpr (has_bound_match<MatcherFn>::value)
+      return matcherFn.match(op, bound);
+    return false;
+  }
 
 private:
   MatcherFn matcherFn;
 };
 
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
 class DynMatcher {
 public:
   // Takes ownership of the provided implementation pointer.
-  DynMatcher(MatcherInterface *implementation)
-      : implementation(implementation) {}
+  DynMatcher(MatcherInterface *implementation, StringRef matcherName)
+      : implementation(implementation), matcherName(matcherName.str()) {}
 
   template <typename MatcherFn>
   static std::unique_ptr<DynMatcher>
-  constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
+  constructDynMatcherFromMatcherFn(MatcherFn &matcherFn,
+                                   StringRef matcherName) {
     auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
-    return std::make_unique<DynMatcher>(impl.release());
+    return std::make_unique<DynMatcher>(impl.release(), matcherName);
   }
 
   bool match(Operation *op) const { return implementation->match(op); }
+  bool match(Operation *op, BoundOperationsGraphBuilder &bound) const {
+    return implementation->match(op, bound);
+  }
 
-  void setFunctionName(StringRef name) { functionName = name.str(); };
-
-  bool hasFunctionName() const { return !functionName.empty(); };
-
-  StringRef getFunctionName() const { return functionName; };
+  void setFunctionName(StringRef name) { functionName = name.str(); }
+  void setMatcherName(StringRef name) { matcherName = name.str(); }
+  bool hasFunctionName() const { return !functionName.empty(); }
+  StringRef getFunctionName() const { return functionName; }
+  StringRef getMatcherName() const { return matcherName; }
 
 private:
   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
+  std::string matcherName;
   std::string functionName;
 };
 
 } // namespace mlir::query::matcher
 
-#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e021..73d96a6913dfe4 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
 namespace mlir::query::matcher {
 
 // All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Matcher, String, Unsigned };
 
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
@@ -81,6 +81,7 @@ class VariantValue {
   // Specific constructors for each supported type.
   VariantValue(const llvm::StringRef string);
   VariantValue(const VariantMatcher &matcher);
+  VariantValue(unsigned Unsigned);
 
   // String value functions.
   bool isString() const;
@@ -92,6 +93,11 @@ class VariantValue {
   const VariantMatcher &getMatcher() const;
   void setMatcher(const VariantMatcher &matcher);
 
+  // Unsigned value functions.
+  bool isUnsigned() const;
+  unsigned getUnsigned() const;
+  void setUnsigned(unsigned Unsigned);
+
   // String representation of the type of the value.
   std::string getTypeAsString() const;
 
@@ -103,12 +109,14 @@ class VariantValue {
     Nothing,
     String,
     Matcher,
+    Unsigned,
   };
 
   // All supported value types.
   union AllValues {
     llvm::StringRef *String;
     VariantMatcher *Matcher;
+    unsigned Unsigned;
   };
 
   ValueType type;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f7..4f1b716756e318 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
     case '\'':
       consumeStringLiteral(&result);
       break;
+    case '0':
+    case '1':
+    case '2':
+    case '3':
+    case '4':
+    case '5':
+    case '6':
+    case '7':
+    case '8':
+    case '9':
+      consumeNumberLiteral(&result);
+      break;
     default:
       parseIdentifierOrInvalid(&result);
       break;
@@ -144,6 +156,53 @@ class Parser::CodeTokenizer {
     return result;
   }
 
+  void consumeNumberLiteral(TokenInfo *result) {
+    bool isFloatingLiteral = false;
+    unsigned length = 1;
+    if (code.size() > 1) {
+      // Consume the 'x' or 'b' radix modifier, if present.
+      switch (tolower(code[1])) {
+      case 'x':
+      case 'b':
+        length = 2;
+      }
+    }
+    while (length < code.size() && isdigit(code[length]))
+      ++length;
+
+    // Try to recognize a floating point literal.
+    while (length < code.size()) {
+      char c = code[length];
+      if (c == '-' || c == '+' || c == '.' || isdigit(c)) {
+        isFloatingLiteral = true;
+        length++;
+      } else {
+        break;
+      }
+    }
+
+    result->text = code.take_front(length);
+    code = code.drop_front(length);
+
+    if (isFloatingLiteral) {
+      char *end;
+      errno = 0;
+      std::string text = result->text.str();
+      double doubleValue = strtod(text.c_str(), &end);
+      if (*end == 0 && errno == 0) {
+        result->kind = TokenKind::Literal;
+        result->value = doubleValue;
+        return;
+      }
+    } else {
+      unsigned value;
+      if (!result->text.getAsInteger(0, value)) {
+        result->kind = TokenKind::Literal;
+        result->value = value;
+        return;
+      }
+    }
+  }
   // Consume a string literal, handle escape sequences and missing closing
   // quote.
   void consumeStringLiteral(TokenInfo *result) {
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 645db7109c2deb..8d6c0135aa1176 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -31,6 +31,8 @@ static std::string asArgString(ArgKind kind) {
     return "Matcher";
   case ArgKind::String:
     return "String";
+  case ArgKind::Unsigned:
+    return "unsigned";
   }
   llvm_unreachable("Unhandled ArgKind");
 }
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 65bd4bd77bcf8a..50d79512196d1a 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,6 +56,10 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
   value.Matcher = new VariantMatcher(matcher);
 }
 
+VariantValue::VariantValue(unsigned Unsigned) : type(ValueType::Unsigned) {
+  value.Unsigned = Unsigned;
+}
+
 VariantValue::~VariantValue() { reset(); }
 
 VariantValue &VariantValue::operator=(const VariantValue &other) {
@@ -69,6 +73,9 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
   case ValueType::Matcher:
     setMatcher(other.getMatcher());
     break;
+  case ValueType::Unsigned:
+    setUnsigned(other.getUnsigned());
+    break;
   case ValueType::Nothing:
     type = ValueType::Nothing;
     break;
@@ -85,12 +92,26 @@ void VariantValue::reset() {
     delete value.Matcher;
     break;
   // Cases that do nothing.
+  case ValueType::Unsigned:
   case ValueType::Nothing:
     break;
   }
   type = ValueType::Nothing;
 }
 
+bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
+
+unsigned VariantValue::getUnsigned() const {
+  assert(isUnsigned());
+  return value.Unsigned;
+}
+
+void VariantValue::setUnsigned(unsigned newValue) {
+  reset();
+  type = ValueType::Unsigned;
+  value.Unsigned = newValue;
+}
+
 bool VariantValue::isString() const { return type == ValueType::String; }
 
 const llvm::StringRef &VariantValue::getString() const {
@@ -123,6 +144,8 @@ std::string VariantValue::getTypeAsString() const {
     return "String";
   case ValueType::Matcher:
     return "Matcher";
+  case ValueType::Unsigned:
+    return "Unsigned";
   case ValueType::Nothing:
     return "Nothing";
   }
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 7d9f3606700519..70be7c36888d50 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -12,8 +12,11 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
+#include <unordered_map>
+#include <unordered_set>
 
 namespace mlir::query {
 
@@ -124,30 +127,130 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   return mlir::success();
 }
 
+void collectMatchNodes(
+    matcher::BoundOperationNode *Node,
+    llvm::SetVector<matcher::BoundOperationNode *> &MatchNodes) {
+  MatchNodes.insert(Node);
+  for (auto ChildNode : Node->Children) {
+    collectMatchNodes(ChildNode, MatchNodes);
+  }
+}
+
+void analyzeAndPrint(llvm::raw_ostream &os, QuerySession &qs,
+                     const matcher::BoundOperationsGraphBuilder &Bound) {
+
+  const auto &Nodes = Bound.getNodes();
+  if (Nodes.empty()) {
+    os << "The graph is empty.\n";
+    return;
+  }
+
+  bool AnyDetailedPrinting = false;
+  for (const auto &Pair : Nodes) {
+    if (Pair.second->DetailedPrinting) {
+      AnyDetailedPrinting = true;
+      break;
+    }
+  }
+
+  unsigned MatchesCounter = 0;
+  if (!AnyDetailedPrinting) {
+    os << "Operations:\n";
+    for (const auto &Pair : Nodes) {
+      os << "\n";
+      os << "  Match #" << ++MatchesCounter << "\n";
+      printMatch(os, qs, Pair.first, "root");
+    }
+    os << MatchesCounter << " matches found!\n";
+    return;
+  }
+
+  // Maps ids to nodes
+  std::unordered_map<Operation *, int> NodeIDs;
+  int id = 0;
+  for (const auto &Pair : Nodes) {
+    NodeIDs[Pair.first] = id++;
+  }
+
+  // Finds root nodes
+  std::vector<matcher::BoundOperationNode *> RootNodes;
+  for (const auto &Pair : Nodes) {
+    matcher::BoundOperationNode *Node = Pair.second.get();
+    if (Node->IsRootNode) {
+      RootNodes.push_back(Node);
+    }
+  }
+
+  for (auto RootNode : RootNodes) {
+    os << "\n";
+    os << "  Match #" << ++MatchesCounter << "\n";
+
+    llvm::SetVector<matcher::BoundOperationNode *> MatchNodes;
+    collectMatchNodes(RootNode, MatchNodes);
+    std::vector<matcher::BoundOperationNode *> SortedMatchNodes(
+        MatchNodes.begin(), MatchNodes.end());
+
+    // Sorts based on file location
+    std::sort(
+        SortedMatchNodes.begin(), SortedMatchNodes.end(),
+        [&](matcher::BoundOperationNode *a, matcher::BoundOperationNode *b) {
+          auto fileLocA = a->op->getLoc()->findInstanceOf<FileLineColLoc>();
+          auto fileLocB = b->op->getLoc()->findInstanceOf<FileLineColLoc>();
+
+          if (!fileLocA && !fileLocB)
+            return false;
+          if (!fileLocA)
+            return false;
+          if (!fileLocB)
+            return true;
+
+          if (fileLocA.getFilename().str() != fileLocB.getFilename().str())
+            return fileLocA.getFilename().str() < fileLocB.getFilename().str();
+          return fileLocA.getLine() < fileLocB.getLine();
+        });
+
+    for (auto Node : SortedMatchNodes) {
+      unsigned NodeID = NodeIDs[Node->op];
+      std::string binding = Node->IsRootNode ? "root" : "";
+      os << NodeID << ": ";
+      printMatch(os, qs, Node->op, binding);
+    }
+
+    // Prints edges
+    os << "Edges:\n";
+    for (auto Node : MatchNodes) {
+      int ParentID = NodeIDs[Node->op];
+      for (auto ChildNode : Node->Children) {
+        if (MatchNodes.count(ChildNode) > 0) {
+          int ChildID = NodeIDs[ChildNode->op];
+          os << "  " << ParentID << " ---> " << ChildID << "\n";
+        }
+      }
+    }
+  }
+  os << "\n" << MatchesCounter << " matches found!\n";
+}
+
 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   Operation *rootOp = qs.getRootOp();
   int matchCount = 0;
-  std::vector<Operation *> matches =
-      matcher::MatchFinder().getMatches(rootOp, matcher);
-
-  // An extract call is recognized by considering if the matcher has a name.
-  // TODO: Consider making the extract more explicit.
-  if (matcher.hasFunctionName()) {
-    auto functionName = matcher.getFunctionName();
-    Operation *function =
-        extractFunction(matches, rootOp->getContext(), functionName);
-    os << "\n" << *function << "\n\n";
-    function->erase();
-    return mlir::success();
-  }
+  auto matches = matcher::MatchFinder().getMatches(rootOp, matcher);
+
+  // An extract call is recognized by considering if the matcher has a
+  //     name.TODO : Consider making the extract
+  //                     more explicit.
+  // if (matcher.hasFunctionName()) {
+  //   auto functionName = matcher.getFunctionName();
+  //   Operation *function = extractFunction(matches.getOperations(),
+  //                                         rootOp->getContext(),
+  //                                         functionName);
+  //   os << "\n" << *function << "\n\n";
+  //   function->erase();
+  //   return mlir::success();
+  // }
 
   os << "\n";
-  for (Operation *op : matches) {
-    os << "Match #" << ++matchCount << ":\n\n";
-    // Placeholder "root" binding for the initial draft.
-    printMatch(os, qs, op, "root");
-  }
-  os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+  analyzeAndPrint(os, qs, matches);
 
   return mlir::success();
 }
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b0..d5c0b1632d3c5d 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -10,11 +10,12 @@
 // of the registered queries.
 //
 //===----------------------------------------------------------------------===//
-
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/ExtraMatchers.h"
+#include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/Matcher/Registry.h"
 #include "mlir/Tools/mlir-query/MlirQueryMain.h"
 
@@ -39,6 +40,10 @@ int main(int argc, char **argv) {
   query::matcher::Registry matcherRegistry;
 
   // Matchers registered in alphabetical order for consistency:
+  matcherRegistry.registerMatcher("getDefinitions",
+                                  mlir::query::extramatcher::getDefinitions);
+  matcherRegistry.registerMatcher("definedBy",
+                                  mlir::query::extramatcher::definedBy);
   matcherRegistry.registerMatcher("hasOpAttrName",
                                   static_cast<HasOpAttrName *>(m_Attr));
   matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
    
    
More information about the Mlir-commits
mailing list