[Mlir-commits] [mlir] [mlir] MLIR-QUERY DefinitionsMatcher implementation & DAG (PR #115670)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 20 13:51:35 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Denzel-Brian Budii (chios202)

<details>
<summary>Changes</summary>

This Pull Request aims to improve MLIR-QUERY tool by implementing `getBackwardSlice `and `getForwardSlice `matchers. As an addition `SetQuery ` also needed to be added to enable custom configuration for each query. e.g: `inclusive`, `omitUsesFromAbove`, `omitBlockArguments`.

Example of current matcher. The query was made to the file: `mlir/test/mlir-query/complex-test.mlir`

```
mlir-query> match getDefinitions(hasOpName("arith.addf"),2)
Match #<!-- -->1:

/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:5:8: note: "root" binds here
  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
       ^
/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:7:10: note: "root" binds here
    %2 = arith.addf %in, %in : f32
         ^
Match #<!-- -->2:

/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:10:16: note: "root" binds here
  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
               ^
/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:13:11: note: "root" binds here
    %c2 = arith.constant 2 : index
          ^
/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:14:18: note: "root" binds here
    %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
                 ^
/home/dbudii/personal/llvm-project-fork/mlir/test/mlir-query/complex-test.mlir:15:10: note: "root" binds here
    %2 = arith.addf %extracted, %extracted : f32
         ^
mlir-query> 
```


---

Patch is 34.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115670.diff


17 Files Affected:

- (modified) mlir/include/mlir/IR/Matchers.h (+2-2) 
- (added) mlir/include/mlir/Query/Matcher/ExtraMatchers.h (+180) 
- (modified) mlir/include/mlir/Query/Matcher/Marshallers.h (+15) 
- (modified) mlir/include/mlir/Query/Matcher/MatchFinder.h (+39-11) 
- (modified) mlir/include/mlir/Query/Matcher/MatchersInternal.h (+49-11) 
- (modified) mlir/include/mlir/Query/Matcher/VariantValue.h (+11-1) 
- (modified) mlir/include/mlir/Query/Query.h (+48-1) 
- (modified) mlir/include/mlir/Query/QuerySession.h (+10-1) 
- (modified) mlir/lib/Query/Matcher/Parser.cpp (+48-6) 
- (modified) mlir/lib/Query/Matcher/RegistryManager.cpp (+2) 
- (modified) mlir/lib/Query/Matcher/VariantValue.cpp (+24) 
- (modified) mlir/lib/Query/Query.cpp (+26-20) 
- (modified) mlir/lib/Query/QueryParser.cpp (+76-1) 
- (modified) mlir/lib/Query/QueryParser.h (+1-1) 
- (added) mlir/test/mlir-query/complex-test.mlir (+22) 
- (modified) mlir/test/mlir-query/function-extraction.mlir (+2-2) 
- (modified) mlir/tools/mlir-query/mlir-query.cpp (+10-1) 


``````````diff
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 1dce055db1b4a7..2204a68be26b10 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -59,7 +59,7 @@ struct NameOpMatcher {
   NameOpMatcher(StringRef name) : name(name) {}
   bool match(Operation *op) { return op->getName().getStringRef() == name; }
 
-  StringRef name;
+  std::string name;
 };
 
 /// The matcher that matches operations that have the specified attribute name.
@@ -67,7 +67,7 @@ struct AttrOpMatcher {
   AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
   bool match(Operation *op) { return op->hasAttr(attrName); }
 
-  StringRef attrName;
+  std::string attrName;
 };
 
 /// The matcher that matches operations that have the `ConstantLike` trait, and
diff --git a/mlir/include/mlir/Query/Matcher/ExtraMatchers.h b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
new file mode 100644
index 00000000000000..57adc3241b0bef
--- /dev/null
+++ b/mlir/include/mlir/Query/Matcher/ExtraMatchers.h
@@ -0,0 +1,180 @@
+//===- ExtraMatchers.h - Various common matchers ---------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides extra matchers that are very useful for mlir-query
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_EXTRAMATCHERS_H
+#define MLIR_IR_EXTRAMATCHERS_H
+
+#include "MatchFinder.h"
+#include "MatchersInternal.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Query/Query.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+
+namespace query {
+
+namespace extramatcher {
+
+namespace detail {
+
+class BackwardSliceMatcher {
+public:
+  BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+  bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
+               QueryOptions &options, unsigned tempHops) {
+
+    if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
+      return false;
+    }
+
+    auto processValue = [&](Value value) {
+      if (tempHops == 0) {
+        return;
+      }
+      if (auto *definingOp = value.getDefiningOp()) {
+        if (backwardSlice.count(definingOp) == 0)
+          matches(definingOp, backwardSlice, options, tempHops - 1);
+      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+        if (options.omitBlockArguments)
+          return;
+        Block *block = blockArg.getOwner();
+
+        Operation *parentOp = block->getParentOp();
+
+        if (parentOp && backwardSlice.count(parentOp) == 0) {
+          assert(parentOp->getNumRegions() == 1 &&
+                 parentOp->getRegion(0).getBlocks().size() == 1);
+          matches(parentOp, backwardSlice, options, tempHops-1);
+        }
+      } else {
+        llvm_unreachable("No definingOp and not a block argument.");
+      }
+    };
+
+    if (!options.omitUsesFromAbove) {
+      llvm::for_each(op->getRegions(), [&](Region &region) {
+        SmallPtrSet<Region *, 4> descendents;
+        region.walk(
+            [&](Region *childRegion) { descendents.insert(childRegion); });
+        region.walk([&](Operation *op) {
+          for (OpOperand &operand : op->getOpOperands()) {
+            if (!descendents.contains(operand.get().getParentRegion()))
+              processValue(operand.get());
+          }
+        });
+      });
+    }
+
+    llvm::for_each(op->getOperands(), processValue);
+    backwardSlice.insert(op);
+    return true;
+  }
+
+public:
+  bool match(Operation *op, SetVector<Operation *> &backwardSlice,
+             QueryOptions &options) {
+    if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
+      if (!options.inclusive) {
+        backwardSlice.remove(op);
+      }
+      return true;
+    }
+    return false;
+  }
+
+private:
+  matcher::DynMatcher innerMatcher;
+  unsigned hops;
+};
+
+class ForwardSliceMatcher {
+public:
+  ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
+      : innerMatcher(std::move(innerMatcher)), hops(hops) {}
+
+private:
+  bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
+               QueryOptions &options, unsigned tempHops) {
+
+    if (tempHops == 0) {
+      forwardSlice.insert(op);
+      return true;
+    }
+
+    for (Region &region : op->getRegions())
+      for (Block &block : region)
+        for (Operation &blockOp : block)
+          if (forwardSlice.count(&blockOp) == 0)
+            matches(&blockOp, forwardSlice, options, tempHops - 1);
+    for (Value result : op->getResults()) {
+      for (Operation *userOp : result.getUsers())
+        if (forwardSlice.count(userOp) == 0)
+          matches(userOp, forwardSlice, options, tempHops - 1);
+    }
+
+    forwardSlice.insert(op);
+    return true;
+  }
+
+public:
+  bool match(Operation *op, SetVector<Operation *> &forwardSlice,
+             QueryOptions &options) {
+    if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
+      if (!options.inclusive) {
+        forwardSlice.remove(op);
+      }
+      SmallVector<Operation *, 0> v(forwardSlice.takeVector());
+      forwardSlice.insert(v.rbegin(), v.rend());
+      return true;
+    }
+    return false;
+  }
+
+private:
+  matcher::DynMatcher innerMatcher;
+  unsigned hops;
+};
+
+} // namespace detail
+
+inline detail::BackwardSliceMatcher
+definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::BackwardSliceMatcher
+getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+  return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+inline detail::ForwardSliceMatcher
+usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
+}
+
+inline detail::ForwardSliceMatcher
+getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
+  return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
+}
+
+} // namespace extramatcher
+
+} // namespace query
+
+} // namespace mlir
+
+#endif // MLIR_IR_EXTRAMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 6ed35ac0ddccc7..c775dbc5c86da0 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
   }
 };
 
