[Mlir-commits] [mlir] [mlir] Improve mlir-query tool by implementing `getBackwardSlice` and `getForwardSlice` matchers (PR #115670)

Denzel-Brian Budii llvmlistbot at llvm.org
Sun May 4 05:20:40 PDT 2025


https://github.com/chios202 updated https://github.com/llvm/llvm-project/pull/115670

>From e07e1feb13ea9607424c6817808f02a2f313f867 Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Tue, 22 Apr 2025 11:44:35 +0000
Subject: [PATCH 1/3] Compute shortest depth in backwardSlice method Relocate
 backwardSlice matcher to Query specific headers Remove unncecessary code

---
 .../mlir/Query/Matcher/ExtraMatchers.h        | 85 +++++++++++++++++++
 mlir/include/mlir/Query/Matcher/Marshallers.h | 30 +++++++
 mlir/include/mlir/Query/Matcher/MatchFinder.h | 45 ++++++----
 .../mlir/Query/Matcher/MatchersInternal.h     | 59 ++++++++++---
 .../include/mlir/Query/Matcher/VariantValue.h | 21 ++++-
 mlir/lib/Query/Matcher/CMakeLists.txt         |  2 +
 mlir/lib/Query/Matcher/ExtraMatchers.cpp      | 66 ++++++++++++++
 mlir/lib/Query/Matcher/MatchFinder.cpp        | 68 +++++++++++++++
 mlir/lib/Query/Matcher/Parser.cpp             | 59 +++++++++++--
 mlir/lib/Query/Matcher/RegistryManager.cpp    |  9 +-
 mlir/lib/Query/Matcher/VariantValue.cpp       | 40 +++++++++
 mlir/lib/Query/Query.cpp                      | 30 +++----
 mlir/lib/Query/QueryParser.cpp                |  1 -
 mlir/test/mlir-query/complex-test.mlir        | 32 +++++++
 mlir/tools/mlir-query/mlir-query.cpp          |  3 +
 15 files changed, 493 insertions(+), 57 deletions(-)
 create mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h
 create mode 100644 mlir/lib/Query/Matcher/ExtraMatchers.cpp
 create mode 100644 mlir/lib/Query/Matcher/MatchFinder.cpp
 create mode 100644 mlir/test/mlir-query/complex-test.mlir

diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 0000000000000..4766a767cf783
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,85 @@
+//===- ExtraMatchers.h - Various common matchers --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides matchers that depend on Query.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Query/Matcher/MatchersInternal.h"
+
+/// A matcher encapsulating the initial `getBackwardSlice` method from
+/// SliceAnalysis.h
+/// Additionally, it limits the slice computation to a certain depth level using
+/// a custom filter
+///
+/// Example starting from node 9, assuming the matcher
+/// computes the slice for the first two depth levels
+/// ============================
+///    1       2      3      4
+///    |_______|      |______|
+///    |   |             |
+///    |   5             6
+///    |___|_____________|
+///      |               |
+///      7               8
+///      |_______________|
+///              |
+///              9
+///
+/// Assuming all local orders match the numbering order:
+///     {5, 7, 6, 8, 9}
+namespace mlir::query::matcher {
+class BackwardSliceMatcher {
+public:
+  explicit BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
+                                int64_t maxDepth, bool inclusive,
+                                bool omitBlockArguments, bool omitUsesFromAbove)
+      : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
+        inclusive(inclusive), omitBlockArguments(omitBlockArguments),
+        omitUsesFromAbove(omitUsesFromAbove) {}
+  bool match(Operation *op, SetVector<Operation *> &backwardSlice) {
+    BackwardSliceOptions options;
+    return (innerMatcher.match(op) &&
+            matches(op, backwardSlice, options, maxDepth));
+  }
+
+private:
+  bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
+               BackwardSliceOptions &options, int64_t maxDepth);
+
+private:
+  // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
+  // to determine whether we want to traverse the DAG or not. For example, we
+  // want to explore the DAG only if the top-level operation name is
+  // "arith.addf".
+  query::matcher::DynMatcher innerMatcher;
+  // maxDepth specifies the maximum depth that the matcher can traverse in the
+  // DAG. For example, if maxDepth is 2, the matcher will explore the defining
+  // operations of the top-level op up to 2 levels.
+  int64_t maxDepth;
+
+  bool inclusive;
+  bool omitBlockArguments;
+  bool omitUsesFromAbove;
+};
+
+// Matches transitive defs of a top level operation up to N levels
+inline BackwardSliceMatcher
+m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
+                 bool inclusive, bool omitBlockArguments,
+                 bool omitUsesFromAbove) {
+  assert(maxDepth >= 0 && "maxDepth must be non-negative");
+  return BackwardSliceMatcher(std::move(innerMatcher), maxDepth, inclusive,
+                              omitBlockArguments, omitUsesFromAbove);
+}
+} // namespace mlir::query::matcher
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc..012bf7b9ec4a9 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,36 @@ struct ArgTypeTraits<llvm::StringRef> {
   }
 };
 
