[Mlir-commits] [mlir] [mlir] Improve mlir-query by adding matcher combinators (PR #141423)
Denzel-Brian Budii
llvmlistbot at llvm.org
Tue May 27 12:37:43 PDT 2025
https://github.com/chios202 updated https://github.com/llvm/llvm-project/pull/141423
>From 11792a69dc5f52f4f44dee3d9e06ece613ac3685 Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Sun, 25 May 2025 17:57:30 +0000
Subject: [PATCH 1/2] Improve MLIR-Query by adding matcher combinators Limit
backward-slice with nested matching Add variadic operators
---
mlir/include/mlir/Query/Matcher/Marshallers.h | 61 ++++++++++
mlir/include/mlir/Query/Matcher/MatchFinder.h | 4 +-
.../mlir/Query/Matcher/MatchersInternal.h | 109 ++++++++++++++++-
.../mlir/Query/Matcher/SliceMatchers.h | 110 +++++++++++++++++-
.../include/mlir/Query/Matcher/VariantValue.h | 11 +-
mlir/lib/Query/Matcher/CMakeLists.txt | 1 +
mlir/lib/Query/Matcher/MatchersInternal.cpp | 36 ++++++
mlir/lib/Query/Matcher/RegistryManager.cpp | 7 +-
mlir/lib/Query/Matcher/VariantValue.cpp | 54 +++++++++
mlir/tools/mlir-query/mlir-query.cpp | 14 ++-
10 files changed, 392 insertions(+), 15 deletions(-)
create mode 100644 mlir/lib/Query/Matcher/MatchersInternal.cpp
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 012bf7b9ec4a9..f81e789f274e6 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -108,6 +108,9 @@ class MatcherDescriptor {
const llvm::ArrayRef<ParserValue> args,
Diagnostics *error) const = 0;
+ // If the matcher is variadic, it can take any number of arguments.
+ virtual bool isVariadic() const = 0;
+
// Returns the number of arguments accepted by the matcher.
virtual unsigned getNumArgs() const = 0;
@@ -140,6 +143,8 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
return marshaller(matcherFunc, matcherName, nameRange, args, error);
}
+ bool isVariadic() const override { return false; }
+
unsigned getNumArgs() const override { return argKinds.size(); }
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -153,6 +158,54 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
const std::vector<ArgKind> argKinds;
};
+class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
+public:
+ using VarOp = DynMatcher::VariadicOperator;
+ VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
+ VarOp varOp, StringRef matcherName)
+ : minCount(minCount), maxCount(maxCount), varOp(varOp),
+ matcherName(matcherName) {}
+
+ VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
+ Diagnostics *error) const override {
+ if (args.size() < minCount || maxCount < args.size()) {
+ addError(error, nameRange, ErrorType::RegistryWrongArgCount,
+ {llvm::Twine("requires between "), llvm::Twine(minCount),
+ llvm::Twine(" and "), llvm::Twine(maxCount),
+ llvm::Twine(" args, got "), llvm::Twine(args.size())});
+ return VariantMatcher();
+ }
+
+ std::vector<VariantMatcher> innerArgs;
+ for (size_t i = 0, e = args.size(); i != e; ++i) {
+ const ParserValue &arg = args[i];
+ const VariantValue &value = arg.value;
+ if (!value.isMatcher()) {
+ addError(error, arg.range, ErrorType::RegistryWrongArgType,
+ {llvm::Twine(i + 1), llvm::Twine("Matcher: "),
+ llvm::Twine(value.getTypeAsString())});
+ return VariantMatcher();
+ }
+ innerArgs.push_back(value.getMatcher());
+ }
+ return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
+ }
+
+ bool isVariadic() const override { return true; }
+
+ unsigned getNumArgs() const override { return 0; }
+
+ void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
+ kinds.push_back(ArgKind(ArgKind::Matcher));
+ }
+
+private:
+ const unsigned minCount;
+ const unsigned maxCount;
+ const VarOp varOp;
+ const StringRef matcherName;
+};
+
// Helper function to check if argument count matches expected count
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
llvm::ArrayRef<ParserValue> args,
@@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
}
+// Variadic operator overload.
+template <unsigned MinCount, unsigned MaxCount>
+std::unique_ptr<MatcherDescriptor>
+makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
+ StringRef matcherName) {
+ return std::make_unique<VariadicOperatorMatcherDescriptor>(
+ MinCount, MaxCount, func.varOp, matcherName);
+}
} // namespace mlir::query::matcher::internal
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index f8abf20ef60bb..6d06ca13d1344 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -21,7 +21,9 @@
namespace mlir::query::matcher {
-/// A class that provides utilities to find operations in the IR.
+/// Finds and collects matches from the IR. After construction
+/// `collectMatches` can be used to traverse the IR and apply
+/// matchers.
class MatchFinder {
public:
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 183b2514e109f..a1c08417eb889 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,11 +8,11 @@
//
// Implements the base layer of the matcher framework.
//
-// Matchers are methods that return a Matcher which provides a method one of the
-// following methods: match(Operation *op), match(Operation *op,
-// SetVector<Operation *> &matchedOps)
+// Matchers are methods that return a Matcher which provide a method
+// `match(...)` method. The method's parameters define the context of the match.
+// Support includes simple (unary) matchers as well as matcher combinators.
+// (anyOf, allOf, etc.)
//
-// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
@@ -25,6 +25,15 @@
#include "llvm/ADT/IntrusiveRefCntPtr.h"
namespace mlir::query::matcher {
+class DynMatcher;
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers);
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers);
+
+} // namespace internal
// Defaults to false if T has no match() method with the signature:
// match(Operation* op).
@@ -84,6 +93,26 @@ class MatcherFnImpl : public MatcherInterface {
MatcherFn matcherFn;
};
+// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
+// match the given operation.
+using VariadicOperatorFunction = bool (*)(Operation *op,
+ SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers);
+
+template <VariadicOperatorFunction Func>
+class VariadicMatcher : public MatcherInterface {
+public:
+ VariadicMatcher(std::vector<DynMatcher> matchers) : matchers(matchers) {}
+
+ bool match(Operation *op) override { return Func(op, nullptr, matchers); }
+ bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
+ return Func(op, &matchedOps, matchers);
+ }
+
+private:
+ std::vector<DynMatcher> matchers;
+};
+
// Matcher wraps a MatcherInterface implementation and provides match()
// methods that redirect calls to the underlying implementation.
class DynMatcher {
@@ -92,6 +121,31 @@ class DynMatcher {
DynMatcher(MatcherInterface *implementation)
: implementation(implementation) {}
+ // Construct from a variadic function.
+ enum VariadicOperator {
+ // Matches operations for which all provided matchers match.
+ AllOf,
+ // Matches operations for which at least one of the provided matchers
+ // matches.
+ AnyOf
+ };
+
+ static std::unique_ptr<DynMatcher>
+ constructVariadic(VariadicOperator Op,
+ std::vector<DynMatcher> innerMatchers) {
+ switch (Op) {
+ case AllOf:
+ return std::make_unique<DynMatcher>(
+ new VariadicMatcher<internal::allOfVariadicOperator>(
+ std::move(innerMatchers)));
+ case AnyOf:
+ return std::make_unique<DynMatcher>(
+ new VariadicMatcher<internal::anyOfVariadicOperator>(
+ std::move(innerMatchers)));
+ }
+ llvm_unreachable("Invalid Op value.");
+ }
+
template <typename MatcherFn>
static std::unique_ptr<DynMatcher>
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -113,6 +167,53 @@ class DynMatcher {
std::string functionName;
};
+// VariadicOperatorMatcher related types.
+template <typename... Ps>
+class VariadicOperatorMatcher {
+public:
+ VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
+ : varOp(varOp), params(std::forward<Ps>(params)...) {}
+
+ operator std::unique_ptr<DynMatcher>() const & {
+ return DynMatcher::constructVariadic(
+ varOp, getMatchers(std::index_sequence_for<Ps...>()));
+ }
+
+ operator std::unique_ptr<DynMatcher>() && {
+ return DynMatcher::constructVariadic(
+ varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
+ }
+
+private:
+ // Helper method to unpack the tuple into a vector.
+ template <std::size_t... Is>
+ std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
+ return {DynMatcher(std::get<Is>(params))...};
+ }
+
+ template <std::size_t... Is>
+ std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
+ return {DynMatcher(std::get<Is>(std::move(params)))...};
+ }
+
+ const DynMatcher::VariadicOperator varOp;
+ std::tuple<Ps...> params;
+};
+
+// Overloaded function object to generate VariadicOperatorMatcher objects from
+// arbitrary matchers.
+template <unsigned MinCount, unsigned MaxCount>
+struct VariadicOperatorMatcherFunc {
+ DynMatcher::VariadicOperator varOp;
+
+ template <typename... Ms>
+ VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
+ static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
+ "invalid number of parameters for variadic matcher");
+ return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
+ }
+};
+
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 5bb8251672eb7..2bc23d6479337 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
-// This file provides matchers for MLIRQuery that peform slicing analysis
+// This file defines slicing-analysis matchers that extend and abstract the
+// core implementations from `SliceAnalysis.h`.
//
//===----------------------------------------------------------------------===//
@@ -15,9 +16,9 @@
#include "mlir/Analysis/SliceAnalysis.h"
-/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
-/// Additionally, it limits the slice computation to a certain depth level using
-/// a custom filter.
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. The traversal stops once the desired depth level
+/// is reached.
///
/// Example: starting from node 9, assuming the matcher
/// computes the slice for the first two depth levels:
@@ -116,6 +117,83 @@ bool BackwardSliceMatcher<Matcher>::matches(
: backwardSlice.size() >= 1;
}
+/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateBackwardSliceMatcher {
+public:
+ PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive, bool omitBlockArguments,
+ bool omitUsesFromAbove)
+ : innerMatcher(std::move(innerMatcher)),
+ filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
+ omitBlockArguments(omitBlockArguments),
+ omitUsesFromAbove(omitUsesFromAbove) {}
+
+ bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
+ backwardSlice.clear();
+ BackwardSliceOptions options;
+ options.inclusive = inclusive;
+ options.omitUsesFromAbove = omitUsesFromAbove;
+ options.omitBlockArguments = omitBlockArguments;
+ if (innerMatcher.match(rootOp)) {
+ options.filter = [&](Operation *subOp) {
+ return !filterMatcher.match(subOp);
+ };
+ getBackwardSlice(rootOp, &backwardSlice, options);
+ return options.inclusive ? backwardSlice.size() > 1
+ : backwardSlice.size() >= 1;
+ }
+ return false;
+ }
+
+private:
+ BaseMatcher innerMatcher;
+ Filter filterMatcher;
+ bool inclusive;
+ bool omitBlockArguments;
+ bool omitUsesFromAbove;
+};
+
+/// Computes the forward-slice of all users reachable from `rootOp`,
+/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
+template <typename BaseMatcher, typename Filter>
+class PredicateForwardSliceMatcher {
+public:
+ PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive)
+ : innerMatcher(std::move(innerMatcher)),
+ filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
+
+ bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
+ forwardSlice.clear();
+ ForwardSliceOptions options;
+ options.inclusive = inclusive;
+ if (innerMatcher.match(rootOp)) {
+ options.filter = [&](Operation *subOp) {
+ return !filterMatcher.match(subOp);
+ };
+ getForwardSlice(rootOp, &forwardSlice, options);
+ return options.inclusive ? forwardSlice.size() > 1
+ : forwardSlice.size() >= 1;
+ }
+ return false;
+ }
+
+private:
+ BaseMatcher innerMatcher;
+ Filter filterMatcher;
+ bool inclusive;
+};
+
+namespace internal {
+const matcher::VariadicOperatorMatcherFunc<1,
+ std::numeric_limits<unsigned>::max()>
+ anyOf = {matcher::DynMatcher::AnyOf};
+const matcher::VariadicOperatorMatcherFunc<1,
+ std::numeric_limits<unsigned>::max()>
+ allOf = {matcher::DynMatcher::AllOf};
+} // namespace internal
/// Matches transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher>
@@ -127,7 +205,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
omitUsesFromAbove);
}
-/// Matches all transitive defs of a top-level operation up to N levels
+/// Matches all transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
int64_t maxDepth) {
@@ -136,6 +214,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
false, false);
}
+/// Matches all transitive defs of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
+m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive, bool omitBlockArguments,
+ bool omitUsesFromAbove) {
+ return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
+ std::move(innerMatcher), std::move(filterMatcher), inclusive,
+ omitBlockArguments, omitUsesFromAbove);
+}
+
+/// Matches all users of a top-level operation and stops where
+/// `filterMatcher` rejects.
+template <typename BaseMatcher, typename Filter>
+inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
+m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
+ bool inclusive) {
+ return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
+ std::move(innerMatcher), std::move(filterMatcher), inclusive);
+}
+
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 98c0a18e25101..1a47576de1841 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
class VariantMatcher {
- class MatcherOps;
+ class MatcherOps {
+ public:
+ std::optional<DynMatcher>
+ constructVariadicOperator(DynMatcher::VariadicOperator varOp,
+ ArrayRef<VariantMatcher> innerMatchers) const;
+ };
// Payload interface to be specialized by each matcher type. It follows a
// similar interface as VariantMatcher itself.
@@ -43,6 +48,9 @@ class VariantMatcher {
// Clones the provided matcher.
static VariantMatcher SingleMatcher(DynMatcher matcher);
+ static VariantMatcher
+ VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+ ArrayRef<VariantMatcher> args);
// Makes the matcher the "null" matcher.
void reset();
@@ -61,6 +69,7 @@ class VariantMatcher {
: value(std::move(value)) {}
class SinglePayload;
+ class VariadicOpPayload;
std::shared_ptr<const Payload> value;
};
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index 629479bf7adc1..ba202762fdfbb 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_library(MLIRQueryMatcher
MatchFinder.cpp
+ MatchersInternal.cpp
Parser.cpp
RegistryManager.cpp
VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
new file mode 100644
index 0000000000000..223e6d1bdcf4f
--- /dev/null
+++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp
@@ -0,0 +1,36 @@
+//===--- MatchersInternal.cpp----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements the base layer of the matcher framework.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Query/Matcher/MatchersInternal.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace mlir::query::matcher {
+
+namespace internal {
+
+bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers) {
+ return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) {
+ if (matchedOps)
+ return matcher.match(op, *matchedOps);
+ return matcher.match(op);
+ });
+}
+bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
+ ArrayRef<DynMatcher> innerMatchers) {
+ return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) {
+ if (matchedOps)
+ return matcher.match(op, *matchedOps);
+ return matcher.match(op);
+ });
+}
+} // namespace internal
+} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 4b511c5f009e7..08b610453b11a 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
unsigned argNumber = ctxEntry.second;
std::vector<ArgKind> nextTypeSet;
- if (argNumber < ctor->getNumArgs())
+ if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
ctor->getArgKinds(argNumber, nextTypeSet);
typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
const internal::MatcherDescriptor &matcher = *m.getValue();
llvm::StringRef name = m.getKey();
- unsigned numArgs = matcher.getNumArgs();
+ unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
std::vector<std::vector<ArgKind>> argKinds(numArgs);
for (const ArgKind &kind : acceptedTypes) {
@@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
}
}
+ if (matcher.isVariadic())
+ os << ",...";
+
os << ")";
typedText += "(";
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 1cb2d48f9d56f..61316cfd0d489 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -27,12 +27,66 @@ class VariantMatcher::SinglePayload : public VariantMatcher::Payload {
DynMatcher matcher;
};
+class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload {
+public:
+ VariadicOpPayload(DynMatcher::VariadicOperator varOp,
+ std::vector<VariantMatcher> args)
+ : varOp(varOp), args(std::move(args)) {}
+
+ std::optional<DynMatcher> getDynMatcher() const override {
+ std::vector<DynMatcher> dynMatchers;
+ for (auto variantMatcher : args) {
+ std::optional<DynMatcher> dynMatcher = variantMatcher.getDynMatcher();
+ if (dynMatcher)
+ dynMatchers.push_back(dynMatcher.value());
+ }
+ auto result = DynMatcher::constructVariadic(varOp, dynMatchers);
+ return *result;
+ }
+
+ std::string getTypeAsString() const override {
+ std::string inner;
+ for (size_t i = 0, e = args.size(); i != e; ++i) {
+ if (i != 0)
+ inner += "&";
+ inner += args[i].getTypeAsString();
+ }
+ return inner;
+ }
+
+private:
+ const DynMatcher::VariadicOperator varOp;
+ const std::vector<VariantMatcher> args;
+};
+
VariantMatcher::VariantMatcher() = default;
VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) {
return VariantMatcher(std::make_shared<SinglePayload>(std::move(matcher)));
}
+VariantMatcher
+VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
+ ArrayRef<VariantMatcher> args) {
+ return VariantMatcher(
+ std::make_shared<VariadicOpPayload>(varOp, std::move(args)));
+}
+
+std::optional<DynMatcher> VariantMatcher::MatcherOps::constructVariadicOperator(
+ DynMatcher::VariadicOperator varOp,
+ ArrayRef<VariantMatcher> innerMatchers) const {
+ std::vector<DynMatcher> dynMatchers;
+ for (const auto &innerMatcher : innerMatchers) {
+ if (!innerMatcher.value)
+ return std::nullopt;
+ std::optional<DynMatcher> inner = innerMatcher.value->getDynMatcher();
+ if (!inner)
+ return std::nullopt;
+ dynMatchers.push_back(*inner);
+ }
+ return *DynMatcher::constructVariadic(varOp, dynMatchers);
+}
+
std::optional<DynMatcher> VariantMatcher::getDynMatcher() const {
return value ? value->getDynMatcher() : std::nullopt;
}
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 78c0ec97c0cdf..8a17a33c61838 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -40,12 +40,22 @@ int main(int argc, char **argv) {
query::matcher::Registry matcherRegistry;
// Matchers registered in alphabetical order for consistency:
+ matcherRegistry.registerMatcher("allOf", query::matcher::internal::allOf);
+ matcherRegistry.registerMatcher("anyOf", query::matcher::internal::anyOf);
+ matcherRegistry.registerMatcher(
+ "getAllDefinitions",
+ query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>);
matcherRegistry.registerMatcher(
"getDefinitions",
query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
matcherRegistry.registerMatcher(
- "getAllDefinitions",
- query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>);
+ "getDefinitionsByPredicate",
+ query::matcher::m_GetDefinitionsByPredicate<query::matcher::DynMatcher,
+ query::matcher::DynMatcher>);
+ matcherRegistry.registerMatcher(
+ "getUsersByPredicate",
+ query::matcher::m_GetUsersByPredicate<query::matcher::DynMatcher,
+ query::matcher::DynMatcher>);
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
>From 37682e7b16457f8f6b9ce84d8d40f3d64385bda3 Mon Sep 17 00:00:00 2001
From: chios202 <chio.star at yahoo.com>
Date: Tue, 27 May 2025 16:04:16 +0000
Subject: [PATCH 2/2] Add test cases Add test cases for variadic matchers
Relocate variadic matchers
---
.../mlir/Query/Matcher/MatchersInternal.h | 6 +++++
.../mlir/Query/Matcher/SliceMatchers.h | 8 ------
mlir/lib/Query/Matcher/MatchersInternal.cpp | 5 +---
mlir/lib/Query/Matcher/VariantValue.cpp | 8 +++---
...ex-test.mlir => backward-slice-union.mlir} | 13 +++++++--
.../forward-slice-by-predicate.mlir | 27 +++++++++++++++++++
.../mlir-query/logical-operator-test.mlir | 12 +++++++++
7 files changed, 60 insertions(+), 19 deletions(-)
rename mlir/test/mlir-query/{complex-test.mlir => backward-slice-union.mlir} (71%)
create mode 100644 mlir/test/mlir-query/forward-slice-by-predicate.mlir
create mode 100644 mlir/test/mlir-query/logical-operator-test.mlir
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index a1c08417eb889..482b2393e27eb 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -214,6 +214,12 @@ struct VariadicOperatorMatcherFunc {
}
};
+namespace internal {
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+ anyOf = {DynMatcher::AnyOf};
+const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
+ allOf = {DynMatcher::AllOf};
+} // namespace internal
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 2bc23d6479337..d9dba2d106854 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -186,14 +186,6 @@ class PredicateForwardSliceMatcher {
bool inclusive;
};
-namespace internal {
-const matcher::VariadicOperatorMatcherFunc<1,
- std::numeric_limits<unsigned>::max()>
- anyOf = {matcher::DynMatcher::AnyOf};
-const matcher::VariadicOperatorMatcherFunc<1,
- std::numeric_limits<unsigned>::max()>
- allOf = {matcher::DynMatcher::AllOf};
-} // namespace internal
/// Matches transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher>
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
index 223e6d1bdcf4f..01f412ade846b 100644
--- a/mlir/lib/Query/Matcher/MatchersInternal.cpp
+++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp
@@ -5,10 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-//
-// Implements the base layer of the matcher framework.
-//
-//===----------------------------------------------------------------------===//
+
#include "mlir/Query/Matcher/MatchersInternal.h"
#include "llvm/ADT/SetVector.h"
diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp
index 61316cfd0d489..7bf4774dba830 100644
--- a/mlir/lib/Query/Matcher/VariantValue.cpp
+++ b/mlir/lib/Query/Matcher/VariantValue.cpp
@@ -46,11 +46,9 @@ class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload {
std::string getTypeAsString() const override {
std::string inner;
- for (size_t i = 0, e = args.size(); i != e; ++i) {
- if (i != 0)
- inner += "&";
- inner += args[i].getTypeAsString();
- }
+ llvm::interleave(
+ args, [&](auto const &arg) { inner += arg.getTypeAsString(); },
+ [&] { inner += " & "; });
return inner;
}
diff --git a/mlir/test/mlir-query/complex-test.mlir b/mlir/test/mlir-query/backward-slice-union.mlir
similarity index 71%
rename from mlir/test/mlir-query/complex-test.mlir
rename to mlir/test/mlir-query/backward-slice-union.mlir
index ad96f03747a43..f8f88c2043749 100644
--- a/mlir/test/mlir-query/complex-test.mlir
+++ b/mlir/test/mlir-query/backward-slice-union.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s
+// RUN: mlir-query %s -c "m anyOf(getAllDefinitions(hasOpName(\"arith.addf\"),2),getAllDefinitions(hasOpName(\"tensor.extract\"),1))" | FileCheck %s
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
@@ -19,14 +19,23 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
}
// CHECK: Match #1:
-
// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
+
+// CHECK: {{.*}}.mlir:7:10: note: "root" binds here
// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
// CHECK: Match #2:
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
+// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
+// CHECK: {{.*}}.mlir:14:18: note: "root" binds here
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
+
+// CHECK: Match #3:
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
+
+// CHECK: {{.*}}.mlir:15:10: note: "root" binds here
// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
diff --git a/mlir/test/mlir-query/forward-slice-by-predicate.mlir b/mlir/test/mlir-query/forward-slice-by-predicate.mlir
new file mode 100644
index 0000000000000..2ff70a08e0590
--- /dev/null
+++ b/mlir/test/mlir-query/forward-slice-by-predicate.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-query %s -c "m getUsersByPredicate(anyOf(hasOpName(\"memref.alloc\"),isConstantOp()),hasOpName(\"affine.load\"),true)" | FileCheck %s
+
+func.func @slice_depth1_loop_nest_with_offsets() {
+ %0 = memref.alloc() : memref<100xf32>
+ %cst = arith.constant 7.000000e+00 : f32
+ affine.for %i0 = 0 to 16 {
+ %a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0)
+ affine.store %cst, %0[%a0] : memref<100xf32>
+ }
+ affine.for %i1 = 4 to 8 {
+ %a1 = affine.apply affine_map<(d0) -> (d0 - 1)>(%i1)
+ %1 = affine.load %0[%a1] : memref<100xf32>
+ }
+ return
+}
+
+// CHECK: Match #1:
+// CHECK: {{.*}}.mlir:4:8: note: "root" binds here
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<100xf32>
+
+// CHECK: affine.store %cst, %0[%a0] : memref<100xf32>
+
+// CHECK: Match #2:
+// CHECK: {{.*}}.mlir:5:10: note: "root" binds here
+// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32
+
+// CHECK: affine.store %[[CST]], %0[%a0] : memref<100xf32>
diff --git a/mlir/test/mlir-query/logical-operator-test.mlir b/mlir/test/mlir-query/logical-operator-test.mlir
new file mode 100644
index 0000000000000..b63d3d180d27f
--- /dev/null
+++ b/mlir/test/mlir-query/logical-operator-test.mlir
@@ -0,0 +1,12 @@
+
+// RUN: mlir-query %s -c "m allOf(hasOpName(\"memref.alloca\"), hasOpAttrName(\"alignment\"))" | FileCheck %s
+
+func.func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> {
+ %0 = memref.alloca(%arg0, %arg1) : memref<?x?xf32>
+ memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32>
+ return %0 : memref<?x?xf32>
+}
+
+// CHECK: Match #1:
+// CHECK: {{.*}}.mlir:6:3: note: "root" binds here
+// CHECK: memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32>
More information about the Mlir-commits
mailing list