[Mlir-commits] [mlir] [mlir] introduce transform.collect_matching (PR #76724)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Jan 9 02:18:04 PST 2024


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/76724

>From 8af4755963635c007d8556e29dd3d1126aeae233 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 2 Jan 2024 13:50:24 +0000
Subject: [PATCH 1/4] [mlir] introduce transform.collect_matching

Introduce a new match combinator into the transform dialect. This
operation collects all operations that are yielded by a satisfactory
match into its results. This is a simpler version of `foreach_match`
that can be inserted directly into existing transform scripts.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td |  32 +++-
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 150 ++++++++++++++++--
 mlir/test/Dialect/Transform/ops-invalid.mlir  |  68 ++++++++
 .../Dialect/Transform/test-interpreter.mlir   |  29 ++++
 4 files changed, 261 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index fcdb21d21503a1..32c4828f66359a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -460,6 +460,36 @@ def NumAssociationsOp : TransformDialectOp<"num_associations",
   let hasVerifier = 1;
 }
 
+def CollectMatchingOp : TransformDialectOp<"collect_matching", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+    DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Collects all payload ops that match the given named matcher";
+  let description = [{
+    Collects operations nested under `root` or other payload IR objects that
+    match the given matcher expressed as a named sequence. The matcher sequence
+    must accept exactly one argument that it is not allowed to modify. It must
+    yield as many values as this op has results. Each of the yielded values must
+    be associated with exactly one payload object. If any operation in the
+    matcher sequence produces a silenceable failure, the matcher advances to the
+    next payload operation in the walk order without finishing the sequence.
+
+    The results of this operation are constructed by concatenating values
+    yielded by successful application of the matcher named sequence.
+
+    The operation succeeds unless the matcher sequence produced a definite
+    failure for any invocation.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$root,
+                       SymbolRefAttr:$matcher);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+
+  let assemblyFormat = [{
+    $matcher `in` $root attr-dict `:` functional-type($root, $results)
+  }];
+}
+
 def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
@@ -674,7 +704,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
 
 def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
-     NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+     NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
   let summary = "Get handle to the producer of this operation's operand number";
   let description = [{
     The handle defined by this Transform op corresponds to operation that
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index aa4694c88d3b2a..85eade07c6d5b6 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 //===----------------------------------------------------------------------===//
-// ForeachMatchOp
+// CollectMatchingOp
 //===----------------------------------------------------------------------===//
 
 /// Applies matcher operations from the given `block` assigning `op` as the
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
   return DiagnosedSilenceableFailure::success();
 }
 
+/// Returns `true` if both types implement one of the interfaces provided as
+/// template parameters.
+template <typename... Tys>
+static bool implementSameInterface(Type t1, Type t2) {
+  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
+}
+
+/// Returns `true` if both types implement one of the transform dialect
+/// interfaces.
+static bool implementSameTransformInterface(Type t1, Type t2) {
+  return implementSameInterface<transform::TransformHandleTypeInterface,
+                                transform::TransformParamTypeInterface,
+                                transform::TransformValueHandleTypeInterface>(
+      t1, t2);
+}
+
+//===----------------------------------------------------------------------===//
+// CollectMatchingOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
+  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
+      getOperation(), getMatcher());
+  if (matcher.isExternal()) {
+    return emitDefiniteFailure()
+           << "unresolved external symbol " << getMatcher();
+  }
+
+  SmallVector<SmallVector<MappedValue>, 2> rawResults;
+  rawResults.resize(getOperation()->getNumResults());
+  std::optional<DiagnosedSilenceableFailure> maybeFailure;
+  for (Operation *root : state.getPayloadOps(getRoot())) {
+    WalkResult walkResult = root->walk([&](Operation *op) {
+      DEBUG_MATCHER({
+        DBGS_MATCHER() << "matching ";
+        op->print(llvm::dbgs(),
+                  OpPrintingFlags().assumeVerified().skipRegions());
+        llvm::dbgs() << " @" << op << "\n";
+      });
+
+      // Try matching.
+      SmallVector<SmallVector<MappedValue>> mappings;
+      DiagnosedSilenceableFailure diag =
+          matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+      if (diag.isDefiniteFailure())
+        return WalkResult::interrupt();
+      if (diag.isSilenceableFailure()) {
+        DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
+                                     << " failed: " << diag.getMessage());
+        return WalkResult::advance();
+      }
+
+      // If succeeded, collect results.
+      for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
+        if (mapping.size() != 1) {
+          maybeFailure.emplace(emitSilenceableError()
+                               << "result #" << i << ", associated with "
+                               << mapping.size()
+                               << " payload objects, expected 1");
+          return WalkResult::interrupt();
+        }
+        rawResults[i].push_back(mapping[0]);
+      }
+      return WalkResult::advance();
+    });
+    if (walkResult.wasInterrupted())
+      return std::move(*maybeFailure);
+    assert(!maybeFailure && "failure set but the walk was not interrupted");
+
+    for (auto &&[opResult, rawResult] :
+         llvm::zip(getOperation()->getResults(), rawResults)) {
+      results.setMappedValues(opResult, rawResult);
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::CollectMatchingOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getRoot(), effects);
+  producesHandle(getResults(), effects);
+  onlyReadsPayload(effects);
+}
+
+LogicalResult transform::CollectMatchingOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
+      symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
+  if (!matcherSymbol ||
+      !isa<TransformOpInterface>(matcherSymbol.getOperation()))
+    return emitError() << "unresolved matcher symbol " << getMatcher();
+
+  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
+  if (argumentTypes.size() != 1 ||
+      !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
+    return emitError()
+           << "expected the matcher to take one operation handle argument";
+  }
+  if (!matcherSymbol.getArgAttr(
+          0, transform::TransformDialect::kArgReadOnlyAttrName)) {
+    return emitError() << "expected the matcher argument to be marked readonly";
+  }
+
+  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
+  if (resultTypes.size() != getOperation()->getNumResults()) {
+    return emitError()
+           << "expected the matcher to yield as many values as op has results ("
+           << getOperation()->getNumResults() << "), got "
+           << resultTypes.size();
+  }
+
+  for (auto &&[i, matcherType, resultType] :
+       llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
+    if (implementSameTransformInterface(matcherType, resultType))
+      continue;
+
+    return emitError()
+           << "mismatching type interfaces for matcher result and op result #"
+           << i;
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ForeachMatchOp
+//===----------------------------------------------------------------------===//
+
 DiagnosedSilenceableFailure
 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &results,
@@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
   return success();
 }
 