+template <>
+struct ArgTypeTraits<unsigned> {
+  static bool hasCorrectType(const VariantValue &value) {
+    return value.isUnsigned();
+  }
+
+  static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
+
+  static ArgKind getKind() { return ArgKind::Unsigned; }
+
+  static std::optional<std::string> getBestGuess(const VariantValue &) {
+    return std::nullopt;
+  }
+};
+
 template <>
 struct ArgTypeTraits<DynMatcher> {
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index b008a21f53ae2a..2175db86a91bdf 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file contains the MatchFinder class, which is used to find operations
-// that match a given matcher.
+// that match a given matcher and print them.
 //
 //===----------------------------------------------------------------------===//
 
@@ -15,24 +15,52 @@
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
 
 #include "MatchersInternal.h"
+#include "mlir/Query/QuerySession.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir::query::matcher {
 
-// MatchFinder is used to find all operations that match a given matcher.
 class MatchFinder {
+private:
+  static void printMatch(llvm::raw_ostream &os, QuerySession &qs,
+                         mlir::Operation *op, const std::string &binding) {
+    auto fileLoc = op->getLoc()->findInstanceOf<FileLineColLoc>();
+    auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
+        qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
+    qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
+                                       "\"" + binding + "\" binds here");
+  };
+
 public:
-  // Returns all operations that match the given matcher.
-  static std::vector<Operation *> getMatches(Operation *root,
-                                             DynMatcher matcher) {
-    std::vector<Operation *> matches;
+  static std::vector<Operation *>
+  getMatches(Operation *root, QueryOptions &options, DynMatcher matcher,
+             llvm::raw_ostream &os, QuerySession &qs) {
+    unsigned matchCount = 0;
+    std::vector<Operation *> matchedOps;
+    SetVector<Operation *> tempStorage;
 
-    // Simple match finding with walk.
     root->walk([&](Operation *subOp) {
-      if (matcher.match(subOp))
-        matches.push_back(subOp);
-    });
+      if (matcher.match(subOp)) {
+        matchedOps.push_back(subOp);
+        os << "Match #" << ++matchCount << ":\n\n";
+        printMatch(os, qs, subOp, "root");
+      } else {
+        SmallVector<Operation *> printingOps;
 
-    return matches;
+        if (matcher.match(subOp, tempStorage, options)) {
+          os << "Match #" << ++matchCount << ":\n\n";
+          SmallVector<Operation *> printingOps(tempStorage.takeVector());
+          for (auto op : printingOps) {
+            printMatch(os, qs, op, "root");
+            matchedOps.push_back(op);
+          }
+          printingOps.clear();
+        }
+      }
+    });
+    return matchedOps;
   }
 };
 
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 117f7d4edef9e3..b532b47be7d051 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -1,4 +1,3 @@
-//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -10,28 +9,53 @@
 //
 // Matchers are methods that return a Matcher which provides a method
 // match(Operation *op)
