[Mlir-commits] [mlir] f90b609 - [mlir] introduce transform.num_associations (#76723)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 3 04:33:22 PST 2024


Author: Oleksandr "Alex" Zinenko
Date: 2024-01-03T13:33:18+01:00
New Revision: f90b6090048b277cad44dd49067c26deec3e110c

URL: https://github.com/llvm/llvm-project/commit/f90b6090048b277cad44dd49067c26deec3e110c
DIFF: https://github.com/llvm/llvm-project/commit/f90b6090048b277cad44dd49067c26deec3e110c.diff

LOG: [mlir] introduce transform.num_associations (#76723)

Add a new transform operation that creates a new parameter containing the number of payload objects (operations, values or attributes) associated with the argument. This is useful in matching and for debugging purposes. This replaces three ad-hoc operations previously provided by the test extension.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
    mlir/test/Dialect/Linalg/transform-op-match.mlir
    mlir/test/Dialect/Linalg/transform-op-pad.mlir
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/Dialect/Transform/test-loop-transforms.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 307257f4a582be..fcdb21d21503a1 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -438,6 +438,28 @@ def CastOp : TransformDialectOp<"cast",
   }];
 }
 
+def NumAssociationsOp : TransformDialectOp<"num_associations",
+    [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     MatchOpInterface]> {
+  let summary =
+    "Returns the number of payload objects associated with the argument";
+  let description = [{
+    Given an argument, handle or parameter, returns a new parameter associated
+    with a single 64-bit number that corresponds to the number of payload
+    objects (operations or values for a handle, attributes for a parameter)
+    associated with the argument.
+
+    Always succeeds.
+  }];
+  let arguments = (ins Transform_AnyHandleOrParamType:$handle);
+  let results = (outs TransformParamTypeInterface:$num);
+  let assemblyFormat = [{
+    $handle attr-dict `:` functional-type(operands, results)
+  }];
+  let hasVerifier = 1;
+}
+
 def ForeachMatchOp : TransformDialectOp<"foreach_match", [
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
     DeclareOpInterfaceMethods<SymbolUserOpInterface>,

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7136e423470a28..aa4694c88d3b2a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -32,6 +32,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
 
@@ -1974,6 +1975,42 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
                     /*extraBindingTypes=*/TypeRange(), bodyBuilder);
 }
 
+//===----------------------------------------------------------------------===//
+// NumAssociationsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
+  size_t numAssociations =
+      llvm::TypeSwitch<Type, size_t>(getHandle().getType())
+          .Case([&](TransformHandleTypeInterface opHandle) {
+            return llvm::range_size(state.getPayloadOps(getHandle()));
+          })
+          .Case([&](TransformValueHandleTypeInterface valueHandle) {
+            return llvm::range_size(state.getPayloadValues(getHandle()));
+          })
+          .Case([&](TransformParamTypeInterface param) {
+            return llvm::range_size(state.getParams(getHandle()));
+          })
+          .Default([](Type) {
+            llvm_unreachable("unknown kind of transform dialect type");
+            return 0;
+          });
+  results.setParams(getNum().cast<OpResult>(),
+                    rewriter.getI64IntegerAttr(numAssociations));
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::NumAssociationsOp::verify() {
+  // Verify that the result type accepts an i64 attribute as payload.
+  auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
+  return resultType
+      .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
+      .checkAndReport();
+}
+
 //===----------------------------------------------------------------------===//
 // SelectOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index 49a52ba9e06f86..aa15ccf0beeee2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -36,13 +36,15 @@ module attributes {transform.with_named_sequence} {
 
     // Ensure that one linalg.fill was generated.
     %fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+    %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+    transform.test_print_param %p : !transform.param<i64>
 
     // Ensure that one linalg.copy was generated.
     %mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op
+    %p2 = transform.num_associations %mat : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op
+    transform.test_print_param %p2 : !transform.param<i64>
     transform.yield
   }
 }
