[Mlir-commits] [mlir] [mlir] MLIR-QUERY slice-matchers implementation (PR #115670)

Denzel Budii llvmlistbot at llvm.org
Sat Mar 29 08:20:09 PDT 2025


https://github.com/dbudii updated https://github.com/llvm/llvm-project/pull/115670

>From 4f34c45b5d13918f7f757b9084ce724898649406 Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Sat, 25 Jan 2025 13:38:31 +0000
Subject: [PATCH 1/5] Fixed pattern matching in mlir-query test files & removed
 asserts from slice-matchers

---
 mlir/include/mlir/IR/Matchers.h               |   4 +-
 .../mlir/Query/Matcher/ExtraMatchers.h        | 188 ++++++++++++++++++
 mlir/include/mlir/Query/Matcher/Marshallers.h |  15 ++
 mlir/include/mlir/Query/Matcher/MatchFinder.h |  52 +++--
 .../mlir/Query/Matcher/MatchersInternal.h     |  60 +++++-
 .../include/mlir/Query/Matcher/VariantValue.h |  12 +-
 mlir/include/mlir/Query/Query.h               |  34 +++-
 mlir/include/mlir/Query/QuerySession.h        |  11 +-
 mlir/lib/Query/Matcher/Parser.cpp             |  36 ++++
 mlir/lib/Query/Matcher/RegistryManager.cpp    |   2 +
 mlir/lib/Query/Matcher/VariantValue.cpp       |  24 +++
 mlir/lib/Query/Query.cpp                      |  37 ++--
 mlir/lib/Query/QueryParser.cpp                |  52 ++++-
 mlir/lib/Query/QueryParser.h                  |   2 +-
 mlir/test/mlir-query/complex-test.mlir        |  39 ++++
 mlir/test/mlir-query/function-extraction.mlir |   2 +-
 mlir/tools/mlir-query/mlir-query.cpp          |  10 +
 17 files changed, 529 insertions(+), 51 deletions(-)
 create mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h
 create mode 100644 mlir/test/mlir-query/complex-test.mlir

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 1dce055db1b4a..2204a68be26b1 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -59,7 +59,7 @@ struct NameOpMatcher {
   NameOpMatcher(StringRef name) : name(name) {}
   bool match(Operation *op) { return op->getName().getStringRef() == name; }
 
-  StringRef name;
+  std::string name;
 };
 
 /// The matcher that matches operations that have the specified attribute name.
@@ -67,7 +67,7 @@ struct AttrOpMatcher {
   AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
   bool match(Operation *op) { return op->hasAttr(attrName); }
 
-  StringRef attrName;
+  std::string attrName;
 };
 
 /// The matcher that matches operations that have the `ConstantLike` trait, and
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 0000000000000..908fccfc704c3
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,188 @@
+//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides extra matchers that are very useful for mlir-query
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTRAMATCHERS_H
+#define MLIR_IR_EXTRAMATCHERS_H
+
+#include "MatchFinder.h"
+#include "MatchersInternal.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Query/Query.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+namespace query {
+
+namespace extramatcher {
+
+namespace detail {
+
+class BackwardSliceMatcher {
+public:
+  BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+  bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+               QueryOptions &options, unsigned tempHops) {
+
+    if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+      return false;
+    }
+
+    auto processValue = [&](Value value) {
+      if (tempHops == 0) {
+        return;
+      }
+      if (auto *definingOp = value.getDefiningOp()) {
+        if (backwardSlice.count(definingOp) == 0)
+          matches(definingOp, backwardSlice, options, tempHops - 1);
+      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+        if (options.omitBlockArguments)
+          return;
+        Block *block = blockArg.getOwner();
+
+        Operation *parentOp = block->getParentOp();
+
+        if (parentOp && backwardSlice.count(parentOp) == 0) {
+          if (parentOp->getNumRegions() != 1 &&
+              parentOp->getRegion(0).getBlocks().size() != 1) {
+            llvm::errs()
+                << "Error: Expected parentOp to have exactly one region and "
+                << "exactly one block, but found " << parentOp->getNumRegions()
+                << " regions and "
+                << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
+          };
+          matches(parentOp, backwardSlice, options, tempHops - 1);
+        }
+      } else {
+        llvm::errs() << "No definingOp and not a block argument\n";
+        return;
+      }
+    };
+
+    if (!options.omitUsesFromAbove) {
+      llvm::for_each(op->getRegions(), [&](Region &region) {
+        SmallPtrSet<Region *, 4> descendents;
+        region.walk(
+            [&](Region *childRegion) { descendents.insert(childRegion); });
+        region.walk([&](Operation *op) {
+          for (OpOperand &operand : op->getOpOperands()) {
+            if (!descendents.contains(operand.get().getParentRegion()))
+              processValue(operand.get());
+          }
+        });
+      });
+    }
+
+    llvm::for_each(op->getOperands(), processValue);
+    backwardSlice.insert(op);
+    return true;
+  }
+
+public:
+  bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+             QueryOptions &options) {
+
+    if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
+      if (!options.inclusive) {
+        backwardSlice.remove(op);
+      }
+      return true;
+    }
+    return false;
+  }
+
+private:
+  matcher::DynMatcher innerMatcher;
+  unsigned hops;
+};
+
+class ForwardSliceMatcher {
+public:
+  ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+  bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+               QueryOptions &options, unsigned tempHops) {
+
+    if (tempHops == 0) {
+      forwardSlice.insert(op);
+      return true;
+    }
+
+    for (Region &region : op->getRegions())
+      for (Block &block : region)
+        for (Operation &blockOp : block)
+          if (forwardSlice.count(&blockOp) == 0)
+            matches(&blockOp, forwardSlice, options, tempHops - 1);
+    for (Value result : op->getResults()) {
+      for (Operation *userOp : result.getUsers())
+        if (forwardSlice.count(userOp) == 0)
+          matches(userOp, forwardSlice, options, tempHops - 1);
+    }
+
+    forwardSlice.insert(op);
+    return true;
+  }
+
+public:
+  bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+             QueryOptions &options) {
+    if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
+      if (!options.inclusive) {
+        forwardSlice.remove(op);
+      }
+      SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+      forwardSlice.insert(v.rbegin(), v.rend());
+      return true;
+    }
+    return false;
+  }
+
+private:
+  matcher::DynMatcher innerMatcher;
+  unsigned hops;
+};
+
+} // namespace detail
+
+inline detail::BackwardSliceMatcher
+definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::BackwardSliceMatcher
+getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+inline detail::ForwardSliceMatcher
+usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::ForwardSliceMatcher
+getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+} // namespace extramatcher
+
+} // namespace query
+
+} // namespace mlir
+
+#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc..c775dbc5c86da 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
   }
 };
 
+template <>
+struct ArgTypeTraits<unsigned> {
+  static bool hasCorrectType(const VariantValue &value) {
+    return value.isUnsigned();
+  }
+
+  static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+
+  static ArgKind getKind() { return ArgKind::Unsigned; }
+
+  static std::optional<std::string> getBestGuess(const VariantValue &) {
+    return std::nullopt;
+  }
+};
+
 template <>
 struct ArgTypeTraits<DynMatcher> {
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2..1b9d3bc307ff5 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
+// that match a given matcher and print them.
 //
 //===----------------------------------------------------------------------===//
 
@@ -15,24 +15,52 @@
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
 
 #include "MatchersInternal.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir::query::matcher {
 
-// MatchFinder is used to find all operations that match a given matcher.
 class MatchFinder {
-public:
-  // Returns all operations that match the given matcher.
-  static std::vector<Operation *> getMatches(Operation *root,
-                                             DynMatcher matcher) {
-    std::vector<Operation *> matches;
+private:
+  static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                         mlir::Operation *op, const std::string &binding) {
+    auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+    auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+        qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+    qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+                                       "\"" + binding + "\" binds here");
+  };
 
-    // Simple match finding with walk.
+public:
+  static std::vector<Operation *>
+  getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+             llvm::raw_ostream &os, QuerySession &qs) {
+    unsigned matchCount = 0;
+    std::vector<Operation *> matchedOps;
+    SetVector<Operation *> tempStorage;
+    os << "\n";
     root->walk([&](Operation *subOp) {
-      if (matcher.match(subOp))
-        matches.push_back(subOp);
+      if (matcher.match(subOp)) {
+        matchedOps.push_back(subOp);
+        os << "Match #" << ++matchCount << ":\n\n";
+        printMatch(os, qs, subOp, "root");
+      } else {
+        SmallVector<Operation *> printingOps;
+        if (matcher.match(subOp, tempStorage, options)) {
+          os << "Match #" << ++matchCount << ":\n\n";
+          SmallVector<Operation *> printingOps(tempStorage.takeVector());
+          for (auto op : printingOps) {
+            printMatch(os, qs, op, "root");
+            matchedOps.push_back(op);
+          }
+          printingOps.clear();
+        }
+      }
     });
-
-    return matches;
+    os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+    return matchedOps;
   }
 };
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e..b532b47be7d05 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,4 +1,3 @@
-//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -10,28 +9,53 @@
 //
 // Matchers are methods that return a Matcher which provides a method
 // match(Operation *op)
+// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
+// &options)
 //
 // The matcher functions are defined in include/mlir/IR/Matchers.h.
 // This file contains the wrapper classes needed to construct matchers for
 // mlir-query.
 //
 //===----------------------------------------------------------------------===//
-
 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 
 #include "mlir/IR/Matchers.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 
+namespace mlir {
+namespace query {
+struct QueryOptions;
+}
+} // namespace mlir
+
 namespace mlir::query::matcher {
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+                               std::declval<Operation *>()))>>
+    : std::true_type {};
+
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+                              std::declval<Operation *>(),
+                              std::declval<SetVector<Operation *> &>(),
+                              std::declval<QueryOptions &>()))>>
+    : std::true_type {};
 
 // Generic interface for matchers on an MLIR operation.
 class MatcherInterface
     : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
 public:
   virtual ~MatcherInterface() = default;
-
   virtual bool match(Operation *op) = 0;
