[Mlir-commits] [mlir] [mlir] introduce transform.collect_matching (PR #76724)
Matthias Springer
llvmlistbot at llvm.org
Tue Jan 2 07:39:30 PST 2024
================
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
return DiagnosedSilenceableFailure::success();
}
+/// Returns `true` if both types implement one of the interfaces provided as
+/// template parameters.
+template <typename... Tys>
+static bool implementSameInterface(Type t1, Type t2) {
+ return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
+}
+
+/// Returns `true` if both types implement one of the transform dialect
+/// interfaces.
+static bool implementSameTransformInterface(Type t1, Type t2) {
+ return implementSameInterface<transform::TransformHandleTypeInterface,
+ transform::TransformParamTypeInterface,
+ transform::TransformValueHandleTypeInterface>(
+ t1, t2);
+}
+
+//===----------------------------------------------------------------------===//
+// CollectMatchingOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
+ getOperation(), getMatcher());
+ if (matcher.isExternal()) {
+ return emitDefiniteFailure()
+ << "unresolved external symbol " << getMatcher();
+ }
+
+ SmallVector<SmallVector<MappedValue>, 2> rawResults;
+ rawResults.resize(getOperation()->getNumResults());
+ std::optional<DiagnosedSilenceableFailure> maybeFailure;
+ for (Operation *root : state.getPayloadOps(getRoot())) {
+ WalkResult walkResult = root->walk([&](Operation *op) {
+ DEBUG_MATCHER({
+ DBGS_MATCHER() << "matching ";
+ op->print(llvm::dbgs(),
+ OpPrintingFlags().assumeVerified().skipRegions());
+ llvm::dbgs() << " @" << op << "\n";
+ });
+
+ // Try matching.
+ SmallVector<SmallVector<MappedValue>> mappings;
+ DiagnosedSilenceableFailure diag =
+ matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+ if (diag.isDefiniteFailure())
+ return WalkResult::interrupt();
+ if (diag.isSilenceableFailure()) {
+ DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage());
+ return WalkResult::advance();
+ }
+
+ // If succeeded, collect results.
+ for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
+ if (mapping.size() != 1) {
+ maybeFailure.emplace(emitSilenceableError()
+ << "result #" << i << ", associated with "
+ << mapping.size()
+ << " payload objects, expected 1");
+ return WalkResult::interrupt();
+ }
+ rawResults[i].push_back(mapping[0]);
+ }
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return std::move(*maybeFailure);
+ assert(!maybeFailure && "failure set but the walk was not interrupted");
+
+ for (auto &&[opResult, rawResult] :
+ llvm::zip(getOperation()->getResults(), rawResults)) {
----------------
matthias-springer wrote:
nit: `zip_equal`
https://github.com/llvm/llvm-project/pull/76724
More information about the Mlir-commits
mailing list