@@ -73,18 +75,21 @@ module attributes {transform.with_named_sequence} {
 
     // Ensure that one linalg.fill was generated.
     %fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+    %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+    transform.test_print_param %p : !transform.param<i64>
 
     // Ensure that one linalg.copy was generated.
     %linalg_copy = transform.select "linalg.copy" in %new : (!transform.any_op) -> !transform.any_op
+    %p2 = transform.num_associations %linalg_copy : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op
+    transform.test_print_param %p2 : !transform.param<i64>
 
     // Ensure that one memref.alloca was generated.
     %alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op
+    %p3 = transform.num_associations %alloca : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op
+    transform.test_print_param %p3 : !transform.param<i64>
 
     // Make sure that One-Shot Bufferize can bufferize the rest.
     %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op

diff  --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 15942db9b5db20..db5b5f1c786776 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -134,8 +134,9 @@ module attributes {transform.with_named_sequence} {
           #linalg.iterator_type<parallel>,
           #linalg.iterator_type<reduction>]}
         in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-remark @below {{0}}
-    transform.test_print_number_of_associated_payload_ir_ops %no_match : !transform.any_op
+    %p = transform.num_associations %no_match : (!transform.any_op) -> !transform.param<i64>
+    // expected-remark @below {{0}}
+    transform.test_print_param %p : !transform.param<i64>
     transform.yield
   }
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 6bca6c1fd6bf12..1f9d81a819e7fb 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -41,8 +41,9 @@ module attributes {transform.with_named_sequence} {
       padding_dimensions=[0, 1, 2],
       pack_paddings=[1, 1, 0]
     } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
+    %p = transform.num_associations %copy_back : (!transform.op<"bufferization.materialize_in_destination">) -> !transform.param<i64>
     // expected-remark @below {{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
+    transform.test_print_param %p : !transform.param<i64>
     transform.yield
   }
 }

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 09641615887981..5123958b02bfb8 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -696,3 +696,11 @@ transform.sequence failures(propagate) {
     transform.named_sequence @foo()
   } : !transform.any_op
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
+  transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index d9a11994eb9d90..a39e6f94cb34f6 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -575,8 +575,9 @@ transform.with_pdl_patterns {
     %0 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
     %2 = merge_handles deduplicate %0, %1 : !transform.any_op
+    %3 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below {{1}}
-    test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+    test_print_param %3 : !transform.param<i64>
   }
 }
 
@@ -676,11 +677,13 @@ module {
     ^bb0(%arg1: !transform.any_op):
       %0 = pdl_match @func in %arg1 : (!transform.any_op) -> !transform.any_op
       %1 = replicate num(%0) %arg1 : !transform.any_op, !transform.any_op
+      %p = num_associations %1 : (!transform.any_op) -> !transform.param<i64>
       // expected-remark @below {{2}}
-      test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+      test_print_param %p : !transform.param<i64>
       %2 = replicate num(%0) %1 : !transform.any_op, !transform.any_op
+      %p2 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
       // expected-remark @below {{4}}
-      test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+      test_print_param %p2 : !transform.param<i64>
     }
   }
 }
@@ -708,8 +711,9 @@ transform.with_pdl_patterns {
     %f = pdl_match @const in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.foreach %f : !transform.any_op {
     ^bb2(%arg2: !transform.any_op):
+      %p = transform.num_associations %arg2 : (!transform.any_op) -> !transform.param<i64>
       // expected-remark @below {{1}}
-      transform.test_print_number_of_associated_payload_ir_ops %arg2 : !transform.any_op
+      transform.test_print_param %p : !transform.param<i64>
       transform.test_print_remark_at_operand %arg2, "transform applied" : !transform.any_op
     }
   }
@@ -780,8 +784,9 @@ transform.with_pdl_patterns {
       transform.yield %g : !transform.any_op
     }
 
+    %p = transform.num_associations %results : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below {{3}}
-    transform.test_print_number_of_associated_payload_ir_ops %results : !transform.any_op
+    transform.test_print_param %p : !transform.param<i64>
     transform.test_print_remark_at_operand %results, "transform applied" : !transform.any_op
   }
 }
