[Mlir-commits] [mlir] 1938039 - [mlir][Transform] Allow parameter and value types in merge_handles op
Quinn Dawkins
llvmlistbot at llvm.org
Tue Jun 20 11:02:14 PDT 2023
Author: Quinn Dawkins
Date: 2023-06-20T13:54:38-04:00
New Revision: 193803968f12170c7b381629b15e3bf3c6c10196
URL: https://github.com/llvm/llvm-project/commit/193803968f12170c7b381629b15e3bf3c6c10196
DIFF: https://github.com/llvm/llvm-project/commit/193803968f12170c7b381629b15e3bf3c6c10196.diff
LOG: [mlir][Transform] Allow parameter and value types in merge_handles op
Similar to operation handles, merging handles for other types can be useful to
avoid repetition of common transformations across a set of parameters.
For example, forming a list of static values for comparison rather than
comparing the parameters one at a time.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D153240
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index a5ed64ee262b3..cf49c451286d8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -559,22 +559,22 @@ def MatchParamCmpIOp : Op<Transform_Dialect, "match.param.cmpi", [
def MergeHandlesOp : TransformDialectOp<"merge_handles",
[DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- SameOperandsAndResultType]> {
+ MatchOpInterface, SameOperandsAndResultType]> {
let summary = "Merges handles into one pointing to the union of payload ops";
let description = [{
Creates a new Transform IR handle value that points to the same Payload IR
- operations as the operand handles. The Payload IR operations are listed
- in the same order as they are in the operand handles, grouped by operand
- handle, e.g., all Payload IR operations associated with the first handle
- come first, then all Payload IR operations associated with the second handle
- and so on. If `deduplicate` is set, do not add the given Payload IR
- operation more than once to the final list regardless of it coming from the
+ operations/values/parameters as the operand handles. The Payload IR elements
+ are listed in the same order as they are in the operand handles, grouped by
+ operand handle, e.g., all Payload IR associated with the first handle comes
+ first, then all Payload IR associated with the second handle and so on. If
+ `deduplicate` is set, do not add the given Payload IR operation, value, or
+ parameter more than once to the final list regardless of it coming from the
same or
diff erent handles. Consumes the operands and produces a new handle.
}];
- let arguments = (ins Variadic<TransformHandleTypeInterface>:$handles,
+ let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$handles,
UnitAttr:$deduplicate);
- let results = (outs TransformHandleTypeInterface:$result);
+ let results = (outs Transform_AnyHandleOrParamType:$result);
let assemblyFormat = "(`deduplicate` $deduplicate^)? $handles attr-dict `:` type($result)";
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 72cf6630c3f80..5b71bf895a626 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1255,16 +1255,48 @@ DiagnosedSilenceableFailure
transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- SmallVector<Operation *> operations;
- for (Value operand : getHandles())
- llvm::append_range(operations, state.getPayloadOps(operand));
+ ValueRange handles = getHandles();
+ if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
+ SmallVector<Operation *> operations;
+ for (Value operand : handles)
+ llvm::append_range(operations, state.getPayloadOps(operand));
+ if (!getDeduplicate()) {
+ results.set(llvm::cast<OpResult>(getResult()), operations);
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ SetVector<Operation *> uniqued(operations.begin(), operations.end());
+ results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
+ SmallVector<Attribute> attrs;
+ for (Value attribute : handles)
+ llvm::append_range(attrs, state.getParams(attribute));
+ if (!getDeduplicate()) {
+ results.setParams(cast<OpResult>(getResult()), attrs);
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
+ results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ assert(
+ llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
+ "expected value handle type");
+ SmallVector<Value> payloadValues;
+ for (Value value : handles)
+ llvm::append_range(payloadValues, state.getPayloadValues(value));
if (!getDeduplicate()) {
- results.set(llvm::cast<OpResult>(getResult()), operations);
+ results.setValues(cast<OpResult>(getResult()), payloadValues);
return DiagnosedSilenceableFailure::success();
}
- SetVector<Operation *> uniqued(operations.begin(), operations.end());
- results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
+ SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
+ results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index a8678a766f6b5..2c8287130b8b4 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -223,6 +223,11 @@ module attributes { transform.with_named_sequence } {
transform.yield
}
+ transform.named_sequence @print_dimension_size_match(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "matched sizes" : !transform.any_op
+ transform.yield
+ }
+
transform.named_sequence @match_dimension_capture(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
// Capture multiple dimension values. Suppress failures so we can print them anyway after the capture.
%0:9 = transform.match.structured failures(suppress) %arg0
@@ -253,9 +258,25 @@ module attributes { transform.with_named_sequence } {
transform.yield %0#0 : !transform.any_op
}
+ transform.named_sequence @match_dimension_sizes(%arg0: !transform.any_op {transform.readonly}) -> (!transform.any_op) {
+ %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ %1 = transform.match.structured.dim %arg1[all] : (!transform.any_op) -> !transform.param<i64>
+ %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+ %c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
+ %c4 = transform.param.constant 4 : i64 -> !transform.param<i64>
+ %2 = transform.merge_handles %c2, %c3, %c4 : !transform.param<i64>
+ transform.match.param.cmpi eq %1, %2 : !transform.param<i64>
+
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+
transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } {
^bb0(%arg0: !transform.any_op):
- transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op
+ %0 = transform.foreach_match in %arg0 @match_dimension_capture -> @do_nothing : (!transform.any_op) -> !transform.any_op
+ %1 = transform.foreach_match in %0 @match_dimension_sizes -> @print_dimension_size_match : (!transform.any_op) -> !transform.any_op
}
func.func @payload(%lhs: tensor<2x4xf32>, %rhs: tensor<4x3xf32>, %out: tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
@@ -269,6 +290,7 @@ module attributes { transform.with_named_sequence } {
// expected-remark @below {{dimensions except -1: 2 : i64, 3 : i64}}
// expected-remark @below {{dimensions except 0, -2: 4 : i64}}
// expected-remark @below {{dimensions 0, -3:}}
+ // expected-remark @below {{matched sizes}}
linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b2d7e7a4bdb5e..bfe61df4e4043 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1622,9 +1622,64 @@ transform.sequence failures(propagate) {
test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
}
+// -----
+
+// Parameter deduplication happens by value
+
+module {
+
+ transform.sequence failures(propagate) {
+ ^bb0(%0: !transform.any_op):
+ %1 = transform.param.constant 1 -> !transform.param<i64>
+ %2 = transform.param.constant 1 -> !transform.param<i64>
+ %3 = transform.param.constant 2 -> !transform.param<i64>
+ %4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param<i64>
+ // expected-remark @below {{1}}
+ test_print_number_of_associated_payload_ir_params %4 : !transform.param<i64>
+
+ %5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param<i64>
+ // expected-remark @below {{1}}
+ test_print_number_of_associated_payload_ir_params %5 : !transform.param<i64>
+
+ %6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param<i64>
+ // expected-remark @below {{2}}
+ test_print_number_of_associated_payload_ir_params %6 : !transform.param<i64>
+
+ %7 = transform.merge_handles %1, %1, %2, %3 : !transform.param<i64>
+ // expected-remark @below {{4}}
+ test_print_number_of_associated_payload_ir_params %7 : !transform.param<i64>
+ }
+}
// -----
+%0:3 = "test.get_two_results"() : () -> (i32, i32, f32)
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+ %1 = transform.structured.match ops{["test.get_two_results"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %2 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
+ %3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value
+
+ %4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value
+ // expected-remark @below {{1}}
+ test_print_number_of_associated_payload_ir_values %4 : !transform.any_value
+
+ %5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value
+ // expected-remark @below {{2}}
+ test_print_number_of_associated_payload_ir_values %5 : !transform.any_value
+
+ %6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
+ %7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value
+ // expected-remark @below {{1}}
+ test_print_number_of_associated_payload_ir_values %6 : !transform.any_value
+
+ %8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value
+ // expected-remark @below {{4}}
+ test_print_number_of_associated_payload_ir_values %8 : !transform.any_value
+}
+// -----
+
// CHECK-LABEL: func @test_annotation()
// CHECK-NEXT: "test.annotate_me"()
// CHECK-SAME: broadcast_attr = 2 : i64
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 8f53b9c927f0f..4b0a8f0c197e9 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -470,6 +470,36 @@ void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
transform::onlyReadsHandle(getHandle(), effects);
}
+DiagnosedSilenceableFailure
+mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::apply(
+ transform::TransformRewriter &rewriter,
+ transform::TransformResults &results, transform::TransformState &state) {
+ if (!getValueHandle())
+ emitRemark() << 0;
+ emitRemark() << llvm::range_size(state.getPayloadValues(getValueHandle()));
+ return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getValueHandle(), effects);
+}
+
+DiagnosedSilenceableFailure
+mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::apply(
+ transform::TransformRewriter &rewriter,
+ transform::TransformResults &results, transform::TransformState &state) {
+ if (!getParam())
+ emitRemark() << 0;
+ emitRemark() << llvm::range_size(state.getParams(getParam()));
+ return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getParam(), effects);
+}
+
DiagnosedSilenceableFailure
mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index fda619053867f..6c0bef9a81ec6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -352,6 +352,24 @@ def TestPrintNumberOfAssociatedPayloadIROps
let cppNamespace = "::mlir::test";
}
+def TestPrintNumberOfAssociatedPayloadIRValues
+ : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_values",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let arguments = (ins TransformValueHandleTypeInterface:$value_handle);
+ let assemblyFormat = "$value_handle attr-dict `:` type($value_handle)";
+ let cppNamespace = "::mlir::test";
+}
+
+def TestPrintNumberOfAssociatedPayloadIRParams
+ : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_params",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let arguments = (ins TransformParamTypeInterface:$param);
+ let assemblyFormat = "$param attr-dict `:` type($param)";
+ let cppNamespace = "::mlir::test";
+}
+
def TestCopyPayloadOp
: Op<Transform_Dialect, "test_copy_payload",
[DeclareOpInterfaceMethods<TransformOpInterface>,
More information about the Mlir-commits
mailing list