[Mlir-commits] [mlir] [mlir] introduce transform.collect_matching (PR #76724)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Mon Jan 8 05:10:14 PST 2024
================
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
return DiagnosedSilenceableFailure::success();
}
+/// Returns `true` if both types implement one of the interfaces provided as
+/// template parameters.
+template <typename... Tys>
+static bool implementSameInterface(Type t1, Type t2) {
+ return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
+}
+
+/// Returns `true` if both types implement one of the transform dialect
+/// interfaces.
+static bool implementSameTransformInterface(Type t1, Type t2) {
+ return implementSameInterface<transform::TransformHandleTypeInterface,
+ transform::TransformParamTypeInterface,
+ transform::TransformValueHandleTypeInterface>(
+ t1, t2);
+}
+
+//===----------------------------------------------------------------------===//
+// CollectMatchingOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
+ getOperation(), getMatcher());
+ if (matcher.isExternal()) {
+ return emitDefiniteFailure()
+ << "unresolved external symbol " << getMatcher();
+ }
+
+ SmallVector<SmallVector<MappedValue>, 2> rawResults;
+ rawResults.resize(getOperation()->getNumResults());
+ std::optional<DiagnosedSilenceableFailure> maybeFailure;
+ for (Operation *root : state.getPayloadOps(getRoot())) {
+ WalkResult walkResult = root->walk([&](Operation *op) {
+ DEBUG_MATCHER({
+ DBGS_MATCHER() << "matching ";
+ op->print(llvm::dbgs(),
+ OpPrintingFlags().assumeVerified().skipRegions());
+ llvm::dbgs() << " @" << op << "\n";
+ });
+
+ // Try matching.
+ SmallVector<SmallVector<MappedValue>> mappings;
+ DiagnosedSilenceableFailure diag =
+ matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
----------------
ftynse wrote:
Normally, _match_ operations should not. But we don't internally differentiate match ops from other transform ops internally right now.
https://github.com/llvm/llvm-project/pull/76724
More information about the Mlir-commits
mailing list