@@ -877,8 +882,9 @@ transform.sequence failures(propagate) {
 ^bb1(%fun: !transform.any_op):
   %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   %h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+  transform.test_print_param %p : !transform.param<i64>
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
   %h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@@ -896,13 +902,15 @@ transform.sequence failures(suppress) {
 ^bb1(%fun: !transform.any_op):
   %muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   %h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+  transform.test_print_param %p : !transform.param<i64>
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   // Silenceable failure and all handles are now empty.
   %h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+  %p2 = transform.num_associations %h_2#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{0}}
-  transform.test_print_number_of_associated_payload_ir_ops %h_2#0 : !transform.any_op
+  transform.test_print_param %p2 : !transform.param<i64>
 }
 
 // -----
@@ -918,12 +926,15 @@ transform.sequence failures(propagate) {
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   // No error, last result handle is empty.
   %h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+  %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+  transform.test_print_param %p : !transform.param<i64>
+  %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+  transform.test_print_param %p2 : !transform.param<i64>
+  %p3 = transform.num_associations %h#2 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{0}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#2 : !transform.any_op
+  transform.test_print_param %p3 : !transform.param<i64>
 }
 
 // -----
@@ -940,10 +951,12 @@ transform.sequence failures(propagate) {
 ^bb1(%fun: !transform.any_op):
   %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
   %h:2 = split_handle %muli_2 {overflow_result = 0} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+  %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{3}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+  transform.test_print_param %p : !transform.param<i64>
+  %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+  transform.test_print_param %p2 : !transform.param<i64>
 }
 
 // -----
@@ -1668,8 +1681,9 @@ transform.sequence failures(propagate) {
   // expected-remark @below {{2 iterations}}
   transform.test_tracked_rewrite %0 : (!transform.any_op) -> ()
   // One replacement op (test.drop_mapping) is dropped from the mapping.
+  %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below {{2}}
-  test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+  test_print_param %p : !transform.param<i64>
 }
 
 // -----
@@ -1684,20 +1698,24 @@ module {
     %2 = transform.param.constant 1 -> !transform.param<i64>
     %3 = transform.param.constant 2 -> !transform.param<i64>
     %4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param<i64>
+    %p = num_associations %4 : (!transform.param<i64>) -> !transform.param<i64>
     // expected-remark @below {{1}}
-    test_print_number_of_associated_payload_ir_params %4 : !transform.param<i64>
+    test_print_param %p : !transform.param<i64>
 
     %5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param<i64>
+    %p2 = num_associations %5 : (!transform.param<i64>) -> !transform.param<i64>
     // expected-remark @below {{1}}
-    test_print_number_of_associated_payload_ir_params %5 : !transform.param<i64>
+    test_print_param %p2 : !transform.param<i64>
 
     %6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param<i64>
+    %p3 = num_associations %6 : (!transform.param<i64>) -> !transform.param<i64>
     // expected-remark @below {{2}}
-    test_print_number_of_associated_payload_ir_params %6 : !transform.param<i64>
+    test_print_param %p3 : !transform.param<i64>
 
     %7 = transform.merge_handles %1, %1, %2, %3 : !transform.param<i64>
+    %p4 = num_associations %7 : (!transform.param<i64>) -> !transform.param<i64>
     // expected-remark @below {{4}}
-    test_print_number_of_associated_payload_ir_params %7 : !transform.param<i64>
+    test_print_param %p4 : !transform.param<i64>
   }
 }
 
@@ -1712,21 +1730,25 @@ transform.sequence failures(propagate) {
   %3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value
 
   %4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value
+  %p = num_associations %4 : (!transform.any_value) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  test_print_number_of_associated_payload_ir_values %4 : !transform.any_value
+  test_print_param %p : !transform.param<i64>
 
   %5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value
+  %p2 = num_associations %5 : (!transform.any_value) -> !transform.param<i64>
   // expected-remark @below {{2}}
-  test_print_number_of_associated_payload_ir_values %5 : !transform.any_value
+  test_print_param %p2 : !transform.param<i64>
 
   %6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
   %7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value
+  %p3 = num_associations %6 : (!transform.any_value) -> !transform.param<i64>
   // expected-remark @below {{1}}
-  test_print_number_of_associated_payload_ir_values %6 : !transform.any_value
+  test_print_param %p3 : !transform.param<i64>
 
   %8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value
+  %p4 = num_associations %8 : (!transform.any_value) -> !transform.param<i64>
   // expected-remark @below {{4}}
-  test_print_number_of_associated_payload_ir_values %8 : !transform.any_value
+  test_print_param %p4 : !transform.param<i64>
 }
 // -----
 
@@ -1820,31 +1842,37 @@ transform.sequence failures(propagate) {
 
   // There are 3 arith.constant ops.
   %all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op
+  %p = num_associations %all : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{3}}
-  test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+  test_print_param %p : !transform.param<i64>
   // "deduplicate" has no effect because these are 3 
diff erent ops.
   %merged_before = transform.merge_handles deduplicate %all : !transform.any_op
+  %p2 = num_associations %merged_before : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{3}}
-  test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op
+  test_print_param %p2 : !transform.param<i64>
 
   // Apply CSE.
   transform.apply_cse to %0 : !transform.any_op
 
   // The handle is still mapped to 3 arith.constant ops.
+  %p3 = num_associations %all : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{3}}
-  test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+  test_print_param %p3 : !transform.param<i64>
   // But they are all the same op.
   %merged_after = transform.merge_handles deduplicate %all : !transform.any_op
+  %p4 = num_associations %merged_after : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{1}}
-  test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op
+  test_print_param %p4 : !transform.param<i64>
 
   // The other handles were also updated.
   test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op
+  %p5 = num_associations %elim_first : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{1}}
-  test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op
+  test_print_param %p5 : !transform.param<i64>
   test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op
+  %p6 = num_associations %elim_second : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{1}}
-  test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op
+  test_print_param %p6 : !transform.param<i64>
 }
 
 // -----
@@ -1907,14 +1935,16 @@ transform.sequence failures(propagate) {
   // Get immediate parent.
   %2 = transform.get_parent_op %0 : (!transform.any_op) -> !transform.any_op
   test_print_remark_at_operand %2, "direct parent" : !transform.any_op
+  %p = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{2}}
-  test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+  test_print_param %p : !transform.param<i64>
 
   // Deduplicate results.
   %3 = transform.structured.match ops{["test.qux"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %4 = transform.get_parent_op %3 {deduplicate} : (!transform.any_op) -> !transform.any_op
+  %p2 = num_associations %4 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{1}}
-  test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
+  test_print_param %p2 : !transform.param<i64>
 }
 
 
