[Mlir-commits] [mlir] 02d9f4d - [mlir][mlir-query] Introduce mlir-query tool with autocomplete support

Jacques Pienaar llvmlistbot at llvm.org
Fri Oct 13 14:03:35 PDT 2023

Author: Devajith
Date: 2023-10-13T14:03:27-07:00
New Revision: 02d9f4d1f128e17e04ab6e602d3c9b9942612428

URL: https://github.com/llvm/llvm-project/commit/02d9f4d1f128e17e04ab6e602d3c9b9942612428
DIFF: https://github.com/llvm/llvm-project/commit/02d9f4d1f128e17e04ab6e602d3c9b9942612428.diff

LOG: [mlir][mlir-query] Introduce mlir-query tool with autocomplete support

This commit adds the initial version of the mlir-query tool, which leverages the pre-existing matchers defined in mlir/include/mlir/IR/Matchers.h

The tool provides the following set of basic queries:


Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D155127




diff  --git a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
new file mode 100644
index 000000000000000..1073daed8703f5a
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h
@@ -0,0 +1,63 @@
+//===--- ErrorBuilder.h - Helper for building error messages ----*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// ErrorBuilder to manage error messages.
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include <initializer_list>
+namespace mlir::query::matcher::internal {
+class Diagnostics;
+// Represents the line and column numbers in a source query.
+struct SourceLocation {
+  unsigned line{};
+  unsigned column{};
+// Represents a range in a source query, defined by its start and end locations.
+struct SourceRange {
+  SourceLocation start{};
+  SourceLocation end{};
+// All errors from the system.
+enum class ErrorType {
+  None,
+  // Parser Errors
+  ParserFailedToBuildMatcher,
+  ParserInvalidToken,
+  ParserNoCloseParen,
+  ParserNoCode,
+  ParserNoComma,
+  ParserNoOpenParen,
+  ParserNotAMatcher,
+  ParserOverloadedType,
+  ParserStringError,
+  ParserTrailingCode,
+  // Registry Errors
+  RegistryMatcherNotFound,
+  RegistryValueNotFound,
+  RegistryWrongArgCount,
+  RegistryWrongArgType
+void addError(Diagnostics *error, SourceRange range, ErrorType errorType,
+              std::initializer_list<llvm::Twine> errorTexts);
+} // namespace mlir::query::matcher::internal

diff  --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
new file mode 100644
index 000000000000000..6ed35ac0ddccc70
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -0,0 +1,199 @@
+//===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// This file contains function templates and classes to wrap matcher construct
+// functions. It provides a collection of template function and classes that
+// present a generic marshalling layer on top of matcher construct functions.
+// The registry uses these to export all marshaller constructors with a uniform
+// interface. This mechanism takes inspiration from clang-query.
+#include "ErrorBuilder.h"
+#include "VariantValue.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+namespace mlir::query::matcher::internal {
+// Helper template class for jumping from argument type to the correct is/get
+// functions in VariantValue. This is used for verifying and extracting the
+// matcher arguments.
+template <class T>
+struct ArgTypeTraits;
+template <class T>
+struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};
+template <>
+struct ArgTypeTraits<llvm::StringRef> {
+  static bool hasCorrectType(const VariantValue &value) {
+    return value.isString();
+  }
+  static const llvm::StringRef &get(const VariantValue &value) {
+    return value.getString();
+  }
+  static ArgKind getKind() { return ArgKind::String; }
+  static std::optional<std::string> getBestGuess(const VariantValue &) {
+    return std::nullopt;
+  }
+template <>
+struct ArgTypeTraits<DynMatcher> {
+  static bool hasCorrectType(const VariantValue &value) {
+    return value.isMatcher();
+  }
+  static DynMatcher get(const VariantValue &value) {
+    return *value.getMatcher().getDynMatcher();
+  }
+  static ArgKind getKind() { return ArgKind::Matcher; }
+  static std::optional<std::string> getBestGuess(const VariantValue &) {
+    return std::nullopt;
+  }
+// Interface for generic matcher descriptor.
+// Offers a create() method that constructs the matcher from the provided
+// arguments.
+class MatcherDescriptor {
+  virtual ~MatcherDescriptor() = default;
+  virtual VariantMatcher create(SourceRange nameRange,
+                                const llvm::ArrayRef<ParserValue> args,
+                                Diagnostics *error) const = 0;
+  // Returns the number of arguments accepted by the matcher.
+  virtual unsigned getNumArgs() const = 0;
+  // Append the set of argument types accepted for argument 'argNo' to
+  // 'argKinds'.
+  virtual void getArgKinds(unsigned argNo,
+                           std::vector<ArgKind> &argKinds) const = 0;
+class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
+  using MarshallerType = VariantMatcher (*)(void (*matcherFunc)(),
+                                            llvm::StringRef matcherName,
+                                            SourceRange nameRange,
+                                            llvm::ArrayRef<ParserValue> args,
+                                            Diagnostics *error);
+  // Marshaller Function to unpack the arguments and call Func. Func is the
+  // Matcher construct function. This is the function that the matcher
+  // expressions would use to create the matcher.
+  FixedArgCountMatcherDescriptor(MarshallerType marshaller,
+                                 void (*matcherFunc)(),
+                                 llvm::StringRef matcherName,
+                                 llvm::ArrayRef<ArgKind> argKinds)
+      : marshaller(marshaller), matcherFunc(matcherFunc),
+        matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {}
+  VariantMatcher create(SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
+                        Diagnostics *error) const override {
+    return marshaller(matcherFunc, matcherName, nameRange, args, error);
+  }
+  unsigned getNumArgs() const override { return argKinds.size(); }
+  void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
+    kinds.push_back(argKinds[argNo]);
+  }
+  const MarshallerType marshaller;
+  void (*const matcherFunc)();
+  const llvm::StringRef matcherName;
+  const std::vector<ArgKind> argKinds;
+// Helper function to check if argument count matches expected count
+inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
+                          llvm::ArrayRef<ParserValue> args,
+                          Diagnostics *error) {
+  if (args.size() != expectedArgCount) {
+    addError(error, nameRange, ErrorType::RegistryWrongArgCount,
+             {llvm::Twine(expectedArgCount), llvm::Twine(args.size())});
+    return false;
+  }
+  return true;
+// Helper function for checking argument type
+template <typename ArgType, size_t Index>
+inline bool checkArgTypeAtIndex(llvm::StringRef matcherName,
+                                llvm::ArrayRef<ParserValue> args,
+                                Diagnostics *error) {
+  if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
+    addError(error, args[Index].range, ErrorType::RegistryWrongArgType,
+             {llvm::Twine(matcherName), llvm::Twine(Index + 1)});
+    return false;
+  }
+  return true;
+// Marshaller function for fixed number of arguments
+template <typename ReturnType, typename... ArgTypes, size_t... Is>
+static VariantMatcher
+matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
+                         SourceRange nameRange,
+                         llvm::ArrayRef<ParserValue> args, Diagnostics *error,
+                         std::index_sequence<Is...>) {
+  using FuncType = ReturnType (*)(ArgTypes...);
+  // Check if the argument count matches the expected count
+  if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error))
+    return VariantMatcher();
+  // Check if each argument at the corresponding index has the correct type
+  if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
+    ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
+        ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
+    return VariantMatcher::SingleMatcher(
+        *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
+  }
+  return VariantMatcher();
+template <typename ReturnType, typename... ArgTypes>
+static VariantMatcher
+matcherMarshallFixed(void (*matcherFunc)(), llvm::StringRef matcherName,
+                     SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
+                     Diagnostics *error) {
+  return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
+      matcherFunc, matcherName, nameRange, args, error,
+      std::index_sequence_for<ArgTypes...>{});
+// Fixed number of arguments overload
+template <typename ReturnType, typename... ArgTypes>
+makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
+                        llvm::StringRef matcherName) {
+  // Create a vector of argument kinds
+  std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
+  return std::make_unique<FixedArgCountMatcherDescriptor>(
+      matcherMarshallFixed<ReturnType, ArgTypes...>,
+      reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
+} // namespace mlir::query::matcher::internal

diff  --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
new file mode 100644
index 000000000000000..b008a21f53ae2a6
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -0,0 +1,41 @@
+//===- MatchFinder.h - ------------------------------------------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// This file contains the MatchFinder class, which is used to find operations
+// that match a given matcher.
+#include "MatchersInternal.h"
+namespace mlir::query::matcher {
+// MatchFinder is used to find all operations that match a given matcher.
+class MatchFinder {
+  // Returns all operations that match the given matcher.
+  static std::vector<Operation *> getMatches(Operation *root,
+                                             DynMatcher matcher) {
+    std::vector<Operation *> matches;
+    // Simple match finding with walk.
+    root->walk([&](Operation *subOp) {
+      if (matcher.match(subOp))
+        matches.push_back(subOp);
+    });
+    return matches;
+  }
+} // namespace mlir::query::matcher

diff  --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
new file mode 100644
index 000000000000000..67455be592393b4
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -0,0 +1,72 @@
+//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Implements the base layer of the matcher framework.
+// Matchers are methods that return a Matcher which provides a method
+// match(Operation *op)
+// The matcher functions are defined in include/mlir/IR/Matchers.h.
+// This file contains the wrapper classes needed to construct matchers for
+// mlir-query.
+#include "mlir/IR/Matchers.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
+namespace mlir::query::matcher {
+// Generic interface for matchers on an MLIR operation.
+class MatcherInterface
+    : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
+  virtual ~MatcherInterface() = default;
+  virtual bool match(Operation *op) = 0;
+// MatcherFnImpl takes a matcher function object and implements
+// MatcherInterface.
+template <typename MatcherFn>
+class MatcherFnImpl : public MatcherInterface {
+  MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
+  bool match(Operation *op) override { return matcherFn.match(op); }
+  MatcherFn matcherFn;
+// Matcher wraps a MatcherInterface implementation and provides a match()
+// method that redirects calls to the underlying implementation.
+class DynMatcher {
+  // Takes ownership of the provided implementation pointer.
+  DynMatcher(MatcherInterface *implementation)
+      : implementation(implementation) {}
+  template <typename MatcherFn>
+  static std::unique_ptr<DynMatcher>
+  constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
+    auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
+    return std::make_unique<DynMatcher>(impl.release());
+  }
+  bool match(Operation *op) const { return implementation->match(op); }
+  llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
+} // namespace mlir::query::matcher

diff  --git a/mlir/include/mlir/Query/Matcher/Registry.h b/mlir/include/mlir/Query/Matcher/Registry.h
new file mode 100644
index 000000000000000..e929b4a04d151db
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/Registry.h
@@ -0,0 +1,51 @@
+//===--- Registry.h - Matcher Registry --------------------------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Registry class to manage the registry of matchers using a map.
+// This class provides a convenient interface for registering and accessing
+// matcher constructors using a string-based map.
+#include "Marshallers.h"
+#include "llvm/ADT/StringMap.h"
+#include <string>
+namespace mlir::query::matcher {
+using ConstructorMap =
+    llvm::StringMap<std::unique_ptr<const internal::MatcherDescriptor>>;
+class Registry {
+  Registry() = default;
+  ~Registry() = default;
+  const ConstructorMap &constructors() const { return constructorMap; }
+  template <typename MatcherType>
+  void registerMatcher(const std::string &name, MatcherType matcher) {
+    registerMatcherDescriptor(name,
+                              internal::makeMatcherAutoMarshall(matcher, name));
+  }
+  void registerMatcherDescriptor(
+      llvm::StringRef matcherName,
+      std::unique_ptr<internal::MatcherDescriptor> callback);
+  ConstructorMap constructorMap;
+} // namespace mlir::query::matcher

diff  --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
new file mode 100644
index 000000000000000..449f8b3a01e0217
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -0,0 +1,128 @@
+//===--- VariantValue.h -----------------------------------------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Supports all the types required for dynamic Matcher construction.
+// Used by the registry to construct matchers in a generic way.
+#include "ErrorBuilder.h"
+#include "MatchersInternal.h"
+#include "llvm/ADT/StringRef.h"
+namespace mlir::query::matcher {
+// All types that VariantValue can contain.
+enum class ArgKind { Matcher, String };
+// A variant matcher object to abstract simple and complex matchers into a
+// single object type.
+class VariantMatcher {
+  class MatcherOps;
+  // Payload interface to be specialized by each matcher type. It follows a
+  // similar interface as VariantMatcher itself.
+  class Payload {
+  public:
+    virtual ~Payload();
+    virtual std::optional<DynMatcher> getDynMatcher() const = 0;
+    virtual std::string getTypeAsString() const = 0;
+  };
+  // A null matcher.
+  VariantMatcher();
+  // Clones the provided matcher.
+  static VariantMatcher SingleMatcher(DynMatcher matcher);
+  // Makes the matcher the "null" matcher.
+  void reset();
+  // Checks if the matcher is null.
+  bool isNull() const { return !value; }
+  // Returns the matcher
+  std::optional<DynMatcher> getDynMatcher() const;
+  // String representation of the type of the value.
+  std::string getTypeAsString() const;
+  explicit VariantMatcher(std::shared_ptr<Payload> value)
+      : value(std::move(value)) {}
+  class SinglePayload;
+  std::shared_ptr<const Payload> value;
+// Variant value class with a tagged union with value type semantics. It is used
+// by the registry as the return value and argument type for the matcher factory
+// methods. It can be constructed from any of the supported types:
+//  - StringRef
+//  - VariantMatcher
+class VariantValue {
+  VariantValue() : type(ValueType::Nothing) {}
+  VariantValue(const VariantValue &other);
+  ~VariantValue();
+  VariantValue &operator=(const VariantValue &other);
+  // Specific constructors for each supported type.
+  VariantValue(const llvm::StringRef string);
+  VariantValue(const VariantMatcher &matcher);
+  // String value functions.
+  bool isString() const;
+  const llvm::StringRef &getString() const;
+  void setString(const llvm::StringRef &string);
+  // Matcher value functions.
+  bool isMatcher() const;
+  const VariantMatcher &getMatcher() const;
+  void setMatcher(const VariantMatcher &matcher);
+  // String representation of the type of the value.
+  std::string getTypeAsString() const;
+  void reset();
+  // All supported value types.
+  enum class ValueType {
+    Nothing,
+    String,
+    Matcher,
+  };
+  // All supported value types.
+  union AllValues {
+    llvm::StringRef *String;
+    VariantMatcher *Matcher;
+  };
+  ValueType type;
+  AllValues value;
+// A VariantValue instance annotated with its parser context.
+struct ParserValue {
+  ParserValue() {}
+  llvm::StringRef text;
+  internal::SourceRange range;
+  VariantValue value;
+} // namespace mlir::query::matcher

diff  --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
new file mode 100644
index 000000000000000..447fc7ca21c8da4
--- /dev/null
+++ b/mlir/include/mlir/Query/Query.h
@@ -0,0 +1,109 @@
+//===--- Query.h ------------------------------------------------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "Matcher/VariantValue.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/LineEditor/LineEditor.h"
+#include <string>
+namespace mlir::query {
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+class QuerySession;
+struct Query : llvm::RefCountedBase<Query> {
+  Query(QueryKind kind) : kind(kind) {}
+  virtual ~Query();
+  // Perform the query on qs and print output to os.
+  virtual mlir::LogicalResult run(llvm::raw_ostream &os,
+                                  QuerySession &qs) const = 0;
+  llvm::StringRef remainingContent;
+  const QueryKind kind;
+typedef llvm::IntrusiveRefCntPtr<Query> QueryRef;
+QueryRef parse(llvm::StringRef line, const QuerySession &qs);
+complete(llvm::StringRef line, size_t pos, const QuerySession &qs);
+// Any query which resulted in a parse error. The error message is in ErrStr.
+struct InvalidQuery : Query {
+  InvalidQuery(const llvm::Twine &errStr)
+      : Query(QueryKind::Invalid), errStr(errStr.str()) {}
+  mlir::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override;
+  std::string errStr;
+  static bool classof(const Query *query) {
+    return query->kind == QueryKind::Invalid;
+  }
+// No-op query (i.e. a blank line).
+struct NoOpQuery : Query {
+  NoOpQuery() : Query(QueryKind::NoOp) {}
+  mlir::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override;
+  static bool classof(const Query *query) {
+    return query->kind == QueryKind::NoOp;
+  }
+// Query for "help".
+struct HelpQuery : Query {
+  HelpQuery() : Query(QueryKind::Help) {}
+  mlir::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override;
+  static bool classof(const Query *query) {
+    return query->kind == QueryKind::Help;
+  }
+// Query for "quit".
+struct QuitQuery : Query {
+  QuitQuery() : Query(QueryKind::Quit) {}
+  mlir::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override;
+  static bool classof(const Query *query) {
+    return query->kind == QueryKind::Quit;
+  }
+// Query for "match MATCHER".
+struct MatchQuery : Query {
+  MatchQuery(llvm::StringRef source, const matcher::DynMatcher &matcher)
+      : Query(QueryKind::Match), matcher(matcher), source(source) {}
+  mlir::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override;
+  const matcher::DynMatcher matcher;
+  llvm::StringRef source;
+  static bool classof(const Query *query) {
+    return query->kind == QueryKind::Match;
+  }
+} // namespace mlir::query

diff  --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
new file mode 100644
index 000000000000000..b03a8cae8f18132
--- /dev/null
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -0,0 +1,42 @@
+//===--- QuerySession.h -----------------------------------------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "llvm/ADT/StringMap.h"
+namespace mlir::query {
+class Registry;
+// Represents the state for a particular mlir-query session.
+class QuerySession {
+  QuerySession(Operation *rootOp, llvm::SourceMgr &sourceMgr, unsigned bufferId,
+               const matcher::Registry &matcherRegistry)
+      : rootOp(rootOp), sourceMgr(sourceMgr), bufferId(bufferId),
+        matcherRegistry(matcherRegistry) {}
+  Operation *getRootOp() { return rootOp; }
+  llvm::SourceMgr &getSourceManager() const { return sourceMgr; }
+  unsigned getBufferId() { return bufferId; }
+  const matcher::Registry &getRegistryData() const { return matcherRegistry; }
+  llvm::StringMap<matcher::VariantValue> namedValues;
+  bool terminate = false;
+  Operation *rootOp;
+  llvm::SourceMgr &sourceMgr;
+  unsigned bufferId;
+  const matcher::Registry &matcherRegistry;
+} // namespace mlir::query

diff  --git a/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h
new file mode 100644
index 000000000000000..fa1cd5d8176ee12
--- /dev/null
+++ b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h
@@ -0,0 +1,30 @@
+//===- MlirQueryMain.h - MLIR Query main ----------------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Main entry function for mlir-query for when built as standalone
+// binary.
+#include "mlir/Query/Matcher/Registry.h"
+#include "mlir/Support/LogicalResult.h"
+namespace mlir {
+class MLIRContext;
+mlirQueryMain(int argc, char **argv, MLIRContext &context,
+              const mlir::query::matcher::Registry &matcherRegistry);
+} // namespace mlir

diff  --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index c71664a3f00631b..d25c84a3975db4c 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(IR)

diff  --git a/mlir/lib/Query/CMakeLists.txt b/mlir/lib/Query/CMakeLists.txt
new file mode 100644
index 000000000000000..817583e94c52229
--- /dev/null
+++ b/mlir/lib/Query/CMakeLists.txt
@@ -0,0 +1,12 @@
+  Query.cpp
+  QueryParser.cpp
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query
+  MLIRQueryMatcher
+  )

diff  --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
new file mode 100644
index 000000000000000..6afd24722bb70ce
--- /dev/null
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -0,0 +1,10 @@
+  Parser.cpp
+  RegistryManager.cpp
+  VariantValue.cpp
+  Diagnostics.cpp
+  ErrorBuilder.cpp
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query/Matcher
+  )

diff  --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp
new file mode 100644
index 000000000000000..10468dbcc530676
--- /dev/null
+++ b/mlir/lib/Query/Matcher/Diagnostics.cpp
@@ -0,0 +1,128 @@
+//===- Diagnostic.cpp -----------------------------------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "Diagnostics.h"
+#include "mlir/Query/Matcher/ErrorBuilder.h"
+namespace mlir::query::matcher::internal {
+Diagnostics::ArgStream &
+Diagnostics::ArgStream::operator<<(const llvm::Twine &arg) {
+  out->push_back(arg.str());
+  return *this;
+Diagnostics::ArgStream Diagnostics::addError(SourceRange range,
+                                             ErrorType error) {
+  errorValues.emplace_back();
+  ErrorContent &last = errorValues.back();
+  last.contextStack = contextStack;
+  last.messages.emplace_back();
+  last.messages.back().range = range;
+  last.messages.back().type = error;
+  return ArgStream(&last.messages.back().args);
+static llvm::StringRef errorTypeToFormatString(ErrorType type) {
+  switch (type) {
+  case ErrorType::RegistryMatcherNotFound:
+    return "Matcher not found: $0";
+  case ErrorType::RegistryWrongArgCount:
+    return "Incorrect argument count. (Expected = $0) != (Actual = $1)";
+  case ErrorType::RegistryWrongArgType:
+    return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)";
+  case ErrorType::RegistryValueNotFound:
+    return "Value not found: $0";
+  case ErrorType::ParserStringError:
+    return "Error parsing string token: <$0>";
+  case ErrorType::ParserNoOpenParen:
+    return "Error parsing matcher. Found token <$0> while looking for '('.";
+  case ErrorType::ParserNoCloseParen:
+    return "Error parsing matcher. Found end-of-code while looking for ')'.";
+  case ErrorType::ParserNoComma:
+    return "Error parsing matcher. Found token <$0> while looking for ','.";
+  case ErrorType::ParserNoCode:
+    return "End of code found while looking for token.";
+  case ErrorType::ParserNotAMatcher:
+    return "Input value is not a matcher expression.";
+  case ErrorType::ParserInvalidToken:
+    return "Invalid token <$0> found when looking for a value.";
+  case ErrorType::ParserTrailingCode:
+    return "Unexpected end of code.";
+  case ErrorType::ParserOverloadedType:
+    return "Input value has unresolved overloaded type: $0";
+  case ErrorType::ParserFailedToBuildMatcher:
+    return "Failed to build matcher: $0.";
+  case ErrorType::None:
+    return "<N/A>";
+  }
+  llvm_unreachable("Unknown ErrorType value.");
+static void formatErrorString(llvm::StringRef formatString,
+                              llvm::ArrayRef<std::string> args,
+                              llvm::raw_ostream &os) {
+  while (!formatString.empty()) {
+    std::pair<llvm::StringRef, llvm::StringRef> pieces =
+        formatString.split("$");
+    os << pieces.first.str();
+    if (pieces.second.empty())
+      break;
+    const char next = pieces.second.front();
+    formatString = pieces.second.drop_front();
+    if (next >= '0' && next <= '9') {
+      const unsigned index = next - '0';
+      if (index < args.size()) {
+        os << args[index];
+      } else {
+        os << "<Argument_Not_Provided>";
+      }
+    }
+  }
+static void maybeAddLineAndColumn(SourceRange range, llvm::raw_ostream &os) {
+  if (range.start.line > 0 && range.start.column > 0) {
+    os << range.start.line << ":" << range.start.column << ": ";
+  }
+void Diagnostics::printMessage(
+    const Diagnostics::ErrorContent::Message &message, const llvm::Twine prefix,
+    llvm::raw_ostream &os) const {
+  maybeAddLineAndColumn(message.range, os);
+  os << prefix;
+  formatErrorString(errorTypeToFormatString(message.type), message.args, os);
+void Diagnostics::printErrorContent(const Diagnostics::ErrorContent &content,
+                                    llvm::raw_ostream &os) const {
+  if (content.messages.size() == 1) {
+    printMessage(content.messages[0], "", os);
+  } else {
+    for (size_t i = 0, e = content.messages.size(); i != e; ++i) {
+      if (i != 0)
+        os << "\n";
+      printMessage(content.messages[i],
+                   "Candidate " + llvm::Twine(i + 1) + ": ", os);
+    }
+  }
+void Diagnostics::print(llvm::raw_ostream &os) const {
+  for (const ErrorContent &error : errorValues) {
+    if (&error != &errorValues.front())
+      os << "\n";
+    printErrorContent(error, os);
+  }
+} // namespace mlir::query::matcher::internal

diff  --git a/mlir/lib/Query/Matcher/Diagnostics.h b/mlir/lib/Query/Matcher/Diagnostics.h
new file mode 100644
index 000000000000000..a58a435b16a9030
--- /dev/null
+++ b/mlir/lib/Query/Matcher/Diagnostics.h
@@ -0,0 +1,82 @@
+//===--- Diagnostics.h - Helper class for error diagnostics -----*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Diagnostics class to manage error messages. Implementation shares similarity
+// to clang-query Diagnostics.
+#include "mlir/Query/Matcher/ErrorBuilder.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+#include <string>
+#include <vector>
+namespace mlir::query::matcher::internal {
+// Diagnostics class to manage error messages.
+class Diagnostics {
+  // Helper stream class for constructing error messages.
+  class ArgStream {
+  public:
+    ArgStream(std::vector<std::string> *out) : out(out) {}
+    template <class T>
+    ArgStream &operator<<(const T &arg) {
+      return operator<<(llvm::Twine(arg));
+    }
+    ArgStream &operator<<(const llvm::Twine &arg);
+  private:
+    std::vector<std::string> *out;
+  };
+  // Add an error message with the specified range and error type.
+  // Returns an ArgStream object to allow constructing the error message using
+  // the << operator.
+  ArgStream addError(SourceRange range, ErrorType error);
+  // Print all error messages to the specified output stream.
+  void print(llvm::raw_ostream &os) const;
+  // Information stored for one frame of the context.
+  struct ContextFrame {
+    SourceRange range;
+    std::vector<std::string> args;
+  };
+  // Information stored for each error found.
+  struct ErrorContent {
+    std::vector<ContextFrame> contextStack;
+    struct Message {
+      SourceRange range;
+      ErrorType type;
+      std::vector<std::string> args;
+    };
+    std::vector<Message> messages;
+  };
+  void printMessage(const ErrorContent::Message &message,
+                    const llvm::Twine Prefix, llvm::raw_ostream &os) const;
+  void printErrorContent(const ErrorContent &content,
+                         llvm::raw_ostream &os) const;
+  std::vector<ContextFrame> contextStack;
+  std::vector<ErrorContent> errorValues;
+} // namespace mlir::query::matcher::internal

diff  --git a/mlir/lib/Query/Matcher/ErrorBuilder.cpp b/mlir/lib/Query/Matcher/ErrorBuilder.cpp
new file mode 100644
index 000000000000000..de6447dac490ac4
--- /dev/null
+++ b/mlir/lib/Query/Matcher/ErrorBuilder.cpp
@@ -0,0 +1,25 @@
+//===--- ErrorBuilder.cpp - Helper for building error messages ------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Query/Matcher/ErrorBuilder.h"
+#include "Diagnostics.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include <initializer_list>
+namespace mlir::query::matcher::internal {
+void addError(Diagnostics *error, SourceRange range, ErrorType errorType,
+              std::initializer_list<llvm::Twine> errorTexts) {
+  Diagnostics::ArgStream argStream = error->addError(range, errorType);
+  for (const llvm::Twine &errorText : errorTexts) {
+    argStream << errorText;
+  }
+} // namespace mlir::query::matcher::internal

diff  --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
new file mode 100644
index 000000000000000..be9e60de221db19
--- /dev/null
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -0,0 +1,540 @@
+//===- Parser.cpp - Matcher expression parser -----------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Recursive parser implementation for the matcher expression grammar.
+#include "Parser.h"
+#include <vector>
+namespace mlir::query::matcher::internal {
+// Simple structure to hold information for one token from the parser.
+struct Parser::TokenInfo {
+  TokenInfo() = default;
+  // Method to set the kind and text of the token
+  void set(TokenKind newKind, llvm::StringRef newText) {
+    kind = newKind;
+    text = newText;
+  }
+  llvm::StringRef text;
+  TokenKind kind = TokenKind::Eof;
+  SourceRange range;
+  VariantValue value;
+class Parser::CodeTokenizer {
+  // Constructor with matcherCode and error
+  explicit CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error)
+      : code(matcherCode), startOfLine(matcherCode), error(error) {
+    nextToken = getNextToken();
+  }
+  // Constructor with matcherCode, error, and codeCompletionOffset
+  CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error,
+                unsigned codeCompletionOffset)
+      : code(matcherCode), startOfLine(matcherCode), error(error),
+        codeCompletionLocation(matcherCode.data() + codeCompletionOffset) {
+    nextToken = getNextToken();
+  }
+  // Peek at next token without consuming it
+  const TokenInfo &peekNextToken() const { return nextToken; }
+  // Consume and return the next token
+  TokenInfo consumeNextToken() {
+    TokenInfo thisToken = nextToken;
+    nextToken = getNextToken();
+    return thisToken;
+  }
+  // Skip any newline tokens
+  TokenInfo skipNewlines() {
+    while (nextToken.kind == TokenKind::NewLine)
+      nextToken = getNextToken();
+    return nextToken;
+  }
+  // Consume and return next token, ignoring newlines
+  TokenInfo consumeNextTokenIgnoreNewlines() {
+    skipNewlines();
+    return nextToken.kind == TokenKind::Eof ? nextToken : consumeNextToken();
+  }
+  // Return kind of next token
+  TokenKind nextTokenKind() const { return nextToken.kind; }
+  // Helper function to get the first character as a new StringRef and drop it
+  // from the original string
+  llvm::StringRef firstCharacterAndDrop(llvm::StringRef &str) {
+    assert(!str.empty());
+    llvm::StringRef firstChar = str.substr(0, 1);
+    str = str.drop_front();
+    return firstChar;
+  }
+  // Get next token, consuming whitespaces and handling 
diff erent token types
+  TokenInfo getNextToken() {
+    consumeWhitespace();
+    TokenInfo result;
+    result.range.start = currentLocation();
+    // Code completion case
+    if (codeCompletionLocation && codeCompletionLocation <= code.data()) {
+      result.set(TokenKind::CodeCompletion,
+                 llvm::StringRef(codeCompletionLocation, 0));
+      codeCompletionLocation = nullptr;
+      return result;
+    }
+    // End of file case
+    if (code.empty()) {
+      result.set(TokenKind::Eof, "");
+      return result;
+    }
+    // Switch to handle specific characters
+    switch (code[0]) {
+    case '#':
+      code = code.drop_until([](char c) { return c == '\n'; });
+      return getNextToken();
+    case ',':
+      result.set(TokenKind::Comma, firstCharacterAndDrop(code));
+      break;
+    case '.':
+      result.set(TokenKind::Period, firstCharacterAndDrop(code));
+      break;
+    case '\n':
+      ++line;
+      startOfLine = code.drop_front();
+      result.set(TokenKind::NewLine, firstCharacterAndDrop(code));
+      break;
+    case '(':
+      result.set(TokenKind::OpenParen, firstCharacterAndDrop(code));
+      break;
+    case ')':
+      result.set(TokenKind::CloseParen, firstCharacterAndDrop(code));
+      break;
+    case '"':
+    case '\'':
+      consumeStringLiteral(&result);
+      break;
+    default:
+      parseIdentifierOrInvalid(&result);
+      break;
+    }
+    result.range.end = currentLocation();
+    return result;
+  }
+  // Consume a string literal, handle escape sequences and missing closing
+  // quote.
+  void consumeStringLiteral(TokenInfo *result) {
+    bool inEscape = false;
+    const char marker = code[0];
+    for (size_t length = 1; length < code.size(); ++length) {
+      if (inEscape) {
+        inEscape = false;
+        continue;
+      }
+      if (code[length] == '\\') {
+        inEscape = true;
+        continue;
+      }
+      if (code[length] == marker) {
+        result->kind = TokenKind::Literal;
+        result->text = code.substr(0, length + 1);
+        result->value = code.substr(1, length - 1);
+        code = code.drop_front(length + 1);
+        return;
+      }
+    }
+    llvm::StringRef errorText = code;
+    code = code.drop_front(code.size());
+    SourceRange range;
+    range.start = result->range.start;
+    range.end = currentLocation();
+    error->addError(range, ErrorType::ParserStringError) << errorText;
+    result->kind = TokenKind::Error;
+  }
+  void parseIdentifierOrInvalid(TokenInfo *result) {
+    if (isalnum(code[0])) {
+      // Parse an identifier
+      size_t tokenLength = 1;
+      while (true) {
+        // A code completion location in/immediately after an identifier will
+        // cause the portion of the identifier before the code completion
+        // location to become a code completion token.
+        if (codeCompletionLocation == code.data() + tokenLength) {
+          codeCompletionLocation = nullptr;
+          result->kind = TokenKind::CodeCompletion;
+          result->text = code.substr(0, tokenLength);
+          code = code.drop_front(tokenLength);
+          return;
+        }
+        if (tokenLength == code.size() || !(isalnum(code[tokenLength])))
+          break;
+        ++tokenLength;
+      }
+      result->kind = TokenKind::Ident;
+      result->text = code.substr(0, tokenLength);
+      code = code.drop_front(tokenLength);
+    } else {
+      result->kind = TokenKind::InvalidChar;
+      result->text = code.substr(0, 1);
+      code = code.drop_front(1);
+    }
+  }
+  // Consume all leading whitespace from code, except newlines
+  void consumeWhitespace() {
+    code = code.drop_while(
+        [](char c) { return llvm::StringRef(" \t\v\f\r").contains(c); });
+  }
+  // Returns the current location in the source code
+  SourceLocation currentLocation() {
+    SourceLocation location;
+    location.line = line;
+    location.column = code.data() - startOfLine.data() + 1;
+    return location;
+  }
+  llvm::StringRef code;
+  llvm::StringRef startOfLine;
+  unsigned line = 1;
+  Diagnostics *error;
+  TokenInfo nextToken;
+  const char *codeCompletionLocation = nullptr;
+Parser::Sema::~Sema() = default;
+std::vector<ArgKind> Parser::Sema::getAcceptedCompletionTypes(
+    llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
+  return {};
+Parser::Sema::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
+  return {};
+// Entry for the scope of a parser
+struct Parser::ScopedContextEntry {
+  Parser *parser;
+  ScopedContextEntry(Parser *parser, MatcherCtor c) : parser(parser) {
+    parser->contextStack.emplace_back(c, 0u);
+  }
+  ~ScopedContextEntry() { parser->contextStack.pop_back(); }
+  void nextArg() { ++parser->contextStack.back().second; }
+// Parse and validate expressions starting with an identifier.
+// This function can parse named values and matchers. In case of failure, it
+// will try to determine the user's intent to give an appropriate error message.
+bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
+  const TokenInfo nameToken = tokenizer->consumeNextToken();
+  if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
+    // Parse as a named value.
+    auto namedValue =
+        namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+    if (!namedValue.isMatcher()) {
+      error->addError(tokenizer->peekNextToken().range,
+                      ErrorType::ParserNotAMatcher);
+      return false;
+    }
+    if (tokenizer->nextTokenKind() == TokenKind::NewLine) {
+      error->addError(tokenizer->peekNextToken().range,
+                      ErrorType::ParserNoOpenParen)
+          << "NewLine";
+      return false;
+    }
+    // If the syntax is correct and the name is not a matcher either, report
+    // an unknown named value.
+    if ((tokenizer->nextTokenKind() == TokenKind::Comma ||
+         tokenizer->nextTokenKind() == TokenKind::CloseParen ||
+         tokenizer->nextTokenKind() == TokenKind::NewLine ||
+         tokenizer->nextTokenKind() == TokenKind::Eof) &&
+        !sema->lookupMatcherCtor(nameToken.text)) {
+      error->addError(nameToken.range, ErrorType::RegistryValueNotFound)
+          << nameToken.text;
+      return false;
+    }
+    // Otherwise, fallback to the matcher parser.
+  }
+  tokenizer->skipNewlines();
+  assert(nameToken.kind == TokenKind::Ident);
+  TokenInfo openToken = tokenizer->consumeNextToken();
+  if (openToken.kind != TokenKind::OpenParen) {
+    error->addError(openToken.range, ErrorType::ParserNoOpenParen)
+        << openToken.text;
+    return false;
+  }
+  std::optional<MatcherCtor> ctor = sema->lookupMatcherCtor(nameToken.text);
+  // Parse as a matcher expression.
+  return parseMatcherExpressionImpl(nameToken, openToken, ctor, value);
+// Parse the arguments of a matcher
+bool Parser::parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
+                              const TokenInfo &nameToken, TokenInfo &endToken) {
+  ScopedContextEntry sce(this, ctor);
+  while (tokenizer->nextTokenKind() != TokenKind::Eof) {
+    if (tokenizer->nextTokenKind() == TokenKind::CloseParen) {
+      // end of args.
+      endToken = tokenizer->consumeNextToken();
+      break;
+    }
+    if (!args.empty()) {
+      // We must find a , token to continue.
+      TokenInfo commaToken = tokenizer->consumeNextToken();
+      if (commaToken.kind != TokenKind::Comma) {
+        error->addError(commaToken.range, ErrorType::ParserNoComma)
+            << commaToken.text;
+        return false;
+      }
+    }
+    ParserValue argValue;
+    tokenizer->skipNewlines();
+    argValue.text = tokenizer->peekNextToken().text;
+    argValue.range = tokenizer->peekNextToken().range;
+    if (!parseExpressionImpl(&argValue.value)) {
+      return false;
+    }
+    tokenizer->skipNewlines();
+    args.push_back(argValue);
+    sce.nextArg();
+  }
+  return true;
+// Parse and validate a matcher expression.
+bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken,
+                                        const TokenInfo &openToken,
+                                        std::optional<MatcherCtor> ctor,
+                                        VariantValue *value) {
+  if (!ctor) {
+    error->addError(nameToken.range, ErrorType::RegistryMatcherNotFound)
+        << nameToken.text;
+    // Do not return here. We need to continue to give completion suggestions.
+  }
+  std::vector<ParserValue> args;
+  TokenInfo endToken;
+  tokenizer->skipNewlines();
+  if (!parseMatcherArgs(args, ctor.value_or(nullptr), nameToken, endToken)) {
+    return false;
+  }
+  // Check for the missing closing parenthesis
+  if (endToken.kind != TokenKind::CloseParen) {
+    error->addError(openToken.range, ErrorType::ParserNoCloseParen)
+        << nameToken.text;
+    return false;
+  }
+  if (!ctor)
+    return false;
+  // Merge the start and end infos.
+  SourceRange matcherRange = nameToken.range;
+  matcherRange.end = endToken.range.end;
+  VariantMatcher result =
+      sema->actOnMatcherExpression(*ctor, matcherRange, args, error);
+  if (result.isNull())
+    return false;
+  *value = result;
+  return true;
+// If the prefix of this completion matches the completion token, add it to
+// completions minus the prefix.
+void Parser::addCompletion(const TokenInfo &compToken,
+                           const MatcherCompletion &completion) {
+  if (llvm::StringRef(completion.typedText).startswith(compToken.text)) {
+    completions.emplace_back(completion.typedText.substr(compToken.text.size()),
+                             completion.matcherDecl);
+  }
+Parser::getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) {
+  if (!namedValues)
+    return {};
+  std::vector<MatcherCompletion> result;
+  for (const auto &entry : *namedValues) {
+    std::string decl =
+        (entry.getValue().getTypeAsString() + " " + entry.getKey()).str();
+    result.emplace_back(entry.getKey(), decl);
+  }
+  return result;
+void Parser::addExpressionCompletions() {
+  const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines();
+  assert(compToken.kind == TokenKind::CodeCompletion);
+  // We cannot complete code if there is an invalid element on the context
+  // stack.
+  for (const auto &entry : contextStack) {
+    if (!entry.first)
+      return;
+  }
+  auto acceptedTypes = sema->getAcceptedCompletionTypes(contextStack);
+  for (const auto &completion : sema->getMatcherCompletions(acceptedTypes)) {
+    addCompletion(compToken, completion);
+  }
+  for (const auto &completion : getNamedValueCompletions(acceptedTypes)) {
+    addCompletion(compToken, completion);
+  }
+// Parse an <Expresssion>
+bool Parser::parseExpressionImpl(VariantValue *value) {
+  switch (tokenizer->nextTokenKind()) {
+  case TokenKind::Literal:
+    *value = tokenizer->consumeNextToken().value;
+    return true;
+  case TokenKind::Ident:
+    return parseIdentifierPrefixImpl(value);
+  case TokenKind::CodeCompletion:
+    addExpressionCompletions();
+    return false;
+  case TokenKind::Eof:
+    error->addError(tokenizer->consumeNextToken().range,
+                    ErrorType::ParserNoCode);
+    return false;
+  case TokenKind::Error:
+    // This error was already reported by the tokenizer.
+    return false;
+  case TokenKind::NewLine:
+  case TokenKind::OpenParen:
+  case TokenKind::CloseParen:
+  case TokenKind::Comma:
+  case TokenKind::Period:
+  case TokenKind::InvalidChar:
+    const TokenInfo token = tokenizer->consumeNextToken();
+    error->addError(token.range, ErrorType::ParserInvalidToken)
+        << (token.kind == TokenKind::NewLine ? "NewLine" : token.text);
+    return false;
+  }
+  llvm_unreachable("Unknown token kind.");
+Parser::Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
+               const NamedValueMap *namedValues, Diagnostics *error)
+    : tokenizer(tokenizer),
+      sema(std::make_unique<RegistrySema>(matcherRegistry)),
+      namedValues(namedValues), error(error) {}
+Parser::RegistrySema::~RegistrySema() = default;
+Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) {
+  return RegistryManager::lookupMatcherCtor(matcherName, matcherRegistry);
+VariantMatcher Parser::RegistrySema::actOnMatcherExpression(
+    MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
+    Diagnostics *error) {
+  return RegistryManager::constructMatcher(ctor, nameRange, args, error);
+std::vector<ArgKind> Parser::RegistrySema::getAcceptedCompletionTypes(
+    llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
+  return RegistryManager::getAcceptedCompletionTypes(context);
+std::vector<MatcherCompletion> Parser::RegistrySema::getMatcherCompletions(
+    llvm::ArrayRef<ArgKind> acceptedTypes) {
+  return RegistryManager::getMatcherCompletions(acceptedTypes, matcherRegistry);
+bool Parser::parseExpression(llvm::StringRef &code,
+                             const Registry &matcherRegistry,
+                             const NamedValueMap *namedValues,
+                             VariantValue *value, Diagnostics *error) {
+  CodeTokenizer tokenizer(code, error);
+  Parser parser(&tokenizer, matcherRegistry, namedValues, error);
+  if (!parser.parseExpressionImpl(value))
+    return false;
+  auto nextToken = tokenizer.peekNextToken();
+  if (nextToken.kind != TokenKind::Eof &&
+      nextToken.kind != TokenKind::NewLine) {
+    error->addError(tokenizer.peekNextToken().range,
+                    ErrorType::ParserTrailingCode);
+    return false;
+  }
+  return true;
+Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset,
+                           const Registry &matcherRegistry,
+                           const NamedValueMap *namedValues) {
+  Diagnostics error;
+  CodeTokenizer tokenizer(code, &error, completionOffset);
+  Parser parser(&tokenizer, matcherRegistry, namedValues, &error);
+  VariantValue dummy;
+  parser.parseExpressionImpl(&dummy);
+  return parser.completions;
+std::optional<DynMatcher> Parser::parseMatcherExpression(
+    llvm::StringRef &code, const Registry &matcherRegistry,
+    const NamedValueMap *namedValues, Diagnostics *error) {
+  VariantValue value;
+  if (!parseExpression(code, matcherRegistry, namedValues, &value, error))
+    return std::nullopt;
+  if (!value.isMatcher()) {
+    error->addError(SourceRange(), ErrorType::ParserNotAMatcher);
+    return std::nullopt;
+  }
+  std::optional<DynMatcher> result = value.getMatcher().getDynMatcher();
+  if (!result) {
+    error->addError(SourceRange(), ErrorType::ParserOverloadedType)
+        << value.getTypeAsString();
+  }
+  return result;
+} // namespace mlir::query::matcher::internal

diff  --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h
new file mode 100644
index 000000000000000..f049af34e9c907a
--- /dev/null
+++ b/mlir/lib/Query/Matcher/Parser.h
@@ -0,0 +1,188 @@
+//===--- Parser.h - Matcher expression parser -------------------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Simple matcher expression parser.
+// This file contains the Parser class, which is responsible for parsing
+// expressions in a specific format: matcherName(Arg0, Arg1, ..., ArgN). The
+// parser can also interpret simple types, like strings.
+// The actual processing of the matchers is handled by a Sema object that is
+// provided to the parser.
+// The grammar for the supported expressions is as follows:
+// <Expression>        := <StringLiteral> | <MatcherExpression>
+// <StringLiteral>     := "quoted string"
+// <MatcherExpression> := <MatcherName>(<ArgumentList>)
+// <MatcherName>       := [a-zA-Z]+
+// <ArgumentList>      := <Expression> | <Expression>,<ArgumentList>
+#include "Diagnostics.h"
+#include "RegistryManager.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include <memory>
+#include <vector>
+namespace mlir::query::matcher::internal {
+// Matcher expression parser.
+class Parser {
+  // Different possible tokens.
+  enum class TokenKind {
+    Eof,
+    NewLine,
+    OpenParen,
+    CloseParen,
+    Comma,
+    Period,
+    Literal,
+    Ident,
+    InvalidChar,
+    CodeCompletion,
+    Error
+  };
+  // Interface to connect the parser with the registry and more. The parser uses
+  // the Sema instance passed into parseMatcherExpression() to handle all
+  // matcher tokens.
+  class Sema {
+  public:
+    virtual ~Sema();
+    // Process a matcher expression. The caller takes ownership of the Matcher
+    // object returned.
+    virtual VariantMatcher
+    actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange,
+                           llvm::ArrayRef<ParserValue> args,
+                           Diagnostics *error) = 0;
+    // Look up a matcher by name in the matcher name found by the parser.
+    virtual std::optional<MatcherCtor>
+    lookupMatcherCtor(llvm::StringRef matcherName) = 0;
+    // Compute the list of completion types for Context.
+    virtual std::vector<ArgKind> getAcceptedCompletionTypes(
+        llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> Context);
+    // Compute the list of completions that match any of acceptedTypes.
+    virtual std::vector<MatcherCompletion>
+    getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes);
+  };
+  // An implementation of the Sema interface that uses the matcher registry to
+  // process tokens.
+  class RegistrySema : public Parser::Sema {
+  public:
+    RegistrySema(const Registry &matcherRegistry)
+        : matcherRegistry(matcherRegistry) {}
+    ~RegistrySema() override;
+    std::optional<MatcherCtor>
+    lookupMatcherCtor(llvm::StringRef matcherName) override;
+    VariantMatcher actOnMatcherExpression(MatcherCtor ctor,
+                                          SourceRange nameRange,
+                                          llvm::ArrayRef<ParserValue> args,
+                                          Diagnostics *error) override;
+    std::vector<ArgKind> getAcceptedCompletionTypes(
+        llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) override;
+    std::vector<MatcherCompletion>
+    getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes) override;
+  private:
+    const Registry &matcherRegistry;
+  };
+  using NamedValueMap = llvm::StringMap<VariantValue>;
+  // Methods to parse a matcher expression and return a DynMatcher object,
+  // transferring ownership to the caller.
+  static std::optional<DynMatcher>
+  parseMatcherExpression(llvm::StringRef &matcherCode,
+                         const Registry &matcherRegistry,
+                         const NamedValueMap *namedValues, Diagnostics *error);
+  static std::optional<DynMatcher>
+  parseMatcherExpression(llvm::StringRef &matcherCode,
+                         const Registry &matcherRegistry, Diagnostics *error) {
+    return parseMatcherExpression(matcherCode, matcherRegistry, nullptr, error);
+  }
+  // Methods to parse any expression supported by this parser.
+  static bool parseExpression(llvm::StringRef &code,
+                              const Registry &matcherRegistry,
+                              const NamedValueMap *namedValues,
+                              VariantValue *value, Diagnostics *error);
+  static bool parseExpression(llvm::StringRef &code,
+                              const Registry &matcherRegistry,
+                              VariantValue *value, Diagnostics *error) {
+    return parseExpression(code, matcherRegistry, nullptr, value, error);
+  }
+  // Methods to complete an expression at a given offset.
+  static std::vector<MatcherCompletion>
+  completeExpression(llvm::StringRef &code, unsigned completionOffset,
+                     const Registry &matcherRegistry,
+                     const NamedValueMap *namedValues);
+  static std::vector<MatcherCompletion>
+  completeExpression(llvm::StringRef &code, unsigned completionOffset,
+                     const Registry &matcherRegistry) {
+    return completeExpression(code, completionOffset, matcherRegistry, nullptr);
+  }
+  class CodeTokenizer;
+  struct ScopedContextEntry;
+  struct TokenInfo;
+  Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry,
+         const NamedValueMap *namedValues, Diagnostics *error);
+  bool parseExpressionImpl(VariantValue *value);
+  bool parseMatcherArgs(std::vector<ParserValue> &args, MatcherCtor ctor,
+                        const TokenInfo &nameToken, TokenInfo &endToken);
+  bool parseMatcherExpressionImpl(const TokenInfo &nameToken,
+                                  const TokenInfo &openToken,
+                                  std::optional<MatcherCtor> ctor,
+                                  VariantValue *value);
+  bool parseIdentifierPrefixImpl(VariantValue *value);
+  void addCompletion(const TokenInfo &compToken,
+                     const MatcherCompletion &completion);
+  void addExpressionCompletions();
+  std::vector<MatcherCompletion>
+  getNamedValueCompletions(llvm::ArrayRef<ArgKind> acceptedTypes);
+  CodeTokenizer *const tokenizer;
+  std::unique_ptr<RegistrySema> sema;
+  const NamedValueMap *const namedValues;
+  Diagnostics *const error;
+  using ContextStackTy = std::vector<std::pair<MatcherCtor, unsigned>>;
+  ContextStackTy contextStack;
+  std::vector<MatcherCompletion> completions;
+} // namespace mlir::query::matcher::internal

diff  --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
new file mode 100644
index 000000000000000..01856aa8ffa67f3
--- /dev/null
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -0,0 +1,139 @@
+//===- RegistryManager.cpp - Matcher registry -----------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Registry map populated at static initialization time.
+#include "RegistryManager.h"
+#include "mlir/Query/Matcher/Registry.h"
+#include <set>
+#include <utility>
+namespace mlir::query::matcher {
+namespace {
+// This is needed because these matchers are defined as overloaded functions.
+using IsConstantOp = detail::constant_op_matcher();
+using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
+using HasOpName = detail::NameOpMatcher(llvm::StringRef);
+// Enum to string for autocomplete.
+static std::string asArgString(ArgKind kind) {
+  switch (kind) {
+  case ArgKind::Matcher:
+    return "Matcher";
+  case ArgKind::String:
+    return "String";
+  }
+  llvm_unreachable("Unhandled ArgKind");
+} // namespace
+void Registry::registerMatcherDescriptor(
+    llvm::StringRef matcherName,
+    std::unique_ptr<internal::MatcherDescriptor> callback) {
+  assert(!constructorMap.contains(matcherName));
+  constructorMap[matcherName] = std::move(callback);
+RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName,
+                                   const Registry &matcherRegistry) {
+  auto it = matcherRegistry.constructors().find(matcherName);
+  return it == matcherRegistry.constructors().end()
+             ? std::optional<MatcherCtor>()
+             : it->second.get();
+std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
+    llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
+  // Starting with the above seed of acceptable top-level matcher types, compute
+  // the acceptable type set for the argument indicated by each context element.
+  std::set<ArgKind> typeSet;
+  typeSet.insert(ArgKind::Matcher);
+  for (const auto &ctxEntry : context) {
+    MatcherCtor ctor = ctxEntry.first;
+    unsigned argNumber = ctxEntry.second;
+    std::vector<ArgKind> nextTypeSet;
+    if (argNumber < ctor->getNumArgs())
+      ctor->getArgKinds(argNumber, nextTypeSet);
+    typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
+  }
+  return std::vector<ArgKind>(typeSet.begin(), typeSet.end());
+RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
+                                       const Registry &matcherRegistry) {
+  std::vector<MatcherCompletion> completions;
+  // Search the registry for acceptable matchers.
+  for (const auto &m : matcherRegistry.constructors()) {
+    const internal::MatcherDescriptor &matcher = *m.getValue();
+    llvm::StringRef name = m.getKey();
+    unsigned numArgs = matcher.getNumArgs();
+    std::vector<std::vector<ArgKind>> argKinds(numArgs);
+    for (const ArgKind &kind : acceptedTypes) {
+      if (kind != ArgKind::Matcher)
+        continue;
+      for (unsigned arg = 0; arg != numArgs; ++arg)
+        matcher.getArgKinds(arg, argKinds[arg]);
+    }
+    std::string decl;
+    llvm::raw_string_ostream os(decl);
+    std::string typedText = std::string(name);
+    os << "Matcher: " << name << "(";
+    for (const std::vector<ArgKind> &arg : argKinds) {
+      if (&arg != &argKinds[0])
+        os << ", ";
+      bool firstArgKind = true;
+      // Two steps. First all non-matchers, then matchers only.
+      for (const ArgKind &argKind : arg) {
+        if (!firstArgKind)
+          os << "|";
+        firstArgKind = false;
+        os << asArgString(argKind);
+      }
+    }
+    os << ")";
+    typedText += "(";
+    if (argKinds.empty())
+      typedText += ")";
+    else if (argKinds[0][0] == ArgKind::String)
+      typedText += "\"";
+    completions.emplace_back(typedText, os.str());
+  }
+  return completions;
+VariantMatcher RegistryManager::constructMatcher(
+    MatcherCtor ctor, internal::SourceRange nameRange,
+    llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
+  return ctor->create(nameRange, args, error);
+} // namespace mlir::query::matcher

diff  --git a/mlir/lib/Query/Matcher/RegistryManager.h b/mlir/lib/Query/Matcher/RegistryManager.h
new file mode 100644
index 000000000000000..5f2867261225e76
--- /dev/null
+++ b/mlir/lib/Query/Matcher/RegistryManager.h
@@ -0,0 +1,70 @@
+//===--- RegistryManager.h - Matcher registry -------------------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// RegistryManager to manage registry of all known matchers.
+// The registry provides a generic interface to construct any matcher by name.
+#include "Diagnostics.h"
+#include "mlir/Query/Matcher/Marshallers.h"
+#include "mlir/Query/Matcher/Registry.h"
+#include "mlir/Query/Matcher/VariantValue.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include <string>
+namespace mlir::query::matcher {
+using MatcherCtor = const internal::MatcherDescriptor *;
+struct MatcherCompletion {
+  MatcherCompletion() = default;
+  MatcherCompletion(llvm::StringRef typedText, llvm::StringRef matcherDecl)
+      : typedText(typedText.str()), matcherDecl(matcherDecl.str()) {}
+  bool operator==(const MatcherCompletion &other) const {
+    return typedText == other.typedText && matcherDecl == other.matcherDecl;
+  }
+  // The text to type to select this matcher.
+  std::string typedText;
+  // The "declaration" of the matcher, with type information.
+  std::string matcherDecl;
+class RegistryManager {
+  RegistryManager() = delete;
+  static std::optional<MatcherCtor>
+  lookupMatcherCtor(llvm::StringRef matcherName,
+                    const Registry &matcherRegistry);
+  static std::vector<ArgKind> getAcceptedCompletionTypes(
+      llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context);
+  static std::vector<MatcherCompletion>
+  getMatcherCompletions(ArrayRef<ArgKind> acceptedTypes,
+                        const Registry &matcherRegistry);
+  static VariantMatcher constructMatcher(MatcherCtor ctor,
+                                         internal::SourceRange nameRange,
+                                         ArrayRef<ParserValue> args,
+                                         internal::Diagnostics *error);
+} // namespace mlir::query::matcher

diff  --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
new file mode 100644
index 000000000000000..65bd4bd77bcf8af
--- /dev/null
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -0,0 +1,132 @@
+//===--- Variantvalue.cpp -------------------------------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Query/Matcher/VariantValue.h"
+namespace mlir::query::matcher {
+VariantMatcher::Payload::~Payload() = default;
+class VariantMatcher::SinglePayload : public VariantMatcher::Payload {
+  explicit SinglePayload(DynMatcher matcher) : matcher(std::move(matcher)) {}
+  std::optional<DynMatcher> getDynMatcher() const override { return matcher; }
+  std::string getTypeAsString() const override { return "Matcher"; }
+  DynMatcher matcher;
+VariantMatcher::VariantMatcher() = default;
+VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) {
+  return VariantMatcher(std::make_shared<SinglePayload>(std::move(matcher)));
+std::optional<DynMatcher> VariantMatcher::getDynMatcher() const {
+  return value ? value->getDynMatcher() : std::nullopt;
+void VariantMatcher::reset() { value.reset(); }
+std::string VariantMatcher::getTypeAsString() const { return "<Nothing>"; }
+VariantValue::VariantValue(const VariantValue &other)
+    : type(ValueType::Nothing) {
+  *this = other;
+VariantValue::VariantValue(const llvm::StringRef string)
+    : type(ValueType::String) {
+  value.String = new llvm::StringRef(string);
+VariantValue::VariantValue(const VariantMatcher &matcher)
+    : type(ValueType::Matcher) {
+  value.Matcher = new VariantMatcher(matcher);
+VariantValue::~VariantValue() { reset(); }
+VariantValue &VariantValue::operator=(const VariantValue &other) {
+  if (this == &other)
+    return *this;
+  reset();
+  switch (other.type) {
+  case ValueType::String:
+    setString(other.getString());
+    break;
+  case ValueType::Matcher:
+    setMatcher(other.getMatcher());
+    break;
+  case ValueType::Nothing:
+    type = ValueType::Nothing;
+    break;
+  }
+  return *this;
+void VariantValue::reset() {
+  switch (type) {
+  case ValueType::String:
+    delete value.String;
+    break;
+  case ValueType::Matcher:
+    delete value.Matcher;
+    break;
+  // Cases that do nothing.
+  case ValueType::Nothing:
+    break;
+  }
+  type = ValueType::Nothing;
+bool VariantValue::isString() const { return type == ValueType::String; }
+const llvm::StringRef &VariantValue::getString() const {
+  assert(isString());
+  return *value.String;
+void VariantValue::setString(const llvm::StringRef &newValue) {
+  reset();
+  type = ValueType::String;
+  value.String = new llvm::StringRef(newValue);
+bool VariantValue::isMatcher() const { return type == ValueType::Matcher; }
+const VariantMatcher &VariantValue::getMatcher() const {
+  assert(isMatcher());
+  return *value.Matcher;
+void VariantValue::setMatcher(const VariantMatcher &newValue) {
+  reset();
+  type = ValueType::Matcher;
+  value.Matcher = new VariantMatcher(newValue);
+std::string VariantValue::getTypeAsString() const {
+  switch (type) {
+  case ValueType::String:
+    return "String";
+  case ValueType::Matcher:
+    return "Matcher";
+  case ValueType::Nothing:
+    return "Nothing";
+  }
+  llvm_unreachable("Invalid Type");
+} // namespace mlir::query::matcher

diff  --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp
new file mode 100644
index 000000000000000..5c42e5a5f0a116e
--- /dev/null
+++ b/mlir/lib/Query/Query.cpp
@@ -0,0 +1,82 @@
+//===---- Query.cpp - -----------------------------------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Query/Query.h"
+#include "QueryParser.h"
+#include "mlir/Query/Matcher/MatchFinder.h"
+#include "mlir/Query/QuerySession.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+namespace mlir::query {
+QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
+  return QueryParser::parse(line, qs);
+complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
+  return QueryParser::complete(line, pos, qs);
+static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
+                       const std::string &binding) {
+  auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+  auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+      qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+  qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+                                     "\"" + binding + "\" binds here");
+Query::~Query() = default;
+mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os,
+                                      QuerySession &qs) const {
+  os << errStr << "\n";
+  return mlir::failure();
+mlir::LogicalResult NoOpQuery::run(llvm::raw_ostream &os,
+                                   QuerySession &qs) const {
+  return mlir::success();
+mlir::LogicalResult HelpQuery::run(llvm::raw_ostream &os,
+                                   QuerySession &qs) const {
+  os << "Available commands:\n\n"
+        "  match MATCHER, m MATCHER      "
+        "Match the mlir against the given matcher.\n"
+        "  quit                              "
+        "Terminates the query session.\n\n";
+  return mlir::success();
+mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os,
+                                   QuerySession &qs) const {
+  qs.terminate = true;
+  return mlir::success();
+mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os,
+                                    QuerySession &qs) const {
+  int matchCount = 0;
+  std::vector<Operation *> matches =
+      matcher::MatchFinder().getMatches(qs.getRootOp(), matcher);
+  os << "\n";
+  for (Operation *op : matches) {
+    os << "Match #" << ++matchCount << ":\n\n";
+    // Placeholder "root" binding for the initial draft.
+    printMatch(os, qs, op, "root");
+  }
+  os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
+  return mlir::success();
+} // namespace mlir::query

diff  --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp
new file mode 100644
index 000000000000000..f43a28569f0078b
--- /dev/null
+++ b/mlir/lib/Query/QueryParser.cpp
@@ -0,0 +1,217 @@
+//===---- QueryParser.cpp - mlir-query command parser ---------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "QueryParser.h"
+#include "llvm/ADT/StringSwitch.h"
+namespace mlir::query {
+// Lex any amount of whitespace followed by a "word" (any sequence of
+// non-whitespace characters) from the start of region [begin,end).  If no word
+// is found before end, return StringRef(). begin is adjusted to exclude the
+// lexed region.
+llvm::StringRef QueryParser::lexWord() {
+  line = line.drop_while([](char c) {
+    // Don't trim newlines.
+    return llvm::StringRef(" \t\v\f\r").contains(c);
+  });
+  if (line.empty())
+    // Even though the line is empty, it contains a pointer and
+    // a (zero) length. The pointer is used in the LexOrCompleteWord
+    // code completion.
+    return line;
+  llvm::StringRef word;
+  if (line.front() == '#') {
+    word = line.substr(0, 1);
+  } else {
+    word = line.take_until([](char c) {
+      // Don't trim newlines.
+      return llvm::StringRef(" \t\v\f\r").contains(c);
+    });
+  }
+  line = line.drop_front(word.size());
+  return word;
+// This is the StringSwitch-alike used by LexOrCompleteWord below. See that
+// function for details.
+template <typename T>
+struct QueryParser::LexOrCompleteWord {
+  llvm::StringRef word;
+  llvm::StringSwitch<T> stringSwitch;
+  QueryParser *queryParser;
+  // Set to the completion point offset in word, or StringRef::npos if
+  // completion point not in word.
+  size_t wordCompletionPos;
+  // Lexes a word and stores it in word. Returns a LexOrCompleteword<T> object
+  // that can be used like a llvm::StringSwitch<T>, but adds cases as possible
+  // completions if the lexed word contains the completion point.
+  LexOrCompleteWord(QueryParser *queryParser, llvm::StringRef &outWord)
+      : word(queryParser->lexWord()), stringSwitch(word),
+        queryParser(queryParser), wordCompletionPos(llvm::StringRef::npos) {
+    outWord = word;
+    if (queryParser->completionPos &&
+        queryParser->completionPos <= word.data() + word.size()) {
+      if (queryParser->completionPos < word.data())
+        wordCompletionPos = 0;
+      else
+        wordCompletionPos = queryParser->completionPos - word.data();
+    }
+  }
+  LexOrCompleteWord &Case(llvm::StringLiteral caseStr, const T &value,
+                          bool isCompletion = true) {
+    if (wordCompletionPos == llvm::StringRef::npos)
+      stringSwitch.Case(caseStr, value);
+    else if (!caseStr.empty() && isCompletion &&
+             wordCompletionPos <= caseStr.size() &&
+             caseStr.substr(0, wordCompletionPos) ==
+                 word.substr(0, wordCompletionPos)) {
+      queryParser->completions.emplace_back(
+          (caseStr.substr(wordCompletionPos) + " ").str(),
+          std::string(caseStr));
+    }
+    return *this;
+  }
+  T Default(T value) { return stringSwitch.Default(value); }
+QueryRef QueryParser::endQuery(QueryRef queryRef) {
+  llvm::StringRef extra = line;
+  llvm::StringRef extraTrimmed = extra.drop_while(
+      [](char c) { return llvm::StringRef(" \t\v\f\r").contains(c); });
+  if ((!extraTrimmed.empty() && extraTrimmed[0] == '\n') ||
+      (extraTrimmed.size() >= 2 && extraTrimmed[0] == '\r' &&
+       extraTrimmed[1] == '\n'))
+    queryRef->remainingContent = extra;
+  else {
+    llvm::StringRef trailingWord = lexWord();
+    if (!trailingWord.empty() && trailingWord.front() == '#') {
+      line = line.drop_until([](char c) { return c == '\n'; });
+      line = line.drop_while([](char c) { return c == '\n'; });
+      return endQuery(queryRef);
+    }
+    if (!trailingWord.empty()) {
+      return new InvalidQuery("unexpected extra input: '" + extra + "'");
+    }
+  }
+  return queryRef;
+namespace {
+enum class ParsedQueryKind {
+  Invalid,
+  Comment,
+  NoOp,
+  Help,
+  Match,
+  Quit,
+makeInvalidQueryFromDiagnostics(const matcher::internal::Diagnostics &diag) {
+  std::string errStr;
+  llvm::raw_string_ostream os(errStr);
+  diag.print(os);
+  return new InvalidQuery(os.str());
+} // namespace
+QueryRef QueryParser::completeMatcherExpression() {
+  std::vector<matcher::MatcherCompletion> comps =
+      matcher::internal::Parser::completeExpression(
+          line, completionPos - line.begin(), qs.getRegistryData(),
+          &qs.namedValues);
+  for (const auto &comp : comps) {
+    completions.emplace_back(comp.typedText, comp.matcherDecl);
+  }
+  return QueryRef();
+QueryRef QueryParser::doParse() {
+  llvm::StringRef commandStr;
+  ParsedQueryKind qKind =
+      LexOrCompleteWord<ParsedQueryKind>(this, commandStr)
+          .Case("", ParsedQueryKind::NoOp)
+          .Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false)
+          .Case("help", ParsedQueryKind::Help)
+          .Case("m", ParsedQueryKind::Match, /*isCompletion=*/false)
+          .Case("match", ParsedQueryKind::Match)
+          .Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false)
+          .Case("quit", ParsedQueryKind::Quit)
+          .Default(ParsedQueryKind::Invalid);
+  switch (qKind) {
+  case ParsedQueryKind::Comment:
+  case ParsedQueryKind::NoOp:
+    line = line.drop_until([](char c) { return c == '\n'; });
+    line = line.drop_while([](char c) { return c == '\n'; });
+    if (line.empty())
+      return new NoOpQuery;
+    return doParse();
+  case ParsedQueryKind::Help:
+    return endQuery(new HelpQuery);
+  case ParsedQueryKind::Quit:
+    return endQuery(new QuitQuery);
+  case ParsedQueryKind::Match: {
+    if (completionPos) {
+      return completeMatcherExpression();
+    }
+    matcher::internal::Diagnostics diag;
+    auto matcherSource = line.ltrim();
+    auto origMatcherSource = matcherSource;
+    std::optional<matcher::DynMatcher> matcher =
+        matcher::internal::Parser::parseMatcherExpression(
+            matcherSource, qs.getRegistryData(), &qs.namedValues, &diag);
+    if (!matcher) {
+      return makeInvalidQueryFromDiagnostics(diag);
+    }
+    auto actualSource = origMatcherSource.slice(0, origMatcherSource.size() -
+                                                       matcherSource.size());
+    QueryRef query = new MatchQuery(actualSource, *matcher);
+    query->remainingContent = matcherSource;
+    return query;
+  }
+  case ParsedQueryKind::Invalid:
+    return new InvalidQuery("unknown command: " + commandStr);
+  }
+  llvm_unreachable("Invalid query kind");
+QueryRef QueryParser::parse(llvm::StringRef line, const QuerySession &qs) {
+  return QueryParser(line, qs).doParse();
+QueryParser::complete(llvm::StringRef line, size_t pos,
+                      const QuerySession &qs) {
+  QueryParser queryParser(line, qs);
+  queryParser.completionPos = line.data() + pos;
+  queryParser.doParse();
+  return queryParser.completions;
+} // namespace mlir::query

diff  --git a/mlir/lib/Query/QueryParser.h b/mlir/lib/Query/QueryParser.h
new file mode 100644
index 000000000000000..e9c30eccecab9e4
--- /dev/null
+++ b/mlir/lib/Query/QueryParser.h
@@ -0,0 +1,59 @@
+//===--- QueryParser.h - ----------------------------------------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "Matcher/Parser.h"
+#include "mlir/Query/Query.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/LineEditor/LineEditor.h"
+namespace mlir::query {
+class QuerySession;
+class QueryParser {
+  // Parse line as a query and return a QueryRef representing the query, which
+  // may be an InvalidQuery.
+  static QueryRef parse(llvm::StringRef line, const QuerySession &qs);
+  static std::vector<llvm::LineEditor::Completion>
+  complete(llvm::StringRef line, size_t pos, const QuerySession &qs);
+  QueryParser(llvm::StringRef line, const QuerySession &qs)
+      : line(line), completionPos(nullptr), qs(qs) {}
+  llvm::StringRef lexWord();
+  template <typename T>
+  struct LexOrCompleteWord;
+  QueryRef completeMatcherExpression();
+  QueryRef endQuery(QueryRef queryRef);
+  // Parse [begin, end) and returns a reference to the parsed query object,
+  // which may be an InvalidQuery if a parse error occurs.
+  QueryRef doParse();
+  llvm::StringRef line;
+  const char *completionPos;
+  std::vector<llvm::LineEditor::Completion> completions;
+  const QuerySession &qs;
+} // namespace mlir::query

diff  --git a/mlir/lib/Tools/CMakeLists.txt b/mlir/lib/Tools/CMakeLists.txt
index 6175a1ce5f8d1db..01270fa4b0fc341 100644
--- a/mlir/lib/Tools/CMakeLists.txt
+++ b/mlir/lib/Tools/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(lsp-server-support)

diff  --git a/mlir/lib/Tools/mlir-query/CMakeLists.txt b/mlir/lib/Tools/mlir-query/CMakeLists.txt
new file mode 100644
index 000000000000000..b81b02d42bfcaa4
--- /dev/null
+++ b/mlir/lib/Tools/mlir-query/CMakeLists.txt
@@ -0,0 +1,13 @@
+  lineeditor
+  )
+  MlirQueryMain.cpp
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-query
+  MLIRQuery
+  )

diff  --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
new file mode 100644
index 000000000000000..15de16a8774bc07
--- /dev/null
+++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp
@@ -0,0 +1,115 @@
+//===- MlirQueryMain.cpp - MLIR Query main --------------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// This file implements the general framework of the MLIR query tool. It
+// parses the command line arguments, parses the MLIR file and outputs the query
+// results.
+#include "mlir/Tools/mlir-query/MlirQueryMain.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Query/Query.h"
+#include "mlir/Query/QuerySession.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/LineEditor/LineEditor.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/SourceMgr.h"
+// Query Parser
+mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context,
+                    const mlir::query::matcher::Registry &matcherRegistry) {
+  // Override the default '-h' and use the default PrintHelpMessage() which
+  // won't print options in categories.
+  static llvm::cl::opt<bool> help("h", llvm::cl::desc("Alias for -help"),
+                                  llvm::cl::Hidden);
+  static llvm::cl::OptionCategory mlirQueryCategory("mlir-query options");
+  static llvm::cl::list<std::string> commands(
+      "c", llvm::cl::desc("Specify command to run"),
+      llvm::cl::value_desc("command"), llvm::cl::cat(mlirQueryCategory));
+  static llvm::cl::opt<std::string> inputFilename(
+      llvm::cl::Positional, llvm::cl::desc("<input file>"),
+      llvm::cl::cat(mlirQueryCategory));
+  static llvm::cl::opt<bool> noImplicitModule{
+      "no-implicit-module",
+      llvm::cl::desc(
+          "Disable implicit addition of a top-level module op during parsing"),
+      llvm::cl::init(false)};
+  static llvm::cl::opt<bool> allowUnregisteredDialects(
+      "allow-unregistered-dialect",
+      llvm::cl::desc("Allow operation with no registered dialects"),
+      llvm::cl::init(false));
+  llvm::cl::HideUnrelatedOptions(mlirQueryCategory);
+  llvm::InitLLVM y(argc, argv);
+  llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR test case query tool.\n");
+  if (help) {
+    llvm::cl::PrintHelpMessage();
+    return mlir::success();
+  }
+  // Set up the input file.
+  std::string errorMessage;
+  auto file = openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return mlir::failure();
+  }
+  auto sourceMgr = llvm::SourceMgr();
+  auto bufferId = sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+  context.allowUnregisteredDialects(allowUnregisteredDialects);
+  // Parse the input MLIR file.
+  OwningOpRef<Operation *> opRef =
+      noImplicitModule ? parseSourceFile(sourceMgr, &context)
+                       : parseSourceFile<mlir::ModuleOp>(sourceMgr, &context);
+  if (!opRef)
+    return mlir::failure();
+  mlir::query::QuerySession qs(opRef.get(), sourceMgr, bufferId,
+                               matcherRegistry);
+  if (!commands.empty()) {
+    for (auto &command : commands) {
+      mlir::query::QueryRef queryRef = mlir::query::parse(command, qs);
+      if (mlir::failed(queryRef->run(llvm::outs(), qs)))
+        return mlir::failure();
+    }
+  } else {
+    llvm::LineEditor le("mlir-query");
+    le.setListCompleter([&qs](llvm::StringRef line, size_t pos) {
+      return mlir::query::complete(line, pos, qs);
+    });
+    while (std::optional<std::string> line = le.readLine()) {
+      mlir::query::QueryRef queryRef = mlir::query::parse(*line, qs);
+      (void)queryRef->run(llvm::outs(), qs);
+      llvm::outs().flush();
+      if (qs.terminate)
+        break;
+    }
+  }
+  return mlir::success();

diff  --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index bf143d036c2f66f..6fc9ae0f3fc58fa 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -104,6 +104,7 @@ set(MLIR_TEST_DEPENDS
+  mlir-query

diff  --git a/mlir/test/mlir-query/simple-test.mlir b/mlir/test/mlir-query/simple-test.mlir
new file mode 100644
index 000000000000000..a4d006598767b3f
--- /dev/null
+++ b/mlir/test/mlir-query/simple-test.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-query %s -c "m isConstantOp()" | FileCheck %s
+// CHECK: {{.*}}.mlir:5:13: note: "root" binds here
+func.func @simple1() {
+  %c1_i32 = arith.constant 1 : i32
+  return
+// CHECK: {{.*}}.mlir:12:11: note: "root" binds here
+// CHECK: {{.*}}.mlir:13:11: note: "root" binds here
+func.func @simple2() {
+  %cst1 = arith.constant 1.0 : f32
+  %cst2 = arith.constant 2.0 : f32
+  %add = arith.addf %cst1, %cst2 : f32
+  return

diff  --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index e9a1e4d6251722e..a01f74f737e1bc1 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(mlir-lsp-server)

diff  --git a/mlir/tools/mlir-query/CMakeLists.txt b/mlir/tools/mlir-query/CMakeLists.txt
new file mode 100644
index 000000000000000..ef2e5a84b5569b5
--- /dev/null
+++ b/mlir/tools/mlir-query/CMakeLists.txt
@@ -0,0 +1,20 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+  set(test_libs
+    MLIRTestDialect
+    )
+  mlir-query.cpp
+  )
+  ${dialect_libs}
+  ${test_libs}
+  MLIRQueryLib
+  )

diff  --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
new file mode 100644
index 000000000000000..0ed4f94d5802b09
--- /dev/null
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -0,0 +1,63 @@
+//===- mlir-query.cpp - MLIR Query Driver ---------------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// This is a command line utility that queries a file from/to MLIR using one
+// of the registered queries.
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Query/Matcher/Registry.h"
+#include "mlir/Tools/mlir-query/MlirQueryMain.h"
+using namespace mlir;
+// This is needed because these matchers are defined as overloaded functions.
+using HasOpAttrName = detail::AttrOpMatcher(StringRef);
+using HasOpName = detail::NameOpMatcher(StringRef);
+using IsConstantOp = detail::constant_op_matcher();
+namespace test {
+void registerTestDialect(DialectRegistry &);
+} // namespace test
+int main(int argc, char **argv) {
+  DialectRegistry dialectRegistry;
+  registerAllDialects(dialectRegistry);
+  query::matcher::Registry matcherRegistry;
+  // Matchers registered in alphabetical order for consistency:
+  matcherRegistry.registerMatcher("hasOpAttrName",
+                                  static_cast<HasOpAttrName *>(m_Attr));
+  matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
+  matcherRegistry.registerMatcher("isConstantOp",
+                                  static_cast<IsConstantOp *>(m_Constant));
+  matcherRegistry.registerMatcher("isNegInfFloat", m_NegInfFloat);
+  matcherRegistry.registerMatcher("isNegZeroFloat", m_NegZeroFloat);
+  matcherRegistry.registerMatcher("isNonZero", m_NonZero);
+  matcherRegistry.registerMatcher("isOne", m_One);
+  matcherRegistry.registerMatcher("isOneFloat", m_OneFloat);
+  matcherRegistry.registerMatcher("isPosInfFloat", m_PosInfFloat);
+  matcherRegistry.registerMatcher("isPosZeroFloat", m_PosZeroFloat);
+  matcherRegistry.registerMatcher("isZero", m_Zero);
+  matcherRegistry.registerMatcher("isZeroFloat", m_AnyZeroFloat);
+  test::registerTestDialect(dialectRegistry);
+  MLIRContext context(dialectRegistry);
+  return failed(mlirQueryMain(argc, argv, context, matcherRegistry));


More information about the Mlir-commits mailing list