[Mlir-commits] [mlir] [mlir] add a chapter on matchers to the transform dialect tutorial (PR #76725)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Jan 2 06:57:02 PST 2024


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/76725

These operations has been available for a while, but were not described
in the tutorial. Add a new chapter on using and defining match
operations.

>From 38942f62735a54cffe6848181cdb772fb1bc2285 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 2 Jan 2024 13:31:15 +0000
Subject: [PATCH 1/3] [mlir] introduce transform.num_associations

Add a new transform operation that creates a new parameter containing
the number of payload objects (operations, values or attributes)
associated with the argument. This is useful in matching and for
debugging purposes. This replaces three ad-hoc operations previously
provided by the test extension.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td | 21 ++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 29 ++++++
 .../transform-op-bufferize-to-allocation.mlir | 15 ++-
 .../Dialect/Linalg/transform-op-match.mlir    |  5 +-
 .../test/Dialect/Linalg/transform-op-pad.mlir |  3 +-
 .../Dialect/Transform/test-interpreter.mlir   | 96 ++++++++++++-------
 .../Transform/test-loop-transforms.mlir       |  9 +-
 .../TestTransformDialectExtension.cpp         | 45 ---------
 .../TestTransformDialectExtension.td          | 27 ------
 9 files changed, 135 insertions(+), 115 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 307257f4a582be..da0162faa6e466 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -438,6 +438,27 @@ def CastOp : TransformDialectOp<"cast",
   }];
 }
 
+def NumAssociationsOp : TransformDialectOp<"num_associations",
+    [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     MatchOpInterface]> {
+  let summary =
+    "Returns the number of payload objects associated with the argument";
+  let description = [{
+    Given an argument, handle or parameter, returns a new parameter associated
+    with a single 64-bit number that corresponds to the number of payload
+    objects (operations or values for a handle, attributes for a parameter)
+    associated with the argument.
+
+    Always succeeds.
+  }];
+  let arguments = (ins Transform_AnyHandleOrParamType:$handle);
+  let results = (outs TransformParamTypeInterface:$num);
+  let assemblyFormat = [{
+    $handle attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7136e423470a28..ca644252f3514a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -32,6 +32,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
 
@@ -1974,6 +1975,34 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
 }
 
+//===----------------------------------------------------------------------===//
+// NumAssociationsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
+  size_t numAssociations =
+      llvm::TypeSwitch<Type, size_t>(getHandle().getType())
+          .Case([&](TransformHandleTypeInterface opHandle) {
+            return llvm::range_size(state.getPayloadOps(getHandle()));
+          })
+          .Case([&](TransformValueHandleTypeInterface valueHandle) {
+            return llvm::range_size(state.getPayloadValues(getHandle()));
+          })
+          .Case([&](TransformParamTypeInterface param) {
+            return llvm::range_size(state.getParams(getHandle()));
+          })
+          .Default([](Type) {
+            llvm_unreachable("unknown kind of transform dialect type");
+            return 0;
+          });
+  results.setParams(getNum().cast<OpResult>(),
+                    rewriter.getI64IntegerAttr(numAssociations));
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // SelectOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index 49a52ba9e06f86..aa15ccf0beeee2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -36,13 +36,15 @@ module attributes {transform.with_named_sequence} {
 
     // Ensure that one linalg.fill was generated.
     %fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+    %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+    transform.test_print_param %p : !transform.param<i64>
 
     // Ensure that one linalg.copy was generated.
     %mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op
+    %p2 = transform.num_associations %mat : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op
+    transform.test_print_param %p2 : !transform.param<i64>
     transform.yield
   }
 }
@@ -73,18 +75,21 @@ module attributes {transform.with_named_sequence} {
 
     // Ensure that one linalg.fill was generated.
     %fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+    %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+    transform.test_print_param %p : !transform.param<i64>
 
     // Ensure that one linalg.copy was generated.
     %linalg_copy = transform.select "linalg.copy" in %new : (!transform.any_op) -> !transform.any_op
+    %p2 = transform.num_associations %linalg_copy : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op
+    transform.test_print_param %p2 : !transform.param<i64>
 
     // Ensure that one memref.alloca was generated.
     %alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op
+    %p3 = transform.num_associations %alloca : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op
+    transform.test_print_param %p3 : !transform.param<i64>
 
     // Make sure that One-Shot Bufferize can bufferize the rest.
     %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 15942db9b5db20..db5b5f1c786776 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -134,8 +134,9 @@ module attributes {transform.with_named_sequence} {
           #linalg.iterator_type<parallel>,
           #linalg.iterator_type<reduction>]}
         in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-remark @below {{0}}
-    transform.test_print_number_of_associated_payload_ir_ops %no_match : !transform.any_op
+    %p = transform.num_associations %no_match : (!transform.any_op) -> !transform.param<i64>
+    // expected-remark @below {{0}}
+    transform.test_print_param %p : !transform.param<i64>
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 6bca6c1fd6bf12..1f9d81a819e7fb 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -41,8 +41,9 @@ module attributes {transform.with_named_sequence} {
       padding_dimensions=[0, 1, 2],
       pack_paddings=[1, 1, 0]
     } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
+    %p = transform.num_associations %copy_back : (!transform.op<"bufferization.materialize_in_destination">) -> !transform.param<i64>
     // expected-remark @below {{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
+    transform.test_print_param %p : !transform.param<i64>
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index d9a11994eb9d90..a39e6f94cb34f6 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -575,8 +575,9 @@ transform.with_pdl_patterns {
     %0 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = merge_handles deduplicate %0, %1 : !transform.any_op
+    %3 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below {{1}}
-    test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+    test_print_param %3 : !transform.param<i64>
   }
 }
 
@@ -676,11 +677,13 @@ module {
     ^bb0(%arg1: !transform.any_op):
       %0 = pdl_match @func in %arg1 : (!transform.any_op) -> !transform.any_op
       %1 = replicate num(%0) %arg1 : !transform.any_op, !transform.any_op
+      %p = num_associations %1 : (!transform.any_op) -> !transform.param<i64>
       // expected-remark @below {{2}}
-      test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+      test_print_param %p : !transform.param<i64>
       %2 = replicate num(%0) %1 : !transform.any_op, !transform.any_op
+      %p2 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
       // expected-remark @below {{4}}
-      test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+      test_print_param %p2 : !transform.param<i64>
     }
   }
 }
@@ -708,8 +711,9 @@ transform.with_pdl_patterns {
     %f = pdl_match @const in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.foreach %f : !transform.any_op {
     ^bb2(%arg2: !transform.any_op):
+      %p = transform.num_associations %arg2 : (!transform.any_op) -> !transform.param<i64>
       // expected-remark @below {{1}}
-      transform.test_print_number_of_associated_payload_ir_ops %arg2 : !transform.any_op
+      transform.test_print_param %p : !transform.param<i64>
       transform.test_print_remark_at_operand %arg2, "transform applied" : !transform.any_op
     }
   }
@@ -780,8 +784,9 @@ transform.with_pdl_patterns {
       transform.yield %g : !transform.any_op
     }
 
+    %p = transform.num_associations %results : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below {{3}}
-    transform.test_print_number_of_associated_payload_ir_ops %results : !transform.any_op
+    transform.test_print_param %p : !transform.param<i64>
     transform.test_print_remark_at_operand %results, "transform applied" : !transform.any_op
   }
 }
@@ -877,8 +882,9 @@ transform.sequence failures(propagate) {
 ^bb1(%fun: !transform.any_op):
   %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   %h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+  transform.test_print_param %p : !transform.param<i64>
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
   %h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@@ -896,13 +902,15 @@ transform.sequence failures(suppress) {
 ^bb1(%fun: !transform.any_op):
   %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   %h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+  transform.test_print_param %p : !transform.param<i64>
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   // Silenceable failure and all handles are now empty.
   %h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+  %p2 = transform.num_associations %h_2#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{0}}
-  transform.test_print_number_of_associated_payload_ir_ops %h_2#0 : !transform.any_op
+  transform.test_print_param %p2 : !transform.param<i64>
 }
 
 // -----
@@ -918,12 +926,15 @@ transform.sequence failures(propagate) {
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   // No error, last result handle is empty.
   %h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+  %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+  transform.test_print_param %p : !transform.param<i64>
+  %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+  transform.test_print_param %p2 : !transform.param<i64>
+  %p3 = transform.num_associations %h#2 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{0}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#2 : !transform.any_op
+  transform.test_print_param %p3 : !transform.param<i64>
 }
 
 // -----
@@ -940,10 +951,12 @@ transform.sequence failures(propagate) {
 ^bb1(%fun: !transform.any_op):
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   %h:2 = split_handle %muli_2 {overflow_result = 0} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{3}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+  transform.test_print_param %p : !transform.param<i64>
+  %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+  transform.test_print_param %p2 : !transform.param<i64>
 }
 
 // -----
@@ -1668,8 +1681,9 @@ transform.sequence failures(propagate) {
   // expected-remark @below {{2 iterations}}
   transform.test_tracked_rewrite %0 : (!transform.any_op) -> ()
   // One replacement op (test.drop_mapping) is dropped from the mapping.
+  %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{2}}
-  test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+  test_print_param %p : !transform.param<i64>
 }
 
 // -----
@@ -1684,20 +1698,24 @@ module {
     %2 = transform.param.constant 1 -> !transform.param<i64>
     %3 = transform.param.constant 2 -> !transform.param<i64>
     %4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param<i64>
+    %p = num_associations %4 : (!transform.param<i64>) -> !transform.param<i64>
     // expected-remark @below {{1}}
-    test_print_number_of_associated_payload_ir_params %4 : !transform.param<i64>
+    test_print_param %p : !transform.param<i64>
 
     %5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param<i64>
+    %p2 = num_associations %5 : (!transform.param<i64>) -> !transform.param<i64>
     // expected-remark @below {{1}}
-    test_print_number_of_associated_payload_ir_params %5 : !transform.param<i64>
+    test_print_param %p2 : !transform.param<i64>
 
     %6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param<i64>
+    %p3 = num_associations %6 : (!transform.param<i64>) -> !transform.param<i64>
     // expected-remark @below {{2}}
-    test_print_number_of_associated_payload_ir_params %6 : !transform.param<i64>
+    test_print_param %p3 : !transform.param<i64>
 
     %7 = transform.merge_handles %1, %1, %2, %3 : !transform.param<i64>
+    %p4 = num_associations %7 : (!transform.param<i64>) -> !transform.param<i64>
     // expected-remark @below {{4}}
-    test_print_number_of_associated_payload_ir_params %7 : !transform.param<i64>
+    test_print_param %p4 : !transform.param<i64>
   }
 }
 