@@ -2029,8 +2059,9 @@ transform.sequence failures(propagate) {
   // Match all ops inside the function (including the function itself).
   %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
   %0 = transform.structured.match in %func_op : (!transform.any_op) -> !transform.any_op
+  %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{5}}
-  test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+  test_print_param %p : !transform.param<i64>
 
   // Select "test.foo".
   %foo = transform.select "test.foo" in %0 : (!transform.any_op) -> !transform.any_op
@@ -2060,8 +2091,9 @@ transform.sequence failures(propagate) {
   %empty_op = transform.structured.match ops{["tensor.empty"]} in %func_op : (!transform.any_op) -> !transform.any_op
   transform.apply_dce to %func_op : !transform.any_op
 
+  %p = num_associations %empty_op : (!transform.any_op) -> !transform.param<i64>
   // expected-remark @below{{0}}
-  test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
+  test_print_param %p : !transform.param<i64>
 }
 
 

diff  --git a/mlir/test/Dialect/Transform/test-loop-transforms.mlir b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
index 425962757f720b..c34f4baf8cd972 100644
--- a/mlir/test/Dialect/Transform/test-loop-transforms.mlir
+++ b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
@@ -37,13 +37,16 @@ module attributes {transform.with_named_sequence} {
     // Make sure that the handles are still valid (and were updated in case of
     // the loop).
 
+    %p = transform.num_associations %0 : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+    transform.test_print_param %p : !transform.param<i64>
     transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op
+    %p2 = transform.num_associations %1 : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+    transform.test_print_param %p2 : !transform.param<i64>
+    %p3 = transform.num_associations %2 : (!transform.any_op) -> !transform.param<i64>
     // expected-remark @below{{1}}
-    transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+    transform.test_print_param %p3 : !transform.param<i64>
 
     transform.yield
   }

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index e8c25aca237251..9c69164e33f600 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -456,51 +456,6 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
   return emitDefaultSilenceableFailure(target);
 }
 
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
-    transform::TransformRewriter &rewriter,
-    transform::TransformResults &results, transform::TransformState &state) {
-  if (!getHandle())
-    emitRemark() << 0;
-  emitRemark() << llvm::range_size(state.getPayloadOps(getHandle()));
-  return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getHandle(), effects);
-}
-
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::apply(
-    transform::TransformRewriter &rewriter,
-    transform::TransformResults &results, transform::TransformState &state) {
-  if (!getValueHandle())
-    emitRemark() << 0;
-  emitRemark() << llvm::range_size(state.getPayloadValues(getValueHandle()));
-  return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getValueHandle(), effects);
-}
-
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::apply(
-    transform::TransformRewriter &rewriter,
-    transform::TransformResults &results, transform::TransformState &state) {
-  if (!getParam())
-    emitRemark() << 0;
-  emitRemark() << llvm::range_size(state.getParams(getParam()));
-  return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(getParam(), effects);
-}
-
 DiagnosedSilenceableFailure
 mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
                                      transform::TransformResults &results,

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 41f318db68405b..5cb47659fdbdfd 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -343,33 +343,6 @@ def TestMixedSuccessAndSilenceableOp
   }];
 }
 
-def TestPrintNumberOfAssociatedPayloadIROps
-  : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_ops",
-       [DeclareOpInterfaceMethods<TransformOpInterface>,
-        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let arguments = (ins TransformHandleTypeInterface:$handle);
-  let assemblyFormat = "$handle attr-dict `:` type($handle)";
-  let cppNamespace = "::mlir::test";
-}
-
-def TestPrintNumberOfAssociatedPayloadIRValues
-  : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_values",
-       [DeclareOpInterfaceMethods<TransformOpInterface>,
-        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let arguments = (ins TransformValueHandleTypeInterface:$value_handle);
-  let assemblyFormat = "$value_handle attr-dict `:` type($value_handle)";
-  let cppNamespace = "::mlir::test";
-}
-
-def TestPrintNumberOfAssociatedPayloadIRParams
-  : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_params",
-       [DeclareOpInterfaceMethods<TransformOpInterface>,
-        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let arguments = (ins TransformParamTypeInterface:$param);
-  let assemblyFormat = "$param attr-dict `:` type($param)";
-  let cppNamespace = "::mlir::test";
-}
-
 def TestCopyPayloadOp
   : Op<Transform_Dialect, "test_copy_payload",
        [DeclareOpInterfaceMethods<TransformOpInterface>,


        


More information about the Mlir-commits mailing list