+  virtual bool match(Operation *op, SetVector<Operation *> &matchedOps,
+                     QueryOptions &options) = 0;
 };
 
 // MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +64,26 @@ template <typename MatcherFn>
 class MatcherFnImpl : public MatcherInterface {
 public:
   MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
-  bool match(Operation *op) override { return matcherFn.match(op); }
+
+  bool match(Operation *op) override {
+    if constexpr (has_simple_match<MatcherFn>::value)
+      return matcherFn.match(op);
+    return false;
+  }
+
+  bool match(Operation *op, SetVector<Operation *> &matchedOps,
+             QueryOptions &options) override {
+    if constexpr (has_bound_match<MatcherFn>::value)
+      return matcherFn.match(op, matchedOps, options);
+    return false;
+  }
 
 private:
   MatcherFn matcherFn;
 };
 
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
 class DynMatcher {
 public:
   // Takes ownership of the provided implementation pointer.
@@ -62,12 +98,14 @@ class DynMatcher {
   }
 
   bool match(Operation *op) const { return implementation->match(op); }
+  bool match(Operation *op, SetVector<Operation *> &matchedOps,
+             QueryOptions &options) const {
+    return implementation->match(op, matchedOps, options);
+  }
 
-  void setFunctionName(StringRef name) { functionName = name.str(); };
-
-  bool hasFunctionName() const { return !functionName.empty(); };
-
-  StringRef getFunctionName() const { return functionName; };
+  void setFunctionName(StringRef name) { functionName = name.str(); }
+  bool hasFunctionName() const { return !functionName.empty(); }
+  StringRef getFunctionName() const { return functionName; }
 
 private:
   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e02..6b57119df7a9b 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
 namespace mlir::query::matcher {
 
 // All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Matcher, String, Unsigned };
 
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
@@ -81,6 +81,7 @@ class VariantValue {
   // Specific constructors for each supported type.
   VariantValue(const llvm::StringRef string);
   VariantValue(const VariantMatcher &matcher);
+  VariantValue(unsigned Unsigned);
 
   // String value functions.
   bool isString() const;
@@ -92,8 +93,15 @@ class VariantValue {
   const VariantMatcher &getMatcher() const;
   void setMatcher(const VariantMatcher &matcher);
 
+  // Unsigned value functions.
+  bool isUnsigned() const;
+  unsigned getUnsigned() const;
+  void setUnsigned(unsigned Unsigned);
+
   // String representation of the type of the value.
   std::string getTypeAsString() const;
+  explicit operator bool() const { return hasValue(); }
+  bool hasValue() const { return type != ValueType::Nothing; }
 
 private:
   void reset();
@@ -103,12 +111,14 @@ class VariantValue {
     Nothing,
     String,
     Matcher,
+    Unsigned,
   };
 
   // All supported value types.
   union AllValues {
     llvm::StringRef *String;
     VariantMatcher *Matcher;
+    unsigned Unsigned;
   };
 
   ValueType type;
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 18f2172c9510a..bb5b98432d51c 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -17,7 +17,13 @@
 
 namespace mlir::query {
 
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+struct QueryOptions {
+  bool omitBlockArguments = false;
+  bool omitUsesFromAbove = true;
+  bool inclusive = true;
+};
+
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit, SetBool };
 
 class QuerySession;
 
@@ -103,6 +109,32 @@ struct MatchQuery : Query {
   }
 };
 
+template <typename T>
+struct SetQueryKind {};
+
+template <>
+struct SetQueryKind<bool> {
+  static const QueryKind value = QueryKind::SetBool;
+};
+template <typename T>
+struct SetQuery : Query {
+  SetQuery(T QuerySession::*var, T value)
+      : Query(SetQueryKind<T>::value), var(var), value(value) {}
+
+  llvm::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override {
+    qs.*var = value;
+    return mlir::success();
+  }
+
+  static bool classof(const Query *query) {
+    return query->kind == SetQueryKind<T>::value;
+  }
+
+  T QuerySession::*var;
+  T value;
+};
+
 } // namespace mlir::query
 
 #endif
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index fe552d750fc77..495358e8f36f9 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -9,14 +9,18 @@
 #ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
 #define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
 
+#include "Matcher/VariantValue.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Query/Matcher/Registry.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/Support/SourceMgr.h"
 
+namespace mlir::query::matcher {
+class Registry;
+}
+
 namespace mlir::query {
 
-class Registry;
 // Represents the state for a particular mlir-query session.
 class QuerySession {
 public:
@@ -33,6 +37,11 @@ class QuerySession {
   llvm::StringMap<matcher::VariantValue> namedValues;
   bool terminate = false;
 
+public:
+  bool omitBlockArguments = false;
+  bool omitUsesFromAbove = true;
+  bool inclusive = true;
+
 private:
   Operation *rootOp;
   llvm::SourceMgr &sourceMgr;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f..4dcb86a9383f3 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
     case '\'':
       consumeStringLiteral(&result);
       break;
+    case '0':
+    case '1':
+    case '2':
+    case '3':
+    case '4':
+    case '5':
+    case '6':
+    case '7':
+    case '8':
+    case '9':
+      consumeNumberLiteral(&result);
+      break;
     default:
       parseIdentifierOrInvalid(&result);
       break;
@@ -144,6 +156,30 @@ class Parser::CodeTokenizer {
     return result;
   }
 
+  void consumeNumberLiteral(TokenInfo *result) {
+    unsigned length = 1;
+    if (code.size() > 1) {
+      // Consume the 'x' or 'b' radix modifier, if present.
+      switch (tolower(code[1])) {
+      case 'x':
+      case 'b':
+        length = 2;
+      }
+    }
+    while (length < code.size() && isdigit(code[length]))
+      ++length;
+
+    result->text = code.take_front(length);
+    code = code.drop_front(length);
+
+    unsigned value;
+    if (!result->text.getAsInteger(0, value)) {
+      result->kind = TokenKind::Literal;
+      result->value = static_cast<unsigned>(value);
+      return;
+    }
+  }
+
   // Consume a string literal, handle escape sequences and missing closing
   // quote.
   void consumeStringLiteral(TokenInfo *result) {
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 645db7109c2de..8d6c0135aa117 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -31,6 +31,8 @@ static std::string asArgString(ArgKind kind) {
     return "Matcher";
   case ArgKind::String:
     return "String";
+  case ArgKind::Unsigned:
+    return "unsigned";
   }
   llvm_unreachable("Unhandled ArgKind");
 }
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 65bd4bd77bcf8..d5218d8dad8c9 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,6 +56,10 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
   value.Matcher = new VariantMatcher(matcher);
 }
 
+VariantValue::VariantValue(unsigned Unsigned) : type(ValueType::Unsigned) {
+  value.Unsigned = Unsigned;
+}
+
 VariantValue::~VariantValue() { reset(); }
 
 VariantValue &VariantValue::operator=(const VariantValue &other) {
@@ -69,6 +73,9 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
   case ValueType::Matcher:
     setMatcher(other.getMatcher());
     break;
+  case ValueType::Unsigned:
+    setUnsigned(other.getUnsigned());
+    break;
   case ValueType::Nothing:
     type = ValueType::Nothing;
     break;
@@ -85,12 +92,27 @@ void VariantValue::reset() {
     delete value.Matcher;
     break;
   // Cases that do nothing.
+  case ValueType::Unsigned:
   case ValueType::Nothing:
     break;
   }
   type = ValueType::Nothing;
 }
 
+// Unsinged
+bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
+
+unsigned VariantValue::getUnsigned() const {
+  assert(isUnsigned());
+  return value.Unsigned;
+}
+
+void VariantValue::setUnsigned(unsigned newValue) {
+  reset();
+  type = ValueType::Unsigned;
+  value.Unsigned = newValue;
+}
+
 bool VariantValue::isString() const { return type == ValueType::String; }
 
 const llvm::StringRef &VariantValue::getString() const {
@@ -123,6 +145,8 @@ std::string VariantValue::getTypeAsString() const {
     return "String";
   case ValueType::Matcher:
     return "Matcher";
+  case ValueType::Unsigned:
+    return "Unsigned";
   case ValueType::Nothing:
     return "Nothing";
   }
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 7d9f360670051..dd699857568d7 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -26,15 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
   return QueryParser::complete(line, pos, qs);
 }
 
-static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
-                       const std::string &binding) {
-  auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
-  auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
-      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
-  qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
-                                     "\"" + binding + "\" binds here");
-}
-
 // TODO: Extract into a helper function that can be reused outside query
 // context.
 static Operation *extractFunction(std::vector<Operation *> &ops,
@@ -99,6 +91,12 @@ static Operation *extractFunction(std::vector<Operation *> &ops,
   return funcOp;
 }
 
+static void parseQueryOptions(QuerySession &qs, QueryOptions &options) {
+  options.omitBlockArguments = qs.omitBlockArguments;
+  options.omitUsesFromAbove = qs.omitUsesFromAbove;
+  options.inclusive = qs.inclusive;
+}
+
 Query::~Query() = default;
 
 LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
@@ -114,6 +112,11 @@ LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   os << "Available commands:\n\n"
         "  match MATCHER, m MATCHER      "
         "Match the mlir against the given matcher.\n"
+        "Set query options, useful for complex matchers \n"
+        "   set omitBlockArguments (true|false) \n"
+        "   set omitUsesFromAbove (true|false) \n"
+        "   set inclusive (true|false) \n"
+        "Give a matcher expression a name, to be used later\n"
         "  quit                              "
         "Terminates the query session.\n\n";
   return mlir::success();
@@ -126,9 +129,11 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
 
 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   Operation *rootOp = qs.getRootOp();
-  int matchCount = 0;
-  std::vector<Operation *> matches =
-      matcher::MatchFinder().getMatches(rootOp, matcher);
+
+  QueryOptions options;
+  parseQueryOptions(qs, options);
+  auto matches = matcher::MatchFinder().getMatches(rootOp, options,
+                                                   std::move(matcher), os, qs);
 
   // An extract call is recognized by considering if the matcher has a name.
   // TODO: Consider making the extract more explicit.
@@ -141,14 +146,6 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
     return mlir::success();
   }
 
-  os << "\n";
-  for (Operation *op : matches) {
-    os << "Match #" << ++matchCount << ":\n\n";
-    // Placeholder "root" binding for the initial draft.
-    printMatch(os, qs, op, "root");
-  }
-  os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
-
   return mlir::success();
 }
 
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 31aead7d403d0..7aaf4847f2e47 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -107,6 +107,18 @@ QueryRef QueryParser::endQuery(QueryRef queryRef) {
   return queryRef;
 }
 
+QueryRef QueryParser::parseSetBool(bool QuerySession::*var) {
+  StringRef valStr;
+  unsigned value = LexOrCompleteWord<unsigned>(this, valStr)
+                       .Case("false", 0)
+                       .Case("true", 1)
+                       .Default(~0u);
+  if (value == ~0u) {
+    return new InvalidQuery("expected 'true' or 'false', got '" + valStr + "'");
+  }
+  return new SetQuery<bool>(var, value);
+}
+
 namespace {
 
 enum class ParsedQueryKind {
@@ -116,6 +128,14 @@ enum class ParsedQueryKind {
   Help,
   Match,
   Quit,
+  Set,
+};
+
+enum ParsedQueryVariable {
+  Invalid,
+  OmitBlockArguments,
+  OmitUsesFromAbove,
+  Inclusive,
 };
 
 QueryRef
@@ -147,6 +167,7 @@ QueryRef QueryParser::doParse() {
           .Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
           .Case("help", ParsedQueryKind::Help)
           .Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
+          .Case("set", ParsedQueryKind::Set)
           .Case("match", ParsedQueryKind::Match)
           .Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false)
           .Case("quit", ParsedQueryKind::Quit)
@@ -187,7 +208,36 @@ QueryRef QueryParser::doParse() {
     query->remainingContent = matcherSource;
     return query;
   }
-
+  case ParsedQueryKind::Set: {
+    llvm::StringRef varStr;
+    ParsedQueryVariable var =
+        LexOrCompleteWord<ParsedQueryVariable>(this, varStr)
+            .Case("omitBlockArguments", ParsedQueryVariable::OmitBlockArguments)
+            .Case("omitUsesFromAbove", ParsedQueryVariable::OmitUsesFromAbove)
+            .Case("inclusive", ParsedQueryVariable::Inclusive)
+            .Default(ParsedQueryVariable::Invalid);
+    if (varStr.empty()) {
+      return new InvalidQuery("expected variable name");
+    }
+    if (var == ParsedQueryVariable::Invalid) {
+      return new InvalidQuery("unknown variable: '" + varStr + "'");
+    }
+    QueryRef query;
+    switch (var) {
+    case ParsedQueryVariable::OmitBlockArguments:
+      query = parseSetBool(&QuerySession::omitBlockArguments);
+      break;
+    case ParsedQueryVariable::OmitUsesFromAbove:
+      query = parseSetBool(&QuerySession::omitUsesFromAbove);
+      break;
+    case ParsedQueryVariable::Inclusive:
+      query = parseSetBool(&QuerySession::inclusive);
+      break;
+    case ParsedQueryVariable::Invalid:
+      llvm_unreachable("Invalid query kind");
+    }
+    return endQuery(query);
+  }
   case ParsedQueryKind::Invalid:
     return new InvalidQuery("unknown command: " + commandStr);
   }
diff --git a/mlir/lib/Query/QueryParser.h b/mlir/lib/Query/QueryParser.h
index e9c30eccecab9..69cc5d0043d57 100644
--- a/mlir/lib/Query/QueryParser.h
+++ b/mlir/lib/Query/QueryParser.h
@@ -39,8 +39,8 @@ class QueryParser {
   struct LexOrCompleteWord;
 
   QueryRef completeMatcherExpression();
-
   QueryRef endQuery(QueryRef queryRef);
+  QueryRef parseSetBool(bool QuerySession::*var);
 
   // Parse [begin, end) and returns a reference to the parsed query object,
   // which may be an InvalidQuery if a parse error occurs.
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
new file mode 100644
index 0000000000000..b3df534ee8871
--- /dev/null
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-query %s -c "match getDefinitions(hasOpName("arith.addf"),2)" | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
+  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.addf %in, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<5x5xf32>
+  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %c2 = arith.constant 2 : index
+    %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
+    %2 = arith.addf %extracted, %extracted : f32
+    linalg.yield %2 : f32
+  } -> tensor<5x5xf32>
+  return
+}
+
+// CHECK: Match #1:
+
+// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} 
+// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
+
+// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
+
+
+// CHECK: Match #2:
+
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%c2] : tensor<25xf32>
+// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
+
+
+
+
+
+
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index a783f65c6761b..d7a867eb1a452 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
+// RUN: mlir-query %s -c "m hasOpName("arith.mulf").extract("testmul")" | FileCheck %s
 
 // CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
 // CHECK:       %[[MUL0:.*]] = arith.mulf {{.*}} : f32
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b..5e74da7ee7bdc 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -15,6 +15,8 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/ExtraMatchers.h"
+#include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/Matcher/Registry.h"
 #include "mlir/Tools/mlir-query/MlirQueryMain.h"
 
