[Mlir-commits] [mlir] [mlir] MLIR-QUERY DefinitionsMatcher implementation & DAG (PR #115670)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 19 05:09:58 PST 2025


https://github.com/dbudii updated https://github.com/llvm/llvm-project/pull/115670

>From cb380e7d65ab7cdd83285a05c6eb5c6cc3a50cce Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Wed, 23 Oct 2024 17:42:19 +0000
Subject: [PATCH 1/5] MLIR-QUERY DefinitionsMatcher implementation & DAG 
 - included printing logic for DAG 	- sfinae for match methods

---
 .../mlir/Query/Matcher/ExtraMatchers.h        | 109 ++++++++++++++
 mlir/include/mlir/Query/Matcher/Marshallers.h |  17 ++-
 mlir/include/mlir/Query/Matcher/MatchFinder.h |  18 ++-
 .../mlir/Query/Matcher/MatchersInternal.h     | 134 +++++++++++++----
 .../include/mlir/Query/Matcher/VariantValue.h |  10 +-
 mlir/lib/Query/Matcher/Parser.cpp             |  59 ++++++++
 mlir/lib/Query/Matcher/RegistryManager.cpp    |   2 +
 mlir/lib/Query/Matcher/VariantValue.cpp       |  23 +++
 mlir/lib/Query/Query.cpp                      | 141 +++++++++++++++---
 mlir/tools/mlir-query/mlir-query.cpp          |   7 +-
 10 files changed, 463 insertions(+), 57 deletions(-)
 create mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h

diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 000000000000000..1764ad35cc9c303
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,109 @@
+//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides extra matchers that are very useful for mlir-query
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTRAMATCHERS_H
+#define MLIR_IR_EXTRAMATCHERS_H
+
+#include "MatchFinder.h"
+#include "MatchersInternal.h"
+
+namespace mlir {
+
+namespace query {
+
+namespace extramatcher {
+
+namespace detail {
+
+class DefinitionsMatcher {
+public:
+  DefinitionsMatcher(matcher::DynMatcher &&InnerMatcher, unsigned Hops)
+      : InnerMatcher(std::move(InnerMatcher)), Hops(Hops) {}
+
+private:
+  bool matches(Operation *op, matcher::BoundOperationsGraphBuilder &Bound,
+               unsigned TempHops) {
+
+    llvm::DenseSet<mlir::Value> Ccache;
+    llvm::SmallVector<std::pair<Operation *, size_t>, 4> TempStorage;
+    TempStorage.push_back({op, TempHops});
+    while (!TempStorage.empty()) {
+      auto [CurrentOp, RemainingHops] = TempStorage.pop_back_val();
+
+      matcher::BoundOperationNode *CurrentNode =
+          Bound.addNode(CurrentOp, true, true);
+      if (RemainingHops == 0) {
+        continue;
+      }
+
+      for (auto Operand : CurrentOp->getOperands()) {
+        if (auto DefiningOp = Operand.getDefiningOp()) {
+          Bound.addEdge(CurrentOp, DefiningOp);
+          if (!Ccache.contains(Operand)) {
+            Ccache.insert(Operand);
+            TempStorage.emplace_back(DefiningOp, RemainingHops - 1);
+          }
+        } else if (auto BlockArg = Operand.dyn_cast<BlockArgument>()) {
+          auto *Block = BlockArg.getOwner();
+
+          if (Block->isEntryBlock() &&
+              isa<FunctionOpInterface>(Block->getParentOp())) {
+            continue;
+          }
+
+          Operation *ParentOp = BlockArg.getOwner()->getParentOp();
+          if (ParentOp) {
+            Bound.addEdge(CurrentOp, ParentOp);
+            if (!!Ccache.contains(BlockArg)) {
+              Ccache.insert(BlockArg);
+              TempStorage.emplace_back(ParentOp, RemainingHops - 1);
+            }
+          }
+        }
+      }
+    }
+    // We need at least 1 defining op
+    return Ccache.size() >= 2;
+  }
+
+public:
+  bool match(Operation *op, matcher::BoundOperationsGraphBuilder &Bound) {
+    if (InnerMatcher.match(op) && matches(op, Bound, Hops)) {
+      return true;
+    }
+    return false;
+  }
+
+private:
+  matcher::DynMatcher InnerMatcher;
+  unsigned Hops;
+};
+} // namespace detail
+
+inline detail::DefinitionsMatcher
+definedBy(mlir::query::matcher::DynMatcher InnerMatcher) {
+  return detail::DefinitionsMatcher(std::move(InnerMatcher), 1);
+}
+
+inline detail::DefinitionsMatcher
+getDefinitions(mlir::query::matcher::DynMatcher InnerMatcher, unsigned Hops) {
+  assert(Hops > 0 && "hops must be >= 1");
+  return detail::DefinitionsMatcher(std::move(InnerMatcher), Hops);
+}
+
+} // namespace extramatcher
+
+} // namespace query
+
+} // namespace mlir
+
+#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc70..4a08b9af82c26cf 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
   }
 };
 
+template <>
+struct ArgTypeTraits<unsigned> {
+  static bool hasCorrectType(const VariantValue &value) {
+    return value.isUnsigned();
+  }
+
+  static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+
+  static ArgKind getKind() { return ArgKind::Unsigned; }
+
+  static std::optional<std::string> getBestGuess(const VariantValue &) {
+    return std::nullopt;
+  }
+};
+
 template <>
 struct ArgTypeTraits<DynMatcher> {
 
@@ -166,7 +181,7 @@ matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
     ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
         ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
     return VariantMatcher::SingleMatcher(
-        *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
+        *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer, matcherName));
   }
 
   return VariantMatcher();
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2a6..4664e48b51b94a2 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -15,6 +15,7 @@
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
 
 #include "MatchersInternal.h"
+#include "mlir/IR/Operation.h"
 
 namespace mlir::query::matcher {
 
@@ -22,17 +23,18 @@ namespace mlir::query::matcher {
 class MatchFinder {
 public:
   // Returns all operations that match the given matcher.
-  static std::vector<Operation *> getMatches(Operation *root,
-                                             DynMatcher matcher) {
-    std::vector<Operation *> matches;
+  static BoundOperationsGraphBuilder getMatches(Operation *root,
+                                                DynMatcher matcher) {
 
-    // Simple match finding with walk.
+    BoundOperationsGraphBuilder Bound;
     root->walk([&](Operation *subOp) {
-      if (matcher.match(subOp))
-        matches.push_back(subOp);
+      if (matcher.match(subOp)) {
+        matcher::BoundOperationNode *currentNode = Bound.addNode(subOp);
+      } else if (matcher.match(subOp, Bound)) {
+        ////
+      }
     });
-
-    return matches;
+    return Bound;
   }
 };
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e31..cb4063dc2845260 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,19 +1,8 @@
 //===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
 //
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Part of the LLVM Project, under the Apache License v2.0 wIth LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Implements the base layer of the matcher framework.
-//
-// Matchers are methods that return a Matcher which provides a method
-// match(Operation *op)
-//
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
-// This file contains the wrapper classes needed to construct matchers for
-// mlir-query.
+// SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
 
@@ -22,16 +11,91 @@
 
 #include "mlir/IR/Matchers.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/ADT/MapVector.h"
+#include <memory>
+#include <stack>
+#include <unordered_set>
+#include <vector>
 
 namespace mlir::query::matcher {
 
+struct BoundOperationNode {
+  Operation *op;
+  std::vector<BoundOperationNode *> Parents;
+  std::vector<BoundOperationNode *> Children;
+
+  bool IsRootNode;
+  bool DetailedPrinting;
+
+  BoundOperationNode(Operation *op, bool IsRootNode = false,
+                     bool DetailedPrinting = false)
+      : op(op), IsRootNode(IsRootNode), DetailedPrinting(DetailedPrinting) {}
+};
+
+class BoundOperationsGraphBuilder {
+public:
+  BoundOperationNode *addNode(Operation *op, bool IsRootNode = false,
+                              bool DetailedPrinting = false) {
+    auto It = Nodes.find(op);
+    if (It != Nodes.end()) {
+      return It->second.get();
+    }
+    auto Node =
+        std::make_unique<BoundOperationNode>(op, IsRootNode, DetailedPrinting);
+    BoundOperationNode *NodePtr = Node.get();
+    Nodes[op] = std::move(Node);
+    return NodePtr;
+  }
+
+  void addEdge(Operation *parentOp, Operation *childOp) {
+    BoundOperationNode *ParentNode = addNode(parentOp, false, false);
+    BoundOperationNode *ChildNode = addNode(childOp, false, false);
+
+    ParentNode->Children.push_back(ChildNode);
+    ChildNode->Parents.push_back(ParentNode);
+  }
+
+  BoundOperationNode *getNode(Operation *op) const {
+    auto It = Nodes.find(op);
+    return It != Nodes.end() ? It->second.get() : nullptr;
+  }
+
+  const llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> &
+  getNodes() const {
+    return Nodes;
+  }
+
+private:
+  llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> Nodes;
+};
+
+// Type traIt to detect if a matcher has a match(Operation*) method
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+                               std::declval<Operation *>()))>>
+    : std::true_type {};
+
+// Type traIt to detect if a matcher has a match(Operation*,
+// BoundOperationsGraphBuilder&) method
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+                              std::declval<Operation *>(),
+                              std::declval<BoundOperationsGraphBuilder &>()))>>
+    : std::true_type {};
+
 // Generic interface for matchers on an MLIR operation.
 class MatcherInterface
     : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
 public:
   virtual ~MatcherInterface() = default;