+template <>
+struct ArgTypeTraits<int64_t> {
+  static bool hasCorrectType(const VariantValue &value) {
+    return value.isSigned();
+  }
+
+  static unsigned get(const VariantValue &value) { return value.getSigned(); }
+
+  static ArgKind getKind() { return ArgKind::Signed; }
+
+  static std::optional<std::string> getBestGuess(const VariantValue &) {
+    return std::nullopt;
+  }
+};
+
+template <>
+struct ArgTypeTraits<bool> {
+  static bool hasCorrectType(const VariantValue &value) {
+    return value.isBoolean();
+  }
+
+  static unsigned get(const VariantValue &value) { return value.getBoolean(); }
+
+  static ArgKind getKind() { return ArgKind::Boolean; }
+
+  static std::optional<std::string> getBestGuess(const VariantValue &) {
+    return std::nullopt;
+  }
+};
+
 template <>
 struct ArgTypeTraits<DynMatcher> {
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2..6b554394b3654 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
+// that match a given matcher and print them.
 //
 //===----------------------------------------------------------------------===//
 
@@ -15,25 +15,40 @@
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
 
 #include "MatchersInternal.h"
+#include "mlir/Query/Query.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
 
 namespace mlir::query::matcher {
 
-// MatchFinder is used to find all operations that match a given matcher.
+/// A class that provides utilities to find operations in a DAG
 class MatchFinder {
+
 public:
-  // Returns all operations that match the given matcher.
-  static std::vector<Operation *> getMatches(Operation *root,
-                                             DynMatcher matcher) {
-    std::vector<Operation *> matches;
-
-    // Simple match finding with walk.
-    root->walk([&](Operation *subOp) {
-      if (matcher.match(subOp))
-        matches.push_back(subOp);
-    });
-
-    return matches;
-  }
+  /// A subclass which preserves the matching information
+  struct MatchResult {
+    MatchResult() = default;
+    MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
+
+    /// Contains the root operation of the matching environment
+    Operation *rootOp = nullptr;
+    /// Contains the matching enviroment. This allows the user to easily
+    /// extract the matched operations
+    std::vector<Operation *> matchedOps;
+  };
+  /// Traverses the DAG and collects the "rootOp" + "matching enviroment" for
+  /// a given Matcher
+  std::vector<MatchResult> collectMatches(Operation *root,
+                                          DynMatcher matcher) const;
+  /// Prints the matched operation
+  void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
+  /// Labels the matched operation with the given binding (e.g., "root") and
+  /// prints it
+  void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
+                  const std::string &binding) const;
+  /// Flattens a vector of MatchResults into a vector of operations
+  std::vector<Operation *>
+  flattenMatchedOps(std::vector<MatchResult> &matches) const;
 };
 
 } // namespace mlir::query::matcher
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e..183b2514e109f 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,8 +8,9 @@
 //
 // Implements the base layer of the matcher framework.
 //
-// Matchers are methods that return a Matcher which provides a method
-// match(Operation *op)
+// Matchers are methods that return a Matcher which provides a method one of the
+// following methods: match(Operation *op), match(Operation *op,
+// SetVector<Operation *> &matchedOps)
 //
 // The matcher functions are defined in include/mlir/IR/Matchers.h.
 // This file contains the wrapper classes needed to construct matchers for
@@ -25,6 +26,31 @@
 
 namespace mlir::query::matcher {
 
+// Defaults to false if T has no match() method with the signature:
+// match(Operation* op).
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+// Specialized type trait that evaluates to true if T has a match() method
+// with the signature: match(Operation* op).
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+                               std::declval<Operation *>()))>>
+    : std::true_type {};
+
+// Defaults to false if T has no match() method with the signature:
+// match(Operation* op, SetVector<Operation*>&).
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+// Specialized type trait that evaluates to true if T has a match() method
+// with the signature: match(Operation* op, SetVector<Operation*>&).
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+                              std::declval<Operation *>(),
+                              std::declval<SetVector<Operation *> &>()))>>
+    : std::true_type {};
+
 // Generic interface for matchers on an MLIR operation.
 class MatcherInterface
     : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
@@ -32,6 +58,7 @@ class MatcherInterface
   virtual ~MatcherInterface() = default;
 
   virtual bool match(Operation *op) = 0;
