[Mlir-commits] [mlir] 1938039 - [mlir][Transform] Allow parameter and value types in merge_handles op

Quinn Dawkins llvmlistbot at llvm.org
Tue Jun 20 11:02:14 PDT 2023


Author: Quinn Dawkins
Date: 2023-06-20T13:54:38-04:00
New Revision: 193803968f12170c7b381629b15e3bf3c6c10196

URL: https://github.com/llvm/llvm-project/commit/193803968f12170c7b381629b15e3bf3c6c10196
DIFF: https://github.com/llvm/llvm-project/commit/193803968f12170c7b381629b15e3bf3c6c10196.diff

LOG: [mlir][Transform] Allow parameter and value types in merge_handles op

Similar to operation handles, merging handles for other types can be useful to
avoid repetition of common transformations across a set of parameters.
For example, forming a list of static values for comparison rather than
comparing the parameters one at a time.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D153240

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index a5ed64ee262b3..cf49c451286d8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -559,22 +559,22 @@ def MatchParamCmpIOp : Op<Transform_Dialect, "match.param.cmpi", [
 def MergeHandlesOp : TransformDialectOp<"merge_handles",
     [DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-     SameOperandsAndResultType]> {
+     MatchOpInterface, SameOperandsAndResultType]> {
   let summary = "Merges handles into one pointing to the union of payload ops";
   let description = [{
     Creates a new Transform IR handle value that points to the same Payload IR
-    operations as the operand handles. The Payload IR operations are listed
-    in the same order as they are in the operand handles, grouped by operand
-    handle, e.g., all Payload IR operations associated with the first handle
-    come first, then all Payload IR operations associated with the second handle
-    and so on. If `deduplicate` is set, do not add the given Payload IR
-    operation more than once to the final list regardless of it coming from the
+    operations/values/parameters as the operand handles. The Payload IR elements
+    are listed in the same order as they are in the operand handles, grouped by
+    operand handle, e.g., all Payload IR associated with the first handle comes
+    first, then all Payload IR associated with the second handle and so on. If
+    `deduplicate` is set, do not add the given Payload IR operation, value, or
+    parameter more than once to the final list regardless of it coming from the
     same or 
diff erent handles. Consumes the operands and produces a new handle.
   }];
 
-  let arguments = (ins Variadic<TransformHandleTypeInterface>:$handles,
+  let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$handles,
                        UnitAttr:$deduplicate);
-  let results = (outs TransformHandleTypeInterface:$result);
+  let results = (outs Transform_AnyHandleOrParamType:$result);
   let assemblyFormat = "(`deduplicate` $deduplicate^)? $handles attr-dict `:` type($result)";
   let hasFolder = 1;
 }

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 72cf6630c3f80..5b71bf895a626 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1255,16 +1255,48 @@ DiagnosedSilenceableFailure
 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
                                  transform::TransformResults &results,
                                  transform::TransformState &state) {
-  SmallVector<Operation *> operations;
-  for (Value operand : getHandles())
-    llvm::append_range(operations, state.getPayloadOps(operand));
+  ValueRange handles = getHandles();
+  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
+    SmallVector<Operation *> operations;
+    for (Value operand : handles)
+      llvm::append_range(operations, state.getPayloadOps(operand));
+    if (!getDeduplicate()) {
+      results.set(llvm::cast<OpResult>(getResult()), operations);
+      return DiagnosedSilenceableFailure::success();
+    }
+
+    SetVector<Operation *> uniqued(operations.begin(), operations.end());
+    results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
+    SmallVector<Attribute> attrs;
+    for (Value attribute : handles)
+      llvm::append_range(attrs, state.getParams(attribute));
+    if (!getDeduplicate()) {
+      results.setParams(cast<OpResult>(getResult()), attrs);
+      return DiagnosedSilenceableFailure::success();
+    }
+
+    SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
+    results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  assert(
+      llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
+      "expected value handle type");
+  SmallVector<Value> payloadValues;
+  for (Value value : handles)
+    llvm::append_range(payloadValues, state.getPayloadValues(value));
   if (!getDeduplicate()) {
-    results.set(llvm::cast<OpResult>(getResult()), operations);
+    results.setValues(cast<OpResult>(getResult()), payloadValues);
     return DiagnosedSilenceableFailure::success();
   }
 
-  SetVector<Operation *> uniqued(operations.begin(), operations.end());
-  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
+  SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
+  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
   return DiagnosedSilenceableFailure::success();
 }
 

diff  --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index a8678a766f6b5..2c8287130b8b4 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -223,6 +223,11 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 
+  transform.named_sequence @print_dimension_size_match(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "matched sizes" : !transform.any_op
+    transform.yield
+  }
+
   transform.named_sequence @match_dimension_capture(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
     // Capture multiple dimension values. Suppress failures so we can print them anyway after the capture.
     %0:9 = transform.match.structured failures(suppress) %arg0 
@@ -253,9 +258,25 @@ module attributes { transform.with_named_sequence } {
     transform.yield %0#0 : !transform.any_op
   }
 
+  transform.named_sequence @match_dimension_sizes(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+    ^bb0(%arg1: !transform.any_op):
+      %1 = transform.match.structured.dim %arg1[all] : (!transform.any_op) -> !transform.param<i64>
+      %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+      %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
+      %c4 = transform.param.constant 4 : i64 -> !transform.param<i64>
+      %2 = transform.merge_handles %c2, %c3, %c4 : !transform.param<i64>
+      transform.match.param.cmpi eq %1, %2 : !transform.param<i64>
+
+      transform.match.structured.yield %arg1 : !transform.any_op
+    }
+    transform.yield %0 : !transform.any_op
+  }
+
   transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
   ^bb0(%arg0: !transform.any_op):
