[Mlir-commits] [mlir] [mlir] MLIR-QUERY DefinitionsMatcher implementation & DAG (PR #115670)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 19 05:09:58 PST 2025
https://github.com/dbudii updated https://github.com/llvm/llvm-project/pull/115670
>From cb380e7d65ab7cdd83285a05c6eb5c6cc3a50cce Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Wed, 23 Oct 2024 17:42:19 +0000
Subject: [PATCH 1/5] MLIR-QUERY DefinitionsMatcher implementation & DAG
- included printing logic for DAG - sfinae for match methods
---
.../mlir/Query/Matcher/ExtraMatchers.h | 109 ++++++++++++++
mlir/include/mlir/Query/Matcher/Marshallers.h | 17 ++-
mlir/include/mlir/Query/Matcher/MatchFinder.h | 18 ++-
.../mlir/Query/Matcher/MatchersInternal.h | 134 +++++++++++++----
.../include/mlir/Query/Matcher/VariantValue.h | 10 +-
mlir/lib/Query/Matcher/Parser.cpp | 59 ++++++++
mlir/lib/Query/Matcher/RegistryManager.cpp | 2 +
mlir/lib/Query/Matcher/VariantValue.cpp | 23 +++
mlir/lib/Query/Query.cpp | 141 +++++++++++++++---
mlir/tools/mlir-query/mlir-query.cpp | 7 +-
10 files changed, 463 insertions(+), 57 deletions(-)
create mode 100644 mlir/include/mlir/Query/Matcher/ExtraMatchers.h
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 000000000000000..1764ad35cc9c303
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,109 @@
+//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides extra matchers that are very useful for mlir-query
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTRAMATCHERS_H
+#define MLIR_IR_EXTRAMATCHERS_H
+
+#include "MatchFinder.h"
+#include "MatchersInternal.h"
+
+namespace mlir {
+
+namespace query {
+
+namespace extramatcher {
+
+namespace detail {
+
+class DefinitionsMatcher {
+public:
+ DefinitionsMatcher(matcher::DynMatcher &&InnerMatcher, unsigned Hops)
+ : InnerMatcher(std::move(InnerMatcher)), Hops(Hops) {}
+
+private:
+ bool matches(Operation *op, matcher::BoundOperationsGraphBuilder &Bound,
+ unsigned TempHops) {
+
+ llvm::DenseSet<mlir::Value> Ccache;
+ llvm::SmallVector<std::pair<Operation *, size_t>, 4> TempStorage;
+ TempStorage.push_back({op, TempHops});
+ while (!TempStorage.empty()) {
+ auto [CurrentOp, RemainingHops] = TempStorage.pop_back_val();
+
+ matcher::BoundOperationNode *CurrentNode =
+ Bound.addNode(CurrentOp, true, true);
+ if (RemainingHops == 0) {
+ continue;
+ }
+
+ for (auto Operand : CurrentOp->getOperands()) {
+ if (auto DefiningOp = Operand.getDefiningOp()) {
+ Bound.addEdge(CurrentOp, DefiningOp);
+ if (!Ccache.contains(Operand)) {
+ Ccache.insert(Operand);
+ TempStorage.emplace_back(DefiningOp, RemainingHops - 1);
+ }
+ } else if (auto BlockArg = Operand.dyn_cast<BlockArgument>()) {
+ auto *Block = BlockArg.getOwner();
+
+ if (Block->isEntryBlock() &&
+ isa<FunctionOpInterface>(Block->getParentOp())) {
+ continue;
+ }
+
+ Operation *ParentOp = BlockArg.getOwner()->getParentOp();
+ if (ParentOp) {
+ Bound.addEdge(CurrentOp, ParentOp);
+ if (!!Ccache.contains(BlockArg)) {
+ Ccache.insert(BlockArg);
+ TempStorage.emplace_back(ParentOp, RemainingHops - 1);
+ }
+ }
+ }
+ }
+ }
+ // We need at least 1 defining op
+ return Ccache.size() >= 2;
+ }
+
+public:
+ bool match(Operation *op, matcher::BoundOperationsGraphBuilder &Bound) {
+ if (InnerMatcher.match(op) && matches(op, Bound, Hops)) {
+ return true;
+ }
+ return false;
+ }
+
+private:
+ matcher::DynMatcher InnerMatcher;
+ unsigned Hops;
+};
+} // namespace detail
+
+inline detail::DefinitionsMatcher
+definedBy(mlir::query::matcher::DynMatcher InnerMatcher) {
+ return detail::DefinitionsMatcher(std::move(InnerMatcher), 1);
+}
+
+inline detail::DefinitionsMatcher
+getDefinitions(mlir::query::matcher::DynMatcher InnerMatcher, unsigned Hops) {
+ assert(Hops > 0 && "hops must be >= 1");
+ return detail::DefinitionsMatcher(std::move(InnerMatcher), Hops);
+}
+
+} // namespace extramatcher
+
+} // namespace query
+
+} // namespace mlir
+
+#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc70..4a08b9af82c26cf 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
}
};
+template <>
+struct ArgTypeTraits<unsigned> {
+ static bool hasCorrectType(const VariantValue &value) {
+ return value.isUnsigned();
+ }
+
+ static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+
+ static ArgKind getKind() { return ArgKind::Unsigned; }
+
+ static std::optional<std::string> getBestGuess(const VariantValue &) {
+ return std::nullopt;
+ }
+};
+
template <>
struct ArgTypeTraits<DynMatcher> {
@@ -166,7 +181,7 @@ matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
return VariantMatcher::SingleMatcher(
- *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
+ *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer, matcherName));
}
return VariantMatcher();
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2a6..4664e48b51b94a2 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -15,6 +15,7 @@
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#include "MatchersInternal.h"
+#include "mlir/IR/Operation.h"
namespace mlir::query::matcher {
@@ -22,17 +23,18 @@ namespace mlir::query::matcher {
class MatchFinder {
public:
// Returns all operations that match the given matcher.
- static std::vector<Operation *> getMatches(Operation *root,
- DynMatcher matcher) {
- std::vector<Operation *> matches;
+ static BoundOperationsGraphBuilder getMatches(Operation *root,
+ DynMatcher matcher) {
- // Simple match finding with walk.
+ BoundOperationsGraphBuilder Bound;
root->walk([&](Operation *subOp) {
- if (matcher.match(subOp))
- matches.push_back(subOp);
+ if (matcher.match(subOp)) {
+ matcher::BoundOperationNode *currentNode = Bound.addNode(subOp);
+ } else if (matcher.match(subOp, Bound)) {
+ ////
+ }
});
-
- return matches;
+ return Bound;
}
};
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e31..cb4063dc2845260 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,19 +1,8 @@
//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// Part of the LLVM Project, under the Apache License v2.0 wIth LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Implements the base layer of the matcher framework.
-//
-// Matchers are methods that return a Matcher which provides a method
-// match(Operation *op)
-//
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
-// This file contains the wrapper classes needed to construct matchers for
-// mlir-query.
+// SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
//
//===----------------------------------------------------------------------===//
@@ -22,16 +11,91 @@
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/ADT/MapVector.h"
+#include <memory>
+#include <stack>
+#include <unordered_set>
+#include <vector>
namespace mlir::query::matcher {
+struct BoundOperationNode {
+ Operation *op;
+ std::vector<BoundOperationNode *> Parents;
+ std::vector<BoundOperationNode *> Children;
+
+ bool IsRootNode;
+ bool DetailedPrinting;
+
+ BoundOperationNode(Operation *op, bool IsRootNode = false,
+ bool DetailedPrinting = false)
+ : op(op), IsRootNode(IsRootNode), DetailedPrinting(DetailedPrinting) {}
+};
+
+class BoundOperationsGraphBuilder {
+public:
+ BoundOperationNode *addNode(Operation *op, bool IsRootNode = false,
+ bool DetailedPrinting = false) {
+ auto It = Nodes.find(op);
+ if (It != Nodes.end()) {
+ return It->second.get();
+ }
+ auto Node =
+ std::make_unique<BoundOperationNode>(op, IsRootNode, DetailedPrinting);
+ BoundOperationNode *NodePtr = Node.get();
+ Nodes[op] = std::move(Node);
+ return NodePtr;
+ }
+
+ void addEdge(Operation *parentOp, Operation *childOp) {
+ BoundOperationNode *ParentNode = addNode(parentOp, false, false);
+ BoundOperationNode *ChildNode = addNode(childOp, false, false);
+
+ ParentNode->Children.push_back(ChildNode);
+ ChildNode->Parents.push_back(ParentNode);
+ }
+
+ BoundOperationNode *getNode(Operation *op) const {
+ auto It = Nodes.find(op);
+ return It != Nodes.end() ? It->second.get() : nullptr;
+ }
+
+ const llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> &
+ getNodes() const {
+ return Nodes;
+ }
+
+private:
+ llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> Nodes;
+};
+
+// Type traIt to detect if a matcher has a match(Operation*) method
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>()))>>
+ : std::true_type {};
+
+// Type traIt to detect if a matcher has a match(Operation*,
+// BoundOperationsGraphBuilder&) method
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+ std::declval<Operation *>(),
+ std::declval<BoundOperationsGraphBuilder &>()))>>
+ : std::true_type {};
+
// Generic interface for matchers on an MLIR operation.
class MatcherInterface
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
public:
virtual ~MatcherInterface() = default;
-
virtual bool match(Operation *op) = 0;
+ virtual bool match(Operation *op, BoundOperationsGraphBuilder &bound) = 0;
};
// MatcherFnImpl takes a matcher function object and implements
@@ -40,40 +104,56 @@ template <typename MatcherFn>
class MatcherFnImpl : public MatcherInterface {
public:
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
- bool match(Operation *op) override { return matcherFn.match(op); }
+
+ bool match(Operation *op) override {
+ if constexpr (has_simple_match<MatcherFn>::value)
+ return matcherFn.match(op);
+ return false;
+ }
+
+ bool match(Operation *op, BoundOperationsGraphBuilder &bound) override {
+ if constexpr (has_bound_match<MatcherFn>::value)
+ return matcherFn.match(op, bound);
+ return false;
+ }
private:
MatcherFn matcherFn;
};
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
class DynMatcher {
public:
// Takes ownership of the provided implementation pointer.
- DynMatcher(MatcherInterface *implementation)
- : implementation(implementation) {}
+ DynMatcher(MatcherInterface *implementation, StringRef matcherName)
+ : implementation(implementation), matcherName(matcherName.str()) {}
template <typename MatcherFn>
static std::unique_ptr<DynMatcher>
- constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
+ constructDynMatcherFromMatcherFn(MatcherFn &matcherFn,
+ StringRef matcherName) {
auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
- return std::make_unique<DynMatcher>(impl.release());
+ return std::make_unique<DynMatcher>(impl.release(), matcherName);
}
bool match(Operation *op) const { return implementation->match(op); }
+ bool match(Operation *op, BoundOperationsGraphBuilder &bound) const {
+ return implementation->match(op, bound);
+ }
- void setFunctionName(StringRef name) { functionName = name.str(); };
-
- bool hasFunctionName() const { return !functionName.empty(); };
-
- StringRef getFunctionName() const { return functionName; };
+ void setFunctionName(StringRef name) { functionName = name.str(); }
+ void setMatcherName(StringRef name) { matcherName = name.str(); }
+ bool hasFunctionName() const { return !functionName.empty(); }
+ StringRef getFunctionName() const { return functionName; }
+ StringRef getMatcherName() const { return matcherName; }
private:
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
+ std::string matcherName;
std::string functionName;
};
} // namespace mlir::query::matcher
-#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e0217..73d96a6913dfe4a 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
namespace mlir::query::matcher {
// All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Matcher, String, Unsigned };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
@@ -81,6 +81,7 @@ class VariantValue {
// Specific constructors for each supported type.
VariantValue(const llvm::StringRef string);
VariantValue(const VariantMatcher &matcher);
+ VariantValue(unsigned Unsigned);
// String value functions.
bool isString() const;
@@ -92,6 +93,11 @@ class VariantValue {
const VariantMatcher &getMatcher() const;
void setMatcher(const VariantMatcher &matcher);
+ // Unsigned value functions.
+ bool isUnsigned() const;
+ unsigned getUnsigned() const;
+ void setUnsigned(unsigned Unsigned);
+
// String representation of the type of the value.
std::string getTypeAsString() const;
@@ -103,12 +109,14 @@ class VariantValue {
Nothing,
String,
Matcher,
+ Unsigned,
};
// All supported value types.
union AllValues {
llvm::StringRef *String;
VariantMatcher *Matcher;
+ unsigned Unsigned;
};
ValueType type;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f7c..4f1b716756e318e 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
case '\'':
consumeStringLiteral(&result);
break;
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ consumeNumberLiteral(&result);
+ break;
default:
parseIdentifierOrInvalid(&result);
break;
@@ -144,6 +156,53 @@ class Parser::CodeTokenizer {
return result;
}
+ void consumeNumberLiteral(TokenInfo *result) {
+ bool isFloatingLiteral = false;
+ unsigned length = 1;
+ if (code.size() > 1) {
+ // Consume the 'x' or 'b' radix modifier, if present.
+ switch (tolower(code[1])) {
+ case 'x':
+ case 'b':
+ length = 2;
+ }
+ }
+ while (length < code.size() && isdigit(code[length]))
+ ++length;
+
+ // Try to recognize a floating point literal.
+ while (length < code.size()) {
+ char c = code[length];
+ if (c == '-' || c == '+' || c == '.' || isdigit(c)) {
+ isFloatingLiteral = true;
+ length++;
+ } else {
+ break;
+ }
+ }
+
+ result->text = code.take_front(length);
+ code = code.drop_front(length);
+
+ if (isFloatingLiteral) {
+ char *end;
+ errno = 0;
+ std::string text = result->text.str();
+ double doubleValue = strtod(text.c_str(), &end);
+ if (*end == 0 && errno == 0) {
+ result->kind = TokenKind::Literal;
+ result->value = doubleValue;
+ return;
+ }
+ } else {
+ unsigned value;
+ if (!result->text.getAsInteger(0, value)) {
+ result->kind = TokenKind::Literal;
+ result->value = value;
+ return;
+ }
+ }
+ }
// Consume a string literal, handle escape sequences and missing closing
// quote.
void consumeStringLiteral(TokenInfo *result) {
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 645db7109c2deb3..8d6c0135aa11768 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -31,6 +31,8 @@ static std::string asArgString(ArgKind kind) {
return "Matcher";
case ArgKind::String:
return "String";
+ case ArgKind::Unsigned:
+ return "unsigned";
}
llvm_unreachable("Unhandled ArgKind");
}
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 65bd4bd77bcf8af..50d79512196d1a5 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -56,6 +56,10 @@ VariantValue::VariantValue(const VariantMatcher &matcher)
value.Matcher = new VariantMatcher(matcher);
}
+VariantValue::VariantValue(unsigned Unsigned) : type(ValueType::Unsigned) {
+ value.Unsigned = Unsigned;
+}
+
VariantValue::~VariantValue() { reset(); }
VariantValue &VariantValue::operator=(const VariantValue &other) {
@@ -69,6 +73,9 @@ VariantValue &VariantValue::operator=(const VariantValue &other) {
case ValueType::Matcher:
setMatcher(other.getMatcher());
break;
+ case ValueType::Unsigned:
+ setUnsigned(other.getUnsigned());
+ break;
case ValueType::Nothing:
type = ValueType::Nothing;
break;
@@ -85,12 +92,26 @@ void VariantValue::reset() {
delete value.Matcher;
break;
// Cases that do nothing.
+ case ValueType::Unsigned:
case ValueType::Nothing:
break;
}
type = ValueType::Nothing;
}
+bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
+
+unsigned VariantValue::getUnsigned() const {
+ assert(isUnsigned());
+ return value.Unsigned;
+}
+
+void VariantValue::setUnsigned(unsigned newValue) {
+ reset();
+ type = ValueType::Unsigned;
+ value.Unsigned = newValue;
+}
+
bool VariantValue::isString() const { return type == ValueType::String; }
const llvm::StringRef &VariantValue::getString() const {
@@ -123,6 +144,8 @@ std::string VariantValue::getTypeAsString() const {
return "String";
case ValueType::Matcher:
return "Matcher";
+ case ValueType::Unsigned:
+ return "Unsigned";
case ValueType::Nothing:
return "Nothing";
}
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 7d9f36067005198..70be7c36888d50b 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -12,8 +12,11 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
+#include <unordered_map>
+#include <unordered_set>
namespace mlir::query {
@@ -124,30 +127,130 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
return mlir::success();
}
+void collectMatchNodes(
+ matcher::BoundOperationNode *Node,
+ llvm::SetVector<matcher::BoundOperationNode *> &MatchNodes) {
+ MatchNodes.insert(Node);
+ for (auto ChildNode : Node->Children) {
+ collectMatchNodes(ChildNode, MatchNodes);
+ }
+}
+
+void analyzeAndPrint(llvm::raw_ostream &os, QuerySession &qs,
+ const matcher::BoundOperationsGraphBuilder &Bound) {
+
+ const auto &Nodes = Bound.getNodes();
+ if (Nodes.empty()) {
+ os << "The graph is empty.\n";
+ return;
+ }
+
+ bool AnyDetailedPrinting = false;
+ for (const auto &Pair : Nodes) {
+ if (Pair.second->DetailedPrinting) {
+ AnyDetailedPrinting = true;
+ break;
+ }
+ }
+
+ unsigned MatchesCounter = 0;
+ if (!AnyDetailedPrinting) {
+ os << "Operations:\n";
+ for (const auto &Pair : Nodes) {
+ os << "\n";
+ os << " Match #" << ++MatchesCounter << "\n";
+ printMatch(os, qs, Pair.first, "root");
+ }
+ os << MatchesCounter << " matches found!\n";
+ return;
+ }
+
+ // Maps ids to nodes
+ std::unordered_map<Operation *, int> NodeIDs;
+ int id = 0;
+ for (const auto &Pair : Nodes) {
+ NodeIDs[Pair.first] = id++;
+ }
+
+ // Finds root nodes
+ std::vector<matcher::BoundOperationNode *> RootNodes;
+ for (const auto &Pair : Nodes) {
+ matcher::BoundOperationNode *Node = Pair.second.get();
+ if (Node->IsRootNode) {
+ RootNodes.push_back(Node);
+ }
+ }
+
+ for (auto RootNode : RootNodes) {
+ os << "\n";
+ os << " Match #" << ++MatchesCounter << "\n";
+
+ llvm::SetVector<matcher::BoundOperationNode *> MatchNodes;
+ collectMatchNodes(RootNode, MatchNodes);
+ std::vector<matcher::BoundOperationNode *> SortedMatchNodes(
+ MatchNodes.begin(), MatchNodes.end());
+
+ // Sorts based on file location
+ std::sort(
+ SortedMatchNodes.begin(), SortedMatchNodes.end(),
+ [&](matcher::BoundOperationNode *a, matcher::BoundOperationNode *b) {
+ auto fileLocA = a->op->getLoc()->findInstanceOf<FileLineColLoc>();
+ auto fileLocB = b->op->getLoc()->findInstanceOf<FileLineColLoc>();
+
+ if (!fileLocA && !fileLocB)
+ return false;
+ if (!fileLocA)
+ return false;
+ if (!fileLocB)
+ return true;
+
+ if (fileLocA.getFilename().str() != fileLocB.getFilename().str())
+ return fileLocA.getFilename().str() < fileLocB.getFilename().str();
+ return fileLocA.getLine() < fileLocB.getLine();
+ });
+
+ for (auto Node : SortedMatchNodes) {
+ unsigned NodeID = NodeIDs[Node->op];
+ std::string binding = Node->IsRootNode ? "root" : "";
+ os << NodeID << ": ";
+ printMatch(os, qs, Node->op, binding);
+ }
+
+ // Prints edges
+ os << "Edges:\n";
+ for (auto Node : MatchNodes) {
+ int ParentID = NodeIDs[Node->op];
+ for (auto ChildNode : Node->Children) {
+ if (MatchNodes.count(ChildNode) > 0) {
+ int ChildID = NodeIDs[ChildNode->op];
+ os << " " << ParentID << " ---> " << ChildID << "\n";
+ }
+ }
+ }
+ }
+ os << "\n" << MatchesCounter << " matches found!\n";
+}
+
LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
int matchCount = 0;
- std::vector<Operation *> matches =
- matcher::MatchFinder().getMatches(rootOp, matcher);
-
- // An extract call is recognized by considering if the matcher has a name.
- // TODO: Consider making the extract more explicit.
- if (matcher.hasFunctionName()) {
- auto functionName = matcher.getFunctionName();
- Operation *function =
- extractFunction(matches, rootOp->getContext(), functionName);
- os << "\n" << *function << "\n\n";
- function->erase();
- return mlir::success();
- }
+ auto matches = matcher::MatchFinder().getMatches(rootOp, matcher);
+
+ // An extract call is recognized by considering if the matcher has a
+ // name.TODO : Consider making the extract
+ // more explicit.
+ // if (matcher.hasFunctionName()) {
+ // auto functionName = matcher.getFunctionName();
+ // Operation *function = extractFunction(matches.getOperations(),
+ // rootOp->getContext(),
+ // functionName);
+ // os << "\n" << *function << "\n\n";
+ // function->erase();
+ // return mlir::success();
+ // }
os << "\n";
- for (Operation *op : matches) {
- os << "Match #" << ++matchCount << ":\n\n";
- // Placeholder "root" binding for the initial draft.
- printMatch(os, qs, op, "root");
- }
- os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+ analyzeAndPrint(os, qs, matches);
return mlir::success();
}
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b09..d5c0b1632d3c5d4 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -10,11 +10,12 @@
// of the registered queries.
//
//===----------------------------------------------------------------------===//
-
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/ExtraMatchers.h"
+#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/Matcher/Registry.h"
#include "mlir/Tools/mlir-query/MlirQueryMain.h"
@@ -39,6 +40,10 @@ int main(int argc, char **argv) {
query::matcher::Registry matcherRegistry;
// Matchers registered in alphabetical order for consistency:
+ matcherRegistry.registerMatcher("getDefinitions",
+ mlir::query::extramatcher::getDefinitions);
+ matcherRegistry.registerMatcher("definedBy",
+ mlir::query::extramatcher::definedBy);
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
>From 23d38a70b3905cb595d0c22d058712588bc9762e Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Wed, 8 Jan 2025 17:58:36 +0000
Subject: [PATCH 2/5] Added SetQuery, LetQuery, new implementation for matchers
---
mlir/include/mlir/IR/Matchers.h | 4 +-
.../mlir/Query/Matcher/ExtraMatchers.h | 186 +++++++++++++-----
mlir/include/mlir/Query/Matcher/MatchFinder.h | 14 +-
.../mlir/Query/Matcher/MatchersInternal.h | 76 ++-----
.../include/mlir/Query/Matcher/VariantValue.h | 2 +
mlir/include/mlir/Query/Query.h | 50 ++++-
mlir/include/mlir/Query/QuerySession.h | 5 +
mlir/lib/Query/Matcher/Parser.cpp | 20 +-
mlir/lib/Query/Matcher/VariantValue.cpp | 1 +
mlir/lib/Query/Query.cpp | 144 +++-----------
mlir/lib/Query/QueryParser.cpp | 77 +++++++-
mlir/lib/Query/QueryParser.h | 2 +-
mlir/test/mlir-query/complex-test.mlir | 23 +++
mlir/test/mlir-query/function-extraction.mlir | 2 +-
mlir/tools/mlir-query/mlir-query.cpp | 4 +
15 files changed, 367 insertions(+), 243 deletions(-)
create mode 100644 mlir/test/mlir-query/complex-test.mlir
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 1dce055db1b4a71..2204a68be26b104 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -59,7 +59,7 @@ struct NameOpMatcher {
NameOpMatcher(StringRef name) : name(name) {}
bool match(Operation *op) { return op->getName().getStringRef() == name; }
- StringRef name;
+ std::string name;
};
/// The matcher that matches operations that have the specified attribute name.
@@ -67,7 +67,7 @@ struct AttrOpMatcher {
AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
bool match(Operation *op) { return op->hasAttr(attrName); }
- StringRef attrName;
+ std::string attrName;
};
/// The matcher that matches operations that have the `ConstantLike` trait, and
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
index 1764ad35cc9c303..1900879ca70920b 100644
--- a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -15,6 +15,9 @@
#include "MatchFinder.h"
#include "MatchersInternal.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Query/Query.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
@@ -24,80 +27,161 @@ namespace extramatcher {
namespace detail {
-class DefinitionsMatcher {
+class BackwardSliceMatcher {
public:
- DefinitionsMatcher(matcher::DynMatcher &&InnerMatcher, unsigned Hops)
- : InnerMatcher(std::move(InnerMatcher)), Hops(Hops) {}
+ BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
private:
- bool matches(Operation *op, matcher::BoundOperationsGraphBuilder &Bound,
- unsigned TempHops) {
-
- llvm::DenseSet<mlir::Value> Ccache;
- llvm::SmallVector<std::pair<Operation *, size_t>, 4> TempStorage;
- TempStorage.push_back({op, TempHops});
- while (!TempStorage.empty()) {
- auto [CurrentOp, RemainingHops] = TempStorage.pop_back_val();
-
- matcher::BoundOperationNode *CurrentNode =
- Bound.addNode(CurrentOp, true, true);
- if (RemainingHops == 0) {
- continue;
- }
+ bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+ QueryOptions &options, unsigned tempHops) {
- for (auto Operand : CurrentOp->getOperands()) {
- if (auto DefiningOp = Operand.getDefiningOp()) {
- Bound.addEdge(CurrentOp, DefiningOp);
- if (!Ccache.contains(Operand)) {
- Ccache.insert(Operand);
- TempStorage.emplace_back(DefiningOp, RemainingHops - 1);
- }
- } else if (auto BlockArg = Operand.dyn_cast<BlockArgument>()) {
- auto *Block = BlockArg.getOwner();
+ bool validSlice = true;
+ if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+ return false;
+ }
- if (Block->isEntryBlock() &&
- isa<FunctionOpInterface>(Block->getParentOp())) {
- continue;
+ auto processValue = [&](Value value) {
+ if (tempHops == 0) {
+ return;
+ }
+ if (auto *definingOp = value.getDefiningOp()) {
+ if (backwardSlice.count(definingOp) == 0)
+ matches(definingOp, backwardSlice, options, tempHops - 1);
+ } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+ if (options.omitBlockArguments)
+ return;
+ Block *block = blockArg.getOwner();
+
+ Operation *parentOp = block->getParentOp();
+
+ if (parentOp && backwardSlice.count(parentOp) == 0) {
+ if (parentOp->getNumRegions() == 1 &&
+ parentOp->getRegion(0).getBlocks().size() == 1) {
+ validSlice = false;
+ return;
+ };
+ matches(parentOp, backwardSlice, options, tempHops - 1);
+ }
+ } else {
+ validSlice = false;
+ return;
+ }
+ };
+
+ if (!options.omitUsesFromAbove) {
+ llvm::for_each(op->getRegions(), [&](Region ®ion) {
+ SmallPtrSet<Region *, 4> descendents;
+ region.walk(
+ [&](Region *childRegion) { descendents.insert(childRegion); });
+ region.walk([&](Operation *op) {
+ for (OpOperand &operand : op->getOpOperands()) {
+ if (!descendents.contains(operand.get().getParentRegion()))
+ processValue(operand.get());
+ if (!validSlice)
+ return;
}
+ });
+ });
+ }
- Operation *ParentOp = BlockArg.getOwner()->getParentOp();
- if (ParentOp) {
- Bound.addEdge(CurrentOp, ParentOp);
- if (!!Ccache.contains(BlockArg)) {
- Ccache.insert(BlockArg);
- TempStorage.emplace_back(ParentOp, RemainingHops - 1);
- }
- }
- }
+ llvm::for_each(op->getOperands(), [&](Value operand) {
+ processValue(operand);
+ if (!validSlice)
+ return;
+ });
+ backwardSlice.insert(op);
+ if (!validSlice) {
+ return false;
+ }
+ return true;
+ }
+
+public:
+ bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+ QueryOptions &options) {
+ if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
+ if (!options.inclusive) {
+ backwardSlice.remove(op);
}
+ return true;
}
- // We need at least 1 defining op
- return Ccache.size() >= 2;
+ return false;
}
+private:
+ matcher::DynMatcher innerMatcher;
+ unsigned hops;
+};
+
+class ForwardSliceMatcher {
public:
- bool match(Operation *op, matcher::BoundOperationsGraphBuilder &Bound) {
- if (InnerMatcher.match(op) && matches(op, Bound, Hops)) {
+ ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+ bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+ QueryOptions &options, unsigned tempHops) {
+
+ if (tempHops == 0) {
+ forwardSlice.insert(op);
+ return true;
+ }
+
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ for (Operation &blockOp : block)
+ if (forwardSlice.count(&blockOp) == 0)
+ matches(&blockOp, forwardSlice, options, tempHops - 1);
+ for (Value result : op->getResults()) {
+ for (Operation *userOp : result.getUsers())
+ if (forwardSlice.count(userOp) == 0)
+ matches(userOp, forwardSlice, options, tempHops - 1);
+ }
+
+ forwardSlice.insert(op);
+ return true;
+ }
+
+public:
+ bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+ QueryOptions &options) {
+ if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
+ if (!options.inclusive) {
+ forwardSlice.remove(op);
+ }
+ SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+ forwardSlice.insert(v.rbegin(), v.rend());
return true;
}
return false;
}
private:
- matcher::DynMatcher InnerMatcher;
- unsigned Hops;
+ matcher::DynMatcher innerMatcher;
+ unsigned hops;
};
+
} // namespace detail
-inline detail::DefinitionsMatcher
-definedBy(mlir::query::matcher::DynMatcher InnerMatcher) {
- return detail::DefinitionsMatcher(std::move(InnerMatcher), 1);
+inline detail::BackwardSliceMatcher
+definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::BackwardSliceMatcher
+getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+ return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+inline detail::ForwardSliceMatcher
+usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
}
-inline detail::DefinitionsMatcher
-getDefinitions(mlir::query::matcher::DynMatcher InnerMatcher, unsigned Hops) {
- assert(Hops > 0 && "hops must be >= 1");
- return detail::DefinitionsMatcher(std::move(InnerMatcher), Hops);
+inline detail::ForwardSliceMatcher
+getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+ return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
}
} // namespace extramatcher
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 4664e48b51b94a2..cbdf2fec46a7611 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -23,18 +23,18 @@ namespace mlir::query::matcher {
class MatchFinder {
public:
// Returns all operations that match the given matcher.
- static BoundOperationsGraphBuilder getMatches(Operation *root,
- DynMatcher matcher) {
-
- BoundOperationsGraphBuilder Bound;
+ static SetVector<Operation *>
+ getMatches(Operation *root, QueryOptions &options, DynMatcher matcher) {
+ SetVector<Operation *> backwardSlice;
root->walk([&](Operation *subOp) {
if (matcher.match(subOp)) {
- matcher::BoundOperationNode *currentNode = Bound.addNode(subOp);
- } else if (matcher.match(subOp, Bound)) {
+ backwardSlice.insert(subOp);
+ } else {
+ matcher.match(subOp, backwardSlice, options);
////
}
});
- return Bound;
+ return backwardSlice;
}
};
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index cb4063dc2845260..081d216b56b61c0 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -17,59 +17,13 @@
#include <unordered_set>
#include <vector>
-namespace mlir::query::matcher {
-
-struct BoundOperationNode {
- Operation *op;
- std::vector<BoundOperationNode *> Parents;
- std::vector<BoundOperationNode *> Children;
-
- bool IsRootNode;
- bool DetailedPrinting;
-
- BoundOperationNode(Operation *op, bool IsRootNode = false,
- bool DetailedPrinting = false)
- : op(op), IsRootNode(IsRootNode), DetailedPrinting(DetailedPrinting) {}
-};
+namespace mlir {
+namespace query {
+struct QueryOptions;
+}
+} // namespace mlir
-class BoundOperationsGraphBuilder {
-public:
- BoundOperationNode *addNode(Operation *op, bool IsRootNode = false,
- bool DetailedPrinting = false) {
- auto It = Nodes.find(op);
- if (It != Nodes.end()) {
- return It->second.get();
- }
- auto Node =
- std::make_unique<BoundOperationNode>(op, IsRootNode, DetailedPrinting);
- BoundOperationNode *NodePtr = Node.get();
- Nodes[op] = std::move(Node);
- return NodePtr;
- }
-
- void addEdge(Operation *parentOp, Operation *childOp) {
- BoundOperationNode *ParentNode = addNode(parentOp, false, false);
- BoundOperationNode *ChildNode = addNode(childOp, false, false);
-
- ParentNode->Children.push_back(ChildNode);
- ChildNode->Parents.push_back(ParentNode);
- }
-
- BoundOperationNode *getNode(Operation *op) const {
- auto It = Nodes.find(op);
- return It != Nodes.end() ? It->second.get() : nullptr;
- }
-
- const llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> &
- getNodes() const {
- return Nodes;
- }
-
-private:
- llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> Nodes;
-};
-
-// Type traIt to detect if a matcher has a match(Operation*) method
+namespace mlir::query::matcher {
template <typename T, typename = void>
struct has_simple_match : std::false_type {};
@@ -78,15 +32,14 @@ struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
std::declval<Operation *>()))>>
: std::true_type {};
-// Type traIt to detect if a matcher has a match(Operation*,
-// BoundOperationsGraphBuilder&) method
template <typename T, typename = void>
struct has_bound_match : std::false_type {};
template <typename T>
struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
std::declval<Operation *>(),
- std::declval<BoundOperationsGraphBuilder &>()))>>
+ std::declval<SetVector<Operation *> &>(),
+ std::declval<QueryOptions &>()))>>
: std::true_type {};
// Generic interface for matchers on an MLIR operation.
@@ -95,7 +48,8 @@ class MatcherInterface
public:
virtual ~MatcherInterface() = default;
virtual bool match(Operation *op) = 0;
- virtual bool match(Operation *op, BoundOperationsGraphBuilder &bound) = 0;
+ virtual bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) = 0;
};
// MatcherFnImpl takes a matcher function object and implements
@@ -111,9 +65,10 @@ class MatcherFnImpl : public MatcherInterface {
return false;
}
- bool match(Operation *op, BoundOperationsGraphBuilder &bound) override {
+ bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) override {
if constexpr (has_bound_match<MatcherFn>::value)
- return matcherFn.match(op, bound);
+ return matcherFn.match(op, matchedOps, options);
return false;
}
@@ -138,8 +93,9 @@ class DynMatcher {
}
bool match(Operation *op) const { return implementation->match(op); }
- bool match(Operation *op, BoundOperationsGraphBuilder &bound) const {
- return implementation->match(op, bound);
+ bool match(Operation *op, SetVector<Operation *> &matchedOps,
+ QueryOptions &options) const {
+ return implementation->match(op, matchedOps, options);
}
void setFunctionName(StringRef name) { functionName = name.str(); }
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 73d96a6913dfe4a..6b57119df7a9bf5 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -100,6 +100,8 @@ class VariantValue {
// String representation of the type of the value.
std::string getTypeAsString() const;
+ explicit operator bool() const { return hasValue(); }
+ bool hasValue() const { return type != ValueType::Nothing; }
private:
void reset();
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 18f2172c9510a35..d695f648016be31 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -10,6 +10,7 @@
#define MLIR_TOOLS_MLIRQUERY_QUERY_H
#include "Matcher/VariantValue.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/LineEditor/LineEditor.h"
@@ -17,7 +18,13 @@
namespace mlir::query {
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+struct QueryOptions {
+ bool omitBlockArguments = false;
+ bool omitUsesFromAbove = true;
+ bool inclusive = true;
+};
+
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit, Let, SetBool };
class QuerySession;
@@ -103,6 +110,47 @@ struct MatchQuery : Query {
}
};
+struct LetQuery : Query {
+ LetQuery(llvm::StringRef name, const matcher::VariantValue &value)
+ : Query(QueryKind::Let), name(name), value(value) {}
+
+ llvm::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override;
+
+ std::string name;
+ matcher::VariantValue value;
+
+ static bool classof(const Query *query) {
+ return query->kind == QueryKind::Let;
+ }
+};
+
+template <typename T>
+struct SetQueryKind {};
+
+template <>
+struct SetQueryKind<bool> {
+ static const QueryKind value = QueryKind::SetBool;
+};
+template <typename T>
+struct SetQuery : Query {
+ SetQuery(T QuerySession::*var, T value)
+ : Query(SetQueryKind<T>::value), var(var), value(value) {}
+
+ llvm::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override {
+ qs.*var = value;
+ return mlir::success();
+ }
+
+ static bool classof(const Query *query) {
+ return query->kind == SetQueryKind<T>::value;
+ }
+
+ T QuerySession::*var;
+ T value;
+};
+
} // namespace mlir::query
#endif
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index fe552d750fc7716..7faf206b96124b5 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -33,6 +33,11 @@ class QuerySession {
llvm::StringMap<matcher::VariantValue> namedValues;
bool terminate = false;
+public:
+ bool omitBlockArguments = false;
+ bool omitUsesFromAbove = true;
+ bool inclusive = true;
+
private:
Operation *rootOp;
llvm::SourceMgr &sourceMgr;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 4f1b716756e318e..79797ec45b549b0 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -191,7 +191,7 @@ class Parser::CodeTokenizer {
double doubleValue = strtod(text.c_str(), &end);
if (*end == 0 && errno == 0) {
result->kind = TokenKind::Literal;
- result->value = doubleValue;
+ result->value = static_cast<double>(doubleValue);
return;
}
} else {
@@ -316,13 +316,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
// Parse as a named value.
- auto namedValue =
- namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+ if (auto namedValue = namedValues ? namedValues->lookup(nameToken.text)
+ : VariantValue()) {
- if (!namedValue.isMatcher()) {
- error->addError(tokenizer->peekNextToken().range,
- ErrorType::ParserNotAMatcher);
- return false;
+ if (tokenizer->nextTokenKind() != TokenKind::Period) {
+ *value = namedValue;
+ return true;
+ }
+
+ if (!namedValue.isMatcher()) {
+ error->addError(tokenizer->peekNextToken().range,
+ ErrorType::ParserNotAMatcher);
+ return false;
+ }
}
if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 50d79512196d1a5..d5218d8dad8c936 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -99,6 +99,7 @@ void VariantValue::reset() {
type = ValueType::Nothing;
}
+// Unsinged
bool VariantValue::isUnsigned() const { return type == ValueType::Unsigned; }
unsigned VariantValue::getUnsigned() const {
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 70be7c36888d50b..03d5018abf5b012 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -29,8 +29,8 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
return QueryParser::complete(line, pos, qs);
}
-static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
- const std::string &binding) {
+static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ mlir::Operation *op, const std::string &binding) {
auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
@@ -102,6 +102,12 @@ static Operation *extractFunction(std::vector<Operation *> &ops,
return funcOp;
}
+static void parseQueryOptions(QuerySession &qs, QueryOptions &options) {
+ options.omitBlockArguments = qs.omitBlockArguments;
+ options.omitUsesFromAbove = qs.omitUsesFromAbove;
+ options.inclusive = qs.inclusive;
+}
+
Query::~Query() = default;
LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
@@ -117,6 +123,8 @@ LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
os << "Available commands:\n\n"
" match MATCHER, m MATCHER "
"Match the mlir against the given matcher.\n"
+ " let NAME MATCHER, l NAME MATCHER "
+ "Give a matcher expression a name, to be used later\n"
" quit "
"Terminates the query session.\n\n";
return mlir::success();
@@ -127,132 +135,44 @@ LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
return mlir::success();
}
-void collectMatchNodes(
- matcher::BoundOperationNode *Node,
- llvm::SetVector<matcher::BoundOperationNode *> &MatchNodes) {
- MatchNodes.insert(Node);
- for (auto ChildNode : Node->Children) {
- collectMatchNodes(ChildNode, MatchNodes);
+LogicalResult LetQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
+ if (value.hasValue()) {
+ qs.namedValues[name] = value;
+ } else {
+ qs.namedValues.erase(name);
}
-}
-
-void analyzeAndPrint(llvm::raw_ostream &os, QuerySession &qs,
- const matcher::BoundOperationsGraphBuilder &Bound) {
-
- const auto &Nodes = Bound.getNodes();
- if (Nodes.empty()) {
- os << "The graph is empty.\n";
- return;
- }
-
- bool AnyDetailedPrinting = false;
- for (const auto &Pair : Nodes) {
- if (Pair.second->DetailedPrinting) {
- AnyDetailedPrinting = true;
- break;
- }
- }
-
- unsigned MatchesCounter = 0;
- if (!AnyDetailedPrinting) {
- os << "Operations:\n";
- for (const auto &Pair : Nodes) {
- os << "\n";
- os << " Match #" << ++MatchesCounter << "\n";
- printMatch(os, qs, Pair.first, "root");
- }
- os << MatchesCounter << " matches found!\n";
- return;
- }
-
- // Maps ids to nodes
- std::unordered_map<Operation *, int> NodeIDs;
- int id = 0;
- for (const auto &Pair : Nodes) {
- NodeIDs[Pair.first] = id++;
- }
-
- // Finds root nodes
- std::vector<matcher::BoundOperationNode *> RootNodes;
- for (const auto &Pair : Nodes) {
- matcher::BoundOperationNode *Node = Pair.second.get();
- if (Node->IsRootNode) {
- RootNodes.push_back(Node);
- }
- }
-
- for (auto RootNode : RootNodes) {
- os << "\n";
- os << " Match #" << ++MatchesCounter << "\n";
-
- llvm::SetVector<matcher::BoundOperationNode *> MatchNodes;
- collectMatchNodes(RootNode, MatchNodes);
- std::vector<matcher::BoundOperationNode *> SortedMatchNodes(
- MatchNodes.begin(), MatchNodes.end());
-
- // Sorts based on file location
- std::sort(
- SortedMatchNodes.begin(), SortedMatchNodes.end(),
- [&](matcher::BoundOperationNode *a, matcher::BoundOperationNode *b) {
- auto fileLocA = a->op->getLoc()->findInstanceOf<FileLineColLoc>();
- auto fileLocB = b->op->getLoc()->findInstanceOf<FileLineColLoc>();
-
- if (!fileLocA && !fileLocB)
- return false;
- if (!fileLocA)
- return false;
- if (!fileLocB)
- return true;
-
- if (fileLocA.getFilename().str() != fileLocB.getFilename().str())
- return fileLocA.getFilename().str() < fileLocB.getFilename().str();
- return fileLocA.getLine() < fileLocB.getLine();
- });
-
- for (auto Node : SortedMatchNodes) {
- unsigned NodeID = NodeIDs[Node->op];
- std::string binding = Node->IsRootNode ? "root" : "";
- os << NodeID << ": ";
- printMatch(os, qs, Node->op, binding);
- }
-
- // Prints edges
- os << "Edges:\n";
- for (auto Node : MatchNodes) {
- int ParentID = NodeIDs[Node->op];
- for (auto ChildNode : Node->Children) {
- if (MatchNodes.count(ChildNode) > 0) {
- int ChildID = NodeIDs[ChildNode->op];
- os << " " << ParentID << " ---> " << ChildID << "\n";
- }
- }
- }
- }
- os << "\n" << MatchesCounter << " matches found!\n";
+ return mlir::success();
}
LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
Operation *rootOp = qs.getRootOp();
int matchCount = 0;
- auto matches = matcher::MatchFinder().getMatches(rootOp, matcher);
- // An extract call is recognized by considering if the matcher has a
- // name.TODO : Consider making the extract
- // more explicit.
+ QueryOptions options;
+ parseQueryOptions(qs, options);
+ auto matches =
+ matcher::MatchFinder().getMatches(rootOp, options, std::move(matcher));
+
+ // An extract call is recognized by considering if the matcher has a name.
+ // TODO: Consider making the extract more explicit.
// if (matcher.hasFunctionName()) {
// auto functionName = matcher.getFunctionName();
- // Operation *function = extractFunction(matches.getOperations(),
- // rootOp->getContext(),
- // functionName);
+ // Operation *function =
+ // extractFunction(matches, rootOp->getContext(), functionName);
// os << "\n" << *function << "\n\n";
// function->erase();
// return mlir::success();
// }
os << "\n";
- analyzeAndPrint(os, qs, matches);
+ for (Operation *op : matches) {
+ os << "Match #" << ++matchCount << ":\n\n";
+ // Placeholder "root" binding for the initial draft.
+ printMatch(os, qs, op, "root");
+ }
+ os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
return mlir::success();
}
-} // namespace mlir::query
+} // namespace mlir::query
\ No newline at end of file
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
index 31aead7d403d0df..0565efeb2b80f3b 100644
--- a/mlir/lib/Query/QueryParser.cpp
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -107,6 +107,18 @@ QueryRef QueryParser::endQuery(QueryRef queryRef) {
return queryRef;
}
+QueryRef QueryParser::parseSetBool(bool QuerySession::*var) {
+ StringRef valStr;
+ unsigned value = LexOrCompleteWord<unsigned>(this, valStr)
+ .Case("false", 0)
+ .Case("true", 1)
+ .Default(~0u);
+ if (value == ~0u) {
+ return new InvalidQuery("expected 'true' or 'false', got '" + valStr + "'");
+ }
+ return new SetQuery<bool>(var, value);
+}
+
namespace {
enum class ParsedQueryKind {
@@ -116,6 +128,15 @@ enum class ParsedQueryKind {
Help,
Match,
Quit,
+ Let,
+ Set,
+};
+
+enum ParsedQueryVariable {
+ Invalid,
+ OmitBlockArguments,
+ OmitUsesFromAbove,
+ Inclusive,
};
QueryRef
@@ -146,7 +167,10 @@ QueryRef QueryParser::doParse() {
.Case("", ParsedQueryKind::NoOp)
.Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
.Case("help", ParsedQueryKind::Help)
+ .Case("l", ParsedQueryKind::Let, /*isCompletion=*/false)
+ .Case("let", ParsedQueryKind::Let, /*isCompletion=*/false)
.Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
+ .Case("set", ParsedQueryKind::Set)
.Case("match", ParsedQueryKind::Match)
.Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false)
.Case("quit", ParsedQueryKind::Quit)
@@ -167,6 +191,27 @@ QueryRef QueryParser::doParse() {
case ParsedQueryKind::Quit:
return endQuery(new QuitQuery);
+ case ParsedQueryKind::Let: {
+ llvm::StringRef name = lexWord();
+
+ if (name.empty()) {
+ return new InvalidQuery("expected variable name");
+ }
+
+ if (completionPos) {
+ return completeMatcherExpression();
+ }
+
+ matcher::internal::Diagnostics diag;
+ matcher::VariantValue value;
+ if (!matcher::internal::Parser::parseExpression(
+ line, qs.getRegistryData(), &qs.namedValues, &value, &diag)) {
+ return makeInvalidQueryFromDiagnostics(diag);
+ }
+ QueryRef query = new LetQuery(name, value);
+ query->remainingContent = line;
+ return query;
+ }
case ParsedQueryKind::Match: {
if (completionPos) {
return completeMatcherExpression();
@@ -183,11 +228,41 @@ QueryRef QueryParser::doParse() {
}
auto actualSource = origMatcherSource.substr(0, origMatcherSource.size() -
matcherSource.size());
+
QueryRef query = new MatchQuery(actualSource, *matcher);
query->remainingContent = matcherSource;
return query;
}
-
+ case ParsedQueryKind::Set: {
+ llvm::StringRef varStr;
+ ParsedQueryVariable var =
+ LexOrCompleteWord<ParsedQueryVariable>(this, varStr)
+ .Case("omitBlockArguments", ParsedQueryVariable::OmitBlockArguments)
+ .Case("omitUsesFromAbove", ParsedQueryVariable::OmitUsesFromAbove)
+ .Case("inclusive", ParsedQueryVariable::Inclusive)
+ .Default(ParsedQueryVariable::Invalid);
+ if (varStr.empty()) {
+ return new InvalidQuery("expected variable name");
+ }
+ if (var == ParsedQueryVariable::Invalid) {
+ return new InvalidQuery("unknown variable: '" + varStr + "'");
+ }
+ QueryRef query;
+ switch (var) {
+ case ParsedQueryVariable::OmitBlockArguments:
+ query = parseSetBool(&QuerySession::omitBlockArguments);
+ break;
+ case ParsedQueryVariable::OmitUsesFromAbove:
+ query = parseSetBool(&QuerySession::omitUsesFromAbove);
+ break;
+ case ParsedQueryVariable::Inclusive:
+ query = parseSetBool(&QuerySession::inclusive);
+ break;
+ case ParsedQueryVariable::Invalid:
+ llvm_unreachable("Invalid query kind");
+ }
+ return endQuery(query);
+ }
case ParsedQueryKind::Invalid:
return new InvalidQuery("unknown command: " + commandStr);
}
diff --git a/mlir/lib/Query/QueryParser.h b/mlir/lib/Query/QueryParser.h
index e9c30eccecab9e4..69cc5d0043d57df 100644
--- a/mlir/lib/Query/QueryParser.h
+++ b/mlir/lib/Query/QueryParser.h
@@ -39,8 +39,8 @@ class QueryParser {
struct LexOrCompleteWord;
QueryRef completeMatcherExpression();
-
QueryRef endQuery(QueryRef queryRef);
+ QueryRef parseSetBool(bool QuerySession::*var);
// Parse [begin, end) and returns a reference to the parsed query object,
// which may be an InvalidQuery if a parse error occurs.
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
new file mode 100644
index 000000000000000..e491beb5007792b
--- /dev/null
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-query %s -c "m definedBy(hasOpName(())" | FileCheck %s
+
+
+func.func @matrix_multiply(%A: memref<4x4xf32>, %B: memref<4x4xf32>, %C: memref<4x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+
+ scf.for %i = %c0 to %c4 step %c1 {
+ scf.for %j = %c0 to %c4 step %c1 {
+ %sum_init = arith.constant 0.0 : f32
+ %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %sum_init) -> (f32) {
+ %a_ik = memref.load %A[%i, %k] : memref<4x4xf32>
+ %b_kj = memref.load %B[%k, %j] : memref<4x4xf32>
+ %prod = arith.mulf %a_ik, %b_kj : f32
+ %new_acc = arith.addf %acc, %prod : f32
+ scf.yield %new_acc : f32
+ }
+ memref.store %sum, %C[%i, %j] : memref<4x4xf32>
+ }
+ }
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index a783f65c6761bcf..473c30aee7123ae 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -16,4 +16,4 @@ func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
%sum2 = arith.addf %mul1, %b : f32
%mul2 = arith.mulf %sub2, %sum2 : f32
return %mul2 : f32
-}
+ }
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index d5c0b1632d3c5d4..468f948bec24cac 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -44,6 +44,10 @@ int main(int argc, char **argv) {
mlir::query::extramatcher::getDefinitions);
matcherRegistry.registerMatcher("definedBy",
mlir::query::extramatcher::definedBy);
+ matcherRegistry.registerMatcher("usedBy", mlir::query::extramatcher::usedBy);
+ matcherRegistry.registerMatcher("getUses",
+ mlir::query::extramatcher::getUses);
+
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
>From 214a32343c7e3438ddd5dadb195c3271a5660e50 Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Sun, 19 Jan 2025 13:07:17 +0000
Subject: [PATCH 3/5] conflicts fixed
---
.../mlir/Query/Matcher/ExtraMatchers.h | 27 +++------
mlir/include/mlir/Query/Matcher/MatchFinder.h | 58 ++++++++++++-------
.../mlir/Query/Matcher/MatchersInternal.h | 9 +--
mlir/include/mlir/Query/Query.h | 1 -
mlir/include/mlir/Query/QuerySession.h | 6 +-
mlir/lib/Query/Matcher/Parser.cpp | 35 ++---------
mlir/lib/Query/Query.cpp | 23 +-------
mlir/test/mlir-query/complex-test.mlir | 39 +++++++------
mlir/test/mlir-query/function-extraction.mlir | 2 +-
9 files changed, 82 insertions(+), 118 deletions(-)
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
index 1900879ca70920b..57adc3241b0bef5 100644
--- a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -1,4 +1,5 @@
-//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
+//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
+//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -36,7 +37,6 @@ class BackwardSliceMatcher {
bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
QueryOptions &options, unsigned tempHops) {
- bool validSlice = true;
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
return false;
}
@@ -56,16 +56,12 @@ class BackwardSliceMatcher {
Operation *parentOp = block->getParentOp();
if (parentOp && backwardSlice.count(parentOp) == 0) {
- if (parentOp->getNumRegions() == 1 &&
- parentOp->getRegion(0).getBlocks().size() == 1) {
- validSlice = false;
- return;
- };
- matches(parentOp, backwardSlice, options, tempHops - 1);
+ assert(parentOp->getNumRegions() == 1 &&
+ parentOp->getRegion(0).getBlocks().size() == 1);
+ matches(parentOp, backwardSlice, options, tempHops-1);
}
} else {
- validSlice = false;
- return;
+ llvm_unreachable("No definingOp and not a block argument.");
}
};
@@ -78,22 +74,13 @@ class BackwardSliceMatcher {
for (OpOperand &operand : op->getOpOperands()) {
if (!descendents.contains(operand.get().getParentRegion()))
processValue(operand.get());
- if (!validSlice)
- return;
}
});
});
}
- llvm::for_each(op->getOperands(), [&](Value operand) {
- processValue(operand);
- if (!validSlice)
- return;
- });
+ llvm::for_each(op->getOperands(), processValue);
backwardSlice.insert(op);
- if (!validSlice) {
- return false;
- }
return true;
}
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index cbdf2fec46a7611..c188bfe75406a7e 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -1,43 +1,61 @@
-//===- MatchFinder.h - ------------------------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
//===----------------------------------------------------------------------===//
-//
-// This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
-//
+// MatchFinder.h
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#include "MatchersInternal.h"
-#include "mlir/IR/Operation.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir::query::matcher {
-// MatchFinder is used to find all operations that match a given matcher.
class MatchFinder {
+private:
+ // Base print function with binding text
+ static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+ mlir::Operation *op, const std::string &binding) {
+ auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+ auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+ "\"" + binding + "\" binds here");
+ };
+
public:
- // Returns all operations that match the given matcher.
static SetVector<Operation *>
- getMatches(Operation *root, QueryOptions &options, DynMatcher matcher) {
- SetVector<Operation *> backwardSlice;
+ getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+ llvm::raw_ostream &os, QuerySession &qs) {
+ unsigned matchCount = 0;
+ SetVector<Operation *> matchedOps;
+ SetVector<Operation *> tempStorage;
+
root->walk([&](Operation *subOp) {
if (matcher.match(subOp)) {
- backwardSlice.insert(subOp);
+ matchedOps.insert(subOp);
+ os << "Match #" << ++matchCount << ":\n\n";
+ printMatch(os, qs, subOp, "root");
} else {
- matcher.match(subOp, backwardSlice, options);
- ////
+ SmallVector<Operation *> printingOps;
+ size_t sizeBefore = matchedOps.size();
+ if (matcher.match(subOp, tempStorage, options)) {
+ os << "Match #" << ++matchCount << ":\n\n";
+ SmallVector<Operation *> printingOps(tempStorage.takeVector());
+ for (auto op : printingOps) {
+ printMatch(os, qs, op, ""); // Using version without binding text
+ matchedOps.insert(op);
+ }
+ printingOps.clear();
+ }
}
});
- return backwardSlice;
+ return matchedOps;
}
};
} // namespace mlir::query::matcher
-#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 081d216b56b61c0..b795d9a291844f8 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,6 +1,6 @@
//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
//
-// Part of the LLVM Project, under the Apache License v2.0 wIth LLVM Exceptions.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
//
@@ -11,11 +11,6 @@
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
-#include "llvm/ADT/MapVector.h"
-#include <memory>
-#include <stack>
-#include <unordered_set>
-#include <vector>
namespace mlir {
namespace query {
@@ -112,4 +107,4 @@ class DynMatcher {
} // namespace mlir::query::matcher
-#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
\ No newline at end of file
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index d695f648016be31..89d48773d2c3e6b 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -10,7 +10,6 @@
#define MLIR_TOOLS_MLIRQUERY_QUERY_H
#include "Matcher/VariantValue.h"
-#include "mlir/Analysis/SliceAnalysis.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/LineEditor/LineEditor.h"
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index 7faf206b96124b5..495358e8f36f94e 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -9,14 +9,18 @@
#ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
+#include "Matcher/VariantValue.h"
#include "mlir/IR/Operation.h"
#include "mlir/Query/Matcher/Registry.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/SourceMgr.h"
+namespace mlir::query::matcher {
+class Registry;
+}
+
namespace mlir::query {
-class Registry;
// Represents the state for a particular mlir-query session.
class QuerySession {
public:
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 79797ec45b549b0..726f1188d7e4c8c 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -157,7 +157,6 @@ class Parser::CodeTokenizer {
}
void consumeNumberLiteral(TokenInfo *result) {
- bool isFloatingLiteral = false;
unsigned length = 1;
if (code.size() > 1) {
// Consume the 'x' or 'b' radix modifier, if present.
@@ -170,39 +169,17 @@ class Parser::CodeTokenizer {
while (length < code.size() && isdigit(code[length]))
++length;
- // Try to recognize a floating point literal.
- while (length < code.size()) {
- char c = code[length];
- if (c == '-' || c == '+' || c == '.' || isdigit(c)) {
- isFloatingLiteral = true;
- length++;
- } else {
- break;
- }
- }
-
result->text = code.take_front(length);
code = code.drop_front(length);
- if (isFloatingLiteral) {
- char *end;
- errno = 0;
- std::string text = result->text.str();
- double doubleValue = strtod(text.c_str(), &end);
- if (*end == 0 && errno == 0) {
- result->kind = TokenKind::Literal;
- result->value = static_cast<double>(doubleValue);
- return;
- }
- } else {
- unsigned value;
- if (!result->text.getAsInteger(0, value)) {
- result->kind = TokenKind::Literal;
- result->value = value;
- return;
- }
+ unsigned value;
+ if (!result->text.getAsInteger(0, value)) {
+ result->kind = TokenKind::Literal;
+ result->value = static_cast<unsigned>(value);
+ return;
}
}
+
// Consume a string literal, handle escape sequences and missing closing
// quote.
void consumeStringLiteral(TokenInfo *result) {
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index 03d5018abf5b012..cbc436299afb701 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -15,8 +15,6 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
-#include <unordered_map>
-#include <unordered_set>
namespace mlir::query {
@@ -29,15 +27,6 @@ complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
return QueryParser::complete(line, pos, qs);
}
-static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
- mlir::Operation *op, const std::string &binding) {
- auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
- auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
- qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
- qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
- "\"" + binding + "\" binds here");
-}
-
// TODO: Extract into a helper function that can be reused outside query
// context.
static Operation *extractFunction(std::vector<Operation *> &ops,
@@ -150,8 +139,8 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
QueryOptions options;
parseQueryOptions(qs, options);
- auto matches =
- matcher::MatchFinder().getMatches(rootOp, options, std::move(matcher));
+ auto matches = matcher::MatchFinder().getMatches(rootOp, options,
+ std::move(matcher), os, qs);
// An extract call is recognized by considering if the matcher has a name.
// TODO: Consider making the extract more explicit.
@@ -164,14 +153,6 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
// return mlir::success();
// }
- os << "\n";
- for (Operation *op : matches) {
- os << "Match #" << ++matchCount << ":\n\n";
- // Placeholder "root" binding for the initial draft.
- printMatch(os, qs, op, "root");
- }
- os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
-
return mlir::success();
}
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/complex-test.mlir
index e491beb5007792b..af6193e4818abfe 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/complex-test.mlir
@@ -1,23 +1,26 @@
// RUN: mlir-query %s -c "m definedBy(hasOpName(())" | FileCheck %s
+func.func @region_control_flow(%arg: memref<2xf32>, %cond: i1) attributes {test.ptr = "func"} {
+ %0 = memref.alloca() {test.ptr = "alloca_1"} : memref<8x64xf32>
+ %1 = memref.alloca() {test.ptr = "alloca_2"} : memref<8x64xf32>
+ %2 = memref.alloc() {test.ptr = "alloc_1"} : memref<8x64xf32>
-func.func @matrix_multiply(%A: memref<4x4xf32>, %B: memref<4x4xf32>, %C: memref<4x4xf32>) {
- %c0 = arith.constant 0 : index
- %c4 = arith.constant 4 : index
- %c1 = arith.constant 1 : index
+ %3 = scf.if %cond -> (memref<8x64xf32>) {
+ scf.yield %0 : memref<8x64xf32>
+ } else {
+ scf.yield %0 : memref<8x64xf32>
+ } {test.ptr = "if_alloca"}
- scf.for %i = %c0 to %c4 step %c1 {
- scf.for %j = %c0 to %c4 step %c1 {
- %sum_init = arith.constant 0.0 : f32
- %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %sum_init) -> (f32) {
- %a_ik = memref.load %A[%i, %k] : memref<4x4xf32>
- %b_kj = memref.load %B[%k, %j] : memref<4x4xf32>
- %prod = arith.mulf %a_ik, %b_kj : f32
- %new_acc = arith.addf %acc, %prod : f32
- scf.yield %new_acc : f32
- }
- memref.store %sum, %C[%i, %j] : memref<4x4xf32>
- }
- }
+ %4 = scf.if %cond -> (memref<8x64xf32>) {
+ scf.yield %0 : memref<8x64xf32>
+ } else {
+ scf.yield %1 : memref<8x64xf32>
+ } {test.ptr = "if_alloca_merge"}
+
+ %5 = scf.if %cond -> (memref<8x64xf32>) {
+ scf.yield %2 : memref<8x64xf32>
+ } else {
+ scf.yield %2 : memref<8x64xf32>
+ } {test.ptr = "if_alloc"}
return
-}
\ No newline at end of file
+}
diff --git a/mlir/test/mlir-query/function-extraction.mlir b/mlir/test/mlir-query/function-extraction.mlir
index 473c30aee7123ae..ecbd17ab59dfb17 100644
--- a/mlir/test/mlir-query/function-extraction.mlir
+++ b/mlir/test/mlir-query/function-extraction.mlir
@@ -4,7 +4,7 @@
// CHECK: %[[MUL0:.*]] = arith.mulf {{.*}} : f32
// CHECK: %[[MUL1:.*]] = arith.mulf {{.*}}, %[[MUL0]] : f32
// CHECK: %[[MUL2:.*]] = arith.mulf {{.*}} : f32
-// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32
+// CHECK-NEXT: return %[[MUL0]], %[[MUL1]], %[[MUL2]] : f32, f32, f32S
func.func @mixedOperations(%a: f32, %b: f32, %c: f32) -> f32 {
%sum0 = arith.addf %a, %b : f32
>From 5d0bf4fbc53543091082bfe747f1a492e4f70a11 Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Sun, 19 Jan 2025 12:42:43 +0000
Subject: [PATCH 4/5] added info in header files & fixed newline
---
mlir/include/mlir/Query/Matcher/MatchFinder.h | 10 ++++++++++
.../mlir/Query/Matcher/MatchersInternal.h | 16 ++++++++++++++--
mlir/lib/Query/Query.cpp | 2 +-
3 files changed, 25 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index c188bfe75406a7e..f957cff9d038bc6 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -1,4 +1,14 @@
+//===- MatchFinder.h - ------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
//===----------------------------------------------------------------------===//
+//
+// This file contains the MatchFinder class, which is used to find operations
+// that match a given matcher and print them.
+//
// MatchFinder.h
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index b795d9a291844f8..3f077bc44bd6771 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,11 +1,23 @@
-//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements the base layer of the matcher framework.
+//
+// Matchers are methods that return a Matcher which provides a method
+// match(Operation *op)
+// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
+// &options)
+//
+// The matcher functions are defined in include/mlir/IR/Matchers.h.
+// This file contains the wrapper classes needed to construct matchers for
+// mlir-query.
// SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
//
//===----------------------------------------------------------------------===//
-
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
index cbc436299afb701..960fe4a3775455e 100644
--- a/mlir/lib/Query/Query.cpp
+++ b/mlir/lib/Query/Query.cpp
@@ -156,4 +156,4 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
return mlir::success();
}
-} // namespace mlir::query
\ No newline at end of file
+} // namespace mlir::query
>From bafbd37932158e744350846f06c94287f8d3aa4c Mon Sep 17 00:00:00 2001
From: Denzel-Brian Budii <denzel-brian.budii at intel.com>
Date: Sun, 19 Jan 2025 13:09:35 +0000
Subject: [PATCH 5/5] removed redundant code
---
mlir/include/mlir/Query/Matcher/MatchFinder.h | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index f957cff9d038bc6..6d8214bd42ca69d 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -50,12 +50,12 @@ class MatchFinder {
printMatch(os, qs, subOp, "root");
} else {
SmallVector<Operation *> printingOps;
- size_t sizeBefore = matchedOps.size();
+
if (matcher.match(subOp, tempStorage, options)) {
os << "Match #" << ++matchCount << ":\n\n";
SmallVector<Operation *> printingOps(tempStorage.takeVector());
for (auto op : printingOps) {
- printMatch(os, qs, op, ""); // Using version without binding text
+ printMatch(os, qs, op, "");
matchedOps.insert(op);
}
printingOps.clear();
More information about the Mlir-commits
mailing list