@@ -1712,21 +1730,25 @@ transform.sequence failures(propagate) {
   %3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value
 
   %4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value
+  %p = num_associations %4 : (!transform.any_value) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  test_print_number_of_associated_payload_ir_values %4 : !transform.any_value
+  test_print_param %p : !transform.param<i64>
 
   %5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value
+  %p2 = num_associations %5 : (!transform.any_value) -> !transform.param<i64>
   // expected-remark @below {{2}}
-  test_print_number_of_associated_payload_ir_values %5 : !transform.any_value
+  test_print_param %p2 : !transform.param<i64>
 
   %6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
   %7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value
+  %p3 = num_associations %6 : (!transform.any_value) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  test_print_number_of_associated_payload_ir_values %6 : !transform.any_value
+  test_print_param %p3 : !transform.param<i64>
 
   %8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value
+  %p4 = num_associations %8 : (!transform.any_value) -> !transform.param<i64>
   // expected-remark @below {{4}}
-  test_print_number_of_associated_payload_ir_values %8 : !transform.any_value
+  test_print_param %p4 : !transform.param<i64>
 }
 // -----
 
@@ -1820,31 +1842,37 @@ transform.sequence failures(propagate) {
 
   // There are 3 arith.constant ops.
   %all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op
+  %p = num_associations %all : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{3}}
-  test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+  test_print_param %p : !transform.param<i64>
   // "deduplicate" has no effect because these are 3 different ops.
   %merged_before = transform.merge_handles deduplicate %all : !transform.any_op
+  %p2 = num_associations %merged_before : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{3}}
-  test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op
+  test_print_param %p2 : !transform.param<i64>
 
   // Apply CSE.
   transform.apply_cse to %0 : !transform.any_op
 
   // The handle is still mapped to 3 arith.constant ops.
+  %p3 = num_associations %all : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{3}}
-  test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+  test_print_param %p3 : !transform.param<i64>
   // But they are all the same op.
   %merged_after = transform.merge_handles deduplicate %all : !transform.any_op
+  %p4 = num_associations %merged_after : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{1}}
-  test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op
+  test_print_param %p4 : !transform.param<i64>
 
   // The other handles were also updated.
   test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op
+  %p5 = num_associations %elim_first : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{1}}
-  test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op
+  test_print_param %p5 : !transform.param<i64>
   test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op
+  %p6 = num_associations %elim_second : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{1}}
-  test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op
+  test_print_param %p6 : !transform.param<i64>
 }
 
 // -----
@@ -1907,14 +1935,16 @@ transform.sequence failures(propagate) {
   // Get immediate parent.
   %2 = transform.get_parent_op %0 : (!transform.any_op) -> !transform.any_op
   test_print_remark_at_operand %2, "direct parent" : !transform.any_op
+  %p = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{2}}
-  test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+  test_print_param %p : !transform.param<i64>
 
   // Deduplicate results.
   %3 = transform.structured.match ops{["test.qux"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %4 = transform.get_parent_op %3 {deduplicate} : (!transform.any_op) -> !transform.any_op
+  %p2 = num_associations %4 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{1}}
-  test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
+  test_print_param %p2 : !transform.param<i64>
 }
 
 
@@ -2029,8 +2059,9 @@ transform.sequence failures(propagate) {
   // Match all ops inside the function (including the function itself).
   %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
   %0 = transform.structured.match in %func_op : (!transform.any_op) -> !transform.any_op
+  %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{5}}
-  test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+  test_print_param %p : !transform.param<i64>
 
   // Select "test.foo".
   %foo = transform.select "test.foo" in %0 : (!transform.any_op) -> !transform.any_op
@@ -2060,8 +2091,9 @@ transform.sequence failures(propagate) {
   %empty_op = transform.structured.match ops{["tensor.empty"]} in %func_op : (!transform.any_op) -> !transform.any_op
   transform.apply_dce to %func_op : !transform.any_op
 
+  %p = num_associations %empty_op : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{0}}
-  test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
+  test_print_param %p : !transform.param<i64>
 }
 
 
diff --git a/mlir/test/Dialect/Transform/test-loop-transforms.mlir b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
index 425962757f720b..c34f4baf8cd972 100644
--- a/mlir/test/Dialect/Transform/test-loop-transforms.mlir
+++ b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
@@ -37,13 +37,16 @@ module attributes {transform.with_named_sequence} {
     // Make sure that the handles are still valid (and were updated in case of
     // the loop).
 
+    %p = transform.num_associations %0 : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+    transform.test_print_param %p : !transform.param<i64>
     transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op
+    %p2 = transform.num_associations %1 : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+    transform.test_print_param %p2 : !transform.param<i64>
+    %p3 = transform.num_associations %2 : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+    transform.test_print_param %p3 : !transform.param<i64>
 
     transform.yield
   }
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index e8c25aca237251..9c69164e33f600 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -456,51 +456,6 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
   return emitDefaultSilenceableFailure(target);
 }
 
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
-    transform::TransformRewriter &rewriter,
-    transform::TransformResults &results, transform::TransformState &state) {
-  if (!getHandle())
-    emitRemark() << 0;
-  emitRemark() << llvm::range_size(state.getPayloadOps(getHandle()));
-  return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getHandle(), effects);
-}
-
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::apply(
-    transform::TransformRewriter &rewriter,
-    transform::TransformResults &results, transform::TransformState &state) {
-  if (!getValueHandle())
-    emitRemark() << 0;
-  emitRemark() << llvm::range_size(state.getPayloadValues(getValueHandle()));
-  return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getValueHandle(), effects);
-}
-
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::apply(
-    transform::TransformRewriter &rewriter,
-    transform::TransformResults &results, transform::TransformState &state) {
-  if (!getParam())
-    emitRemark() << 0;
-  emitRemark() << llvm::range_size(state.getParams(getParam()));
-  return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getParam(), effects);
-}
-
 DiagnosedSilenceableFailure
 mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
                                      transform::TransformResults &results,
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 41f318db68405b..5cb47659fdbdfd 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -343,33 +343,6 @@ def TestMixedSuccessAndSilenceableOp
   }];
 }
 
-def TestPrintNumberOfAssociatedPayloadIROps
-  : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_ops",
-       [DeclareOpInterfaceMethods<TransformOpInterface>,
-        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let arguments = (ins TransformHandleTypeInterface:$handle);
-  let assemblyFormat = "$handle attr-dict `:` type($handle)";
-  let cppNamespace = "::mlir::test";
-}
-
-def TestPrintNumberOfAssociatedPayloadIRValues
-  : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_values",
-       [DeclareOpInterfaceMethods<TransformOpInterface>,
-        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let arguments = (ins TransformValueHandleTypeInterface:$value_handle);
-  let assemblyFormat = "$value_handle attr-dict `:` type($value_handle)";
-  let cppNamespace = "::mlir::test";
-}
-
-def TestPrintNumberOfAssociatedPayloadIRParams
-  : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_params",
-       [DeclareOpInterfaceMethods<TransformOpInterface>,
-        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let arguments = (ins TransformParamTypeInterface:$param);
-  let assemblyFormat = "$param attr-dict `:` type($param)";
-  let cppNamespace = "::mlir::test";
-}
-
 def TestCopyPayloadOp
   : Op<Transform_Dialect, "test_copy_payload",
        [DeclareOpInterfaceMethods<TransformOpInterface>,

>From f1b335a9dae6822ce4ca9b5be9f6227d9e97230c Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 2 Jan 2024 13:50:24 +0000
Subject: [PATCH 2/3] [mlir] introduce transform.collect_matching

Introduce a new match combinator into the transform dialect. This
operation collects all operations that are yielded by a satisfactory
match into its results. This is a simpler version of `foreach_match`
that can be inserted directly into existing transform scripts.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td |  32 +++-
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 150 ++++++++++++++++--
 mlir/test/Dialect/Transform/ops-invalid.mlir  |  68 ++++++++
 .../Dialect/Transform/test-interpreter.mlir   |  29 ++++
 4 files changed, 261 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index da0162faa6e466..18cdbde54db353 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -459,6 +459,36 @@ def NumAssociationsOp : TransformDialectOp<"num_associations",
   }];
 }
 
+def CollectMatchingOp : TransformDialectOp<"collect_matching", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+    DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Collects all payload ops that match the given named matcher";
+  let description = [{
+    Collects operations nested under `root` or other payload IR objects that
+    match the given matcher expressed as a named sequence. The matcher sequence
+    must accept exactly one argument that it is not allowed to modify. It must
+    yield as many values as this op has results. Each of the yielded values must
+    be associated with exactly one payload object. If any operation in the
+    matcher sequence produces a silenceable failure, the matcher advances to the
+    next payload operation in the walk order without finishing the sequence.
+
+    The results of this operation are constructed by concatenating values
+    yielded by successful application of the matcher named sequence.
+
+    The operation succeeds unless the matcher sequence produced a definite
+    failure for any invocation.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$root,
+                       SymbolRefAttr:$matcher);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+
+  let assemblyFormat = [{
+    $matcher `in` $root attr-dict `:` functional-type($root, $results)
+  }];
+}
+
 def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
@@ -673,7 +703,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
 
 def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
-     NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
+     NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
   let summary = "Get handle to the producer of this operation's operand number";
   let description = [{
     The handle defined by this Transform op corresponds to operation that
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index ca644252f3514a..76293bfb31719c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 }
 
 //===----------------------------------------------------------------------===//
-// ForeachMatchOp
+// CollectMatchingOp
 //===----------------------------------------------------------------------===//
 
 /// Applies matcher operations from the given `block` assigning `op` as the
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
   return DiagnosedSilenceableFailure::success();
 }
 