-
   virtual bool match(Operation *op) = 0;
+  virtual bool match(Operation *op, BoundOperationsGraphBuilder &bound) = 0;
 };
 
 // MatcherFnImpl takes a matcher function object and implements
@@ -40,40 +104,56 @@ template <typename MatcherFn>
 class MatcherFnImpl : public MatcherInterface {
 public:
   MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
-  bool match(Operation *op) override { return matcherFn.match(op); }
+
+  bool match(Operation *op) override {
+    if constexpr (has_simple_match<MatcherFn>::value)
+      return matcherFn.match(op);
+    return false;
+  }
+
+  bool match(Operation *op, BoundOperationsGraphBuilder &bound) override {
+    if constexpr (has_bound_match<MatcherFn>::value)
+      return matcherFn.match(op, bound);
+    return false;
+  }
 
 private:
   MatcherFn matcherFn;
 };
 
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
 class DynMatcher {
 public:
   // Takes ownership of the provided implementation pointer.
-  DynMatcher(MatcherInterface *implementation)
-      : implementation(implementation) {}
+  DynMatcher(MatcherInterface *implementation, StringRef matcherName)
+      : implementation(implementation), matcherName(matcherName.str()) {}
 
   template <typename MatcherFn>
   static std::unique_ptr<DynMatcher>
-  constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
+  constructDynMatcherFromMatcherFn(MatcherFn &matcherFn,
+                                   StringRef matcherName) {
     auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
-    return std::make_unique<DynMatcher>(impl.release());
+    return std::make_unique<DynMatcher>(impl.release(), matcherName);
   }
 
   bool match(Operation *op) const { return implementation->match(op); }
+  bool match(Operation *op, BoundOperationsGraphBuilder &bound) const {
+    return implementation->match(op, bound);
+  }
 
-  void setFunctionName(StringRef name) { functionName = name.str(); };
-
-  bool hasFunctionName() const { return !functionName.empty(); };
-
-  StringRef getFunctionName() const { return functionName; };
+  void setFunctionName(StringRef name) { functionName = name.str(); }
+  void setMatcherName(StringRef name) { matcherName = name.str(); }
+  bool hasFunctionName() const { return !functionName.empty(); }
+  StringRef getFunctionName() const { return functionName; }
+  StringRef getMatcherName() const { return matcherName; }
 
 private:
   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
+  std::string matcherName;
   std::string functionName;
 };
 
 } // namespace mlir::query::matcher
 