-/// Returns `true` if both types implement one of the interfaces provided as
-/// template parameters.
-template <typename... Tys>
-static bool implementSameInterface(Type t1, Type t2) {
-  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
-}
-
-/// Returns `true` if both types implement one of the transform dialect
-/// interfaces.
-static bool implementSameTransformInterface(Type t1, Type t2) {
-  return implementSameInterface<transform::TransformHandleTypeInterface,
-                                transform::TransformParamTypeInterface,
-                                transform::TransformValueHandleTypeInterface>(
-      t1, t2);
-}
-
 /// Checks that the attributes of the function-like operation have correct
 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
 /// annotations being present even if they can be inferred from the body.
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 5123958b02bfb8..233dbbcb6804cc 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -704,3 +704,71 @@ transform.sequence failures(propagate) {
   // expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
   transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{unresolved matcher symbol @missing_symbol}}
+    transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{expected the matcher to take one operation handle argument}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher() {
+    transform.yield
+  }
+}
+
+// -----
+
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{expected the matcher argument to be marked readonly}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op) {
+    transform.yield
+  }
+}
+
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) {
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{mismatching type interfaces for matcher result and op result #0}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.yield %arg0 : !transform.any_op
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 3bbf875ef309ec..c1f0f8da1e922f 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2380,3 +2380,32 @@ module @named_inclusion attributes { transform.with_named_sequence } {
     transform.yield
   }
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    %0 = transform.merge_handles %arg0, %arg0 : !transform.any_op
+    transform.yield %0 : !transform.any_op
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-error @below {{unresolved external symbol @matcher}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+}

>From d93e1ce01a01841487f628f38cc6d8a627b39375 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Wed, 3 Jan 2024 10:59:31 +0000
Subject: [PATCH 2/4] address review

---
 .../mlir/Dialect/Transform/IR/TransformOps.td  |  6 ++++--
 mlir/lib/Dialect/Transform/IR/TransformOps.cpp |  2 +-
 .../Dialect/Transform/test-interpreter.mlir    | 18 ++++++++++++++++++
 3 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 32c4828f66359a..1cd9d76a34ca02 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -474,8 +474,10 @@ def CollectMatchingOp : TransformDialectOp<"collect_matching", [
     matcher sequence produces a silenceable failure, the matcher advances to the
     next payload operation in the walk order without finishing the sequence.
 
-    The results of this operation are constructed by concatenating values
-    yielded by successful application of the matcher named sequence.
+    The i-th result of this operation is constructed by concatenating the i-th
+    yielded payload IR objects of all successful matcher sequence applications.
+    All results are guaranteed to be mapped to the same number of payload IR
+    objects.
 
     The operation succeeds unless the matcher sequence produced a definite
     failure for any invocation.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 85eade07c6d5b6..b80fc09751d2aa 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -896,7 +896,7 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
     assert(!maybeFailure && "failure set but the walk was not interrupted");
 
     for (auto &&[opResult, rawResult] :
-         llvm::zip(getOperation()->getResults(), rawResults)) {
+         llvm::zip_equal(getOperation()->getResults(), rawResults)) {
       results.setMappedValues(opResult, rawResult);
     }
   }
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index c1f0f8da1e922f..1ec99ad3978468 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2409,3 +2409,21 @@ module attributes { transform.with_named_sequence } {
 
   transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-remark @below {{matched}}
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-remark @below {{matched}}
+    %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.test_print_remark_at_operand %0, "matched" : !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.match.operation_name %arg0 ["transform.sequence", "transform.collect_matching"] : !transform.any_op
+    transform.yield %arg0 : !transform.any_op
+  }
+}

>From f497482fed4e5fa3cf645248e1be897bd4176cf5 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 9 Jan 2024 09:01:10 +0000
Subject: [PATCH 3/4] Rephrase documentation

---
 .../mlir/Dialect/Transform/IR/TransformOps.td     | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 1cd9d76a34ca02..fe2c28f45aea04 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -466,13 +466,14 @@ def CollectMatchingOp : TransformDialectOp<"collect_matching", [
     DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let summary = "Collects all payload ops that match the given named matcher";
   let description = [{
-    Collects operations nested under `root` or other payload IR objects that
-    match the given matcher expressed as a named sequence. The matcher sequence
-    must accept exactly one argument that it is not allowed to modify. It must
-    yield as many values as this op has results. Each of the yielded values must
-    be associated with exactly one payload object. If any operation in the
-    matcher sequence produces a silenceable failure, the matcher advances to the
-    next payload operation in the walk order without finishing the sequence.
+    Collects operations or other payload IR objects nested under `root`
+    (inclusive) that match the given matcher expressed as a named sequence. The
+    matcher sequence must accept exactly one argument that it is not allowed to
+    modify. It must yield as many values as this op has results. Each of the
+    yielded values must be associated with exactly one payload object. If any
+    operation in the matcher sequence produces a silenceable failure, the
+    matcher advances to the next payload operation in the walk order without
+    finishing the sequence.
 
     The i-th result of this operation is constructed by concatenating the i-th
     yielded payload IR objects of all successful matcher sequence applications.

>From b665e800102908ebb1555ca9ebc9dbe1c0e14857 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 9 Jan 2024 10:17:32 +0000
Subject: [PATCH 4/4] Update test

---
 mlir/test/Dialect/Transform/test-interpreter.mlir | 13 +++++--------
 1 file changed, 5 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 1ec99ad3978468..4ecd731ce4178f 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2384,8 +2384,7 @@ module @named_inclusion attributes { transform.with_named_sequence } {
 // -----
 
 module attributes { transform.with_named_sequence } {
-  transform.sequence failures(propagate) {
-  ^bb0(%arg0: !transform.any_op):
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
     // expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
     transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -2400,8 +2399,7 @@ module attributes { transform.with_named_sequence } {
 // -----
 
 module attributes { transform.with_named_sequence } {
-  transform.sequence failures(propagate) {
-  ^bb0(%arg0: !transform.any_op):
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
     // expected-error @below {{unresolved external symbol @matcher}}
     transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -2413,17 +2411,16 @@ module attributes { transform.with_named_sequence } {
 // -----
 
 module attributes { transform.with_named_sequence } {
-  // expected-remark @below {{matched}}
-  transform.sequence failures(propagate) {
-  ^bb0(%arg0: !transform.any_op):
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
     // expected-remark @below {{matched}}
     %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-remark @below {{matched}}
     transform.test_print_remark_at_operand %0, "matched" : !transform.any_op
     transform.yield
   }
 
   transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
-    transform.match.operation_name %arg0 ["transform.sequence", "transform.collect_matching"] : !transform.any_op
+    transform.match.operation_name %arg0 ["transform.test_print_remark_at_operand", "transform.collect_matching"] : !transform.any_op
     transform.yield %arg0 : !transform.any_op
   }
 }



More information about the Mlir-commits mailing list