+/// Returns `true` if both types implement one of the interfaces provided as
+/// template parameters.
+template <typename... Tys>
+static bool implementSameInterface(Type t1, Type t2) {
+  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
+}
+
+/// Returns `true` if both types implement one of the transform dialect
+/// interfaces.
+static bool implementSameTransformInterface(Type t1, Type t2) {
+  return implementSameInterface<transform::TransformHandleTypeInterface,
+                                transform::TransformParamTypeInterface,
+                                transform::TransformValueHandleTypeInterface>(
+      t1, t2);
+}
+
+//===----------------------------------------------------------------------===//
+// CollectMatchingOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
+  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
+      getOperation(), getMatcher());
+  if (matcher.isExternal()) {
+    return emitDefiniteFailure()
+           << "unresolved external symbol " << getMatcher();
+  }
+
+  SmallVector<SmallVector<MappedValue>, 2> rawResults;
+  rawResults.resize(getOperation()->getNumResults());
+  std::optional<DiagnosedSilenceableFailure> maybeFailure;
+  for (Operation *root : state.getPayloadOps(getRoot())) {
+    WalkResult walkResult = root->walk([&](Operation *op) {
+      DEBUG_MATCHER({
+        DBGS_MATCHER() << "matching ";
+        op->print(llvm::dbgs(),
+                  OpPrintingFlags().assumeVerified().skipRegions());
+        llvm::dbgs() << " @" << op << "\n";
+      });
+
+      // Try matching.
+      SmallVector<SmallVector<MappedValue>> mappings;
+      DiagnosedSilenceableFailure diag =
+          matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
+      if (diag.isDefiniteFailure())
+        return WalkResult::interrupt();
+      if (diag.isSilenceableFailure()) {
+        DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
+                                     << " failed: " << diag.getMessage());
+        return WalkResult::advance();
+      }
+
+      // If succeeded, collect results.
+      for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
+        if (mapping.size() != 1) {
+          maybeFailure.emplace(emitSilenceableError()
+                               << "result #" << i << ", associated with "
+                               << mapping.size()
+                               << " payload objects, expected 1");
+          return WalkResult::interrupt();
+        }
+        rawResults[i].push_back(mapping[0]);
+      }
+      return WalkResult::advance();
+    });
+    if (walkResult.wasInterrupted())
+      return std::move(*maybeFailure);
+    assert(!maybeFailure && "failure set but the walk was not interrupted");
+
+    for (auto &&[opResult, rawResult] :
+         llvm::zip(getOperation()->getResults(), rawResults)) {
+      results.setMappedValues(opResult, rawResult);
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::CollectMatchingOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getRoot(), effects);
+  producesHandle(getResults(), effects);
+  onlyReadsPayload(effects);
+}
+
+LogicalResult transform::CollectMatchingOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
+      symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
+  if (!matcherSymbol ||
+      !isa<TransformOpInterface>(matcherSymbol.getOperation()))
+    return emitError() << "unresolved matcher symbol " << getMatcher();
+
+  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
+  if (argumentTypes.size() != 1 ||
+      !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
+    return emitError()
+           << "expected the matcher to take one operation handle argument";
+  }
+  if (!matcherSymbol.getArgAttr(
+          0, transform::TransformDialect::kArgReadOnlyAttrName)) {
+    return emitError() << "expected the matcher argument to be marked readonly";
+  }
+
+  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
+  if (resultTypes.size() != getOperation()->getNumResults()) {
+    return emitError()
+           << "expected the matcher to yield as many values as op has results ("
+           << getOperation()->getNumResults() << "), got "
+           << resultTypes.size();
+  }
+
+  for (auto &&[i, matcherType, resultType] :
+       llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
+    if (implementSameTransformInterface(matcherType, resultType))
+      continue;
+
+    return emitError()
+           << "mismatching type interfaces for matcher result and op result #"
+           << i;
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ForeachMatchOp
+//===----------------------------------------------------------------------===//
+
 DiagnosedSilenceableFailure
 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &results,
@@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
   return success();
 }
 