@@ -39,6 +41,14 @@ int main(int argc, char **argv) {
   query::matcher::Registry matcherRegistry;
 
   // Matchers registered in alphabetical order for consistency:
+  matcherRegistry.registerMatcher("getDefinitions",
+                                  mlir::query::extramatcher::getDefinitions);
+  matcherRegistry.registerMatcher("definedBy",
+                                  mlir::query::extramatcher::definedBy);
+  matcherRegistry.registerMatcher("usedBy", mlir::query::extramatcher::usedBy);
+  matcherRegistry.registerMatcher("getUses",
+                                  mlir::query::extramatcher::getUses);
+
   matcherRegistry.registerMatcher("hasOpAttrName",
                                   static_cast<HasOpAttrName *>(m_Attr));
   matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

>From ec431eb20f17d3a9bb17c4373c509d4cf291777a Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Wed, 23 Oct 2024 17:42:19 +0000
Subject: [PATCH 2/5] MLIR-QUERY: backwardSlice, forwardSlice & QueryOptions
 added

---
 mlir/lib/Query/Matcher/Parser.cpp             | 18 +++++++++-----
 mlir/lib/Query/Query.cpp                      |  9 +++++++
 mlir/lib/Query/QueryParser.cpp                | 24 +++++++++++++++++++
 mlir/test/mlir-query/function-extraction.mlir |  4 ++--
 mlir/tools/mlir-query/mlir-query.cpp          |  1 -
 5 files changed, 47 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 4dcb86a9383f3..726f1188d7e4c 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -293,13 +293,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
 
   if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
     // Parse as a named value.
-    auto namedValue =
-        namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+    if (auto namedValue = namedValues ? namedValues->lookup(nameToken.text)
+                                      : VariantValue()) {
 
-    if (!namedValue.isMatcher()) {
-      error->addError(tokenizer->peekNextToken().range,
-                      ErrorType::ParserNotAMatcher);
-      return false;
+      if (tokenizer->nextTokenKind() != TokenKind::Period) {
+        *value = namedValue;
+        return true;
+      }
+
+      if (!namedValue.isMatcher()) {
+        error->addError(tokenizer->peekNextToken().range,
+                        ErrorType::ParserNotAMatcher);
+        return false;
+      }
     }
 
     if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index dd699857568d7..500fee50a1609 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -127,6 +127,15 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   return mlir::success();
 }
 
+LogicalResult LetQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
+  if (value.hasValue()) {
+    qs.namedValues[name] = value;
+  } else {
+    qs.namedValues.erase(name);
+  }
+  return mlir::success();
+}
+
 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   Operation *rootOp = qs.getRootOp();
 
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 7aaf4847f2e47..4350fb9a434d4 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -166,6 +166,8 @@ QueryRef QueryParser::doParse() {
           .Case("", ParsedQueryKind::NoOp)
           .Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
           .Case("help", ParsedQueryKind::Help)
+          .Case("l", ParsedQueryKind::Let, /*isCompletion=*/false)
+          .Case("let", ParsedQueryKind::Let, /*isCompletion=*/false)
           .Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
           .Case("set", ParsedQueryKind::Set)
           .Case("match", ParsedQueryKind::Match)
@@ -188,6 +190,27 @@ QueryRef QueryParser::doParse() {
   case ParsedQueryKind::Quit:
     return endQuery(new QuitQuery);
 
+  case ParsedQueryKind::Let: {
+    llvm::StringRef name = lexWord();
+
+    if (name.empty()) {
+      return new InvalidQuery("expected variable name");
+    }
+
+    if (completionPos) {
+      return completeMatcherExpression();
+    }
+
+    matcher::internal::Diagnostics diag;
+    matcher::VariantValue value;
+    if (!matcher::internal::Parser::parseExpression(
+            line, qs.getRegistryData(), &qs.namedValues, &value, &diag)) {
+      return makeInvalidQueryFromDiagnostics(diag);
+    }
+    QueryRef query = new LetQuery(name, value);
+    query->remainingContent = line;
+    return query;
+  }
   case ParsedQueryKind::Match: {
     if (completionPos) {
       return completeMatcherExpression();
@@ -204,6 +227,7 @@ QueryRef QueryParser::doParse() {
     }
     auto actualSource = origMatcherSource.substr(0, origMatcherSource.size() -
                                                         matcherSource.size());
+
     QueryRef query = new MatchQuery(actualSource, *matcher);
     query->remainingContent = matcherSource;
     return query;
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index d7a867eb1a452..5a20c09d02eb6 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -4,7 +4,7 @@
 // CHECK:       %[[MUL0:.*]] = arith.mulf {{.*}} : f32
 // CHECK:       %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
 // CHECK:       %[[MUL2:.*]] = arith.mulf {{.*}} : f32
-// CHECK-NEXT:  return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
+// CHECK-NEXT:  return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32S
 
 func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
   %sum0 = arith.addf %a, %b : f32
@@ -16,4 +16,4 @@ func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
   %sum2 = arith.addf %mul1, %b : f32
   %mul2 = arith.mulf %sub2, %sum2 : f32
   return %mul2 : f32
-}
+    }
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 5e74da7ee7bdc..468f948bec24c 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -10,7 +10,6 @@
 // of the registered queries.
 //
 //===----------------------------------------------------------------------===//
-
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"

>From 9cf7595da453256e645b49cc52fafe8b841821ac Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Sat, 25 Jan 2025 15:29:12 +0000
Subject: [PATCH 3/5] removed LetQuery implementation

---
 mlir/lib/Query/Query.cpp       |  9 ---------
 mlir/lib/Query/QueryParser.cpp | 24 ------------------------
 2 files changed, 33 deletions(-)

diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 500fee50a1609..dd699857568d7 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -127,15 +127,6 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   return mlir::success();
 }
 
-LogicalResult LetQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
-  if (value.hasValue()) {
-    qs.namedValues[name] = value;
-  } else {
-    qs.namedValues.erase(name);
-  }
-  return mlir::success();
-}
-
 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   Operation *rootOp = qs.getRootOp();
 
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 4350fb9a434d4..53e8f91e657cb 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -166,8 +166,6 @@ QueryRef QueryParser::doParse() {
           .Case("", ParsedQueryKind::NoOp)
           .Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
           .Case("help", ParsedQueryKind::Help)
-          .Case("l", ParsedQueryKind::Let, /*isCompletion=*/false)
-          .Case("let", ParsedQueryKind::Let, /*isCompletion=*/false)
           .Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
           .Case("set", ParsedQueryKind::Set)
           .Case("match", ParsedQueryKind::Match)
@@ -189,28 +187,6 @@ QueryRef QueryParser::doParse() {
 
   case ParsedQueryKind::Quit:
     return endQuery(new QuitQuery);
-
-  case ParsedQueryKind::Let: {
-    llvm::StringRef name = lexWord();
-
-    if (name.empty()) {
-      return new InvalidQuery("expected variable name");
-    }
-
-    if (completionPos) {
-      return completeMatcherExpression();
-    }
-
-    matcher::internal::Diagnostics diag;
-    matcher::VariantValue value;
-    if (!matcher::internal::Parser::parseExpression(
-            line, qs.getRegistryData(), &qs.namedValues, &value, &diag)) {
-      return makeInvalidQueryFromDiagnostics(diag);
-    }
-    QueryRef query = new LetQuery(name, value);
-    query->remainingContent = line;
-    return query;
-  }
   case ParsedQueryKind::Match: {
     if (completionPos) {
       return completeMatcherExpression();

>From 410c5c9b1dc5fd07b07246abddfa4ca24e2ec6a8 Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Sun, 23 Feb 2025 15:01:25 +0000
Subject: [PATCH 4/5] Enhance matcher and QueryOptions documentation

- Enhance docs for matchers and QueryOptions
- Fix whitespace and alignment issues
- Move matchers to Matchers.h
- Change data type from unsigned to signed for arithmetic operations
---
 mlir/include/mlir/IR/Matchers.h               | 262 ++++++++++++++++++
 .../mlir/Query/Matcher/ExtraMatchers.h        | 188 -------------
 mlir/include/mlir/Query/Matcher/Marshallers.h |   8 +-
 mlir/include/mlir/Query/Matcher/MatchFinder.h |  63 ++++-
 .../mlir/Query/Matcher/MatchersInternal.h     |  21 +-
 .../include/mlir/Query/Matcher/VariantValue.h |  16 +-
 mlir/include/mlir/Query/Query.h               |  20 +-
 mlir/include/mlir/Query/QuerySession.h        |   4 -
 mlir/lib/Query/Matcher/RegistryManager.cpp    |   9 +-
 mlir/lib/Query/Matcher/VariantValue.cpp       |  28 +-
 mlir/test/mlir-query/complex-test.mlir        |  13 +-
 mlir/test/mlir-query/function-extraction.mlir |   6 +-
 mlir/tools/mlir-query/mlir-query.cpp          |  13 +-
 13 files changed, 383 insertions(+), 268 deletions(-)
 delete mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 2204a68be26b1..ee9e2afb10bad 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -19,6 +19,9 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Query/Matcher/MatchersInternal.h"
+#include "mlir/Query/Query.h"
+#include "llvm/ADT/SetVector.h"
 
 namespace mlir {
 
@@ -363,8 +366,267 @@ struct RecursivePatternMatcher {
   std::tuple<OperandMatchers...> operandMatchers;
 };
 
+/// Fills `backwardSlice` with the computed backward slice (i.e.
+/// all the transitive defs of op)
+///
+/// The implementation traverses the def chains in postorder traversal for
+/// efficiency reasons: if an operation is already in `backwardSlice`, no
+/// need to traverse its definitions again. Since use-def chains form a DAG,
+/// this terminates.
+///
+/// Upon return to the root call, `backwardSlice` is filled with a
+/// postorder list of defs. This happens to be a topological order, from the
+/// point of view of the use-def chains.
+///
+/// Example starting from node 8
+/// ============================
+///
+///    1       2      3      4
+///    |_______|      |______|
+///    |   |             |
+///    |   5             6
+///    |___|_____________|
+///      |               |
+///      7               8
+///      |_______________|
+///              |
+///              9
+///
+/// Assuming all local orders match the numbering order:
+///    {1, 2, 5, 3, 4, 6}
+///
+
+class BackwardSliceMatcher {
+public:
+  BackwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
+                       int64_t maxDepth)
+      : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
+
+  bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+             mlir::query::QueryOptions &options) {
+
+    if (innerMatcher.match(op) &&
+        matches(op, backwardSlice, options, maxDepth)) {
+      if (!options.inclusive) {
+        // Don't insert the top level operation, we just queried on it and don't
+        // want it in the results.
+        backwardSlice.remove(op);
+      }
+      return true;
+    }
+    return false;
+  }
+
+private:
+  bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+               mlir::query::QueryOptions &options, int64_t remainingDepth) {
+
+    if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+      return false;
+    }
+
+    auto processValue = [&](Value value) {
+      // We need to check the current depth level;
+      // if we have reached level 0, we stop further traversing
+      if (remainingDepth == 0) {
+        return;
+      }
+      if (auto *definingOp = value.getDefiningOp()) {
+        // We omit traversing the same operations
+        if (backwardSlice.count(definingOp) == 0)
+          matches(definingOp, backwardSlice, options, remainingDepth - 1);
+      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+        if (options.omitBlockArguments)
+          return;
+        Block *block = blockArg.getOwner();
+
+        Operation *parentOp = block->getParentOp();
+        // TODO: determine whether we want to recurse backward into the other
+        // blocks of parentOp, which are not technically backward unless they
+        // flow into us. For now, just bail.
+        if (parentOp && backwardSlice.count(parentOp) == 0) {
+          if (parentOp->getNumRegions() != 1 &&
+              parentOp->getRegion(0).getBlocks().size() != 1) {
+            llvm::errs()
+                << "Error: Expected parentOp to have exactly one region and "
+                << "exactly one block, but found " << parentOp->getNumRegions()
+                << " regions and "
+                << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
+          };
+          matches(parentOp, backwardSlice, options, remainingDepth - 1);
+        }
+      } else {
+        llvm_unreachable("No definingOp and not a block argument\n");
+        return;
+      }
+    };
+
+    if (!options.omitUsesFromAbove) {
+      llvm::for_each(op->getRegions(), [&](Region &region) {
+        // Walk this region recursively to collect the regions that descend from
+        // this op's nested regions (inclusive).
+        SmallPtrSet<Region *, 4> descendents;
+        region.walk(
+            [&](Region *childRegion) { descendents.insert(childRegion); });
+        region.walk([&](Operation *op) {
+          for (OpOperand &operand : op->getOpOperands()) {
+            if (!descendents.contains(operand.get().getParentRegion()))
+              processValue(operand.get());
+          }
+        });
+      });
+    }
+
+    llvm::for_each(op->getOperands(), processValue);
+    backwardSlice.insert(op);
+    return true;
+  }
+
+private:
+  // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
+  // to determine whether we want to traverse the DAG or not. For example, we
+  // want to explore the DAG only if the top-level operation name is
+  // "arith.addf".
+  mlir::query::matcher::DynMatcher innerMatcher;
+
+  // maxDepth specifies the maximum depth that the matcher can traverse in the
+  // DAG. For example, if maxDepth is 2, the matcher will explore the defining
+  // operations of the top-level op up to 2 levels.
+  int64_t maxDepth;
+};
+
+/// Fills `forwardSlice` with the computed forward slice (i.e. all
+/// the transitive uses of op)
+///
+///
+/// The implementation traverses the use chains in postorder traversal for
+/// efficiency reasons: if an operation is already in `forwardSlice`, no
+/// need to traverse its uses again. Since use-def chains form a DAG, this
+/// terminates.
+///
+/// Upon return to the root call, `forwardSlice` is filled with a
+/// postorder list of uses (i.e. a reverse topological order). To get a proper
+/// topological order, we just reverse the order in `forwardSlice` before
+/// returning.
+///
+/// Example starting from node 0
+/// ============================
+///
+///               0
+///    ___________|___________
+///    1       2      3      4
+///    |_______|      |______|
+///    |   |             |
+///    |   5             6
+///    |___|_____________|
+///      |               |
+///      7               8
+///      |_______________|
+///              |
+///              9
+///
+/// Assuming all local orders match the numbering order:
+/// 1. after getting back to the root getForwardSlice, `forwardSlice` may
+///    contain:
+///      {9, 7, 8, 5, 1, 2, 6, 3, 4}
+/// 2. reversing the result of 1. gives:
+///      {4, 3, 6, 2, 1, 5, 8, 7, 9}
+///
+class ForwardSliceMatcher {
+public:
+  ForwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
+                      int64_t maxDepth)
+      : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
+
+  bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+             mlir::query::QueryOptions &options) {
+    if (innerMatcher.match(op) &&
+        matches(op, forwardSlice, options, maxDepth)) {
+      if (!options.inclusive) {
+        // Don't insert the top level operation, we just queried on it and don't
+        // want it in the results.
+        forwardSlice.remove(op);
+      }
+      // Reverse to get back the actual topological order.
+      // std::reverse does not work out of the box on SetVector and I want an
+      // in-place swap based thing (the real std::reverse, not the LLVM
+      // adapter).
+      SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+      forwardSlice.insert(v.rbegin(), v.rend());
+      return true;
+    }
+    return false;
+  }
+
+private:
+  bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+               mlir::query::QueryOptions &options, int64_t remainingDepth) {
+
+    // We need to check the current depth level;
+    // if we have reached level 0, we stop further traversing and insert
+    // the last user in def-use chain
+    if (remainingDepth == 0) {
+      forwardSlice.insert(op);
+      return true;
+    }
+
+    for (Region &region : op->getRegions())
+      for (Block &block : region)
+        for (Operation &blockOp : block)
+          if (forwardSlice.count(&blockOp) == 0)
+            matches(&blockOp, forwardSlice, options, remainingDepth - 1);
+    for (Value result : op->getResults()) {
+      for (Operation *userOp : result.getUsers())
+        // We omit traversing the same operations
+        if (forwardSlice.count(userOp) == 0)
+          matches(userOp, forwardSlice, options, remainingDepth - 1);
+    }
+
+    forwardSlice.insert(op);
+    return true;
+  }
+
+private:
+  // The outer matcher e.g (ForwardSliceMatcher) relies on the innerMatcher to
+  // determine whether we want to traverse the graph or not. E.g: we want to
+  // explore the DAG only if the top level operation name is "arith.addf"
+  mlir::query::matcher::DynMatcher innerMatcher;
+
+  // maxDepth specifies the maximum depth that the matcher can traverse the
+  // graph E.g: if maxDepth is 2, the matcher will explore the user
+  // operations of the top level op up to 2 levels
+  int64_t maxDepth;
+};
+
 } // namespace detail
 