-#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e0217..73d96a6913dfe4a 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
 namespace mlir::query::matcher {
 
 // All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Matcher, String, Unsigned };
 
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
@@ -81,6 +81,7 @@ class VariantValue {
   // Specific constructors for each supported type.
   VariantValue(const llvm::StringRef string);
   VariantValue(const VariantMatcher &matcher);
+  VariantValue(unsigned Unsigned);
 
   // String value functions.
   bool isString() const;
@@ -92,6 +93,11 @@ class VariantValue {
   const VariantMatcher &getMatcher() const;
   void setMatcher(const VariantMatcher &matcher);
 
+  // Unsigned value functions.
+  bool isUnsigned() const;
+  unsigned getUnsigned() const;
+  void setUnsigned(unsigned Unsigned);
+
   // String representation of the type of the value.
   std::string getTypeAsString() const;
 
@@ -103,12 +109,14 @@ class VariantValue {
     Nothing,
     String,
     Matcher,
+    Unsigned,
   };
 
   // All supported value types.
   union AllValues {
     llvm::StringRef *String;
     VariantMatcher *Matcher;
+    unsigned Unsigned;
   };
 
   ValueType type;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f7c..4f1b716756e318e 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
     case '\'':
       consumeStringLiteral(&result);
       break;
+    case '0':
+    case '1':
+    case '2':
+    case '3':
+    case '4':
+    case '5':
+    case '6':
+    case '7':
+    case '8':
+    case '9':
+      consumeNumberLiteral(&result);
+      break;
     default:
       parseIdentifierOrInvalid(&result);
       break;
@@ -144,6 +156,53 @@ class Parser::CodeTokenizer {
     return result;
   }
 
+  void consumeNumberLiteral(TokenInfo *result) {
+    bool isFloatingLiteral = false;
+    unsigned length = 1;
+    if (code.size() > 1) {
+      // Consume the 'x' or 'b' radix modifier, if present.
+      switch (tolower(code[1])) {
+      case 'x':
+      case 'b':
+        length = 2;
+      }
+    }
+    while (length < code.size() && isdigit(code[length]))
+      ++length;
+
+    // Try to recognize a floating point literal.
+    while (length < code.size()) {
+      char c = code[length];
+      if (c == '-' || c == '+' || c == '.' || isdigit(c)) {
+        isFloatingLiteral = true;
+        length++;
+      } else {
+        break;
+      }
+    }
+
+    result->text = code.take_front(length);
+    code = code.drop_front(length);
+
+    if (isFloatingLiteral) {
+      char *end;
+      errno = 0;
+      std::string text = result->text.str();
+      double doubleValue = strtod(text.c_str(), &end);
+      if (*end == 0 && errno == 0) {
+        result->kind = TokenKind::Literal;
+        result->value = doubleValue;
+        return;
+      }
+    } else {
+      unsigned value;
+      if (!result->text.getAsInteger(0, value)) {
+        result->kind = TokenKind::Literal;
+        result->value = value;
+        return;
+      }
+    }
+  }
   // Consume a string literal, handle escape sequences and missing closing
   // quote.
   void consumeStringLiteral(TokenInfo *result) {
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 645db7109c2deb3..8d6c0135aa11768 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -31,6 +31,8 @@ static std::string asArgString(ArgKind kind) {
     return "Matcher";
   case ArgKind::String:
     return "String";
+  case ArgKind::Unsigned:
+    return "unsigned";
   }
   llvm_unreachable("Unhandled ArgKind");
 }
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 65bd4bd77bcf8af..50d79512196d1a5 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,6 +56,10 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
   value.Matcher = new VariantMatcher(matcher);
 }
 
+VariantValue::VariantValue(unsigned Unsigned) : type(ValueType::Unsigned) {
+  value.Unsigned = Unsigned;
+}
+
 VariantValue::~VariantValue() { reset(); }
 
 VariantValue &VariantValue::operator=(const VariantValue &other) {
@@ -69,6 +73,9 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
   case ValueType::Matcher:
     setMatcher(other.getMatcher());
     break;
+  case ValueType::Unsigned:
+    setUnsigned(other.getUnsigned());
+    break;
   case ValueType::Nothing:
     type = ValueType::Nothing;
     break;
@@ -85,12 +92,26 @@ void VariantValue::reset() {
     delete value.Matcher;
     break;
   // Cases that do nothing.
+  case ValueType::Unsigned:
   case ValueType::Nothing:
     break;
   }
   type = ValueType::Nothing;
 }
 
+bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
+
+unsigned VariantValue::getUnsigned() const {
+  assert(isUnsigned());
+  return value.Unsigned;
+}
+
+void VariantValue::setUnsigned(unsigned newValue) {
+  reset();
+  type = ValueType::Unsigned;
+  value.Unsigned = newValue;
+}
+
 bool VariantValue::isString() const { return type == ValueType::String; }
 
 const llvm::StringRef &VariantValue::getString() const {
@@ -123,6 +144,8 @@ std::string VariantValue::getTypeAsString() const {
     return "String";
   case ValueType::Matcher:
     return "Matcher";
+  case ValueType::Unsigned:
+    return "Unsigned";
   case ValueType::Nothing:
     return "Nothing";
   }
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 7d9f36067005198..70be7c36888d50b 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -12,8 +12,11 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
+#include <unordered_map>
+#include <unordered_set>
 
 namespace mlir::query {
 
@@ -124,30 +127,130 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   return mlir::success();
 }
 
+void collectMatchNodes(
+    matcher::BoundOperationNode *Node,
+    llvm::SetVector<matcher::BoundOperationNode *> &MatchNodes) {
+  MatchNodes.insert(Node);
+  for (auto ChildNode : Node->Children) {
+    collectMatchNodes(ChildNode, MatchNodes);
+  }
+}
+
+void analyzeAndPrint(llvm::raw_ostream &os, QuerySession &qs,
+                     const matcher::BoundOperationsGraphBuilder &Bound) {
+
+  const auto &Nodes = Bound.getNodes();
+  if (Nodes.empty()) {
+    os << "The graph is empty.\n";
+    return;
+  }
+
+  bool AnyDetailedPrinting = false;
+  for (const auto &Pair : Nodes) {
+    if (Pair.second->DetailedPrinting) {
+      AnyDetailedPrinting = true;
+      break;
+    }
+  }
+
+  unsigned MatchesCounter = 0;
+  if (!AnyDetailedPrinting) {
+    os << "Operations:\n";
+    for (const auto &Pair : Nodes) {
+      os << "\n";
+      os << "  Match #" << ++MatchesCounter << "\n";
+      printMatch(os, qs, Pair.first, "root");
+    }
+    os << MatchesCounter << " matches found!\n";
+    return;
+  }
+
+  // Maps ids to nodes
+  std::unordered_map<Operation *, int> NodeIDs;
+  int id = 0;
+  for (const auto &Pair : Nodes) {
+    NodeIDs[Pair.first] = id++;
+  }
+
+  // Finds root nodes
+  std::vector<matcher::BoundOperationNode *> RootNodes;
+  for (const auto &Pair : Nodes) {
+    matcher::BoundOperationNode *Node = Pair.second.get();
+    if (Node->IsRootNode) {
+      RootNodes.push_back(Node);
+    }
+  }
+
+  for (auto RootNode : RootNodes) {
+    os << "\n";
+    os << "  Match #" << ++MatchesCounter << "\n";
+
+    llvm::SetVector<matcher::BoundOperationNode *> MatchNodes;
+    collectMatchNodes(RootNode, MatchNodes);
+    std::vector<matcher::BoundOperationNode *> SortedMatchNodes(
+        MatchNodes.begin(), MatchNodes.end());
+
+    // Sorts based on file location
+    std::sort(
+        SortedMatchNodes.begin(), SortedMatchNodes.end(),
+        [&](matcher::BoundOperationNode *a, matcher::BoundOperationNode *b) {
+          auto fileLocA = a->op->getLoc()->findInstanceOf<FileLineColLoc>();
+          auto fileLocB = b->op->getLoc()->findInstanceOf<FileLineColLoc>();
+
+          if (!fileLocA && !fileLocB)
+            return false;
+          if (!fileLocA)
+            return false;
+          if (!fileLocB)
+            return true;
+
+          if (fileLocA.getFilename().str() != fileLocB.getFilename().str())
+            return fileLocA.getFilename().str() < fileLocB.getFilename().str();
+          return fileLocA.getLine() < fileLocB.getLine();
+        });
+
+    for (auto Node : SortedMatchNodes) {
+      unsigned NodeID = NodeIDs[Node->op];
+      std::string binding = Node->IsRootNode ? "root" : "";
+      os << NodeID << ": ";
+      printMatch(os, qs, Node->op, binding);
+    }
+
+    // Prints edges
+    os << "Edges:\n";
+    for (auto Node : MatchNodes) {
+      int ParentID = NodeIDs[Node->op];
+      for (auto ChildNode : Node->Children) {
+        if (MatchNodes.count(ChildNode) > 0) {
+          int ChildID = NodeIDs[ChildNode->op];
+          os << "  " << ParentID << " ---> " << ChildID << "\n";
+        }
+      }
+    }
+  }
+  os << "\n" << MatchesCounter << " matches found!\n";
+}
+
 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   Operation *rootOp = qs.getRootOp();
   int matchCount = 0;
-  std::vector<Operation *> matches =
-      matcher::MatchFinder().getMatches(rootOp, matcher);
-
-  // An extract call is recognized by considering if the matcher has a name.
-  // TODO: Consider making the extract more explicit.
-  if (matcher.hasFunctionName()) {
-    auto functionName = matcher.getFunctionName();
-    Operation *function =
-        extractFunction(matches, rootOp->getContext(), functionName);
-    os << "\n" << *function << "\n\n";
-    function->erase();
-    return mlir::success();
-  }
+  auto matches = matcher::MatchFinder().getMatches(rootOp, matcher);
+
+  // An extract call is recognized by considering if the matcher has a
+  //     name.TODO : Consider making the extract
+  //                     more explicit.
+  // if (matcher.hasFunctionName()) {
+  //   auto functionName = matcher.getFunctionName();
+  //   Operation *function = extractFunction(matches.getOperations(),
+  //                                         rootOp->getContext(),
+  //                                         functionName);
+  //   os << "\n" << *function << "\n\n";
+  //   function->erase();
+  //   return mlir::success();
+  // }
 
   os << "\n";
-  for (Operation *op : matches) {
-    os << "Match #" << ++matchCount << ":\n\n";
-    // Placeholder "root" binding for the initial draft.
-    printMatch(os, qs, op, "root");
-  }
-  os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+  analyzeAndPrint(os, qs, matches);
 
   return mlir::success();
 }
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b09..d5c0b1632d3c5d4 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -10,11 +10,12 @@
 // of the registered queries.
 //
 //===----------------------------------------------------------------------===//
-
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/ExtraMatchers.h"
+#include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/Matcher/Registry.h"
 #include "mlir/Tools/mlir-query/MlirQueryMain.h"
 
@@ -39,6 +40,10 @@ int main(int argc, char **argv) {
   query::matcher::Registry matcherRegistry;
 
   // Matchers registered in alphabetical order for consistency:
+  matcherRegistry.registerMatcher("getDefinitions",
+                                  mlir::query::extramatcher::getDefinitions);
+  matcherRegistry.registerMatcher("definedBy",
+                                  mlir::query::extramatcher::definedBy);
   matcherRegistry.registerMatcher("hasOpAttrName",
                                   static_cast<HasOpAttrName *>(m_Attr));
   matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

>From 23d38a70b3905cb595d0c22d058712588bc9762e Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Wed, 8 Jan 2025 17:58:36 +0000
Subject: [PATCH 2/5] Added SetQuery, LetQuery, new implementation for matchers

---
 mlir/include/mlir/IR/Matchers.h               |   4 +-
 .../mlir/Query/Matcher/ExtraMatchers.h        | 186 +++++++++++++-----
 mlir/include/mlir/Query/Matcher/MatchFinder.h |  14 +-
 .../mlir/Query/Matcher/MatchersInternal.h     |  76 ++-----
 .../include/mlir/Query/Matcher/VariantValue.h |   2 +
 mlir/include/mlir/Query/Query.h               |  50 ++++-
 mlir/include/mlir/Query/QuerySession.h        |   5 +
 mlir/lib/Query/Matcher/Parser.cpp             |  20 +-
 mlir/lib/Query/Matcher/VariantValue.cpp       |   1 +
 mlir/lib/Query/Query.cpp                      | 144 +++-----------
 mlir/lib/Query/QueryParser.cpp                |  77 +++++++-
 mlir/lib/Query/QueryParser.h                  |   2 +-
 mlir/test/mlir-query/complex-test.mlir        |  23 +++
 mlir/test/mlir-query/function-extraction.mlir |   2 +-
 mlir/tools/mlir-query/mlir-query.cpp          |   4 +
 15 files changed, 367 insertions(+), 243 deletions(-)
 create mode 100644 mlir/test/mlir-query/complex-test.mlir

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 1dce055db1b4a71..2204a68be26b104 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -59,7 +59,7 @@ struct NameOpMatcher {
   NameOpMatcher(StringRef name) : name(name) {}
   bool match(Operation *op) { return op->getName().getStringRef() == name; }
 
-  StringRef name;
+  std::string name;
 };
 
 /// The matcher that matches operations that have the specified attribute name.
@@ -67,7 +67,7 @@ struct AttrOpMatcher {
   AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
   bool match(Operation *op) { return op->hasAttr(attrName); }
 
-  StringRef attrName;
+  std::string attrName;
 };
 
 /// The matcher that matches operations that have the `ConstantLike` trait, and
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
index 1764ad35cc9c303..1900879ca70920b 100644
--- a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -15,6 +15,9 @@
 
 #include "MatchFinder.h"
 #include "MatchersInternal.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Query/Query.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 
@@ -24,80 +27,161 @@ namespace extramatcher {
 
 namespace detail {
 
-class DefinitionsMatcher {
+class BackwardSliceMatcher {
 public:
-  DefinitionsMatcher(matcher::DynMatcher &&InnerMatcher, unsigned Hops)
-      : InnerMatcher(std::move(InnerMatcher)), Hops(Hops) {}
+  BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
 
 private:
-  bool matches(Operation *op, matcher::BoundOperationsGraphBuilder &Bound,
-               unsigned TempHops) {
-
-    llvm::DenseSet<mlir::Value> Ccache;
-    llvm::SmallVector<std::pair<Operation *, size_t>, 4> TempStorage;
-    TempStorage.push_back({op, TempHops});
-    while (!TempStorage.empty()) {
-      auto [CurrentOp, RemainingHops] = TempStorage.pop_back_val();
-
-      matcher::BoundOperationNode *CurrentNode =
-          Bound.addNode(CurrentOp, true, true);
-      if (RemainingHops == 0) {
-        continue;
-      }
+  bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+               QueryOptions &options, unsigned tempHops) {
 
-      for (auto Operand : CurrentOp->getOperands()) {
-        if (auto DefiningOp = Operand.getDefiningOp()) {
-          Bound.addEdge(CurrentOp, DefiningOp);
-          if (!Ccache.contains(Operand)) {
-            Ccache.insert(Operand);
-            TempStorage.emplace_back(DefiningOp, RemainingHops - 1);
-          }
-        } else if (auto BlockArg = Operand.dyn_cast<BlockArgument>()) {
-          auto *Block = BlockArg.getOwner();
+    bool validSlice = true;
+    if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+      return false;
+    }
 
-          if (Block->isEntryBlock() &&
-              isa<FunctionOpInterface>(Block->getParentOp())) {
-            continue;
+    auto processValue = [&](Value value) {
+      if (tempHops == 0) {
+        return;
+      }
+      if (auto *definingOp = value.getDefiningOp()) {
+        if (backwardSlice.count(definingOp) == 0)
+          matches(definingOp, backwardSlice, options, tempHops - 1);
+      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+        if (options.omitBlockArguments)
+          return;
+        Block *block = blockArg.getOwner();
+
+        Operation *parentOp = block->getParentOp();
+
+        if (parentOp && backwardSlice.count(parentOp) == 0) {
+          if (parentOp->getNumRegions() == 1 &&
+              parentOp->getRegion(0).getBlocks().size() == 1) {
+            validSlice = false;
+            return;
+          };
+          matches(parentOp, backwardSlice, options, tempHops - 1);
+        }
+      } else {
+        validSlice = false;
+        return;
+      }
+    };
+
+    if (!options.omitUsesFromAbove) {
+      llvm::for_each(op->getRegions(), [&](Region &region) {
+        SmallPtrSet<Region *, 4> descendents;
+        region.walk(
+            [&](Region *childRegion) { descendents.insert(childRegion); });
+        region.walk([&](Operation *op) {
+          for (OpOperand &operand : op->getOpOperands()) {
+            if (!descendents.contains(operand.get().getParentRegion()))
+              processValue(operand.get());
+            if (!validSlice)
+              return;
           }
+        });
+      });
+    }
 
-          Operation *ParentOp = BlockArg.getOwner()->getParentOp();
-          if (ParentOp) {
-            Bound.addEdge(CurrentOp, ParentOp);
-            if (!!Ccache.contains(BlockArg)) {
-              Ccache.insert(BlockArg);
-              TempStorage.emplace_back(ParentOp, RemainingHops - 1);
-            }
-          }
-        }
+    llvm::for_each(op->getOperands(), [&](Value operand) {
+      processValue(operand);
+      if (!validSlice)
+        return;
+    });
+    backwardSlice.insert(op);
+    if (!validSlice) {
+      return false;
+    }
+    return true;
+  }
+
+public:
+  bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+             QueryOptions &options) {
+    if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
+      if (!options.inclusive) {
+        backwardSlice.remove(op);
       }
+      return true;
     }
-    // We need at least 1 defining op
-    return Ccache.size() >= 2;
+    return false;
   }
 
+private:
+  matcher::DynMatcher innerMatcher;
+  unsigned hops;
+};
+
+class ForwardSliceMatcher {
 public:
-  bool match(Operation *op, matcher::BoundOperationsGraphBuilder &Bound) {
-    if (InnerMatcher.match(op) && matches(op, Bound, Hops)) {
+  ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+  bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+               QueryOptions &options, unsigned tempHops) {
+
+    if (tempHops == 0) {
+      forwardSlice.insert(op);
+      return true;
+    }
+
+    for (Region &region : op->getRegions())
+      for (Block &block : region)
+        for (Operation &blockOp : block)
+          if (forwardSlice.count(&blockOp) == 0)
+            matches(&blockOp, forwardSlice, options, tempHops - 1);
+    for (Value result : op->getResults()) {
+      for (Operation *userOp : result.getUsers())
+        if (forwardSlice.count(userOp) == 0)
+          matches(userOp, forwardSlice, options, tempHops - 1);
+    }
+
+    forwardSlice.insert(op);
+    return true;
+  }
+
+public:
+  bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+             QueryOptions &options) {
+    if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
+      if (!options.inclusive) {
+        forwardSlice.remove(op);
+      }
+      SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+      forwardSlice.insert(v.rbegin(), v.rend());
       return true;
     }
     return false;
   }
 
 private:
-  matcher::DynMatcher InnerMatcher;
-  unsigned Hops;
+  matcher::DynMatcher innerMatcher;
+  unsigned hops;
 };
+
 } // namespace detail
 
-inline detail::DefinitionsMatcher
-definedBy(mlir::query::matcher::DynMatcher InnerMatcher) {
-  return detail::DefinitionsMatcher(std::move(InnerMatcher), 1);
+inline detail::BackwardSliceMatcher
+definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::BackwardSliceMatcher
+getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+inline detail::ForwardSliceMatcher
+usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
 }
 
-inline detail::DefinitionsMatcher
-getDefinitions(mlir::query::matcher::DynMatcher InnerMatcher, unsigned Hops) {
-  assert(Hops > 0 && "hops must be >= 1");
-  return detail::DefinitionsMatcher(std::move(InnerMatcher), Hops);
+inline detail::ForwardSliceMatcher
+getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
 }
 
 } // namespace extramatcher
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 4664e48b51b94a2..cbdf2fec46a7611 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -23,18 +23,18 @@ namespace mlir::query::matcher {
 class MatchFinder {
 public:
   // Returns all operations that match the given matcher.
-  static BoundOperationsGraphBuilder getMatches(Operation *root,
-                                                DynMatcher matcher) {
-
-    BoundOperationsGraphBuilder Bound;
+  static SetVector<Operation *>
+  getMatches(Operation *root, QueryOptions &options, DynMatcher matcher) {
+    SetVector<Operation *> backwardSlice;
     root->walk([&](Operation *subOp) {
       if (matcher.match(subOp)) {
-        matcher::BoundOperationNode *currentNode = Bound.addNode(subOp);
-      } else if (matcher.match(subOp, Bound)) {
+        backwardSlice.insert(subOp);
+      } else {
+        matcher.match(subOp, backwardSlice, options);
         ////
       }
     });
-    return Bound;
+    return backwardSlice;
   }
 };
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index cb4063dc2845260..081d216b56b61c0 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -17,59 +17,13 @@
 #include <unordered_set>
 #include <vector>
 
-namespace mlir::query::matcher {
-
-struct BoundOperationNode {
-  Operation *op;
-  std::vector<BoundOperationNode *> Parents;
-  std::vector<BoundOperationNode *> Children;
-
-  bool IsRootNode;
-  bool DetailedPrinting;
-
-  BoundOperationNode(Operation *op, bool IsRootNode = false,
-                     bool DetailedPrinting = false)
-      : op(op), IsRootNode(IsRootNode), DetailedPrinting(DetailedPrinting) {}
-};
+namespace mlir {
+namespace query {
+struct QueryOptions;
+}
+} // namespace mlir
 
-class BoundOperationsGraphBuilder {
-public:
-  BoundOperationNode *addNode(Operation *op, bool IsRootNode = false,
-                              bool DetailedPrinting = false) {
-    auto It = Nodes.find(op);
-    if (It != Nodes.end()) {
-      return It->second.get();
-    }
-    auto Node =
-        std::make_unique<BoundOperationNode>(op, IsRootNode, DetailedPrinting);
-    BoundOperationNode *NodePtr = Node.get();
-    Nodes[op] = std::move(Node);
-    return NodePtr;
-  }
-
-  void addEdge(Operation *parentOp, Operation *childOp) {
-    BoundOperationNode *ParentNode = addNode(parentOp, false, false);
-    BoundOperationNode *ChildNode = addNode(childOp, false, false);
-
-    ParentNode->Children.push_back(ChildNode);
-    ChildNode->Parents.push_back(ParentNode);
-  }
-
-  BoundOperationNode *getNode(Operation *op) const {
-    auto It = Nodes.find(op);
-    return It != Nodes.end() ? It->second.get() : nullptr;
-  }
-
-  const llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> &
-  getNodes() const {
-    return Nodes;
-  }
-
-private:
-  llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> Nodes;
-};
-
-// Type traIt to detect if a matcher has a match(Operation*) method
+namespace mlir::query::matcher {
 template <typename T, typename = void>
 struct has_simple_match : std::false_type {};
 
@@ -78,15 +32,14 @@ struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
                                std::declval<Operation *>()))>>
     : std::true_type {};
 
-// Type traIt to detect if a matcher has a match(Operation*,
-// BoundOperationsGraphBuilder&) method
 template <typename T, typename = void>
 struct has_bound_match : std::false_type {};
 
 template <typename T>
 struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
                               std::declval<Operation *>(),
-                              std::declval<BoundOperationsGraphBuilder &>()))>>
+                              std::declval<SetVector<Operation *> &>(),
+                              std::declval<QueryOptions &>()))>>
     : std::true_type {};
 
 // Generic interface for matchers on an MLIR operation.
@@ -95,7 +48,8 @@ class MatcherInterface
 public:
   virtual ~MatcherInterface() = default;
   virtual bool match(Operation *op) = 0;
-  virtual bool match(Operation *op, BoundOperationsGraphBuilder &bound) = 0;
+  virtual bool match(Operation *op, SetVector<Operation *> &matchedOps,
+                     QueryOptions &options) = 0;
 };
 
 // MatcherFnImpl takes a matcher function object and implements
