[Mlir-commits] [mlir] Revert "[mlir] Improve mlir-query by adding matcher combinators" (PR #145534)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 24 08:38:11 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir
Author: Qinkun Bao (qinkunbao)
<details>
<summary>Changes</summary>
Reverts llvm/llvm-project#<!-- -->141423
---
Patch is 31.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145534.diff
15 Files Affected:
- (modified) mlir/include/mlir/Query/Matcher/Marshallers.h (-61)
- (modified) mlir/include/mlir/Query/Matcher/MatchFinder.h (+1-3)
- (modified) mlir/include/mlir/Query/Matcher/MatchersInternal.h (+4-112)
- (modified) mlir/include/mlir/Query/Matcher/SliceMatchers.h (+5-99)
- (modified) mlir/include/mlir/Query/Matcher/VariantValue.h (+1-10)
- (modified) mlir/lib/Query/Matcher/CMakeLists.txt (-1)
- (removed) mlir/lib/Query/Matcher/MatchersInternal.cpp (-33)
- (modified) mlir/lib/Query/Matcher/RegistryManager.cpp (+2-5)
- (modified) mlir/lib/Query/Matcher/VariantValue.cpp (-52)
- (modified) mlir/lib/Query/Query.cpp (-5)
- (renamed) mlir/test/mlir-query/complex-test.mlir (+2-11)
- (removed) mlir/test/mlir-query/forward-slice-by-predicate.mlir (-27)
- (removed) mlir/test/mlir-query/logical-operator-test.mlir (-11)
- (removed) mlir/test/mlir-query/slice-function-extraction.mlir (-29)
- (modified) mlir/tools/mlir-query/mlir-query.cpp (+2-12)
``````````diff
diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h
index 5fe6965f32efb..012bf7b9ec4a9 100644
--- a/mlir/include/mlir/Query/Matcher/Marshallers.h
+++ b/mlir/include/mlir/Query/Matcher/Marshallers.h
@@ -108,9 +108,6 @@ class MatcherDescriptor {
const llvm::ArrayRef<ParserValue> args,
Diagnostics *error) const = 0;
- // If the matcher is variadic, it can take any number of arguments.
- virtual bool isVariadic() const = 0;
-
// Returns the number of arguments accepted by the matcher.
virtual unsigned getNumArgs() const = 0;
@@ -143,8 +140,6 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
return marshaller(matcherFunc, matcherName, nameRange, args, error);
}
- bool isVariadic() const override { return false; }
-
unsigned getNumArgs() const override { return argKinds.size(); }
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -158,54 +153,6 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
const std::vector<ArgKind> argKinds;
};
-class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
-public:
- using VarOp = DynMatcher::VariadicOperator;
- VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
- VarOp varOp, StringRef matcherName)
- : minCount(minCount), maxCount(maxCount), varOp(varOp),
- matcherName(matcherName) {}
-
- VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
- Diagnostics *error) const override {
- if (args.size() < minCount || maxCount < args.size()) {
- addError(error, nameRange, ErrorType::RegistryWrongArgCount,
- {llvm::Twine("requires between "), llvm::Twine(minCount),
- llvm::Twine(" and "), llvm::Twine(maxCount),
- llvm::Twine(" args, got "), llvm::Twine(args.size())});
- return VariantMatcher();
- }
-
- std::vector<VariantMatcher> innerArgs;
- for (int64_t i = 0, e = args.size(); i != e; ++i) {
- const ParserValue &arg = args[i];
- const VariantValue &value = arg.value;
- if (!value.isMatcher()) {
- addError(error, arg.range, ErrorType::RegistryWrongArgType,
- {llvm::Twine(i + 1), llvm::Twine("matcher: "),
- llvm::Twine(value.getTypeAsString())});
- return VariantMatcher();
- }
- innerArgs.push_back(value.getMatcher());
- }
- return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
- }
-
- bool isVariadic() const override { return true; }
-
- unsigned getNumArgs() const override { return 0; }
-
- void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
- kinds.push_back(ArgKind(ArgKind::Matcher));
- }
-
-private:
- const unsigned minCount;
- const unsigned maxCount;
- const VarOp varOp;
- const StringRef matcherName;
-};
-
// Helper function to check if argument count matches expected count
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
llvm::ArrayRef<ParserValue> args,
@@ -277,14 +224,6 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
}
-// Variadic operator overload.
-template <unsigned MinCount, unsigned MaxCount>
-std::unique_ptr<MatcherDescriptor>
-makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
- StringRef matcherName) {
- return std::make_unique<VariadicOperatorMatcherDescriptor>(
- MinCount, MaxCount, func.varOp, matcherName);
-}
} // namespace mlir::query::matcher::internal
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h
index 6d06ca13d1344..f8abf20ef60bb 100644
--- a/mlir/include/mlir/Query/Matcher/MatchFinder.h
+++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h
@@ -21,9 +21,7 @@
namespace mlir::query::matcher {
-/// Finds and collects matches from the IR. After construction
-/// `collectMatches` can be used to traverse the IR and apply
-/// matchers.
+/// A class that provides utilities to find operations in the IR.
class MatchFinder {
public:
diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
index 88109430b6feb..183b2514e109f 100644
--- a/mlir/include/mlir/Query/Matcher/MatchersInternal.h
+++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h
@@ -8,11 +8,11 @@
//
// Implements the base layer of the matcher framework.
//
-// Matchers are methods that return a Matcher which provides a
-// `match(...)` method whose parameters define the context of the match.
-// Support includes simple (unary) matchers as well as matcher combinators
-// (anyOf, allOf, etc.)
+// Matchers are methods that return a Matcher which provides a method one of the
+// following methods: match(Operation *op), match(Operation *op,
+// SetVector<Operation *> &matchedOps)
//
+// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
@@ -25,15 +25,6 @@
#include "llvm/ADT/IntrusiveRefCntPtr.h"
namespace mlir::query::matcher {
-class DynMatcher;
-namespace internal {
-
-bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers);
-bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers);
-
-} // namespace internal
// Defaults to false if T has no match() method with the signature:
// match(Operation* op).
@@ -93,27 +84,6 @@ class MatcherFnImpl : public MatcherInterface {
MatcherFn matcherFn;
};
-// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
-// match the given operation.
-using VariadicOperatorFunction = bool (*)(Operation *op,
- SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers);
-
-template <VariadicOperatorFunction Func>
-class VariadicMatcher : public MatcherInterface {
-public:
- VariadicMatcher(std::vector<DynMatcher> matchers)
- : matchers(std::move(matchers)) {}
-
- bool match(Operation *op) override { return Func(op, nullptr, matchers); }
- bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
- return Func(op, &matchedOps, matchers);
- }
-
-private:
- std::vector<DynMatcher> matchers;
-};
-
// Matcher wraps a MatcherInterface implementation and provides match()
// methods that redirect calls to the underlying implementation.
class DynMatcher {
@@ -122,31 +92,6 @@ class DynMatcher {
DynMatcher(MatcherInterface *implementation)
: implementation(implementation) {}
- // Construct from a variadic function.
- enum VariadicOperator {
- // Matches operations for which all provided matchers match.
- AllOf,
- // Matches operations for which at least one of the provided matchers
- // matches.
- AnyOf
- };
-
- static std::unique_ptr<DynMatcher>
- constructVariadic(VariadicOperator Op,
- std::vector<DynMatcher> innerMatchers) {
- switch (Op) {
- case AllOf:
- return std::make_unique<DynMatcher>(
- new VariadicMatcher<internal::allOfVariadicOperator>(
- std::move(innerMatchers)));
- case AnyOf:
- return std::make_unique<DynMatcher>(
- new VariadicMatcher<internal::anyOfVariadicOperator>(
- std::move(innerMatchers)));
- }
- llvm_unreachable("Invalid Op value.");
- }
-
template <typename MatcherFn>
static std::unique_ptr<DynMatcher>
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -168,59 +113,6 @@ class DynMatcher {
std::string functionName;
};
-// VariadicOperatorMatcher related types.
-template <typename... Ps>
-class VariadicOperatorMatcher {
-public:
- VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
- : varOp(varOp), params(std::forward<Ps>(params)...) {}
-
- operator std::unique_ptr<DynMatcher>() const & {
- return DynMatcher::constructVariadic(
- varOp, getMatchers(std::index_sequence_for<Ps...>()));
- }
-
- operator std::unique_ptr<DynMatcher>() && {
- return DynMatcher::constructVariadic(
- varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
- }
-
-private:
- // Helper method to unpack the tuple into a vector.
- template <std::size_t... Is>
- std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
- return {DynMatcher(std::get<Is>(params))...};
- }
-
- template <std::size_t... Is>
- std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
- return {DynMatcher(std::get<Is>(std::move(params)))...};
- }
-
- const DynMatcher::VariadicOperator varOp;
- std::tuple<Ps...> params;
-};
-
-// Overloaded function object to generate VariadicOperatorMatcher objects from
-// arbitrary matchers.
-template <unsigned MinCount, unsigned MaxCount>
-struct VariadicOperatorMatcherFunc {
- DynMatcher::VariadicOperator varOp;
-
- template <typename... Ms>
- VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
- static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
- "invalid number of parameters for variadic matcher");
- return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
- }
-};
-
-namespace internal {
-const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
- anyOf = {DynMatcher::AnyOf};
-const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
- allOf = {DynMatcher::AllOf};
-} // namespace internal
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
diff --git a/mlir/include/mlir/Query/Matcher/SliceMatchers.h b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
index 7181648f06f89..441205b3a9615 100644
--- a/mlir/include/mlir/Query/Matcher/SliceMatchers.h
+++ b/mlir/include/mlir/Query/Matcher/SliceMatchers.h
@@ -6,8 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines slicing-analysis matchers that extend and abstract the
-// core implementations from `SliceAnalysis.h`.
+// This file provides matchers for MLIRQuery that peform slicing analysis
//
//===----------------------------------------------------------------------===//
@@ -17,9 +16,9 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Operation.h"
-/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
-/// if `innerMatcher` matches. The traversal stops once the desired depth level
-/// is reached.
+/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
+/// Additionally, it limits the slice computation to a certain depth level using
+/// a custom filter.
///
/// Example: starting from node 9, assuming the matcher
/// computes the slice for the first two depth levels:
@@ -120,77 +119,6 @@ bool BackwardSliceMatcher<Matcher>::matches(
: backwardSlice.size() >= 1;
}
-/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
-/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
-template <typename BaseMatcher, typename Filter>
-class PredicateBackwardSliceMatcher {
-public:
- PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
- bool inclusive, bool omitBlockArguments,
- bool omitUsesFromAbove)
- : innerMatcher(std::move(innerMatcher)),
- filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
- omitBlockArguments(omitBlockArguments),
- omitUsesFromAbove(omitUsesFromAbove) {}
-
- bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
- backwardSlice.clear();
- BackwardSliceOptions options;
- options.inclusive = inclusive;
- options.omitUsesFromAbove = omitUsesFromAbove;
- options.omitBlockArguments = omitBlockArguments;
- if (innerMatcher.match(rootOp)) {
- options.filter = [&](Operation *subOp) {
- return !filterMatcher.match(subOp);
- };
- LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
- assert(result.succeeded() && "expected backward slice to succeed");
- (void)result;
- return options.inclusive ? backwardSlice.size() > 1
- : backwardSlice.size() >= 1;
- }
- return false;
- }
-
-private:
- BaseMatcher innerMatcher;
- Filter filterMatcher;
- bool inclusive;
- bool omitBlockArguments;
- bool omitUsesFromAbove;
-};
-
-/// Computes the forward-slice of all users reachable from `rootOp`,
-/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
-template <typename BaseMatcher, typename Filter>
-class PredicateForwardSliceMatcher {
-public:
- PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
- bool inclusive)
- : innerMatcher(std::move(innerMatcher)),
- filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
-
- bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
- forwardSlice.clear();
- ForwardSliceOptions options;
- options.inclusive = inclusive;
- if (innerMatcher.match(rootOp)) {
- options.filter = [&](Operation *subOp) {
- return !filterMatcher.match(subOp);
- };
- getForwardSlice(rootOp, &forwardSlice, options);
- return options.inclusive ? forwardSlice.size() > 1
- : forwardSlice.size() >= 1;
- }
- return false;
- }
-
-private:
- BaseMatcher innerMatcher;
- Filter filterMatcher;
- bool inclusive;
-};
-
/// Matches transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher>
@@ -202,7 +130,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
omitUsesFromAbove);
}
-/// Matches all transitive defs of a top-level operation up to N levels.
+/// Matches all transitive defs of a top-level operation up to N levels
template <typename Matcher>
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
int64_t maxDepth) {
@@ -211,28 +139,6 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
false, false);
}
-/// Matches all transitive defs of a top-level operation and stops where
-/// `filterMatcher` rejects.
-template <typename BaseMatcher, typename Filter>
-inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
-m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
- bool inclusive, bool omitBlockArguments,
- bool omitUsesFromAbove) {
- return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
- std::move(innerMatcher), std::move(filterMatcher), inclusive,
- omitBlockArguments, omitUsesFromAbove);
-}
-
-/// Matches all users of a top-level operation and stops where
-/// `filterMatcher` rejects.
-template <typename BaseMatcher, typename Filter>
-inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
-m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
- bool inclusive) {
- return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
- std::move(innerMatcher), std::move(filterMatcher), inclusive);
-}
-
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h
index 1a47576de1841..98c0a18e25101 100644
--- a/mlir/include/mlir/Query/Matcher/VariantValue.h
+++ b/mlir/include/mlir/Query/Matcher/VariantValue.h
@@ -26,12 +26,7 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
class VariantMatcher {
- class MatcherOps {
- public:
- std::optional<DynMatcher>
- constructVariadicOperator(DynMatcher::VariadicOperator varOp,
- ArrayRef<VariantMatcher> innerMatchers) const;
- };
+ class MatcherOps;
// Payload interface to be specialized by each matcher type. It follows a
// similar interface as VariantMatcher itself.
@@ -48,9 +43,6 @@ class VariantMatcher {
// Clones the provided matcher.
static VariantMatcher SingleMatcher(DynMatcher matcher);
- static VariantMatcher
- VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
- ArrayRef<VariantMatcher> args);
// Makes the matcher the "null" matcher.
void reset();
@@ -69,7 +61,6 @@ class VariantMatcher {
: value(std::move(value)) {}
class SinglePayload;
- class VariadicOpPayload;
std::shared_ptr<const Payload> value;
};
diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt
index ba202762fdfbb..629479bf7adc1 100644
--- a/mlir/lib/Query/Matcher/CMakeLists.txt
+++ b/mlir/lib/Query/Matcher/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_library(MLIRQueryMatcher
MatchFinder.cpp
- MatchersInternal.cpp
Parser.cpp
RegistryManager.cpp
VariantValue.cpp
diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp
deleted file mode 100644
index 01f412ade846b..0000000000000
--- a/mlir/lib/Query/Matcher/MatchersInternal.cpp
+++ /dev/null
@@ -1,33 +0,0 @@
-//===--- MatchersInternal.cpp----------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Query/Matcher/MatchersInternal.h"
-#include "llvm/ADT/SetVector.h"
-
-namespace mlir::query::matcher {
-
-namespace internal {
-
-bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers) {
- return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) {
- if (matchedOps)
- return matcher.match(op, *matchedOps);
- return matcher.match(op);
- });
-}
-bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
- ArrayRef<DynMatcher> innerMatchers) {
- return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) {
- if (matchedOps)
- return matcher.match(op, *matchedOps);
- return matcher.match(op);
- });
-}
-} // namespace internal
-} // namespace mlir::query::matcher
diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp
index 08b610453b11a..4b511c5f009e7 100644
--- a/mlir/lib/Query/Matcher/RegistryManager.cpp
+++ b/mlir/lib/Query/Matcher/RegistryManager.cpp
@@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
unsigned argNumber = ctxEntry.second;
std::vector<ArgKind> nextTypeSet;
- if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
+ if (argNumber < ctor->getNumArgs())
ctor->getArgKinds(argNumber, nextTypeSet);
typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
const internal::MatcherDescriptor &matcher = *m.getValue();
llvm::StringRef name = m.getKey();
- unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
+ unsigned numArgs = matcher.getNumArgs();
std::vector<std::vector<ArgKind>> argKinds(numArgs);
for (const ArgKind &kind : acceptedTypes) {
@@ -115,9 +115,6 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
}
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/145534
More information about the Mlir-commits
mailing list