+// Matches transitive defs of a top level operation up to 1 level
+inline detail::BackwardSliceMatcher
+m_DefinedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+// Matches transitive defs of a top level operation up to N levels
+inline detail::BackwardSliceMatcher
+m_GetDefinitions(mlir::query::matcher::DynMatcher innerMatcher,
+                 int64_t maxDepth) {
+  assert(maxDepth >= 0 && "maxDepth must be non-negative");
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), maxDepth);
+}
+
+// Matches uses of a top level operation up to 1 level
+inline detail::ForwardSliceMatcher
+m_UsedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+// Matches uses of a top level operation up to N  levels
+inline detail::ForwardSliceMatcher
+m_GetUses(mlir::query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
+  assert(maxDepth >= 0 && "maxDepth must be non-negative");
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), maxDepth);
+}
+
 /// Matches a constant foldable operation.
 inline detail::constant_op_matcher m_Constant() {
   return detail::constant_op_matcher();
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
deleted file mode 100644
index 908fccfc704c3..0000000000000
--- a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
+++ /dev/null
@@ -1,188 +0,0 @@
-//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
-//-*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file provides extra matchers that are very useful for mlir-query
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_EXTRAMATCHERS_H
-#define MLIR_IR_EXTRAMATCHERS_H
-
-#include "MatchFinder.h"
-#include "MatchersInternal.h"
-#include "mlir/IR/Region.h"
-#include "mlir/Query/Query.h"
-#include "llvm/Support/raw_ostream.h"
-
-namespace mlir {
-
-namespace query {
-
-namespace extramatcher {
-
-namespace detail {
-
-class BackwardSliceMatcher {
-public:
-  BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
-      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
-
-private:
-  bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
-               QueryOptions &options, unsigned tempHops) {
-
-    if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
-      return false;
-    }
-
-    auto processValue = [&](Value value) {
-      if (tempHops == 0) {
-        return;
-      }
-      if (auto *definingOp = value.getDefiningOp()) {
-        if (backwardSlice.count(definingOp) == 0)
-          matches(definingOp, backwardSlice, options, tempHops - 1);
-      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
-        if (options.omitBlockArguments)
-          return;
-        Block *block = blockArg.getOwner();
-
-        Operation *parentOp = block->getParentOp();
-
-        if (parentOp && backwardSlice.count(parentOp) == 0) {
-          if (parentOp->getNumRegions() != 1 &&
-              parentOp->getRegion(0).getBlocks().size() != 1) {
-            llvm::errs()
-                << "Error: Expected parentOp to have exactly one region and "
-                << "exactly one block, but found " << parentOp->getNumRegions()
-                << " regions and "
-                << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
-          };
-          matches(parentOp, backwardSlice, options, tempHops - 1);
-        }
-      } else {
-        llvm::errs() << "No definingOp and not a block argument\n";
-        return;
-      }
-    };
-
-    if (!options.omitUsesFromAbove) {
-      llvm::for_each(op->getRegions(), [&](Region &region) {
-        SmallPtrSet<Region *, 4> descendents;
-        region.walk(
-            [&](Region *childRegion) { descendents.insert(childRegion); });
-        region.walk([&](Operation *op) {
-          for (OpOperand &operand : op->getOpOperands()) {
-            if (!descendents.contains(operand.get().getParentRegion()))
-              processValue(operand.get());
-          }
-        });
-      });
-    }
-
-    llvm::for_each(op->getOperands(), processValue);
-    backwardSlice.insert(op);
-    return true;
-  }
-
-public:
-  bool match(Operation *op, SetVector<Operation *> &backwardSlice,
-             QueryOptions &options) {
-
-    if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
-      if (!options.inclusive) {
-        backwardSlice.remove(op);
-      }
-      return true;
-    }
-    return false;
-  }
-
-private:
-  matcher::DynMatcher innerMatcher;
-  unsigned hops;
-};
-
-class ForwardSliceMatcher {
-public:
-  ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
-      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
-
-private:
-  bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
-               QueryOptions &options, unsigned tempHops) {
-
-    if (tempHops == 0) {
-      forwardSlice.insert(op);
-      return true;
-    }
-
-    for (Region &region : op->getRegions())
-      for (Block &block : region)
-        for (Operation &blockOp : block)
-          if (forwardSlice.count(&blockOp) == 0)
-            matches(&blockOp, forwardSlice, options, tempHops - 1);
-    for (Value result : op->getResults()) {
-      for (Operation *userOp : result.getUsers())
-        if (forwardSlice.count(userOp) == 0)
-          matches(userOp, forwardSlice, options, tempHops - 1);
-    }
-
-    forwardSlice.insert(op);
-    return true;
-  }
-
-public:
-  bool match(Operation *op, SetVector<Operation *> &forwardSlice,
-             QueryOptions &options) {
-    if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
-      if (!options.inclusive) {
-        forwardSlice.remove(op);
-      }
-      SmallVector<Operation *, 0> v(forwardSlice.takeVector());
-      forwardSlice.insert(v.rbegin(), v.rend());
-      return true;
-    }
-    return false;
-  }
-
-private:
-  matcher::DynMatcher innerMatcher;
-  unsigned hops;
-};
-
-} // namespace detail
-
-inline detail::BackwardSliceMatcher
-definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
-  return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
-}
-
-inline detail::BackwardSliceMatcher
-getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
-  return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
-}
-
-inline detail::ForwardSliceMatcher
-usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
-  return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
-}
-
-inline detail::ForwardSliceMatcher
-getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
-  return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
-}
-
-} // namespace extramatcher
-
-} // namespace query
-
-} // namespace mlir
-
-#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index c775dbc5c86da..43643298e4702 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -51,14 +51,14 @@ struct ArgTypeTraits<llvm::StringRef> {
 };
 
 template <>