+// match(Operation *op, SetVector<Operation *> &matchedOps, QueryOptions
+// &options)
 //
 // The matcher functions are defined in include/mlir/IR/Matchers.h.
 // This file contains the wrapper classes needed to construct matchers for
 // mlir-query.
 //
 //===----------------------------------------------------------------------===//
-
 #ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 #define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
 
 #include "mlir/IR/Matchers.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 
+namespace mlir {
+namespace query {
+struct QueryOptions;
+}
+} // namespace mlir
+
 namespace mlir::query::matcher {
+template <typename T, typename = void>
+struct has_simple_match : std::false_type {};
+
+template <typename T>
+struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
+                               std::declval<Operation *>()))>>
+    : std::true_type {};
+
+template <typename T, typename = void>
+struct has_bound_match : std::false_type {};
+
+template <typename T>
+struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
+                              std::declval<Operation *>(),
+                              std::declval<SetVector<Operation *> &>(),
+                              std::declval<QueryOptions &>()))>>
+    : std::true_type {};
 
 // Generic interface for matchers on an MLIR operation.
 class MatcherInterface
     : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
 public:
   virtual ~MatcherInterface() = default;
-
   virtual bool match(Operation *op) = 0;
+  virtual bool match(Operation *op, SetVector<Operation *> &matchedOps,
+                     QueryOptions &options) = 0;
 };
 
 // MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +64,26 @@ template <typename MatcherFn>
 class MatcherFnImpl : public MatcherInterface {
 public:
   MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
-  bool match(Operation *op) override { return matcherFn.match(op); }
+
+  bool match(Operation *op) override {
+    if constexpr (has_simple_match<MatcherFn>::value)
+      return matcherFn.match(op);
+    return false;
+  }
+
+  bool match(Operation *op, SetVector<Operation *> &matchedOps,
+             QueryOptions &options) override {
+    if constexpr (has_bound_match<MatcherFn>::value)
+      return matcherFn.match(op, matchedOps, options);
+    return false;
+  }
 
 private:
   MatcherFn matcherFn;
 };
 
