[Mlir-commits] [mlir] [mlir] MLIR-QUERY DefinitionsMatcher implementation & DAG (PR #115670)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 25 05:45:23 PST 2025
https://github.com/dbudii updated https://github.com/llvm/llvm-project/pull/115670
>From 9d274f32043bd7282a23bb25959b96d8d5a33b02 Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Sat, 25 Jan 2025 13:38:31 +0000
Subject: [PATCH 1/2] Fixed pattern matching in mlir-query test files & removed
asserts from slice-matchers
---
mlir/include/mlir/IR/Matchers.h | 4 +-
.../mlir/Query/Matcher/ExtraMatchers.h | 188 ++++++++++++++++++
mlir/include/mlir/Query/Matcher/Marshallers.h | 15 ++
mlir/include/mlir/Query/Matcher/MatchFinder.h | 52 +++--
.../mlir/Query/Matcher/MatchersInternal.h | 60 +++++-
.../include/mlir/Query/Matcher/VariantValue.h | 12 +-
mlir/include/mlir/Query/Query.h | 34 +++-
mlir/include/mlir/Query/QuerySession.h | 11 +-
mlir/lib/Query/Matcher/Parser.cpp | 36 ++++
mlir/lib/Query/Matcher/RegistryManager.cpp | 2 +
mlir/lib/Query/Matcher/VariantValue.cpp | 24 +++
mlir/lib/Query/Query.cpp | 37 ++--
mlir/lib/Query/QueryParser.cpp | 52 ++++-
mlir/lib/Query/QueryParser.h | 2 +-
mlir/test/mlir-query/complex-test.mlir | 39 ++++
mlir/test/mlir-query/function-extraction.mlir | 2 +-
mlir/tools/mlir-query/mlir-query.cpp | 10 +
17 files changed, 529 insertions(+), 51 deletions(-)
create mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h
create mode 100644 mlir/test/mlir-query/complex-test.mlir
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 1dce055db1b4a7..2204a68be26b10 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -59,7 +59,7 @@ struct NameOpMatcher {
NameOpMatcher(StringRef name) : name(name) {}
bool match(Operation *op) { return op->getName().getStringRef() == name; }
- StringRef name;
+ std::string name;
};
/// The matcher that matches operations that have the specified attribute name.
@@ -67,7 +67,7 @@ struct AttrOpMatcher {
AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
bool match(Operation *op) { return op->hasAttr(attrName); }
- StringRef attrName;
+ std::string attrName;
};
/// The matcher that matches operations that have the `ConstantLike` trait, and
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 00000000000000..908fccfc704c33
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,188 @@
+//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides extra matchers that are very useful for mlir-query
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTRAMATCHERS_H
+#define MLIR_IR_EXTRAMATCHERS_H
+
+#include "MatchFinder.h"
+#include "MatchersInternal.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Query/Query.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+namespace query {
+
+namespace extramatcher {
+
+namespace detail {
+
+class BackwardSliceMatcher {
+public:
+ BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+ bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+ QueryOptions &options, unsigned tempHops) {
+
+ if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ return false;
+ }
+
+ auto processValue = [&](Value value) {
+ if (tempHops == 0) {
+ return;
+ }
+ if (auto *definingOp = value.getDefiningOp()) {
+ if (backwardSlice.count(definingOp) == 0)
+ matches(definingOp, backwardSlice, options, tempHops - 1);
+ } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+ if (options.omitBlockArguments)
+ return;
+ Block *block = blockArg.getOwner();
+
+ Operation *parentOp = block->getParentOp();
+
+ if (parentOp && backwardSlice.count(parentOp) == 0) {
+ if (parentOp->getNumRegions() != 1 &&
+ parentOp->getRegion(0).getBlocks().size() != 1) {
+ llvm::errs()
+ << "Error: Expected parentOp to have exactly one region and "
+ << "exactly one block, but found " << parentOp->getNumRegions()
+ << " regions and "
+ << (parentOp->getRegion(0).getBlocks().size()) << " blocks.\n";
+ };
+ matches(parentOp, backwardSlice, options, tempHops - 1);
+ }
+ } else {
+ llvm::errs() << "No definingOp and not a block argument\n";
+ return;
+ }
+ };
+
+ if (!options.omitUsesFromAbove) {
+ llvm::for_each(op->getRegions(), [&](Region ®ion) {
+ SmallPtrSet<Region *, 4> descendents;
+ region.walk(
+ [&](Region *childRegion) { descendents.insert(childRegion); });
+ region.walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ if (!descendents.contains(operand.get().getParentRegion()))
+ processValue(operand.get());
+ }
+ });
+ });
+ }
+
+ llvm::for_each(op->getOperands(), processValue);
+ backwardSlice.insert(op);
+ return true;
+ }
+
+public:
+ bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+ QueryOptions &options) {
+
+ if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
+ if (!options.inclusive) {
+ backwardSlice.remove(op);
+ }
+ return true;
+ }
+ return false;
+ }
+
+private:
+ matcher::DynMatcher innerMatcher;
+ unsigned hops;
+};
+
+class ForwardSliceMatcher {
+public:
+ ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+ bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+ QueryOptions &options, unsigned tempHops) {
+
+ if (tempHops == 0) {
+ forwardSlice.insert(op);
+ return true;
+ }
+
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ for (Operation &blockOp : block)
+ if (forwardSlice.count(&blockOp) == 0)
+ matches(&blockOp, forwardSlice, options, tempHops - 1);
+ for (Value result : op->getResults()) {
+ for (Operation *userOp : result.getUsers())
+ if (forwardSlice.count(userOp) == 0)
+ matches(userOp, forwardSlice, options, tempHops - 1);
+ }
+
+ forwardSlice.insert(op);
+ return true;
+ }
+
+public:
+ bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+ QueryOptions &options) {
+ if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
+ if (!options.inclusive) {
+ forwardSlice.remove(op);
+ }
+ SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+ forwardSlice.insert(v.rbegin(), v.rend());
+ return true;
+ }
+ return false;
+ }
+
+private:
+ matcher::DynMatcher innerMatcher;
+ unsigned hops;
+};
+
+} // namespace detail
+
+inline detail::BackwardSliceMatcher
+definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::BackwardSliceMatcher
+getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+inline detail::ForwardSliceMatcher
+usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::ForwardSliceMatcher
+getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+} // namespace extramatcher
+
+} // namespace query
+
+} // namespace mlir
+
+#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc7..c775dbc5c86da0 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
}
};
+template <>
+struct ArgTypeTraits<unsigned> {
+ static bool hasCorrectType(const VariantValue &value) {
+ return value.isUnsigned();
+ }
+
+ static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+
+ static ArgKind getKind() { return ArgKind::Unsigned; }
+
+ static std::optional<std::string> getBestGuess(const VariantValue &) {
+ return std::nullopt;
+ }
+};
+
template <>
struct ArgTypeTraits<DynMatcher> {
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2a..1b9d3bc307ff50 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
+// that match a given matcher and print them.
//
//===----------------------------------------------------------------------===//
@@ -15,24 +15,52 @@
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#include "MatchersInternal.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir::query::matcher {
-// MatchFinder is used to find all operations that match a given matcher.
class MatchFinder {
-public:
- // Returns all operations that match the given matcher.
- static std::vector<Operation *> getMatches(Operation *root,
- DynMatcher matcher) {
- std::vector<Operation *> matches;
+private:
+ static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ mlir::Operation *op, const std::string &binding) {
+ auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+ auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+ "\"" + binding + "\" binds here");
+ };
- // Simple match finding with walk.
+public:
+ static std::vector<Operation *>
+ getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+ llvm::raw_ostream &os, QuerySession &qs) {
+ unsigned matchCount = 0;
+ std::vector<Operation *> matchedOps;
+ SetVector<Operation *> tempStorage;
+ os << "\n";
root->walk([&](Operation *subOp) {
- if (matcher.match(subOp))
- matches.push_back(subOp);
+ if (matcher.match(subOp)) {
+ matchedOps.push_back(subOp);
+ os << "Match #" << ++matchCount << ":\n\n";
+ printMatch(os, qs, subOp, "root");
+ } else {
+ SmallVector<Operation *> printingOps;
+ if (matcher.match(subOp, tempStorage, options)) {
+ os << "Match #" << ++matchCount << ":\n\n";
+ SmallVector<Operation *> printingOps(tempStorage.takeVector());
+ for (auto op : printingOps) {
+ printMatch(os, qs, op, "root");
+ matchedOps.push_back(op);
+ }
+ printingOps.clear();
+ }
+ }
});
-
- return matches;
+ os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+ return matchedOps;
}
};
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e3..b532b47be7d051 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,4 +1,3 @@
-//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,28 +9,53 @@
//
// Matchers are methods that return a Matcher which provides a method
// match(Operation *op)
+// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
+// &options)
//
// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
//===----------------------------------------------------------------------===//
-
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
+namespace mlir {
+namespace query {
+struct QueryOptions;
+}
+} // namespace mlir
+
namespace mlir::query::matcher {
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>()))>>
+ : std::true_type {};
+
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>(),
+ std::declval<SetVector<Operation *> &>(),
+ std::declval<QueryOptions &>()))>>
+ : std::true_type {};
// Generic interface for matchers on an MLIR operation.
class MatcherInterface
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
public:
virtual ~MatcherInterface() = default;
-
virtual bool match(Operation *op) = 0;
+ virtual bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) = 0;
};
// MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +64,26 @@ template <typename MatcherFn>
class MatcherFnImpl : public MatcherInterface {
public:
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
- bool match(Operation *op) override { return matcherFn.match(op); }
+
+ bool match(Operation *op) override {
+ if constexpr (has_simple_match<MatcherFn>::value)
+ return matcherFn.match(op);
+ return false;
+ }
+
+ bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) override {
+ if constexpr (has_bound_match<MatcherFn>::value)
+ return matcherFn.match(op, matchedOps, options);
+ return false;
+ }
private:
MatcherFn matcherFn;
};
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
class DynMatcher {
public:
// Takes ownership of the provided implementation pointer.
@@ -62,12 +98,14 @@ class DynMatcher {
}
bool match(Operation *op) const { return implementation->match(op); }
+ bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) const {
+ return implementation->match(op, matchedOps, options);
+ }
- void setFunctionName(StringRef name) { functionName = name.str(); };
-
- bool hasFunctionName() const { return !functionName.empty(); };
-
- StringRef getFunctionName() const { return functionName; };
+ void setFunctionName(StringRef name) { functionName = name.str(); }
+ bool hasFunctionName() const { return !functionName.empty(); }
+ StringRef getFunctionName() const { return functionName; }
private:
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e021..6b57119df7a9bf 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
namespace mlir::query::matcher {
// All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Matcher, String, Unsigned };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
@@ -81,6 +81,7 @@ class VariantValue {
// Specific constructors for each supported type.
VariantValue(const llvm::StringRef string);
VariantValue(const VariantMatcher &matcher);
+ VariantValue(unsigned Unsigned);
// String value functions.
bool isString() const;
@@ -92,8 +93,15 @@ class VariantValue {
const VariantMatcher &getMatcher() const;
void setMatcher(const VariantMatcher &matcher);
+ // Unsigned value functions.
+ bool isUnsigned() const;
+ unsigned getUnsigned() const;
+ void setUnsigned(unsigned Unsigned);
+
// String representation of the type of the value.
std::string getTypeAsString() const;
+ explicit operator bool() const { return hasValue(); }
+ bool hasValue() const { return type != ValueType::Nothing; }
private:
void reset();
@@ -103,12 +111,14 @@ class VariantValue {
Nothing,
String,
Matcher,
+ Unsigned,
};
// All supported value types.
union AllValues {
llvm::StringRef *String;
VariantMatcher *Matcher;
+ unsigned Unsigned;
};
ValueType type;
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 18f2172c9510a3..bb5b98432d51c2 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -17,7 +17,13 @@
namespace mlir::query {
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+struct QueryOptions {
+ bool omitBlockArguments = false;
+ bool omitUsesFromAbove = true;
+ bool inclusive = true;
+};
+
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit, SetBool };
class QuerySession;
@@ -103,6 +109,32 @@ struct MatchQuery : Query {
}
};
+template <typename T>
+struct SetQueryKind {};
+
+template <>
+struct SetQueryKind<bool> {
+ static const QueryKind value = QueryKind::SetBool;
+};
+template <typename T>
+struct SetQuery : Query {
+ SetQuery(T QuerySession::*var, T value)
+ : Query(SetQueryKind<T>::value), var(var), value(value) {}
+
+ llvm::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override {
+ qs.*var = value;
+ return mlir::success();
+ }
+
+ static bool classof(const Query *query) {
+ return query->kind == SetQueryKind<T>::value;
+ }
+
+ T QuerySession::*var;
+ T value;
+};
+
} // namespace mlir::query
#endif
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index fe552d750fc771..495358e8f36f94 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -9,14 +9,18 @@
#ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
+#include "Matcher/VariantValue.h"
#include "mlir/IR/Operation.h"
#include "mlir/Query/Matcher/Registry.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/SourceMgr.h"
+namespace mlir::query::matcher {
+class Registry;
+}
+
namespace mlir::query {
-class Registry;
// Represents the state for a particular mlir-query session.
class QuerySession {
public:
@@ -33,6 +37,11 @@ class QuerySession {
llvm::StringMap<matcher::VariantValue> namedValues;
bool terminate = false;
+public:
+ bool omitBlockArguments = false;
+ bool omitUsesFromAbove = true;
+ bool inclusive = true;
+
private:
Operation *rootOp;
llvm::SourceMgr &sourceMgr;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f7..4dcb86a9383f32 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
case '\'':
consumeStringLiteral(&result);
break;
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ consumeNumberLiteral(&result);
+ break;
default:
parseIdentifierOrInvalid(&result);
break;
@@ -144,6 +156,30 @@ class Parser::CodeTokenizer {
return result;
}
+ void consumeNumberLiteral(TokenInfo *result) {
+ unsigned length = 1;
+ if (code.size() > 1) {
+ // Consume the 'x' or 'b' radix modifier, if present.
+ switch (tolower(code[1])) {
+ case 'x':
+ case 'b':
+ length = 2;
+ }
+ }
+ while (length < code.size() && isdigit(code[length]))
+ ++length;
+
+ result->text = code.take_front(length);
+ code = code.drop_front(length);
+
+ unsigned value;
+ if (!result->text.getAsInteger(0, value)) {
+ result->kind = TokenKind::Literal;
+ result->value = static_cast<unsigned>(value);
+ return;
+ }
+ }
+
// Consume a string literal, handle escape sequences and missing closing
// quote.
void consumeStringLiteral(TokenInfo *result) {
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 645db7109c2deb..8d6c0135aa1176 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -31,6 +31,8 @@ static std::string asArgString(ArgKind kind) {
return "Matcher";
case ArgKind::String:
return "String";
+ case ArgKind::Unsigned:
+ return "unsigned";
}
llvm_unreachable("Unhandled ArgKind");
}
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 65bd4bd77bcf8a..d5218d8dad8c93 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,6 +56,10 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
value.Matcher = new VariantMatcher(matcher);
}
+VariantValue::VariantValue(unsigned Unsigned) : type(ValueType::Unsigned) {
+ value.Unsigned = Unsigned;
+}
+
VariantValue::~VariantValue() { reset(); }
VariantValue &VariantValue::operator=(const VariantValue &other) {
@@ -69,6 +73,9 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
case ValueType::Matcher:
setMatcher(other.getMatcher());
break;
+ case ValueType::Unsigned:
+ setUnsigned(other.getUnsigned());
+ break;
case ValueType::Nothing:
type = ValueType::Nothing;
break;
@@ -85,12 +92,27 @@ void VariantValue::reset() {
delete value.Matcher;
break;
// Cases that do nothing.
+ case ValueType::Unsigned:
case ValueType::Nothing:
break;
}
type = ValueType::Nothing;
}
+// Unsinged
+bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
+
+unsigned VariantValue::getUnsigned() const {
+ assert(isUnsigned());
+ return value.Unsigned;
+}
+
+void VariantValue::setUnsigned(unsigned newValue) {
+ reset();
+ type = ValueType::Unsigned;
+ value.Unsigned = newValue;
+}
+
bool VariantValue::isString() const { return type == ValueType::String; }
const llvm::StringRef &VariantValue::getString() const {
@@ -123,6 +145,8 @@ std::string VariantValue::getTypeAsString() const {
return "String";
case ValueType::Matcher:
return "Matcher";
+ case ValueType::Unsigned:
+ return "Unsigned";
case ValueType::Nothing:
return "Nothing";
}
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 7d9f3606700519..dd699857568d7e 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
@@ -26,15 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
return QueryParser::complete(line, pos, qs);
}
-static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
- const std::string &binding) {
- auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
- auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
- qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
- qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
- "\"" + binding + "\" binds here");
-}
-
// TODO: Extract into a helper function that can be reused outside query
// context.
static Operation *extractFunction(std::vector<Operation *> &ops,
@@ -99,6 +91,12 @@ static Operation *extractFunction(std::vector<Operation *> &ops,
return funcOp;
}
+static void parseQueryOptions(QuerySession &qs, QueryOptions &options) {
+ options.omitBlockArguments = qs.omitBlockArguments;
+ options.omitUsesFromAbove = qs.omitUsesFromAbove;
+ options.inclusive = qs.inclusive;
+}
+
Query::~Query() = default;
LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
@@ -114,6 +112,11 @@ LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
os << "Available commands:\n\n"
" match MATCHER, m MATCHER "
"Match the mlir against the given matcher.\n"
+ "Set query options, useful for complex matchers \n"
+ " set omitBlockArguments (true|false) \n"
+ " set omitUsesFromAbove (true|false) \n"
+ " set inclusive (true|false) \n"
+ "Give a matcher expression a name, to be used later\n"
" quit "
"Terminates the query session.\n\n";
return mlir::success();
@@ -126,9 +129,11 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
- int matchCount = 0;
- std::vector<Operation *> matches =
- matcher::MatchFinder().getMatches(rootOp, matcher);
+
+ QueryOptions options;
+ parseQueryOptions(qs, options);
+ auto matches = matcher::MatchFinder().getMatches(rootOp, options,
+ std::move(matcher), os, qs);
// An extract call is recognized by considering if the matcher has a name.
// TODO: Consider making the extract more explicit.
@@ -141,14 +146,6 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
return mlir::success();
}
- os << "\n";
- for (Operation *op : matches) {
- os << "Match #" << ++matchCount << ":\n\n";
- // Placeholder "root" binding for the initial draft.
- printMatch(os, qs, op, "root");
- }
- os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
-
return mlir::success();
}
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 31aead7d403d0d..7aaf4847f2e47b 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -107,6 +107,18 @@ QueryRef QueryParser::endQuery(QueryRef queryRef) {
return queryRef;
}
+QueryRef QueryParser::parseSetBool(bool QuerySession::*var) {
+ StringRef valStr;
+ unsigned value = LexOrCompleteWord<unsigned>(this, valStr)
+ .Case("false", 0)
+ .Case("true", 1)
+ .Default(~0u);
+ if (value == ~0u) {
+ return new InvalidQuery("expected 'true' or 'false', got '" + valStr + "'");
+ }
+ return new SetQuery<bool>(var, value);
+}
+
namespace {
enum class ParsedQueryKind {
@@ -116,6 +128,14 @@ enum class ParsedQueryKind {
Help,
Match,
Quit,
+ Set,
+};
+
+enum ParsedQueryVariable {
+ Invalid,
+ OmitBlockArguments,
+ OmitUsesFromAbove,
+ Inclusive,
};
QueryRef
@@ -147,6 +167,7 @@ QueryRef QueryParser::doParse() {
.Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
.Case("help", ParsedQueryKind::Help)
.Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
+ .Case("set", ParsedQueryKind::Set)
.Case("match", ParsedQueryKind::Match)
.Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false)
.Case("quit", ParsedQueryKind::Quit)
@@ -187,7 +208,36 @@ QueryRef QueryParser::doParse() {
query->remainingContent = matcherSource;
return query;
}
-
+ case ParsedQueryKind::Set: {
+ llvm::StringRef varStr;
+ ParsedQueryVariable var =
+ LexOrCompleteWord<ParsedQueryVariable>(this, varStr)
+ .Case("omitBlockArguments", ParsedQueryVariable::OmitBlockArguments)
+ .Case("omitUsesFromAbove", ParsedQueryVariable::OmitUsesFromAbove)
+ .Case("inclusive", ParsedQueryVariable::Inclusive)
+ .Default(ParsedQueryVariable::Invalid);
+ if (varStr.empty()) {
+ return new InvalidQuery("expected variable name");
+ }
+ if (var == ParsedQueryVariable::Invalid) {
+ return new InvalidQuery("unknown variable: '" + varStr + "'");
+ }
+ QueryRef query;
+ switch (var) {
+ case ParsedQueryVariable::OmitBlockArguments:
+ query = parseSetBool(&QuerySession::omitBlockArguments);
+ break;
+ case ParsedQueryVariable::OmitUsesFromAbove:
+ query = parseSetBool(&QuerySession::omitUsesFromAbove);
+ break;
+ case ParsedQueryVariable::Inclusive:
+ query = parseSetBool(&QuerySession::inclusive);
+ break;
+ case ParsedQueryVariable::Invalid:
+ llvm_unreachable("Invalid query kind");
+ }
+ return endQuery(query);
+ }
case ParsedQueryKind::Invalid:
return new InvalidQuery("unknown command: " + commandStr);
}
diff --git a/mlir/lib/Query/QueryParser.h b/mlir/lib/Query/QueryParser.h
index e9c30eccecab9e..69cc5d0043d57d 100644
--- a/mlir/lib/Query/QueryParser.h
+++ b/mlir/lib/Query/QueryParser.h
@@ -39,8 +39,8 @@ class QueryParser {
struct LexOrCompleteWord;
QueryRef completeMatcherExpression();
-
QueryRef endQuery(QueryRef queryRef);
+ QueryRef parseSetBool(bool QuerySession::*var);
// Parse [begin, end) and returns a reference to the parsed query object,
// which may be an InvalidQuery if a parse error occurs.
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
new file mode 100644
index 00000000000000..b3df534ee88713
--- /dev/null
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-query %s -c "match getDefinitions(hasOpName("arith.addf"),2)" | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
+ %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %2 = arith.addf %in, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<5x5xf32>
+ %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
+ %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %c2 = arith.constant 2 : index
+ %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
+ %2 = arith.addf %extracted, %extracted : f32
+ linalg.yield %2 : f32
+ } -> tensor<5x5xf32>
+ return
+}
+
+// CHECK: Match #1:
+
+// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
+
+// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
+
+
+// CHECK: Match #2:
+
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %0 {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%c2] : tensor<25xf32>
+// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
+
+
+
+
+
+
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index a783f65c6761bc..d7a867eb1a4525 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-query %s -c "m hasOpName(\"arith.mulf\").extract(\"testmul\")" | FileCheck %s
+// RUN: mlir-query %s -c "m hasOpName("arith.mulf").extract("testmul")" | FileCheck %s
// CHECK: func.func @testmul({{.*}}) -> (f32, f32, f32) {
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b0..5e74da7ee7bdc4 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -15,6 +15,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/ExtraMatchers.h"
+#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/Matcher/Registry.h"
#include "mlir/Tools/mlir-query/MlirQueryMain.h"
@@ -39,6 +41,14 @@ int main(int argc, char **argv) {
query::matcher::Registry matcherRegistry;
// Matchers registered in alphabetical order for consistency:
+ matcherRegistry.registerMatcher("getDefinitions",
+ mlir::query::extramatcher::getDefinitions);
+ matcherRegistry.registerMatcher("definedBy",
+ mlir::query::extramatcher::definedBy);
+ matcherRegistry.registerMatcher("usedBy", mlir::query::extramatcher::usedBy);
+ matcherRegistry.registerMatcher("getUses",
+ mlir::query::extramatcher::getUses);
+
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
>From 0a12247e12ce6a1c7dd32f86726e93df6fccac8c Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Wed, 23 Oct 2024 17:42:19 +0000
Subject: [PATCH 2/2] MLIR-QUERY: backwardSlice, forwardSlice & QueryOptions
added
---
mlir/lib/Query/Matcher/Parser.cpp | 18 +++++++++-----
mlir/lib/Query/Query.cpp | 9 +++++++
mlir/lib/Query/QueryParser.cpp | 24 +++++++++++++++++++
mlir/test/mlir-query/function-extraction.mlir | 4 ++--
mlir/tools/mlir-query/mlir-query.cpp | 1 -
5 files changed, 47 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 4dcb86a9383f32..726f1188d7e4c8 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -293,13 +293,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
// Parse as a named value.
- auto namedValue =
- namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+ if (auto namedValue = namedValues ? namedValues->lookup(nameToken.text)
+ : VariantValue()) {
- if (!namedValue.isMatcher()) {
- error->addError(tokenizer->peekNextToken().range,
- ErrorType::ParserNotAMatcher);
- return false;
+ if (tokenizer->nextTokenKind() != TokenKind::Period) {
+ *value = namedValue;
+ return true;
+ }
+
+ if (!namedValue.isMatcher()) {
+ error->addError(tokenizer->peekNextToken().range,
+ ErrorType::ParserNotAMatcher);
+ return false;
+ }
}
if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index dd699857568d7e..500fee50a16093 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -127,6 +127,15 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
return mlir::success();
}
+LogicalResult LetQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
+ if (value.hasValue()) {
+ qs.namedValues[name] = value;
+ } else {
+ qs.namedValues.erase(name);
+ }
+ return mlir::success();
+}
+
LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 7aaf4847f2e47b..4350fb9a434d40 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -166,6 +166,8 @@ QueryRef QueryParser::doParse() {
.Case("", ParsedQueryKind::NoOp)
.Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
.Case("help", ParsedQueryKind::Help)
+ .Case("l", ParsedQueryKind::Let, /*isCompletion=*/false)
+ .Case("let", ParsedQueryKind::Let, /*isCompletion=*/false)
.Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
.Case("set", ParsedQueryKind::Set)
.Case("match", ParsedQueryKind::Match)
@@ -188,6 +190,27 @@ QueryRef QueryParser::doParse() {
case ParsedQueryKind::Quit:
return endQuery(new QuitQuery);
+ case ParsedQueryKind::Let: {
+ llvm::StringRef name = lexWord();
+
+ if (name.empty()) {
+ return new InvalidQuery("expected variable name");
+ }
+
+ if (completionPos) {
+ return completeMatcherExpression();
+ }
+
+ matcher::internal::Diagnostics diag;
+ matcher::VariantValue value;
+ if (!matcher::internal::Parser::parseExpression(
+ line, qs.getRegistryData(), &qs.namedValues, &value, &diag)) {
+ return makeInvalidQueryFromDiagnostics(diag);
+ }
+ QueryRef query = new LetQuery(name, value);
+ query->remainingContent = line;
+ return query;
+ }
case ParsedQueryKind::Match: {
if (completionPos) {
return completeMatcherExpression();
@@ -204,6 +227,7 @@ QueryRef QueryParser::doParse() {
}
auto actualSource = origMatcherSource.substr(0, origMatcherSource.size() -
matcherSource.size());
+
QueryRef query = new MatchQuery(actualSource, *matcher);
query->remainingContent = matcherSource;
return query;
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index d7a867eb1a4525..5a20c09d02eb63 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -4,7 +4,7 @@
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
-// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
+// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32S
func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
%sum0 = arith.addf %a, %b : f32
@@ -16,4 +16,4 @@ func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
%sum2 = arith.addf %mul1, %b : f32
%mul2 = arith.mulf %sub2, %sum2 : f32
return %mul2 : f32
-}
+ }
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 5e74da7ee7bdc4..468f948bec24ca 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -10,7 +10,6 @@
// of the registered queries.
//
//===----------------------------------------------------------------------===//
-
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
More information about the Mlir-commits
mailing list