-struct ArgTypeTraits<unsigned> {
+struct ArgTypeTraits<int64_t> {
   static bool hasCorrectType(const VariantValue &value) {
-    return value.isUnsigned();
+    return value.isSigned();
   }
 
-  static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+  static unsigned get(const VariantValue &value) { return value.getSigned(); }
 
-  static ArgKind getKind() { return ArgKind::Unsigned; }
+  static ArgKind getKind() { return ArgKind::Signed; }
 
   static std::optional<std::string> getBestGuess(const VariantValue &) {
     return std::nullopt;
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 1b9d3bc307ff5..1d64f894bb8a1 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -23,45 +23,78 @@
 namespace mlir::query::matcher {
 
 class MatchFinder {
-private:
-  static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
-                         mlir::Operation *op, const std::string &binding) {
-    auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
-    auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
-        qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
-    qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
-                                       "\"" + binding + "\" binds here");
-  };
 
 public:
+  //
+  // getMatches walks the IR and prints operations as soon as it matches them
+  // if a matcher is to be further extracted into the function, then it does not
+  // print operations
+  //
   static std::vector<Operation *>
   getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
              llvm::raw_ostream &os, QuerySession &qs) {
-    unsigned matchCount = 0;
+    int matchCount = 0;
+    bool printMatchingOps = true;
+    // If matcher is to be extracted to a function, we don't want to print
+    // matching ops to sdout
+    if (matcher.hasFunctionName()) {
+      printMatchingOps = false;
+    }
     std::vector<Operation *> matchedOps;
     SetVector<Operation *> tempStorage;
     os << "\n";
     root->walk([&](Operation *subOp) {
       if (matcher.match(subOp)) {
         matchedOps.push_back(subOp);
-        os << "Match #" << ++matchCount << ":\n\n";
-        printMatch(os, qs, subOp, "root");
+        if (printMatchingOps) {
+          os << "Match #" << ++matchCount << ":\n\n";
+          printMatch(os, qs, subOp, "root");
+        }
       } else {
         SmallVector<Operation *> printingOps;
         if (matcher.match(subOp, tempStorage, options)) {
-          os << "Match #" << ++matchCount << ":\n\n";
+          if (printMatchingOps) {
+            os << "Match #" << ++matchCount << ":\n\n";
+          }
           SmallVector<Operation *> printingOps(tempStorage.takeVector());
           for (auto op : printingOps) {
-            printMatch(os, qs, op, "root");
+            if (printMatchingOps) {
+              printMatch(os, qs, op, "root");
+            }
             matchedOps.push_back(op);
           }
           printingOps.clear();
         }
       }
     });
-    os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+    if (printMatchingOps) {
+      os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+    }
     return matchedOps;
   }
+
+private:
+  // Overloaded version that doesn't print the binding
+  static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                         mlir::Operation *op) {
+    auto fileLoc = op->getLoc()->dyn_cast<FileLineColLoc>();
+    SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
+        qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+
+    llvm::SMDiagnostic diag =
+        qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note,
+
+                                         "");
+    diag.print("", os, true, false, true);
+  }
+  static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                         mlir::Operation *op, const std::string &binding) {
+    auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+    auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+        qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+    qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+                                       "\"" + binding + "\" binds here");
+  }
 };
 
 } // namespace mlir::query::matcher
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index b532b47be7d05..c5c24190f0e7f 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,3 +1,4 @@
+//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,20 +8,20 @@
 //
 // Implements the base layer of the matcher framework.
 //
-// Matchers are methods that return a Matcher which provides a method
-// match(Operation *op)
-// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
-// &options)
+// Matchers are methods that return a Matcher which provides a method one of the
+// following methods: match(Operation *op), match(Operation *op,
+// SetVector<Operation *> &matchedOps, QueryOptions &options)
 //
 // The matcher functions are defined in include/mlir/IR/Matchers.h.
 // This file contains the wrapper classes needed to construct matchers for
 // mlir-query.
 //
 //===----------------------------------------------------------------------===//
+
 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 
-#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 
 namespace mlir {
@@ -30,17 +31,27 @@ struct QueryOptions;
 } // namespace mlir
 
 namespace mlir::query::matcher {
+
+// Defaults to false if T has no match() method with the signature:
+// match(Operation* op).
 template <typename T, typename = void>
 struct has_simple_match : std::false_type {};
 
+// Specialized type trait that evaluates to true if T has a match() method
+// with the signature: match(Operation* op).
 template <typename T>
 struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
                                std::declval<Operation *>()))>>
     : std::true_type {};
 
+// Defaults to false if T has no match() method with the signature:
+// match(Operation* op, SetVector<Operation*>&, QueryOptions&).
 template <typename T, typename = void>
 struct has_bound_match : std::false_type {};
 