@@ -111,9 +65,10 @@ class MatcherFnImpl : public MatcherInterface {
     return false;
   }
 
-  bool match(Operation *op, BoundOperationsGraphBuilder &bound) override {
+  bool match(Operation *op, SetVector<Operation *> &matchedOps,
+             QueryOptions &options) override {
     if constexpr (has_bound_match<MatcherFn>::value)
-      return matcherFn.match(op, bound);
+      return matcherFn.match(op, matchedOps, options);
     return false;
   }
 
@@ -138,8 +93,9 @@ class DynMatcher {
   }
 
   bool match(Operation *op) const { return implementation->match(op); }
-  bool match(Operation *op, BoundOperationsGraphBuilder &bound) const {
-    return implementation->match(op, bound);
+  bool match(Operation *op, SetVector<Operation *> &matchedOps,
+             QueryOptions &options) const {
+    return implementation->match(op, matchedOps, options);
   }
 
   void setFunctionName(StringRef name) { functionName = name.str(); }
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 73d96a6913dfe4a..6b57119df7a9bf5 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -100,6 +100,8 @@ class VariantValue {
 
   // String representation of the type of the value.
   std::string getTypeAsString() const;
+  explicit operator bool() const { return hasValue(); }
+  bool hasValue() const { return type != ValueType::Nothing; }
 
 private:
   void reset();
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 18f2172c9510a35..d695f648016be31 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -10,6 +10,7 @@
 #define MLIR_TOOLS_MLIRQUERY_QUERY_H
 
 #include "Matcher/VariantValue.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/LineEditor/LineEditor.h"
@@ -17,7 +18,13 @@
 
 namespace mlir::query {
 
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+struct QueryOptions {
+  bool omitBlockArguments = false;
+  bool omitUsesFromAbove = true;
+  bool inclusive = true;
+};
+
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit, Let, SetBool };
 
 class QuerySession;
 
@@ -103,6 +110,47 @@ struct MatchQuery : Query {
   }
 };
 
+struct LetQuery : Query {
+  LetQuery(llvm::StringRef name, const matcher::VariantValue &value)
+      : Query(QueryKind::Let), name(name), value(value) {}
+
+  llvm::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override;
+
+  std::string name;
+  matcher::VariantValue value;
+
+  static bool classof(const Query *query) {
+    return query->kind == QueryKind::Let;
+  }
+};
+
+template <typename T>
+struct SetQueryKind {};
+
+template <>
+struct SetQueryKind<bool> {
+  static const QueryKind value = QueryKind::SetBool;
+};
+template <typename T>
+struct SetQuery : Query {
+  SetQuery(T QuerySession::*var, T value)
+      : Query(SetQueryKind<T>::value), var(var), value(value) {}
+
+  llvm::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override {
+    qs.*var = value;
+    return mlir::success();
+  }
+
+  static bool classof(const Query *query) {
+    return query->kind == SetQueryKind<T>::value;
+  }
+
+  T QuerySession::*var;
+  T value;
+};
+
 } // namespace mlir::query
 
 #endif
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index fe552d750fc7716..7faf206b96124b5 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -33,6 +33,11 @@ class QuerySession {
   llvm::StringMap<matcher::VariantValue> namedValues;
   bool terminate = false;
 
+public:
+  bool omitBlockArguments = false;
+  bool omitUsesFromAbove = true;
+  bool inclusive = true;
+
 private:
   Operation *rootOp;
   llvm::SourceMgr &sourceMgr;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 4f1b716756e318e..79797ec45b549b0 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -191,7 +191,7 @@ class Parser::CodeTokenizer {
       double doubleValue = strtod(text.c_str(), &end);
       if (*end == 0 && errno == 0) {
         result->kind = TokenKind::Literal;
-        result->value = doubleValue;
+        result->value = static_cast<double>(doubleValue);
         return;
       }
     } else {
@@ -316,13 +316,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
 
   if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
     // Parse as a named value.
-    auto namedValue =
-        namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+    if (auto namedValue = namedValues ? namedValues->lookup(nameToken.text)
+                                      : VariantValue()) {
 
-    if (!namedValue.isMatcher()) {
-      error->addError(tokenizer->peekNextToken().range,
-                      ErrorType::ParserNotAMatcher);
-      return false;
+      if (tokenizer->nextTokenKind() != TokenKind::Period) {
+        *value = namedValue;
+        return true;
+      }
+
+      if (!namedValue.isMatcher()) {
+        error->addError(tokenizer->peekNextToken().range,
+                        ErrorType::ParserNotAMatcher);
+        return false;
+      }
     }
 
     if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 50d79512196d1a5..d5218d8dad8c936 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -99,6 +99,7 @@ void VariantValue::reset() {
   type = ValueType::Nothing;
 }
 
+// Unsinged
 bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
 
 unsigned VariantValue::getUnsigned() const {
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 70be7c36888d50b..03d5018abf5b012 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -29,8 +29,8 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
   return QueryParser::complete(line, pos, qs);
 }
 
-static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
-                       const std::string &binding) {
+static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                       mlir::Operation *op, const std::string &binding) {
   auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
   auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
       qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
@@ -102,6 +102,12 @@ static Operation *extractFunction(std::vector<Operation *> &ops,
   return funcOp;
 }
 
+static void parseQueryOptions(QuerySession &qs, QueryOptions &options) {
+  options.omitBlockArguments = qs.omitBlockArguments;
+  options.omitUsesFromAbove = qs.omitUsesFromAbove;
+  options.inclusive = qs.inclusive;
+}
+
 Query::~Query() = default;
 
 LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
@@ -117,6 +123,8 @@ LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   os << "Available commands:\n\n"
         "  match MATCHER, m MATCHER      "
         "Match the mlir against the given matcher.\n"
+        "  let NAME MATCHER, l NAME MATCHER  "
+        "Give a matcher expression a name, to be used later\n"
         "  quit                              "
         "Terminates the query session.\n\n";
   return mlir::success();
@@ -127,132 +135,44 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   return mlir::success();
 }
 
-void collectMatchNodes(
-    matcher::BoundOperationNode *Node,
-    llvm::SetVector<matcher::BoundOperationNode *> &MatchNodes) {
-  MatchNodes.insert(Node);
-  for (auto ChildNode : Node->Children) {
-    collectMatchNodes(ChildNode, MatchNodes);
+LogicalResult LetQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
+  if (value.hasValue()) {
+    qs.namedValues[name] = value;
+  } else {
+    qs.namedValues.erase(name);
   }
-}
-
-void analyzeAndPrint(llvm::raw_ostream &os, QuerySession &qs,
-                     const matcher::BoundOperationsGraphBuilder &Bound) {
-
-  const auto &Nodes = Bound.getNodes();
-  if (Nodes.empty()) {
-    os << "The graph is empty.\n";
-    return;
-  }
-
-  bool AnyDetailedPrinting = false;
-  for (const auto &Pair : Nodes) {
-    if (Pair.second->DetailedPrinting) {
-      AnyDetailedPrinting = true;
-      break;
-    }
-  }
-
-  unsigned MatchesCounter = 0;
-  if (!AnyDetailedPrinting) {
-    os << "Operations:\n";
-    for (const auto &Pair : Nodes) {
-      os << "\n";
-      os << "  Match #" << ++MatchesCounter << "\n";
-      printMatch(os, qs, Pair.first, "root");
-    }
-    os << MatchesCounter << " matches found!\n";
-    return;
-  }
-
-  // Maps ids to nodes
-  std::unordered_map<Operation *, int> NodeIDs;
-  int id = 0;
-  for (const auto &Pair : Nodes) {
-    NodeIDs[Pair.first] = id++;
-  }
-
-  // Finds root nodes
-  std::vector<matcher::BoundOperationNode *> RootNodes;
-  for (const auto &Pair : Nodes) {
-    matcher::BoundOperationNode *Node = Pair.second.get();
-    if (Node->IsRootNode) {
-      RootNodes.push_back(Node);
-    }
-  }
-
-  for (auto RootNode : RootNodes) {
-    os << "\n";
-    os << "  Match #" << ++MatchesCounter << "\n";
-
-    llvm::SetVector<matcher::BoundOperationNode *> MatchNodes;
-    collectMatchNodes(RootNode, MatchNodes);
-    std::vector<matcher::BoundOperationNode *> SortedMatchNodes(
-        MatchNodes.begin(), MatchNodes.end());
-
-    // Sorts based on file location
-    std::sort(
-        SortedMatchNodes.begin(), SortedMatchNodes.end(),
-        [&](matcher::BoundOperationNode *a, matcher::BoundOperationNode *b) {
-          auto fileLocA = a->op->getLoc()->findInstanceOf<FileLineColLoc>();
-          auto fileLocB = b->op->getLoc()->findInstanceOf<FileLineColLoc>();
-
-          if (!fileLocA && !fileLocB)
-            return false;
-          if (!fileLocA)
-            return false;
-          if (!fileLocB)
-            return true;
-
-          if (fileLocA.getFilename().str() != fileLocB.getFilename().str())
-            return fileLocA.getFilename().str() < fileLocB.getFilename().str();
-          return fileLocA.getLine() < fileLocB.getLine();
-        });
-
-    for (auto Node : SortedMatchNodes) {
-      unsigned NodeID = NodeIDs[Node->op];
-      std::string binding = Node->IsRootNode ? "root" : "";
-      os << NodeID << ": ";
-      printMatch(os, qs, Node->op, binding);
-    }
-
-    // Prints edges
-    os << "Edges:\n";
-    for (auto Node : MatchNodes) {
-      int ParentID = NodeIDs[Node->op];
-      for (auto ChildNode : Node->Children) {
-        if (MatchNodes.count(ChildNode) > 0) {
-          int ChildID = NodeIDs[ChildNode->op];
-          os << "  " << ParentID << " ---> " << ChildID << "\n";
-        }
-      }
-    }
-  }
-  os << "\n" << MatchesCounter << " matches found!\n";
+  return mlir::success();
 }
 
 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   Operation *rootOp = qs.getRootOp();
   int matchCount = 0;
-  auto matches = matcher::MatchFinder().getMatches(rootOp, matcher);
 
-  // An extract call is recognized by considering if the matcher has a
-  //     name.TODO : Consider making the extract
-  //                     more explicit.
+  QueryOptions options;
+  parseQueryOptions(qs, options);
+  auto matches =
+      matcher::MatchFinder().getMatches(rootOp, options, std::move(matcher));
+
+  // An extract call is recognized by considering if the matcher has a name.
+  // TODO: Consider making the extract more explicit.
   // if (matcher.hasFunctionName()) {
   //   auto functionName = matcher.getFunctionName();
-  //   Operation *function = extractFunction(matches.getOperations(),
-  //                                         rootOp->getContext(),
-  //                                         functionName);
+  //   Operation *function =
+  //       extractFunction(matches, rootOp->getContext(), functionName);
   //   os << "\n" << *function << "\n\n";
   //   function->erase();
   //   return mlir::success();
   // }
 
   os << "\n";
-  analyzeAndPrint(os, qs, matches);
+  for (Operation *op : matches) {
+    os << "Match #" << ++matchCount << ":\n\n";
+    //  Placeholder "root" binding for the initial draft.
+    printMatch(os, qs, op, "root");
+  }
+  os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
 
   return mlir::success();
 }
 
-} // namespace mlir::query
+} // namespace mlir::query
\ No newline at end of file
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 31aead7d403d0df..0565efeb2b80f3b 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -107,6 +107,18 @@ QueryRef QueryParser::endQuery(QueryRef queryRef) {
   return queryRef;
 }
 
+QueryRef QueryParser::parseSetBool(bool QuerySession::*var) {
+  StringRef valStr;
+  unsigned value = LexOrCompleteWord<unsigned>(this, valStr)
+                       .Case("false", 0)
+                       .Case("true", 1)
+                       .Default(~0u);
+  if (value == ~0u) {
+    return new InvalidQuery("expected 'true' or 'false', got '" + valStr + "'");
+  }
+  return new SetQuery<bool>(var, value);
+}
+
 namespace {
 
 enum class ParsedQueryKind {
@@ -116,6 +128,15 @@ enum class ParsedQueryKind {
   Help,
   Match,
   Quit,
+  Let,
+  Set,
+};
+
+enum ParsedQueryVariable {
+  Invalid,
+  OmitBlockArguments,
+  OmitUsesFromAbove,
+  Inclusive,
 };
 
 QueryRef
@@ -146,7 +167,10 @@ QueryRef QueryParser::doParse() {
           .Case("", ParsedQueryKind::NoOp)
           .Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
           .Case("help", ParsedQueryKind::Help)
+          .Case("l", ParsedQueryKind::Let, /*isCompletion=*/false)
+          .Case("let", ParsedQueryKind::Let, /*isCompletion=*/false)
           .Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
+          .Case("set", ParsedQueryKind::Set)
           .Case("match", ParsedQueryKind::Match)
           .Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false)
           .Case("quit", ParsedQueryKind::Quit)
@@ -167,6 +191,27 @@ QueryRef QueryParser::doParse() {
   case ParsedQueryKind::Quit:
     return endQuery(new QuitQuery);
 
+  case ParsedQueryKind::Let: {
+    llvm::StringRef name = lexWord();
+
+    if (name.empty()) {
+      return new InvalidQuery("expected variable name");
+    }
+
+    if (completionPos) {
+      return completeMatcherExpression();
+    }
+
+    matcher::internal::Diagnostics diag;
+    matcher::VariantValue value;
+    if (!matcher::internal::Parser::parseExpression(
+            line, qs.getRegistryData(), &qs.namedValues, &value, &diag)) {
+      return makeInvalidQueryFromDiagnostics(diag);
+    }
+    QueryRef query = new LetQuery(name, value);
+    query->remainingContent = line;
+    return query;
+  }
   case ParsedQueryKind::Match: {
     if (completionPos) {
       return completeMatcherExpression();
@@ -183,11 +228,41 @@ QueryRef QueryParser::doParse() {
     }
     auto actualSource = origMatcherSource.substr(0, origMatcherSource.size() -
                                                         matcherSource.size());
+
     QueryRef query = new MatchQuery(actualSource, *matcher);
     query->remainingContent = matcherSource;
     return query;
   }
-
+  case ParsedQueryKind::Set: {
+    llvm::StringRef varStr;
+    ParsedQueryVariable var =
+        LexOrCompleteWord<ParsedQueryVariable>(this, varStr)
+            .Case("omitBlockArguments", ParsedQueryVariable::OmitBlockArguments)
+            .Case("omitUsesFromAbove", ParsedQueryVariable::OmitUsesFromAbove)
+            .Case("inclusive", ParsedQueryVariable::Inclusive)
+            .Default(ParsedQueryVariable::Invalid);
+    if (varStr.empty()) {
+      return new InvalidQuery("expected variable name");
+    }
+    if (var == ParsedQueryVariable::Invalid) {
+      return new InvalidQuery("unknown variable: '" + varStr + "'");
+    }
+    QueryRef query;
+    switch (var) {
+    case ParsedQueryVariable::OmitBlockArguments:
+      query = parseSetBool(&QuerySession::omitBlockArguments);
+      break;
+    case ParsedQueryVariable::OmitUsesFromAbove:
+      query = parseSetBool(&QuerySession::omitUsesFromAbove);
+      break;
+    case ParsedQueryVariable::Inclusive:
+      query = parseSetBool(&QuerySession::inclusive);
+      break;
+    case ParsedQueryVariable::Invalid:
+      llvm_unreachable("Invalid query kind");
+    }
+    return endQuery(query);
+  }
   case ParsedQueryKind::Invalid:
     return new InvalidQuery("unknown command: " + commandStr);
   }
diff --git a/mlir/lib/Query/QueryParser.h b/mlir/lib/Query/QueryParser.h
index e9c30eccecab9e4..69cc5d0043d57df 100644
--- a/mlir/lib/Query/QueryParser.h
+++ b/mlir/lib/Query/QueryParser.h
@@ -39,8 +39,8 @@ class QueryParser {
   struct LexOrCompleteWord;
 
   QueryRef completeMatcherExpression();
-
   QueryRef endQuery(QueryRef queryRef);
+  QueryRef parseSetBool(bool QuerySession::*var);
 
   // Parse [begin, end) and returns a reference to the parsed query object,
   // which may be an InvalidQuery if a parse error occurs.
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
new file mode 100644
index 000000000000000..e491beb5007792b
--- /dev/null
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-query %s -c "m definedBy(hasOpName(())" | FileCheck %s
+
+
+func.func @matrix_multiply(%A: memref<4x4xf32>, %B: memref<4x4xf32>, %C: memref<4x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c1 = arith.constant 1 : index
+
+  scf.for %i = %c0 to %c4 step %c1 {
+    scf.for %j = %c0 to %c4 step %c1 {
+      %sum_init = arith.constant 0.0 : f32
+      %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %sum_init) -> (f32) {
+        %a_ik = memref.load %A[%i, %k] : memref<4x4xf32>
+        %b_kj = memref.load %B[%k, %j] : memref<4x4xf32>
+        %prod = arith.mulf %a_ik, %b_kj : f32
+        %new_acc = arith.addf %acc, %prod : f32
+        scf.yield %new_acc : f32
+      }
+      memref.store %sum, %C[%i, %j] : memref<4x4xf32>
+    }
+  }
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index a783f65c6761bcf..473c30aee7123ae 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -16,4 +16,4 @@ func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
   %sum2 = arith.addf %mul1, %b : f32
   %mul2 = arith.mulf %sub2, %sum2 : f32
   return %mul2 : f32
-}
+    }
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index d5c0b1632d3c5d4..468f948bec24cac 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -44,6 +44,10 @@ int main(int argc, char **argv) {
                                   mlir::query::extramatcher::getDefinitions);
   matcherRegistry.registerMatcher("definedBy",
                                   mlir::query::extramatcher::definedBy);
+  matcherRegistry.registerMatcher("usedBy", mlir::query::extramatcher::usedBy);
+  matcherRegistry.registerMatcher("getUses",
+                                  mlir::query::extramatcher::getUses);
+
   matcherRegistry.registerMatcher("hasOpAttrName",
                                   static_cast<HasOpAttrName *>(m_Attr));
   matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

>From 214a32343c7e3438ddd5dadb195c3271a5660e50 Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Sun, 19 Jan 2025 13:07:17 +0000
Subject: [PATCH 3/5] conflicts fixed

---
 .../mlir/Query/Matcher/ExtraMatchers.h        | 27 +++------
 mlir/include/mlir/Query/Matcher/MatchFinder.h | 58 ++++++++++++-------
 .../mlir/Query/Matcher/MatchersInternal.h     |  9 +--
 mlir/include/mlir/Query/Query.h               |  1 -
 mlir/include/mlir/Query/QuerySession.h        |  6 +-
 mlir/lib/Query/Matcher/Parser.cpp             | 35 ++---------
 mlir/lib/Query/Query.cpp                      | 23 +-------
 mlir/test/mlir-query/complex-test.mlir        | 39 +++++++------
 mlir/test/mlir-query/function-extraction.mlir |  2 +-
 9 files changed, 82 insertions(+), 118 deletions(-)

diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
index 1900879ca70920b..57adc3241b0bef5 100644
--- a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -1,4 +1,5 @@
-//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
+//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
+//-*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -36,7 +37,6 @@ class BackwardSliceMatcher {
   bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
                QueryOptions &options, unsigned tempHops) {
 
-    bool validSlice = true;
     if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
       return false;
     }
@@ -56,16 +56,12 @@ class BackwardSliceMatcher {
         Operation *parentOp = block->getParentOp();
 
         if (parentOp && backwardSlice.count(parentOp) == 0) {
-          if (parentOp->getNumRegions() == 1 &&
-              parentOp->getRegion(0).getBlocks().size() == 1) {
-            validSlice = false;
-            return;
-          };
-          matches(parentOp, backwardSlice, options, tempHops - 1);
+          assert(parentOp->getNumRegions() == 1 &&
+                 parentOp->getRegion(0).getBlocks().size() == 1);
+          matches(parentOp, backwardSlice, options, tempHops-1);
         }
       } else {
-        validSlice = false;
-        return;
+        llvm_unreachable("No definingOp and not a block argument.");
       }
     };
 
@@ -78,22 +74,13 @@ class BackwardSliceMatcher {
           for (OpOperand &operand : op->getOpOperands()) {
             if (!descendents.contains(operand.get().getParentRegion()))
               processValue(operand.get());
-            if (!validSlice)
-              return;
           }
         });
       });
     }
 
