[Mlir-commits] [mlir] 02d9f4d - [mlir][mlir-query] Introduce mlir-query tool with autocomplete support
Jacques Pienaar
llvmlistbot at llvm.org
Fri Oct 13 14:03:35 PDT 2023
Author: Devajith
Date: 2023-10-13T14:03:27-07:00
New Revision: 02d9f4d1f128e17e04ab6e602d3c9b9942612428
URL: https://github.com/llvm/llvm-project/commit/02d9f4d1f128e17e04ab6e602d3c9b9942612428
DIFF: https://github.com/llvm/llvm-project/commit/02d9f4d1f128e17e04ab6e602d3c9b9942612428.diff
LOG: [mlir][mlir-query] Introduce mlir-query tool with autocomplete support
This commit adds the initial version of the mlir-query tool, which leverages the pre-existing matchers defined in mlir/include/mlir/IR/Matchers.h
The tool provides the following set of basic queries:
hasOpAttrName(string)
hasOpName(string)
isConstantOp()
isNegInfFloat()
isNegZeroFloat()
isNonZero()
isOne()
isOneFloat()
isPosInfFloat()
isPosZeroFloat()
isZero()
isZeroFloat()
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D155127
Added:
mlir/include/mlir/Query/Matcher/ErrorBuilder.h
mlir/include/mlir/Query/Matcher/Marshallers.h
mlir/include/mlir/Query/Matcher/MatchFinder.h
mlir/include/mlir/Query/Matcher/MatchersInternal.h
mlir/include/mlir/Query/Matcher/Registry.h
mlir/include/mlir/Query/Matcher/VariantValue.h
mlir/include/mlir/Query/Query.h
mlir/include/mlir/Query/QuerySession.h
mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h
mlir/lib/Query/CMakeLists.txt
mlir/lib/Query/Matcher/CMakeLists.txt
mlir/lib/Query/Matcher/Diagnostics.cpp
mlir/lib/Query/Matcher/Diagnostics.h
mlir/lib/Query/Matcher/ErrorBuilder.cpp
mlir/lib/Query/Matcher/Parser.cpp
mlir/lib/Query/Matcher/Parser.h
mlir/lib/Query/Matcher/RegistryManager.cpp
mlir/lib/Query/Matcher/RegistryManager.h
mlir/lib/Query/Matcher/VariantValue.cpp
mlir/lib/Query/Query.cpp
mlir/lib/Query/QueryParser.cpp
mlir/lib/Query/QueryParser.h
mlir/lib/Tools/mlir-query/CMakeLists.txt
mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
mlir/test/mlir-query/simple-test.mlir
mlir/tools/mlir-query/CMakeLists.txt
mlir/tools/mlir-query/mlir-query.cpp
Modified:
mlir/lib/CMakeLists.txt
mlir/lib/Tools/CMakeLists.txt
mlir/test/CMakeLists.txt
mlir/tools/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
new file mode 100644
index 000000000000000..1073daed8703f5a
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
@@ -0,0 +1,63 @@
+//===--- ErrorBuilder.h - Helper for building error messages ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// ErrorBuilder to manage error messages.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_ERRORBUILDER_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHER_ERRORBUILDER_H
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include <initializer_list>
+
+namespace mlir::query::matcher::internal {
+class Diagnostics;
+
+// Represents the line and column numbers in a source query.
+struct SourceLocation {
+ unsigned line{};
+ unsigned column{};
+};
+
+// Represents a range in a source query, defined by its start and end locations.
+struct SourceRange {
+ SourceLocation start{};
+ SourceLocation end{};
+};
+
+// All errors from the system.
+enum class ErrorType {
+ None,
+
+ // Parser Errors
+ ParserFailedToBuildMatcher,
+ ParserInvalidToken,
+ ParserNoCloseParen,
+ ParserNoCode,
+ ParserNoComma,
+ ParserNoOpenParen,
+ ParserNotAMatcher,
+ ParserOverloadedType,
+ ParserStringError,
+ ParserTrailingCode,
+
+ // Registry Errors
+ RegistryMatcherNotFound,
+ RegistryValueNotFound,
+ RegistryWrongArgCount,
+ RegistryWrongArgType
+};
+
+void addError(Diagnostics *error, SourceRange range, ErrorType errorType,
+ std::initializer_list<llvm::Twine> errorTexts);
+
+} // namespace mlir::query::matcher::internal
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_ERRORBUILDER_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
new file mode 100644
index 000000000000000..6ed35ac0ddccc70
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -0,0 +1,199 @@
+//===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains function templates and classes to wrap matcher construct
+// functions. It provides a collection of template function and classes that
+// present a generic marshalling layer on top of matcher construct functions.
+// The registry uses these to export all marshaller constructors with a uniform
+// interface. This mechanism takes inspiration from clang-query.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
+
+#include "ErrorBuilder.h"
+#include "VariantValue.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir::query::matcher::internal {
+
+// Helper template class for jumping from argument type to the correct is/get
+// functions in VariantValue. This is used for verifying and extracting the
+// matcher arguments.
+template <class T>
+struct ArgTypeTraits;
+template <class T>
+struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};
+
+template <>
+struct ArgTypeTraits<llvm::StringRef> {
+
+ static bool hasCorrectType(const VariantValue &value) {
+ return value.isString();
+ }
+
+ static const llvm::StringRef &get(const VariantValue &value) {
+ return value.getString();
+ }
+
+ static ArgKind getKind() { return ArgKind::String; }
+
+ static std::optional<std::string> getBestGuess(const VariantValue &) {
+ return std::nullopt;
+ }
+};
+
+template <>
+struct ArgTypeTraits<DynMatcher> {
+
+ static bool hasCorrectType(const VariantValue &value) {
+ return value.isMatcher();
+ }
+
+ static DynMatcher get(const VariantValue &value) {
+ return *value.getMatcher().getDynMatcher();
+ }
+
+ static ArgKind getKind() { return ArgKind::Matcher; }
+
+ static std::optional<std::string> getBestGuess(const VariantValue &) {
+ return std::nullopt;
+ }
+};
+
+// Interface for generic matcher descriptor.
+// Offers a create() method that constructs the matcher from the provided
+// arguments.
+class MatcherDescriptor {
+public:
+ virtual ~MatcherDescriptor() = default;
+ virtual VariantMatcher create(SourceRange nameRange,
+ const llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) const = 0;
+
+ // Returns the number of arguments accepted by the matcher.
+ virtual unsigned getNumArgs() const = 0;
+
+ // Append the set of argument types accepted for argument 'argNo' to
+ // 'argKinds'.
+ virtual void getArgKinds(unsigned argNo,
+ std::vector<ArgKind> &argKinds) const = 0;
+};
+
+class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
+public:
+ using MarshallerType = VariantMatcher (*)(void (*matcherFunc)(),
+ llvm::StringRef matcherName,
+ SourceRange nameRange,
+ llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error);
+
+ // Marshaller Function to unpack the arguments and call Func. Func is the
+ // Matcher construct function. This is the function that the matcher
+ // expressions would use to create the matcher.
+ FixedArgCountMatcherDescriptor(MarshallerType marshaller,
+ void (*matcherFunc)(),
+ llvm::StringRef matcherName,
+ llvm::ArrayRef<ArgKind> argKinds)
+ : marshaller(marshaller), matcherFunc(matcherFunc),
+ matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {}
+
+ VariantMatcher create(SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) const override {
+ return marshaller(matcherFunc, matcherName, nameRange, args, error);
+ }
+
+ unsigned getNumArgs() const override { return argKinds.size(); }
+
+ void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
+ kinds.push_back(argKinds[argNo]);
+ }
+
+private:
+ const MarshallerType marshaller;
+ void (*const matcherFunc)();
+ const llvm::StringRef matcherName;
+ const std::vector<ArgKind> argKinds;
+};
+
+// Helper function to check if argument count matches expected count
+inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
+ llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) {
+ if (args.size() != expectedArgCount) {
+ addError(error, nameRange, ErrorType::RegistryWrongArgCount,
+ {llvm::Twine(expectedArgCount), llvm::Twine(args.size())});
+ return false;
+ }
+ return true;
+}
+
+// Helper function for checking argument type
+template <typename ArgType, size_t Index>
+inline bool checkArgTypeAtIndex(llvm::StringRef matcherName,
+ llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) {
+ if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
+ addError(error, args[Index].range, ErrorType::RegistryWrongArgType,
+ {llvm::Twine(matcherName), llvm::Twine(Index + 1)});
+ return false;
+ }
+ return true;
+}
+
+// Marshaller function for fixed number of arguments
+template <typename ReturnType, typename... ArgTypes, size_t... Is>
+static VariantMatcher
+matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
+ SourceRange nameRange,
+ llvm::ArrayRef<ParserValue> args, Diagnostics *error,
+ std::index_sequence<Is...>) {
+ using FuncType = ReturnType (*)(ArgTypes...);
+
+ // Check if the argument count matches the expected count
+ if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error))
+ return VariantMatcher();
+
+ // Check if each argument at the corresponding index has the correct type
+ if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
+ ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
+ ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
+ return VariantMatcher::SingleMatcher(
+ *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
+ }
+
+ return VariantMatcher();
+}
+
+template <typename ReturnType, typename... ArgTypes>
+static VariantMatcher
+matcherMarshallFixed(void (*matcherFunc)(), llvm::StringRef matcherName,
+ SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) {
+ return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
+ matcherFunc, matcherName, nameRange, args, error,
+ std::index_sequence_for<ArgTypes...>{});
+}
+
+// Fixed number of arguments overload
+template <typename ReturnType, typename... ArgTypes>
+std::unique_ptr<MatcherDescriptor>
+makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
+ llvm::StringRef matcherName) {
+ // Create a vector of argument kinds
+ std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
+ return std::make_unique<FixedArgCountMatcherDescriptor>(
+ matcherMarshallFixed<ReturnType, ArgTypes...>,
+ reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
+}
+
+} // namespace mlir::query::matcher::internal
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
new file mode 100644
index 000000000000000..b008a21f53ae2a6
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -0,0 +1,41 @@
+//===- MatchFinder.h - ------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the MatchFinder class, which is used to find operations
+// that match a given matcher.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
+
+#include "MatchersInternal.h"
+
+namespace mlir::query::matcher {
+
+// MatchFinder is used to find all operations that match a given matcher.
+class MatchFinder {
+public:
+ // Returns all operations that match the given matcher.
+ static std::vector<Operation *> getMatches(Operation *root,
+ DynMatcher matcher) {
+ std::vector<Operation *> matches;
+
+ // Simple match finding with walk.
+ root->walk([&](Operation *subOp) {
+ if (matcher.match(subOp))
+ matches.push_back(subOp);
+ });
+
+ return matches;
+ }
+};
+
+} // namespace mlir::query::matcher
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
new file mode 100644
index 000000000000000..67455be592393b4
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -0,0 +1,72 @@
+//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements the base layer of the matcher framework.
+//
+// Matchers are methods that return a Matcher which provides a method
+// match(Operation *op)
+//
+// The matcher functions are defined in include/mlir/IR/Matchers.h.
+// This file contains the wrapper classes needed to construct matchers for
+// mlir-query.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
+
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
+
+namespace mlir::query::matcher {
+
+// Generic interface for matchers on an MLIR operation.
+class MatcherInterface
+ : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
+public:
+ virtual ~MatcherInterface() = default;
+
+ virtual bool match(Operation *op) = 0;
+};
+
+// MatcherFnImpl takes a matcher function object and implements
+// MatcherInterface.
+template <typename MatcherFn>
+class MatcherFnImpl : public MatcherInterface {
+public:
+ MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
+ bool match(Operation *op) override { return matcherFn.match(op); }
+
+private:
+ MatcherFn matcherFn;
+};
+
+// Matcher wraps a MatcherInterface implementation and provides a match()
+// method that redirects calls to the underlying implementation.
+class DynMatcher {
+public:
+ // Takes ownership of the provided implementation pointer.
+ DynMatcher(MatcherInterface *implementation)
+ : implementation(implementation) {}
+
+ template <typename MatcherFn>
+ static std::unique_ptr<DynMatcher>
+ constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
+ auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
+ return std::make_unique<DynMatcher>(impl.release());
+ }
+
+ bool match(Operation *op) const { return implementation->match(op); }
+
+private:
+ llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
+};
+
+} // namespace mlir::query::matcher
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/Registry.h b/mlir/include/mlir/Query/Matcher/Registry.h
new file mode 100644
index 000000000000000..e929b4a04d151db
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/Registry.h
@@ -0,0 +1,51 @@
+//===--- Registry.h - Matcher Registry --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Registry class to manage the registry of matchers using a map.
+//
+// This class provides a convenient interface for registering and accessing
+// matcher constructors using a string-based map.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRY_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRY_H
+
+#include "Marshallers.h"
+#include "llvm/ADT/StringMap.h"
+#include <string>
+
+namespace mlir::query::matcher {
+
+using ConstructorMap =
+ llvm::StringMap<std::unique_ptr<const internal::MatcherDescriptor>>;
+
+class Registry {
+public:
+ Registry() = default;
+ ~Registry() = default;
+
+ const ConstructorMap &constructors() const { return constructorMap; }
+
+ template <typename MatcherType>
+ void registerMatcher(const std::string &name, MatcherType matcher) {
+ registerMatcherDescriptor(name,
+ internal::makeMatcherAutoMarshall(matcher, name));
+ }
+
+private:
+ void registerMatcherDescriptor(
+ llvm::StringRef matcherName,
+ std::unique_ptr<internal::MatcherDescriptor> callback);
+
+ ConstructorMap constructorMap;
+};
+
+} // namespace mlir::query::matcher
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRY_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
new file mode 100644
index 000000000000000..449f8b3a01e0217
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -0,0 +1,128 @@
+//===--- VariantValue.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Supports all the types required for dynamic Matcher construction.
+// Used by the registry to construct matchers in a generic way.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_VARIANTVALUE_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHER_VARIANTVALUE_H
+
+#include "ErrorBuilder.h"
+#include "MatchersInternal.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir::query::matcher {
+
+// All types that VariantValue can contain.
+enum class ArgKind { Matcher, String };
+
+// A variant matcher object to abstract simple and complex matchers into a
+// single object type.
+class VariantMatcher {
+ class MatcherOps;
+
+ // Payload interface to be specialized by each matcher type. It follows a
+ // similar interface as VariantMatcher itself.
+ class Payload {
+ public:
+ virtual ~Payload();
+ virtual std::optional<DynMatcher> getDynMatcher() const = 0;
+ virtual std::string getTypeAsString() const = 0;
+ };
+
+public:
+ // A null matcher.
+ VariantMatcher();
+
+ // Clones the provided matcher.
+ static VariantMatcher SingleMatcher(DynMatcher matcher);
+
+ // Makes the matcher the "null" matcher.
+ void reset();
+
+ // Checks if the matcher is null.
+ bool isNull() const { return !value; }
+
+ // Returns the matcher
+ std::optional<DynMatcher> getDynMatcher() const;
+
+ // String representation of the type of the value.
+ std::string getTypeAsString() const;
+
+private:
+ explicit VariantMatcher(std::shared_ptr<Payload> value)
+ : value(std::move(value)) {}
+
+ class SinglePayload;
+
+ std::shared_ptr<const Payload> value;
+};
+
+// Variant value class with a tagged union with value type semantics. It is used
+// by the registry as the return value and argument type for the matcher factory
+// methods. It can be constructed from any of the supported types:
+// - StringRef
+// - VariantMatcher
+class VariantValue {
+public:
+ VariantValue() : type(ValueType::Nothing) {}
+
+ VariantValue(const VariantValue &other);
+ ~VariantValue();
+ VariantValue &operator=(const VariantValue &other);
+
+ // Specific constructors for each supported type.
+ VariantValue(const llvm::StringRef string);
+ VariantValue(const VariantMatcher &matcher);
+
+ // String value functions.
+ bool isString() const;
+ const llvm::StringRef &getString() const;
+ void setString(const llvm::StringRef &string);
+
+ // Matcher value functions.
+ bool isMatcher() const;
+ const VariantMatcher &getMatcher() const;
+ void setMatcher(const VariantMatcher &matcher);
+
+ // String representation of the type of the value.
+ std::string getTypeAsString() const;
+
+private:
+ void reset();
+
+ // All supported value types.
+ enum class ValueType {
+ Nothing,
+ String,
+ Matcher,
+ };
+
+ // All supported value types.
+ union AllValues {
+ llvm::StringRef *String;
+ VariantMatcher *Matcher;
+ };
+
+ ValueType type;
+ AllValues value;
+};
+
+// A VariantValue instance annotated with its parser context.
+struct ParserValue {
+ ParserValue() {}
+ llvm::StringRef text;
+ internal::SourceRange range;
+ VariantValue value;
+};
+
+} // namespace mlir::query::matcher
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_VARIANTVALUE_H
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
new file mode 100644
index 000000000000000..447fc7ca21c8da4
--- /dev/null
+++ b/mlir/include/mlir/Query/Query.h
@@ -0,0 +1,109 @@
+//===--- Query.h ------------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_QUERY_H
+#define MLIR_TOOLS_MLIRQUERY_QUERY_H
+
+#include "Matcher/VariantValue.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/LineEditor/LineEditor.h"
+#include <string>
+
+namespace mlir::query {
+
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+
+class QuerySession;
+
+struct Query : llvm::RefCountedBase<Query> {
+ Query(QueryKind kind) : kind(kind) {}
+ virtual ~Query();
+
+ // Perform the query on qs and print output to os.
+ virtual mlir::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const = 0;
+
+ llvm::StringRef remainingContent;
+ const QueryKind kind;
+};
+
+typedef llvm::IntrusiveRefCntPtr<Query> QueryRef;
+
+QueryRef parse(llvm::StringRef line, const QuerySession &qs);
+
+std::vector<llvm::LineEditor::Completion>
+complete(llvm::StringRef line, size_t pos, const QuerySession &qs);
+
+// Any query which resulted in a parse error. The error message is in ErrStr.
+struct InvalidQuery : Query {
+ InvalidQuery(const llvm::Twine &errStr)
+ : Query(QueryKind::Invalid), errStr(errStr.str()) {}
+ mlir::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override;
+
+ std::string errStr;
+
+ static bool classof(const Query *query) {
+ return query->kind == QueryKind::Invalid;
+ }
+};
+
+// No-op query (i.e. a blank line).
+struct NoOpQuery : Query {
+ NoOpQuery() : Query(QueryKind::NoOp) {}
+ mlir::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override;
+
+ static bool classof(const Query *query) {
+ return query->kind == QueryKind::NoOp;
+ }
+};
+
+// Query for "help".
+struct HelpQuery : Query {
+ HelpQuery() : Query(QueryKind::Help) {}
+ mlir::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override;
+
+ static bool classof(const Query *query) {
+ return query->kind == QueryKind::Help;
+ }
+};
+
+// Query for "quit".
+struct QuitQuery : Query {
+ QuitQuery() : Query(QueryKind::Quit) {}
+ mlir::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override;
+
+ static bool classof(const Query *query) {
+ return query->kind == QueryKind::Quit;
+ }
+};
+
+// Query for "match MATCHER".
+struct MatchQuery : Query {
+ MatchQuery(llvm::StringRef source, const matcher::DynMatcher &matcher)
+ : Query(QueryKind::Match), matcher(matcher), source(source) {}
+ mlir::LogicalResult run(llvm::raw_ostream &os,
+ QuerySession &qs) const override;
+
+ const matcher::DynMatcher matcher;
+
+ llvm::StringRef source;
+
+ static bool classof(const Query *query) {
+ return query->kind == QueryKind::Match;
+ }
+};
+
+} // namespace mlir::query
+
+#endif
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
new file mode 100644
index 000000000000000..b03a8cae8f18132
--- /dev/null
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -0,0 +1,42 @@
+//===--- QuerySession.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
+#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
+
+#include "llvm/ADT/StringMap.h"
+
+namespace mlir::query {
+
+class Registry;
+// Represents the state for a particular mlir-query session.
+class QuerySession {
+public:
+ QuerySession(Operation *rootOp, llvm::SourceMgr &sourceMgr, unsigned bufferId,
+ const matcher::Registry &matcherRegistry)
+ : rootOp(rootOp), sourceMgr(sourceMgr), bufferId(bufferId),
+ matcherRegistry(matcherRegistry) {}
+
+ Operation *getRootOp() { return rootOp; }
+ llvm::SourceMgr &getSourceManager() const { return sourceMgr; }
+ unsigned getBufferId() { return bufferId; }
+ const matcher::Registry &getRegistryData() const { return matcherRegistry; }
+
+ llvm::StringMap<matcher::VariantValue> namedValues;
+ bool terminate = false;
+
+private:
+ Operation *rootOp;
+ llvm::SourceMgr &sourceMgr;
+ unsigned bufferId;
+ const matcher::Registry &matcherRegistry;
+};
+
+} // namespace mlir::query
+
+#endif // MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
diff --git a/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h
new file mode 100644
index 000000000000000..fa1cd5d8176ee12
--- /dev/null
+++ b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h
@@ -0,0 +1,30 @@
+//===- MlirQueryMain.h - MLIR Query main ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Main entry function for mlir-query for when built as standalone
+// binary.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H
+#define MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H
+
+#include "mlir/Query/Matcher/Registry.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+
+class MLIRContext;
+
+LogicalResult
+mlirQueryMain(int argc, char **argv, MLIRContext &context,
+ const mlir::query::matcher::Registry &matcherRegistry);
+
+} // namespace mlir
+
+#endif // MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index c71664a3f00631b..d25c84a3975db4c 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(IR)
add_subdirectory(Interfaces)
add_subdirectory(Parser)
add_subdirectory(Pass)
+add_subdirectory(Query)
add_subdirectory(Reducer)
add_subdirectory(Rewrite)
add_subdirectory(Support)
diff --git a/mlir/lib/Query/CMakeLists.txt b/mlir/lib/Query/CMakeLists.txt
new file mode 100644
index 000000000000000..817583e94c52229
--- /dev/null
+++ b/mlir/lib/Query/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_library(MLIRQuery
+ Query.cpp
+ QueryParser.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query
+
+ LINK_LIBS PUBLIC
+ MLIRQueryMatcher
+ )
+
+add_subdirectory(Matcher)
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
new file mode 100644
index 000000000000000..6afd24722bb70ce
--- /dev/null
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_mlir_library(MLIRQueryMatcher
+ Parser.cpp
+ RegistryManager.cpp
+ VariantValue.cpp
+ Diagnostics.cpp
+ ErrorBuilder.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query/Matcher
+ )
diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp
new file mode 100644
index 000000000000000..10468dbcc530676
--- /dev/null
+++ b/mlir/lib/Query/Matcher/Diagnostics.cpp
@@ -0,0 +1,128 @@
+//===- Diagnostic.cpp -----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Diagnostics.h"
+#include "mlir/Query/Matcher/ErrorBuilder.h"
+
+namespace mlir::query::matcher::internal {
+
+Diagnostics::ArgStream &
+Diagnostics::ArgStream::operator<<(const llvm::Twine &arg) {
+ out->push_back(arg.str());
+ return *this;
+}
+
+Diagnostics::ArgStream Diagnostics::addError(SourceRange range,
+ ErrorType error) {
+ errorValues.emplace_back();
+ ErrorContent &last = errorValues.back();
+ last.contextStack = contextStack;
+ last.messages.emplace_back();
+ last.messages.back().range = range;
+ last.messages.back().type = error;
+ return ArgStream(&last.messages.back().args);
+}
+
+static llvm::StringRef errorTypeToFormatString(ErrorType type) {
+ switch (type) {
+ case ErrorType::RegistryMatcherNotFound:
+ return "Matcher not found: $0";
+ case ErrorType::RegistryWrongArgCount:
+ return "Incorrect argument count. (Expected = $0) != (Actual = $1)";
+ case ErrorType::RegistryWrongArgType:
+ return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
+ case ErrorType::RegistryValueNotFound:
+ return "Value not found: $0";
+
+ case ErrorType::ParserStringError:
+ return "Error parsing string token: <$0>";
+ case ErrorType::ParserNoOpenParen:
+ return "Error parsing matcher. Found token <$0> while looking for '('.";
+ case ErrorType::ParserNoCloseParen:
+ return "Error parsing matcher. Found end-of-code while looking for ')'.";
+ case ErrorType::ParserNoComma:
+ return "Error parsing matcher. Found token <$0> while looking for ','.";
+ case ErrorType::ParserNoCode:
+ return "End of code found while looking for token.";
+ case ErrorType::ParserNotAMatcher:
+ return "Input value is not a matcher expression.";
+ case ErrorType::ParserInvalidToken:
+ return "Invalid token <$0> found when looking for a value.";
+ case ErrorType::ParserTrailingCode:
+ return "Unexpected end of code.";
+ case ErrorType::ParserOverloadedType:
+ return "Input value has unresolved overloaded type: $0";
+ case ErrorType::ParserFailedToBuildMatcher:
+ return "Failed to build matcher: $0.";
+
+ case ErrorType::None:
+ return "<N/A>";
+ }
+ llvm_unreachable("Unknown ErrorType value.");
+}
+
+static void formatErrorString(llvm::StringRef formatString,
+ llvm::ArrayRef<std::string> args,
+ llvm::raw_ostream &os) {
+ while (!formatString.empty()) {
+ std::pair<llvm::StringRef, llvm::StringRef> pieces =
+ formatString.split("$");
+ os << pieces.first.str();
+ if (pieces.second.empty())
+ break;
+
+ const char next = pieces.second.front();
+ formatString = pieces.second.drop_front();
+ if (next >= '0' && next <= '9') {
+ const unsigned index = next - '0';
+ if (index < args.size()) {
+ os << args[index];
+ } else {
+ os << "<Argument_Not_Provided>";
+ }
+ }
+ }
+}
+
+static void maybeAddLineAndColumn(SourceRange range, llvm::raw_ostream &os) {
+ if (range.start.line > 0 && range.start.column > 0) {
+ os << range.start.line << ":" << range.start.column << ": ";
+ }
+}
+
+void Diagnostics::printMessage(
+ const Diagnostics::ErrorContent::Message &message, const llvm::Twine prefix,
+ llvm::raw_ostream &os) const {
+ maybeAddLineAndColumn(message.range, os);
+ os << prefix;
+ formatErrorString(errorTypeToFormatString(message.type), message.args, os);
+}
+
+void Diagnostics::printErrorContent(const Diagnostics::ErrorContent &content,
+ llvm::raw_ostream &os) const {
+ if (content.messages.size() == 1) {
+ printMessage(content.messages[0], "", os);
+ } else {
+ for (size_t i = 0, e = content.messages.size(); i != e; ++i) {
+ if (i != 0)
+ os << "\n";
+ printMessage(content.messages[i],
+ "Candidate " + llvm::Twine(i + 1) + ": ", os);
+ }
+ }
+}
+
+void Diagnostics::print(llvm::raw_ostream &os) const {
+ for (const ErrorContent &error : errorValues) {
+ if (&error != &errorValues.front())
+ os << "\n";
+ printErrorContent(error, os);
+ }
+}
+
+} // namespace mlir::query::matcher::internal
diff --git a/mlir/lib/Query/Matcher/Diagnostics.h b/mlir/lib/Query/Matcher/Diagnostics.h
new file mode 100644
index 000000000000000..a58a435b16a9030
--- /dev/null
+++ b/mlir/lib/Query/Matcher/Diagnostics.h
@@ -0,0 +1,82 @@
+//===--- Diagnostics.h - Helper class for error diagnostics -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Diagnostics class to manage error messages. Implementation shares similarity
+// to clang-query Diagnostics.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
+
+#include "mlir/Query/Matcher/ErrorBuilder.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+#include <string>
+#include <vector>
+
+namespace mlir::query::matcher::internal {
+
+// Diagnostics class to manage error messages.
+class Diagnostics {
+public:
+ // Helper stream class for constructing error messages.
+ class ArgStream {
+ public:
+ ArgStream(std::vector<std::string> *out) : out(out) {}
+ template <class T>
+ ArgStream &operator<<(const T &arg) {
+ return operator<<(llvm::Twine(arg));
+ }
+ ArgStream &operator<<(const llvm::Twine &arg);
+
+ private:
+ std::vector<std::string> *out;
+ };
+
+ // Add an error message with the specified range and error type.
+ // Returns an ArgStream object to allow constructing the error message using
+ // the << operator.
+ ArgStream addError(SourceRange range, ErrorType error);
+
+ // Print all error messages to the specified output stream.
+ void print(llvm::raw_ostream &os) const;
+
+private:
+ // Information stored for one frame of the context.
+ struct ContextFrame {
+ SourceRange range;
+ std::vector<std::string> args;
+ };
+
+ // Information stored for each error found.
+ struct ErrorContent {
+ std::vector<ContextFrame> contextStack;
+ struct Message {
+ SourceRange range;
+ ErrorType type;
+ std::vector<std::string> args;
+ };
+ std::vector<Message> messages;
+ };
+
+ void printMessage(const ErrorContent::Message &message,
+ const llvm::Twine Prefix, llvm::raw_ostream &os) const;
+
+ void printErrorContent(const ErrorContent &content,
+ llvm::raw_ostream &os) const;
+
+ std::vector<ContextFrame> contextStack;
+ std::vector<ErrorContent> errorValues;
+};
+
+} // namespace mlir::query::matcher::internal
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
diff --git a/mlir/lib/Query/Matcher/ErrorBuilder.cpp b/mlir/lib/Query/Matcher/ErrorBuilder.cpp
new file mode 100644
index 000000000000000..de6447dac490ac4
--- /dev/null
+++ b/mlir/lib/Query/Matcher/ErrorBuilder.cpp
@@ -0,0 +1,25 @@
+//===--- ErrorBuilder.cpp - Helper for building error messages ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/ErrorBuilder.h"
+#include "Diagnostics.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include <initializer_list>
+
+namespace mlir::query::matcher::internal {
+
+void addError(Diagnostics *error, SourceRange range, ErrorType errorType,
+ std::initializer_list<llvm::Twine> errorTexts) {
+ Diagnostics::ArgStream argStream = error->addError(range, errorType);
+ for (const llvm::Twine &errorText : errorTexts) {
+ argStream << errorText;
+ }
+}
+
+} // namespace mlir::query::matcher::internal
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
new file mode 100644
index 000000000000000..be9e60de221db19
--- /dev/null
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -0,0 +1,540 @@
+//===- Parser.cpp - Matcher expression parser -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Recursive parser implementation for the matcher expression grammar.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Parser.h"
+
+#include <vector>
+
+namespace mlir::query::matcher::internal {
+
+// Simple structure to hold information for one token from the parser.
+struct Parser::TokenInfo {
+ TokenInfo() = default;
+
+ // Method to set the kind and text of the token
+ void set(TokenKind newKind, llvm::StringRef newText) {
+ kind = newKind;
+ text = newText;
+ }
+
+ llvm::StringRef text;
+ TokenKind kind = TokenKind::Eof;
+ SourceRange range;
+ VariantValue value;
+};
+
+class Parser::CodeTokenizer {
+public:
+ // Constructor with matcherCode and error
+ explicit CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error)
+ : code(matcherCode), startOfLine(matcherCode), error(error) {
+ nextToken = getNextToken();
+ }
+
+ // Constructor with matcherCode, error, and codeCompletionOffset
+ CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error,
+ unsigned codeCompletionOffset)
+ : code(matcherCode), startOfLine(matcherCode), error(error),
+ codeCompletionLocation(matcherCode.data() + codeCompletionOffset) {
+ nextToken = getNextToken();
+ }
+
+ // Peek at next token without consuming it
+ const TokenInfo &peekNextToken() const { return nextToken; }
+
+ // Consume and return the next token
+ TokenInfo consumeNextToken() {
+ TokenInfo thisToken = nextToken;
+ nextToken = getNextToken();
+ return thisToken;
+ }
+
+ // Skip any newline tokens
+ TokenInfo skipNewlines() {
+ while (nextToken.kind == TokenKind::NewLine)
+ nextToken = getNextToken();
+ return nextToken;
+ }
+
+ // Consume and return next token, ignoring newlines
+ TokenInfo consumeNextTokenIgnoreNewlines() {
+ skipNewlines();
+ return nextToken.kind == TokenKind::Eof ? nextToken : consumeNextToken();
+ }
+
+ // Return kind of next token
+ TokenKind nextTokenKind() const { return nextToken.kind; }
+
+private:
+ // Helper function to get the first character as a new StringRef and drop it
+ // from the original string
+ llvm::StringRef firstCharacterAndDrop(llvm::StringRef &str) {
+ assert(!str.empty());
+ llvm::StringRef firstChar = str.substr(0, 1);
+ str = str.drop_front();
+ return firstChar;
+ }
+
+ // Get next token, consuming whitespaces and handling
diff erent token types
+ TokenInfo getNextToken() {
+ consumeWhitespace();
+ TokenInfo result;
+ result.range.start = currentLocation();
+
+ // Code completion case
+ if (codeCompletionLocation && codeCompletionLocation <= code.data()) {
+ result.set(TokenKind::CodeCompletion,
+ llvm::StringRef(codeCompletionLocation, 0));
+ codeCompletionLocation = nullptr;
+ return result;
+ }
+
+ // End of file case
+ if (code.empty()) {
+ result.set(TokenKind::Eof, "");
+ return result;
+ }
+
+ // Switch to handle specific characters
+ switch (code[0]) {
+ case '#':
+ code = code.drop_until([](char c) { return c == '\n'; });
+ return getNextToken();
+ case ',':
+ result.set(TokenKind::Comma, firstCharacterAndDrop(code));
+ break;
+ case '.':
+ result.set(TokenKind::Period, firstCharacterAndDrop(code));
+ break;
+ case '\n':
+ ++line;
+ startOfLine = code.drop_front();
+ result.set(TokenKind::NewLine, firstCharacterAndDrop(code));
+ break;
+ case '(':
+ result.set(TokenKind::OpenParen, firstCharacterAndDrop(code));
+ break;
+ case ')':
+ result.set(TokenKind::CloseParen, firstCharacterAndDrop(code));
+ break;
+ case '"':
+ case '\'':
+ consumeStringLiteral(&result);
+ break;
+ default:
+ parseIdentifierOrInvalid(&result);
+ break;
+ }
+
+ result.range.end = currentLocation();
+ return result;
+ }
+
+ // Consume a string literal, handle escape sequences and missing closing
+ // quote.
+ void consumeStringLiteral(TokenInfo *result) {
+ bool inEscape = false;
+ const char marker = code[0];
+ for (size_t length = 1; length < code.size(); ++length) {
+ if (inEscape) {
+ inEscape = false;
+ continue;
+ }
+ if (code[length] == '\\') {
+ inEscape = true;
+ continue;
+ }
+ if (code[length] == marker) {
+ result->kind = TokenKind::Literal;
+ result->text = code.substr(0, length + 1);
+ result->value = code.substr(1, length - 1);
+ code = code.drop_front(length + 1);
+ return;
+ }
+ }
+ llvm::StringRef errorText = code;
+ code = code.drop_front(code.size());
+ SourceRange range;
+ range.start = result->range.start;
+ range.end = currentLocation();
+ error->addError(range, ErrorType::ParserStringError) << errorText;
+ result->kind = TokenKind::Error;
+ }
+
+ void parseIdentifierOrInvalid(TokenInfo *result) {
+ if (isalnum(code[0])) {
+ // Parse an identifier
+ size_t tokenLength = 1;
+
+ while (true) {
+ // A code completion location in/immediately after an identifier will
+ // cause the portion of the identifier before the code completion
+ // location to become a code completion token.
+ if (codeCompletionLocation == code.data() + tokenLength) {
+ codeCompletionLocation = nullptr;
+ result->kind = TokenKind::CodeCompletion;
+ result->text = code.substr(0, tokenLength);
+ code = code.drop_front(tokenLength);
+ return;
+ }
+ if (tokenLength == code.size() || !(isalnum(code[tokenLength])))
+ break;
+ ++tokenLength;
+ }
+ result->kind = TokenKind::Ident;
+ result->text = code.substr(0, tokenLength);
+ code = code.drop_front(tokenLength);
+ } else {
+ result->kind = TokenKind::InvalidChar;
+ result->text = code.substr(0, 1);
+ code = code.drop_front(1);
+ }
+ }
+
+ // Consume all leading whitespace from code, except newlines
+ void consumeWhitespace() {
+ code = code.drop_while(
+ [](char c) { return llvm::StringRef(" \t\v\f\r").contains(c); });
+ }
+
+ // Returns the current location in the source code
+ SourceLocation currentLocation() {
+ SourceLocation location;
+ location.line = line;
+ location.column = code.data() - startOfLine.data() + 1;
+ return location;
+ }
+
+ llvm::StringRef code;
+ llvm::StringRef startOfLine;
+ unsigned line = 1;
+ Diagnostics *error;
+ TokenInfo nextToken;
+ const char *codeCompletionLocation = nullptr;
+};
+
+Parser::Sema::~Sema() = default;
+
+std::vector<ArgKind> Parser::Sema::getAcceptedCompletionTypes(
+ llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
+ return {};
+}
+
+std::vector<MatcherCompletion>
+Parser::Sema::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
+ return {};
+}
+
+// Entry for the scope of a parser
+struct Parser::ScopedContextEntry {
+ Parser *parser;
+
+ ScopedContextEntry(Parser *parser, MatcherCtor c) : parser(parser) {
+ parser->contextStack.emplace_back(c, 0u);
+ }
+
+ ~ScopedContextEntry() { parser->contextStack.pop_back(); }
+
+ void nextArg() { ++parser->contextStack.back().second; }
+};
+
+// Parse and validate expressions starting with an identifier.
+// This function can parse named values and matchers. In case of failure, it
+// will try to determine the user's intent to give an appropriate error message.
+bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
+ const TokenInfo nameToken = tokenizer->consumeNextToken();
+
+ if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
+ // Parse as a named value.
+ auto namedValue =
+ namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+
+ if (!namedValue.isMatcher()) {
+ error->addError(tokenizer->peekNextToken().range,
+ ErrorType::ParserNotAMatcher);
+ return false;
+ }
+
+ if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
+ error->addError(tokenizer->peekNextToken().range,
+ ErrorType::ParserNoOpenParen)
+ << "NewLine";
+ return false;
+ }
+
+ // If the syntax is correct and the name is not a matcher either, report
+ // an unknown named value.
+ if ((tokenizer->nextTokenKind() == TokenKind::Comma ||
+ tokenizer->nextTokenKind() == TokenKind::CloseParen ||
+ tokenizer->nextTokenKind() == TokenKind::NewLine ||
+ tokenizer->nextTokenKind() == TokenKind::Eof) &&
+ !sema->lookupMatcherCtor(nameToken.text)) {
+ error->addError(nameToken.range, ErrorType::RegistryValueNotFound)
+ << nameToken.text;
+ return false;
+ }
+ // Otherwise, fallback to the matcher parser.
+ }
+
+ tokenizer->skipNewlines();
+
+ assert(nameToken.kind == TokenKind::Ident);
+ TokenInfo openToken = tokenizer->consumeNextToken();
+ if (openToken.kind != TokenKind::OpenParen) {
+ error->addError(openToken.range, ErrorType::ParserNoOpenParen)
+ << openToken.text;
+ return false;
+ }
+
+ std::optional<MatcherCtor> ctor = sema->lookupMatcherCtor(nameToken.text);
+
+ // Parse as a matcher expression.
+ return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
+}
+
+// Parse the arguments of a matcher
+bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
+ const TokenInfo &nameToken, TokenInfo &endToken) {
+ ScopedContextEntry sce(this, ctor);
+
+ while (tokenizer->nextTokenKind() != TokenKind::Eof) {
+ if (tokenizer->nextTokenKind() == TokenKind::CloseParen) {
+ // end of args.
+ endToken = tokenizer->consumeNextToken();
+ break;
+ }
+
+ if (!args.empty()) {
+ // We must find a , token to continue.
+ TokenInfo commaToken = tokenizer->consumeNextToken();
+ if (commaToken.kind != TokenKind::Comma) {
+ error->addError(commaToken.range, ErrorType::ParserNoComma)
+ << commaToken.text;
+ return false;
+ }
+ }
+
+ ParserValue argValue;
+ tokenizer->skipNewlines();
+
+ argValue.text = tokenizer->peekNextToken().text;
+ argValue.range = tokenizer->peekNextToken().range;
+ if (!parseExpressionImpl(&argValue.value)) {
+ return false;
+ }
+
+ tokenizer->skipNewlines();
+ args.push_back(argValue);
+ sce.nextArg();
+ }
+
+ return true;
+}
+
+// Parse and validate a matcher expression.
+bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
+ const TokenInfo &openToken,
+ std::optional<MatcherCtor> ctor,
+ VariantValue *value) {
+ if (!ctor) {
+ error->addError(nameToken.range, ErrorType::RegistryMatcherNotFound)
+ << nameToken.text;
+ // Do not return here. We need to continue to give completion suggestions.
+ }
+
+ std::vector<ParserValue> args;
+ TokenInfo endToken;
+
+ tokenizer->skipNewlines();
+
+ if (!parseMatcherArgs(args, ctor.value_or(nullptr), nameToken, endToken)) {
+ return false;
+ }
+
+ // Check for the missing closing parenthesis
+ if (endToken.kind != TokenKind::CloseParen) {
+ error->addError(openToken.range, ErrorType::ParserNoCloseParen)
+ << nameToken.text;
+ return false;
+ }
+
+ if (!ctor)
+ return false;
+ // Merge the start and end infos.
+ SourceRange matcherRange = nameToken.range;
+ matcherRange.end = endToken.range.end;
+ VariantMatcher result =
+ sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
+ if (result.isNull())
+ return false;
+ *value = result;
+ return true;
+}
+
+// If the prefix of this completion matches the completion token, add it to
+// completions minus the prefix.
+void Parser::addCompletion(const TokenInfo &compToken,
+ const MatcherCompletion &completion) {
+ if (llvm::StringRef(completion.typedText).startswith(compToken.text)) {
+ completions.emplace_back(completion.typedText.substr(compToken.text.size()),
+ completion.matcherDecl);
+ }
+}
+
+std::vector<MatcherCompletion>
+Parser::getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
+ if (!namedValues)
+ return {};
+
+ std::vector<MatcherCompletion> result;
+ for (const auto &entry : *namedValues) {
+ std::string decl =
+ (entry.getValue().getTypeAsString() + " " + entry.getKey()).str();
+ result.emplace_back(entry.getKey(), decl);
+ }
+ return result;
+}
+
+void Parser::addExpressionCompletions() {
+ const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines();
+ assert(compToken.kind == TokenKind::CodeCompletion);
+
+ // We cannot complete code if there is an invalid element on the context
+ // stack.
+ for (const auto &entry : contextStack) {
+ if (!entry.first)
+ return;
+ }
+
+ auto acceptedTypes = sema->getAcceptedCompletionTypes(contextStack);
+ for (const auto &completion : sema->getMatcherCompletions(acceptedTypes)) {
+ addCompletion(compToken, completion);
+ }
+
+ for (const auto &completion : getNamedValueCompletions(acceptedTypes)) {
+ addCompletion(compToken, completion);
+ }
+}
+
+// Parse an <Expresssion>
+bool Parser::parseExpressionImpl(VariantValue *value) {
+ switch (tokenizer->nextTokenKind()) {
+ case TokenKind::Literal:
+ *value = tokenizer->consumeNextToken().value;
+ return true;
+ case TokenKind::Ident:
+ return parseIdentifierPrefixImpl(value);
+ case TokenKind::CodeCompletion:
+ addExpressionCompletions();
+ return false;
+ case TokenKind::Eof:
+ error->addError(tokenizer->consumeNextToken().range,
+ ErrorType::ParserNoCode);
+ return false;
+
+ case TokenKind::Error:
+ // This error was already reported by the tokenizer.
+ return false;
+ case TokenKind::NewLine:
+ case TokenKind::OpenParen:
+ case TokenKind::CloseParen:
+ case TokenKind::Comma:
+ case TokenKind::Period:
+ case TokenKind::InvalidChar:
+ const TokenInfo token = tokenizer->consumeNextToken();
+ error->addError(token.range, ErrorType::ParserInvalidToken)
+ << (token.kind == TokenKind::NewLine ? "NewLine" : token.text);
+ return false;
+ }
+
+ llvm_unreachable("Unknown token kind.");
+}
+
+Parser::Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
+ const NamedValueMap *namedValues, Diagnostics *error)
+ : tokenizer(tokenizer),
+ sema(std::make_unique<RegistrySema>(matcherRegistry)),
+ namedValues(namedValues), error(error) {}
+
+Parser::RegistrySema::~RegistrySema() = default;
+
+std::optional<MatcherCtor>
+Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
+ return RegistryManager::lookupMatcherCtor(matcherName, matcherRegistry);
+}
+
+VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
+ MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) {
+ return RegistryManager::constructMatcher(ctor, nameRange, args, error);
+}
+
+std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(
+ llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
+ return RegistryManager::getAcceptedCompletionTypes(context);
+}
+
+std::vector<MatcherCompletion> Parser::RegistrySema::getMatcherCompletions(
+ llvm::ArrayRef<ArgKind> acceptedTypes) {
+ return RegistryManager::getMatcherCompletions(acceptedTypes, matcherRegistry);
+}
+
+bool Parser::parseExpression(llvm::StringRef &code,
+ const Registry &matcherRegistry,
+ const NamedValueMap *namedValues,
+ VariantValue *value, Diagnostics *error) {
+ CodeTokenizer tokenizer(code, error);
+ Parser parser(&tokenizer, matcherRegistry, namedValues, error);
+ if (!parser.parseExpressionImpl(value))
+ return false;
+ auto nextToken = tokenizer.peekNextToken();
+ if (nextToken.kind != TokenKind::Eof &&
+ nextToken.kind != TokenKind::NewLine) {
+ error->addError(tokenizer.peekNextToken().range,
+ ErrorType::ParserTrailingCode);
+ return false;
+ }
+ return true;
+}
+
+std::vector<MatcherCompletion>
+Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset,
+ const Registry &matcherRegistry,
+ const NamedValueMap *namedValues) {
+ Diagnostics error;
+ CodeTokenizer tokenizer(code, &error, completionOffset);
+ Parser parser(&tokenizer, matcherRegistry, namedValues, &error);
+ VariantValue dummy;
+ parser.parseExpressionImpl(&dummy);
+
+ return parser.completions;
+}
+
+std::optional<DynMatcher> Parser::parseMatcherExpression(
+ llvm::StringRef &code, const Registry &matcherRegistry,
+ const NamedValueMap *namedValues, Diagnostics *error) {
+ VariantValue value;
+ if (!parseExpression(code, matcherRegistry, namedValues, &value, error))
+ return std::nullopt;
+ if (!value.isMatcher()) {
+ error->addError(SourceRange(), ErrorType::ParserNotAMatcher);
+ return std::nullopt;
+ }
+ std::optional<DynMatcher> result = value.getMatcher().getDynMatcher();
+ if (!result) {
+ error->addError(SourceRange(), ErrorType::ParserOverloadedType)
+ << value.getTypeAsString();
+ }
+ return result;
+}
+
+} // namespace mlir::query::matcher::internal
diff --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h
new file mode 100644
index 000000000000000..f049af34e9c907a
--- /dev/null
+++ b/mlir/lib/Query/Matcher/Parser.h
@@ -0,0 +1,188 @@
+//===--- Parser.h - Matcher expression parser -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Simple matcher expression parser.
+//
+// This file contains the Parser class, which is responsible for parsing
+// expressions in a specific format: matcherName(Arg0, Arg1, ..., ArgN). The
+// parser can also interpret simple types, like strings.
+//
+// The actual processing of the matchers is handled by a Sema object that is
+// provided to the parser.
+//
+// The grammar for the supported expressions is as follows:
+// <Expression> := <StringLiteral> | <MatcherExpression>
+// <StringLiteral> := "quoted string"
+// <MatcherExpression> := <MatcherName>(<ArgumentList>)
+// <MatcherName> := [a-zA-Z]+
+// <ArgumentList> := <Expression> | <Expression>,<ArgumentList>
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H
+
+#include "Diagnostics.h"
+#include "RegistryManager.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include <memory>
+#include <vector>
+
+namespace mlir::query::matcher::internal {
+
+// Matcher expression parser.
+class Parser {
+public:
+ // Different possible tokens.
+ enum class TokenKind {
+ Eof,
+ NewLine,
+ OpenParen,
+ CloseParen,
+ Comma,
+ Period,
+ Literal,
+ Ident,
+ InvalidChar,
+ CodeCompletion,
+ Error
+ };
+
+ // Interface to connect the parser with the registry and more. The parser uses
+ // the Sema instance passed into parseMatcherExpression() to handle all
+ // matcher tokens.
+ class Sema {
+ public:
+ virtual ~Sema();
+
+ // Process a matcher expression. The caller takes ownership of the Matcher
+ // object returned.
+ virtual VariantMatcher
+ actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
+ llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) = 0;
+
+ // Look up a matcher by name in the matcher name found by the parser.
+ virtual std::optional<MatcherCtor>
+ lookupMatcherCtor(llvm::StringRef matcherName) = 0;
+
+ // Compute the list of completion types for Context.
+ virtual std::vector<ArgKind> getAcceptedCompletionTypes(
+ llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> Context);
+
+ // Compute the list of completions that match any of acceptedTypes.
+ virtual std::vector<MatcherCompletion>
+ getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes);
+ };
+
+ // An implementation of the Sema interface that uses the matcher registry to
+ // process tokens.
+ class RegistrySema : public Parser::Sema {
+ public:
+ RegistrySema(const Registry &matcherRegistry)
+ : matcherRegistry(matcherRegistry) {}
+ ~RegistrySema() override;
+
+ std::optional<MatcherCtor>
+ lookupMatcherCtor(llvm::StringRef matcherName) override;
+
+ VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
+ SourceRange nameRange,
+ llvm::ArrayRef<ParserValue> args,
+ Diagnostics *error) override;
+
+ std::vector<ArgKind> getAcceptedCompletionTypes(
+ llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
+
+ std::vector<MatcherCompletion>
+ getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) override;
+
+ private:
+ const Registry &matcherRegistry;
+ };
+
+ using NamedValueMap = llvm::StringMap<VariantValue>;
+
+ // Methods to parse a matcher expression and return a DynMatcher object,
+ // transferring ownership to the caller.
+ static std::optional<DynMatcher>
+ parseMatcherExpression(llvm::StringRef &matcherCode,
+ const Registry &matcherRegistry,
+ const NamedValueMap *namedValues, Diagnostics *error);
+ static std::optional<DynMatcher>
+ parseMatcherExpression(llvm::StringRef &matcherCode,
+ const Registry &matcherRegistry, Diagnostics *error) {
+ return parseMatcherExpression(matcherCode, matcherRegistry, nullptr, error);
+ }
+
+ // Methods to parse any expression supported by this parser.
+ static bool parseExpression(llvm::StringRef &code,
+ const Registry &matcherRegistry,
+ const NamedValueMap *namedValues,
+ VariantValue *value, Diagnostics *error);
+
+ static bool parseExpression(llvm::StringRef &code,
+ const Registry &matcherRegistry,
+ VariantValue *value, Diagnostics *error) {
+ return parseExpression(code, matcherRegistry, nullptr, value, error);
+ }
+
+ // Methods to complete an expression at a given offset.
+ static std::vector<MatcherCompletion>
+ completeExpression(llvm::StringRef &code, unsigned completionOffset,
+ const Registry &matcherRegistry,
+ const NamedValueMap *namedValues);
+ static std::vector<MatcherCompletion>
+ completeExpression(llvm::StringRef &code, unsigned completionOffset,
+ const Registry &matcherRegistry) {
+ return completeExpression(code, completionOffset, matcherRegistry, nullptr);
+ }
+
+private:
+ class CodeTokenizer;
+ struct ScopedContextEntry;
+ struct TokenInfo;
+
+ Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
+ const NamedValueMap *namedValues, Diagnostics *error);
+
+ bool parseExpressionImpl(VariantValue *value);
+
+ bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
+ const TokenInfo &nameToken, TokenInfo &endToken);
+
+ bool parseMatcherExpressionImpl(const TokenInfo &nameToken,
+ const TokenInfo &openToken,
+ std::optional<MatcherCtor> ctor,
+ VariantValue *value);
+
+ bool parseIdentifierPrefixImpl(VariantValue *value);
+
+ void addCompletion(const TokenInfo &compToken,
+ const MatcherCompletion &completion);
+ void addExpressionCompletions();
+
+ std::vector<MatcherCompletion>
+ getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes);
+
+ CodeTokenizer *const tokenizer;
+ std::unique_ptr<RegistrySema> sema;
+ const NamedValueMap *const namedValues;
+ Diagnostics *const error;
+
+ using ContextStackTy = std::vector<std::pair<MatcherCtor, unsigned>>;
+
+ ContextStackTy contextStack;
+ std::vector<MatcherCompletion> completions;
+};
+
+} // namespace mlir::query::matcher::internal
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
new file mode 100644
index 000000000000000..01856aa8ffa67f3
--- /dev/null
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -0,0 +1,139 @@
+//===- RegistryManager.cpp - Matcher registry -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Registry map populated at static initialization time.
+//
+//===----------------------------------------------------------------------===//
+
+#include "RegistryManager.h"
+#include "mlir/Query/Matcher/Registry.h"
+
+#include <set>
+#include <utility>
+
+namespace mlir::query::matcher {
+namespace {
+
+// This is needed because these matchers are defined as overloaded functions.
+using IsConstantOp = detail::constant_op_matcher();
+using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
+using HasOpName = detail::NameOpMatcher(llvm::StringRef);
+
+// Enum to string for autocomplete.
+static std::string asArgString(ArgKind kind) {
+ switch (kind) {
+ case ArgKind::Matcher:
+ return "Matcher";
+ case ArgKind::String:
+ return "String";
+ }
+ llvm_unreachable("Unhandled ArgKind");
+}
+
+} // namespace
+
+void Registry::registerMatcherDescriptor(
+ llvm::StringRef matcherName,
+ std::unique_ptr<internal::MatcherDescriptor> callback) {
+ assert(!constructorMap.contains(matcherName));
+ constructorMap[matcherName] = std::move(callback);
+}
+
+std::optional<MatcherCtor>
+RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName,
+ const Registry &matcherRegistry) {
+ auto it = matcherRegistry.constructors().find(matcherName);
+ return it == matcherRegistry.constructors().end()
+ ? std::optional<MatcherCtor>()
+ : it->second.get();
+}
+
+std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
+ llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
+ // Starting with the above seed of acceptable top-level matcher types, compute
+ // the acceptable type set for the argument indicated by each context element.
+ std::set<ArgKind> typeSet;
+ typeSet.insert(ArgKind::Matcher);
+
+ for (const auto &ctxEntry : context) {
+ MatcherCtor ctor = ctxEntry.first;
+ unsigned argNumber = ctxEntry.second;
+ std::vector<ArgKind> nextTypeSet;
+
+ if (argNumber < ctor->getNumArgs())
+ ctor->getArgKinds(argNumber, nextTypeSet);
+
+ typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
+ }
+
+ return std::vector<ArgKind>(typeSet.begin(), typeSet.end());
+}
+
+std::vector<MatcherCompletion>
+RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
+ const Registry &matcherRegistry) {
+ std::vector<MatcherCompletion> completions;
+
+ // Search the registry for acceptable matchers.
+ for (const auto &m : matcherRegistry.constructors()) {
+ const internal::MatcherDescriptor &matcher = *m.getValue();
+ llvm::StringRef name = m.getKey();
+
+ unsigned numArgs = matcher.getNumArgs();
+ std::vector<std::vector<ArgKind>> argKinds(numArgs);
+
+ for (const ArgKind &kind : acceptedTypes) {
+ if (kind != ArgKind::Matcher)
+ continue;
+
+ for (unsigned arg = 0; arg != numArgs; ++arg)
+ matcher.getArgKinds(arg, argKinds[arg]);
+ }
+
+ std::string decl;
+ llvm::raw_string_ostream os(decl);
+
+ std::string typedText = std::string(name);
+ os << "Matcher: " << name << "(";
+
+ for (const std::vector<ArgKind> &arg : argKinds) {
+ if (&arg != &argKinds[0])
+ os << ", ";
+
+ bool firstArgKind = true;
+ // Two steps. First all non-matchers, then matchers only.
+ for (const ArgKind &argKind : arg) {
+ if (!firstArgKind)
+ os << "|";
+
+ firstArgKind = false;
+ os << asArgString(argKind);
+ }
+ }
+
+ os << ")";
+ typedText += "(";
+
+ if (argKinds.empty())
+ typedText += ")";
+ else if (argKinds[0][0] == ArgKind::String)
+ typedText += "\"";
+
+ completions.emplace_back(typedText, os.str());
+ }
+
+ return completions;
+}
+
+VariantMatcher RegistryManager::constructMatcher(
+ MatcherCtor ctor, internal::SourceRange nameRange,
+ llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
+ return ctor->create(nameRange, args, error);
+}
+
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.h b/mlir/lib/Query/Matcher/RegistryManager.h
new file mode 100644
index 000000000000000..5f2867261225e76
--- /dev/null
+++ b/mlir/lib/Query/Matcher/RegistryManager.h
@@ -0,0 +1,70 @@
+//===--- RegistryManager.h - Matcher registry -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// RegistryManager to manage registry of all known matchers.
+//
+// The registry provides a generic interface to construct any matcher by name.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRYMANAGER_H
+#define MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRYMANAGER_H
+
+#include "Diagnostics.h"
+#include "mlir/Query/Matcher/Marshallers.h"
+#include "mlir/Query/Matcher/Registry.h"
+#include "mlir/Query/Matcher/VariantValue.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include <string>
+
+namespace mlir::query::matcher {
+
+using MatcherCtor = const internal::MatcherDescriptor *;
+
+struct MatcherCompletion {
+ MatcherCompletion() = default;
+ MatcherCompletion(llvm::StringRef typedText, llvm::StringRef matcherDecl)
+ : typedText(typedText.str()), matcherDecl(matcherDecl.str()) {}
+
+ bool operator==(const MatcherCompletion &other) const {
+ return typedText == other.typedText && matcherDecl == other.matcherDecl;
+ }
+
+ // The text to type to select this matcher.
+ std::string typedText;
+
+ // The "declaration" of the matcher, with type information.
+ std::string matcherDecl;
+};
+
+class RegistryManager {
+public:
+ RegistryManager() = delete;
+
+ static std::optional<MatcherCtor>
+ lookupMatcherCtor(llvm::StringRef matcherName,
+ const Registry &matcherRegistry);
+
+ static std::vector<ArgKind> getAcceptedCompletionTypes(
+ llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context);
+
+ static std::vector<MatcherCompletion>
+ getMatcherCompletions(ArrayRef<ArgKind> acceptedTypes,
+ const Registry &matcherRegistry);
+
+ static VariantMatcher constructMatcher(MatcherCtor ctor,
+ internal::SourceRange nameRange,
+ ArrayRef<ParserValue> args,
+ internal::Diagnostics *error);
+};
+
+} // namespace mlir::query::matcher
+
+#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRYMANAGER_H
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
new file mode 100644
index 000000000000000..65bd4bd77bcf8af
--- /dev/null
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -0,0 +1,132 @@
+//===--- Variantvalue.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Matcher/VariantValue.h"
+
+namespace mlir::query::matcher {
+
+VariantMatcher::Payload::~Payload() = default;
+
+class VariantMatcher::SinglePayload : public VariantMatcher::Payload {
+public:
+ explicit SinglePayload(DynMatcher matcher) : matcher(std::move(matcher)) {}
+
+ std::optional<DynMatcher> getDynMatcher() const override { return matcher; }
+
+ std::string getTypeAsString() const override { return "Matcher"; }
+
+private:
+ DynMatcher matcher;
+};
+
+VariantMatcher::VariantMatcher() = default;
+
+VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) {
+ return VariantMatcher(std::make_shared<SinglePayload>(std::move(matcher)));
+}
+
+std::optional<DynMatcher> VariantMatcher::getDynMatcher() const {
+ return value ? value->getDynMatcher() : std::nullopt;
+}
+
+void VariantMatcher::reset() { value.reset(); }
+
+std::string VariantMatcher::getTypeAsString() const { return "<Nothing>"; }
+
+VariantValue::VariantValue(const VariantValue &other)
+ : type(ValueType::Nothing) {
+ *this = other;
+}
+
+VariantValue::VariantValue(const llvm::StringRef string)
+ : type(ValueType::String) {
+ value.String = new llvm::StringRef(string);
+}
+
+VariantValue::VariantValue(const VariantMatcher &matcher)
+ : type(ValueType::Matcher) {
+ value.Matcher = new VariantMatcher(matcher);
+}
+
+VariantValue::~VariantValue() { reset(); }
+
+VariantValue &VariantValue::operator=(const VariantValue &other) {
+ if (this == &other)
+ return *this;
+ reset();
+ switch (other.type) {
+ case ValueType::String:
+ setString(other.getString());
+ break;
+ case ValueType::Matcher:
+ setMatcher(other.getMatcher());
+ break;
+ case ValueType::Nothing:
+ type = ValueType::Nothing;
+ break;
+ }
+ return *this;
+}
+
+void VariantValue::reset() {
+ switch (type) {
+ case ValueType::String:
+ delete value.String;
+ break;
+ case ValueType::Matcher:
+ delete value.Matcher;
+ break;
+ // Cases that do nothing.
+ case ValueType::Nothing:
+ break;
+ }
+ type = ValueType::Nothing;
+}
+
+bool VariantValue::isString() const { return type == ValueType::String; }
+
+const llvm::StringRef &VariantValue::getString() const {
+ assert(isString());
+ return *value.String;
+}
+
+void VariantValue::setString(const llvm::StringRef &newValue) {
+ reset();
+ type = ValueType::String;
+ value.String = new llvm::StringRef(newValue);
+}
+
+bool VariantValue::isMatcher() const { return type == ValueType::Matcher; }
+
+const VariantMatcher &VariantValue::getMatcher() const {
+ assert(isMatcher());
+ return *value.Matcher;
+}
+
+void VariantValue::setMatcher(const VariantMatcher &newValue) {
+ reset();
+ type = ValueType::Matcher;
+ value.Matcher = new VariantMatcher(newValue);
+}
+
+std::string VariantValue::getTypeAsString() const {
+ switch (type) {
+ case ValueType::String:
+ return "String";
+ case ValueType::Matcher:
+ return "Matcher";
+ case ValueType::Nothing:
+ return "Nothing";
+ }
+ llvm_unreachable("Invalid Type");
+}
+
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
new file mode 100644
index 000000000000000..5c42e5a5f0a116e
--- /dev/null
+++ b/mlir/lib/Query/Query.cpp
@@ -0,0 +1,82 @@
+//===---- Query.cpp - -----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Query/Query.h"
+#include "QueryParser.h"
+#include "mlir/Query/Matcher/MatchFinder.h"
+#include "mlir/Query/QuerySession.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir::query {
+
+QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
+ return QueryParser::parse(line, qs);
+}
+
+std::vector<llvm::LineEditor::Completion>
+complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
+ return QueryParser::complete(line, pos, qs);
+}
+
+static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
+ const std::string &binding) {
+ auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+ auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+ qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+ qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+ "\"" + binding + "\" binds here");
+}
+
+Query::~Query() = default;
+
+mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
+ QuerySession &qs) const {
+ os << errStr << "\n";
+ return mlir::failure();
+}
+
+mlir::LogicalResult NoOpQuery::run(llvm::raw_ostream &os,
+ QuerySession &qs) const {
+ return mlir::success();
+}
+
+mlir::LogicalResult HelpQuery::run(llvm::raw_ostream &os,
+ QuerySession &qs) const {
+ os << "Available commands:\n\n"
+ " match MATCHER, m MATCHER "
+ "Match the mlir against the given matcher.\n"
+ " quit "
+ "Terminates the query session.\n\n";
+ return mlir::success();
+}
+
+mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
+ QuerySession &qs) const {
+ qs.terminate = true;
+ return mlir::success();
+}
+
+mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
+ QuerySession &qs) const {
+ int matchCount = 0;
+ std::vector<Operation *> matches =
+ matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
+ os << "\n";
+ for (Operation *op : matches) {
+ os << "Match #" << ++matchCount << ":\n\n";
+ // Placeholder "root" binding for the initial draft.
+ printMatch(os, qs, op, "root");
+ }
+ os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+
+ return mlir::success();
+}
+
+} // namespace mlir::query
diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
new file mode 100644
index 000000000000000..f43a28569f0078b
--- /dev/null
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -0,0 +1,217 @@
+//===---- QueryParser.cpp - mlir-query command parser ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "QueryParser.h"
+#include "llvm/ADT/StringSwitch.h"
+
+namespace mlir::query {
+
+// Lex any amount of whitespace followed by a "word" (any sequence of
+// non-whitespace characters) from the start of region [begin,end). If no word
+// is found before end, return StringRef(). begin is adjusted to exclude the
+// lexed region.
+llvm::StringRef QueryParser::lexWord() {
+ line = line.drop_while([](char c) {
+ // Don't trim newlines.
+ return llvm::StringRef(" \t\v\f\r").contains(c);
+ });
+
+ if (line.empty())
+ // Even though the line is empty, it contains a pointer and
+ // a (zero) length. The pointer is used in the LexOrCompleteWord
+ // code completion.
+ return line;
+
+ llvm::StringRef word;
+ if (line.front() == '#') {
+ word = line.substr(0, 1);
+ } else {
+ word = line.take_until([](char c) {
+ // Don't trim newlines.
+ return llvm::StringRef(" \t\v\f\r").contains(c);
+ });
+ }
+
+ line = line.drop_front(word.size());
+ return word;
+}
+
+// This is the StringSwitch-alike used by LexOrCompleteWord below. See that
+// function for details.
+template <typename T>
+struct QueryParser::LexOrCompleteWord {
+ llvm::StringRef word;
+ llvm::StringSwitch<T> stringSwitch;
+
+ QueryParser *queryParser;
+ // Set to the completion point offset in word, or StringRef::npos if
+ // completion point not in word.
+ size_t wordCompletionPos;
+
+ // Lexes a word and stores it in word. Returns a LexOrCompleteword<T> object
+ // that can be used like a llvm::StringSwitch<T>, but adds cases as possible
+ // completions if the lexed word contains the completion point.
+ LexOrCompleteWord(QueryParser *queryParser, llvm::StringRef &outWord)
+ : word(queryParser->lexWord()), stringSwitch(word),
+ queryParser(queryParser), wordCompletionPos(llvm::StringRef::npos) {
+ outWord = word;
+ if (queryParser->completionPos &&
+ queryParser->completionPos <= word.data() + word.size()) {
+ if (queryParser->completionPos < word.data())
+ wordCompletionPos = 0;
+ else
+ wordCompletionPos = queryParser->completionPos - word.data();
+ }
+ }
+
+ LexOrCompleteWord &Case(llvm::StringLiteral caseStr, const T &value,
+ bool isCompletion = true) {
+
+ if (wordCompletionPos == llvm::StringRef::npos)
+ stringSwitch.Case(caseStr, value);
+ else if (!caseStr.empty() && isCompletion &&
+ wordCompletionPos <= caseStr.size() &&
+ caseStr.substr(0, wordCompletionPos) ==
+ word.substr(0, wordCompletionPos)) {
+
+ queryParser->completions.emplace_back(
+ (caseStr.substr(wordCompletionPos) + " ").str(),
+ std::string(caseStr));
+ }
+ return *this;
+ }
+
+ T Default(T value) { return stringSwitch.Default(value); }
+};
+
+QueryRef QueryParser::endQuery(QueryRef queryRef) {
+ llvm::StringRef extra = line;
+ llvm::StringRef extraTrimmed = extra.drop_while(
+ [](char c) { return llvm::StringRef(" \t\v\f\r").contains(c); });
+
+ if ((!extraTrimmed.empty() && extraTrimmed[0] == '\n') ||
+ (extraTrimmed.size() >= 2 && extraTrimmed[0] == '\r' &&
+ extraTrimmed[1] == '\n'))
+ queryRef->remainingContent = extra;
+ else {
+ llvm::StringRef trailingWord = lexWord();
+ if (!trailingWord.empty() && trailingWord.front() == '#') {
+ line = line.drop_until([](char c) { return c == '\n'; });
+ line = line.drop_while([](char c) { return c == '\n'; });
+ return endQuery(queryRef);
+ }
+ if (!trailingWord.empty()) {
+ return new InvalidQuery("unexpected extra input: '" + extra + "'");
+ }
+ }
+ return queryRef;
+}
+
+namespace {
+
+enum class ParsedQueryKind {
+ Invalid,
+ Comment,
+ NoOp,
+ Help,
+ Match,
+ Quit,
+};
+
+QueryRef
+makeInvalidQueryFromDiagnostics(const matcher::internal::Diagnostics &diag) {
+ std::string errStr;
+ llvm::raw_string_ostream os(errStr);
+ diag.print(os);
+ return new InvalidQuery(os.str());
+}
+} // namespace
+
+QueryRef QueryParser::completeMatcherExpression() {
+ std::vector<matcher::MatcherCompletion> comps =
+ matcher::internal::Parser::completeExpression(
+ line, completionPos - line.begin(), qs.getRegistryData(),
+ &qs.namedValues);
+ for (const auto &comp : comps) {
+ completions.emplace_back(comp.typedText, comp.matcherDecl);
+ }
+ return QueryRef();
+}
+
+QueryRef QueryParser::doParse() {
+
+ llvm::StringRef commandStr;
+ ParsedQueryKind qKind =
+ LexOrCompleteWord<ParsedQueryKind>(this, commandStr)
+ .Case("", ParsedQueryKind::NoOp)
+ .Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
+ .Case("help", ParsedQueryKind::Help)
+ .Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
+ .Case("match", ParsedQueryKind::Match)
+ .Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false)
+ .Case("quit", ParsedQueryKind::Quit)
+ .Default(ParsedQueryKind::Invalid);
+
+ switch (qKind) {
+ case ParsedQueryKind::Comment:
+ case ParsedQueryKind::NoOp:
+ line = line.drop_until([](char c) { return c == '\n'; });
+ line = line.drop_while([](char c) { return c == '\n'; });
+ if (line.empty())
+ return new NoOpQuery;
+ return doParse();
+
+ case ParsedQueryKind::Help:
+ return endQuery(new HelpQuery);
+
+ case ParsedQueryKind::Quit:
+ return endQuery(new QuitQuery);
+
+ case ParsedQueryKind::Match: {
+ if (completionPos) {
+ return completeMatcherExpression();
+ }
+
+ matcher::internal::Diagnostics diag;
+ auto matcherSource = line.ltrim();
+ auto origMatcherSource = matcherSource;
+ std::optional<matcher::DynMatcher> matcher =
+ matcher::internal::Parser::parseMatcherExpression(
+ matcherSource, qs.getRegistryData(), &qs.namedValues, &diag);
+ if (!matcher) {
+ return makeInvalidQueryFromDiagnostics(diag);
+ }
+ auto actualSource = origMatcherSource.slice(0, origMatcherSource.size() -
+ matcherSource.size());
+ QueryRef query = new MatchQuery(actualSource, *matcher);
+ query->remainingContent = matcherSource;
+ return query;
+ }
+
+ case ParsedQueryKind::Invalid:
+ return new InvalidQuery("unknown command: " + commandStr);
+ }
+
+ llvm_unreachable("Invalid query kind");
+}
+
+QueryRef QueryParser::parse(llvm::StringRef line, const QuerySession &qs) {
+ return QueryParser(line, qs).doParse();
+}
+
+std::vector<llvm::LineEditor::Completion>
+QueryParser::complete(llvm::StringRef line, size_t pos,
+ const QuerySession &qs) {
+ QueryParser queryParser(line, qs);
+ queryParser.completionPos = line.data() + pos;
+
+ queryParser.doParse();
+ return queryParser.completions;
+}
+
+} // namespace mlir::query
diff --git a/mlir/lib/Query/QueryParser.h b/mlir/lib/Query/QueryParser.h
new file mode 100644
index 000000000000000..e9c30eccecab9e4
--- /dev/null
+++ b/mlir/lib/Query/QueryParser.h
@@ -0,0 +1,59 @@
+//===--- QueryParser.h - ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H
+#define MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H
+
+#include "Matcher/Parser.h"
+#include "mlir/Query/Query.h"
+#include "mlir/Query/QuerySession.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/LineEditor/LineEditor.h"
+
+namespace mlir::query {
+
+class QuerySession;
+
+class QueryParser {
+public:
+ // Parse line as a query and return a QueryRef representing the query, which
+ // may be an InvalidQuery.
+ static QueryRef parse(llvm::StringRef line, const QuerySession &qs);
+
+ static std::vector<llvm::LineEditor::Completion>
+ complete(llvm::StringRef line, size_t pos, const QuerySession &qs);
+
+private:
+ QueryParser(llvm::StringRef line, const QuerySession &qs)
+ : line(line), completionPos(nullptr), qs(qs) {}
+
+ llvm::StringRef lexWord();
+
+ template <typename T>
+ struct LexOrCompleteWord;
+
+ QueryRef completeMatcherExpression();
+
+ QueryRef endQuery(QueryRef queryRef);
+
+ // Parse [begin, end) and returns a reference to the parsed query object,
+ // which may be an InvalidQuery if a parse error occurs.
+ QueryRef doParse();
+
+ llvm::StringRef line;
+
+ const char *completionPos;
+ std::vector<llvm::LineEditor::Completion> completions;
+
+ const QuerySession &qs;
+};
+
+} // namespace mlir::query
+
+#endif // MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H
diff --git a/mlir/lib/Tools/CMakeLists.txt b/mlir/lib/Tools/CMakeLists.txt
index 6175a1ce5f8d1db..01270fa4b0fc341 100644
--- a/mlir/lib/Tools/CMakeLists.txt
+++ b/mlir/lib/Tools/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(lsp-server-support)
add_subdirectory(mlir-lsp-server)
add_subdirectory(mlir-opt)
add_subdirectory(mlir-pdll-lsp-server)
+add_subdirectory(mlir-query)
add_subdirectory(mlir-reduce)
add_subdirectory(mlir-tblgen)
add_subdirectory(mlir-translate)
diff --git a/mlir/lib/Tools/mlir-query/CMakeLists.txt b/mlir/lib/Tools/mlir-query/CMakeLists.txt
new file mode 100644
index 000000000000000..b81b02d42bfcaa4
--- /dev/null
+++ b/mlir/lib/Tools/mlir-query/CMakeLists.txt
@@ -0,0 +1,13 @@
+set(LLVM_LINK_COMPONENTS
+ lineeditor
+ )
+
+add_mlir_library(MLIRQueryLib
+ MlirQueryMain.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-query
+
+ LINK_LIBS PUBLIC
+ MLIRQuery
+ )
diff --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
new file mode 100644
index 000000000000000..15de16a8774bc07
--- /dev/null
+++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
@@ -0,0 +1,115 @@
+//===- MlirQueryMain.cpp - MLIR Query main --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the general framework of the MLIR query tool. It
+// parses the command line arguments, parses the MLIR file and outputs the query
+// results.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/mlir-query/MlirQueryMain.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Query/Query.h"
+#include "mlir/Query/QuerySession.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/LineEditor/LineEditor.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
+
+//===----------------------------------------------------------------------===//
+// Query Parser
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult
+mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context,
+ const mlir::query::matcher::Registry &matcherRegistry) {
+
+ // Override the default '-h' and use the default PrintHelpMessage() which
+ // won't print options in categories.
+ static llvm::cl::opt<bool> help("h", llvm::cl::desc("Alias for -help"),
+ llvm::cl::Hidden);
+
+ static llvm::cl::OptionCategory mlirQueryCategory("mlir-query options");
+
+ static llvm::cl::list<std::string> commands(
+ "c", llvm::cl::desc("Specify command to run"),
+ llvm::cl::value_desc("command"), llvm::cl::cat(mlirQueryCategory));
+
+ static llvm::cl::opt<std::string> inputFilename(
+ llvm::cl::Positional, llvm::cl::desc("<input file>"),
+ llvm::cl::cat(mlirQueryCategory));
+
+ static llvm::cl::opt<bool> noImplicitModule{
+ "no-implicit-module",
+ llvm::cl::desc(
+ "Disable implicit addition of a top-level module op during parsing"),
+ llvm::cl::init(false)};
+
+ static llvm::cl::opt<bool> allowUnregisteredDialects(
+ "allow-unregistered-dialect",
+ llvm::cl::desc("Allow operation with no registered dialects"),
+ llvm::cl::init(false));
+
+ llvm::cl::HideUnrelatedOptions(mlirQueryCategory);
+
+ llvm::InitLLVM y(argc, argv);
+
+ llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR test case query tool.\n");
+
+ if (help) {
+ llvm::cl::PrintHelpMessage();
+ return mlir::success();
+ }
+
+ // Set up the input file.
+ std::string errorMessage;
+ auto file = openInputFile(inputFilename, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ return mlir::failure();
+ }
+
+ auto sourceMgr = llvm::SourceMgr();
+ auto bufferId = sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+ context.allowUnregisteredDialects(allowUnregisteredDialects);
+
+ // Parse the input MLIR file.
+ OwningOpRef<Operation *> opRef =
+ noImplicitModule ? parseSourceFile(sourceMgr, &context)
+ : parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
+ if (!opRef)
+ return mlir::failure();
+
+ mlir::query::QuerySession qs(opRef.get(), sourceMgr, bufferId,
+ matcherRegistry);
+ if (!commands.empty()) {
+ for (auto &command : commands) {
+ mlir::query::QueryRef queryRef = mlir::query::parse(command, qs);
+ if (mlir::failed(queryRef->run(llvm::outs(), qs)))
+ return mlir::failure();
+ }
+ } else {
+ llvm::LineEditor le("mlir-query");
+ le.setListCompleter([&qs](llvm::StringRef line, size_t pos) {
+ return mlir::query::complete(line, pos, qs);
+ });
+ while (std::optional<std::string> line = le.readLine()) {
+ mlir::query::QueryRef queryRef = mlir::query::parse(*line, qs);
+ (void)queryRef->run(llvm::outs(), qs);
+ llvm::outs().flush();
+ if (qs.terminate)
+ break;
+ }
+ }
+
+ return mlir::success();
+}
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index bf143d036c2f66f..6fc9ae0f3fc58fa 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -104,6 +104,7 @@ set(MLIR_TEST_DEPENDS
mlir-pdll-lsp-server
mlir-opt
mlir-pdll
+ mlir-query
mlir-reduce
mlir-tblgen
mlir-translate
diff --git a/mlir/test/mlir-query/simple-test.mlir b/mlir/test/mlir-query/simple-test.mlir
new file mode 100644
index 000000000000000..a4d006598767b3f
--- /dev/null
+++ b/mlir/test/mlir-query/simple-test.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-query %s -c "m isConstantOp()" | FileCheck %s
+
+// CHECK: {{.*}}.mlir:5:13: note: "root" binds here
+func.func @simple1() {
+ %c1_i32 = arith.constant 1 : i32
+ return
+}
+
+// CHECK: {{.*}}.mlir:12:11: note: "root" binds here
+// CHECK: {{.*}}.mlir:13:11: note: "root" binds here
+func.func @simple2() {
+ %cst1 = arith.constant 1.0 : f32
+ %cst2 = arith.constant 2.0 : f32
+ %add = arith.addf %cst1, %cst2 : f32
+ return
+}
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index e9a1e4d6251722e..a01f74f737e1bc1 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(mlir-lsp-server)
add_subdirectory(mlir-opt)
add_subdirectory(mlir-parser-fuzzer)
add_subdirectory(mlir-pdll-lsp-server)
+add_subdirectory(mlir-query)
add_subdirectory(mlir-reduce)
add_subdirectory(mlir-shlib)
add_subdirectory(mlir-spirv-cpu-runner)
diff --git a/mlir/tools/mlir-query/CMakeLists.txt b/mlir/tools/mlir-query/CMakeLists.txt
new file mode 100644
index 000000000000000..ef2e5a84b5569b5
--- /dev/null
+++ b/mlir/tools/mlir-query/CMakeLists.txt
@@ -0,0 +1,20 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+
+if(MLIR_INCLUDE_TESTS)
+ set(test_libs
+ MLIRTestDialect
+ )
+endif()
+
+add_mlir_tool(mlir-query
+ mlir-query.cpp
+ )
+llvm_update_compile_flags(mlir-query)
+target_link_libraries(mlir-query
+ PRIVATE
+ ${dialect_libs}
+ ${test_libs}
+ MLIRQueryLib
+ )
+
+mlir_check_link_libraries(mlir-query)
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
new file mode 100644
index 000000000000000..0ed4f94d5802b09
--- /dev/null
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -0,0 +1,63 @@
+//===- mlir-query.cpp - MLIR Query Driver ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is a command line utility that queries a file from/to MLIR using one
+// of the registered queries.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/Registry.h"
+#include "mlir/Tools/mlir-query/MlirQueryMain.h"
+
+using namespace mlir;
+
+// This is needed because these matchers are defined as overloaded functions.
+using HasOpAttrName = detail::AttrOpMatcher(StringRef);
+using HasOpName = detail::NameOpMatcher(StringRef);
+using IsConstantOp = detail::constant_op_matcher();
+
+namespace test {
+#ifdef MLIR_INCLUDE_TESTS
+void registerTestDialect(DialectRegistry &);
+#endif
+} // namespace test
+
+int main(int argc, char **argv) {
+
+ DialectRegistry dialectRegistry;
+ registerAllDialects(dialectRegistry);
+
+ query::matcher::Registry matcherRegistry;
+
+ // Matchers registered in alphabetical order for consistency:
+ matcherRegistry.registerMatcher("hasOpAttrName",
+ static_cast<HasOpAttrName *>(m_Attr));
+ matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
+ matcherRegistry.registerMatcher("isConstantOp",
+ static_cast<IsConstantOp *>(m_Constant));
+ matcherRegistry.registerMatcher("isNegInfFloat", m_NegInfFloat);
+ matcherRegistry.registerMatcher("isNegZeroFloat", m_NegZeroFloat);
+ matcherRegistry.registerMatcher("isNonZero", m_NonZero);
+ matcherRegistry.registerMatcher("isOne", m_One);
+ matcherRegistry.registerMatcher("isOneFloat", m_OneFloat);
+ matcherRegistry.registerMatcher("isPosInfFloat", m_PosInfFloat);
+ matcherRegistry.registerMatcher("isPosZeroFloat", m_PosZeroFloat);
+ matcherRegistry.registerMatcher("isZero", m_Zero);
+ matcherRegistry.registerMatcher("isZeroFloat", m_AnyZeroFloat);
+
+#ifdef MLIR_INCLUDE_TESTS
+ test::registerTestDialect(dialectRegistry);
+#endif
+ MLIRContext context(dialectRegistry);
+
+ return failed(mlirQueryMain(argc, argv, context, matcherRegistry));
+}
More information about the Mlir-commits
mailing list