[Mlir-commits] [mlir] [mlir] MLIR-QUERY slice-matchers implementation (PR #115670)
Denzel Budii
llvmlistbot at llvm.org
Sat Mar 29 08:20:09 PDT 2025
https://github.com/dbudii updated https://github.com/llvm/llvm-project/pull/115670
>From 4f34c45b5d13918f7f757b9084ce724898649406 Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Sat, 25 Jan 2025 13:38:31 +0000
Subject: [PATCH 1/5] Fixed pattern matching in mlir-query test files & removed
asserts from slice-matchers
---
mlir/include/mlir/IR/Matchers.h | 4 +-
.../mlir/Query/Matcher/ExtraMatchers.h | 188 ++++++++++++++++++
mlir/include/mlir/Query/Matcher/Marshallers.h | 15 ++
mlir/include/mlir/Query/Matcher/MatchFinder.h | 52 +++--
.../mlir/Query/Matcher/MatchersInternal.h | 60 +++++-
.../include/mlir/Query/Matcher/VariantValue.h | 12 +-
mlir/include/mlir/Query/Query.h | 34 +++-
mlir/include/mlir/Query/QuerySession.h | 11 +-
mlir/lib/Query/Matcher/Parser.cpp | 36 ++++
mlir/lib/Query/Matcher/RegistryManager.cpp | 2 +
mlir/lib/Query/Matcher/VariantValue.cpp | 24 +++
mlir/lib/Query/Query.cpp | 37 ++--
mlir/lib/Query/QueryParser.cpp | 52 ++++-
mlir/lib/Query/QueryParser.h | 2 +-
mlir/test/mlir-query/complex-test.mlir | 39 ++++
mlir/test/mlir-query/function-extraction.mlir | 2 +-
mlir/tools/mlir-query/mlir-query.cpp | 10 +
17 files changed, 529 insertions(+), 51 deletions(-)
create mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h
create mode 100644 mlir/test/mlir-query/complex-test.mlir
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 1dce055db1b4a..2204a68be26b1 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -59,7 +59,7 @@ struct NameOpMatcher {
NameOpMatcher(StringRef name) : name(name) {}
bool match(Operation *op) { return op->getName().getStringRef() == name; }
- StringRef name;
+ std::string name;
};
/// The matcher that matches operations that have the specified attribute name.
@@ -67,7 +67,7 @@ struct AttrOpMatcher {
AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
bool match(Operation *op) { return op->hasAttr(attrName); }
- StringRef attrName;
+ std::string attrName;
};
/// The matcher that matches operations that have the `ConstantLike` trait, and
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 0000000000000..908fccfc704c3
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,188 @@
+//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides extra matchers that are very useful for mlir-query
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTRAMATCHERS_H
+#define MLIR_IR_EXTRAMATCHERS_H
+
+#include "MatchFinder.h"
+#include "MatchersInternal.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Query/Query.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+namespace query {
+
+namespace extramatcher {
+
+namespace detail {
+
+class BackwardSliceMatcher {
+public:
+ BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+ bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+ QueryOptions &options, unsigned tempHops) {
+
+ if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ return false;
+ }
+
+ auto processValue = [&](Value value) {
+ if (tempHops == 0) {
+ return;
+ }
+ if (auto *definingOp = value.getDefiningOp()) {
+ if (backwardSlice.count(definingOp) == 0)
+ matches(definingOp, backwardSlice, options, tempHops - 1);
+ } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+ if (options.omitBlockArguments)
+ return;
+ Block *block = blockArg.getOwner();
+
+ Operation *parentOp = block->getParentOp();
+
+ if (parentOp && backwardSlice.count(parentOp) == 0) {
+ if (parentOp->getNumRegions() != 1 &&
+ parentOp->getRegion(0).getBlocks().size() != 1) {
+ llvm::errs()
+ << "Error: Expected parentOp to have exactly one region and "
+ << "exactly one block, but found " << parentOp->getNumRegions()
+ << " regions and "
+ << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
+ };
+ matches(parentOp, backwardSlice, options, tempHops - 1);
+ }
+ } else {
+ llvm::errs() << "No definingOp and not a block argument\n";
+ return;
+ }
+ };
+
+ if (!options.omitUsesFromAbove) {
+ llvm::for_each(op->getRegions(), [&](Region ®ion) {
+ SmallPtrSet<Region *, 4> descendents;
+ region.walk(
+ [&](Region *childRegion) { descendents.insert(childRegion); });
+ region.walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ if (!descendents.contains(operand.get().getParentRegion()))
+ processValue(operand.get());
+ }
+ });
+ });
+ }
+
+ llvm::for_each(op->getOperands(), processValue);
+ backwardSlice.insert(op);
+ return true;
+ }
+
+public:
+ bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+ QueryOptions &options) {
+
+ if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
+ if (!options.inclusive) {
+ backwardSlice.remove(op);
+ }
+ return true;
+ }
+ return false;
+ }
+
+private:
+ matcher::DynMatcher innerMatcher;
+ unsigned hops;
+};
+
+class ForwardSliceMatcher {
+public:
+ ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+ bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+ QueryOptions &options, unsigned tempHops) {
+
+ if (tempHops == 0) {
+ forwardSlice.insert(op);
+ return true;
+ }
+
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ for (Operation &blockOp : block)
+ if (forwardSlice.count(&blockOp) == 0)
+ matches(&blockOp, forwardSlice, options, tempHops - 1);
+ for (Value result : op->getResults()) {
+ for (Operation *userOp : result.getUsers())
+ if (forwardSlice.count(userOp) == 0)
+ matches(userOp, forwardSlice, options, tempHops - 1);
+ }
+
+ forwardSlice.insert(op);
+ return true;
+ }
+
+public:
+ bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+ QueryOptions &options) {
+ if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
+ if (!options.inclusive) {
+ forwardSlice.remove(op);
+ }
+ SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+ forwardSlice.insert(v.rbegin(), v.rend());
+ return true;
+ }
+ return false;
+ }
+
+private:
+ matcher::DynMatcher innerMatcher;
+ unsigned hops;
+};
+
+} // namespace detail
+
+inline detail::BackwardSliceMatcher
+definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::BackwardSliceMatcher
+getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+inline detail::ForwardSliceMatcher
+usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::ForwardSliceMatcher
+getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+} // namespace extramatcher
+
+} // namespace query
+
+} // namespace mlir
+
+#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc..c775dbc5c86da 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
}
};
+template <>
+struct ArgTypeTraits<unsigned> {
+ static bool hasCorrectType(const VariantValue &value) {
+ return value.isUnsigned();
+ }
+
+ static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+
+ static ArgKind getKind() { return ArgKind::Unsigned; }
+
+ static std::optional<std::string> getBestGuess(const VariantValue &) {
+ return std::nullopt;
+ }
+};
+
template <>
struct ArgTypeTraits<DynMatcher> {
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2..1b9d3bc307ff5 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
+// that match a given matcher and print them.
//
//===----------------------------------------------------------------------===//
@@ -15,24 +15,52 @@
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#include "MatchersInternal.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir::query::matcher {
-// MatchFinder is used to find all operations that match a given matcher.
class MatchFinder {
-public:
- // Returns all operations that match the given matcher.
- static std::vector<Operation *> getMatches(Operation *root,
- DynMatcher matcher) {
- std::vector<Operation *> matches;
+private:
+ static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ mlir::Operation *op, const std::string &binding) {
+ auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+ auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+ "\"" + binding + "\" binds here");
+ };
- // Simple match finding with walk.
+public:
+ static std::vector<Operation *>
+ getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+ llvm::raw_ostream &os, QuerySession &qs) {
+ unsigned matchCount = 0;
+ std::vector<Operation *> matchedOps;
+ SetVector<Operation *> tempStorage;
+ os << "\n";
root->walk([&](Operation *subOp) {
- if (matcher.match(subOp))
- matches.push_back(subOp);
+ if (matcher.match(subOp)) {
+ matchedOps.push_back(subOp);
+ os << "Match #" << ++matchCount << ":\n\n";
+ printMatch(os, qs, subOp, "root");
+ } else {
+ SmallVector<Operation *> printingOps;
+ if (matcher.match(subOp, tempStorage, options)) {
+ os << "Match #" << ++matchCount << ":\n\n";
+ SmallVector<Operation *> printingOps(tempStorage.takeVector());
+ for (auto op : printingOps) {
+ printMatch(os, qs, op, "root");
+ matchedOps.push_back(op);
+ }
+ printingOps.clear();
+ }
+ }
});
-
- return matches;
+ os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+ return matchedOps;
}
};
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e..b532b47be7d05 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,4 +1,3 @@
-//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,28 +9,53 @@
//
// Matchers are methods that return a Matcher which provides a method
// match(Operation *op)
+// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
+// &options)
//
// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
//===----------------------------------------------------------------------===//
-
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
+namespace mlir {
+namespace query {
+struct QueryOptions;
+}
+} // namespace mlir
+
namespace mlir::query::matcher {
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>()))>>
+ : std::true_type {};
+
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>(),
+ std::declval<SetVector<Operation *> &>(),
+ std::declval<QueryOptions &>()))>>
+ : std::true_type {};
// Generic interface for matchers on an MLIR operation.
class MatcherInterface
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
public:
virtual ~MatcherInterface() = default;
-
virtual bool match(Operation *op) = 0;
+ virtual bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) = 0;
};
// MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +64,26 @@ template <typename MatcherFn>
class MatcherFnImpl : public MatcherInterface {
public:
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
- bool match(Operation *op) override { return matcherFn.match(op); }
+
+ bool match(Operation *op) override {
+ if constexpr (has_simple_match<MatcherFn>::value)
+ return matcherFn.match(op);
+ return false;
+ }
+
+ bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) override {
+ if constexpr (has_bound_match<MatcherFn>::value)
+ return matcherFn.match(op, matchedOps, options);
+ return false;
+ }
private:
MatcherFn matcherFn;
};
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
class DynMatcher {
public:
// Takes ownership of the provided implementation pointer.
@@ -62,12 +98,14 @@ class DynMatcher {
}
bool match(Operation *op) const { return implementation->match(op); }
+ bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) const {
+ return implementation->match(op, matchedOps, options);
+ }
- void setFunctionName(StringRef name) { functionName = name.str(); };
-
- bool hasFunctionName() const { return !functionName.empty(); };
-
- StringRef getFunctionName() const { return functionName; };
+ void setFunctionName(StringRef name) { functionName = name.str(); }
+ bool hasFunctionName() const { return !functionName.empty(); }
+ StringRef getFunctionName() const { return functionName; }
private:
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e02..6b57119df7a9b 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
namespace mlir::query::matcher {
// All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Matcher, String, Unsigned };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
@@ -81,6 +81,7 @@ class VariantValue {
// Specific constructors for each supported type.
VariantValue(const llvm::StringRef string);
VariantValue(const VariantMatcher &matcher);
+ VariantValue(unsigned Unsigned);
// String value functions.
bool isString() const;
@@ -92,8 +93,15 @@ class VariantValue {
const VariantMatcher &getMatcher() const;
void setMatcher(const VariantMatcher &matcher);
+ // Unsigned value functions.
+ bool isUnsigned() const;
+ unsigned getUnsigned() const;
+ void setUnsigned(unsigned Unsigned);
+
// String representation of the type of the value.
std::string getTypeAsString() const;
+ explicit operator bool() const { return hasValue(); }
+ bool hasValue() const { return type != ValueType::Nothing; }
private:
void reset();
@@ -103,12 +111,14 @@ class VariantValue {
Nothing,
String,
Matcher,
+ Unsigned,
};
// All supported value types.
union AllValues {
llvm::StringRef *String;
VariantMatcher *Matcher;
+ unsigned Unsigned;
};
ValueType type;
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 18f2172c9510a..bb5b98432d51c 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -17,7 +17,13 @@
namespace mlir::query {
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+struct QueryOptions {
+ bool omitBlockArguments = false;
+ bool omitUsesFromAbove = true;
+ bool inclusive = true;
+};
+
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit, SetBool };
class QuerySession;
@@ -103,6 +109,32 @@ struct MatchQuery : Query {
}
};
+template <typename T>
+struct SetQueryKind {};
+
+template <>
+struct SetQueryKind<bool> {
+ static const QueryKind value = QueryKind::SetBool;
+};
+template <typename T>
+struct SetQuery : Query {
+ SetQuery(T QuerySession::*var, T value)
+ : Query(SetQueryKind<T>::value), var(var), value(value) {}
+
+ llvm::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override {
+ qs.*var = value;
+ return mlir::success();
+ }
+
+ static bool classof(const Query *query) {
+ return query->kind == SetQueryKind<T>::value;
+ }
+
+ T QuerySession::*var;
+ T value;
+};
+
} // namespace mlir::query
#endif
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index fe552d750fc77..495358e8f36f9 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -9,14 +9,18 @@
#ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
+#include "Matcher/VariantValue.h"
#include "mlir/IR/Operation.h"
#include "mlir/Query/Matcher/Registry.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/SourceMgr.h"
+namespace mlir::query::matcher {
+class Registry;
+}
+
namespace mlir::query {
-class Registry;
// Represents the state for a particular mlir-query session.
class QuerySession {
public:
@@ -33,6 +37,11 @@ class QuerySession {
llvm::StringMap<matcher::VariantValue> namedValues;
bool terminate = false;
+public:
+ bool omitBlockArguments = false;
+ bool omitUsesFromAbove = true;
+ bool inclusive = true;
+
private:
Operation *rootOp;
llvm::SourceMgr &sourceMgr;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f..4dcb86a9383f3 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
case '\'':
consumeStringLiteral(&result);
break;
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ consumeNumberLiteral(&result);
+ break;
default:
parseIdentifierOrInvalid(&result);
break;
@@ -144,6 +156,30 @@ class Parser::CodeTokenizer {
return result;
}
+ void consumeNumberLiteral(TokenInfo *result) {
+ unsigned length = 1;
+ if (code.size() > 1) {
+ // Consume the 'x' or 'b' radix modifier, if present.
+ switch (tolower(code[1])) {
+ case 'x':
+ case 'b':
+ length = 2;
+ }
+ }
+ while (length < code.size() && isdigit(code[length]))
+ ++length;
+
+ result->text = code.take_front(length);
+ code = code.drop_front(length);
+
+ unsigned value;
+ if (!result->text.getAsInteger(0, value)) {
+ result->kind = TokenKind::Literal;
+ result->value = static_cast<unsigned>(value);
+ return;
+ }
+ }
+
// Consume a string literal, handle escape sequences and missing closing
// quote.
void consumeStringLiteral(TokenInfo *result) {
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 645db7109c2de..8d6c0135aa117 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -31,6 +31,8 @@ static std::string asArgString(ArgKind kind) {
return "Matcher";
case ArgKind::String:
return "String";
+ case ArgKind::Unsigned:
+ return "unsigned";
}
llvm_unreachable("Unhandled ArgKind");
}
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 65bd4bd77bcf8..d5218d8dad8c9 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,6 +56,10 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
value.Matcher = new VariantMatcher(matcher);
}
+VariantValue::VariantValue(unsigned Unsigned) : type(ValueType::Unsigned) {
+ value.Unsigned = Unsigned;
+}
+
VariantValue::~VariantValue() { reset(); }
VariantValue &VariantValue::operator=(const VariantValue &other) {
@@ -69,6 +73,9 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
case ValueType::Matcher:
setMatcher(other.getMatcher());
break;
+ case ValueType::Unsigned:
+ setUnsigned(other.getUnsigned());
+ break;
case ValueType::Nothing:
type = ValueType::Nothing;
break;
@@ -85,12 +92,27 @@ void VariantValue::reset() {
delete value.Matcher;
break;
// Cases that do nothing.
+ case ValueType::Unsigned:
case ValueType::Nothing:
break;
}
type = ValueType::Nothing;
}
+// Unsinged
+bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
+
+unsigned VariantValue::getUnsigned() const {
+ assert(isUnsigned());
+ return value.Unsigned;
+}
+
+void VariantValue::setUnsigned(unsigned newValue) {
+ reset();
+ type = ValueType::Unsigned;
+ value.Unsigned = newValue;
+}
+
bool VariantValue::isString() const { return type == ValueType::String; }
const llvm::StringRef &VariantValue::getString() const {
@@ -123,6 +145,8 @@ std::string VariantValue::getTypeAsString() const {
return "String";
case ValueType::Matcher:
return "Matcher";
+ case ValueType::Unsigned:
+ return "Unsigned";
case ValueType::Nothing:
return "Nothing";
}
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 7d9f360670051..dd699857568d7 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
@@ -26,15 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
return QueryParser::complete(line, pos, qs);
}
-static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
- const std::string &binding) {
- auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
- auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
- qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
- qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
- "\"" + binding + "\" binds here");
-}
-
// TODO: Extract into a helper function that can be reused outside query
// context.
static Operation *extractFunction(std::vector<Operation *> &ops,
@@ -99,6 +91,12 @@ static Operation *extractFunction(std::vector<Operation *> &ops,
return funcOp;
}
+static void parseQueryOptions(QuerySession &qs, QueryOptions &options) {
+ options.omitBlockArguments = qs.omitBlockArguments;
+ options.omitUsesFromAbove = qs.omitUsesFromAbove;
+ options.inclusive = qs.inclusive;
+}
+
Query::~Query() = default;
LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
@@ -114,6 +112,11 @@ LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
os << "Available commands:\n\n"
" match MATCHER, m MATCHER "
"Match the mlir against the given matcher.\n"
+ "Set query options, useful for complex matchers \n"
+ " set omitBlockArguments (true|false) \n"
+ " set omitUsesFromAbove (true|false) \n"
+ " set inclusive (true|false) \n"
+ "Give a matcher expression a name, to be used later\n"
" quit "
"Terminates the query session.\n\n";
return mlir::success();
@@ -126,9 +129,11 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
- int matchCount = 0;
- std::vector<Operation *> matches =
- matcher::MatchFinder().getMatches(rootOp, matcher);
+
+ QueryOptions options;
+ parseQueryOptions(qs, options);
+ auto matches = matcher::MatchFinder().getMatches(rootOp, options,
+ std::move(matcher), os, qs);
// An extract call is recognized by considering if the matcher has a name.
// TODO: Consider making the extract more explicit.
@@ -141,14 +146,6 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
return mlir::success();
}
- os << "\n";
- for (Operation *op : matches) {
- os << "Match #" << ++matchCount << ":\n\n";
- // Placeholder "root" binding for the initial draft.
- printMatch(os, qs, op, "root");
- }
- os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
-
return mlir::success();
}
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 31aead7d403d0..7aaf4847f2e47 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -107,6 +107,18 @@ QueryRef QueryParser::endQuery(QueryRef queryRef) {
return queryRef;
}
+QueryRef QueryParser::parseSetBool(bool QuerySession::*var) {
+ StringRef valStr;
+ unsigned value = LexOrCompleteWord<unsigned>(this, valStr)
+ .Case("false", 0)
+ .Case("true", 1)
+ .Default(~0u);
+ if (value == ~0u) {
+ return new InvalidQuery("expected 'true' or 'false', got '" + valStr + "'");
+ }
+ return new SetQuery<bool>(var, value);
+}
+
namespace {
enum class ParsedQueryKind {
@@ -116,6 +128,14 @@ enum class ParsedQueryKind {
Help,
Match,
Quit,
+ Set,
+};
+
+enum ParsedQueryVariable {
+ Invalid,
+ OmitBlockArguments,
+ OmitUsesFromAbove,
+ Inclusive,
};
QueryRef
@@ -147,6 +167,7 @@ QueryRef QueryParser::doParse() {
.Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
.Case("help", ParsedQueryKind::Help)
.Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
+ .Case("set", ParsedQueryKind::Set)
.Case("match", ParsedQueryKind::Match)
.Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false)
.Case("quit", ParsedQueryKind::Quit)
@@ -187,7 +208,36 @@ QueryRef QueryParser::doParse() {
query->remainingContent = matcherSource;
return query;
}
-
+ case ParsedQueryKind::Set: {
+ llvm::StringRef varStr;
+ ParsedQueryVariable var =
+ LexOrCompleteWord<ParsedQueryVariable>(this, varStr)
+ .Case("omitBlockArguments", ParsedQueryVariable::OmitBlockArguments)
+ .Case("omitUsesFromAbove", ParsedQueryVariable::OmitUsesFromAbove)
+ .Case("inclusive", ParsedQueryVariable::Inclusive)
+ .Default(ParsedQueryVariable::Invalid);
+ if (varStr.empty()) {
+ return new InvalidQuery("expected variable name");
+ }
+ if (var == ParsedQueryVariable::Invalid) {
+ return new InvalidQuery("unknown variable: '" + varStr + "'");
+ }
+ QueryRef query;
+ switch (var) {
+ case ParsedQueryVariable::OmitBlockArguments:
+ query = parseSetBool(&QuerySession::omitBlockArguments);
+ break;
+ case ParsedQueryVariable::OmitUsesFromAbove:
+ query = parseSetBool(&QuerySession::omitUsesFromAbove);
+ break;
+ case ParsedQueryVariable::Inclusive:
+ query = parseSetBool(&QuerySession::inclusive);
+ break;
+ case ParsedQueryVariable::Invalid:
+ llvm_unreachable("Invalid query kind");
+ }
+ return endQuery(query);
+ }
case ParsedQueryKind::Invalid:
return new InvalidQuery("unknown command: " + commandStr);
}
diff --git a/mlir/lib/Query/QueryParser.h b/mlir/lib/Query/QueryParser.h
index e9c30eccecab9..69cc5d0043d57 100644
--- a/mlir/lib/Query/QueryParser.h
+++ b/mlir/lib/Query/QueryParser.h
@@ -39,8 +39,8 @@ class QueryParser {
struct LexOrCompleteWord;
QueryRef completeMatcherExpression();
-
QueryRef endQuery(QueryRef queryRef);
+ QueryRef parseSetBool(bool QuerySession::*var);
// Parse [begin, end) and returns a reference to the parsed query object,
// which may be an InvalidQuery if a parse error occurs.
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
new file mode 100644
index 0000000000000..b3df534ee8871
--- /dev/null
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-query %s -c "match getDefinitions(hasOpName("arith.addf"),2)" | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
+ %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %2 = arith.addf %in, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<5x5xf32>
+ %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %c2 = arith.constant 2 : index
+ %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
+ %2 = arith.addf %extracted, %extracted : f32
+ linalg.yield %2 : f32
+ } -> tensor<5x5xf32>
+ return
+}
+
+// CHECK: Match #1:
+
+// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
+
+// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
+
+
+// CHECK: Match #2:
+
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%c2] : tensor<25xf32>
+// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
+
+
+
+
+
+
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index a783f65c6761b..d7a867eb1a452 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
+// RUN: mlir-query %s -c "m hasOpName("arith.mulf").extract("testmul")" | FileCheck %s
// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b..5e74da7ee7bdc 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -15,6 +15,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/ExtraMatchers.h"
+#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/Matcher/Registry.h"
#include "mlir/Tools/mlir-query/MlirQueryMain.h"
@@ -39,6 +41,14 @@ int main(int argc, char **argv) {
query::matcher::Registry matcherRegistry;
// Matchers registered in alphabetical order for consistency:
+ matcherRegistry.registerMatcher("getDefinitions",
+ mlir::query::extramatcher::getDefinitions);
+ matcherRegistry.registerMatcher("definedBy",
+ mlir::query::extramatcher::definedBy);
+ matcherRegistry.registerMatcher("usedBy", mlir::query::extramatcher::usedBy);
+ matcherRegistry.registerMatcher("getUses",
+ mlir::query::extramatcher::getUses);
+
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
>From ec431eb20f17d3a9bb17c4373c509d4cf291777a Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Wed, 23 Oct 2024 17:42:19 +0000
Subject: [PATCH 2/5] MLIR-QUERY: backwardSlice, forwardSlice & QueryOptions
added
---
mlir/lib/Query/Matcher/Parser.cpp | 18 +++++++++-----
mlir/lib/Query/Query.cpp | 9 +++++++
mlir/lib/Query/QueryParser.cpp | 24 +++++++++++++++++++
mlir/test/mlir-query/function-extraction.mlir | 4 ++--
mlir/tools/mlir-query/mlir-query.cpp | 1 -
5 files changed, 47 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 4dcb86a9383f3..726f1188d7e4c 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -293,13 +293,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
// Parse as a named value.
- auto namedValue =
- namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+ if (auto namedValue = namedValues ? namedValues->lookup(nameToken.text)
+ : VariantValue()) {
- if (!namedValue.isMatcher()) {
- error->addError(tokenizer->peekNextToken().range,
- ErrorType::ParserNotAMatcher);
- return false;
+ if (tokenizer->nextTokenKind() != TokenKind::Period) {
+ *value = namedValue;
+ return true;
+ }
+
+ if (!namedValue.isMatcher()) {
+ error->addError(tokenizer->peekNextToken().range,
+ ErrorType::ParserNotAMatcher);
+ return false;
+ }
}
if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index dd699857568d7..500fee50a1609 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -127,6 +127,15 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
return mlir::success();
}
+LogicalResult LetQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
+ if (value.hasValue()) {
+ qs.namedValues[name] = value;
+ } else {
+ qs.namedValues.erase(name);
+ }
+ return mlir::success();
+}
+
LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 7aaf4847f2e47..4350fb9a434d4 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -166,6 +166,8 @@ QueryRef QueryParser::doParse() {
.Case("", ParsedQueryKind::NoOp)
.Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
.Case("help", ParsedQueryKind::Help)
+ .Case("l", ParsedQueryKind::Let, /*isCompletion=*/false)
+ .Case("let", ParsedQueryKind::Let, /*isCompletion=*/false)
.Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
.Case("set", ParsedQueryKind::Set)
.Case("match", ParsedQueryKind::Match)
@@ -188,6 +190,27 @@ QueryRef QueryParser::doParse() {
case ParsedQueryKind::Quit:
return endQuery(new QuitQuery);
+ case ParsedQueryKind::Let: {
+ llvm::StringRef name = lexWord();
+
+ if (name.empty()) {
+ return new InvalidQuery("expected variable name");
+ }
+
+ if (completionPos) {
+ return completeMatcherExpression();
+ }
+
+ matcher::internal::Diagnostics diag;
+ matcher::VariantValue value;
+ if (!matcher::internal::Parser::parseExpression(
+ line, qs.getRegistryData(), &qs.namedValues, &value, &diag)) {
+ return makeInvalidQueryFromDiagnostics(diag);
+ }
+ QueryRef query = new LetQuery(name, value);
+ query->remainingContent = line;
+ return query;
+ }
case ParsedQueryKind::Match: {
if (completionPos) {
return completeMatcherExpression();
@@ -204,6 +227,7 @@ QueryRef QueryParser::doParse() {
}
auto actualSource = origMatcherSource.substr(0, origMatcherSource.size() -
matcherSource.size());
+
QueryRef query = new MatchQuery(actualSource, *matcher);
query->remainingContent = matcherSource;
return query;
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index d7a867eb1a452..5a20c09d02eb6 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -4,7 +4,7 @@
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
-// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
+// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32S
func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
%sum0 = arith.addf %a, %b : f32
@@ -16,4 +16,4 @@ func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
%sum2 = arith.addf %mul1, %b : f32
%mul2 = arith.mulf %sub2, %sum2 : f32
return %mul2 : f32
-}
+ }
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 5e74da7ee7bdc..468f948bec24c 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -10,7 +10,6 @@
// of the registered queries.
//
//===----------------------------------------------------------------------===//
-
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
>From 9cf7595da453256e645b49cc52fafe8b841821ac Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Sat, 25 Jan 2025 15:29:12 +0000
Subject: [PATCH 3/5] removed LetQuery implementation
---
mlir/lib/Query/Query.cpp | 9 ---------
mlir/lib/Query/QueryParser.cpp | 24 ------------------------
2 files changed, 33 deletions(-)
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 500fee50a1609..dd699857568d7 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -127,15 +127,6 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
return mlir::success();
}
-LogicalResult LetQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
- if (value.hasValue()) {
- qs.namedValues[name] = value;
- } else {
- qs.namedValues.erase(name);
- }
- return mlir::success();
-}
-
LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 4350fb9a434d4..53e8f91e657cb 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -166,8 +166,6 @@ QueryRef QueryParser::doParse() {
.Case("", ParsedQueryKind::NoOp)
.Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
.Case("help", ParsedQueryKind::Help)
- .Case("l", ParsedQueryKind::Let, /*isCompletion=*/false)
- .Case("let", ParsedQueryKind::Let, /*isCompletion=*/false)
.Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
.Case("set", ParsedQueryKind::Set)
.Case("match", ParsedQueryKind::Match)
@@ -189,28 +187,6 @@ QueryRef QueryParser::doParse() {
case ParsedQueryKind::Quit:
return endQuery(new QuitQuery);
-
- case ParsedQueryKind::Let: {
- llvm::StringRef name = lexWord();
-
- if (name.empty()) {
- return new InvalidQuery("expected variable name");
- }
-
- if (completionPos) {
- return completeMatcherExpression();
- }
-
- matcher::internal::Diagnostics diag;
- matcher::VariantValue value;
- if (!matcher::internal::Parser::parseExpression(
- line, qs.getRegistryData(), &qs.namedValues, &value, &diag)) {
- return makeInvalidQueryFromDiagnostics(diag);
- }
- QueryRef query = new LetQuery(name, value);
- query->remainingContent = line;
- return query;
- }
case ParsedQueryKind::Match: {
if (completionPos) {
return completeMatcherExpression();
>From 410c5c9b1dc5fd07b07246abddfa4ca24e2ec6a8 Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Sun, 23 Feb 2025 15:01:25 +0000
Subject: [PATCH 4/5] Enhance matcher and QueryOptions documentation
- Enhance docs for matchers and QueryOptions
- Fix whitespace and alignment issues
- Move matchers to Matchers.h
- Change data type from unsigned to signed for arithmetic operations
---
mlir/include/mlir/IR/Matchers.h | 262 ++++++++++++++++++
.../mlir/Query/Matcher/ExtraMatchers.h | 188 -------------
mlir/include/mlir/Query/Matcher/Marshallers.h | 8 +-
mlir/include/mlir/Query/Matcher/MatchFinder.h | 63 ++++-
.../mlir/Query/Matcher/MatchersInternal.h | 21 +-
.../include/mlir/Query/Matcher/VariantValue.h | 16 +-
mlir/include/mlir/Query/Query.h | 20 +-
mlir/include/mlir/Query/QuerySession.h | 4 -
mlir/lib/Query/Matcher/RegistryManager.cpp | 9 +-
mlir/lib/Query/Matcher/VariantValue.cpp | 28 +-
mlir/test/mlir-query/complex-test.mlir | 13 +-
mlir/test/mlir-query/function-extraction.mlir | 6 +-
mlir/tools/mlir-query/mlir-query.cpp | 13 +-
13 files changed, 383 insertions(+), 268 deletions(-)
delete mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 2204a68be26b1..ee9e2afb10bad 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -19,6 +19,9 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Query/Matcher/MatchersInternal.h"
+#include "mlir/Query/Query.h"
+#include "llvm/ADT/SetVector.h"
namespace mlir {
@@ -363,8 +366,267 @@ struct RecursivePatternMatcher {
std::tuple<OperandMatchers...> operandMatchers;
};
+/// Fills `backwardSlice` with the computed backward slice (i.e.
+/// all the transitive defs of op)
+///
+/// The implementation traverses the def chains in postorder traversal for
+/// efficiency reasons: if an operation is already in `backwardSlice`, no
+/// need to traverse its definitions again. Since use-def chains form a DAG,
+/// this terminates.
+///
+/// Upon return to the root call, `backwardSlice` is filled with a
+/// postorder list of defs. This happens to be a topological order, from the
+/// point of view of the use-def chains.
+///
+/// Example starting from node 8
+/// ============================
+///
+/// 1 2 3 4
+/// |_______| |______|
+/// | | |
+/// | 5 6
+/// |___|_____________|
+/// | |
+/// 7 8
+/// |_______________|
+/// |
+/// 9
+///
+/// Assuming all local orders match the numbering order:
+/// {1, 2, 5, 3, 4, 6}
+///
+
+class BackwardSliceMatcher {
+public:
+ BackwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
+ int64_t maxDepth)
+ : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
+
+ bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+ mlir::query::QueryOptions &options) {
+
+ if (innerMatcher.match(op) &&
+ matches(op, backwardSlice, options, maxDepth)) {
+ if (!options.inclusive) {
+ // Don't insert the top level operation, we just queried on it and don't
+ // want it in the results.
+ backwardSlice.remove(op);
+ }
+ return true;
+ }
+ return false;
+ }
+
+private:
+ bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+ mlir::query::QueryOptions &options, int64_t remainingDepth) {
+
+ if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ return false;
+ }
+
+ auto processValue = [&](Value value) {
+ // We need to check the current depth level;
+ // if we have reached level 0, we stop further traversing
+ if (remainingDepth == 0) {
+ return;
+ }
+ if (auto *definingOp = value.getDefiningOp()) {
+ // We omit traversing the same operations
+ if (backwardSlice.count(definingOp) == 0)
+ matches(definingOp, backwardSlice, options, remainingDepth - 1);
+ } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+ if (options.omitBlockArguments)
+ return;
+ Block *block = blockArg.getOwner();
+
+ Operation *parentOp = block->getParentOp();
+ // TODO: determine whether we want to recurse backward into the other
+ // blocks of parentOp, which are not technically backward unless they
+ // flow into us. For now, just bail.
+ if (parentOp && backwardSlice.count(parentOp) == 0) {
+ if (parentOp->getNumRegions() != 1 &&
+ parentOp->getRegion(0).getBlocks().size() != 1) {
+ llvm::errs()
+ << "Error: Expected parentOp to have exactly one region and "
+ << "exactly one block, but found " << parentOp->getNumRegions()
+ << " regions and "
+ << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
+ };
+ matches(parentOp, backwardSlice, options, remainingDepth - 1);
+ }
+ } else {
+ llvm_unreachable("No definingOp and not a block argument\n");
+ return;
+ }
+ };
+
+ if (!options.omitUsesFromAbove) {
+ llvm::for_each(op->getRegions(), [&](Region ®ion) {
+ // Walk this region recursively to collect the regions that descend from
+ // this op's nested regions (inclusive).
+ SmallPtrSet<Region *, 4> descendents;
+ region.walk(
+ [&](Region *childRegion) { descendents.insert(childRegion); });
+ region.walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ if (!descendents.contains(operand.get().getParentRegion()))
+ processValue(operand.get());
+ }
+ });
+ });
+ }
+
+ llvm::for_each(op->getOperands(), processValue);
+ backwardSlice.insert(op);
+ return true;
+ }
+
+private:
+ // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
+ // to determine whether we want to traverse the DAG or not. For example, we
+ // want to explore the DAG only if the top-level operation name is
+ // "arith.addf".
+ mlir::query::matcher::DynMatcher innerMatcher;
+
+ // maxDepth specifies the maximum depth that the matcher can traverse in the
+ // DAG. For example, if maxDepth is 2, the matcher will explore the defining
+ // operations of the top-level op up to 2 levels.
+ int64_t maxDepth;
+};
+
+/// Fills `forwardSlice` with the computed forward slice (i.e. all
+/// the transitive uses of op)
+///
+///
+/// The implementation traverses the use chains in postorder traversal for
+/// efficiency reasons: if an operation is already in `forwardSlice`, no
+/// need to traverse its uses again. Since use-def chains form a DAG, this
+/// terminates.
+///
+/// Upon return to the root call, `forwardSlice` is filled with a
+/// postorder list of uses (i.e. a reverse topological order). To get a proper
+/// topological order, we just reverse the order in `forwardSlice` before
+/// returning.
+///
+/// Example starting from node 0
+/// ============================
+///
+/// 0
+/// ___________|___________
+/// 1 2 3 4
+/// |_______| |______|
+/// | | |
+/// | 5 6
+/// |___|_____________|
+/// | |
+/// 7 8
+/// |_______________|
+/// |
+/// 9
+///
+/// Assuming all local orders match the numbering order:
+/// 1. after getting back to the root getForwardSlice, `forwardSlice` may
+/// contain:
+/// {9, 7, 8, 5, 1, 2, 6, 3, 4}
+/// 2. reversing the result of 1. gives:
+/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
+///
+class ForwardSliceMatcher {
+public:
+ ForwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
+ int64_t maxDepth)
+ : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
+
+ bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+ mlir::query::QueryOptions &options) {
+ if (innerMatcher.match(op) &&
+ matches(op, forwardSlice, options, maxDepth)) {
+ if (!options.inclusive) {
+ // Don't insert the top level operation, we just queried on it and don't
+ // want it in the results.
+ forwardSlice.remove(op);
+ }
+ // Reverse to get back the actual topological order.
+ // std::reverse does not work out of the box on SetVector and I want an
+ // in-place swap based thing (the real std::reverse, not the LLVM
+ // adapter).
+ SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+ forwardSlice.insert(v.rbegin(), v.rend());
+ return true;
+ }
+ return false;
+ }
+
+private:
+ bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+ mlir::query::QueryOptions &options, int64_t remainingDepth) {
+
+ // We need to check the current depth level;
+ // if we have reached level 0, we stop further traversing and insert
+ // the last user in def-use chain
+ if (remainingDepth == 0) {
+ forwardSlice.insert(op);
+ return true;
+ }
+
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ for (Operation &blockOp : block)
+ if (forwardSlice.count(&blockOp) == 0)
+ matches(&blockOp, forwardSlice, options, remainingDepth - 1);
+ for (Value result : op->getResults()) {
+ for (Operation *userOp : result.getUsers())
+ // We omit traversing the same operations
+ if (forwardSlice.count(userOp) == 0)
+ matches(userOp, forwardSlice, options, remainingDepth - 1);
+ }
+
+ forwardSlice.insert(op);
+ return true;
+ }
+
+private:
+ // The outer matcher e.g (ForwardSliceMatcher) relies on the innerMatcher to
+ // determine whether we want to traverse the graph or not. E.g: we want to
+ // explore the DAG only if the top level operation name is "arith.addf"
+ mlir::query::matcher::DynMatcher innerMatcher;
+
+ // maxDepth specifies the maximum depth that the matcher can traverse the
+ // graph E.g: if maxDepth is 2, the matcher will explore the user
+ // operations of the top level op up to 2 levels
+ int64_t maxDepth;
+};
+
} // namespace detail
+// Matches transitive defs of a top level operation up to 1 level
+inline detail::BackwardSliceMatcher
+m_DefinedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+// Matches transitive defs of a top level operation up to N levels
+inline detail::BackwardSliceMatcher
+m_GetDefinitions(mlir::query::matcher::DynMatcher innerMatcher,
+ int64_t maxDepth) {
+ assert(maxDepth >= 0 && "maxDepth must be non-negative");
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), maxDepth);
+}
+
+// Matches uses of a top level operation up to 1 level
+inline detail::ForwardSliceMatcher
+m_UsedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+// Matches uses of a top level operation up to N levels
+inline detail::ForwardSliceMatcher
+m_GetUses(mlir::query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
+ assert(maxDepth >= 0 && "maxDepth must be non-negative");
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), maxDepth);
+}
+
/// Matches a constant foldable operation.
inline detail::constant_op_matcher m_Constant() {
return detail::constant_op_matcher();
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
deleted file mode 100644
index 908fccfc704c3..0000000000000
--- a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
+++ /dev/null
@@ -1,188 +0,0 @@
-//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
-//-*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file provides extra matchers that are very useful for mlir-query
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_EXTRAMATCHERS_H
-#define MLIR_IR_EXTRAMATCHERS_H
-
-#include "MatchFinder.h"
-#include "MatchersInternal.h"
-#include "mlir/IR/Region.h"
-#include "mlir/Query/Query.h"
-#include "llvm/Support/raw_ostream.h"
-
-namespace mlir {
-
-namespace query {
-
-namespace extramatcher {
-
-namespace detail {
-
-class BackwardSliceMatcher {
-public:
- BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
- : innerMatcher(std::move(innerMatcher)), hops(hops) {}
-
-private:
- bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
- QueryOptions &options, unsigned tempHops) {
-
- if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
- return false;
- }
-
- auto processValue = [&](Value value) {
- if (tempHops == 0) {
- return;
- }
- if (auto *definingOp = value.getDefiningOp()) {
- if (backwardSlice.count(definingOp) == 0)
- matches(definingOp, backwardSlice, options, tempHops - 1);
- } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
- if (options.omitBlockArguments)
- return;
- Block *block = blockArg.getOwner();
-
- Operation *parentOp = block->getParentOp();
-
- if (parentOp && backwardSlice.count(parentOp) == 0) {
- if (parentOp->getNumRegions() != 1 &&
- parentOp->getRegion(0).getBlocks().size() != 1) {
- llvm::errs()
- << "Error: Expected parentOp to have exactly one region and "
- << "exactly one block, but found " << parentOp->getNumRegions()
- << " regions and "
- << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
- };
- matches(parentOp, backwardSlice, options, tempHops - 1);
- }
- } else {
- llvm::errs() << "No definingOp and not a block argument\n";
- return;
- }
- };
-
- if (!options.omitUsesFromAbove) {
- llvm::for_each(op->getRegions(), [&](Region ®ion) {
- SmallPtrSet<Region *, 4> descendents;
- region.walk(
- [&](Region *childRegion) { descendents.insert(childRegion); });
- region.walk([&](Operation *op) {
- for (OpOperand &operand : op->getOpOperands()) {
- if (!descendents.contains(operand.get().getParentRegion()))
- processValue(operand.get());
- }
- });
- });
- }
-
- llvm::for_each(op->getOperands(), processValue);
- backwardSlice.insert(op);
- return true;
- }
-
-public:
- bool match(Operation *op, SetVector<Operation *> &backwardSlice,
- QueryOptions &options) {
-
- if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
- if (!options.inclusive) {
- backwardSlice.remove(op);
- }
- return true;
- }
- return false;
- }
-
-private:
- matcher::DynMatcher innerMatcher;
- unsigned hops;
-};
-
-class ForwardSliceMatcher {
-public:
- ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
- : innerMatcher(std::move(innerMatcher)), hops(hops) {}
-
-private:
- bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
- QueryOptions &options, unsigned tempHops) {
-
- if (tempHops == 0) {
- forwardSlice.insert(op);
- return true;
- }
-
- for (Region ®ion : op->getRegions())
- for (Block &block : region)
- for (Operation &blockOp : block)
- if (forwardSlice.count(&blockOp) == 0)
- matches(&blockOp, forwardSlice, options, tempHops - 1);
- for (Value result : op->getResults()) {
- for (Operation *userOp : result.getUsers())
- if (forwardSlice.count(userOp) == 0)
- matches(userOp, forwardSlice, options, tempHops - 1);
- }
-
- forwardSlice.insert(op);
- return true;
- }
-
-public:
- bool match(Operation *op, SetVector<Operation *> &forwardSlice,
- QueryOptions &options) {
- if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
- if (!options.inclusive) {
- forwardSlice.remove(op);
- }
- SmallVector<Operation *, 0> v(forwardSlice.takeVector());
- forwardSlice.insert(v.rbegin(), v.rend());
- return true;
- }
- return false;
- }
-
-private:
- matcher::DynMatcher innerMatcher;
- unsigned hops;
-};
-
-} // namespace detail
-
-inline detail::BackwardSliceMatcher
-definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
- return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
-}
-
-inline detail::BackwardSliceMatcher
-getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
- return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
-}
-
-inline detail::ForwardSliceMatcher
-usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
- return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
-}
-
-inline detail::ForwardSliceMatcher
-getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
- return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
-}
-
-} // namespace extramatcher
-
-} // namespace query
-
-} // namespace mlir
-
-#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index c775dbc5c86da..43643298e4702 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -51,14 +51,14 @@ struct ArgTypeTraits<llvm::StringRef> {
};
template <>
-struct ArgTypeTraits<unsigned> {
+struct ArgTypeTraits<int64_t> {
static bool hasCorrectType(const VariantValue &value) {
- return value.isUnsigned();
+ return value.isSigned();
}
- static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+ static unsigned get(const VariantValue &value) { return value.getSigned(); }
- static ArgKind getKind() { return ArgKind::Unsigned; }
+ static ArgKind getKind() { return ArgKind::Signed; }
static std::optional<std::string> getBestGuess(const VariantValue &) {
return std::nullopt;
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 1b9d3bc307ff5..1d64f894bb8a1 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -23,45 +23,78 @@
namespace mlir::query::matcher {
class MatchFinder {
-private:
- static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
- mlir::Operation *op, const std::string &binding) {
- auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
- auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
- qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
- qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
- "\"" + binding + "\" binds here");
- };
public:
+ //
+ // getMatches walks the IR and prints operations as soon as it matches them
+ // if a matcher is to be further extracted into the function, then it does not
+ // print operations
+ //
static std::vector<Operation *>
getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
llvm::raw_ostream &os, QuerySession &qs) {
- unsigned matchCount = 0;
+ int matchCount = 0;
+ bool printMatchingOps = true;
+ // If matcher is to be extracted to a function, we don't want to print
+ // matching ops to sdout
+ if (matcher.hasFunctionName()) {
+ printMatchingOps = false;
+ }
std::vector<Operation *> matchedOps;
SetVector<Operation *> tempStorage;
os << "\n";
root->walk([&](Operation *subOp) {
if (matcher.match(subOp)) {
matchedOps.push_back(subOp);
- os << "Match #" << ++matchCount << ":\n\n";
- printMatch(os, qs, subOp, "root");
+ if (printMatchingOps) {
+ os << "Match #" << ++matchCount << ":\n\n";
+ printMatch(os, qs, subOp, "root");
+ }
} else {
SmallVector<Operation *> printingOps;
if (matcher.match(subOp, tempStorage, options)) {
- os << "Match #" << ++matchCount << ":\n\n";
+ if (printMatchingOps) {
+ os << "Match #" << ++matchCount << ":\n\n";
+ }
SmallVector<Operation *> printingOps(tempStorage.takeVector());
for (auto op : printingOps) {
- printMatch(os, qs, op, "root");
+ if (printMatchingOps) {
+ printMatch(os, qs, op, "root");
+ }
matchedOps.push_back(op);
}
printingOps.clear();
}
}
});
- os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+ if (printMatchingOps) {
+ os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+ }
return matchedOps;
}
+
+private:
+ // Overloaded version that doesn't print the binding
+ static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ mlir::Operation *op) {
+ auto fileLoc = op->getLoc()->dyn_cast<FileLineColLoc>();
+ SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+
+ llvm::SMDiagnostic diag =
+ qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note,
+
+ "");
+ diag.print("", os, true, false, true);
+ }
+ static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ mlir::Operation *op, const std::string &binding) {
+ auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+ auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+ "\"" + binding + "\" binds here");
+ }
};
} // namespace mlir::query::matcher
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index b532b47be7d05..c5c24190f0e7f 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,3 +1,4 @@
+//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,20 +8,20 @@
//
// Implements the base layer of the matcher framework.
//
-// Matchers are methods that return a Matcher which provides a method
-// match(Operation *op)
-// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
-// &options)
+// Matchers are methods that return a Matcher which provides a method one of the
+// following methods: match(Operation *op), match(Operation *op,
+// SetVector<Operation *> &matchedOps, QueryOptions &options)
//
// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
//===----------------------------------------------------------------------===//
+
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
-#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
namespace mlir {
@@ -30,17 +31,27 @@ struct QueryOptions;
} // namespace mlir
namespace mlir::query::matcher {
+
+// Defaults to false if T has no match() method with the signature:
+// match(Operation* op).
template <typename T, typename = void>
struct has_simple_match : std::false_type {};
+// Specialized type trait that evaluates to true if T has a match() method
+// with the signature: match(Operation* op).
template <typename T>
struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
std::declval<Operation *>()))>>
: std::true_type {};
+// Defaults to false if T has no match() method with the signature:
+// match(Operation* op, SetVector<Operation*>&, QueryOptions&).
template <typename T, typename = void>
struct has_bound_match : std::false_type {};
+// Specialized type trait that evaluates to true if T has a match() method
+// with the signature: match(Operation* op, SetVector<Operation*>&,
+// QueryOptions&).
template <typename T>
struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
std::declval<Operation *>(),
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 6b57119df7a9b..71ec628edeea1 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
namespace mlir::query::matcher {
// All types that VariantValue can contain.
-enum class ArgKind { Matcher, String, Unsigned };
+enum class ArgKind { Matcher, String, Signed };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
@@ -81,7 +81,7 @@ class VariantValue {
// Specific constructors for each supported type.
VariantValue(const llvm::StringRef string);
VariantValue(const VariantMatcher &matcher);
- VariantValue(unsigned Unsigned);
+ VariantValue(int64_t signedValue);
// String value functions.
bool isString() const;
@@ -93,10 +93,10 @@ class VariantValue {
const VariantMatcher &getMatcher() const;
void setMatcher(const VariantMatcher &matcher);
- // Unsigned value functions.
- bool isUnsigned() const;
- unsigned getUnsigned() const;
- void setUnsigned(unsigned Unsigned);
+ // Signed value functions.
+ bool isSigned() const;
+ int64_t getSigned() const;
+ void setSigned(int64_t signedValue);
// String representation of the type of the value.
std::string getTypeAsString() const;
@@ -111,14 +111,14 @@ class VariantValue {
Nothing,
String,
Matcher,
- Unsigned,
+ Signed,
};
// All supported value types.
union AllValues {
llvm::StringRef *String;
VariantMatcher *Matcher;
- unsigned Unsigned;
+ int64_t Signed;
};
ValueType type;
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index bb5b98432d51c..77114a2a00e23 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -17,13 +17,31 @@
namespace mlir::query {
+///
+/// Options for configuring which parts of the IR are to be
+/// traversed by the matcher
+///
struct QueryOptions {
+ /// When omitBlockArguments is true, the matcher omits traversing
+ /// any block arguments
bool omitBlockArguments = false;
+ /// When omitUsesFromAbove is true, the matcher omits
+ /// traversing values that are captured from above.
bool omitUsesFromAbove = true;
+ /// When inclusive is true, the matcher will include the include the
+ /// top level op in the slice. When inclusive is false, the matcher will
+ /// not include thee top level op in the slice
bool inclusive = true;
};
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit, SetBool };
+enum class QueryKind {
+ Invalid,
+ NoOp,
+ Help,
+ SetBool,
+ Match,
+ Quit,
+};
class QuerySession;
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index 495358e8f36f9..03dbd481d64cf 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -15,10 +15,6 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/SourceMgr.h"
-namespace mlir::query::matcher {
-class Registry;
-}
-
namespace mlir::query {
// Represents the state for a particular mlir-query session.
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 8d6c0135aa117..8f7da5aeaa5e6 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -19,11 +19,6 @@
namespace mlir::query::matcher {
namespace {
-// This is needed because these matchers are defined as overloaded functions.
-using IsConstantOp = detail::constant_op_matcher();
-using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
-using HasOpName = detail::NameOpMatcher(llvm::StringRef);
-
// Enum to string for autocomplete.
static std::string asArgString(ArgKind kind) {
switch (kind) {
@@ -31,8 +26,8 @@ static std::string asArgString(ArgKind kind) {
return "Matcher";
case ArgKind::String:
return "String";
- case ArgKind::Unsigned:
- return "unsigned";
+ case ArgKind::Signed:
+ return "signed";
}
llvm_unreachable("Unhandled ArgKind");
}
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index d5218d8dad8c9..d4f3e4f4d594d 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,8 +56,8 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
value.Matcher = new VariantMatcher(matcher);
}
-VariantValue::VariantValue(unsigned Unsigned) : type(ValueType::Unsigned) {
- value.Unsigned = Unsigned;
+VariantValue::VariantValue(int64_t signedValue) : type(ValueType::Signed) {
+ value.Signed = signedValue;
}
VariantValue::~VariantValue() { reset(); }
@@ -73,8 +73,8 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
case ValueType::Matcher:
setMatcher(other.getMatcher());
break;
- case ValueType::Unsigned:
- setUnsigned(other.getUnsigned());
+ case ValueType::Signed:
+ setSigned(other.getSigned());
break;
case ValueType::Nothing:
type = ValueType::Nothing;
@@ -92,7 +92,7 @@ void VariantValue::reset() {
delete value.Matcher;
break;
// Cases that do nothing.
- case ValueType::Unsigned:
+ case ValueType::Signed:
case ValueType::Nothing:
break;
}
@@ -100,17 +100,17 @@ void VariantValue::reset() {
}
// Unsinged
-bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
+bool VariantValue::isSigned() const { return type == ValueType::Signed; }
-unsigned VariantValue::getUnsigned() const {
- assert(isUnsigned());
- return value.Unsigned;
+int64_t VariantValue::getSigned() const {
+ assert(isSigned());
+ return value.Signed;
}
-void VariantValue::setUnsigned(unsigned newValue) {
+void VariantValue::setSigned(int64_t newValue) {
reset();
- type = ValueType::Unsigned;
- value.Unsigned = newValue;
+ type = ValueType::Signed;
+ value.Signed = newValue;
}
bool VariantValue::isString() const { return type == ValueType::String; }
@@ -145,8 +145,8 @@ std::string VariantValue::getTypeAsString() const {
return "String";
case ValueType::Matcher:
return "Matcher";
- case ValueType::Unsigned:
- return "Unsigned";
+ case ValueType::Signed:
+ return "Signed";
case ValueType::Nothing:
return "Nothing";
}
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
index b3df534ee8871..c5fee38327704 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-query %s -c "match getDefinitions(hasOpName("arith.addf"),2)" | FileCheck %s
+// RUN: mlir-query %s -c "match getDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
@@ -22,18 +22,11 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
-
// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
-
// CHECK: Match #2:
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
-// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%c2] : tensor<25xf32>
+// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
-
-
-
-
-
-
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index 5a20c09d02eb6..a783f65c6761b 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-query %s -c "m hasOpName("arith.mulf").extract("testmul")" | FileCheck %s
+// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
-// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32S
+// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
%sum0 = arith.addf %a, %b : f32
@@ -16,4 +16,4 @@ func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
%sum2 = arith.addf %mul1, %b : f32
%mul2 = arith.mulf %sub2, %sum2 : f32
return %mul2 : f32
- }
+}
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 468f948bec24c..91714aab33699 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -14,7 +14,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/InitAllDialects.h"
-#include "mlir/Query/Matcher/ExtraMatchers.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/Matcher/Registry.h"
#include "mlir/Tools/mlir-query/MlirQueryMain.h"
@@ -40,14 +39,10 @@ int main(int argc, char **argv) {
query::matcher::Registry matcherRegistry;
// Matchers registered in alphabetical order for consistency:
- matcherRegistry.registerMatcher("getDefinitions",
- mlir::query::extramatcher::getDefinitions);
- matcherRegistry.registerMatcher("definedBy",
- mlir::query::extramatcher::definedBy);
- matcherRegistry.registerMatcher("usedBy", mlir::query::extramatcher::usedBy);
- matcherRegistry.registerMatcher("getUses",
- mlir::query::extramatcher::getUses);
-
+ matcherRegistry.registerMatcher("getDefinitions", m_GetDefinitions);
+ matcherRegistry.registerMatcher("definedBy", m_DefinedBy);
+ matcherRegistry.registerMatcher("usedBy", m_UsedBy);
+ matcherRegistry.registerMatcher("getUses", m_GetUses);
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
>From e6bc9b34a1fd45100afb0c8b23d928ecc659c681 Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Sat, 29 Mar 2025 15:12:50 +0000
Subject: [PATCH 5/5] Implement nested slicing matcher & enhance MatchFinder
class - nested slicing matcher - enhance MatchFinder class -
rename getSlice static method to avoid collision with SliceAnalysis::getSlice
---
mlir/include/mlir/IR/Matchers.h | 225 ++----------------
mlir/include/mlir/Query/Matcher/MatchFinder.h | 97 ++------
.../mlir/Query/Matcher/MatchersInternal.h | 6 +-
mlir/include/mlir/Query/Query.h | 20 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 6 +-
mlir/lib/IR/CMakeLists.txt | 1 +
mlir/lib/IR/Matchers.cpp | 57 +++++
mlir/lib/Query/Matcher/CMakeLists.txt | 1 +
mlir/lib/Query/Matcher/MatchFinder.cpp | 72 ++++++
mlir/lib/Query/Matcher/Parser.cpp | 24 +-
mlir/lib/Query/Matcher/VariantValue.cpp | 7 +-
mlir/lib/Query/Query.cpp | 22 +-
mlir/lib/Query/QueryParser.cpp | 2 +-
mlir/test/mlir-query/complex-test.mlir | 2 +-
mlir/tools/mlir-query/mlir-query.cpp | 2 -
15 files changed, 208 insertions(+), 336 deletions(-)
create mode 100644 mlir/lib/IR/Matchers.cpp
create mode 100644 mlir/lib/Query/Matcher/MatchFinder.cpp
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index ee9e2afb10bad..5ea91e64fa93c 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -21,8 +21,6 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Query/Matcher/MatchersInternal.h"
#include "mlir/Query/Query.h"
-#include "llvm/ADT/SetVector.h"
-
namespace mlir {
namespace detail {
@@ -366,21 +364,14 @@ struct RecursivePatternMatcher {
std::tuple<OperandMatchers...> operandMatchers;
};
-/// Fills `backwardSlice` with the computed backward slice (i.e.
-/// all the transitive defs of op)
-///
-/// The implementation traverses the def chains in postorder traversal for
-/// efficiency reasons: if an operation is already in `backwardSlice`, no
-/// need to traverse its definitions again. Since use-def chains form a DAG,
-/// this terminates.
-///
-/// Upon return to the root call, `backwardSlice` is filled with a
-/// postorder list of defs. This happens to be a topological order, from the
-/// point of view of the use-def chains.
+/// A matcher encapsulating the initial `getBackwardSlice` method from
+/// SliceAnalysis.h
+/// Additionally, it limits the slice computation to a certain depth level using
+/// a custom filter
///
-/// Example starting from node 8
+/// Example starting from node 9, assuming the matcher
+/// computes the slice for the first two depth levels
/// ============================
-///
/// 1 2 3 4
/// |_______| |______|
/// | | |
@@ -393,240 +384,52 @@ struct RecursivePatternMatcher {
/// 9
///
/// Assuming all local orders match the numbering order:
-/// {1, 2, 5, 3, 4, 6}
-///
-
+/// {5, 7, 6, 8, 9}
class BackwardSliceMatcher {
public:
- BackwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
+ BackwardSliceMatcher(query::matcher::DynMatcher &&innerMatcher,
int64_t maxDepth)
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
-
bool match(Operation *op, SetVector<Operation *> &backwardSlice,
- mlir::query::QueryOptions &options) {
+ query::QueryOptions &options) {
if (innerMatcher.match(op) &&
matches(op, backwardSlice, options, maxDepth)) {
- if (!options.inclusive) {
- // Don't insert the top level operation, we just queried on it and don't
- // want it in the results.
- backwardSlice.remove(op);
- }
return true;
}
return false;
}
private:
- bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
- mlir::query::QueryOptions &options, int64_t remainingDepth) {
-
- if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
- return false;
- }
-
- auto processValue = [&](Value value) {
- // We need to check the current depth level;
- // if we have reached level 0, we stop further traversing
- if (remainingDepth == 0) {
- return;
- }
- if (auto *definingOp = value.getDefiningOp()) {
- // We omit traversing the same operations
- if (backwardSlice.count(definingOp) == 0)
- matches(definingOp, backwardSlice, options, remainingDepth - 1);
- } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
- if (options.omitBlockArguments)
- return;
- Block *block = blockArg.getOwner();
-
- Operation *parentOp = block->getParentOp();
- // TODO: determine whether we want to recurse backward into the other
- // blocks of parentOp, which are not technically backward unless they
- // flow into us. For now, just bail.
- if (parentOp && backwardSlice.count(parentOp) == 0) {
- if (parentOp->getNumRegions() != 1 &&
- parentOp->getRegion(0).getBlocks().size() != 1) {
- llvm::errs()
- << "Error: Expected parentOp to have exactly one region and "
- << "exactly one block, but found " << parentOp->getNumRegions()
- << " regions and "
- << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
- };
- matches(parentOp, backwardSlice, options, remainingDepth - 1);
- }
- } else {
- llvm_unreachable("No definingOp and not a block argument\n");
- return;
- }
- };
-
- if (!options.omitUsesFromAbove) {
- llvm::for_each(op->getRegions(), [&](Region ®ion) {
- // Walk this region recursively to collect the regions that descend from
- // this op's nested regions (inclusive).
- SmallPtrSet<Region *, 4> descendents;
- region.walk(
- [&](Region *childRegion) { descendents.insert(childRegion); });
- region.walk([&](Operation *op) {
- for (OpOperand &operand : op->getOpOperands()) {
- if (!descendents.contains(operand.get().getParentRegion()))
- processValue(operand.get());
- }
- });
- });
- }
-
- llvm::for_each(op->getOperands(), processValue);
- backwardSlice.insert(op);
- return true;
- }
+ bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
+ query::QueryOptions &options, int64_t maxDepth);
private:
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
// to determine whether we want to traverse the DAG or not. For example, we
// want to explore the DAG only if the top-level operation name is
// "arith.addf".
- mlir::query::matcher::DynMatcher innerMatcher;
-
+ query::matcher::DynMatcher innerMatcher;
// maxDepth specifies the maximum depth that the matcher can traverse in the
// DAG. For example, if maxDepth is 2, the matcher will explore the defining
// operations of the top-level op up to 2 levels.
int64_t maxDepth;
};
-
-/// Fills `forwardSlice` with the computed forward slice (i.e. all
-/// the transitive uses of op)
-///
-///
-/// The implementation traverses the use chains in postorder traversal for
-/// efficiency reasons: if an operation is already in `forwardSlice`, no
-/// need to traverse its uses again. Since use-def chains form a DAG, this
-/// terminates.
-///
-/// Upon return to the root call, `forwardSlice` is filled with a
-/// postorder list of uses (i.e. a reverse topological order). To get a proper
-/// topological order, we just reverse the order in `forwardSlice` before
-/// returning.
-///
-/// Example starting from node 0
-/// ============================
-///
-/// 0
-/// ___________|___________
-/// 1 2 3 4
-/// |_______| |______|
-/// | | |
-/// | 5 6
-/// |___|_____________|
-/// | |
-/// 7 8
-/// |_______________|
-/// |
-/// 9
-///
-/// Assuming all local orders match the numbering order:
-/// 1. after getting back to the root getForwardSlice, `forwardSlice` may
-/// contain:
-/// {9, 7, 8, 5, 1, 2, 6, 3, 4}
-/// 2. reversing the result of 1. gives:
-/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
-///
-class ForwardSliceMatcher {
-public:
- ForwardSliceMatcher(mlir::query::matcher::DynMatcher &&innerMatcher,
- int64_t maxDepth)
- : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
-
- bool match(Operation *op, SetVector<Operation *> &forwardSlice,
- mlir::query::QueryOptions &options) {
- if (innerMatcher.match(op) &&
- matches(op, forwardSlice, options, maxDepth)) {
- if (!options.inclusive) {
- // Don't insert the top level operation, we just queried on it and don't
- // want it in the results.
- forwardSlice.remove(op);
- }
- // Reverse to get back the actual topological order.
- // std::reverse does not work out of the box on SetVector and I want an
- // in-place swap based thing (the real std::reverse, not the LLVM
- // adapter).
- SmallVector<Operation *, 0> v(forwardSlice.takeVector());
- forwardSlice.insert(v.rbegin(), v.rend());
- return true;
- }
- return false;
- }
-
-private:
- bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
- mlir::query::QueryOptions &options, int64_t remainingDepth) {
-
- // We need to check the current depth level;
- // if we have reached level 0, we stop further traversing and insert
- // the last user in def-use chain
- if (remainingDepth == 0) {
- forwardSlice.insert(op);
- return true;
- }
-
- for (Region ®ion : op->getRegions())
- for (Block &block : region)
- for (Operation &blockOp : block)
- if (forwardSlice.count(&blockOp) == 0)
- matches(&blockOp, forwardSlice, options, remainingDepth - 1);
- for (Value result : op->getResults()) {
- for (Operation *userOp : result.getUsers())
- // We omit traversing the same operations
- if (forwardSlice.count(userOp) == 0)
- matches(userOp, forwardSlice, options, remainingDepth - 1);
- }
-
- forwardSlice.insert(op);
- return true;
- }
-
-private:
- // The outer matcher e.g (ForwardSliceMatcher) relies on the innerMatcher to
- // determine whether we want to traverse the graph or not. E.g: we want to
- // explore the DAG only if the top level operation name is "arith.addf"
- mlir::query::matcher::DynMatcher innerMatcher;
-
- // maxDepth specifies the maximum depth that the matcher can traverse the
- // graph E.g: if maxDepth is 2, the matcher will explore the user
- // operations of the top level op up to 2 levels
- int64_t maxDepth;
-};
-
} // namespace detail
// Matches transitive defs of a top level operation up to 1 level
inline detail::BackwardSliceMatcher
-m_DefinedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+m_DefinedBy(query::matcher::DynMatcher innerMatcher) {
return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
}
// Matches transitive defs of a top level operation up to N levels
inline detail::BackwardSliceMatcher
-m_GetDefinitions(mlir::query::matcher::DynMatcher innerMatcher,
- int64_t maxDepth) {
+m_GetDefinitions(query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
assert(maxDepth >= 0 && "maxDepth must be non-negative");
return detail::BackwardSliceMatcher(std::move(innerMatcher), maxDepth);
}
-// Matches uses of a top level operation up to 1 level
-inline detail::ForwardSliceMatcher
-m_UsedBy(mlir::query::matcher::DynMatcher innerMatcher) {
- return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
-}
-
-// Matches uses of a top level operation up to N levels
-inline detail::ForwardSliceMatcher
-m_GetUses(mlir::query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
- assert(maxDepth >= 0 && "maxDepth must be non-negative");
- return detail::ForwardSliceMatcher(std::move(innerMatcher), maxDepth);
-}
-
/// Matches a constant foldable operation.
inline detail::constant_op_matcher m_Constant() {
return detail::constant_op_matcher();
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 1d64f894bb8a1..3591cf05e7599 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -15,86 +15,41 @@
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#include "MatchersInternal.h"
+#include "mlir/Query/Query.h"
#include "mlir/Query/QuerySession.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/SourceMgr.h"
-#include "llvm/Support/raw_ostream.h"
namespace mlir::query::matcher {
+/// A class that provides utilities to find operations in a DAG
class MatchFinder {
public:
- //
- // getMatches walks the IR and prints operations as soon as it matches them
- // if a matcher is to be further extracted into the function, then it does not
- // print operations
- //
- static std::vector<Operation *>
- getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
- llvm::raw_ostream &os, QuerySession &qs) {
- int matchCount = 0;
- bool printMatchingOps = true;
- // If matcher is to be extracted to a function, we don't want to print
- // matching ops to sdout
- if (matcher.hasFunctionName()) {
- printMatchingOps = false;
- }
- std::vector<Operation *> matchedOps;
- SetVector<Operation *> tempStorage;
- os << "\n";
- root->walk([&](Operation *subOp) {
- if (matcher.match(subOp)) {
- matchedOps.push_back(subOp);
- if (printMatchingOps) {
- os << "Match #" << ++matchCount << ":\n\n";
- printMatch(os, qs, subOp, "root");
- }
- } else {
- SmallVector<Operation *> printingOps;
- if (matcher.match(subOp, tempStorage, options)) {
- if (printMatchingOps) {
- os << "Match #" << ++matchCount << ":\n\n";
- }
- SmallVector<Operation *> printingOps(tempStorage.takeVector());
- for (auto op : printingOps) {
- if (printMatchingOps) {
- printMatch(os, qs, op, "root");
- }
- matchedOps.push_back(op);
- }
- printingOps.clear();
- }
- }
- });
- if (printMatchingOps) {
- os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
- }
- return matchedOps;
- }
-
-private:
- // Overloaded version that doesn't print the binding
- static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
- mlir::Operation *op) {
- auto fileLoc = op->getLoc()->dyn_cast<FileLineColLoc>();
- SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
- qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ /// A subclass which preserves the matching information
+ struct MatchResult {
+ MatchResult() = default;
+ MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
- llvm::SMDiagnostic diag =
- qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note,
+ /// Contains the root operation of the matching environment
+ Operation *rootOp = nullptr;
- "");
- diag.print("", os, true, false, true);
- }
- static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
- mlir::Operation *op, const std::string &binding) {
- auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
- auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
- qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
- qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
- "\"" + binding + "\" binds here");
- }
+ /// Contains the matching enviroment. This allows the user to easily extract
+ /// the matched operations
+ std::vector<Operation *> matchedOps;
+ };
+ /// Traverses the DAG and collects the "rootOp" + "matching enviroment" for a
+ /// given Matcher
+ std::vector<MatchResult>
+ collectMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+ llvm::raw_ostream &os, QuerySession &qs) const;
+ /// Prints the matched operation
+ void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
+ /// Labels the matched operation with the given binding (e.g., "root") and
+ /// prints it
+ void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
+ const std::string &binding) const;
+ /// Flattens a vector of MatchResults into a vector of operations
+ std::vector<Operation *>
+ flattenMatchedOps(std::vector<MatchResult> &matches) const;
};
} // namespace mlir::query::matcher
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index c5c24190f0e7f..e26697cdc4ae8 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -21,15 +21,13 @@
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
-namespace mlir {
-namespace query {
+namespace mlir::query {
struct QueryOptions;
}
-} // namespace mlir
-
namespace mlir::query::matcher {
// Defaults to false if T has no match() method with the signature:
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 77114a2a00e23..5644113ba9e18 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -10,6 +10,7 @@
#define MLIR_TOOLS_MLIRQUERY_QUERY_H
#include "Matcher/VariantValue.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/LineEditor/LineEditor.h"
@@ -17,22 +18,9 @@
namespace mlir::query {
-///
-/// Options for configuring which parts of the IR are to be
-/// traversed by the matcher
-///
-struct QueryOptions {
- /// When omitBlockArguments is true, the matcher omits traversing
- /// any block arguments
- bool omitBlockArguments = false;
- /// When omitUsesFromAbove is true, the matcher omits
- /// traversing values that are captured from above.
- bool omitUsesFromAbove = true;
- /// When inclusive is true, the matcher will include the include the
- /// top level op in the slice. When inclusive is false, the matcher will
- /// not include thee top level op in the slice
- bool inclusive = true;
-};
+/// QueryOptions is a class derived from BackwardSliceOptions
+/// Addtional options can be added for further customization
+struct QueryOptions : public BackwardSliceOptions {};
enum class QueryKind {
Invalid,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ff89ead59981c..c1e3942432352 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -74,7 +74,7 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
/// Returns a memref.subview or a tensor.extract_slice based on the type of the
/// `source`.
-static Operation *getSlice(OpBuilder &b, Location loc, Value source,
+static Operation *getSubviewOrSlice(OpBuilder &b, Location loc, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
@@ -2675,13 +2675,13 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Value> tiledOperands;
Operation *inputSlice =
- getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
+ getSubviewOrSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
if (!inputSlice) {
return emitOpError("failed to compute input slice");
}
tiledOperands.emplace_back(inputSlice->getResult(0));
Operation *outputSlice =
- getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
+ getSubviewOrSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
if (!outputSlice) {
return emitOpError("failed to compute output slice");
}
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 4cabac185171c..c6c44260fe776 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_library(MLIRIR
ExtensibleDialect.cpp
IntegerSet.cpp
Location.cpp
+ Matchers.cpp
MLIRContext.cpp
ODSSupport.cpp
Operation.cpp
diff --git a/mlir/lib/IR/Matchers.cpp b/mlir/lib/IR/Matchers.cpp
new file mode 100644
index 0000000000000..055f0a17527db
--- /dev/null
+++ b/mlir/lib/IR/Matchers.cpp
@@ -0,0 +1,57 @@
+//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements specific matchers
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Matchers.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+namespace mlir::detail {
+
+bool BackwardSliceMatcher::matches(Operation *rootOp,
+ llvm::SetVector<Operation *> &backwardSlice,
+ query::QueryOptions &options,
+ int64_t maxDepth) {
+ backwardSlice.clear();
+ llvm::DenseMap<Operation *, int64_t> opDepths;
+ // Initialize the map with the root operation
+ // and set its depth to 0
+ opDepths[rootOp] = 0;
+ options.filter = [&](Operation *op) {
+ if (opDepths[op] > maxDepth)
+ return false;
+ // Begins by checking the previous operation's arguments
+ // and computing their depth
+ for (auto operand : op->getOperands()) {
+ if (auto definingOp = operand.getDefiningOp()) {
+ // If the operation is in the map, it means
+ // we have already computed its depth
+ if (!opDepths.contains(definingOp)) {
+ // The operation's depth is 1 level above its root op
+ opDepths[definingOp] = opDepths[op] + 1;
+ if (opDepths[op] > maxDepth)
+ return false;
+ }
+ } else {
+ auto blockArgument = cast<BlockArgument>(operand);
+ Operation *parentOp = blockArgument.getOwner()->getParentOp();
+ if (!opDepths.contains(parentOp)) {
+ opDepths[parentOp] = opDepths[op] + 1;
+ if (opDepths[op] > maxDepth)
+ return false;
+ }
+ }
+ }
+ return true;
+ };
+ getBackwardSlice(rootOp, &backwardSlice, options);
+ return true;
+}
+} // namespace mlir::detail
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 3adff9f99243f..629479bf7adc1 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_library(MLIRQueryMatcher
+ MatchFinder.cpp
Parser.cpp
RegistryManager.cpp
VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchFinder.cpp b/mlir/lib/Query/Matcher/MatchFinder.cpp
new file mode 100644
index 0000000000000..b0a95660c1d59
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchFinder.cpp
@@ -0,0 +1,72 @@
+//===- MatchFinder.cpp - -----------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the method definitions for the `MatchFinder` class
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/MatchFinder.h"
+namespace mlir::query::matcher {
+
+MatchFinder::MatchResult::MatchResult(Operation *rootOp,
+ std::vector<Operation *> matchedOps)
+ : rootOp(rootOp), matchedOps(std::move(matchedOps)) {}
+
+std::vector<MatchFinder::MatchResult>
+MatchFinder::collectMatches(Operation *root, QueryOptions &options,
+ DynMatcher matcher, llvm::raw_ostream &os,
+ QuerySession &qs) const {
+ std::vector<MatchResult> results;
+ llvm::SetVector<Operation *> tempStorage;
+ os << "\n";
+ root->walk([&](Operation *subOp) {
+ if (matcher.match(subOp)) {
+ MatchResult match;
+ match.rootOp = subOp;
+ match.matchedOps.push_back(subOp);
+ results.push_back(std::move(match));
+ } else if (matcher.match(subOp, tempStorage, options)) {
+ results.emplace_back(subOp, std::vector<Operation *>(tempStorage.begin(),
+ tempStorage.end()));
+ }
+ tempStorage.clear();
+ });
+ return results;
+}
+
+void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ Operation *op) const {
+ auto fileLoc = dyn_cast<FileLineColLoc>(op->getLoc());
+ SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+
+ llvm::SMDiagnostic diag =
+ qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note, "");
+ diag.print("", os, true, false, true);
+}
+
+void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ Operation *op, const std::string &binding) const {
+ auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+ auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+ "\"" + binding + "\" binds here");
+}
+
+std::vector<Operation *>
+MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const {
+ std::vector<Operation *> newVector;
+ for (auto &result : matches) {
+ newVector.insert(newVector.end(), result.matchedOps.begin(),
+ result.matchedOps.end());
+ }
+ return newVector;
+}
+
+} // namespace mlir::query::matcher
\ No newline at end of file
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 726f1188d7e4c..a82af80dbdb0c 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -157,25 +157,13 @@ class Parser::CodeTokenizer {
}
void consumeNumberLiteral(TokenInfo *result) {
- unsigned length = 1;
- if (code.size() > 1) {
- // Consume the 'x' or 'b' radix modifier, if present.
- switch (tolower(code[1])) {
- case 'x':
- case 'b':
- length = 2;
- }
- }
- while (length < code.size() && isdigit(code[length]))
- ++length;
-
- result->text = code.take_front(length);
- code = code.drop_front(length);
-
- unsigned value;
- if (!result->text.getAsInteger(0, value)) {
+ StringRef original = code;
+ unsigned value = 0;
+ if (!code.consumeInteger(0, value)) {
+ size_t numConsumed = original.size() - code.size();
+ result->text = original.take_front(numConsumed);
result->kind = TokenKind::Literal;
- result->value = static_cast<unsigned>(value);
+ result->value = static_cast<int64_t>(value);
return;
}
}
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index d4f3e4f4d594d..f2bf0f9065bbe 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -99,13 +99,10 @@ void VariantValue::reset() {
type = ValueType::Nothing;
}
-// Unsinged
+// Signed
bool VariantValue::isSigned() const { return type == ValueType::Signed; }
-int64_t VariantValue::getSigned() const {
- assert(isSigned());
- return value.Signed;
-}
+int64_t VariantValue::getSigned() const { return value.Signed; }
void VariantValue::setSigned(int64_t newValue) {
reset();
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index dd699857568d7..7082fdb0f8482 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -129,23 +129,37 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
-
+ int matchCount = 0;
QueryOptions options;
+ matcher::MatchFinder finder;
parseQueryOptions(qs, options);
- auto matches = matcher::MatchFinder().getMatches(rootOp, options,
- std::move(matcher), os, qs);
+ auto matches =
+ finder.collectMatches(rootOp, options, std::move(matcher), os, qs);
// An extract call is recognized by considering if the matcher has a name.
// TODO: Consider making the extract more explicit.
if (matcher.hasFunctionName()) {
auto functionName = matcher.getFunctionName();
+ std::vector<Operation *> flattenedMatches =
+ finder.flattenMatchedOps(matches);
Operation *function =
- extractFunction(matches, rootOp->getContext(), functionName);
+ extractFunction(flattenedMatches, rootOp->getContext(), functionName);
os << "\n" << *function << "\n\n";
function->erase();
return mlir::success();
}
+ for (auto &results : matches) {
+ os << "Match #" << ++matchCount << ":\n\n";
+ for (auto op : results.matchedOps) {
+ if (op == results.rootOp) {
+ finder.printMatch(os, qs, op, "root");
+ } else {
+ finder.printMatch(os, qs, op);
+ }
+ }
+ }
+ os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
return mlir::success();
}
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 53e8f91e657cb..b7c6118575dc8 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -167,10 +167,10 @@ QueryRef QueryParser::doParse() {
.Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
.Case("help", ParsedQueryKind::Help)
.Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
- .Case("set", ParsedQueryKind::Set)
.Case("match", ParsedQueryKind::Match)
.Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false)
.Case("quit", ParsedQueryKind::Quit)
+ .Case("set", ParsedQueryKind::Set)
.Default(ParsedQueryKind::Invalid);
switch (qKind) {
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
index c5fee38327704..e0f7ee3034ed9 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -29,4 +29,4 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
-// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
+// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 91714aab33699..34fb7d1d80a8d 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -41,8 +41,6 @@ int main(int argc, char **argv) {
// Matchers registered in alphabetical order for consistency:
matcherRegistry.registerMatcher("getDefinitions", m_GetDefinitions);
matcherRegistry.registerMatcher("definedBy", m_DefinedBy);
- matcherRegistry.registerMatcher("usedBy", m_UsedBy);
- matcherRegistry.registerMatcher("getUses", m_GetUses);
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
More information about the Mlir-commits
mailing list