[Mlir-commits] [mlir] [mlir] Improve mlir-query tool by implementing `getBackwardSlice` and `getForwardSlice` matchers (PR #115670)
Denzel-Brian Budii
llvmlistbot at llvm.org
Sat Apr 26 03:15:38 PDT 2025
https://github.com/chios202 updated https://github.com/llvm/llvm-project/pull/115670
>From e07e1feb13ea9607424c6817808f02a2f313f867 Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Tue, 22 Apr 2025 11:44:35 +0000
Subject: [PATCH 1/2] Compute shortest depth in backwardSlice method Relocate
backwardSlice matcher to Query specific headers Remove unncecessary code
---
.../mlir/Query/Matcher/ExtraMatchers.h | 85 +++++++++++++++++++
mlir/include/mlir/Query/Matcher/Marshallers.h | 30 +++++++
mlir/include/mlir/Query/Matcher/MatchFinder.h | 45 ++++++----
.../mlir/Query/Matcher/MatchersInternal.h | 59 ++++++++++---
.../include/mlir/Query/Matcher/VariantValue.h | 21 ++++-
mlir/lib/Query/Matcher/CMakeLists.txt | 2 +
mlir/lib/Query/Matcher/ExtraMatchers.cpp | 66 ++++++++++++++
mlir/lib/Query/Matcher/MatchFinder.cpp | 68 +++++++++++++++
mlir/lib/Query/Matcher/Parser.cpp | 59 +++++++++++--
mlir/lib/Query/Matcher/RegistryManager.cpp | 9 +-
mlir/lib/Query/Matcher/VariantValue.cpp | 40 +++++++++
mlir/lib/Query/Query.cpp | 30 +++----
mlir/lib/Query/QueryParser.cpp | 1 -
mlir/test/mlir-query/complex-test.mlir | 32 +++++++
mlir/tools/mlir-query/mlir-query.cpp | 3 +
15 files changed, 493 insertions(+), 57 deletions(-)
create mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h
create mode 100644 mlir/lib/Query/Matcher/ExtraMatchers.cpp
create mode 100644 mlir/lib/Query/Matcher/MatchFinder.cpp
create mode 100644 mlir/test/mlir-query/complex-test.mlir
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 0000000000000..4766a767cf783
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,85 @@
+//===- ExtraMatchers.h - Various common matchers --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides matchers that depend on Query.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Query/Matcher/MatchersInternal.h"
+
+/// A matcher encapsulating the initial `getBackwardSlice` method from
+/// SliceAnalysis.h
+/// Additionally, it limits the slice computation to a certain depth level using
+/// a custom filter
+///
+/// Example starting from node 9, assuming the matcher
+/// computes the slice for the first two depth levels
+/// ============================
+/// 1 2 3 4
+/// |_______| |______|
+/// | | |
+/// | 5 6
+/// |___|_____________|
+/// | |
+/// 7 8
+/// |_______________|
+/// |
+/// 9
+///
+/// Assuming all local orders match the numbering order:
+/// {5, 7, 6, 8, 9}
+namespace mlir::query::matcher {
+class BackwardSliceMatcher {
+public:
+ explicit BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
+ int64_t maxDepth, bool inclusive,
+ bool omitBlockArguments, bool omitUsesFromAbove)
+ : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
+ inclusive(inclusive), omitBlockArguments(omitBlockArguments),
+ omitUsesFromAbove(omitUsesFromAbove) {}
+ bool match(Operation *op, SetVector<Operation *> &backwardSlice) {
+ BackwardSliceOptions options;
+ return (innerMatcher.match(op) &&
+ matches(op, backwardSlice, options, maxDepth));
+ }
+
+private:
+ bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
+ BackwardSliceOptions &options, int64_t maxDepth);
+
+private:
+ // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
+ // to determine whether we want to traverse the DAG or not. For example, we
+ // want to explore the DAG only if the top-level operation name is
+ // "arith.addf".
+ query::matcher::DynMatcher innerMatcher;
+ // maxDepth specifies the maximum depth that the matcher can traverse in the
+ // DAG. For example, if maxDepth is 2, the matcher will explore the defining
+ // operations of the top-level op up to 2 levels.
+ int64_t maxDepth;
+
+ bool inclusive;
+ bool omitBlockArguments;
+ bool omitUsesFromAbove;
+};
+
+// Matches transitive defs of a top level operation up to N levels
+inline BackwardSliceMatcher
+m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
+ bool inclusive, bool omitBlockArguments,
+ bool omitUsesFromAbove) {
+ assert(maxDepth >= 0 && "maxDepth must be non-negative");
+ return BackwardSliceMatcher(std::move(innerMatcher), maxDepth, inclusive,
+ omitBlockArguments, omitUsesFromAbove);
+}
+} // namespace mlir::query::matcher
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc..012bf7b9ec4a9 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,36 @@ struct ArgTypeTraits<llvm::StringRef> {
}
};
+template <>
+struct ArgTypeTraits<int64_t> {
+ static bool hasCorrectType(const VariantValue &value) {
+ return value.isSigned();
+ }
+
+ static unsigned get(const VariantValue &value) { return value.getSigned(); }
+
+ static ArgKind getKind() { return ArgKind::Signed; }
+
+ static std::optional<std::string> getBestGuess(const VariantValue &) {
+ return std::nullopt;
+ }
+};
+
+template <>
+struct ArgTypeTraits<bool> {
+ static bool hasCorrectType(const VariantValue &value) {
+ return value.isBoolean();
+ }
+
+ static unsigned get(const VariantValue &value) { return value.getBoolean(); }
+
+ static ArgKind getKind() { return ArgKind::Boolean; }
+
+ static std::optional<std::string> getBestGuess(const VariantValue &) {
+ return std::nullopt;
+ }
+};
+
template <>
struct ArgTypeTraits<DynMatcher> {
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2..6b554394b3654 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
+// that match a given matcher and print them.
//
//===----------------------------------------------------------------------===//
@@ -15,25 +15,40 @@
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#include "MatchersInternal.h"
+#include "mlir/Query/Query.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
namespace mlir::query::matcher {
-// MatchFinder is used to find all operations that match a given matcher.
+/// A class that provides utilities to find operations in a DAG
class MatchFinder {
+
public:
- // Returns all operations that match the given matcher.
- static std::vector<Operation *> getMatches(Operation *root,
- DynMatcher matcher) {
- std::vector<Operation *> matches;
-
- // Simple match finding with walk.
- root->walk([&](Operation *subOp) {
- if (matcher.match(subOp))
- matches.push_back(subOp);
- });
-
- return matches;
- }
+ /// A subclass which preserves the matching information
+ struct MatchResult {
+ MatchResult() = default;
+ MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
+
+ /// Contains the root operation of the matching environment
+ Operation *rootOp = nullptr;
+ /// Contains the matching enviroment. This allows the user to easily
+ /// extract the matched operations
+ std::vector<Operation *> matchedOps;
+ };
+ /// Traverses the DAG and collects the "rootOp" + "matching enviroment" for
+ /// a given Matcher
+ std::vector<MatchResult> collectMatches(Operation *root,
+ DynMatcher matcher) const;
+ /// Prints the matched operation
+ void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
+ /// Labels the matched operation with the given binding (e.g., "root") and
+ /// prints it
+ void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
+ const std::string &binding) const;
+ /// Flattens a vector of MatchResults into a vector of operations
+ std::vector<Operation *>
+ flattenMatchedOps(std::vector<MatchResult> &matches) const;
};
} // namespace mlir::query::matcher
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e..183b2514e109f 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,8 +8,9 @@
//
// Implements the base layer of the matcher framework.
//
-// Matchers are methods that return a Matcher which provides a method
-// match(Operation *op)
+// Matchers are methods that return a Matcher which provides a method one of the
+// following methods: match(Operation *op), match(Operation *op,
+// SetVector<Operation *> &matchedOps)
//
// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
@@ -25,6 +26,31 @@
namespace mlir::query::matcher {
+// Defaults to false if T has no match() method with the signature:
+// match(Operation* op).
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+// Specialized type trait that evaluates to true if T has a match() method
+// with the signature: match(Operation* op).
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>()))>>
+ : std::true_type {};
+
+// Defaults to false if T has no match() method with the signature:
+// match(Operation* op, SetVector<Operation*>&).
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+// Specialized type trait that evaluates to true if T has a match() method
+// with the signature: match(Operation* op, SetVector<Operation*>&).
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>(),
+ std::declval<SetVector<Operation *> &>()))>>
+ : std::true_type {};
+
// Generic interface for matchers on an MLIR operation.
class MatcherInterface
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
@@ -32,6 +58,7 @@ class MatcherInterface
virtual ~MatcherInterface() = default;
virtual bool match(Operation *op) = 0;
+ virtual bool match(Operation *op, SetVector<Operation *> &matchedOps) = 0;
};
// MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +67,25 @@ template <typename MatcherFn>
class MatcherFnImpl : public MatcherInterface {
public:
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
- bool match(Operation *op) override { return matcherFn.match(op); }
+
+ bool match(Operation *op) override {
+ if constexpr (has_simple_match<MatcherFn>::value)
+ return matcherFn.match(op);
+ return false;
+ }
+
+ bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
+ if constexpr (has_bound_match<MatcherFn>::value)
+ return matcherFn.match(op, matchedOps);
+ return false;
+ }
private:
MatcherFn matcherFn;
};
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
class DynMatcher {
public:
// Takes ownership of the provided implementation pointer.
@@ -62,12 +100,13 @@ class DynMatcher {
}
bool match(Operation *op) const { return implementation->match(op); }
+ bool match(Operation *op, SetVector<Operation *> &matchedOps) const {
+ return implementation->match(op, matchedOps);
+ }
- void setFunctionName(StringRef name) { functionName = name.str(); };
-
- bool hasFunctionName() const { return !functionName.empty(); };
-
- StringRef getFunctionName() const { return functionName; };
+ void setFunctionName(StringRef name) { functionName = name.str(); }
+ bool hasFunctionName() const { return !functionName.empty(); }
+ StringRef getFunctionName() const { return functionName; }
private:
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e02..98c0a18e25101 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
namespace mlir::query::matcher {
// All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Boolean, Matcher, Signed, String };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
@@ -81,6 +81,8 @@ class VariantValue {
// Specific constructors for each supported type.
VariantValue(const llvm::StringRef string);
VariantValue(const VariantMatcher &matcher);
+ VariantValue(int64_t signedValue);
+ VariantValue(bool setBoolean);
// String value functions.
bool isString() const;
@@ -92,21 +94,36 @@ class VariantValue {
const VariantMatcher &getMatcher() const;
void setMatcher(const VariantMatcher &matcher);
+ // Signed value functions.
+ bool isSigned() const;
+ int64_t getSigned() const;
+ void setSigned(int64_t signedValue);
+
+ // Boolean value functions.
+ bool isBoolean() const;
+ bool getBoolean() const;
+ void setBoolean(bool booleanValue);
// String representation of the type of the value.
std::string getTypeAsString() const;
+ explicit operator bool() const { return hasValue(); }
+ bool hasValue() const { return type != ValueType::Nothing; }
private:
void reset();
// All supported value types.
enum class ValueType {
+ Boolean,
+ Matcher,
Nothing,
+ Signed,
String,
- Matcher,
};
// All supported value types.
union AllValues {
+ bool Boolean;
+ int64_t Signed;
llvm::StringRef *String;
VariantMatcher *Matcher;
};
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 3adff9f99243f..d84b1b50e8b04 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,4 +1,6 @@
add_mlir_library(MLIRQueryMatcher
+ MatchFinder.cpp
+ ExtraMatchers.cpp
Parser.cpp
RegistryManager.cpp
VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/ExtraMatchers.cpp b/mlir/lib/Query/Matcher/ExtraMatchers.cpp
new file mode 100644
index 0000000000000..1c69995a5d690
--- /dev/null
+++ b/mlir/lib/Query/Matcher/ExtraMatchers.cpp
@@ -0,0 +1,66 @@
+//===- ExtraMatchers.cpp - Various common matchers ---------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements specific matchers
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/ExtraMatchers.h"
+
+namespace mlir::query::matcher {
+
+bool BackwardSliceMatcher::matches(Operation *rootOp,
+ llvm::SetVector<Operation *> &backwardSlice,
+ BackwardSliceOptions &options,
+ int64_t maxDepth) {
+ options.inclusive = inclusive;
+ options.omitUsesFromAbove = omitUsesFromAbove;
+ options.omitBlockArguments = omitBlockArguments;
+ backwardSlice.clear();
+ llvm::DenseMap<Operation *, int64_t> opDepths;
+ // The starting point is the root op, therfore we set its depth to 0
+ opDepths[rootOp] = 0;
+ options.filter = [&](Operation *subOp) {
+ // If the subOp’s depth exceeds maxDepth, we can stop further computing the
+ // slice for the current branch
+ if (opDepths[subOp] > maxDepth)
+ return false;
+ // Examining subOp's operands to compute the depths of their defining
+ // operations
+ for (auto operand : subOp->getOperands()) {
+ if (auto definingOp = operand.getDefiningOp()) {
+ // Set the defining operation's depth to one level greater than
+ // subOp's depth
+ int64_t newDepth = opDepths[subOp] + 1;
+ if (!opDepths.contains(definingOp)) {
+ opDepths[definingOp] = newDepth;
+ } else {
+ opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
+ }
+ return !(opDepths[subOp] > maxDepth);
+ } else {
+ auto blockArgument = cast<BlockArgument>(operand);
+ Operation *parentOp = blockArgument.getOwner()->getParentOp();
+ if (!parentOp)
+ continue;
+ int64_t newDepth = opDepths[subOp] + 1;
+ if (!opDepths.contains(parentOp)) {
+ opDepths[parentOp] = newDepth;
+ } else {
+ opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
+ }
+ return !(opDepths[parentOp] > maxDepth);
+ }
+ }
+ return true;
+ };
+ getBackwardSlice(rootOp, &backwardSlice, options);
+ return true;
+}
+
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/MatchFinder.cpp b/mlir/lib/Query/Matcher/MatchFinder.cpp
new file mode 100644
index 0000000000000..386b85b1e27a6
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchFinder.cpp
@@ -0,0 +1,68 @@
+//===- MatchFinder.cpp - -----------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the method definitions for the `MatchFinder` class
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/MatchFinder.h"
+namespace mlir::query::matcher {
+
+MatchFinder::MatchResult::MatchResult(Operation *rootOp,
+ std::vector<Operation *> matchedOps)
+ : rootOp(rootOp), matchedOps(std::move(matchedOps)) {}
+
+std::vector<MatchFinder::MatchResult>
+MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const {
+ std::vector<MatchResult> results;
+ llvm::SetVector<Operation *> tempStorage;
+ root->walk([&](Operation *subOp) {
+ if (matcher.match(subOp)) {
+ MatchResult match;
+ match.rootOp = subOp;
+ match.matchedOps.push_back(subOp);
+ results.push_back(std::move(match));
+ } else if (matcher.match(subOp, tempStorage)) {
+ results.emplace_back(subOp, std::vector<Operation *>(tempStorage.begin(),
+ tempStorage.end()));
+ }
+ tempStorage.clear();
+ });
+ return results;
+}
+
+void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ Operation *op) const {
+ auto fileLoc = cast<FileLineColLoc>(op->getLoc());
+ SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ llvm::SMDiagnostic diag =
+ qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note, "");
+ diag.print("", os, true, false, true);
+}
+
+void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ Operation *op, const std::string &binding) const {
+ auto fileLoc = cast<FileLineColLoc>(op->getLoc());
+ auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+ "\"" + binding + "\" binds here");
+}
+
+std::vector<Operation *>
+MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const {
+ std::vector<Operation *> newVector;
+ for (auto &result : matches) {
+ newVector.insert(newVector.end(), result.matchedOps.begin(),
+ result.matchedOps.end());
+ }
+ return newVector;
+}
+
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f..e392a885c511b 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
case '\'':
consumeStringLiteral(&result);
break;
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ consumeNumberLiteral(&result);
+ break;
default:
parseIdentifierOrInvalid(&result);
break;
@@ -144,6 +156,18 @@ class Parser::CodeTokenizer {
return result;
}
+ void consumeNumberLiteral(TokenInfo *result) {
+ StringRef original = code;
+ unsigned value = 0;
+ if (!code.consumeInteger(0, value)) {
+ size_t numConsumed = original.size() - code.size();
+ result->text = original.take_front(numConsumed);
+ result->kind = TokenKind::Literal;
+ result->value = static_cast<int64_t>(value);
+ return;
+ }
+ }
+
// Consume a string literal, handle escape sequences and missing closing
// quote.
void consumeStringLiteral(TokenInfo *result) {
@@ -195,9 +219,22 @@ class Parser::CodeTokenizer {
break;
++tokenLength;
}
- result->kind = TokenKind::Ident;
- result->text = code.substr(0, tokenLength);
+ llvm::StringRef token = code.substr(0, tokenLength);
code = code.drop_front(tokenLength);
+ // Check if the identifier is a boolean literal
+ if (token == "true") {
+ result->text = "false";
+ result->kind = TokenKind::Literal;
+ result->value = true;
+ } else if (token == "false") {
+ result->text = "false";
+ result->kind = TokenKind::Literal;
+ result->value = false;
+ } else {
+ // Otherwise it is treated as a normal identifier
+ result->kind = TokenKind::Ident;
+ result->text = token;
+ }
} else {
result->kind = TokenKind::InvalidChar;
result->text = code.substr(0, 1);
@@ -257,13 +294,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
// Parse as a named value.
- auto namedValue =
- namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+ if (auto namedValue = namedValues ? namedValues->lookup(nameToken.text)
+ : VariantValue()) {
- if (!namedValue.isMatcher()) {
- error->addError(tokenizer->peekNextToken().range,
- ErrorType::ParserNotAMatcher);
- return false;
+ if (tokenizer->nextTokenKind() != TokenKind::Period) {
+ *value = namedValue;
+ return true;
+ }
+
+ if (!namedValue.isMatcher()) {
+ error->addError(tokenizer->peekNextToken().range,
+ ErrorType::ParserNotAMatcher);
+ return false;
+ }
}
if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 645db7109c2de..4b511c5f009e7 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -19,16 +19,15 @@
namespace mlir::query::matcher {
namespace {
-// This is needed because these matchers are defined as overloaded functions.
-using IsConstantOp = detail::constant_op_matcher();
-using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
-using HasOpName = detail::NameOpMatcher(llvm::StringRef);
-
// Enum to string for autocomplete.
static std::string asArgString(ArgKind kind) {
switch (kind) {
+ case ArgKind::Boolean:
+ return "Boolean";
case ArgKind::Matcher:
return "Matcher";
+ case ArgKind::Signed:
+ return "Signed";
case ArgKind::String:
return "String";
}
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 65bd4bd77bcf8..1cb2d48f9d56f 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,6 +56,14 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
value.Matcher = new VariantMatcher(matcher);
}
+VariantValue::VariantValue(int64_t signedValue) : type(ValueType::Signed) {
+ value.Signed = signedValue;
+}
+
+VariantValue::VariantValue(bool setBoolean) : type(ValueType::Boolean) {
+ value.Boolean = setBoolean;
+}
+
VariantValue::~VariantValue() { reset(); }
VariantValue &VariantValue::operator=(const VariantValue &other) {
@@ -69,6 +77,12 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
case ValueType::Matcher:
setMatcher(other.getMatcher());
break;
+ case ValueType::Signed:
+ setSigned(other.getSigned());
+ break;
+ case ValueType::Boolean:
+ setBoolean(other.getBoolean());
+ break;
case ValueType::Nothing:
type = ValueType::Nothing;
break;
@@ -85,12 +99,34 @@ void VariantValue::reset() {
delete value.Matcher;
break;
// Cases that do nothing.
+ case ValueType::Signed:
+ case ValueType::Boolean:
case ValueType::Nothing:
break;
}
type = ValueType::Nothing;
}
+// Signed
+bool VariantValue::isSigned() const { return type == ValueType::Signed; }
+
+int64_t VariantValue::getSigned() const { return value.Signed; }
+
+void VariantValue::setSigned(int64_t newValue) {
+ type = ValueType::Signed;
+ value.Signed = newValue;
+}
+
+// Boolean
+bool VariantValue::isBoolean() const { return type == ValueType::Boolean; }
+
+bool VariantValue::getBoolean() const { return value.Signed; }
+
+void VariantValue::setBoolean(bool newValue) {
+ type = ValueType::Boolean;
+ value.Signed = newValue;
+}
+
bool VariantValue::isString() const { return type == ValueType::String; }
const llvm::StringRef &VariantValue::getString() const {
@@ -123,6 +159,10 @@ std::string VariantValue::getTypeAsString() const {
return "String";
case ValueType::Matcher:
return "Matcher";
+ case ValueType::Signed:
+ return "Signed";
+ case ValueType::Boolean:
+ return "Boolean";
case ValueType::Nothing:
return "Nothing";
}
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 869ee8f2ae1dc..f060ab80aa73d 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
@@ -26,15 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
return QueryParser::complete(line, pos, qs);
}
-static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
- const std::string &binding) {
- auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
- auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
- qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
- qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
- "\"" + binding + "\" binds here");
-}
-
// TODO: Extract into a helper function that can be reused outside query
// context.
static Operation *extractFunction(std::vector<Operation *> &ops,
@@ -125,28 +117,34 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
int matchCount = 0;
- std::vector<Operation *> matches =
- matcher::MatchFinder().getMatches(rootOp, matcher);
+ matcher::MatchFinder finder;
+ auto matches = finder.collectMatches(rootOp, std::move(matcher));
// An extract call is recognized by considering if the matcher has a name.
// TODO: Consider making the extract more explicit.
if (matcher.hasFunctionName()) {
auto functionName = matcher.getFunctionName();
+ std::vector<Operation *> flattenedMatches =
+ finder.flattenMatchedOps(matches);
Operation *function =
- extractFunction(matches, rootOp->getContext(), functionName);
+ extractFunction(flattenedMatches, rootOp->getContext(), functionName);
os << "\n" << *function << "\n\n";
function->erase();
return mlir::success();
}
os << "\n";
- for (Operation *op : matches) {
+ for (auto &results : matches) {
os << "Match #" << ++matchCount << ":\n\n";
- // Placeholder "root" binding for the initial draft.
- printMatch(os, qs, op, "root");
+ for (auto op : results.matchedOps) {
+ if (op == results.rootOp) {
+ finder.printMatch(os, qs, op, "root");
+ } else {
+ finder.printMatch(os, qs, op);
+ }
+ }
}
os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
-
return mlir::success();
}
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 31aead7d403d0..3990b697ead7f 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -166,7 +166,6 @@ QueryRef QueryParser::doParse() {
case ParsedQueryKind::Quit:
return endQuery(new QuitQuery);
-
case ParsedQueryKind::Match: {
if (completionPos) {
return completeMatcherExpression();
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
new file mode 100644
index 0000000000000..d18a9cc1d1550
--- /dev/null
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-query %s -c "m getDefinitions(hasOpName(\"arith.addf\"),2,true,false,false)" | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
+ %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %2 = arith.addf %in, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<5x5xf32>
+ %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %c2 = arith.constant 2 : index
+ %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
+ %2 = arith.addf %extracted, %extracted : f32
+ linalg.yield %2 : f32
+ } -> tensor<5x5xf32>
+ return
+}
+
+// CHECK: Match #1:
+
+// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
+// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
+
+// CHECK: Match #2:
+
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
+// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
+// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b..1d392e5f0dcfd 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/ExtraMatchers.h"
#include "mlir/Query/Matcher/Registry.h"
#include "mlir/Tools/mlir-query/MlirQueryMain.h"
@@ -39,6 +40,8 @@ int main(int argc, char **argv) {
query::matcher::Registry matcherRegistry;
// Matchers registered in alphabetical order for consistency:
+ matcherRegistry.registerMatcher("getDefinitions",
+ query::matcher::m_GetDefinitions);
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
>From 5f940da512394c8a57eacb57bfa451fa89f68c4d Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Thu, 24 Apr 2025 15:32:47 +0000
Subject: [PATCH 2/2] Update grammar Make BackwardSlice matcher more generic
Capture values in tests
---
.../mlir/Query/Matcher/ExtraMatchers.h | 102 +++++++++++++-----
mlir/include/mlir/Query/Matcher/MatchFinder.h | 25 +++--
mlir/lib/Query/Matcher/CMakeLists.txt | 1 -
mlir/lib/Query/Matcher/ExtraMatchers.cpp | 66 ------------
mlir/lib/Query/Matcher/Parser.h | 5 +-
mlir/lib/Query/QueryParser.cpp | 1 +
mlir/test/mlir-query/complex-test.mlir | 2 +-
mlir/tools/mlir-query/mlir-query.cpp | 5 +-
8 files changed, 99 insertions(+), 108 deletions(-)
delete mode 100644 mlir/lib/Query/Matcher/ExtraMatchers.cpp
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
index 4766a767cf783..48cab7760d5cf 100644
--- a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -12,16 +12,15 @@
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
#define MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
+
#include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Query/Matcher/MatchersInternal.h"
-/// A matcher encapsulating the initial `getBackwardSlice` method from
-/// SliceAnalysis.h
+/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
/// Additionally, it limits the slice computation to a certain depth level using
-/// a custom filter
+/// a custom filter.
///
-/// Example starting from node 9, assuming the matcher
-/// computes the slice for the first two depth levels
+/// Example: starting from node 9, assuming the matcher
+/// computes the slice for the first two depth levels:
/// ============================
/// 1 2 3 4
/// |_______| |______|
@@ -37,18 +36,23 @@
/// Assuming all local orders match the numbering order:
/// {5, 7, 6, 8, 9}
namespace mlir::query::matcher {
+
+template <typename Matcher>
class BackwardSliceMatcher {
public:
- explicit BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
- int64_t maxDepth, bool inclusive,
- bool omitBlockArguments, bool omitUsesFromAbove)
+ BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
+ bool omitBlockArguments, bool omitUsesFromAbove)
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
inclusive(inclusive), omitBlockArguments(omitBlockArguments),
omitUsesFromAbove(omitUsesFromAbove) {}
- bool match(Operation *op, SetVector<Operation *> &backwardSlice) {
+
+ bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
BackwardSliceOptions options;
- return (innerMatcher.match(op) &&
- matches(op, backwardSlice, options, maxDepth));
+ options.inclusive = inclusive;
+ options.omitUsesFromAbove = omitUsesFromAbove;
+ options.omitBlockArguments = omitBlockArguments;
+ return (innerMatcher.match(rootOp) &&
+ matches(rootOp, backwardSlice, options, maxDepth));
}
private:
@@ -57,29 +61,75 @@ class BackwardSliceMatcher {
private:
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
- // to determine whether we want to traverse the DAG or not. For example, we
- // want to explore the DAG only if the top-level operation name is
- // "arith.addf".
- query::matcher::DynMatcher innerMatcher;
- // maxDepth specifies the maximum depth that the matcher can traverse in the
- // DAG. For example, if maxDepth is 2, the matcher will explore the defining
+ // to determine whether we want to traverse the IR or not. For example, we
+ // want to explore the IR only if the top-level operation name is
+ // `"arith.addf"`.
+ Matcher innerMatcher;
+ // `maxDepth` specifies the maximum depth that the matcher can traverse the
+ // IR. For example, if `maxDepth` is 2, the matcher will explore the defining
// operations of the top-level op up to 2 levels.
int64_t maxDepth;
-
bool inclusive;
bool omitBlockArguments;
bool omitUsesFromAbove;
};
-// Matches transitive defs of a top level operation up to N levels
-inline BackwardSliceMatcher
-m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth,
- bool inclusive, bool omitBlockArguments,
- bool omitUsesFromAbove) {
+template <typename Matcher>
+bool BackwardSliceMatcher<Matcher>::matches(
+ Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
+ BackwardSliceOptions &options, int64_t maxDepth) {
+ backwardSlice.clear();
+ llvm::DenseMap<Operation *, int64_t> opDepths;
+ // The starting point is the root op; therefore, we set its depth to 0.
+ opDepths[rootOp] = 0;
+ options.filter = [&](Operation *subOp) {
+ // If the subOp's depth exceeds maxDepth, we stop further slicing for this
+ // branch.
+ if (opDepths[subOp] > maxDepth)
+ return false;
+ // Examine subOp's operands to compute depths of their defining operations.
+ for (auto operand : subOp->getOperands()) {
+ if (auto definingOp = operand.getDefiningOp()) {
+ // Set the defining operation's depth to one level greater than
+ // subOp's depth.
+ int64_t newDepth = opDepths[subOp] + 1;
+ if (!opDepths.contains(definingOp)) {
+ opDepths[definingOp] = newDepth;
+ } else {
+ opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
+ }
+ return !(opDepths[subOp] > maxDepth);
+ } else {
+ auto blockArgument = cast<BlockArgument>(operand);
+ Operation *parentOp = blockArgument.getOwner()->getParentOp();
+ if (!parentOp)
+ continue;
+ int64_t newDepth = opDepths[subOp] + 1;
+ if (!opDepths.contains(parentOp)) {
+ opDepths[parentOp] = newDepth;
+ } else {
+ opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
+ }
+ return !(opDepths[parentOp] > maxDepth);
+ }
+ }
+ return true;
+ };
+ getBackwardSlice(rootOp, &backwardSlice, options);
+ return true;
+}
+
+// Matches transitive defs of a top-level operation up to N levels.
+template <typename Matcher>
+inline BackwardSliceMatcher<Matcher>
+m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
+ bool omitBlockArguments, bool omitUsesFromAbove) {
assert(maxDepth >= 0 && "maxDepth must be non-negative");
- return BackwardSliceMatcher(std::move(innerMatcher), maxDepth, inclusive,
- omitBlockArguments, omitUsesFromAbove);
+ return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth,
+ inclusive, omitBlockArguments,
+ omitUsesFromAbove);
}
+
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 6b554394b3654..f8abf20ef60bb 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -21,32 +21,35 @@
namespace mlir::query::matcher {
-/// A class that provides utilities to find operations in a DAG
+/// A class that provides utilities to find operations in the IR.
class MatchFinder {
public:
- /// A subclass which preserves the matching information
+ /// A subclass which preserves the matching information. Each instance
+ /// contains the `rootOp` along with the matching environment.
struct MatchResult {
MatchResult() = default;
MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
- /// Contains the root operation of the matching environment
Operation *rootOp = nullptr;
- /// Contains the matching enviroment. This allows the user to easily
- /// extract the matched operations
+ /// Contains the matching environment.
std::vector<Operation *> matchedOps;
};
- /// Traverses the DAG and collects the "rootOp" + "matching enviroment" for
- /// a given Matcher
+
+ /// Traverses the IR and returns a vector of `MatchResult` for each match of
+ /// the `matcher`.
std::vector<MatchResult> collectMatches(Operation *root,
DynMatcher matcher) const;
- /// Prints the matched operation
+
+ /// Prints the matched operation.
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
- /// Labels the matched operation with the given binding (e.g., "root") and
- /// prints it
+
+ /// Labels the matched operation with the given binding (e.g., `"root"`) and
+ /// prints it.
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
const std::string &binding) const;
- /// Flattens a vector of MatchResults into a vector of operations
+
+ /// Flattens a vector of `MatchResult` into a vector of operations.
std::vector<Operation *>
flattenMatchedOps(std::vector<MatchResult> &matches) const;
};
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index d84b1b50e8b04..629479bf7adc1 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_library(MLIRQueryMatcher
MatchFinder.cpp
- ExtraMatchers.cpp
Parser.cpp
RegistryManager.cpp
VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/ExtraMatchers.cpp b/mlir/lib/Query/Matcher/ExtraMatchers.cpp
deleted file mode 100644
index 1c69995a5d690..0000000000000
--- a/mlir/lib/Query/Matcher/ExtraMatchers.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-//===- ExtraMatchers.cpp - Various common matchers ---------------*- C++-*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements specific matchers
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Query/Matcher/ExtraMatchers.h"
-
-namespace mlir::query::matcher {
-
-bool BackwardSliceMatcher::matches(Operation *rootOp,
- llvm::SetVector<Operation *> &backwardSlice,
- BackwardSliceOptions &options,
- int64_t maxDepth) {
- options.inclusive = inclusive;
- options.omitUsesFromAbove = omitUsesFromAbove;
- options.omitBlockArguments = omitBlockArguments;
- backwardSlice.clear();
- llvm::DenseMap<Operation *, int64_t> opDepths;
- // The starting point is the root op, therfore we set its depth to 0
- opDepths[rootOp] = 0;
- options.filter = [&](Operation *subOp) {
- // If the subOp’s depth exceeds maxDepth, we can stop further computing the
- // slice for the current branch
- if (opDepths[subOp] > maxDepth)
- return false;
- // Examining subOp's operands to compute the depths of their defining
- // operations
- for (auto operand : subOp->getOperands()) {
- if (auto definingOp = operand.getDefiningOp()) {
- // Set the defining operation's depth to one level greater than
- // subOp's depth
- int64_t newDepth = opDepths[subOp] + 1;
- if (!opDepths.contains(definingOp)) {
- opDepths[definingOp] = newDepth;
- } else {
- opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
- }
- return !(opDepths[subOp] > maxDepth);
- } else {
- auto blockArgument = cast<BlockArgument>(operand);
- Operation *parentOp = blockArgument.getOwner()->getParentOp();
- if (!parentOp)
- continue;
- int64_t newDepth = opDepths[subOp] + 1;
- if (!opDepths.contains(parentOp)) {
- opDepths[parentOp] = newDepth;
- } else {
- opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
- }
- return !(opDepths[parentOp] > maxDepth);
- }
- }
- return true;
- };
- getBackwardSlice(rootOp, &backwardSlice, options);
- return true;
-}
-
-} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h
index 58968023022d5..2199a2335ba9c 100644
--- a/mlir/lib/Query/Matcher/Parser.h
+++ b/mlir/lib/Query/Matcher/Parser.h
@@ -16,8 +16,11 @@
// provided to the parser.
//
// The grammar for the supported expressions is as follows:
-// <Expression> := <StringLiteral> | <MatcherExpression>
+// <Expression> := <Literal> | <MatcherExpression>
+// <Literal> := <StringLiteral> | <NumericLiteral> | <BooleanLiteral>
// <StringLiteral> := "quoted string"
+// <BooleanLiteral> := "true" | "false"
+// <NumericLiteral> := [0-9]+
// <MatcherExpression> := <MatcherName>(<ArgumentList>)
// <MatcherName> := [a-zA-Z]+
// <ArgumentList> := <Expression> | <Expression>,<ArgumentList>
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 3990b697ead7f..31aead7d403d0 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -166,6 +166,7 @@ QueryRef QueryParser::doParse() {
case ParsedQueryKind::Quit:
return endQuery(new QuitQuery);
+
case ParsedQueryKind::Match: {
if (completionPos) {
return completeMatcherExpression();
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
index d18a9cc1d1550..3e0bf8b8b9fa6 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -26,7 +26,7 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
// CHECK: Match #2:
-// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 1d392e5f0dcfd..0cc9a5db25a91 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -40,8 +40,9 @@ int main(int argc, char **argv) {
query::matcher::Registry matcherRegistry;
// Matchers registered in alphabetical order for consistency:
- matcherRegistry.registerMatcher("getDefinitions",
- query::matcher::m_GetDefinitions);
+ matcherRegistry.registerMatcher(
+ "getDefinitions",
+ query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
More information about the Mlir-commits
mailing list