[Mlir-commits] [mlir] [mlir] introduce transform.collect_matching (PR #76724)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Wed Jan 3 03:03:06 PST 2024
https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/76724
>From 38942f62735a54cffe6848181cdb772fb1bc2285 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 2 Jan 2024 13:31:15 +0000
Subject: [PATCH 1/3] [mlir] introduce transform.num_associations
Add a new transform operation that creates a new parameter containing
the number of payload objects (operations, values or attributes)
associated with the argument. This is useful in matching and for
debugging purposes. This replaces three ad-hoc operations previously
provided by the test extension.
---
.../mlir/Dialect/Transform/IR/TransformOps.td | 21 ++++
.../lib/Dialect/Transform/IR/TransformOps.cpp | 29 ++++++
.../transform-op-bufferize-to-allocation.mlir | 15 ++-
.../Dialect/Linalg/transform-op-match.mlir | 5 +-
.../test/Dialect/Linalg/transform-op-pad.mlir | 3 +-
.../Dialect/Transform/test-interpreter.mlir | 96 ++++++++++++-------
.../Transform/test-loop-transforms.mlir | 9 +-
.../TestTransformDialectExtension.cpp | 45 ---------
.../TestTransformDialectExtension.td | 27 ------
9 files changed, 135 insertions(+), 115 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 307257f4a582be..da0162faa6e466 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -438,6 +438,27 @@ def CastOp : TransformDialectOp<"cast",
}];
}
+def NumAssociationsOp : TransformDialectOp<"num_associations",
+ [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ MatchOpInterface]> {
+ let summary =
+ "Returns the number of payload objects associated with the argument";
+ let description = [{
+ Given an argument, handle or parameter, returns a new parameter associated
+ with a single 64-bit number that corresponds to the number of payload
+ objects (operations or values for a handle, attributes for a parameter)
+ associated with the argument.
+
+ Always succeeds.
+ }];
+ let arguments = (ins Transform_AnyHandleOrParamType:$handle);
+ let results = (outs TransformParamTypeInterface:$num);
+ let assemblyFormat = [{
+ $handle attr-dict `:` functional-type(operands, results)
+ }];
+}
+
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7136e423470a28..ca644252f3514a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -1974,6 +1975,34 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
+//===----------------------------------------------------------------------===//
+// NumAssociationsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ size_t numAssociations =
+ llvm::TypeSwitch<Type, size_t>(getHandle().getType())
+ .Case([&](TransformHandleTypeInterface opHandle) {
+ return llvm::range_size(state.getPayloadOps(getHandle()));
+ })
+ .Case([&](TransformValueHandleTypeInterface valueHandle) {
+ return llvm::range_size(state.getPayloadValues(getHandle()));
+ })
+ .Case([&](TransformParamTypeInterface param) {
+ return llvm::range_size(state.getParams(getHandle()));
+ })
+ .Default([](Type) {
+ llvm_unreachable("unknown kind of transform dialect type");
+ return 0;
+ });
+ results.setParams(getNum().cast<OpResult>(),
+ rewriter.getI64IntegerAttr(numAssociations));
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index 49a52ba9e06f86..aa15ccf0beeee2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -36,13 +36,15 @@ module attributes {transform.with_named_sequence} {
// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+ %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
// Ensure that one linalg.copy was generated.
%mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op
+ %p2 = transform.num_associations %mat : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
transform.yield
}
}
@@ -73,18 +75,21 @@ module attributes {transform.with_named_sequence} {
// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+ %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
// Ensure that one linalg.copy was generated.
%linalg_copy = transform.select "linalg.copy" in %new : (!transform.any_op) -> !transform.any_op
+ %p2 = transform.num_associations %linalg_copy : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
// Ensure that one memref.alloca was generated.
%alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op
+ %p3 = transform.num_associations %alloca : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
// Make sure that One-Shot Bufferize can bufferize the rest.
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 15942db9b5db20..db5b5f1c786776 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -134,8 +134,9 @@ module attributes {transform.with_named_sequence} {
#linalg.iterator_type<parallel>,
#linalg.iterator_type<reduction>]}
in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %no_match : !transform.any_op
+ %p = transform.num_associations %no_match : (!transform.any_op) -> !transform.param<i64>
+ // expected-remark @below {{0}}
+ transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 6bca6c1fd6bf12..1f9d81a819e7fb 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -41,8 +41,9 @@ module attributes {transform.with_named_sequence} {
padding_dimensions=[0, 1, 2],
pack_paddings=[1, 1, 0]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
+ %p = transform.num_associations %copy_back : (!transform.op<"bufferization.materialize_in_destination">) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
+ transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index d9a11994eb9d90..a39e6f94cb34f6 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -575,8 +575,9 @@ transform.with_pdl_patterns {
%0 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
%2 = merge_handles deduplicate %0, %1 : !transform.any_op
+ %3 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %3 : !transform.param<i64>
}
}
@@ -676,11 +677,13 @@ module {
^bb0(%arg1: !transform.any_op):
%0 = pdl_match @func in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = replicate num(%0) %arg1 : !transform.any_op, !transform.any_op
+ %p = num_associations %1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
%2 = replicate num(%0) %1 : !transform.any_op, !transform.any_op
+ %p2 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
}
}
}
@@ -708,8 +711,9 @@ transform.with_pdl_patterns {
%f = pdl_match @const in %arg1 : (!transform.any_op) -> !transform.any_op
transform.foreach %f : !transform.any_op {
^bb2(%arg2: !transform.any_op):
+ %p = transform.num_associations %arg2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %arg2 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %arg2, "transform applied" : !transform.any_op
}
}
@@ -780,8 +784,9 @@ transform.with_pdl_patterns {
transform.yield %g : !transform.any_op
}
+ %p = transform.num_associations %results : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{3}}
- transform.test_print_number_of_associated_payload_ir_ops %results : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %results, "transform applied" : !transform.any_op
}
}
@@ -877,8 +882,9 @@ transform.sequence failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
%h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@@ -896,13 +902,15 @@ transform.sequence failures(suppress) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// Silenceable failure and all handles are now empty.
%h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %p2 = transform.num_associations %h_2#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %h_2#0 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
}
// -----
@@ -918,12 +926,15 @@ transform.sequence failures(propagate) {
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// No error, last result handle is empty.
%h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
+ %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
+ %p3 = transform.num_associations %h#2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %h#2 : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
}
// -----
@@ -940,10 +951,12 @@ transform.sequence failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli_2 {overflow_result = 0} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{3}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
+ %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
}
// -----
@@ -1668,8 +1681,9 @@ transform.sequence failures(propagate) {
// expected-remark @below {{2 iterations}}
transform.test_tracked_rewrite %0 : (!transform.any_op) -> ()
// One replacement op (test.drop_mapping) is dropped from the mapping.
+ %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
}
// -----
@@ -1684,20 +1698,24 @@ module {
%2 = transform.param.constant 1 -> !transform.param<i64>
%3 = transform.param.constant 2 -> !transform.param<i64>
%4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param<i64>
+ %p = num_associations %4 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_params %4 : !transform.param<i64>
+ test_print_param %p : !transform.param<i64>
%5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param<i64>
+ %p2 = num_associations %5 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_params %5 : !transform.param<i64>
+ test_print_param %p2 : !transform.param<i64>
%6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param<i64>
+ %p3 = num_associations %6 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_params %6 : !transform.param<i64>
+ test_print_param %p3 : !transform.param<i64>
%7 = transform.merge_handles %1, %1, %2, %3 : !transform.param<i64>
+ %p4 = num_associations %7 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_params %7 : !transform.param<i64>
+ test_print_param %p4 : !transform.param<i64>
}
}
@@ -1712,21 +1730,25 @@ transform.sequence failures(propagate) {
%3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value
%4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value
+ %p = num_associations %4 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_values %4 : !transform.any_value
+ test_print_param %p : !transform.param<i64>
%5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value
+ %p2 = num_associations %5 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_values %5 : !transform.any_value
+ test_print_param %p2 : !transform.param<i64>
%6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
%7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value
+ %p3 = num_associations %6 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_values %6 : !transform.any_value
+ test_print_param %p3 : !transform.param<i64>
%8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value
+ %p4 = num_associations %8 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_values %8 : !transform.any_value
+ test_print_param %p4 : !transform.param<i64>
}
// -----
@@ -1820,31 +1842,37 @@ transform.sequence failures(propagate) {
// There are 3 arith.constant ops.
%all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op
+ %p = num_associations %all : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// "deduplicate" has no effect because these are 3 different ops.
%merged_before = transform.merge_handles deduplicate %all : !transform.any_op
+ %p2 = num_associations %merged_before : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
// Apply CSE.
transform.apply_cse to %0 : !transform.any_op
// The handle is still mapped to 3 arith.constant ops.
+ %p3 = num_associations %all : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ test_print_param %p3 : !transform.param<i64>
// But they are all the same op.
%merged_after = transform.merge_handles deduplicate %all : !transform.any_op
+ %p4 = num_associations %merged_after : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op
+ test_print_param %p4 : !transform.param<i64>
// The other handles were also updated.
test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op
+ %p5 = num_associations %elim_first : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op
+ test_print_param %p5 : !transform.param<i64>
test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op
+ %p6 = num_associations %elim_second : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op
+ test_print_param %p6 : !transform.param<i64>
}
// -----
@@ -1907,14 +1935,16 @@ transform.sequence failures(propagate) {
// Get immediate parent.
%2 = transform.get_parent_op %0 : (!transform.any_op) -> !transform.any_op
test_print_remark_at_operand %2, "direct parent" : !transform.any_op
+ %p = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{2}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// Deduplicate results.
%3 = transform.structured.match ops{["test.qux"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%4 = transform.get_parent_op %3 {deduplicate} : (!transform.any_op) -> !transform.any_op
+ %p2 = num_associations %4 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
}
@@ -2029,8 +2059,9 @@ transform.sequence failures(propagate) {
// Match all ops inside the function (including the function itself).
%func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%0 = transform.structured.match in %func_op : (!transform.any_op) -> !transform.any_op
+ %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{5}}
- test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// Select "test.foo".
%foo = transform.select "test.foo" in %0 : (!transform.any_op) -> !transform.any_op
@@ -2060,8 +2091,9 @@ transform.sequence failures(propagate) {
%empty_op = transform.structured.match ops{["tensor.empty"]} in %func_op : (!transform.any_op) -> !transform.any_op
transform.apply_dce to %func_op : !transform.any_op
+ %p = num_associations %empty_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{0}}
- test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
+ test_print_param %p : !transform.param<i64>
}
diff --git a/mlir/test/Dialect/Transform/test-loop-transforms.mlir b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
index 425962757f720b..c34f4baf8cd972 100644
--- a/mlir/test/Dialect/Transform/test-loop-transforms.mlir
+++ b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
@@ -37,13 +37,16 @@ module attributes {transform.with_named_sequence} {
// Make sure that the handles are still valid (and were updated in case of
// the loop).
+ %p = transform.num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op
+ %p2 = transform.num_associations %1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
+ %p3 = transform.num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
transform.yield
}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index e8c25aca237251..9c69164e33f600 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -456,51 +456,6 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
return emitDefaultSilenceableFailure(target);
}
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
- transform::TransformRewriter &rewriter,
- transform::TransformResults &results, transform::TransformState &state) {
- if (!getHandle())
- emitRemark() << 0;
- emitRemark() << llvm::range_size(state.getPayloadOps(getHandle()));
- return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getHandle(), effects);
-}
-
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::apply(
- transform::TransformRewriter &rewriter,
- transform::TransformResults &results, transform::TransformState &state) {
- if (!getValueHandle())
- emitRemark() << 0;
- emitRemark() << llvm::range_size(state.getPayloadValues(getValueHandle()));
- return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getValueHandle(), effects);
-}
-
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::apply(
- transform::TransformRewriter &rewriter,
- transform::TransformResults &results, transform::TransformState &state) {
- if (!getParam())
- emitRemark() << 0;
- emitRemark() << llvm::range_size(state.getParams(getParam()));
- return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getParam(), effects);
-}
-
DiagnosedSilenceableFailure
mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 41f318db68405b..5cb47659fdbdfd 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -343,33 +343,6 @@ def TestMixedSuccessAndSilenceableOp
}];
}
-def TestPrintNumberOfAssociatedPayloadIROps
- : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_ops",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let arguments = (ins TransformHandleTypeInterface:$handle);
- let assemblyFormat = "$handle attr-dict `:` type($handle)";
- let cppNamespace = "::mlir::test";
-}
-
-def TestPrintNumberOfAssociatedPayloadIRValues
- : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_values",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let arguments = (ins TransformValueHandleTypeInterface:$value_handle);
- let assemblyFormat = "$value_handle attr-dict `:` type($value_handle)";
- let cppNamespace = "::mlir::test";
-}
-
-def TestPrintNumberOfAssociatedPayloadIRParams
- : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_params",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let arguments = (ins TransformParamTypeInterface:$param);
- let assemblyFormat = "$param attr-dict `:` type($param)";
- let cppNamespace = "::mlir::test";
-}
-
def TestCopyPayloadOp
: Op<Transform_Dialect, "test_copy_payload",
[DeclareOpInterfaceMethods<TransformOpInterface>,
>From f1b335a9dae6822ce4ca9b5be9f6227d9e97230c Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 2 Jan 2024 13:50:24 +0000
Subject: [PATCH 2/3] [mlir] introduce transform.collect_matching
Introduce a new match combinator into the transform dialect. This
operation collects all operations that are yielded by a satisfactory
match into its results. This is a simpler version of `foreach_match`
that can be inserted directly into existing transform scripts.
---
.../mlir/Dialect/Transform/IR/TransformOps.td | 32 +++-
.../lib/Dialect/Transform/IR/TransformOps.cpp | 150 ++++++++++++++++--
mlir/test/Dialect/Transform/ops-invalid.mlir | 68 ++++++++
.../Dialect/Transform/test-interpreter.mlir | 29 ++++
4 files changed, 261 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index da0162faa6e466..18cdbde54db353 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -459,6 +459,36 @@ def NumAssociationsOp : TransformDialectOp<"num_associations",
}];
}
+def CollectMatchingOp : TransformDialectOp<"collect_matching", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Collects all payload ops that match the given named matcher";
+ let description = [{
+ Collects operations nested under `root` or other payload IR objects that
+ match the given matcher expressed as a named sequence. The matcher sequence
+ must accept exactly one argument that it is not allowed to modify. It must
+ yield as many values as this op has results. Each of the yielded values must
+ be associated with exactly one payload object. If any operation in the
+ matcher sequence produces a silenceable failure, the matcher advances to the
+ next payload operation in the walk order without finishing the sequence.
+
+ The results of this operation are constructed by concatenating values
+ yielded by successful application of the matcher named sequence.
+
+ The operation succeeds unless the matcher sequence produced a definite
+ failure for any invocation.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$root,
+ SymbolRefAttr:$matcher);
+ let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+
+ let assemblyFormat = [{
+ $matcher `in` $root attr-dict `:` functional-type($root, $results)
+ }];
+}
+
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
@@ -673,7 +703,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
- NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+ NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
let summary = "Get handle to the producer of this operation's operand number";
let description = [{
The handle defined by this Transform op corresponds to operation that
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index ca644252f3514a..76293bfb31719c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
@@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
//===----------------------------------------------------------------------===//
-// ForeachMatchOp
+// CollectMatchingOp
//===----------------------------------------------------------------------===//
/// Applies matcher operations from the given `block` assigning `op` as the
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
return DiagnosedSilenceableFailure::success();
}
+/// Returns `true` if both types implement one of the interfaces provided as
+/// template parameters.
+template <typename... Tys>
+static bool implementSameInterface(Type t1, Type t2) {
+ return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
+}
+
+/// Returns `true` if both types implement one of the transform dialect
+/// interfaces.
+static bool implementSameTransformInterface(Type t1, Type t2) {
+ return implementSameInterface<transform::TransformHandleTypeInterface,
+ transform::TransformParamTypeInterface,
+ transform::TransformValueHandleTypeInterface>(
+ t1, t2);
+}
+
+//===----------------------------------------------------------------------===//
+// CollectMatchingOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
+ getOperation(), getMatcher());
+ if (matcher.isExternal()) {
+ return emitDefiniteFailure()
+ << "unresolved external symbol " << getMatcher();
+ }
+
+ SmallVector<SmallVector<MappedValue>, 2> rawResults;
+ rawResults.resize(getOperation()->getNumResults());
+ std::optional<DiagnosedSilenceableFailure> maybeFailure;
+ for (Operation *root : state.getPayloadOps(getRoot())) {
+ WalkResult walkResult = root->walk([&](Operation *op) {
+ DEBUG_MATCHER({
+ DBGS_MATCHER() << "matching ";
+ op->print(llvm::dbgs(),
+ OpPrintingFlags().assumeVerified().skipRegions());
+ llvm::dbgs() << " @" << op << "\n";
+ });
+
+ // Try matching.
+ SmallVector<SmallVector<MappedValue>> mappings;
+ DiagnosedSilenceableFailure diag =
+ matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+ if (diag.isDefiniteFailure())
+ return WalkResult::interrupt();
+ if (diag.isSilenceableFailure()) {
+ DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
+ << " failed: " << diag.getMessage());
+ return WalkResult::advance();
+ }
+
+ // If succeeded, collect results.
+ for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
+ if (mapping.size() != 1) {
+ maybeFailure.emplace(emitSilenceableError()
+ << "result #" << i << ", associated with "
+ << mapping.size()
+ << " payload objects, expected 1");
+ return WalkResult::interrupt();
+ }
+ rawResults[i].push_back(mapping[0]);
+ }
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return std::move(*maybeFailure);
+ assert(!maybeFailure && "failure set but the walk was not interrupted");
+
+ for (auto &&[opResult, rawResult] :
+ llvm::zip(getOperation()->getResults(), rawResults)) {
+ results.setMappedValues(opResult, rawResult);
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::CollectMatchingOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getRoot(), effects);
+ producesHandle(getResults(), effects);
+ onlyReadsPayload(effects);
+}
+
+LogicalResult transform::CollectMatchingOp::verifySymbolUses(
+ SymbolTableCollection &symbolTable) {
+ auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
+ symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
+ if (!matcherSymbol ||
+ !isa<TransformOpInterface>(matcherSymbol.getOperation()))
+ return emitError() << "unresolved matcher symbol " << getMatcher();
+
+ ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
+ if (argumentTypes.size() != 1 ||
+ !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
+ return emitError()
+ << "expected the matcher to take one operation handle argument";
+ }
+ if (!matcherSymbol.getArgAttr(
+ 0, transform::TransformDialect::kArgReadOnlyAttrName)) {
+ return emitError() << "expected the matcher argument to be marked readonly";
+ }
+
+ ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
+ if (resultTypes.size() != getOperation()->getNumResults()) {
+ return emitError()
+ << "expected the matcher to yield as many values as op has results ("
+ << getOperation()->getNumResults() << "), got "
+ << resultTypes.size();
+ }
+
+ for (auto &&[i, matcherType, resultType] :
+ llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
+ if (implementSameTransformInterface(matcherType, resultType))
+ continue;
+
+ return emitError()
+ << "mismatching type interfaces for matcher result and op result #"
+ << i;
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ForeachMatchOp
+//===----------------------------------------------------------------------===//
+
DiagnosedSilenceableFailure
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
return success();
}
-/// Returns `true` if both types implement one of the interfaces provided as
-/// template parameters.
-template <typename... Tys>
-static bool implementSameInterface(Type t1, Type t2) {
- return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
-}
-
-/// Returns `true` if both types implement one of the transform dialect
-/// interfaces.
-static bool implementSameTransformInterface(Type t1, Type t2) {
- return implementSameInterface<transform::TransformHandleTypeInterface,
- transform::TransformParamTypeInterface,
- transform::TransformValueHandleTypeInterface>(
- t1, t2);
-}
-
/// Checks that the attributes of the function-like operation have correct
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
/// annotations being present even if they can be inferred from the body.
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 09641615887981..fb8d0c6adf5eb2 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -696,3 +696,71 @@ transform.sequence failures(propagate) {
transform.named_sequence @foo()
} : !transform.any_op
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{unresolved matcher symbol @missing_symbol}}
+ transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{expected the matcher to take one operation handle argument}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher() {
+ transform.yield
+ }
+}
+
+// -----
+
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{expected the matcher argument to be marked readonly}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op) {
+ transform.yield
+ }
+}
+
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) {
+ transform.yield
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-error @below {{mismatching type interfaces for matcher result and op result #0}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.yield %arg0 : !transform.any_op
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index a39e6f94cb34f6..d24d091a6627a4 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2167,3 +2167,32 @@ transform.sequence failures(propagate) {
transform.yield
}
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ %0 = transform.merge_handles %arg0, %arg0 : !transform.any_op
+ transform.yield %0 : !transform.any_op
+ }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{unresolved external symbol @matcher}}
+ transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+}
>From c8aa646c5db16db83d4890f9e25134ab4abbf7f7 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Wed, 3 Jan 2024 10:59:31 +0000
Subject: [PATCH 3/3] address review
---
.../mlir/Dialect/Transform/IR/TransformOps.td | 6 ++++--
mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 2 +-
.../Dialect/Transform/test-interpreter.mlir | 18 ++++++++++++++++++
3 files changed, 23 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 18cdbde54db353..3dc70920482d9c 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -473,8 +473,10 @@ def CollectMatchingOp : TransformDialectOp<"collect_matching", [
matcher sequence produces a silenceable failure, the matcher advances to the
next payload operation in the walk order without finishing the sequence.
- The results of this operation are constructed by concatenating values
- yielded by successful application of the matcher named sequence.
+ The i-th result of this operation is constructed by concatenating the i-th
+ yielded payload IR objects of all successful matcher sequence applications.
+ All results are guaranteed to be mapped to the same number of payload IR
+ objects.
The operation succeeds unless the matcher sequence produced a definite
failure for any invocation.
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 76293bfb31719c..1bcc2ba76035e7 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -896,7 +896,7 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
assert(!maybeFailure && "failure set but the walk was not interrupted");
for (auto &&[opResult, rawResult] :
- llvm::zip(getOperation()->getResults(), rawResults)) {
+ llvm::zip_equal(getOperation()->getResults(), rawResults)) {
results.setMappedValues(opResult, rawResult);
}
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index d24d091a6627a4..0b758d3c48c1c0 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2196,3 +2196,21 @@ module attributes { transform.with_named_sequence } {
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+ // expected-remark @below {{matched}}
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-remark @below {{matched}}
+ %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %0, "matched" : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ transform.match.operation_name %arg0 ["transform.sequence", "transform.collect_matching"] : !transform.any_op
+ transform.yield %arg0 : !transform.any_op
+ }
+}
More information about the Mlir-commits
mailing list