-// Matcher wraps a MatcherInterface implementation and provides a match()
-// method that redirects calls to the underlying implementation.
+// Matcher wraps a MatcherInterface implementation and provides match()
+// methods that redirect calls to the underlying implementation.
 class DynMatcher {
 public:
   // Takes ownership of the provided implementation pointer.
@@ -62,12 +98,14 @@ class DynMatcher {
   }
 
   bool match(Operation *op) const { return implementation->match(op); }
+  bool match(Operation *op, SetVector<Operation *> &matchedOps,
+             QueryOptions &options) const {
+    return implementation->match(op, matchedOps, options);
+  }
 
-  void setFunctionName(StringRef name) { functionName = name.str(); };
-
-  bool hasFunctionName() const { return !functionName.empty(); };
-
-  StringRef getFunctionName() const { return functionName; };
+  void setFunctionName(StringRef name) { functionName = name.str(); }
+  bool hasFunctionName() const { return !functionName.empty(); }
+  StringRef getFunctionName() const { return functionName; }
 
 private:
   llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 449f8b3a01e021..6b57119df7a9bf 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -21,7 +21,7 @@
 namespace mlir::query::matcher {
 
 // All types that VariantValue can contain.
-enum class ArgKind { Matcher, String };
+enum class ArgKind { Matcher, String, Unsigned };
 
 // A variant matcher object to abstract simple and complex matchers into a
 // single object type.
@@ -81,6 +81,7 @@ class VariantValue {
   // Specific constructors for each supported type.
   VariantValue(const llvm::StringRef string);
   VariantValue(const VariantMatcher &matcher);
+  VariantValue(unsigned Unsigned);
 
   // String value functions.
   bool isString() const;
@@ -92,8 +93,15 @@ class VariantValue {
   const VariantMatcher &getMatcher() const;
   void setMatcher(const VariantMatcher &matcher);
 
+  // Unsigned value functions.
+  bool isUnsigned() const;
+  unsigned getUnsigned() const;
+  void setUnsigned(unsigned Unsigned);
+
   // String representation of the type of the value.
   std::string getTypeAsString() const;
+  explicit operator bool() const { return hasValue(); }
+  bool hasValue() const { return type != ValueType::Nothing; }
 
 private:
   void reset();
@@ -103,12 +111,14 @@ class VariantValue {
     Nothing,
     String,
     Matcher,
+    Unsigned,
   };
 
   // All supported value types.
   union AllValues {
     llvm::StringRef *String;
     VariantMatcher *Matcher;
+    unsigned Unsigned;
   };
 
   ValueType type;
diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h
index 18f2172c9510a3..89d48773d2c3e6 100644
--- a/mlir/include/mlir/Query/Query.h
+++ b/mlir/include/mlir/Query/Query.h
@@ -17,7 +17,13 @@
 
 namespace mlir::query {
 
-enum class QueryKind { Invalid, NoOp, Help, Match, Quit };
+struct QueryOptions {
+  bool omitBlockArguments = false;
+  bool omitUsesFromAbove = true;
+  bool inclusive = true;
+};
+
+enum class QueryKind { Invalid, NoOp, Help, Match, Quit, Let, SetBool };
 
 class QuerySession;
 
@@ -103,6 +109,47 @@ struct MatchQuery : Query {
   }
 };
 
+struct LetQuery : Query {
+  LetQuery(llvm::StringRef name, const matcher::VariantValue &value)
+      : Query(QueryKind::Let), name(name), value(value) {}
+
+  llvm::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override;
+
+  std::string name;
+  matcher::VariantValue value;
+
+  static bool classof(const Query *query) {
+    return query->kind == QueryKind::Let;
+  }
+};
+
+template <typename T>
+struct SetQueryKind {};
+
+template <>
+struct SetQueryKind<bool> {
+  static const QueryKind value = QueryKind::SetBool;
+};
+template <typename T>
+struct SetQuery : Query {
+  SetQuery(T QuerySession::*var, T value)
+      : Query(SetQueryKind<T>::value), var(var), value(value) {}
+
+  llvm::LogicalResult run(llvm::raw_ostream &os,
+                          QuerySession &qs) const override {
+    qs.*var = value;
+    return mlir::success();
+  }
+
+  static bool classof(const Query *query) {
+    return query->kind == SetQueryKind<T>::value;
+  }
+
+  T QuerySession::*var;
+  T value;
+};
+
 } // namespace mlir::query
 
 #endif
diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h
index fe552d750fc771..495358e8f36f94 100644
--- a/mlir/include/mlir/Query/QuerySession.h
+++ b/mlir/include/mlir/Query/QuerySession.h
@@ -9,14 +9,18 @@
 #ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
 #define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H
 
+#include "Matcher/VariantValue.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Query/Matcher/Registry.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/Support/SourceMgr.h"
 
+namespace mlir::query::matcher {
+class Registry;
+}
+
 namespace mlir::query {
 
-class Registry;
 // Represents the state for a particular mlir-query session.
 class QuerySession {
 public:
@@ -33,6 +37,11 @@ class QuerySession {
   llvm::StringMap<matcher::VariantValue> namedValues;
   bool terminate = false;
 
+public:
+  bool omitBlockArguments = false;
+  bool omitUsesFromAbove = true;
+  bool inclusive = true;
+
 private:
   Operation *rootOp;
   llvm::SourceMgr &sourceMgr;
diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp
index 3609e24f9939f7..726f1188d7e4c8 100644
--- a/mlir/lib/Query/Matcher/Parser.cpp
+++ b/mlir/lib/Query/Matcher/Parser.cpp
@@ -135,6 +135,18 @@ class Parser::CodeTokenizer {
     case '\'':
       consumeStringLiteral(&result);
       break;
+    case '0':
+    case '1':
+    case '2':
+    case '3':
+    case '4':
+    case '5':
+    case '6':
+    case '7':
+    case '8':
+    case '9':
+      consumeNumberLiteral(&result);
+      break;
     default:
       parseIdentifierOrInvalid(&result);
       break;
@@ -144,6 +156,30 @@ class Parser::CodeTokenizer {
     return result;
   }
 
+  void consumeNumberLiteral(TokenInfo *result) {
+    unsigned length = 1;
+    if (code.size() > 1) {
+      // Consume the 'x' or 'b' radix modifier, if present.
+      switch (tolower(code[1])) {
+      case 'x':
+      case 'b':
+        length = 2;
+      }
+    }
+    while (length < code.size() && isdigit(code[length]))
+      ++length;
+
+    result->text = code.take_front(length);
+    code = code.drop_front(length);
+
+    unsigned value;
+    if (!result->text.getAsInteger(0, value)) {
+      result->kind = TokenKind::Literal;
+      result->value = static_cast<unsigned>(value);
+      return;
+    }
+  }
+
   // Consume a string literal, handle escape sequences and missing closing
   // quote.
   void consumeStringLiteral(TokenInfo *result) {
@@ -257,13 +293,19 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) {
 
   if (tokenizer->nextTokenKind() != TokenKind::OpenParen) {
     // Parse as a named value.
-    auto namedValue =
-        namedValues ? namedValues->lookup(nameToken.text) : VariantValue();
+    if (auto namedValue = namedValues ? namedValues->lookup(nameToken.te...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/115670


More information about the Mlir-commits mailing list