+  virtual bool match(Operation *op, SetVector<Operation *> &matchedOps) = 0;
 };
 
 // MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +67,25 @@ template <typename MatcherFn>
 class MatcherFnImpl : public MatcherInterface {
 public:
   MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
-  bool match(Operation *op) override { return matcherFn.match(op); }
+
+  bool match(Operation *op) override {
+    if constexpr (has_simple_match<MatcherFn>::value)
+      return matcherFn.match(op);
+    return false;
+  }
+
+  bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
+    if constexpr (has_bound_match<MatcherFn>::value)
+      return matcherFn.match(op, matchedOps);
+    return false;
+  }
 
 private:
   MatcherFn matcherFn;
 };
 
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
 class DynMatcher {
 public:
   // Takes ownership of the provided implementation pointer.
@@ -62,12 +100,13 @@ class DynMatcher {
   }
 
   bool match(Operation *op) const { return implementation->match(op); }
+  bool match(Operation *op, SetVector<Operation *> &matchedOps) const {
+    return implementation->match(op, matchedOps);
+  }
 
-  void setFunctionName(StringRef name) { functionName = name.str(); };
-
-  bool hasFunctionName() const { return !functionName.empty(); };
-
-  StringRef getFunctionName() const { return functionName; };
+  void setFunctionName(StringRef name) { functionName = name.str(); }
+  bool hasFunctionName() const { return !functionName.empty(); }
+  StringRef getFunctionName() const { return functionName; }
 
 private:
   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e02..98c0a18e25101 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
 namespace mlir::query::matcher {
 
 // All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Boolean, Matcher, Signed, String };
 
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
@@ -81,6 +81,8 @@ class VariantValue {
   // Specific constructors for each supported type.
   VariantValue(const llvm::StringRef string);
   VariantValue(const VariantMatcher &matcher);
+  VariantValue(int64_t signedValue);
+  VariantValue(bool setBoolean);
 
   // String value functions.
   bool isString() const;
@@ -92,21 +94,36 @@ class VariantValue {
   const VariantMatcher &getMatcher() const;
   void setMatcher(const VariantMatcher &matcher);
 
+  // Signed value functions.
+  bool isSigned() const;
+  int64_t getSigned() const;
+  void setSigned(int64_t signedValue);
+
+  // Boolean value functions.
+  bool isBoolean() const;
+  bool getBoolean() const;
+  void setBoolean(bool booleanValue);
   // String representation of the type of the value.
   std::string getTypeAsString() const;
+  explicit operator bool() const { return hasValue(); }
+  bool hasValue() const { return type != ValueType::Nothing; }
 
 private:
   void reset();
 
   // All supported value types.
   enum class ValueType {
+    Boolean,
+    Matcher,
     Nothing,
+    Signed,
     String,
-    Matcher,
   };
 
   // All supported value types.
   union AllValues {
+    bool Boolean;
+    int64_t Signed;
     llvm::StringRef *String;
     VariantMatcher *Matcher;
   };
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 3adff9f99243f..d84b1b50e8b04 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,4 +1,6 @@
 add_mlir_library(MLIRQueryMatcher
+  MatchFinder.cpp
+  ExtraMatchers.cpp
   Parser.cpp
   RegistryManager.cpp
   VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/ExtraMatchers.cpp b/mlir/lib/Query/Matcher/ExtraMatchers.cpp
new file mode 100644
index 0000000000000..1c69995a5d690
--- /dev/null
+++ b/mlir/lib/Query/Matcher/ExtraMatchers.cpp
@@ -0,0 +1,66 @@
+//===- ExtraMatchers.cpp - Various common matchers ---------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements specific matchers
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/ExtraMatchers.h"
+
+namespace mlir::query::matcher {
+
+bool BackwardSliceMatcher::matches(Operation *rootOp,
+                                   llvm::SetVector<Operation *> &backwardSlice,
+                                   BackwardSliceOptions &options,
+                                   int64_t maxDepth) {
+  options.inclusive = inclusive;
+  options.omitUsesFromAbove = omitUsesFromAbove;
+  options.omitBlockArguments = omitBlockArguments;
+  backwardSlice.clear();
+  llvm::DenseMap<Operation *, int64_t> opDepths;
+  // The starting point is the root op, therfore we set its depth to 0
+  opDepths[rootOp] = 0;
+  options.filter = [&](Operation *subOp) {
+    // If the subOp’s depth exceeds maxDepth, we can stop further computing the
+    // slice for the current branch
+    if (opDepths[subOp] > maxDepth)
+      return false;
+    // Examining subOp's operands to compute the depths of their defining
+    // operations
+    for (auto operand : subOp->getOperands()) {
+      if (auto definingOp = operand.getDefiningOp()) {
+        // Set the defining operation's depth to one level greater than
+        // subOp's depth
+        int64_t newDepth = opDepths[subOp] + 1;
+        if (!opDepths.contains(definingOp)) {
+          opDepths[definingOp] = newDepth;
+        } else {
+          opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
+        }
+        return !(opDepths[subOp] > maxDepth);
+      } else {
+        auto blockArgument = cast<BlockArgument>(operand);
+        Operation *parentOp = blockArgument.getOwner()->getParentOp();
+        if (!parentOp)
+          continue;
+        int64_t newDepth = opDepths[subOp] + 1;
+        if (!opDepths.contains(parentOp)) {
+          opDepths[parentOp] = newDepth;
+        } else {
+          opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
+        }
+        return !(opDepths[parentOp] > maxDepth);
+      }
+    }
+    return true;
+  };
+  getBackwardSlice(rootOp, &backwardSlice, options);
+  return true;
+}
+
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/MatchFinder.cpp b/mlir/lib/Query/Matcher/MatchFinder.cpp
new file mode 100644
index 0000000000000..386b85b1e27a6
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchFinder.cpp
@@ -0,0 +1,68 @@
+//===- MatchFinder.cpp - -----------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the method definitions for the `MatchFinder` class
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/MatchFinder.h"
+namespace mlir::query::matcher {
+
+MatchFinder::MatchResult::MatchResult(Operation *rootOp,
+                                      std::vector<Operation *> matchedOps)
+    : rootOp(rootOp), matchedOps(std::move(matchedOps)) {}
+
+std::vector<MatchFinder::MatchResult>
+MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const {
+  std::vector<MatchResult> results;
+  llvm::SetVector<Operation *> tempStorage;
+  root->walk([&](Operation *subOp) {
+    if (matcher.match(subOp)) {
+      MatchResult match;
+      match.rootOp = subOp;
+      match.matchedOps.push_back(subOp);
+      results.push_back(std::move(match));
+    } else if (matcher.match(subOp, tempStorage)) {
+      results.emplace_back(subOp, std::vector<Operation *>(tempStorage.begin(),
+                                                           tempStorage.end()));
+    }
+    tempStorage.clear();
+  });
+  return results;
+}
+
+void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                             Operation *op) const {
+  auto fileLoc = cast<FileLineColLoc>(op->getLoc());
+  SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
+      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+  llvm::SMDiagnostic diag =
+      qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note, "");
+  diag.print("", os, true, false, true);
+}
+
+void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                             Operation *op, const std::string &binding) const {
+  auto fileLoc = cast<FileLineColLoc>(op->getLoc());
+  auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+  qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+                                     "\"" + binding + "\" binds here");
+}
+
+std::vector<Operation *>
+MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const {
+  std::vector<Operation *> newVector;
+  for (auto &result : matches) {
+    newVector.insert(newVector.end(), result.matchedOps.begin(),
+                     result.matchedOps.end());
+  }
+  return newVector;
+}
+
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f..e392a885c511b 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
     case '\'':
       consumeStringLiteral(&result);
       break;
+    case '0':
+    case '1':
+    case '2':
+    case '3':
+    case '4':
+    case '5':
+    case '6':
+    case '7':
+    case '8':
+    case '9':
+      consumeNumberLiteral(&result);
+      break;
     default:
       parseIdentifierOrInvalid(&result);
       break;
@@ -144,6 +156,18 @@ class Parser::CodeTokenizer {
     return result;
   }
 
+  void consumeNumberLiteral(TokenInfo *result) {
+    StringRef original = code;
+    unsigned value = 0;
+    if (!code.consumeInteger(0, value)) {
+      size_t numConsumed = original.size() - code.size();
+      result->text = original.take_front(numConsumed);
+      result->kind = TokenKind::Literal;
+      result->value = static_cast<int64_t>(value);
+      return;
+    }
+  }
+
   // Consume a string literal, handle escape sequences and missing closing
   // quote.
   void consumeStringLiteral(TokenInfo *result) {
@@ -195,9 +219,22 @@ class Parser::CodeTokenizer {
           break;
         ++tokenLength;
       }
-      result->kind = TokenKind::Ident;
-      result->text = code.substr(0, tokenLength);
+      llvm::StringRef token = code.substr(0, tokenLength);
       code = code.drop_front(tokenLength);
+      // Check if the identifier is a boolean literal
+      if (token == "true") {
+        result->text = "false";
+        result->kind = TokenKind::Literal;
+        result->value = true;
+      } else if (token == "false") {
+        result->text = "false";
+        result->kind = TokenKind::Literal;
+        result->value = false;
+      } else {
+        // Otherwise it is treated as a normal identifier
+        result->kind = TokenKind::Ident;
+        result->text = token;
+      }
     } else {
       result->kind = TokenKind::InvalidChar;
       result->text = code.substr(0, 1);
@@ -257,13 +294,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
 
   if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
     // Parse as a named value.
-    auto namedValue =
-        namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+    if (auto namedValue = namedValues ? namedValues->lookup(nameToken.text)
+                                      : VariantValue()) {
 
-    if (!namedValue.isMatcher()) {
-      error->addError(tokenizer->peekNextToken().range,
-                      ErrorType::ParserNotAMatcher);
-      return false;
+      if (tokenizer->nextTokenKind() != TokenKind::Period) {
+        *value = namedValue;
+        return true;
+      }
+
+      if (!namedValue.isMatcher()) {
+        error->addError(tokenizer->peekNextToken().range,
+                        ErrorType::ParserNotAMatcher);
+        return false;
+      }
     }
 
     if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 645db7109c2de..4b511c5f009e7 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -19,16 +19,15 @@
 namespace mlir::query::matcher {
 namespace {
 
-// This is needed because these matchers are defined as overloaded functions.
-using IsConstantOp = detail::constant_op_matcher();
-using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
-using HasOpName = detail::NameOpMatcher(llvm::StringRef);
-
 // Enum to string for autocomplete.
 static std::string asArgString(ArgKind kind) {
   switch (kind) {
+  case ArgKind::Boolean:
+    return "Boolean";
   case ArgKind::Matcher:
     return "Matcher";
+  case ArgKind::Signed:
+    return "Signed";
   case ArgKind::String:
     return "String";
   }
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 65bd4bd77bcf8..1cb2d48f9d56f 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,6 +56,14 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
   value.Matcher = new VariantMatcher(matcher);
 }
 
+VariantValue::VariantValue(int64_t signedValue) : type(ValueType::Signed) {
+  value.Signed = signedValue;
+}
+
+VariantValue::VariantValue(bool setBoolean) : type(ValueType::Boolean) {
+  value.Boolean = setBoolean;
+}
+
 VariantValue::~VariantValue() { reset(); }
 
 VariantValue &VariantValue::operator=(const VariantValue &other) {
@@ -69,6 +77,12 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
   case ValueType::Matcher:
     setMatcher(other.getMatcher());
     break;
+  case ValueType::Signed:
+    setSigned(other.getSigned());
+    break;
+  case ValueType::Boolean:
+    setBoolean(other.getBoolean());
+    break;
   case ValueType::Nothing:
     type = ValueType::Nothing;
     break;
@@ -85,12 +99,34 @@ void VariantValue::reset() {
     delete value.Matcher;
     break;
   // Cases that do nothing.
+  case ValueType::Signed:
+  case ValueType::Boolean:
   case ValueType::Nothing:
     break;
   }
   type = ValueType::Nothing;
 }
 
+// Signed
+bool VariantValue::isSigned() const { return type == ValueType::Signed; }
+
+int64_t VariantValue::getSigned() const { return value.Signed; }
+
+void VariantValue::setSigned(int64_t newValue) {
+  type = ValueType::Signed;
+  value.Signed = newValue;
+}
+
+// Boolean
+bool VariantValue::isBoolean() const { return type == ValueType::Boolean; }
+
+bool VariantValue::getBoolean() const { return value.Signed; }
+
+void VariantValue::setBoolean(bool newValue) {
+  type = ValueType::Boolean;
+  value.Signed = newValue;
+}
+
 bool VariantValue::isString() const { return type == ValueType::String; }
 
 const llvm::StringRef &VariantValue::getString() const {
@@ -123,6 +159,10 @@ std::string VariantValue::getTypeAsString() const {
     return "String";
   case ValueType::Matcher:
     return "Matcher";
+  case ValueType::Signed:
+    return "Signed";
+  case ValueType::Boolean:
+    return "Boolean";
   case ValueType::Nothing:
     return "Nothing";
   }
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 869ee8f2ae1dc..f060ab80aa73d 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Query/Matcher/MatchFinder.h"
 #include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -26,15 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
   return QueryParser::complete(line, pos, qs);
 }
 
-static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
-                       const std::string &binding) {
-  auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
-  auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
-      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
-  qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
-                                     "\"" + binding + "\" binds here");
-}
-
 // TODO: Extract into a helper function that can be reused outside query
 // context.
 static Operation *extractFunction(std::vector<Operation *> &ops,
@@ -125,28 +117,34 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
   Operation *rootOp = qs.getRootOp();
   int matchCount = 0;
-  std::vector<Operation *> matches =
-      matcher::MatchFinder().getMatches(rootOp, matcher);
+  matcher::MatchFinder finder;
+  auto matches = finder.collectMatches(rootOp, std::move(matcher));
 
   // An extract call is recognized by considering if the matcher has a name.
   // TODO: Consider making the extract more explicit.
   if (matcher.hasFunctionName()) {
     auto functionName = matcher.getFunctionName();
+    std::vector<Operation *> flattenedMatches =
+        finder.flattenMatchedOps(matches);
     Operation *function =
-        extractFunction(matches, rootOp->getContext(), functionName);
+        extractFunction(flattenedMatches, rootOp->getContext(), functionName);
     os << "\n" << *function << "\n\n";
     function->erase();
     return mlir::success();
   }
 
   os << "\n";
-  for (Operation *op : matches) {
+  for (auto &results : matches) {
     os << "Match #" << ++matchCount << ":\n\n";
-    // Placeholder "root" binding for the initial draft.
-    printMatch(os, qs, op, "root");
+    for (auto op : results.matchedOps) {
+      if (op == results.rootOp) {
+        finder.printMatch(os, qs, op, "root");
+      } else {
+        finder.printMatch(os, qs, op);
+      }
+    }
   }
   os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
-
   return mlir::success();
 }
 
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 31aead7d403d0..3990b697ead7f 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -166,7 +166,6 @@ QueryRef QueryParser::doParse() {
 
   case ParsedQueryKind::Quit:
     return endQuery(new QuitQuery);
-
   case ParsedQueryKind::Match: {
     if (completionPos) {
       return completeMatcherExpression();
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
new file mode 100644
index 0000000000000..d18a9cc1d1550
--- /dev/null
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-query %s -c "m getDefinitions(hasOpName(\"arith.addf\"),2,true,false,false)" | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
+  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %2 = arith.addf %in, %in : f32
+    linalg.yield %2 : f32
+  } -> tensor<5x5xf32>
+  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %c2 = arith.constant 2 : index
+    %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
+    %2 = arith.addf %extracted, %extracted : f32
+    linalg.yield %2 : f32
+  } -> tensor<5x5xf32>
+  return
+}
+
+// CHECK: Match #1:
+
+// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} 
+// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
+// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
+
+// CHECK: Match #2:
+
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
+// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
+// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32  
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b..1d392e5f0dcfd 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/ExtraMatchers.h"
 #include "mlir/Query/Matcher/Registry.h"
 #include "mlir/Tools/mlir-query/MlirQueryMain.h"
 
@@ -39,6 +40,8 @@ int main(int argc, char **argv) {
   query::matcher::Registry matcherRegistry;
 
   // Matchers registered in alphabetical order for consistency:
+  matcherRegistry.registerMatcher("getDefinitions",
+                                  query::matcher::m_GetDefinitions);
   matcherRegistry.registerMatcher("hasOpAttrName",
                                   static_cast<HasOpAttrName *>(m_Attr));
   matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

>From 5f940da512394c8a57eacb57bfa451fa89f68c4d Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Thu, 24 Apr 2025 15:32:47 +0000
Subject: [PATCH 2/3] Update grammar Make BackwardSlice matcher more generic
 Capture values in tests

---
 .../mlir/Query/Matcher/ExtraMatchers.h        | 102 +++++++++++++-----
 mlir/include/mlir/Query/Matcher/MatchFinder.h |  25 +++--
 mlir/lib/Query/Matcher/CMakeLists.txt         |   1 -
 mlir/lib/Query/Matcher/ExtraMatchers.cpp      |  66 ------------
 mlir/lib/Query/Matcher/Parser.h               |   5 +-
 mlir/lib/Query/QueryParser.cpp                |   1 +
 mlir/test/mlir-query/complex-test.mlir        |   2 +-
 mlir/tools/mlir-query/mlir-query.cpp          |   5 +-
 8 files changed, 99 insertions(+), 108 deletions(-)
 delete mode 100644 mlir/lib/Query/Matcher/ExtraMatchers.cpp

diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
index 4766a767cf783..48cab7760d5cf 100644
--- a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -12,16 +12,15 @@
 
 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
 #define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
+
 #include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Query/Matcher/MatchersInternal.h"
 
-/// A matcher encapsulating the initial `getBackwardSlice` method from
-/// SliceAnalysis.h
+/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
 /// Additionally, it limits the slice computation to a certain depth level using
-/// a custom filter
+/// a custom filter.
 ///
-/// Example starting from node 9, assuming the matcher
-/// computes the slice for the first two depth levels
+/// Example: starting from node 9, assuming the matcher
+/// computes the slice for the first two depth levels:
 /// ============================
 ///    1       2      3      4
 ///    |_______|      |______|
@@ -37,18 +36,23 @@
 /// Assuming all local orders match the numbering order:
 ///     {5, 7, 6, 8, 9}
 namespace mlir::query::matcher {
+
+template <typename Matcher>
 class BackwardSliceMatcher {
 public:
-  explicit BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
-                                int64_t maxDepth, bool inclusive,
-                                bool omitBlockArguments, bool omitUsesFromAbove)
+  BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
+                       bool omitBlockArguments, bool omitUsesFromAbove)
       : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
         inclusive(inclusive), omitBlockArguments(omitBlockArguments),
         omitUsesFromAbove(omitUsesFromAbove) {}
-  bool match(Operation *op, SetVector<Operation *> &backwardSlice) {
+
+  bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
     BackwardSliceOptions options;
-    return (innerMatcher.match(op) &&
-            matches(op, backwardSlice, options, maxDepth));
+    options.inclusive = inclusive;
+    options.omitUsesFromAbove = omitUsesFromAbove;
+    options.omitBlockArguments = omitBlockArguments;
+    return (innerMatcher.match(rootOp) &&
+            matches(rootOp, backwardSlice, options, maxDepth));
   }
 
 private:
@@ -57,29 +61,75 @@ class BackwardSliceMatcher {
 
 private:
   // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
-  // to determine whether we want to traverse the DAG or not. For example, we
-  // want to explore the DAG only if the top-level operation name is
-  // "arith.addf".
-  query::matcher::DynMatcher innerMatcher;
-  // maxDepth specifies the maximum depth that the matcher can traverse in the
-  // DAG. For example, if maxDepth is 2, the matcher will explore the defining
+  // to determine whether we want to traverse the IR or not. For example, we
+  // want to explore the IR only if the top-level operation name is
+  // `"arith.addf"`.
+  Matcher innerMatcher;
+  // `maxDepth` specifies the maximum depth that the matcher can traverse the
+  // IR. For example, if `maxDepth` is 2, the matcher will explore the defining
   // operations of the top-level op up to 2 levels.
   int64_t maxDepth;
-
   bool inclusive;
   bool omitBlockArguments;
   bool omitUsesFromAbove;
 };
 
-// Matches transitive defs of a top level operation up to N levels
-inline BackwardSliceMatcher
-m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
-                 bool inclusive, bool omitBlockArguments,
-                 bool omitUsesFromAbove) {
+template <typename Matcher>
+bool BackwardSliceMatcher<Matcher>::matches(
+    Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
+    BackwardSliceOptions &options, int64_t maxDepth) {
+  backwardSlice.clear();
+  llvm::DenseMap<Operation *, int64_t> opDepths;
+  // The starting point is the root op; therefore, we set its depth to 0.
+  opDepths[rootOp] = 0;
+  options.filter = [&](Operation *subOp) {
+    // If the subOp's depth exceeds maxDepth, we stop further slicing for this
+    // branch.
+    if (opDepths[subOp] > maxDepth)
+      return false;
+    // Examine subOp's operands to compute depths of their defining operations.
+    for (auto operand : subOp->getOperands()) {
+      if (auto definingOp = operand.getDefiningOp()) {
+        // Set the defining operation's depth to one level greater than
+        // subOp's depth.
+        int64_t newDepth = opDepths[subOp] + 1;
+        if (!opDepths.contains(definingOp)) {
+          opDepths[definingOp] = newDepth;
+        } else {
+          opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
+        }
+        return !(opDepths[subOp] > maxDepth);
+      } else {
+        auto blockArgument = cast<BlockArgument>(operand);
+        Operation *parentOp = blockArgument.getOwner()->getParentOp();
+        if (!parentOp)
+          continue;
+        int64_t newDepth = opDepths[subOp] + 1;
+        if (!opDepths.contains(parentOp)) {
+          opDepths[parentOp] = newDepth;
+        } else {
+          opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
+        }
+        return !(opDepths[parentOp] > maxDepth);
+      }
+    }
+    return true;
+  };
+  getBackwardSlice(rootOp, &backwardSlice, options);
+  return true;
+}
+
+// Matches transitive defs of a top-level operation up to N levels.
+template <typename Matcher>
+inline BackwardSliceMatcher<Matcher>
+m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
+                 bool omitBlockArguments, bool omitUsesFromAbove) {
   assert(maxDepth >= 0 && "maxDepth must be non-negative");
-  return BackwardSliceMatcher(std::move(innerMatcher), maxDepth, inclusive,
-                              omitBlockArguments, omitUsesFromAbove);
+  return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth,
+                                       inclusive, omitBlockArguments,
+                                       omitUsesFromAbove);
 }
+
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 6b554394b3654..f8abf20ef60bb 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -21,32 +21,35 @@
 
 namespace mlir::query::matcher {
 
-/// A class that provides utilities to find operations in a DAG
+/// A class that provides utilities to find operations in the IR.
 class MatchFinder {
 
 public:
-  /// A subclass which preserves the matching information
+  /// A subclass which preserves the matching information. Each instance
+  /// contains the `rootOp` along with the matching environment.
   struct MatchResult {
     MatchResult() = default;
     MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
 
-    /// Contains the root operation of the matching environment
     Operation *rootOp = nullptr;
-    /// Contains the matching enviroment. This allows the user to easily
-    /// extract the matched operations
+    /// Contains the matching environment.
     std::vector<Operation *> matchedOps;
   };
-  /// Traverses the DAG and collects the "rootOp" + "matching enviroment" for
-  /// a given Matcher
+
+  /// Traverses the IR and returns a vector of `MatchResult` for each match of
+  /// the `matcher`.
   std::vector<MatchResult> collectMatches(Operation *root,
                                           DynMatcher matcher) const;
-  /// Prints the matched operation
+
+  /// Prints the matched operation.
   void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
-  /// Labels the matched operation with the given binding (e.g., "root") and
-  /// prints it
+
+  /// Labels the matched operation with the given binding (e.g., `"root"`) and
+  /// prints it.
   void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
                   const std::string &binding) const;
-  /// Flattens a vector of MatchResults into a vector of operations
+
+  /// Flattens a vector of `MatchResult` into a vector of operations.
   std::vector<Operation *>
   flattenMatchedOps(std::vector<MatchResult> &matches) const;
 };
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index d84b1b50e8b04..629479bf7adc1 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_library(MLIRQueryMatcher
   MatchFinder.cpp
-  ExtraMatchers.cpp
   Parser.cpp
   RegistryManager.cpp
   VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/ExtraMatchers.cpp b/mlir/lib/Query/Matcher/ExtraMatchers.cpp
deleted file mode 100644
index 1c69995a5d690..0000000000000
--- a/mlir/lib/Query/Matcher/ExtraMatchers.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-//===- ExtraMatchers.cpp - Various common matchers ---------------*- C++-*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements specific matchers
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Query/Matcher/ExtraMatchers.h"
-
-namespace mlir::query::matcher {
-
-bool BackwardSliceMatcher::matches(Operation *rootOp,
-                                   llvm::SetVector<Operation *> &backwardSlice,
-                                   BackwardSliceOptions &options,
-                                   int64_t maxDepth) {
-  options.inclusive = inclusive;
-  options.omitUsesFromAbove = omitUsesFromAbove;
-  options.omitBlockArguments = omitBlockArguments;
-  backwardSlice.clear();
-  llvm::DenseMap<Operation *, int64_t> opDepths;
-  // The starting point is the root op, therfore we set its depth to 0
-  opDepths[rootOp] = 0;
-  options.filter = [&](Operation *subOp) {
-    // If the subOp’s depth exceeds maxDepth, we can stop further computing the
-    // slice for the current branch
-    if (opDepths[subOp] > maxDepth)
-      return false;
-    // Examining subOp's operands to compute the depths of their defining
-    // operations
-    for (auto operand : subOp->getOperands()) {
-      if (auto definingOp = operand.getDefiningOp()) {
-        // Set the defining operation's depth to one level greater than
-        // subOp's depth
-        int64_t newDepth = opDepths[subOp] + 1;
-        if (!opDepths.contains(definingOp)) {
-          opDepths[definingOp] = newDepth;
-        } else {
-          opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
-        }
-        return !(opDepths[subOp] > maxDepth);
-      } else {
-        auto blockArgument = cast<BlockArgument>(operand);
-        Operation *parentOp = blockArgument.getOwner()->getParentOp();
-        if (!parentOp)
-          continue;
-        int64_t newDepth = opDepths[subOp] + 1;
-        if (!opDepths.contains(parentOp)) {
-          opDepths[parentOp] = newDepth;
-        } else {
-          opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
-        }
-        return !(opDepths[parentOp] > maxDepth);
-      }
-    }
-    return true;
-  };
-  getBackwardSlice(rootOp, &backwardSlice, options);
-  return true;
-}
-
-} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h
index 58968023022d5..2199a2335ba9c 100644
--- a/mlir/lib/Query/Matcher/Parser.h
+++ b/mlir/lib/Query/Matcher/Parser.h
@@ -16,8 +16,11 @@
 // provided to the parser.
 //
 // The grammar for the supported expressions is as follows:
-// <Expression>        := <StringLiteral> | <MatcherExpression>
+// <Expression>        := <Literal> | <MatcherExpression>
+// <Literal>           := <StringLiteral> | <NumericLiteral> | <BooleanLiteral>
 // <StringLiteral>     := "quoted string"
+// <BooleanLiteral>    := "true" | "false"
+// <NumericLiteral>    := [0-9]+
 // <MatcherExpression> := <MatcherName>(<ArgumentList>)
 // <MatcherName>       := [a-zA-Z]+
 // <ArgumentList>      := <Expression> | <Expression>,<ArgumentList>
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 3990b697ead7f..31aead7d403d0 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -166,6 +166,7 @@ QueryRef QueryParser::doParse() {
 
   case ParsedQueryKind::Quit:
     return endQuery(new QuitQuery);
+
   case ParsedQueryKind::Match: {
     if (completionPos) {
       return completeMatcherExpression();
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
index d18a9cc1d1550..3e0bf8b8b9fa6 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -26,7 +26,7 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
 
 // CHECK: Match #2:
 
-// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
 // CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
 // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
 // CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32  
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 1d392e5f0dcfd..0cc9a5db25a91 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -40,8 +40,9 @@ int main(int argc, char **argv) {
   query::matcher::Registry matcherRegistry;
 
   // Matchers registered in alphabetical order for consistency:
-  matcherRegistry.registerMatcher("getDefinitions",
-                                  query::matcher::m_GetDefinitions);
+  matcherRegistry.registerMatcher(
+      "getDefinitions",
+      query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
   matcherRegistry.registerMatcher("hasOpAttrName",
                                   static_cast<HasOpAttrName *>(m_Attr));
   matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

>From 9a16aedff5e14b057cf7105736a1bcbab364d506 Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Sat, 3 May 2025 16:36:19 +0000
Subject: [PATCH 3/3] Improve depth limiting approach

---
 .../mlir/Query/Matcher/ExtraMatchers.h        | 43 +++++++++++--------
 mlir/test/mlir-query/complex-test.mlir        |  2 +-
 mlir/tools/mlir-query/mlir-query.cpp          |  3 ++
 3 files changed, 28 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
index 48cab7760d5cf..097cfe82ab996 100644
--- a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -80,37 +80,33 @@ bool BackwardSliceMatcher<Matcher>::matches(
     BackwardSliceOptions &options, int64_t maxDepth) {
   backwardSlice.clear();
   llvm::DenseMap<Operation *, int64_t> opDepths;
-  // The starting point is the root op; therefore, we set its depth to 0.
+  // Initializing the root op with a depth of 0
   opDepths[rootOp] = 0;
   options.filter = [&](Operation *subOp) {
-    // If the subOp's depth exceeds maxDepth, we stop further slicing for this
-    // branch.
-    if (opDepths[subOp] > maxDepth)
+    // If the subOp hasn't been recorded in opDepths, it is deeper than
+    // maxDepth.
+    if (!opDepths.contains(subOp))
       return false;
     // Examine subOp's operands to compute depths of their defining operations.
     for (auto operand : subOp->getOperands()) {
+      int64_t newDepth = opDepths[subOp] + 1;
+      // If the newDepth is greater than maxDepth, further computation can be
+      // skipped.
+      if (newDepth > maxDepth)
+        continue;
+
       if (auto definingOp = operand.getDefiningOp()) {
-        // Set the defining operation's depth to one level greater than
-        // subOp's depth.
-        int64_t newDepth = opDepths[subOp] + 1;
-        if (!opDepths.contains(definingOp)) {
+        // Registers the minimum depth
+        if (!opDepths.contains(definingOp) || newDepth < opDepths[definingOp])
           opDepths[definingOp] = newDepth;
-        } else {
-          opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
-        }
-        return !(opDepths[subOp] > maxDepth);
       } else {
         auto blockArgument = cast<BlockArgument>(operand);
         Operation *parentOp = blockArgument.getOwner()->getParentOp();
         if (!parentOp)
           continue;
-        int64_t newDepth = opDepths[subOp] + 1;
-        if (!opDepths.contains(parentOp)) {
+
+        if (!opDepths.contains(parentOp) || newDepth < opDepths[parentOp])
           opDepths[parentOp] = newDepth;
-        } else {
-          opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
-        }
-        return !(opDepths[parentOp] > maxDepth);
       }
     }
     return true;
@@ -119,7 +115,7 @@ bool BackwardSliceMatcher<Matcher>::matches(
   return true;
 }
 
-// Matches transitive defs of a top-level operation up to N levels.
+/// Matches transitive defs of a top-level operation up to N levels.
 template <typename Matcher>
 inline BackwardSliceMatcher<Matcher>
 m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
@@ -130,6 +126,15 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
                                        omitUsesFromAbove);
 }
 
+/// Matches all transitive defs of a top-level operation up to N levels
+template <typename Matcher>
+inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
+                                                         int64_t maxDepth) {
+  assert(maxDepth >= 0 && "maxDepth must be non-negative");
+  return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth, true,
+                                       false, false);
+}
+
 } // namespace mlir::query::matcher
 
 #endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
index 3e0bf8b8b9fa6..ad96f03747a43 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-query %s -c "m getDefinitions(hasOpName(\"arith.addf\"),2,true,false,false)" | FileCheck %s
+// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
 func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0cc9a5db25a91..b83c1c913ebfe 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -43,6 +43,9 @@ int main(int argc, char **argv) {
   matcherRegistry.registerMatcher(
       "getDefinitions",
       query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
+  matcherRegistry.registerMatcher(
+      "getAllDefinitions",
+      query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>);
   matcherRegistry.registerMatcher("hasOpAttrName",
                                   static_cast<HasOpAttrName *>(m_Attr));
   matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));



More information about the Mlir-commits mailing list