+// Specialized type trait that evaluates to true if T has a match() method
+// with the signature: match(Operation* op, SetVector<Operation*>&,
+// QueryOptions&).
 template <typename T>
 struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
                               std::declval<Operation *>(),
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 6b57119df7a9b..71ec628edeea1 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
 namespace mlir::query::matcher {
 
 // All types that VariantValue can contain.
-enum class ArgKind { Matcher, String, Unsigned };
+enum class ArgKind { Matcher, String, Signed };
 
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
@@ -81,7 +81,7 @@ class VariantValue {
   // Specific constructors for each supported type.
   VariantValue(const llvm::StringRef string);
   VariantValue(const VariantMatcher &matcher);
-  VariantValue(unsigned Unsigned);
+  VariantValue(int64_t signedValue);
 
   // String value functions.
   bool isString() const;
@@ -93,10 +93,10 @@ class VariantValue {
   const VariantMatcher &getMatcher() const;
   void setMatcher(const VariantMatcher &matcher);
 
-  // Unsigned value functions.
-  bool isUnsigned() const;
-  unsigned getUnsigned() const;
-  void setUnsigned(unsigned Unsigned);
+  // Signed value functions.
+  bool isSigned() const;
+  int64_t getSigned() const;
+  void setSigned(int64_t signedValue);
 
   // String representation of the type of the value.
   std::string getTypeAsString() const;
@@ -111,14 +111,14 @@ class VariantValue {
     Nothing,
     String,
     Matcher,
-    Unsigned,
+    Signed,
   };
 
   // All supported value types.
   union AllValues {
     llvm::StringRef *String;
     VariantMatcher *Matcher;
-    unsigned Unsigned;
+    int64_t Signed;
   };
 
   ValueType type;
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index bb5b98432d51c..77114a2a00e23 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -17,13 +17,31 @@
 
 namespace mlir::query {
 
+///
+/// Options for configuring which parts of the IR are to be
+/// traversed by the matcher
+///
 struct QueryOptions {
+  /// When omitBlockArguments is true, the matcher omits traversing
+  /// any block arguments
   bool omitBlockArguments = false;
+  /// When omitUsesFromAbove is true, the matcher omits
+  /// traversing values that are captured from above.
   bool omitUsesFromAbove = true;
+  /// When inclusive is true, the matcher will include the include the
+  /// top level op in the slice. When inclusive is false, the matcher will
+  /// not include thee top level op in the slice
   bool inclusive = true;
 };
 
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit, SetBool };
+enum class QueryKind {
+  Invalid,
+  NoOp,
+  Help,
+  SetBool,
+  Match,
+  Quit,
+};
 
 class QuerySession;
 
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index 495358e8f36f9..03dbd481d64cf 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -15,10 +15,6 @@
 #include "llvm/ADT/StringMap.h"
 #include "llvm/Support/SourceMgr.h"
 
-namespace mlir::query::matcher {
-class Registry;
-}
-
 namespace mlir::query {
 
 // Represents the state for a particular mlir-query session.
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 8d6c0135aa117..8f7da5aeaa5e6 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -19,11 +19,6 @@
 namespace mlir::query::matcher {
 namespace {
 
-// This is needed because these matchers are defined as overloaded functions.
-using IsConstantOp = detail::constant_op_matcher();
-using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
-using HasOpName = detail::NameOpMatcher(llvm::StringRef);
-
 // Enum to string for autocomplete.
 static std::string asArgString(ArgKind kind) {
   switch (kind) {
@@ -31,8 +26,8 @@ static std::string asArgString(ArgKind kind) {
     return "Matcher";
   case ArgKind::String:
     return "String";
-  case ArgKind::Unsigned:
-    return "unsigned";
+  case ArgKind::Signed:
+    return "signed";
   }
   llvm_unreachable("Unhandled ArgKind");
 }
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index d5218d8dad8c9..d4f3e4f4d594d 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,8 +56,8 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
   value.Matcher = new VariantMatcher(matcher);
 }
 
-VariantValue::VariantValue(unsigned Unsigned) : type(ValueType::Unsigned) {
-  value.Unsigned = Unsigned;
+VariantValue::VariantValue(int64_t signedValue) : type(ValueType::Signed) {
+  value.Signed = signedValue;
 }
 
 VariantValue::~VariantValue() { reset(); }
@@ -73,8 +73,8 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
   case ValueType::Matcher:
     setMatcher(other.getMatcher());
     break;
-  case ValueType::Unsigned:
-    setUnsigned(other.getUnsigned());
+  case ValueType::Signed:
+    setSigned(other.getSigned());
     break;
   case ValueType::Nothing:
     type = ValueType::Nothing;
@@ -92,7 +92,7 @@ void VariantValue::reset() {
     delete value.Matcher;
     break;
   // Cases that do nothing.
-  case ValueType::Unsigned:
+  case ValueType::Signed:
   case ValueType::Nothing:
     break;
   }
@@ -100,17 +100,17 @@ void VariantValue::reset() {
 }
 
 // Unsinged
-bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
+bool VariantValue::isSigned() const { return type == ValueType::Signed; }
 
-unsigned VariantValue::getUnsigned() const {
-  assert(isUnsigned());
-  return value.Unsigned;
+int64_t VariantValue::getSigned() const {
+  assert(isSigned());
+  return value.Signed;
 }
 
-void VariantValue::setUnsigned(unsigned newValue) {
+void VariantValue::setSigned(int64_t newValue) {
   reset();
-  type = ValueType::Unsigned;
-  value.Unsigned = newValue;
+  type = ValueType::Signed;
+  value.Signed = newValue;
 }
 
 bool VariantValue::isString() const { return type == ValueType::String; }
@@ -145,8 +145,8 @@ std::string VariantValue::getTypeAsString() const {
     return "String";
   case ValueType::Matcher:
     return "Matcher";
-  case ValueType::Unsigned:
-    return "Unsigned";
+  case ValueType::Signed:
+    return "Signed";
   case ValueType::Nothing:
     return "Nothing";
   }
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
index b3df534ee8871..c5fee38327704 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-query %s -c "match getDefinitions(hasOpName("arith.addf"),2)" | FileCheck %s
+// RUN: mlir-query %s -c "match getDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
@@ -22,18 +22,11 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
 
 // CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} 
 // CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
-
 // CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
 
-
 // CHECK: Match #2:
 
 // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
-// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%c2] : tensor<25xf32>
+// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
 // CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
-
-
-
-
-
-
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index 5a20c09d02eb6..a783f65c6761b 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-query %s -c "m hasOpName("arith.mulf").extract("testmul")" | FileCheck %s
+// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
 
 // CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
 // CHECK:       %[[MUL0:.*]] = arith.mulf {{.*}} : f32
 // CHECK:       %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
 // CHECK:       %[[MUL2:.*]] = arith.mulf {{.*}} : f32
-// CHECK-NEXT:  return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32S
+// CHECK-NEXT:  return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
 
 func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
   %sum0 = arith.addf %a, %b : f32
@@ -16,4 +16,4 @@ func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
   %sum2 = arith.addf %mul1, %b : f32
   %mul2 = arith.mulf %sub2, %sum2 : f32
   return %mul2 : f32
-    }
+}
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 468f948bec24c..91714aab33699 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -14,7 +14,6 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/InitAllDialects.h"
-#include "mlir/Query/Matcher/ExtraMatchers.h"
 #include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/Matcher/Registry.h"
 #include "mlir/Tools/mlir-query/MlirQueryMain.h"
@@ -40,14 +39,10 @@ int main(int argc, char **argv) {
   query::matcher::Registry matcherRegistry;
 
   // Matchers registered in alphabetical order for consistency:
-  matcherRegistry.registerMatcher("getDefinitions",
-                                  mlir::query::extramatcher::getDefinitions);
-  matcherRegistry.registerMatcher("definedBy",
-                                  mlir::query::extramatcher::definedBy);
-  matcherRegistry.registerMatcher("usedBy", mlir::query::extramatcher::usedBy);
-  matcherRegistry.registerMatcher("getUses",
-                                  mlir::query::extramatcher::getUses);
-
+  matcherRegistry.registerMatcher("getDefinitions", m_GetDefinitions);
+  matcherRegistry.registerMatcher("definedBy", m_DefinedBy);
+  matcherRegistry.registerMatcher("usedBy", m_UsedBy);
+  matcherRegistry.registerMatcher("getUses", m_GetUses);
   matcherRegistry.registerMatcher("hasOpAttrName",
                                   static_cast<HasOpAttrName *>(m_Attr));
   matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

>From e6bc9b34a1fd45100afb0c8b23d928ecc659c681 Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Sat, 29 Mar 2025 15:12:50 +0000
Subject: [PATCH 5/5] Implement nested slicing matcher & enhance MatchFinder
 class 	- nested slicing matcher 	- enhance MatchFinder class 	-
 rename getSlice static method to avoid collision with SliceAnalysis::getSlice

---
 mlir/include/mlir/IR/Matchers.h               | 225 ++----------------
 mlir/include/mlir/Query/Matcher/MatchFinder.h |  97 ++------
 .../mlir/Query/Matcher/MatchersInternal.h     |   6 +-
 mlir/include/mlir/Query/Query.h               |  20 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |   6 +-
 mlir/lib/IR/CMakeLists.txt                    |   1 +
 mlir/lib/IR/Matchers.cpp                      |  57 +++++
 mlir/lib/Query/Matcher/CMakeLists.txt         |   1 +
 mlir/lib/Query/Matcher/MatchFinder.cpp        |  72 ++++++
 mlir/lib/Query/Matcher/Parser.cpp             |  24 +-
 mlir/lib/Query/Matcher/VariantValue.cpp       |   7 +-
 mlir/lib/Query/Query.cpp                      |  22 +-
 mlir/lib/Query/QueryParser.cpp                |   2 +-
 mlir/test/mlir-query/complex-test.mlir        |   2 +-
 mlir/tools/mlir-query/mlir-query.cpp          |   2 -
 15 files changed, 208 insertions(+), 336 deletions(-)
 create mode 100644 mlir/lib/IR/Matchers.cpp
 create mode 100644 mlir/lib/Query/Matcher/MatchFinder.cpp

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index ee9e2afb10bad..5ea91e64fa93c 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -21,8 +21,6 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Query/Matcher/MatchersInternal.h"
 #include "mlir/Query/Query.h"
-#include "llvm/ADT/SetVector.h"
-
 namespace mlir {
 
 namespace detail {
@@ -366,21 +364,14 @@ struct RecursivePatternMatcher {
   std::tuple<OperandMatchers...> operandMatchers;
 };
 
-/// Fills `backwardSlice` with the computed backward slice (i.e.
-/// all the transitive defs of op)
-///
-/// The implementation traverses the def chains in postorder traversal for
-/// efficiency reasons: if an operation is already in `backwardSlice`, no
-/// need to traverse its definitions again. Since use-def chains form a DAG,
-/// this terminates.
-///
-/// Upon return to the root call, `backwardSlice` is filled with a
-/// postorder list of defs. This happens to be a topological order, from the
-/// point of view of the use-def chains.
+/// A matcher encapsulating the initial `getBackwardSlice` method from
+/// SliceAnalysis.h
+/// Additionally, it limits the slice computation to a certain depth level using
+/// a custom filter
 ///
-/// Example starting from node 8
+/// Example starting from node 9, assuming the matcher
+/// computes the slice for the first two depth levels
 /// ============================
-///
 ///    1       2      3      4
 ///    |_______|      |______|
 ///    |   |             |
@@ -393,240 +384,52 @@ struct RecursivePatternMatcher {
 ///              9
 ///
 /// Assuming all local orders match the numbering order:
-///    {1, 2, 5, 3, 4, 6}
-///
-
+///     {5, 7, 6, 8, 9}
 class BackwardSliceMatcher {
 public:
-  BackwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
+  BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
                        int64_t maxDepth)
       : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
-
   bool match(Operation *op, SetVector<Operation *> &backwardSlice,
-             mlir::query::QueryOptions &options) {
+             query::QueryOptions &options) {
 
     if (innerMatcher.match(op) &&
         matches(op, backwardSlice, options, maxDepth)) {
-      if (!options.inclusive) {
-        // Don't insert the top level operation, we just queried on it and don't
-        // want it in the results.
-        backwardSlice.remove(op);
-      }
       return true;
     }
     return false;
   }
 
 private:
-  bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
-               mlir::query::QueryOptions &options, int64_t remainingDepth) {
-
-    if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
-      return false;
-    }
-
-    auto processValue = [&](Value value) {
-      // We need to check the current depth level;
-      // if we have reached level 0, we stop further traversing
-      if (remainingDepth == 0) {
-        return;
-      }
-      if (auto *definingOp = value.getDefiningOp()) {
-        // We omit traversing the same operations
-        if (backwardSlice.count(definingOp) == 0)
-          matches(definingOp, backwardSlice, options, remainingDepth - 1);
-      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
-        if (options.omitBlockArguments)
-          return;
-        Block *block = blockArg.getOwner();
-
-        Operation *parentOp = block->getParentOp();
-        // TODO: determine whether we want to recurse backward into the other
-        // blocks of parentOp, which are not technically backward unless they
-        // flow into us. For now, just bail.
-        if (parentOp && backwardSlice.count(parentOp) == 0) {
-          if (parentOp->getNumRegions() != 1 &&
-              parentOp->getRegion(0).getBlocks().size() != 1) {
-            llvm::errs()
-                << "Error: Expected parentOp to have exactly one region and "
-                << "exactly one block, but found " << parentOp->getNumRegions()
-                << " regions and "
-                << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
-          };
-          matches(parentOp, backwardSlice, options, remainingDepth - 1);
-        }
-      } else {
-        llvm_unreachable("No definingOp and not a block argument\n");
-        return;
-      }
-    };
-
-    if (!options.omitUsesFromAbove) {
-      llvm::for_each(op->getRegions(), [&](Region &region) {
-        // Walk this region recursively to collect the regions that descend from
-        // this op's nested regions (inclusive).
-        SmallPtrSet<Region *, 4> descendents;
-        region.walk(
-            [&](Region *childRegion) { descendents.insert(childRegion); });
-        region.walk([&](Operation *op) {
-          for (OpOperand &operand : op->getOpOperands()) {
-            if (!descendents.contains(operand.get().getParentRegion()))
-              processValue(operand.get());
-          }
-        });
-      });
-    }
-
-    llvm::for_each(op->getOperands(), processValue);
-    backwardSlice.insert(op);
-    return true;
-  }
+  bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
+               query::QueryOptions &options, int64_t maxDepth);
 
 private:
   // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
   // to determine whether we want to traverse the DAG or not. For example, we
   // want to explore the DAG only if the top-level operation name is
   // "arith.addf".
-  mlir::query::matcher::DynMatcher innerMatcher;
-
+  query::matcher::DynMatcher innerMatcher;
   // maxDepth specifies the maximum depth that the matcher can traverse in the
   // DAG. For example, if maxDepth is 2, the matcher will explore the defining
   // operations of the top-level op up to 2 levels.
   int64_t maxDepth;
 };
-
-/// Fills `forwardSlice` with the computed forward slice (i.e. all
-/// the transitive uses of op)
-///
-///
-/// The implementation traverses the use chains in postorder traversal for
-/// efficiency reasons: if an operation is already in `forwardSlice`, no
-/// need to traverse its uses again. Since use-def chains form a DAG, this
-/// terminates.
-///
-/// Upon return to the root call, `forwardSlice` is filled with a
-/// postorder list of uses (i.e. a reverse topological order). To get a proper
-/// topological order, we just reverse the order in `forwardSlice` before
-/// returning.
-///
-/// Example starting from node 0
-/// ============================
-///
-///               0
-///    ___________|___________
-///    1       2      3      4
-///    |_______|      |______|
-///    |   |             |
-///    |   5             6
-///    |___|_____________|
-///      |               |
-///      7               8
-///      |_______________|
-///              |
-///              9
-///
-/// Assuming all local orders match the numbering order:
-/// 1. after getting back to the root getForwardSlice, `forwardSlice` may
-///    contain:
-///      {9, 7, 8, 5, 1, 2, 6, 3, 4}
-/// 2. reversing the result of 1. gives:
-///      {4, 3, 6, 2, 1, 5, 8, 7, 9}
-///
-class ForwardSliceMatcher {
-public:
-  ForwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
-                      int64_t maxDepth)
-      : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
-
-  bool match(Operation *op, SetVector<Operation *> &forwardSlice,
-             mlir::query::QueryOptions &options) {
-    if (innerMatcher.match(op) &&
-        matches(op, forwardSlice, options, maxDepth)) {
-      if (!options.inclusive) {
-        // Don't insert the top level operation, we just queried on it and don't
-        // want it in the results.
-        forwardSlice.remove(op);
-      }
-      // Reverse to get back the actual topological order.
-      // std::reverse does not work out of the box on SetVector and I want an
-      // in-place swap based thing (the real std::reverse, not the LLVM
-      // adapter).
-      SmallVector<Operation *, 0> v(forwardSlice.takeVector());
-      forwardSlice.insert(v.rbegin(), v.rend());
-      return true;
-    }
-    return false;
-  }
-
-private:
-  bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
-               mlir::query::QueryOptions &options, int64_t remainingDepth) {
-
-    // We need to check the current depth level;
-    // if we have reached level 0, we stop further traversing and insert
-    // the last user in def-use chain
-    if (remainingDepth == 0) {
-      forwardSlice.insert(op);
-      return true;
-    }
-
-    for (Region &region : op->getRegions())
-      for (Block &block : region)
-        for (Operation &blockOp : block)
-          if (forwardSlice.count(&blockOp) == 0)
-            matches(&blockOp, forwardSlice, options, remainingDepth - 1);
-    for (Value result : op->getResults()) {
-      for (Operation *userOp : result.getUsers())
-        // We omit traversing the same operations
-        if (forwardSlice.count(userOp) == 0)
-          matches(userOp, forwardSlice, options, remainingDepth - 1);
-    }
-
-    forwardSlice.insert(op);
-    return true;
-  }
-
-private:
-  // The outer matcher e.g (ForwardSliceMatcher) relies on the innerMatcher to
-  // determine whether we want to traverse the graph or not. E.g: we want to
-  // explore the DAG only if the top level operation name is "arith.addf"
-  mlir::query::matcher::DynMatcher innerMatcher;
-
-  // maxDepth specifies the maximum depth that the matcher can traverse the
-  // graph E.g: if maxDepth is 2, the matcher will explore the user
-  // operations of the top level op up to 2 levels
-  int64_t maxDepth;
-};
-
 } // namespace detail
 
 // Matches transitive defs of a top level operation up to 1 level
 inline detail::BackwardSliceMatcher
-m_DefinedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+m_DefinedBy(query::matcher::DynMatcher innerMatcher) {
   return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
 }
 
 // Matches transitive defs of a top level operation up to N levels
 inline detail::BackwardSliceMatcher
-m_GetDefinitions(mlir::query::matcher::DynMatcher innerMatcher,
-                 int64_t maxDepth) {
+m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
   assert(maxDepth >= 0 && "maxDepth must be non-negative");
   return detail::BackwardSliceMatcher(std::move(innerMatcher), maxDepth);
 }
 
-// Matches uses of a top level operation up to 1 level
-inline detail::ForwardSliceMatcher
-m_UsedBy(mlir::query::matcher::DynMatcher innerMatcher) {
-  return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
-}
-
-// Matches uses of a top level operation up to N  levels
-inline detail::ForwardSliceMatcher
-m_GetUses(mlir::query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
-  assert(maxDepth >= 0 && "maxDepth must be non-negative");
-  return detail::ForwardSliceMatcher(std::move(innerMatcher), maxDepth);
-}
-
 /// Matches a constant foldable operation.
 inline detail::constant_op_matcher m_Constant() {
   return detail::constant_op_matcher();
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 1d64f894bb8a1..3591cf05e7599 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -15,86 +15,41 @@
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
 
 #include "MatchersInternal.h"
+#include "mlir/Query/Query.h"
 #include "mlir/Query/QuerySession.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
 
 namespace mlir::query::matcher {
 
+/// A class that provides utilities to find operations in a DAG
 class MatchFinder {
 
 public:
-  //
-  // getMatches walks the IR and prints operations as soon as it matches them
-  // if a matcher is to be further extracted into the function, then it does not
-  // print operations
-  //
-  static std::vector<Operation *>
-  getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
-             llvm::raw_ostream &os, QuerySession &qs) {
-    int matchCount = 0;
-    bool printMatchingOps = true;
-    // If matcher is to be extracted to a function, we don't want to print
-    // matching ops to sdout
-    if (matcher.hasFunctionName()) {
-      printMatchingOps = false;
-    }
-    std::vector<Operation *> matchedOps;
-    SetVector<Operation *> tempStorage;
-    os << "\n";
-    root->walk([&](Operation *subOp) {
-      if (matcher.match(subOp)) {
-        matchedOps.push_back(subOp);
-        if (printMatchingOps) {
-          os << "Match #" << ++matchCount << ":\n\n";
-          printMatch(os, qs, subOp, "root");
-        }
-      } else {
-        SmallVector<Operation *> printingOps;
-        if (matcher.match(subOp, tempStorage, options)) {
-          if (printMatchingOps) {
-            os << "Match #" << ++matchCount << ":\n\n";
-          }
-          SmallVector<Operation *> printingOps(tempStorage.takeVector());
-          for (auto op : printingOps) {
-            if (printMatchingOps) {
-              printMatch(os, qs, op, "root");
-            }
-            matchedOps.push_back(op);
-          }
-          printingOps.clear();
-        }
-      }
-    });
-    if (printMatchingOps) {
-      os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
-    }
-    return matchedOps;
-  }
-
-private:
-  // Overloaded version that doesn't print the binding
-  static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
-                         mlir::Operation *op) {
-    auto fileLoc = op->getLoc()->dyn_cast<FileLineColLoc>();
-    SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
-        qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+  /// A subclass which preserves the matching information
+  struct MatchResult {
+    MatchResult() = default;
+    MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
 
-    llvm::SMDiagnostic diag =
-        qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note,
+    /// Contains the root operation of the matching environment
+    Operation *rootOp = nullptr;
 
-                                         "");
-    diag.print("", os, true, false, true);
-  }
-  static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
-                         mlir::Operation *op, const std::string &binding) {
-    auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
-    auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
-        qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
-    qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
-                                       "\"" + binding + "\" binds here");
-  }
+    /// Contains the matching enviroment. This allows the user to easily extract
+    /// the matched operations
+    std::vector<Operation *> matchedOps;
+  };
+  /// Traverses the DAG and collects the "rootOp" + "matching enviroment" for a
+  /// given Matcher
+  std::vector<MatchResult>
+  collectMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+                 llvm::raw_ostream &os, QuerySession &qs) const;
+  /// Prints the matched operation
+  void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
+  /// Labels the matched operation with the given binding (e.g., "root") and
+  /// prints it
+  void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
+                  const std::string &binding) const;
+  /// Flattens a vector of MatchResults into a vector of operations
+  std::vector<Operation *>
+  flattenMatchedOps(std::vector<MatchResult> &matches) const;
 };
 
 } // namespace mlir::query::matcher
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index c5c24190f0e7f..e26697cdc4ae8 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -21,15 +21,13 @@
 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 