-/// Returns `true` if both types implement one of the interfaces provided as
-/// template parameters.
-template <typename... Tys>
-static bool implementSameInterface(Type t1, Type t2) {
-  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
-}
-
-/// Returns `true` if both types implement one of the transform dialect
-/// interfaces.
-static bool implementSameTransformInterface(Type t1, Type t2) {
-  return implementSameInterface<transform::TransformHandleTypeInterface,
-                                transform::TransformParamTypeInterface,
-                                transform::TransformValueHandleTypeInterface>(
-      t1, t2);
-}
-
 /// Checks that the attributes of the function-like operation have correct
 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
 /// annotations being present even if they can be inferred from the body.
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 09641615887981..fb8d0c6adf5eb2 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -696,3 +696,71 @@ transform.sequence failures(propagate) {
     transform.named_sequence @foo()
   } : !transform.any_op
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{unresolved matcher symbol @missing_symbol}}
+    transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{expected the matcher to take one operation handle argument}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher() {
+    transform.yield
+  }
+}
+
+// -----
+
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{expected the matcher argument to be marked readonly}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op) {
+    transform.yield
+  }
+}
+
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) {
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-error @below {{mismatching type interfaces for matcher result and op result #0}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.yield %arg0 : !transform.any_op
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index a39e6f94cb34f6..d24d091a6627a4 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2167,3 +2167,32 @@ transform.sequence failures(propagate) {
   transform.yield 
 }
 }
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    %0 = transform.merge_handles %arg0, %arg0 : !transform.any_op
+    transform.yield %0 : !transform.any_op
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-error @below {{unresolved external symbol @matcher}}
+    transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
+}

>From 28ef1e3d4f4c81546cf473bd44bbbf4ae4d70555 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 2 Jan 2024 14:46:22 +0000
Subject: [PATCH 3/3] [mlir] add a chapter on matchers to the transform dialect
 tutorial

These operations has been available for a while, but were not described
in the tutorial. Add a new chapter on using and defining match
operations.
---
 mlir/docs/Tutorials/transform/Ch4.md          | 553 ++++++++++++++++++
 mlir/examples/transform/CMakeLists.txt        |   1 +
 .../Ch3/transform-opt/transform-opt.cpp       |   2 +-
 mlir/examples/transform/Ch4/CMakeLists.txt    |  21 +
 .../transform/Ch4/include/CMakeLists.txt      |  14 +
 .../transform/Ch4/include/MyExtension.h       |  30 +
 .../transform/Ch4/include/MyExtension.td      |  46 ++
 .../examples/transform/Ch4/lib/CMakeLists.txt |  20 +
 .../transform/Ch4/lib/MyExtension.cpp         | 207 +++++++
 .../Ch4/transform-opt/transform-opt.cpp       |  55 ++
 mlir/test/CMakeLists.txt                      |   1 +
 .../test/Examples/transform/Ch4/features.mlir | 123 ++++
 .../test/Examples/transform/Ch4/multiple.mlir | 131 +++++
 .../test/Examples/transform/Ch4/sequence.mlir | 139 +++++
 mlir/test/lit.cfg.py                          |   1 +
 15 files changed, 1343 insertions(+), 1 deletion(-)
 create mode 100644 mlir/docs/Tutorials/transform/Ch4.md
 create mode 100644 mlir/examples/transform/Ch4/CMakeLists.txt
 create mode 100644 mlir/examples/transform/Ch4/include/CMakeLists.txt
 create mode 100644 mlir/examples/transform/Ch4/include/MyExtension.h
 create mode 100644 mlir/examples/transform/Ch4/include/MyExtension.td
 create mode 100644 mlir/examples/transform/Ch4/lib/CMakeLists.txt
 create mode 100644 mlir/examples/transform/Ch4/lib/MyExtension.cpp
 create mode 100644 mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp
 create mode 100644 mlir/test/Examples/transform/Ch4/features.mlir
 create mode 100644 mlir/test/Examples/transform/Ch4/multiple.mlir
 create mode 100644 mlir/test/Examples/transform/Ch4/sequence.mlir

diff --git a/mlir/docs/Tutorials/transform/Ch4.md b/mlir/docs/Tutorials/transform/Ch4.md
new file mode 100644
index 00000000000000..3cd1fc53308a2b
--- /dev/null
+++ b/mlir/docs/Tutorials/transform/Ch4.md
@@ -0,0 +1,553 @@
+# Chapter 4: Matching Payload with Transform Operations
+
+Up until now, we were applying transform dialect scripts under the assumption
+that specific payload operations are identified by the caller when the transform
+dialect interpreter is invoked. This may be seen as contrary to the idea of
+driving transformations from a dialect since the transformation targets must be
+identified by the caller in C++. It also adds practical overhead due to
+increased interaction with the interpreter in C++, and cognitive overhead of
+manipulating two interfaces at once. To remedy this, Transform dialect proposes
+a subset of operations for _matching_ payload operations that need to be
+transformed.
+
+_Match_ operations are simply transform operations with some additional
+guarantees. In particular, they are not expected to modify the payload IR and
+are expected to fail if their operands (typically payload operation handles) are
+not associated with payload IR objects having desired properties, such as
+operation names or kinds of arguments. Using simple combinator operations, it
+becomes possible to set up a higher-level match and rewrite infrastructure
+directly within the transform dialect.
+
+
+## Simple match
+
+Let us reconsider the “fully connected layer” example from Chapter 1, reproduced
+below for convenience.
+
+
+```mlir
+// Original function to optimize.
+func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
+                   %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
+                   -> tensor<512x512xf32> {
+  // Matrix-matrix multiplication.
+  %matmul = linalg.matmul
+            ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
+            outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
+
+  // Elementwise addition.
+  %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
+    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
+    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
+
+  // Elementwise max with 0 (ReLU).
+  %c0f = arith.constant 0.0 : f32
+  %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
+    ins(%biased, %c0f : tensor<512x512xf32>, f32)
+    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
+  func.return %relued : tensor<512x512xf32>
+}
+
+```
+
+
+In Chapter 1, we were calling the test transform interpreter pass with
+additional arguments, `bind-first-extra-to-ops=linalg.matmul
+bind-second-extra-to-ops=linalg.elemwise_binary`, to provide initial
+associations for operation handles. Instead, we can use match operations to
+discover relevant operations in the payload IR. Match operations can be combined
+with “regular” transform operations using, e.g., the
+`transform.collect_matching` combinator operation that leverages the concept of
+named sequences to organize matchers.
+
+
+```mlir
+// The module containing named sequences must have an attribute allowing them
+// to enable verification.
+module @transforms attributes { transform.with_named_sequence } {
+  // Entry point. This takes as the only argument the root operation (typically
+  // pass root) given to the transform interpreter.
+  transform.named_sequence @__transform_main(
+      %root: !transform.any_op {transform.readonly}) {
+    // Collect operations that match the criteria specified in named named
+    // sequence. If the named sequence fails with a silenceable failure,
+    // silences it (the message is forwarded to the debug stream). If the named
+    // sequence succeeds, appends its results to the results of this operation.
+    %elemwise = transform.collect_matching @match_elemwise in %root
+      : (!transform.any_op) -> !transform.any_op
+    %matmul = transform.collect_matching @match_matmul in %root
+      : (!transform.any_op) -> !transform.any_op
+    transform.include @print_elemwise failures(propagate)  (%elemwise)
+      : (!transform.any_op) -> ()
+    transform.include @print_matmul failures(propagate)  (%matmul)
+      : (!transform.any_op) -> ()
+
+    transform.yield
+  }
+
+  // This is a matcher sequence. It is given an operation to match and the
+  // match is considered successful unless any nested operation produces a
+  // failure. The values yielded by this operation will be forwarded to the
+  // rewriter sequence on success.
+  transform.named_sequence @match_elemwise(
+      %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.match.operation_name %entry ["linalg.elemwise_binary"]
+      : !transform.any_op
+    transform.yield %entry : !transform.any_op
+  }
+  transform.named_sequence @match_matmul(
+      %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op
+    transform.yield %entry : !transform.any_op
+  }
+
+  // This is a rewriter sequence.
+  transform.named_sequence @print_elemwise(
+      %elemwise_binary: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand
+      %elemwise_binary, "elementwise binary" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence @print_matmul(
+      %matmul: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op
+    transform.yield
+  }
+}
+
+```
+
+
+This script can be executed using the non-test interpreter pass running on the
+root operation of the translation unit without additional flags: `mlir-opt
+--transform-interpreter`. It will emit corresponding remarks at elementwise and
+matmul operations. In debug builds, the infrastructure provides a convenient
+method to understand the matching process by passing
+`-debug-only=transform-matcher` to `mlir-opt` or a derived tool. It will print
+the silenceable failure messages produced by the match operations into the debug
+stream, for example:
+
+
+```
+[transform-matcher] matching %0 = linalg.matmul ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32> @0x5622eee08410
+[transform-matcher] matcher match_elemwise failed: wrong operation name
+mlir/test/Examples/transform/Ch4/sequence.mlir:14:13: remark: matmul
+```
+
+
+This is now sufficient to run the rest of the transform script from Chapter 1,
+substituting `%arg1` with `%matmul` and `%arg2` with `%elemwise`.
+
+
+## Matching Chains of Operations
+
+The matcher above remains naive as it matches _all_ operations of the certain
+kind under the payload root. These operations may or may not be related, and
+may, for example, belong to different functions. Even if they are in a single
+function, if there are multiple groups of such operations, we wouldn’t be able
+to differentiate them with this approach. In reality, we want to match a
+specific group of operations where a `matmul` operation produces a result that
+is used by an elementwise operation, which in turn feeds another elementwise
+operation in a similar way.
+
+This can be achieved using the following matcher sequence.
+
+
+```mlir
+// This is also a matcher sequence. It is similarly given an operation to
+// match and nested operations must succeed in order for a match to be deemed
+// successful. It starts matching from the last operation in the use-def chain
+// and goes back because each operand (use) has exactly one definition.
+transform.named_sequence @match_matmul_elemwise(
+    %last: !transform.any_op {transform.readonly})
+    -> (!transform.any_op, !transform.any_op, !transform.any_op) {
+  // The last operation must be an elementwise binary.
+  transform.match.operation_name %last ["linalg.elemwise_binary"]
+    : !transform.any_op
+  // Its first operand must be defined by another operation, to which we
+  // will get a handle here. We are guaranteed that the first operand exists
+  // because we know the operation is binary, but even in absence of such a
+  // guarantee, this operation would have produced a silenceable failure when
+  // `%last` does not have enough operands.
+  %middle = transform.get_producer_of_operand %last[0]
+    : (!transform.any_op) -> !transform.any_op
+  // The defining operation must itself be an elementwise binary.
+  transform.match.operation_name %middle ["linalg.elemwise_binary"]
+    : !transform.any_op
+  // And the first operand of that operation must be defined by yet another
+  // operation.
+  %matmul = transform.get_producer_of_operand %middle[0]
+    : (!transform.any_op) -> !transform.any_op
+  // And that operation is a matmul.
+  transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
+  // We will yield the handles to the matmul and the two elementwise
+  // operations separately.
+  transform.yield %matmul, %middle, %last
+    : !transform.any_op, !transform.any_op, !transform.any_op
+}
+```
+
+This matcher is applicable in presence of other `elemwise` and `matmul`
+operations and will return the triple of _related_ operations rather than
+operations in the order in which they are found.
+
+
+## Defining Match Operations
+
+The matcher of a chain of operations is correct in presence of other operations,
+but is still insufficiently robust for many cases of interest. In particular, it
+requires that the _first_ operand of elementwise operations is produced by
+another operation. The same transformation strategy may however apply regardless
+of the operand position: many binary operations are associative. Let us use this
+opportunity to introduce a new match operation. Specifically, we would like this
+operation to succeed if _any_ of the operands satisfies certain conditions that
+can be expressed as other match operations. We also want it to return some of
+the state and the position of the matched operand in the operand list.
+
+Match operations are defined similarly to other transform operations, with the
+only difference of additionally implementing the `MatchOpInterface`. Note that
+this interface has _no additional methods_ (though it may add some eventually)
+and is only used as a verification contract that the operation is intended for
+matching and will not attempt to transform the payload. The minimal definition
+of our operation is as follows.
+
+
+```tablegen
+// Define the new operation. By convention, prefix its name with `match`
+// followed by the name of the dialect extension.
+def HasOperandSatisfyingOp : TransformDialectOp<"match.my.has_operand_satisfying",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     // Indicate that the operation implements MatchOpInterface in addition to
+     // the TransformOpInterface. This interface is only used as a tag at this
+     // point and has no methods that are mandatory to implement.
+     MatchOpInterface,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+  let summary = "Succeed if any of the operands matches all nested criteria";
+  let arguments = (ins TransformHandleTypeInterface:$op);
+  let results = (outs TransformParamTypeInterface:$position,
+                      Variadic<Transform_AnyHandleOrParamType>:$results);
+
+  // Match operations can be arbitrarily complex, e.g., containing regions.
+  let regions = (region SizedRegion<1>:$body);
+  let hasVerifier = 1;
+  let assemblyFormat = [{
+    $op `:` functional-type($op, results) attr-dict-with-keyword $body
+  }];
+}
+```
+
+
+It takes as argument the handle associated with the payload operations whose
+operands it will match, has an associated single-block region containing the
+match criteria, and returns the position of the matched operand as well as any
+other transform value yielded from the body on the successful match.
+
+The matching logic is implemented in the `apply` method of the
+`TransformOpInterface` and is easily composable with other transform operations.
+All facilities for managing the interpreter state and recursively entering the
+blocks are available in the same way as they are for “regular” transform
+operations. Match operations are expected to return a silenceable failure to
+indicate failure to match, and to immediately propagate definite failures. If
+they have nested operations, they are expected to handle and, in most cases,
+silence the silenceable failures produced when applying those operations. For
+our operation, the matching is essentially a loop iterating over all operands of
+the (single) payload operation and applying nested transform ops until they all
+succeed for one of the operands.
+
+
+```cpp
+// Matcher ops implement `apply` similarly to other transform ops. They are not
+// expected to modify payload, but use the tri-state result to signal failure or
+// success to match, as well as potential irrecoverable errors.
+mlir::DiagnosedSilenceableFailure
+mlir::transform::HasOperandSatisfyingOp::apply(
+    mlir::transform::TransformRewriter &rewriter,
+    mlir::transform::TransformResults &results,
+    mlir::transform::TransformState &state) {
+  // For simplicity, only handle a single payload op. Actual implementations
+  // can use `SingleOpMatcher` trait to simplify implementation and document
+  // this expectation.
+  auto payloadOps = state.getPayloadOps(getOp());
+  if (!llvm::hasSingleElement(payloadOps))
+    return emitSilenceableError() << "expected single payload";
+
+  // Iterate over all operands of the payload op to see if they can be matched
+  // using the body of this op.
+  Operation *payload = *payloadOps.begin();
+  for (OpOperand &operand : payload->getOpOperands()) {
+    // Create a scope for transform values defined in the body. This corresponds
+    // to the syntactic scope of the region attached to this op. Any values
+    // associated with payloads from now on will be automatically dissociated
+    // when this object is destroyed, i.e. at the end of the iteration.
+    // Associate the block argument handle with the operand.
+    auto matchScope = state.make_region_scope(getBody());
+    if (failed(state.mapBlockArgument(getBody().getArgument(0),
+                                      {operand.get()}))) {
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+
+    // Iterate over all nested matchers with the current mapping and see if they
+    // succeed.
+    bool matchSucceeded = true;
+    for (Operation &matcher : getBody().front().without_terminator()) {
+      // Matcher ops are applied similarly to any other transform op.
+      DiagnosedSilenceableFailure diag =
+          state.applyTransform(cast<TransformOpInterface>(matcher));
+
+      // Definite failures are immediately propagated as they are irrecoverable.
+      if (diag.isDefiniteFailure())
+        return diag;
+
+      // On success, keep checking the remaining conditions.
+      if (diag.succeeded())
+        continue;
+
+      // Report failure-to-match for debugging purposes and stop matching this
+      // operand.
+      assert(diag.isSilenceableFailure());
+      DEBUG_MATCHER(DBGS_MATCHER()
+                    << "failed to match operand #" << operand.getOperandNumber()
+                    << ": " << diag.getMessage());
+      (void)diag.silence();
+      matchSucceeded = false;
+      break;
+    }
+    // If failed to match this operand, try other operands.
+    if (!matchSucceeded)
+      continue;
+
+    // If we reached this point, the matching succeeded for the current operand.
+    // Remap the values associated with terminator operands to be associated
+    // with op results, and also map the parameter result to the operand's
+    // position. Note that it is safe to do here despite the end of the scope
+    // as `results` are integrated into `state` by the interpreter after `apply`
+    // returns rather than immediately.
+    SmallVector<SmallVector<MappedValue>> yieldedMappings;
+    transform::detail::prepareValueMappings(
+        yieldedMappings, getBody().front().getTerminator()->getOperands(),
+        state);
+    results.setParams(getPosition().cast<OpResult>(),
+                      {rewriter.getI32IntegerAttr(operand.getOperandNumber())});
+    for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
+      results.setMappedValues(result, mapping);
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  // If we reached this point, none of the operands succeeded the match.
+  return emitSilenceableError()
+         << "none of the operands satisfied the conditions";
+}
+
+```
+
+
+By convention, operations implementing `MatchOpInterface` must not modify
+payload IR and must therefore specify that they only read operand handles and
+payload as their effects.
+
+
+```
+void transform::CollectMatchingOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getRoot(), effects);
+  producesHandle(getResults(), effects);
+  onlyReadsPayload(effects);
+}
+```
+
+
+This operation can now be included in a transform dialect extension, loaded and
+used in our matcher. Specifically, we will use it to indicate that either of the
+operands of the “max” elementwise operation in our example can be produced by
+the previous elementwise operation. The previous operation will still require
+the matmul to produce the first operand for simplicity. The updated matcher
+sequence looks as follows.
+
+
+```
+transform.named_sequence @match_matmul_elemwise(
+    %last: !transform.any_op {transform.readonly})
+    -> (!transform.any_op, !transform.any_op, !transform.any_op,
+        !transform.param<i32>) {
+  // The last operation must be an elementwise binary.
+  transform.match.operation_name %last ["linalg.elemwise_binary"]
+    : !transform.any_op
+
+  // One of its operands must be defined by another operation, to which we
+  // will get a handle here. This is achieved thanks to a newly defined
+  // operation that tries to match operands one by one using the match
+  // operations nested in its region.
+  %pos, %middle = transform.match.my.has_operand_satisfying %last
+      : (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) {
+  ^bb0(%operand: !transform.any_value):
+    // The operand must be defined by an operation.
+    %def = transform.get_defining_op %operand
+      : (!transform.any_value) -> !transform.any_op
+    // The defining operation must itself be an elementwise binary.
+    transform.match.operation_name %def ["linalg.elemwise_binary"]
+      : !transform.any_op
+    transform.yield %def : !transform.any_op
+  }
+
+  // And the first operand of that operation must be defined by yet another
+  // operation.
+  %matmul = transform.get_producer_of_operand %middle[0]
+    : (!transform.any_op) -> !transform.any_op
+  // And that operation is a matmul.
+  transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
+  // We will yield the handles to the matmul and the two elementwise
+  // operations separately.
+  transform.yield %matmul, %middle, %last, %pos
+    : !transform.any_op, !transform.any_op, !transform.any_op,
+      !transform.param<i32>
+}
+```
+
+
+This achieves the desired effect and matches both `max(add(matmul(...), bias),
+0)` and `max(0, add(matmul(...), bias))` in the same values. The `%pos` value is
+a transform dialect _parameter_, which is used to store lists of entities known
+to be constant throughout the transform application. Most often, parameters are
+numeric values, but they can generally be any MLIR attributes.
+
+In order to demonstrate that groups of operations are matched independently of
+each other, let us use the `transform.foreach_match` operation that allows one
+to implement a simple high-level pattern rewriting approach within the transform
+dialect (for advanced or lower-level pattern rewriting, consider PDL(L) or C++
+rewriting APIs). It maps a matcher named sequence to an action named sequence,
+and the latter gets invoked whenever the former succeeds.
+
+
+```mlir
+// Traverses the payload IR associated with the operand handle, invoking
+// @match_matmul_elemwise on each of the operations. If the named sequence
+// succeeds, i.e., if none of the nested match (transform) operations
+// produced a silenceable failure, invokes @print_matmul_elemwise and
+// forwards the values yielded as arguments of the new invocation. If the
+// named sequence fails with a silenceable failure, silences it (the message
+// is forwarded to the debug stream). Definite failures are propagated
+// immediately and unconditionally, as usual.
+transform.foreach_match in %root
+  @match_matmul_elemwise -> @print_matmul_elemwise
+  : (!transform.any_op) -> !transform.any_op
+```
+
+
+The `@print_matmul_elemwise` named sequence, available in `multiple.mlir`, will
+use the parameter with the position of the operand to differentiate the two
+groups.
+
+
+## Matchers for Inferred Features
+
+The matcher sequences described above, although useful to drive transformations
+from within the transform dialect interpreter, are rather basic since they
+mostly rely on operation names and use-def chains. Alternative implementations
+using APIs or various declarative rewrite rules are barely less expressive and
+sometimes more concise. The real power of transform dialect matcher ops lies in
+the possibility to define matchers of _inferred properties_ of payloads, i.e.,
+properties that are not directly accessible as an attribute of an operation or
+any straightforward relation between IR components.
+
+The utility of such matchers can be easily demonstrated by slightly modifying
+our original example. If matrix multiplication is expressed as a special case of
+tensor contraction using `linalg.generic` instead of `linalg.matmul`, the
+operation name-based matcher no longer applies. Yet such a representation is
+very common and can appear both in the original input and during the course of
+transformation, e.g., where a higher-dimensional contraction is decomposed into
+loops around a matrix multiplication.
+
+In order to be a (potentially transposed) matrix multiplication, the
+`linalg.generic` operation must have the following features:
+
+
+
+*   Total rank of 3.
+*   Two inputs accessed as projected permutation of iteration dimensions.
+*   One output accessed as projected permutation of iteration dimensions.
+*   Iteration dimensions can be subdivided into LHS parallel, RHS parallel and reduction dimensions.
+*   The body block consists of a multiplication and an addition.
+
+Most of these features can be derived from the properties of the operation,
+e.g., the total rank corresponds to the number of entries in the `iterators`
+attribute, but almost none of them are immediately accessible in the IR or in
+any declarative form, which is usually limited to checking the presence or the
+exact match of an attribute or a type.  The transform dialect allows these
+features to be implemented in the `apply` method of a matcher op and reused
+across multiple matching cases. For structured linear algebra payload
+operations, many such match operations are readily available in the `structured`
+extension. They are sufficient to implement a matrix multiplication matcher
+using the features listed above almost verbatim.
+
+
+```mlir
+transform.named_sequence @match_generic_matmul(
+    %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op {
+  // Match a structured linear algebra operation.
+  transform.match.structured %candidate : !transform.any_op {
+  ^bb0(%c: !transform.any_op):
+    // With a rank equal to 3.
+    %rank = transform.match.structured.rank %c
+      : (!transform.any_op) -> !transform.param<i64>
+    %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
+    transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64>
+
+    // With 2 inputs.
+    %n_ins = transform.match.structured.num_inputs %c
+      : (!transform.any_op) -> !transform.param<i64>
+    %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+    transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64>
+
+    // With 1 output (note that structured ops in destination passing style
+    // has as many inits as outputs).
+    %n_inits = transform.match.structured.num_inits %c
+      : (!transform.any_op) -> !transform.param<i64>
+    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
+    transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
+
+    // All inputs and inits are accessed with a projected permutation.
+    transform.match.structured.input %c[all] {projected_permutation}
+      : !transform.any_op
+    transform.match.structured.init %c[0] {projected_permutation}
+      : !transform.any_op
+
+    // The body is a mulf/addf contraction with appropriate dimensions.
+    transform.match.structured.body %c
+      { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
+    %batch, %lhs, %rhs, %reduction =
+    transform.match.structured.classify_contraction_dims %c
+      : (!transform.any_op)
+      -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
+          !transform.param<i64>)
+
+
+    // There is one of lhs, rhs and reduction dimensions and zero batch
+    // dimensions.
+    %n_batch = transform.num_associations %batch
+      : (!transform.param<i64>) -> !transform.param<i64>
+    %n_lhs = transform.num_associations %lhs
+      : (!transform.param<i64>) -> !transform.param<i64>
+    %n_rhs = transform.num_associations %rhs
+      : (!transform.param<i64>) -> !transform.param<i64>
+    %n_reduction = transform.num_associations %reduction
+      : (!transform.param<i64>) -> !transform.param<i64>
+    %c0 = transform.param.constant 0 : i64 -> !transform.param<i64>
+    transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64>
+    transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64>
+    transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64>
+    transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64>
+  }
+  transform.yield %candidate : !transform.any_op
+}
+```
+
+
+While this example leverages the contraction-specific matchers that have a
+rather non-trivial C++ implementation, the transform dialect is sufficiently
+flexible to implement this reasoning directly if desired. One could, for
+example, obtain the access map of each input as a parameter and extract the
+accessed dimensions as other parameters that can be compared with each other to
+ensure the subscripts are `m,k` for LHS, `k,n` for RHS and `m,n` for the
+init/result given the `m,n,k` notation for loops.
+
diff --git a/mlir/examples/transform/CMakeLists.txt b/mlir/examples/transform/CMakeLists.txt
index 3f3740ad2a8da7..b688aa7461d6f2 100644
--- a/mlir/examples/transform/CMakeLists.txt
+++ b/mlir/examples/transform/CMakeLists.txt
@@ -2,3 +2,4 @@ add_custom_target(TransformExample)
 
 add_subdirectory(Ch2)
 add_subdirectory(Ch3)
+add_subdirectory(Ch4)
diff --git a/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp b/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp
index 1e4367ad469025..3c348c663abad4 100644
--- a/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp
+++ b/mlir/examples/transform/Ch3/transform-opt/transform-opt.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This is the top-level file for the Transform dialect tutorial chapter 2.
+// This is the top-level file for the Transform dialect tutorial chapter 3.
 //
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/examples/transform/Ch4/CMakeLists.txt b/mlir/examples/transform/Ch4/CMakeLists.txt
new file mode 100644
index 00000000000000..c070a04a35a80c
--- /dev/null
+++ b/mlir/examples/transform/Ch4/CMakeLists.txt
@@ -0,0 +1,21 @@
+# For a better top-level template to copy, see examples/standalone.
+
+include_directories(${CMAKE_CURRENT_BINARY_DIR})
+include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
+include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
+
+add_subdirectory(include)
+add_subdirectory(lib)
+
+add_dependencies(TransformExample transform-opt-ch4)
+add_llvm_example(transform-opt-ch4
+  transform-opt/transform-opt.cpp)
+
+target_link_libraries(transform-opt-ch4
+  PRIVATE
+  MLIRIR
+  MLIRMlirOptMain
+  MLIRSideEffectInterfaces
+  MLIRTransformDialectTransforms
+  MyExtensionCh4
+)
diff --git a/mlir/examples/transform/Ch4/include/CMakeLists.txt b/mlir/examples/transform/Ch4/include/CMakeLists.txt
new file mode 100644
index 00000000000000..1f960e590529b3
--- /dev/null
+++ b/mlir/examples/transform/Ch4/include/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Tell Tablegen to use MyExtension.td as input.
+set(LLVM_TARGET_DEFINITIONS MyExtension.td)
+
+# Ask Tablegen to generate op declarations and definitions from ODS.
+mlir_tablegen(MyExtension.h.inc -gen-op-decls)
+mlir_tablegen(MyExtension.cpp.inc -gen-op-defs)
+
+# Add a CMakeTarget we can depend on to ensure the generation happens before the
+# compilation.
+add_public_tablegen_target(MyExtensionCh4IncGen)
+
+# Don't forget to generate the documentation, this will produce a
+# MyExtensionCh4.md under Tutorials/transform
+add_mlir_doc(MyExtension MyExtensionCh4 Tutorials/transform/ -gen-op-doc)
diff --git a/mlir/examples/transform/Ch4/include/MyExtension.h b/mlir/examples/transform/Ch4/include/MyExtension.h
new file mode 100644
index 00000000000000..13e5b3c04b02f1
--- /dev/null
+++ b/mlir/examples/transform/Ch4/include/MyExtension.h
@@ -0,0 +1,30 @@
+//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines Transform dialect extension operations used in the
+// Chapter 4 of the Transform dialect tutorial.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+
+namespace mlir {
+class CallOpInterface;
+namespace func {
+class CallOp;
+} // namespace func
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "MyExtension.h.inc"
+
+// Registers our Transform dialect extension.
+void registerMyExtension(::mlir::DialectRegistry &registry);
diff --git a/mlir/examples/transform/Ch4/include/MyExtension.td b/mlir/examples/transform/Ch4/include/MyExtension.td
new file mode 100644
index 00000000000000..ae58dc37db43fc
--- /dev/null
+++ b/mlir/examples/transform/Ch4/include/MyExtension.td
@@ -0,0 +1,46 @@
+//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines Transform dialect extension operations used in the
+// Chapter 4 of the Transform dialect tutorial.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MY_EXTENSION
+#define MY_EXTENSION
+
+include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+// Define the new operation. By convention, prefix its name with `match`
+// followed by the name of the dialect extension.
+def HasOperandSatisfyingOp : TransformDialectOp<"match.my.has_operand_satisfying",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     // Indicate that the operation implements MatchOpInterface in addition to
+     // the TransformOpInterface. This interface is only used as a tag at this
+     // point and has no methods that are mandatory to implement.
+     MatchOpInterface,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+  let summary = "Succeed if any of the operands matches all nested criteria";
+  let arguments = (ins TransformHandleTypeInterface:$op);
+  let results = (outs TransformParamTypeInterface:$position,
+                      Variadic<Transform_AnyHandleOrParamType>:$results);
+
+  // Match operations can be arbitrarily complex, e.g., containing regions.
+  let regions = (region SizedRegion<1>:$body);
+  let hasVerifier = 1;
+  let assemblyFormat = [{
+    $op `:` functional-type($op, results) attr-dict-with-keyword $body
+  }];
+}
+
+#endif // MY_EXTENSION
diff --git a/mlir/examples/transform/Ch4/lib/CMakeLists.txt b/mlir/examples/transform/Ch4/lib/CMakeLists.txt
new file mode 100644
index 00000000000000..33338a679af3ce
--- /dev/null
+++ b/mlir/examples/transform/Ch4/lib/CMakeLists.txt
@@ -0,0 +1,20 @@
+# Outside examples, this should be `add_mlir_library`.
+add_mlir_example_library(
+  # Library called MyExtension.
+  MyExtensionCh4
+
+  # Built from the following source files.
+  MyExtension.cpp
+
+  # Make includes visible without top-level path.
+  ADDITIONAL_HEADER_DIRS
+  ${PROJECT_SOURCE_DIR}/examples/transform/Ch4/include
+
+  # Make sure ODS declaration and definitions are generated before compiling this.
+  DEPENDS
+  MyExtensionCh4IncGen
+
+  # Link in the transform dialect, an all generated dialects.
+  LINK_LIBS PRIVATE
+  MLIRTransformDialect
+)
diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp
new file mode 100644
index 00000000000000..26e348f2a30ec6
--- /dev/null
+++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp
@@ -0,0 +1,207 @@
+//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines Transform dialect extension operations used in the
+// Chapter 4 of the Transform dialect tutorial.
+//
+//===----------------------------------------------------------------------===//
+
+#include "MyExtension.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE_MATCHER "transform-matcher"
+#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
+#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
+
+#define GET_OP_CLASSES
+#include "MyExtension.cpp.inc"
+
+//===---------------------------------------------------------------------===//
+// MyExtension
+//===---------------------------------------------------------------------===//
+
+// Define a new transform dialect extension. This uses the CRTP idiom to
+// identify extensions.
+class MyExtension
+    : public ::mlir::transform::TransformDialectExtension<MyExtension> {
+public:
+  // The extension must derive the base constructor.
+  using Base::Base;
+
+  // This function initializes the extension, similarly to `initialize` in
+  // dialect definitions. List individual operations and dependent dialects
+  // here.
+  void init();
+};
+
+void MyExtension::init() {
+  // Register the additional match operations with the dialect similarly to
+  // other transform operations. List all operations generated from ODS. This
+  // call will perform additional checks that the operations implement the
+  // transform and memory effect interfaces required by the dialect interpreter
+  // and assert if they do not.
+  registerTransformOps<
+#define GET_OP_LIST
+#include "MyExtension.cpp.inc"
+      >();
+}
+
+//===---------------------------------------------------------------------===//
+// HasOperandSatisfyingOp
+//===---------------------------------------------------------------------===//
+
+/// Returns `true` if both types implement one of the interfaces provided as
+/// template parameters.
+template <typename... Tys>
+static bool implementSameInterface(mlir::Type t1, mlir::Type t2) {
+  return ((llvm::isa<Tys>(t1) && llvm::isa<Tys>(t2)) || ... || false);
+}
+
+/// Returns `true` if both types implement one of the transform dialect
+/// interfaces.
+static bool implementSameTransformInterface(mlir::Type t1, mlir::Type t2) {
+  return implementSameInterface<
+      mlir::transform::TransformHandleTypeInterface,
+      mlir::transform::TransformParamTypeInterface,
+      mlir::transform::TransformValueHandleTypeInterface>(t1, t2);
+}
+
+// Matcher ops implement `apply` similarly to other transform ops. They are not
+// expected to modify payload, but use the tri-state result to signal failure or
+// success to match, as well as potential irrecoverable errors.
+mlir::DiagnosedSilenceableFailure
+mlir::transform::HasOperandSatisfyingOp::apply(
+    mlir::transform::TransformRewriter &rewriter,
+    mlir::transform::TransformResults &results,
+    mlir::transform::TransformState &state) {
+  // For simplicity, only handle a single payload op. Actual implementations
+  // can use `SingleOpMatcher` trait to simplify implementation and document
+  // this expectation.
+  auto payloadOps = state.getPayloadOps(getOp());
+  if (!llvm::hasSingleElement(payloadOps))
+    return emitSilenceableError() << "expected single payload";
+
+  // Iterate over all operands of the payload op to see if they can be matched
+  // using the body of this op.
+  Operation *payload = *payloadOps.begin();
+  for (OpOperand &operand : payload->getOpOperands()) {
+    // Create a scope for transform values defined in the body. This corresponds
+    // to the syntactic scope of the region attached to this op. Any values
+    // associated with payloads from now on will be automatically dissociated
+    // when this object is destroyed, i.e. at the end of the iteration.
+    // Associate the block argument handle with the operand.
+    auto matchScope = state.make_region_scope(getBody());
+    if (failed(state.mapBlockArgument(getBody().getArgument(0),
+                                      {operand.get()}))) {
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+
+    // Iterate over all nested matchers with the current mapping and see if they
+    // succeed.
+    bool matchSucceeded = true;
+    for (Operation &matcher : getBody().front().without_terminator()) {
+      // Matcher ops are applied similarly to any other transform op.
+      DiagnosedSilenceableFailure diag =
+          state.applyTransform(cast<TransformOpInterface>(matcher));
+
+      // Definite failures are immediately propagated as they are irrecoverable.
+      if (diag.isDefiniteFailure())
+        return diag;
+
+      // On success, keep checking the remaining conditions.
+      if (diag.succeeded())
+        continue;
+
+      // Report failure-to-match for debugging purposes and stop matching this
+      // operand.
+      assert(diag.isSilenceableFailure());
+      DEBUG_MATCHER(DBGS_MATCHER()
+                    << "failed to match operand #" << operand.getOperandNumber()
+                    << ": " << diag.getMessage());
+      (void)diag.silence();
+      matchSucceeded = false;
+      break;
+    }
+    // If failed to match this operand, try other operands.
+    if (!matchSucceeded)
+      continue;
+
+    // If we reached this point, the matching succeeded for the current operand.
+    // Remap the values associated with terminator operands to be associated
+    // with op results, and also map the parameter result to the operand's
+    // position. Note that it is safe to do here despite the end of the scope
+    // as `results` are integrated into `state` by the interpreter after `apply`
+    // returns rather than immediately.
+    SmallVector<SmallVector<MappedValue>> yieldedMappings;
+    transform::detail::prepareValueMappings(
+        yieldedMappings, getBody().front().getTerminator()->getOperands(),
+        state);
+    results.setParams(getPosition().cast<OpResult>(),
+                      {rewriter.getI32IntegerAttr(operand.getOperandNumber())});
+    for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
+      results.setMappedValues(result, mapping);
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  // If we reached this point, none of the operands succeeded the match.
+  return emitSilenceableError()
+         << "none of the operands satisfied the conditions";
+}
+
+// By convention, operations implementing MatchOpInterface must not modify
+// payload IR and must therefore specify that they only read operand handles and
+// payload as their effects.
+void mlir::transform::HasOperandSatisfyingOp::getEffects(
+    llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects) {
+  onlyReadsPayload(effects);
+  onlyReadsHandle(getOp(), effects);
+  producesHandle(getPosition(), effects);
+  producesHandle(getResults(), effects);
+}
+
+// Verify well-formedness of the operation and emit diagnostics if it is
+// ill-formed.
+mlir::LogicalResult mlir::transform::HasOperandSatisfyingOp::verify() {
+  mlir::Block &bodyBlock = getBody().front();
+  if (bodyBlock.getNumArguments() != 1 ||
+      !isa<TransformValueHandleTypeInterface>(
+          bodyBlock.getArgument(0).getType())) {
+    return emitOpError()
+           << "expects the body to have one value handle argument";
+  }
+  if (bodyBlock.getTerminator()->getNumOperands() != getNumResults() - 1) {
+    return emitOpError() << "expects the body to yield "
+                         << (getNumResults() - 1) << " values, got "
+                         << bodyBlock.getTerminator()->getNumOperands();
+  }
+  for (auto &&[i, operand, result] :
+       llvm::enumerate(bodyBlock.getTerminator()->getOperands().getTypes(),
+                       getResults().getTypes())) {
+    if (implementSameTransformInterface(operand, result))
+      continue;
+    return emitOpError() << "expects terminator operand #" << i
+                         << " and result #" << (i + 1)
+                         << " to implement the same transform interface";
+  }
+
+  for (Operation &op : bodyBlock.without_terminator()) {
+    if (!isa<TransformOpInterface>(op) || !isa<MatchOpInterface>(op)) {
+      InFlightDiagnostic diag = emitOpError()
+                                << "expects body to contain match ops";
+      diag.attachNote(op.getLoc()) << "non-match operation";
+      return diag;
+    }
+  }
+
+  return success();
+}
+
+void registerMyExtension(::mlir::DialectRegistry &registry) {
+  registry.addExtensions<MyExtension>();
+}
diff --git a/mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp b/mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp
new file mode 100644
index 00000000000000..10190664b51cdf
--- /dev/null
+++ b/mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp
@@ -0,0 +1,55 @@
+//===-- transform-opt.cpp - Transform dialect tutorial entry point --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the top-level file for the Transform dialect tutorial chapter 4.
+//
+//===----------------------------------------------------------------------===//
+
+#include "MyExtension.h"
+
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/InitAllExtensions.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "mlir/Transforms/Passes.h"
+#include <cstdlib>
+
+namespace test {
+void registerTestTransformDialectExtension(mlir::DialectRegistry &);
+} // namespace test
+
+int main(int argc, char **argv) {
+  // Register all "core" dialects and our transform dialect extension.
+  mlir::DialectRegistry registry;
+  mlir::registerAllDialects(registry);
+  mlir::registerAllExtensions(registry);
+  registerMyExtension(registry);
+
+  // Register a handful of cleanup passes that we can run to make the output IR
+  // look nicer.
+  mlir::registerCanonicalizerPass();
+  mlir::registerCSEPass();
+  mlir::registerSymbolDCEPass();
+  mlir::transform::registerInterpreterPass();
+
+  // Register the test passes.
+#ifdef MLIR_INCLUDE_TESTS
+  test::registerTestTransformDialectExtension(registry);
+#else
+  llvm::errs() << "warning: MLIR built without test extension, interpreter "
+                  "testing will not be available\n";
+#endif // MLIR_INCLUDE_TESTS
+
+  // Delegate to the MLIR utility for parsing and pass management.
+  return mlir::MlirOptMain(argc, argv, "transform-opt-ch4", registry)
+                 .succeeded()
+             ? EXIT_SUCCESS
+             : EXIT_FAILURE;
+}
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 3f312164cb1f35..be90407f54d01a 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -161,6 +161,7 @@ if(LLVM_BUILD_EXAMPLES)
     toyc-ch5
     transform-opt-ch2
     transform-opt-ch3
+    transform-opt-ch4
     mlir-minimal-opt
     )
   if(MLIR_ENABLE_EXECUTION_ENGINE)
diff --git a/mlir/test/Examples/transform/Ch4/features.mlir b/mlir/test/Examples/transform/Ch4/features.mlir
new file mode 100644
index 00000000000000..9a2af474aa4fa6
--- /dev/null
+++ b/mlir/test/Examples/transform/Ch4/features.mlir
@@ -0,0 +1,123 @@
+// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
+
+// Matmul as a named operation.
+func.func @named(
+    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
+    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
+    -> tensor<512x512xf32> {
+  // expected-remark @below {{matmul}}
+  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
+                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
+  func.return %matmul : tensor<512x512xf32>
+}
+
+// Matmul as a generic operation.
+func.func @generic(
+    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
+    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
+    -> tensor<512x512xf32> {
+  // expected-remark @below {{matmul}}
+  %matmul = linalg.generic {
+    iterator_types = ["parallel", "parallel", "reduction"],
+    indexing_maps = [
+      affine_map<(d0, d1, d2) -> (d0, d2)>,
+      affine_map<(d0, d1, d2) -> (d2, d1)>,
+      affine_map<(d0, d1, d2) -> (d0, d1)>]
+  } ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
+    outs(%output: tensor<512x512xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+    %0 = arith.mulf %arg0, %arg1 : f32
+    %1 = arith.addf %0, %arg2 : f32
+    linalg.yield %1 : f32
+  } -> tensor<512x512xf32>
+  return %matmul : tensor<512x512xf32>
+}
+
+// The module containing named sequences must have an attribute allowing them
+// to enable verification.
+module @transforms attributes { transform.with_named_sequence } {
+  // Entry point. This takes as the only argument the root operation (typically
+  // pass root) given to the transform interpreter.
+  transform.named_sequence @__transform_main(
+      %root: !transform.any_op {transform.consumed}) {
+
+    // Traverses the payload IR associated with the operand handle, invoking
+    // @match_matmul_elemwise on each of the operations. If the named sequence
+    // succeeds, i.e., if none of the nested match (transform) operations
+    // produced a silenceable failure, invokes @print_matmul_elemwise and
+    // forwards the values yielded as arguments of the new invocation. If the
+    // named sequence fails with a silenceable failure, silences it (the message
+    // is forwarded to the debug stream). Definite failures are propagated
+    // immediately and unconditionally, as usual.
+    transform.foreach_match in %root
+      @match_generic_matmul -> @print_generic_matmul
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.yield
+  }
+
+  // This is an action sequence.
+  transform.named_sequence @print_generic_matmul(
+      %matmul: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @match_generic_matmul(
+      %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    // Match a structured linear algebra operation.
+    transform.match.structured %candidate : !transform.any_op {
+    ^bb0(%c: !transform.any_op):
+      // With a rank equal to 3.
+      %rank = transform.match.structured.rank %c
+        : (!transform.any_op) -> !transform.param<i64>
+      %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
+      transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64>
+
+      // With 2 inputs.
+      %n_ins = transform.match.structured.num_inputs %c
+        : (!transform.any_op) -> !transform.param<i64>
+      %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+      transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64>
+
+      // With 1 output (note that structured ops in destination passing style
+      // has as many inits as outputs).
+      %n_inits = transform.match.structured.num_inits %c
+        : (!transform.any_op) -> !transform.param<i64>
+      %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
+      transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
+
+      // All inputs and inits are accessed with a projected permutation.
+      transform.match.structured.input %c[all] {projected_permutation}
+        : !transform.any_op
+      transform.match.structured.init %c[0] {projected_permutation}
+        : !transform.any_op
+
+      // The body is a mulf/addf contraction with appropriate dimensions.
+      transform.match.structured.body %c 
+        { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
+      %batch, %lhs, %rhs, %reduction =
+      transform.match.structured.classify_contraction_dims %c
+        : (!transform.any_op)
+        -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
+            !transform.param<i64>)
+
+      // There is one of lhs, rhs and reduction dimensions and zero batch
+      // dimensions.
+      %n_batch = transform.num_associations %batch
+        : (!transform.param<i64>) -> !transform.param<i64>
+      %n_lhs = transform.num_associations %lhs
+        : (!transform.param<i64>) -> !transform.param<i64>
+      %n_rhs = transform.num_associations %rhs
+        : (!transform.param<i64>) -> !transform.param<i64>
+      %n_reduction = transform.num_associations %reduction
+        : (!transform.param<i64>) -> !transform.param<i64>
+      %c0 = transform.param.constant 0 : i64 -> !transform.param<i64>
+      transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64>
+      transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64>
+      transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64>
+      transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64>
+    }
+    transform.yield %candidate : !transform.any_op
+  }
+}
diff --git a/mlir/test/Examples/transform/Ch4/multiple.mlir b/mlir/test/Examples/transform/Ch4/multiple.mlir
new file mode 100644
index 00000000000000..22ef7c99f86a36
--- /dev/null
+++ b/mlir/test/Examples/transform/Ch4/multiple.mlir
@@ -0,0 +1,131 @@
+// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
+
+// Matmul+ReLU.
+func.func @fc_relu_operands_00(
+    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
+    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
+    -> tensor<512x512xf32> {
+  // Matrix-matrix multiplication.
+  // expected-remark @below {{matmul # 0}}
+  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
+                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
+
+  // Elementwise addition.
+  // expected-remark @below {{add # 0}}
+  %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
+    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
+    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
+
+  // Elementwise max with 0 (ReLU).
+  %c0f = arith.constant 0.0 : f32
+  // expected-remark @below {{max # 0}}
+  %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
+    ins(%biased, %c0f : tensor<512x512xf32>, f32)
+    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
+  func.return %relued : tensor<512x512xf32>
+}
+
+// Matmul+ReLU with swapped operands.
+func.func @fc_relu_operands_01(
+    %lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
+    %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
+    -> tensor<512x512xf32> {
+  // Matrix-matrix multiplication.
+  // expected-remark @below {{matmul # 1}}
+  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
+                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
+
+  // Elementwise addition.
+  // expected-remark @below {{add # 1}}
+  %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
+    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
+    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
+
+  // Elementwise max with 0 (ReLU).
+  %c0f = arith.constant 0.0 : f32
+  // expected-remark @below {{max # 1}}
+  %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
+    ins(%c0f, %biased : f32, tensor<512x512xf32>)
+    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
+  func.return %relued : tensor<512x512xf32>
+}
+
+// The module containing named sequences must have an attribute allowing them
+// to enable verification.
+module @transforms attributes { transform.with_named_sequence } {
+  // Entry point. This takes as the only argument the root operation (typically
+  // pass root) given to the transform interpreter.
+  transform.named_sequence @__transform_main(
+      %root: !transform.any_op {transform.consumed}) {
+
+    // Traverses the payload IR associated with the operand handle, invoking
+    // @match_matmul_elemwise on each of the operations. If the named sequence
+    // succeeds, i.e., if none of the nested match (transform) operations
+    // produced a silenceable failure, invokes @print_matmul_elemwise and
+    // forwards the values yielded as arguments of the new invocation. If the
+    // named sequence fails with a silenceable failure, silences it (the message
+    // is forwarded to the debug stream). Definite failures are propagated
+    // immediately and unconditionally, as usual.
+    transform.foreach_match in %root
+      @match_matmul_elemwise -> @print_matmul_elemwise
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.yield
+  }
+
+  // This is an action sequence.
+  transform.named_sequence @print_matmul_elemwise(
+      %matmul: !transform.any_op {transform.readonly},
+      %add: !transform.any_op {transform.readonly},
+      %max: !transform.any_op {transform.readonly},
+      %pos: !transform.param<i32> {transform.readonly}) {
+    transform.test_print_param %pos, "matmul #" at %matmul
+      : !transform.param<i32>, !transform.any_op
+    transform.test_print_param %pos, "add #" at %add
+      : !transform.param<i32>, !transform.any_op
+    transform.test_print_param %pos, "max #" at %max
+      : !transform.param<i32>, !transform.any_op
+    transform.yield
+  }
+
+  // This is also a matcher sequence. It is similarly given an operation to
+  // match and nested operations must succeed in order for a match to be deemed
+  // successful. It starts matching from the last operation in the use-def chain
+  // and goes back because each operand (use) has exactly one definition.
+  transform.named_sequence @match_matmul_elemwise(
+      %last: !transform.any_op {transform.readonly}) 
+      -> (!transform.any_op, !transform.any_op, !transform.any_op,
+          !transform.param<i32>) {
+    // The last operation must be an elementwise binary.
+    transform.match.operation_name %last ["linalg.elemwise_binary"]
+      : !transform.any_op
+
+    // One of its operands must be defined by another operation, to which we
+    // will get a handle here. This is achieved thanks to a newly defined
+    // operation that tries to match operands one by one using the match
+    // operations nested in its region.
+    %pos, %middle = transform.match.my.has_operand_satisfying %last
+        : (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) {
+    ^bb0(%operand: !transform.any_value):
+      // The operand must be defined by an operation.
+      %def = transform.get_defining_op %operand 
+        : (!transform.any_value) -> !transform.any_op
+      // The defining operation must itself be an elementwise binary.
+      transform.match.operation_name %def ["linalg.elemwise_binary"]
+        : !transform.any_op
+      transform.yield %def : !transform.any_op
+    }
+    
+    // And the first operand of that operation must be defined by yet another
+    // operation.
+    %matmul = transform.get_producer_of_operand %middle[0]
+      : (!transform.any_op) -> !transform.any_op
+    // And that operation is a matmul.
+    transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
+    // We will yield the handles to the matmul and the two elementwise
+    // operations separately. 
+    transform.yield %matmul, %middle, %last, %pos
+      : !transform.any_op, !transform.any_op, !transform.any_op,
+        !transform.param<i32>
+  }
+}
diff --git a/mlir/test/Examples/transform/Ch4/sequence.mlir b/mlir/test/Examples/transform/Ch4/sequence.mlir
new file mode 100644
index 00000000000000..28c3e9649bd956
--- /dev/null
+++ b/mlir/test/Examples/transform/Ch4/sequence.mlir
@@ -0,0 +1,139 @@
+// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
+//
+// RUN: transform-opt-ch4 %s \
+// RUN:              --transform-interpreter='entry-point=__transform_main_v2' \
+// RUN:              --verify-diagnostics
+
+// ****************************** IMPORTANT NOTE ******************************
+//
+// If you are changing this file, you may also need to change
+// mlir/docs/Tutorials/Transform accordingly.
+//
+// ****************************************************************************
+
+// Original function to optimize.
+func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
+                   %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
+                   -> tensor<512x512xf32> {
+  // Matrix-matrix multiplication.
+  // expected-remark @below {{matmul}}
+  %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
+                          outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
+
+  // Elementwise addition.
+  // expected-remark @below {{elementwise binary}}
+  %biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
+    ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
+    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
+
+  // Elementwise max with 0 (ReLU).
+  %c0f = arith.constant 0.0 : f32
+  // expected-remark @below {{elementwise binary}}
+  %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
+    ins(%biased, %c0f : tensor<512x512xf32>, f32)
+    outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
+  func.return %relued : tensor<512x512xf32>
+}
+
+// The module containing named sequences must have an attribute allowing them
+// to enable verification.
+module @transforms attributes { transform.with_named_sequence } {
+  // Entry point. This takes as the only argument the root operation (typically
+  // pass root) given to the transform interpreter.
+  transform.named_sequence @__transform_main(
+      %root: !transform.any_op {transform.readonly}) {
+    // Collect operations that match the criteria specified in the named
+    // sequence. If the named sequence fails with a silenceable failure,
+    // silences it (the message is forwarded to the debug stream). If the named
+    // sequence succeeds, appends its results to the results of this operation.
+    %elemwise = transform.collect_matching @match_elemwise in %root
+      : (!transform.any_op) -> !transform.any_op
+    %matmul = transform.collect_matching @match_matmul in %root
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.include @print_elemwise failures(propagate)  (%elemwise)
+      : (!transform.any_op) -> ()
+    transform.include @print_matmul failures(propagate)  (%matmul)
+      : (!transform.any_op) -> ()
+
+    transform.yield
+  }
+
+  // Alternative entry point.
+  transform.named_sequence @__transform_main_v2(
+      %root: !transform.any_op {transform.readonly}) {
+    // Collect groups of operations that match the criteria specified in the
+    // named sequence.
+    %matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root 
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %elemwise = transform.merge_handles %el1, %el2 : !transform.any_op
+
+    transform.include @print_elemwise failures(propagate)  (%elemwise)
+      : (!transform.any_op) -> ()
+    transform.include @print_matmul failures(propagate)  (%matmul)
+      : (!transform.any_op) -> ()
+
+    transform.yield
+  }
+
+  // This is a matcher sequence. It is given an operation to match and the
+  // match is considered successful unless any nested operation produces a
+  // failure. The values yielded by this operation will be forwarded to the
+  // rewriter sequence on success.
+  transform.named_sequence @match_elemwise(
+      %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.match.operation_name %entry ["linalg.elemwise_binary"] 
+      : !transform.any_op
+    transform.yield %entry : !transform.any_op
+  }
+  transform.named_sequence @match_matmul(
+      %entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op
+    transform.yield %entry : !transform.any_op
+  }
+
+  // This is an action sequence.
+  transform.named_sequence @print_elemwise(
+      %elemwise_binary: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand
+      %elemwise_binary, "elementwise binary" : !transform.any_op
+    transform.yield
+  }
+  transform.named_sequence @print_matmul(
+      %matmul: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op
+    transform.yield
+  }
+
+  // This is also a matcher sequence. It is similarly given an operation to
+  // match and nested operations must succeed in order for a match to be deemed
+  // successful. It starts matching from the last operation in the use-def chain
+  // and goes back because each operand (use) has exactly one definition.
+  transform.named_sequence @match_matmul_elemwise(
+      %last: !transform.any_op {transform.readonly}) 
+      -> (!transform.any_op, !transform.any_op, !transform.any_op) {
+    // The last operation must be an elementwise binary.
+    transform.match.operation_name %last ["linalg.elemwise_binary"]
+      : !transform.any_op
+    // Its first operand must be defined by another operation, to which we
+    // will get a handle here. We are guaranteed that the first operand exists
+    // because we know the operation is binary, but even in absence of such a
+    // guarantee, this operation would have produced a silenceable failure when
+    // `%last` does not have enough operands.
+    %middle = transform.get_producer_of_operand %last[0]
+      : (!transform.any_op) -> !transform.any_op
+    // The defining operation must itself be an elementwise binary.
+    transform.match.operation_name %middle ["linalg.elemwise_binary"]
+      : !transform.any_op
+    // And the first operand of that operation must be defined by yet another
+    // operation.
+    %matmul = transform.get_producer_of_operand %middle[0]
+      : (!transform.any_op) -> !transform.any_op
+    // And that operation is a matmul.
+    transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
+    // We will yield the handles to the matmul and the two elementwise
+    // operations separately. 
+    transform.yield %matmul, %middle, %last
+      : !transform.any_op, !transform.any_op, !transform.any_op
+  }
+}
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 5b92491175e5b3..dcbf2de1ca974c 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -156,6 +156,7 @@ def add_runtime(name):
         ToolSubst("toyc-ch7", unresolved="ignore"),
         ToolSubst('transform-opt-ch2', unresolved='ignore'),
         ToolSubst('transform-opt-ch3', unresolved='ignore'),
+        ToolSubst('transform-opt-ch4', unresolved='ignore'),
         ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"),
         ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"),
     ]



More information about the Mlir-commits mailing list