-    transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op
+    %0 = transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op
+    %1 = transform.foreach_match in %0 @match_dimension_sizes -> @print_dimension_size_match : (!transform.any_op) -> !transform.any_op
   }
 
   func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
@@ -269,6 +290,7 @@ module attributes { transform.with_named_sequence } {
     // expected-remark @below {{dimensions except -1: 2 : i64, 3 : i64}}
     // expected-remark @below {{dimensions except 0, -2: 4 : i64}}
     // expected-remark @below {{dimensions 0, -3:}}
+    // expected-remark @below {{matched sizes}}
     linalg.generic {
       indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel", "reduction"]

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b2d7e7a4bdb5e..bfe61df4e4043 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1622,9 +1622,64 @@ transform.sequence failures(propagate) {
   test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
 }
 
+// -----
+
+// Parameter deduplication happens by value
+
+module {
+
+  transform.sequence failures(propagate) {
+  ^bb0(%0: !transform.any_op):
+    %1 = transform.param.constant 1 -> !transform.param<i64>
+    %2 = transform.param.constant 1 -> !transform.param<i64>
+    %3 = transform.param.constant 2 -> !transform.param<i64>
+    %4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param<i64>
+    // expected-remark @below {{1}}
+    test_print_number_of_associated_payload_ir_params %4 : !transform.param<i64>
+
+    %5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param<i64>
+    // expected-remark @below {{1}}
+    test_print_number_of_associated_payload_ir_params %5 : !transform.param<i64>
+
+    %6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param<i64>
+    // expected-remark @below {{2}}
+    test_print_number_of_associated_payload_ir_params %6 : !transform.param<i64>
+
+    %7 = transform.merge_handles %1, %1, %2, %3 : !transform.param<i64>
+    // expected-remark @below {{4}}
+    test_print_number_of_associated_payload_ir_params %7 : !transform.param<i64>
+  }
+}
 
 // -----
 
+%0:3 = "test.get_two_results"() : () -> (i32, i32, f32)
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %1 = transform.structured.match ops{["test.get_two_results"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  %2 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
+  %3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value
+
+  %4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value
+  // expected-remark @below {{1}}
+  test_print_number_of_associated_payload_ir_values %4 : !transform.any_value
+
+  %5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value
+  // expected-remark @below {{2}}
+  test_print_number_of_associated_payload_ir_values %5 : !transform.any_value
+
+  %6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
+  %7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value
+  // expected-remark @below {{1}}
+  test_print_number_of_associated_payload_ir_values %6 : !transform.any_value
+
+  %8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value
+  // expected-remark @below {{4}}
+  test_print_number_of_associated_payload_ir_values %8 : !transform.any_value
+}
+// -----
+
 // CHECK-LABEL: func @test_annotation()
 //  CHECK-NEXT:   "test.annotate_me"()
 //  CHECK-SAME:                        broadcast_attr = 2 : i64

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 8f53b9c927f0f..4b0a8f0c197e9 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -470,6 +470,36 @@ void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
   transform::onlyReadsHandle(getHandle(), effects);
 }
 
+DiagnosedSilenceableFailure
+mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  if (!getValueHandle())
+    emitRemark() << 0;
+  emitRemark() << llvm::range_size(state.getPayloadValues(getValueHandle()));
+  return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getValueHandle(), effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  if (!getParam())
+    emitRemark() << 0;
+  emitRemark() << llvm::range_size(state.getParams(getParam()));
+  return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getParam(), effects);
+}
+
 DiagnosedSilenceableFailure
 mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
                                      transform::TransformResults &results,

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index fda619053867f..6c0bef9a81ec6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -352,6 +352,24 @@ def TestPrintNumberOfAssociatedPayloadIROps
   let cppNamespace = "::mlir::test";
 }
 
+def TestPrintNumberOfAssociatedPayloadIRValues
+  : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_values",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins TransformValueHandleTypeInterface:$value_handle);
+  let assemblyFormat = "$value_handle attr-dict `:` type($value_handle)";
+  let cppNamespace = "::mlir::test";
+}
+
+def TestPrintNumberOfAssociatedPayloadIRParams
+  : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_params",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let arguments = (ins TransformParamTypeInterface:$param);
+  let assemblyFormat = "$param attr-dict `:` type($param)";
+  let cppNamespace = "::mlir::test";
+}
+
 def TestCopyPayloadOp
   : Op<Transform_Dialect, "test_copy_payload",
        [DeclareOpInterfaceMethods<TransformOpInterface>,


        


More information about the Mlir-commits mailing list