-namespace mlir {
-namespace query {
+namespace mlir::query {
 struct QueryOptions;
 }
-} // namespace mlir
-
 namespace mlir::query::matcher {
 
 // Defaults to false if T has no match() method with the signature:
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 77114a2a00e23..5644113ba9e18 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -10,6 +10,7 @@
 #define MLIR_TOOLS_MLIRQUERY_QUERY_H
 
 #include "Matcher/VariantValue.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/LineEditor/LineEditor.h"
@@ -17,22 +18,9 @@
 
 namespace mlir::query {
 
-///
-/// Options for configuring which parts of the IR are to be
-/// traversed by the matcher
-///
-struct QueryOptions {
-  /// When omitBlockArguments is true, the matcher omits traversing
-  /// any block arguments
-  bool omitBlockArguments = false;
-  /// When omitUsesFromAbove is true, the matcher omits
-  /// traversing values that are captured from above.
-  bool omitUsesFromAbove = true;
-  /// When inclusive is true, the matcher will include the include the
-  /// top level op in the slice. When inclusive is false, the matcher will
-  /// not include thee top level op in the slice
-  bool inclusive = true;
-};
+/// QueryOptions is a class derived from BackwardSliceOptions
+/// Addtional options can be added for further customization
+struct QueryOptions : public BackwardSliceOptions {};
 
 enum class QueryKind {
   Invalid,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ff89ead59981c..c1e3942432352 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -74,7 +74,7 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
 
 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
 /// `source`.
-static Operation *getSlice(OpBuilder &b, Location loc, Value source,
+static Operation *getSubviewOrSlice(OpBuilder &b, Location loc, Value source,
                            ArrayRef<OpFoldResult> offsets,
                            ArrayRef<OpFoldResult> sizes,
                            ArrayRef<OpFoldResult> strides) {
@@ -2675,13 +2675,13 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
   SmallVector<OpFoldResult> strides(rank, oneAttr);
   SmallVector<Value> tiledOperands;
   Operation *inputSlice =
-      getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
+      getSubviewOrSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
   if (!inputSlice) {
     return emitOpError("failed to compute input slice");
   }
   tiledOperands.emplace_back(inputSlice->getResult(0));
   Operation *outputSlice =
-      getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
+      getSubviewOrSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
   if (!outputSlice) {
     return emitOpError("failed to compute output slice");
   }
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 4cabac185171c..c6c44260fe776 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_library(MLIRIR
   ExtensibleDialect.cpp
   IntegerSet.cpp
   Location.cpp
+  Matchers.cpp
   MLIRContext.cpp
   ODSSupport.cpp
   Operation.cpp
diff --git a/mlir/lib/IR/Matchers.cpp b/mlir/lib/IR/Matchers.cpp
new file mode 100644
index 0000000000000..055f0a17527db
--- /dev/null
+++ b/mlir/lib/IR/Matchers.cpp
@@ -0,0 +1,57 @@
+//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements specific matchers
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Matchers.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+namespace mlir::detail {
+
+bool BackwardSliceMatcher::matches(Operation *rootOp,
+                                   llvm::SetVector<Operation *> &backwardSlice,
+                                   query::QueryOptions &options,
+                                   int64_t maxDepth) {
+  backwardSlice.clear();
+  llvm::DenseMap<Operation *, int64_t> opDepths;
+  // Initialize the map with the root operation
+  // and set its depth to 0
+  opDepths[rootOp] = 0;
+  options.filter = [&](Operation *op) {
+    if (opDepths[op] > maxDepth)
+      return false;
+    // Begins by checking the previous operation's arguments
+    // and computing their depth
+    for (auto operand : op->getOperands()) {
+      if (auto definingOp = operand.getDefiningOp()) {
+        // If the operation is in the map, it means
+        // we have already computed its depth
+        if (!opDepths.contains(definingOp)) {
+          // The operation's depth is 1 level above its root op
+          opDepths[definingOp] = opDepths[op] + 1;
+          if (opDepths[op] > maxDepth)
+            return false;
+        }
+      } else {
+        auto blockArgument = cast<BlockArgument>(operand);
+        Operation *parentOp = blockArgument.getOwner()->getParentOp();
+        if (!opDepths.contains(parentOp)) {
+          opDepths[parentOp] = opDepths[op] + 1;
+          if (opDepths[op] > maxDepth)
+            return false;
+        }
+      }
+    }
+    return true;
+  };
+  getBackwardSlice(rootOp, &backwardSlice, options);
+  return true;
+}
+} // namespace mlir::detail
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 3adff9f99243f..629479bf7adc1 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRQueryMatcher
+  MatchFinder.cpp
   Parser.cpp
   RegistryManager.cpp
   VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchFinder.cpp b/mlir/lib/Query/Matcher/MatchFinder.cpp
new file mode 100644
index 0000000000000..b0a95660c1d59
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchFinder.cpp
@@ -0,0 +1,72 @@
+//===- MatchFinder.cpp - -----------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the method definitions for the `MatchFinder` class
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/MatchFinder.h"
+namespace mlir::query::matcher {
+
+MatchFinder::MatchResult::MatchResult(Operation *rootOp,
+                                      std::vector<Operation *> matchedOps)
+    : rootOp(rootOp), matchedOps(std::move(matchedOps)) {}
+
+std::vector<MatchFinder::MatchResult>
+MatchFinder::collectMatches(Operation *root, QueryOptions &options,
+                            DynMatcher matcher, llvm::raw_ostream &os,
+                            QuerySession &qs) const {
+  std::vector<MatchResult> results;
+  llvm::SetVector<Operation *> tempStorage;
+  os << "\n";
+  root->walk([&](Operation *subOp) {
+    if (matcher.match(subOp)) {
+      MatchResult match;
+      match.rootOp = subOp;
+      match.matchedOps.push_back(subOp);
+      results.push_back(std::move(match));
+    } else if (matcher.match(subOp, tempStorage, options)) {
+      results.emplace_back(subOp, std::vector<Operation *>(tempStorage.begin(),
+                                                           tempStorage.end()));
+    }
+    tempStorage.clear();
+  });
+  return results;
+}
+
+void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                             Operation *op) const {
+  auto fileLoc = dyn_cast<FileLineColLoc>(op->getLoc());
+  SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
+      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+
+  llvm::SMDiagnostic diag =
+      qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note, "");
+  diag.print("", os, true, false, true);
+}
+
+void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                             Operation *op, const std::string &binding) const {
+  auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+  auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+  qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+                                     "\"" + binding + "\" binds here");
+}
+
+std::vector<Operation *>
+MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const {
+  std::vector<Operation *> newVector;
+  for (auto &result : matches) {
+    newVector.insert(newVector.end(), result.matchedOps.begin(),
+                     result.matchedOps.end());
+  }
+  return newVector;
+}
+
+} // namespace mlir::query::matcher
\ No newline at end of file
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 726f1188d7e4c..a82af80dbdb0c 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -157,25 +157,13 @@ class Parser::CodeTokenizer {
   }
 
   void consumeNumberLiteral(TokenInfo *result) {
-    unsigned length = 1;
-    if (code.size() > 1) {
-      // Consume the 'x' or 'b' radix modifier, if present.
-      switch (tolower(code[1])) {
-      case 'x':
-      case 'b':
-        length = 2;
-      }
-    }
-    while (length < code.size() && isdigit(code[length]))
-      ++length;
-
-    result->text = code.take_front(length);
-    code = code.drop_front(length);
-
-    unsigned value;
-    if (!result->text.getAsInteger(0, value)) {
+    StringRef original = code;
+    unsigned value = 0;
+    if (!code.consumeInteger(0, value)) {
+      size_t numConsumed = original.size() - code.size();
+      result->text = original.take_front(numConsumed);
       result->kind = TokenKind::Literal;
-      result->value = static_cast<unsigned>(value);
+      result->value = static_cast<int64_t>(value);
       return;
     }
   }
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index d4f3e4f4d594d..f2bf0f9065bbe 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -99,13 +99,10 @@ void VariantValue::reset() {
   type = ValueType::Nothing;
 }
 
-// Unsinged
+// Signed
 bool VariantValue::isSigned() const { return type == ValueType::Signed; }
 
-int64_t VariantValue::getSigned() const {
-  assert(isSigned());
-  return value.Signed;
-}
+int64_t VariantValue::getSigned() const { return value.Signed; }
 
 void VariantValue::setSigned(int64_t newValue) {
   reset();
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index dd699857568d7..7082fdb0f8482 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -129,23 +129,37 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
 
 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   Operation *rootOp = qs.getRootOp();
-
+  int matchCount = 0;
   QueryOptions options;
+  matcher::MatchFinder finder;
   parseQueryOptions(qs, options);
-  auto matches = matcher::MatchFinder().getMatches(rootOp, options,
-                                                   std::move(matcher), os, qs);
+  auto matches =
+      finder.collectMatches(rootOp, options, std::move(matcher), os, qs);
 
   // An extract call is recognized by considering if the matcher has a name.
   // TODO: Consider making the extract more explicit.
   if (matcher.hasFunctionName()) {
     auto functionName = matcher.getFunctionName();
+    std::vector<Operation *> flattenedMatches =
+        finder.flattenMatchedOps(matches);
     Operation *function =
-        extractFunction(matches, rootOp->getContext(), functionName);
+        extractFunction(flattenedMatches, rootOp->getContext(), functionName);
     os << "\n" << *function << "\n\n";
     function->erase();
     return mlir::success();
   }
 
+  for (auto &results : matches) {
+    os << "Match #" << ++matchCount << ":\n\n";
+    for (auto op : results.matchedOps) {
+      if (op == results.rootOp) {
+        finder.printMatch(os, qs, op, "root");
+      } else {
+        finder.printMatch(os, qs, op);
+      }
+    }
+  }
+  os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
   return mlir::success();
 }
 
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 53e8f91e657cb..b7c6118575dc8 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -167,10 +167,10 @@ QueryRef QueryParser::doParse() {
           .Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
           .Case("help", ParsedQueryKind::Help)
           .Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
-          .Case("set", ParsedQueryKind::Set)
           .Case("match", ParsedQueryKind::Match)
           .Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false)
           .Case("quit", ParsedQueryKind::Quit)
+          .Case("set", ParsedQueryKind::Set)
           .Default(ParsedQueryKind::Invalid);
 
   switch (qKind) {
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
index c5fee38327704..e0f7ee3034ed9 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -29,4 +29,4 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
 // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
 // CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
 // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
-// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
+// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32  
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 91714aab33699..34fb7d1d80a8d 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -41,8 +41,6 @@ int main(int argc, char **argv) {
   // Matchers registered in alphabetical order for consistency:
   matcherRegistry.registerMatcher("getDefinitions", m_GetDefinitions);
   matcherRegistry.registerMatcher("definedBy", m_DefinedBy);
-  matcherRegistry.registerMatcher("usedBy", m_UsedBy);
-  matcherRegistry.registerMatcher("getUses", m_GetUses);
   matcherRegistry.registerMatcher("hasOpAttrName",
                                   static_cast<HasOpAttrName *>(m_Attr));
   matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));



More information about the Mlir-commits mailing list