-    llvm::for_each(op->getOperands(), [&](Value operand) {
-      processValue(operand);
-      if (!validSlice)
-        return;
-    });
+    llvm::for_each(op->getOperands(), processValue);
     backwardSlice.insert(op);
-    if (!validSlice) {
-      return false;
-    }
     return true;
   }
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index cbdf2fec46a7611..c188bfe75406a7e 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -1,43 +1,61 @@
-//===- MatchFinder.h - ------------------------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
 //===----------------------------------------------------------------------===//
-//
-// This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
-//
+// MatchFinder.h
 //===----------------------------------------------------------------------===//
 
 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
 
 #include "MatchersInternal.h"
-#include "mlir/IR/Operation.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir::query::matcher {
 
-// MatchFinder is used to find all operations that match a given matcher.
 class MatchFinder {
+private:
+  // Base print function with binding text
+  static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                         mlir::Operation *op, const std::string &binding) {
+    auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+    auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+        qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+    qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+                                       "\"" + binding + "\" binds here");
+  };
+
 public:
-  // Returns all operations that match the given matcher.
   static SetVector<Operation *>
-  getMatches(Operation *root, QueryOptions &options, DynMatcher matcher) {
-    SetVector<Operation *> backwardSlice;
+  getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+             llvm::raw_ostream &os, QuerySession &qs) {
+    unsigned matchCount = 0;
+    SetVector<Operation *> matchedOps;
+    SetVector<Operation *> tempStorage;
+
     root->walk([&](Operation *subOp) {
       if (matcher.match(subOp)) {
-        backwardSlice.insert(subOp);
+        matchedOps.insert(subOp);
+        os << "Match #" << ++matchCount << ":\n\n";
+        printMatch(os, qs, subOp, "root");
       } else {
-        matcher.match(subOp, backwardSlice, options);
-        ////
+        SmallVector<Operation *> printingOps;
+        size_t sizeBefore = matchedOps.size();
+        if (matcher.match(subOp, tempStorage, options)) {
+          os << "Match #" << ++matchCount << ":\n\n";
+          SmallVector<Operation *> printingOps(tempStorage.takeVector());
+          for (auto op : printingOps) {
+            printMatch(os, qs, op, ""); // Using version without binding text
+            matchedOps.insert(op);
+          }
+          printingOps.clear();
+        }
       }
     });
-    return backwardSlice;
+    return matchedOps;
   }
 };
 
 } // namespace mlir::query::matcher
 
