[Mlir-commits] [mlir] [mlir] introduce transform.collect_matching (PR #76724)

Matthias Springer llvmlistbot at llvm.org
Tue Jan 2 07:39:29 PST 2024


================
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
   return DiagnosedSilenceableFailure::success();
 }
 
+/// Returns `true` if both types implement one of the interfaces provided as
+/// template parameters.
+template <typename... Tys>
+static bool implementSameInterface(Type t1, Type t2) {
+  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
+}
+
+/// Returns `true` if both types implement one of the transform dialect
+/// interfaces.
+static bool implementSameTransformInterface(Type t1, Type t2) {
+  return implementSameInterface<transform::TransformHandleTypeInterface,
+                                transform::TransformParamTypeInterface,
+                                transform::TransformValueHandleTypeInterface>(
+      t1, t2);
+}
+
+//===----------------------------------------------------------------------===//
+// CollectMatchingOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
+  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
+      getOperation(), getMatcher());
+  if (matcher.isExternal()) {
+    return emitDefiniteFailure()
+           << "unresolved external symbol " << getMatcher();
+  }
+
+  SmallVector<SmallVector<MappedValue>, 2> rawResults;
+  rawResults.resize(getOperation()->getNumResults());
+  std::optional<DiagnosedSilenceableFailure> maybeFailure;
+  for (Operation *root : state.getPayloadOps(getRoot())) {
+    WalkResult walkResult = root->walk([&](Operation *op) {
+      DEBUG_MATCHER({
+        DBGS_MATCHER() << "matching ";
+        op->print(llvm::dbgs(),
+                  OpPrintingFlags().assumeVerified().skipRegions());
+        llvm::dbgs() << " @" << op << "\n";
+      });
+
+      // Try matching.
+      SmallVector<SmallVector<MappedValue>> mappings;
+      DiagnosedSilenceableFailure diag =
+          matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+      if (diag.isDefiniteFailure())
+        return WalkResult::interrupt();
+      if (diag.isSilenceableFailure()) {
+        DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
+                                     << " failed: " << diag.getMessage());
+        return WalkResult::advance();
+      }
+
+      // If succeeded, collect results.
+      for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
+        if (mapping.size() != 1) {
+          maybeFailure.emplace(emitSilenceableError()
+                               << "result #" << i << ", associated with "
+                               << mapping.size()
+                               << " payload objects, expected 1");
+          return WalkResult::interrupt();
+        }
+        rawResults[i].push_back(mapping[0]);
+      }
+      return WalkResult::advance();
+    });
+    if (walkResult.wasInterrupted())
+      return std::move(*maybeFailure);
+    assert(!maybeFailure && "failure set but the walk was not interrupted");
+
+    for (auto &&[opResult, rawResult] :
+         llvm::zip(getOperation()->getResults(), rawResults)) {
+      results.setMappedValues(opResult, rawResult);
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::CollectMatchingOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getRoot(), effects);
+  producesHandle(getResults(), effects);
+  onlyReadsPayload(effects);
+}
+
+LogicalResult transform::CollectMatchingOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
+      symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
+  if (!matcherSymbol ||
+      !isa<TransformOpInterface>(matcherSymbol.getOperation()))
+    return emitError() << "unresolved matcher symbol " << getMatcher();
+
+  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
+  if (argumentTypes.size() != 1 ||
+      !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
+    return emitError()
+           << "expected the matcher to take one operation handle argument";
+  }
+  if (!matcherSymbol.getArgAttr(
+          0, transform::TransformDialect::kArgReadOnlyAttrName)) {
----------------
matthias-springer wrote:

Is there an easy way to check that matcher does not modify the payload?


https://github.com/llvm/llvm-project/pull/76724


More information about the Mlir-commits mailing list