[Mlir-commits] [mlir] [mlir] make transform.foreach_match forward arguments (PR #89920)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri May 3 00:58:58 PDT 2024


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/89920

>From 0fb51dd43ec0e2db8093894c9bc0b73e0515b1c2 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Wed, 24 Apr 2024 13:09:27 +0000
Subject: [PATCH] [mlir] make transform.foreach_match forward arguments

It may be useful to have access to additional handles or parameters when
performing matches and actions in `foreach_match`, for example, to
parameterize the matcher by rank or restrict it in a non-trivial way.
Enable `foreach_match` to forward additional handles from operands to
matcher symbols and from action symbols to results.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td |  62 ++++---
 .../Interfaces/TransformInterfaces.h          |  13 ++
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 160 ++++++++++++++----
 .../Interfaces/TransformInterfaces.cpp        |  34 +++-
 .../test/Dialect/Transform/foreach-match.mlir | 110 ++++++++++++
 mlir/test/Dialect/Transform/ops-invalid.mlir  |  85 +++++++++-
 6 files changed, 395 insertions(+), 69 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index fbac1ffb621fd2..77048a28d75108 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -512,7 +512,10 @@ def CollectMatchingOp : TransformDialectOp<"collect_matching", [
 def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
-    DeclareOpInterfaceMethods<TransformOpInterface>]> {
+    DeclareOpInterfaceMethods<TransformOpInterface,
+                              ["allowsRepeatedHandleOperands"]>,
+    DeclareOpInterfaceMethods<OpAsmOpInterface,
+                              ["getAsmResultNames"]>]> {
   let summary = "Applies named sequences when a named matcher succeeds";
   let description = [{
     Given a pair of co-indexed lists of transform dialect symbols (such as
@@ -528,25 +531,31 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     the following matchers are not applied to the same payload operation. If the
     action succeeds, the next payload operation in walk order is matched. If it
     fails, both silenceable and definite errors are propagated as the result of
-    this op.
-
-    The matcher symbol must take one operand of a type that implements the same
-    transform dialect interface as the `root` operand (a check is performed at
-    application time to see if the associated payload satisfies the constraints
-    of the actual type). It must not consume the operand as multiple matchers
+    this op; propagation of silenceable errors is postponed until the end of the
+    walk.
+
+    The matcher symbol must take at least one operand of a type that implements
+    the same transform dialect interface as the `root` operand (a check is
+    performed at application time to see if the associated payload satisfies the
+    constraints of the actual type), and may take additional operands with a
+    similar type requirement. It must not consume operands as multiple matchers
     may be applied. The matcher may produce any number of results. The action
     symbol paired with the matcher must take the same number of arguments as the
     matcher has results, and these arguments must implement the same transform
     dialect interfaces, but not necessarily have the exact same type (again, a
     check is performed at application time to see if the associated payload
-    satisfies the constraints of actual types on both sides). The action symbol
-    may not have results. The actions are expected to only modify payload
-    operations nested in the `root` payload operations associated with the
-    operand of this transform operation. Furhermore, the actions may not modify
-    operations outside of the currently matched payload operation, e.g., they
-    may not modify sibling or parent operations. If such behavior is desired,
-    the parent must be matched first and the nested operations obtained by
-    traversing the IR from the parent. This is due to the matching being
+    satisfies the constraints of actual types on both sides).
+
+    The action symbol may have results that are accumulated from all actions and
+    returned from the `foreach_match` operation on success. Unless the
+    `flatten_results` attribute is present, each action result must be
+    associated with exactly one payload entity. The actions are expected to only
+    modify payload operations nested in the `root` payload operations associated
+    with the operand of this transform operation. Furthermore, the actions may
+    not modify operations outside of the currently matched payload operation,
+    e.g., they may not modify sibling or parent operations. If such behavior is
+    desired, the parent must be matched first and the nested operations obtained
+    by traversing the IR from the parent. This is due to the matching being
     performed as a post-order IR walk.
 
     This operation consumes the operand and produces a new handle associated
@@ -573,19 +582,26 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     produced a definite failure.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$root,
-                       UnitAttr:$restrict_root,
-                       SymbolRefArrayAttr:$matchers,
-                       SymbolRefArrayAttr:$actions);
-  let results = (outs TransformHandleTypeInterface:$updated);
+  let arguments =
+      (ins TransformHandleTypeInterface:$root,
+           Variadic<Transform_AnyHandleOrParamType>:$forwarded_inputs,
+           UnitAttr:$restrict_root,
+           UnitAttr:$flatten_results,
+           SymbolRefArrayAttr:$matchers,
+           SymbolRefArrayAttr:$actions);
+  let results =
+      (outs TransformHandleTypeInterface:$updated,
+            Variadic<Transform_AnyHandleOrParamType>:$forwarded_outputs);
 
   let assemblyFormat = [{
-    (`restrict_root` $restrict_root^)?
+    oilist( `restrict_root` $restrict_root
+          | `flatten_results` $flatten_results
+          )
     `in`
-    $root
+    $root (`,` $forwarded_inputs^)?
     custom<ForeachMatchSymbols>($matchers, $actions)
     attr-dict
-    `:` functional-type($root, $updated)
+    `:` functional-type(operands, results)
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 59cc2f22c93813..21795753ac5f36 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -52,6 +52,17 @@ void getPotentialTopLevelEffects(
 /// Verification hook for TransformOpInterface.
 LogicalResult verifyTransformOpInterface(Operation *op);
 
+/// Appends the entities associated with the given transform values in `state`
+/// to the pre-existing list of mappings. The array of mappings must have as
+/// many elements as values. If `flatten` is set, multiple values may be
+/// associated with each transform value, and this always succeeds. Otherwise,
+/// checks that each value has exactly one mapping associated and return failure
+/// otherwise.
+LogicalResult appendValueMappings(
+    MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
+    ValueRange values, const transform::TransformState &state,
+    bool flatten = true);
+
 /// Populates `mappings` with mapped values associated with the given transform
 /// IR values in the given `state`.
 void prepareValueMappings(
@@ -317,6 +328,8 @@ class TransformState {
   }
   LogicalResult mapBlockArgument(BlockArgument argument,
                                  ArrayRef<MappedValue> values);
+  LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
+                                  ArrayRef<SmallVector<MappedValue>> mapping);
 
   // Forward declarations to support limited visibility.
   class RegionScope;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7a5a6974700586..eb09f007fbca88 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
@@ -834,19 +835,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 // CollectMatchingOp
 //===----------------------------------------------------------------------===//
 
-/// Applies matcher operations from the given `block` assigning `op` as the
-/// payload of the block's first argument. Updates `state` accordingly. If any
-/// of the matcher produces a silenceable failure, discards it (printing the
-/// content to the debug output stream) and returns failure. If any of the
-/// matchers produces a definite failure, reports it and returns failure. If all
-/// matchers in the block succeed, populates `mappings` with the payload
-/// entities associated with the block terminator operands.
+/// Applies matcher operations from the given `block` using
+/// `blockArgumentMapping` to initialize block arguments. Updates `state`
+/// accordingly. If any of the matcher produces a silenceable failure, discards
+/// it (printing the content to the debug output stream) and returns failure. If
+/// any of the matchers produces a definite failure, reports it and returns
+/// failure. If all matchers in the block succeed, populates `mappings` with the
+/// payload entities associated with the block terminator operands. Note that
+/// `mappings` will be cleared before that.
 static DiagnosedSilenceableFailure
-matchBlock(Block &block, Operation *op, transform::TransformState &state,
+matchBlock(Block &block,
+           ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
+           transform::TransformState &state,
            SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
   assert(block.getParent() && "cannot match using a detached block");
   auto matchScope = state.make_region_scope(*block.getParent());
-  if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
+  if (failed(
+          state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
     return DiagnosedSilenceableFailure::definiteFailure();
 
   for (Operation &match : block.without_terminator()) {
@@ -866,6 +871,9 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
   // Remember the values mapped to the terminator operands so we can
   // forward them to the action.
   ValueRange yieldedValues = block.getTerminator()->getOperands();
+  // Our contract with the caller is that the mappings will contain only the
+  // newly mapped values, clear the rest.
+  mappings.clear();
   transform::detail::prepareValueMappings(mappings, yieldedValues, state);
   return DiagnosedSilenceableFailure::success();
 }
@@ -915,8 +923,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
 
       // Try matching.
       SmallVector<SmallVector<MappedValue>> mappings;
-      DiagnosedSilenceableFailure diag =
-          matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+      SmallVector<transform::MappedValue> inputMapping({op});
+      DiagnosedSilenceableFailure diag = matchBlock(
+          matcher.getFunctionBody().front(),
+          ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
+          mappings);
       if (diag.isDefiniteFailure())
         return WalkResult::interrupt();
       if (diag.isSilenceableFailure()) {
@@ -1001,6 +1012,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses(
 // ForeachMatchOp
 //===----------------------------------------------------------------------===//
 
+// This is fine because nothing is actually consumed by this op.
+bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
+
 DiagnosedSilenceableFailure
 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &results,
@@ -1030,6 +1044,18 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
 
   DiagnosedSilenceableFailure overallDiag =
       DiagnosedSilenceableFailure::success();
+
+  SmallVector<SmallVector<MappedValue>> matchInputMapping;
+  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
+  SmallVector<SmallVector<MappedValue>> actionResultMapping;
+  // Explicitly add the mapping for the first block argument (the op being
+  // matched).
+  matchInputMapping.emplace_back();
+  transform::detail::prepareValueMappings(matchInputMapping,
+                                          getForwardedInputs(), state);
+  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
+  actionResultMapping.resize(getForwardedOutputs().size());
+
   for (Operation *root : state.getPayloadOps(getRoot())) {
     WalkResult walkResult = root->walk([&](Operation *op) {
       // If getRestrictRoot is not present, skip over the root op itself so we
@@ -1044,11 +1070,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
         llvm::dbgs() << " @" << op << "\n";
       });
 
+      firstMatchArgument.clear();
+      firstMatchArgument.push_back(op);
+
       // Try all the match/action pairs until the first successful match.
       for (auto [matcher, action] : matchActionPairs) {
-        SmallVector<SmallVector<MappedValue>> mappings;
         DiagnosedSilenceableFailure diag =
-            matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+            matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
+                       state, matchOutputMapping);
         if (diag.isDefiniteFailure())
           return WalkResult::interrupt();
         if (diag.isSilenceableFailure()) {
@@ -1058,10 +1087,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
         }
 
         auto scope = state.make_region_scope(action.getFunctionBody());
-        for (auto &&[arg, map] : llvm::zip_equal(
-                 action.getFunctionBody().front().getArguments(), mappings)) {
-          if (failed(state.mapBlockArgument(arg, map)))
-            return WalkResult::interrupt();
+        if (failed(state.mapBlockArguments(
+                action.getFunctionBody().front().getArguments(),
+                matchOutputMapping))) {
+          return WalkResult::interrupt();
         }
 
         for (Operation &transform :
@@ -1082,6 +1111,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
             continue;
           }
         }
+        if (failed(detail::appendValueMappings(
+                MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
+                action.getFunctionBody().front().getTerminator()->getOperands(),
+                state, getFlattenResults()))) {
+          emitDefiniteFailure()
+              << "action @" << action.getName()
+              << " has results associated with multiple payload entities, "
+                 "but flattening was not requested";
+          return WalkResult::interrupt();
+        }
         break;
       }
       return WalkResult::advance();
@@ -1096,9 +1135,21 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
   // by actions, are invalidated.
   results.set(llvm::cast<OpResult>(getUpdated()),
               state.getPayloadOps(getRoot()));
+  for (auto &&[result, mapping] :
+       llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
+    results.setMappedValues(result, mapping);
+  }
   return overallDiag;
 }
 
+void transform::ForeachMatchOp::getAsmResultNames(
+    OpAsmSetValueNameFn setNameFn) {
+  setNameFn(getUpdated(), "updated_root");
+  for (Value v : getForwardedOutputs()) {
+    setNameFn(v, "yielded");
+  }
+}
+
 void transform::ForeachMatchOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   // Bail if invalid.
@@ -1108,7 +1159,8 @@ void transform::ForeachMatchOp::getEffects(
   }
 
   consumesHandle(getRoot(), effects);
-  producesHandle(getUpdated(), effects);
+  onlyReadsHandle(getForwardedInputs(), effects);
+  producesHandle(getResults(), effects);
   modifiesPayload(effects);
 }
 
@@ -1224,6 +1276,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
       StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
   for (auto &&[matcher, action] :
        llvm::zip_equal(getMatchers(), getActions())) {
+    // Presence and typing.
     auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
         symbolTable.lookupNearestSymbolFrom(getOperation(),
                                             cast<SymbolRefAttr>(matcher)));
@@ -1250,8 +1303,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
       return failure();
     }
 
-    ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
-    ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
+    // Input -> matcher forwarding.
+    TypeRange operandTypes = getOperandTypes();
+    TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
+    if (operandTypes.size() != matcherArguments.size()) {
+      InFlightDiagnostic diag =
+          emitError() << "the number of operands (" << operandTypes.size()
+                      << ") doesn't match the number of matcher arguments ("
+                      << matcherArguments.size() << ") for " << matcher;
+      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
+      return diag;
+    }
+    for (auto &&[i, operand, argument] :
+         llvm::enumerate(operandTypes, matcherArguments)) {
+      if (matcherSymbol.getArgAttr(i, consumedAttr)) {
+        InFlightDiagnostic diag =
+            emitOpError()
+            << "does not expect matcher symbol to consume its operand #" << i;
+        diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
+        return diag;
+      }
+
+      if (implementSameTransformInterface(operand, argument))
+        continue;
+
+      InFlightDiagnostic diag =
+          emitError()
+          << "mismatching type interfaces for operand and matcher argument #"
+          << i << " of matcher " << matcher;
+      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
+      return diag;
+    }
+
+    // Matcher -> action forwarding.
+    TypeRange matcherResults = matcherSymbol.getResultTypes();
+    TypeRange actionArguments = actionSymbol.getArgumentTypes();
     if (matcherResults.size() != actionArguments.size()) {
       return emitError() << "mismatching number of matcher results and "
                             "action arguments between "
@@ -1265,31 +1351,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
 
       return emitError() << "mismatching type interfaces for matcher result "
                             "and action argument #"
-                         << i;
+                         << i << "of matcher " << matcher << " and action "
+                         << action;
     }
 
-    if (!actionSymbol.getResultTypes().empty()) {
+    // Action -> result forwarding.
+    TypeRange actionResults = actionSymbol.getResultTypes();
+    auto resultTypes = TypeRange(getResultTypes()).drop_front();
+    if (actionResults.size() != resultTypes.size()) {
       InFlightDiagnostic diag =
-          emitError() << "action symbol is not expected to have results";
+          emitError() << "the number of action results ("
+                      << actionResults.size() << ") for " << action
+                      << " doesn't match the number of extra op results ("
+                      << resultTypes.size() << ")";
       diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
       return diag;
     }
+    for (auto &&[i, resultType, actionType] :
+         llvm::enumerate(resultTypes, actionResults)) {
+      if (implementSameTransformInterface(resultType, actionType))
+        continue;
 
-    if (matcherSymbol.getArgumentTypes().size() != 1 ||
-        !implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
-                                         getRoot().getType())) {
-      InFlightDiagnostic diag =
-          emitOpError() << "expects matcher symbol to have one argument with "
-                           "the same transform interface as the first operand";
-      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
-      return diag;
-    }
-
-    if (matcherSymbol.getArgAttr(0, consumedAttr)) {
       InFlightDiagnostic diag =
-          emitOpError()
-          << "does not expect matcher symbol to consume its operand";
-      diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
+          emitError() << "mismatching type interfaces for action result #" << i
+                      << " of action " << action << " and op result";
+      diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
       return diag;
     }
   }
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 48f3954b6cf69f..b6a35e23a5d1fc 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -206,6 +206,15 @@ transform::TransformState::mapBlockArgument(BlockArgument argument,
       .checkAndReport();
 }
 
+LogicalResult transform::TransformState::mapBlockArguments(
+    Block::BlockArgListType arguments,
+    ArrayRef<SmallVector<MappedValue>> mapping) {
+  for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
+    if (failed(mapBlockArgument(argument, values)))
+      return failure();
+  return success();
+}
+
 LogicalResult
 transform::TransformState::setPayloadOps(Value value,
                                          ArrayRef<Operation *> targets) {
@@ -1528,11 +1537,12 @@ void transform::detail::setApplyToOneResults(
 // Utilities for implementing transform ops with regions.
 //===----------------------------------------------------------------------===//
 
-void transform::detail::prepareValueMappings(
-    SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
-    ValueRange values, const transform::TransformState &state) {
-  for (Value operand : values) {
-    SmallVector<MappedValue> &mapped = mappings.emplace_back();
+LogicalResult transform::detail::appendValueMappings(
+    MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
+    ValueRange values, const transform::TransformState &state, bool flatten) {
+  assert(mappings.size() == values.size() && "mismatching number of mappings");
+  for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
+    size_t mappedSize = mapped.size();
     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
       llvm::append_range(mapped, state.getPayloadOps(operand));
     } else if (llvm::isa<TransformValueHandleTypeInterface>(
@@ -1543,7 +1553,21 @@ void transform::detail::prepareValueMappings(
              "unsupported kind of transform dialect value");
       llvm::append_range(mapped, state.getParams(operand));
     }
+
+    if (mapped.size() - mappedSize != 1 && !flatten)
+      return failure();
   }
+  return success();
+}
+
+void transform::detail::prepareValueMappings(
+    SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
+    ValueRange values, const transform::TransformState &state) {
+  mappings.resize(mappings.size() + values.size());
+  (void)appendValueMappings(
+      MutableArrayRef<SmallVector<transform::MappedValue>>(mappings).take_back(
+          values.size()),
+      values, state);
 }
 
 void transform::detail::forwardTerminatorOperands(
diff --git a/mlir/test/Dialect/Transform/foreach-match.mlir b/mlir/test/Dialect/Transform/foreach-match.mlir
index 206625ae0746be..a7cd8e9ff543ee 100644
--- a/mlir/test/Dialect/Transform/foreach-match.mlir
+++ b/mlir/test/Dialect/Transform/foreach-match.mlir
@@ -78,3 +78,113 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 }
+
+// -----
+
+// expected-remark @below {{op from within the matcher}}
+module attributes { transform.with_named_sequence } {
+  // expected-remark @below {{returned root}}
+  func.func @foo() {
+    return
+  }
+
+  transform.named_sequence @match_fail(
+      %op: !transform.any_op {transform.readonly},
+      %root: !transform.any_op {transform.readonly},
+      %param: !transform.param<i64> {transform.readonly}) -> (!transform.any_op, !transform.param<i64>) {
+    transform.test_succeed_if_operand_of_op_kind %op, "test.impossible_to_match" : !transform.any_op
+    transform.yield %root, %param : !transform.any_op, !transform.param<i64>
+  }
+
+  transform.named_sequence @match_succeed(
+      %op: !transform.any_op {transform.readonly},
+      %root: !transform.any_op {transform.readonly},
+      %param: !transform.param<i64> {transform.readonly}) -> (!transform.any_op, !transform.param<i64>) {
+    transform.debug.emit_remark_at %root, "op from within the matcher" : !transform.any_op
+    // expected-remark @below {{param from within the matcher 42}}
+    transform.debug.emit_param_as_remark %param, "param from within the matcher" : !transform.param<i64>
+    transform.yield %root, %param : !transform.any_op, !transform.param<i64>
+  }
+
+  transform.named_sequence @return(
+      %root: !transform.any_op {transform.readonly},
+      %param: !transform.param<i64> {transform.readonly}) -> (!transform.param<i64>, !transform.param<i64>, !transform.any_op) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    transform.yield %param, %param, %func : !transform.param<i64>, !transform.param<i64>, !transform.any_op
+  }
+
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    %param = transform.param.constant 42 : i64 -> !transform.param<i64>
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    %func2, %yielded:3 = transform.foreach_match restrict_root in %func, %root, %param
+      @match_fail -> @return,
+      @match_succeed -> @return
+      : (!transform.any_op, !transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.param<i64>, !transform.param<i64>, !transform.any_op)
+    transform.debug.emit_remark_at %yielded#2, "returned root" : !transform.any_op
+    // expected-remark @below {{42 : i64, 42 : i64}}
+    transform.debug.emit_param_as_remark %yielded#0: !transform.param<i64>
+    %num_roots = transform.num_associations %yielded#2 : (!transform.any_op) -> !transform.param<i64>
+    // expected-remark @below {{2 : i64}}
+    transform.debug.emit_param_as_remark %num_roots : !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  func.func private @foo()
+  func.func private @bar()
+
+  transform.named_sequence @match(
+      %op: !transform.any_op {transform.readonly},
+      %func: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+    transform.yield %func : !transform.any_op
+  }
+
+  transform.named_sequence @return(
+      %func: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+    transform.yield %func : !transform.any_op
+  }
+
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    %func2, %yielded = transform.foreach_match flatten_results restrict_root in %func, %func
+      @match -> @return
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %num = transform.num_associations %yielded : (!transform.any_op) -> !transform.param<i64>
+    // 2 funcs are yielded for each of the 2 funcs = 4:
+    // expected-remark @below {{4 : i64}}
+    transform.debug.emit_param_as_remark %num : !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
+
+module attributes { transform.with_named_sequence } {
+  func.func private @foo()
+  func.func private @bar()
+
+  transform.named_sequence @match(
+      %op: !transform.any_op {transform.readonly},
+      %func: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+    transform.yield %func : !transform.any_op
+  }
+
+  transform.named_sequence @return(
+      %func: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+    transform.yield %func : !transform.any_op
+  }
+
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{action @return has results associated with multiple payload entities, but flattening was not requested}}
+    %func2, %yielded = transform.foreach_match restrict_root in %func, %func
+      @match -> @return
+      : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %num = transform.num_associations %yielded : (!transform.any_op) -> !transform.param<i64>
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index cc04e65420c5b7..30a68cc5f3c448 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -629,7 +629,84 @@ module attributes { transform.with_named_sequence } {
 // -----
 
 module attributes { transform.with_named_sequence } {
-  transform.named_sequence @match() -> !transform.any_op
+  // expected-note @below {{symbol declaration}}
+  transform.named_sequence @match(!transform.any_op {transform.readonly}, !transform.any_op {transform.readonly}) -> !transform.any_op
+  transform.named_sequence @action(!transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%root: !transform.any_op):
+    // expected-error @below {{the number of operands (1) doesn't match the number of matcher arguments (2) for @match}}
+    transform.foreach_match in %root
+      @match -> @action : (!transform.any_op) -> !transform.any_op
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-note @below {{symbol declaration}}
+  transform.named_sequence @match(!transform.any_op {transform.readonly}, !transform.any_op {transform.consumed}) -> !transform.any_op
+  transform.named_sequence @action(!transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%root: !transform.any_op):
+    %r = transform.replicate num(%root) %root : !transform.any_op, !transform.any_op
+    // expected-error @below {{'transform.foreach_match' op does not expect matcher symbol to consume its operand #1}}
+    transform.foreach_match in %root, %r
+      @match -> @action : (!transform.any_op, !transform.any_op) -> !transform.any_op
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  // expected-note @below {{symbol declaration}}
+  transform.named_sequence @match(!transform.any_op {transform.readonly}, !transform.any_op {transform.readonly}) -> !transform.any_op
+  transform.named_sequence @action(!transform.any_op {transform.readonly})
+
+  transform.sequence failures(propagate) {
+  ^bb0(%root: !transform.any_op):
+    %r = transform.get_operand %root[0] : (!transform.any_op) -> !transform.any_value
+    // expected-error @below {{mismatching type interfaces for operand and matcher argument #1 of matcher @match}}
+    transform.foreach_match in %root, %r
+      @match -> @action : (!transform.any_op, !transform.any_value) -> !transform.any_op
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @match(!transform.any_op {transform.readonly}) -> !transform.any_op
+  // expected-note @below {{symbol declaration}}
+  transform.named_sequence @action(!transform.any_op {transform.readonly}) -> !transform.any_op
+
+  transform.sequence failures(propagate) {
+  ^bb0(%root: !transform.any_op):
+    // expected-error @below {{the number of action results (1) for @action doesn't match the number of extra op results (0)}}
+    transform.foreach_match in %root
+      @match -> @action : (!transform.any_op) -> !transform.any_op
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @match(!transform.any_op {transform.readonly}) -> !transform.any_op
+  // expected-note @below {{symbol declaration}}
+  transform.named_sequence @action(!transform.any_op {transform.readonly}) -> !transform.any_op
+
+  transform.sequence failures(propagate) {
+  ^bb0(%root: !transform.any_op):
+    // expected-error @below {{mismatching type interfaces for action result #0 of action @action and op result}}
+    transform.foreach_match in %root
+      @match -> @action : (!transform.any_op) -> (!transform.any_op, !transform.any_value)
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @match(!transform.any_op {transform.readonly}) -> !transform.any_op
   transform.named_sequence @action()
 
   transform.sequence failures(propagate) {
@@ -649,7 +726,7 @@ module attributes { transform.with_named_sequence } {
 
   transform.sequence failures(propagate) {
   ^bb0(%root: !transform.any_op):
-    // expected-error @below {{action symbol is not expected to have results}}
+    // expected-error @below {{the number of action results (1) for @action doesn't match the number of extra op results (0)}}
     transform.foreach_match in %root
       @match -> @action : (!transform.any_op) -> !transform.any_op
   }
@@ -664,7 +741,7 @@ module attributes { transform.with_named_sequence } {
 
   transform.sequence failures(propagate) {
   ^bb0(%root: !transform.any_op):
-    // expected-error @below {{expects matcher symbol to have one argument with the same transform interface as the first operand}}
+    // expected-error @below {{the number of operands (1) doesn't match the number of matcher arguments (0) for @match}}
     transform.foreach_match in %root
       @match -> @action : (!transform.any_op) -> !transform.any_op
   }
@@ -679,7 +756,7 @@ module attributes { transform.with_named_sequence } {
 
   transform.sequence failures(propagate) {
   ^bb0(%root: !transform.any_op):
-    // expected-error @below {{'transform.foreach_match' op does not expect matcher symbol to consume its operand}}
+    // expected-error @below {{'transform.foreach_match' op does not expect matcher symbol to consume its operand #0}}
     transform.foreach_match in %root
       @match -> @action : (!transform.any_op) -> !transform.any_op
   }



More information about the Mlir-commits mailing list