-#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 081d216b56b61c0..b795d9a291844f8 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,6 +1,6 @@
 //===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
 //
-// Part of the LLVM Project, under the Apache License v2.0 wIth LLVM Exceptions.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
 //
@@ -11,11 +11,6 @@
 
 #include "mlir/IR/Matchers.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
-#include "llvm/ADT/MapVector.h"
-#include <memory>
-#include <stack>
-#include <unordered_set>
-#include <vector>
 
 namespace mlir {
 namespace query {
@@ -112,4 +107,4 @@ class DynMatcher {
 
 } // namespace mlir::query::matcher
 
-#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
\ No newline at end of file
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index d695f648016be31..89d48773d2c3e6b 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -10,7 +10,6 @@
 #define MLIR_TOOLS_MLIRQUERY_QUERY_H
 
 #include "Matcher/VariantValue.h"
-#include "mlir/Analysis/SliceAnalysis.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/LineEditor/LineEditor.h"
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index 7faf206b96124b5..495358e8f36f94e 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -9,14 +9,18 @@
 #ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
 #define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
 
+#include "Matcher/VariantValue.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Query/Matcher/Registry.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/Support/SourceMgr.h"
 
+namespace mlir::query::matcher {
+class Registry;
+}
+
 namespace mlir::query {
 
-class Registry;
 // Represents the state for a particular mlir-query session.
 class QuerySession {
 public:
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 79797ec45b549b0..726f1188d7e4c8c 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -157,7 +157,6 @@ class Parser::CodeTokenizer {
   }
 
   void consumeNumberLiteral(TokenInfo *result) {
-    bool isFloatingLiteral = false;
     unsigned length = 1;
     if (code.size() > 1) {
       // Consume the 'x' or 'b' radix modifier, if present.
@@ -170,39 +169,17 @@ class Parser::CodeTokenizer {
     while (length < code.size() && isdigit(code[length]))
       ++length;
 
-    // Try to recognize a floating point literal.
-    while (length < code.size()) {
-      char c = code[length];
-      if (c == '-' || c == '+' || c == '.' || isdigit(c)) {
-        isFloatingLiteral = true;
-        length++;
-      } else {
-        break;
-      }
-    }
-
     result->text = code.take_front(length);
     code = code.drop_front(length);
 
-    if (isFloatingLiteral) {
-      char *end;
-      errno = 0;
-      std::string text = result->text.str();
-      double doubleValue = strtod(text.c_str(), &end);
-      if (*end == 0 && errno == 0) {
-        result->kind = TokenKind::Literal;
-        result->value = static_cast<double>(doubleValue);
-        return;
-      }
-    } else {
-      unsigned value;
-      if (!result->text.getAsInteger(0, value)) {
-        result->kind = TokenKind::Literal;
-        result->value = value;
-        return;
-      }
+    unsigned value;
+    if (!result->text.getAsInteger(0, value)) {
+      result->kind = TokenKind::Literal;
+      result->value = static_cast<unsigned>(value);
+      return;
     }
   }
+
   // Consume a string literal, handle escape sequences and missing closing
   // quote.
   void consumeStringLiteral(TokenInfo *result) {
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 03d5018abf5b012..cbc436299afb701 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -15,8 +15,6 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
-#include <unordered_map>
-#include <unordered_set>
 
 namespace mlir::query {
 
@@ -29,15 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
   return QueryParser::complete(line, pos, qs);
 }
 
-static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
-                       mlir::Operation *op, const std::string &binding) {
-  auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
-  auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
-      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
-  qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
-                                     "\"" + binding + "\" binds here");
-}
-
 // TODO: Extract into a helper function that can be reused outside query
 // context.
 static Operation *extractFunction(std::vector<Operation *> &ops,
@@ -150,8 +139,8 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
 
   QueryOptions options;
   parseQueryOptions(qs, options);
-  auto matches =
-      matcher::MatchFinder().getMatches(rootOp, options, std::move(matcher));
+  auto matches = matcher::MatchFinder().getMatches(rootOp, options,
+                                                   std::move(matcher), os, qs);
 
   // An extract call is recognized by considering if the matcher has a name.
   // TODO: Consider making the extract more explicit.
@@ -164,14 +153,6 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   //   return mlir::success();
   // }
 
-  os << "\n";
-  for (Operation *op : matches) {
-    os << "Match #" << ++matchCount << ":\n\n";
-    //  Placeholder "root" binding for the initial draft.
-    printMatch(os, qs, op, "root");
-  }
-  os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
-
   return mlir::success();
 }
 
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
index e491beb5007792b..af6193e4818abfe 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -1,23 +1,26 @@
 // RUN: mlir-query %s -c "m definedBy(hasOpName(())" | FileCheck %s
 
+func.func @region_control_flow(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr = "func"} {
+  %0 = memref.alloca() {test.ptr = "alloca_1"} : memref<8x64xf32>
+  %1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
+  %2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
 
-func.func @matrix_multiply(%A: memref<4x4xf32>, %B: memref<4x4xf32>, %C: memref<4x4xf32>) {
-  %c0 = arith.constant 0 : index
-  %c4 = arith.constant 4 : index
-  %c1 = arith.constant 1 : index
+  %3 = scf.if %cond -> (memref<8x64xf32>) {
+    scf.yield %0 : memref<8x64xf32>
+  } else {
+    scf.yield %0 : memref<8x64xf32>
+  } {test.ptr = "if_alloca"}
 
-  scf.for %i = %c0 to %c4 step %c1 {
-    scf.for %j = %c0 to %c4 step %c1 {
-      %sum_init = arith.constant 0.0 : f32
-      %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %sum_init) -> (f32) {
-        %a_ik = memref.load %A[%i, %k] : memref<4x4xf32>
-        %b_kj = memref.load %B[%k, %j] : memref<4x4xf32>
-        %prod = arith.mulf %a_ik, %b_kj : f32
-        %new_acc = arith.addf %acc, %prod : f32
-        scf.yield %new_acc : f32
-      }
-      memref.store %sum, %C[%i, %j] : memref<4x4xf32>
-    }
-  }
+  %4 = scf.if %cond -> (memref<8x64xf32>) {
+    scf.yield %0 : memref<8x64xf32>
+  } else {
+    scf.yield %1 : memref<8x64xf32>
+  } {test.ptr = "if_alloca_merge"}
+
+  %5 = scf.if %cond -> (memref<8x64xf32>) {
+    scf.yield %2 : memref<8x64xf32>
+  } else {
+    scf.yield %2 : memref<8x64xf32>
+  } {test.ptr = "if_alloc"}
   return
-}
\ No newline at end of file
+}
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index 473c30aee7123ae..ecbd17ab59dfb17 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -4,7 +4,7 @@
 // CHECK:       %[[MUL0:.*]] = arith.mulf {{.*}} : f32
 // CHECK:       %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
 // CHECK:       %[[MUL2:.*]] = arith.mulf {{.*}} : f32
-// CHECK-NEXT:  return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
+// CHECK-NEXT:  return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32S
 
 func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
   %sum0 = arith.addf %a, %b : f32

>From 5d0bf4fbc53543091082bfe747f1a492e4f70a11 Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Sun, 19 Jan 2025 12:42:43 +0000
Subject: [PATCH 4/5] added info in header files & fixed newline

---
 mlir/include/mlir/Query/Matcher/MatchFinder.h    | 10 ++++++++++
 .../mlir/Query/Matcher/MatchersInternal.h        | 16 ++++++++++++++--
 mlir/lib/Query/Query.cpp                         |  2 +-
 3 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index c188bfe75406a7e..f957cff9d038bc6 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -1,4 +1,14 @@
+//===- MatchFinder.h - ------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
 //===----------------------------------------------------------------------===//
+//
+// This file contains the MatchFinder class, which is used to find operations
+// that match a given matcher and print them.
+//
 // MatchFinder.h
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index b795d9a291844f8..3f077bc44bd6771 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,11 +1,23 @@
-//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements the base layer of the matcher framework.
+//
+// Matchers are methods that return a Matcher which provides a method
+// match(Operation *op)
+// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
+// &options)
+//
+// The matcher functions are defined in include/mlir/IR/Matchers.h.
+// This file contains the wrapper classes needed to construct matchers for
+// mlir-query.
 // SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-
 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index cbc436299afb701..960fe4a3775455e 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -156,4 +156,4 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   return mlir::success();
 }
 
-} // namespace mlir::query
\ No newline at end of file
+} // namespace mlir::query

>From bafbd37932158e744350846f06c94287f8d3aa4c Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Sun, 19 Jan 2025 13:09:35 +0000
Subject: [PATCH 5/5] removed redundant code

---
 mlir/include/mlir/Query/Matcher/MatchFinder.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index f957cff9d038bc6..6d8214bd42ca69d 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -50,12 +50,12 @@ class MatchFinder {
         printMatch(os, qs, subOp, "root");
       } else {
         SmallVector<Operation *> printingOps;
-        size_t sizeBefore = matchedOps.size();
+
         if (matcher.match(subOp, tempStorage, options)) {
           os << "Match #" << ++matchCount << ":\n\n";
           SmallVector<Operation *> printingOps(tempStorage.takeVector());
           for (auto op : printingOps) {
-            printMatch(os, qs, op, ""); // Using version without binding text
+            printMatch(os, qs, op, "");
             matchedOps.insert(op);
           }
           printingOps.clear();



More information about the Mlir-commits mailing list