[Mlir-commits] [mlir] 288529e - [mlir][transform] Clean up SplitHandlesOp
Matthias Springer
llvmlistbot at llvm.org
Fri May 5 06:30:09 PDT 2023
Author: Matthias Springer
Date: 2023-05-05T22:29:43+09:00
New Revision: 288529e730f8f225c96ec2f816cef94593bb61a8
URL: https://github.com/llvm/llvm-project/commit/288529e730f8f225c96ec2f816cef94593bb61a8
DIFF: https://github.com/llvm/llvm-project/commit/288529e730f8f225c96ec2f816cef94593bb61a8.diff
LOG: [mlir][transform] Clean up SplitHandlesOp
* Rename to `SplitHandleOp`: it splits a single handle.
* Drop `num_result_handles` attribute: it is redundant and can be inferred from the number of results.
* Improve documentation and minor code cleanups.
Differential Revision: https://reviews.llvm.org/D149937
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 4bee1d4c6eb2e..53c1c0af54c0f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -509,26 +509,28 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence",
}];
}
-def SplitHandlesOp : TransformDialectOp<"split_handles",
+def SplitHandleOp : TransformDialectOp<"split_handle",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let summary = "Splits handles from a union of payload ops to a list";
+ let summary = "Splits a handle of payload ops into handles with a single op";
let description = [{
- Creates `num_result_handles` transform IR handles extracted from the
- `handle` operand. The resulting Payload IR operation handles are listed
- in the same order as the operations appear in the source `handle`.
- This is useful for ensuring a statically known number of operations are
- tracked by the source `handle` and to extract them into individual handles
- that can be further manipulated in isolation.
-
- This operation succeeds and returns `num_result_handles` if the statically
- specified `num_result_handles` corresponds to the dynamic number of
- operations contained in the source `handle`. Otherwise it silently fails.
+ Splits `handle` into one or multiple handles, as specified by the number
+ of results of this operation. `handle` should be mapped to as many payload
+ ops as there are results. Otherwise, this transform will fail silently.
+ Each result handle is mapped to exactly one payload op. The order
+ of the payload ops is preserved, i.e., the i-th payload op is mapped to the
+ i-th result handle.
+
+ This operation is useful for ensuring a statically known number of
+ operations are tracked by the source `handle` and to extract them into
+ individual handles that can be further manipulated in isolation.
+
+ If `handle` is empty, this transform will succeed and all result handles
+ are empty.
}];
- let arguments = (ins TransformHandleTypeInterface:$handle,
- I64Attr:$num_result_handles);
+ let arguments = (ins TransformHandleTypeInterface:$handle);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
let builders = [
@@ -536,8 +538,7 @@ def SplitHandlesOp : TransformDialectOp<"split_handles",
];
let assemblyFormat = [{
- $handle `in` `[` $num_result_handles `]`
- attr-dict `:` functional-type(operands, results)
+ $handle attr-dict `:` functional-type(operands, results)
}];
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 138f4ff69cfbc..5c39ccc3ea510 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1488,48 +1488,43 @@ LogicalResult transform::NamedSequenceOp::verify() {
}
//===----------------------------------------------------------------------===//
-// SplitHandlesOp
+// SplitHandleOp
//===----------------------------------------------------------------------===//
-void transform::SplitHandlesOp::build(OpBuilder &builder,
- OperationState &result, Value target,
- int64_t numResultHandles) {
+void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
+ Value target, int64_t numResultHandles) {
result.addOperands(target);
- result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name),
- builder.getI64IntegerAttr(numResultHandles));
auto pdlOpType = pdl::OperationType::get(builder.getContext());
result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType));
}
DiagnosedSilenceableFailure
-transform::SplitHandlesOp::apply(transform::TransformResults &results,
- transform::TransformState &state) {
- int64_t numResultHandles =
- getHandle() ? state.getPayloadOps(getHandle()).size() : 0;
- int64_t expectedNumResultHandles = getNumResultHandles();
- if (numResultHandles != expectedNumResultHandles) {
- // Empty input handle corner case: always propagates empty handles in both
- // suppress and propagate modes.
- if (numResultHandles == 0) {
- for (OpResult result : getResults())
- results.set(result, {});
- return DiagnosedSilenceableFailure::success();
- }
+transform::SplitHandleOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ int64_t numPayloadOps = state.getPayloadOps(getHandle()).size();
- // If the input handle was not empty and the number of result handles does
- // not match, this is a legit silenceable error.
- return emitSilenceableError()
- << getHandle() << " expected to contain " << expectedNumResultHandles
- << " operation handles but it contains " << numResultHandles
- << " handles";
+ // Empty handle corner case: all result handles are empty.
+ if (numPayloadOps == 0) {
+ for (OpResult result : getResults())
+ results.set(result, {});
+ return DiagnosedSilenceableFailure::success();
}
- // Normal successful case.
+
+ // If the input handle was not empty and the number of payload ops does not
+ // match, this is a legit silenceable error.
+ if (numPayloadOps != getNumResults())
+ return emitSilenceableError()
+ << getHandle() << " expected to contain " << getNumResults()
+ << " payload ops but it contains " << numPayloadOps
+ << " payload ops";
+
for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle())))
results.set(getResults()[en.index()].cast<OpResult>(), en.value());
+
return DiagnosedSilenceableFailure::success();
}
-void transform::SplitHandlesOp::getEffects(
+void transform::SplitHandleOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getHandle(), effects);
producesHandle(getResults(), effects);
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 3668abb1c3a42..9ed99577fbd0b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -818,7 +818,7 @@ transform.sequence failures(propagate) {
// -----
-func.func @split_handles(%a: index, %b: index, %c: index) {
+func.func @split_handle(%a: index, %b: index, %c: index) {
%0 = arith.muli %a, %b : index
%1 = arith.muli %a, %c : index
return
@@ -827,17 +827,17 @@ func.func @split_handles(%a: index, %b: index, %c: index) {
transform.sequence failures(propagate) {
^bb1(%fun: !pdl.operation):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
- %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+ %h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
// expected-remark @below {{1}}
transform.test_print_number_of_associated_payload_ir_ops %h#0
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
- // expected-error @below {{expected to contain 3 operation handles but it contains 2 handles}}
- %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+ // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+ %h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
}
// -----
-func.func @split_handles(%a: index, %b: index, %c: index) {
+func.func @split_handle(%a: index, %b: index, %c: index) {
%0 = arith.muli %a, %b : index
%1 = arith.muli %a, %c : index
return
@@ -846,12 +846,12 @@ func.func @split_handles(%a: index, %b: index, %c: index) {
transform.sequence failures(suppress) {
^bb1(%fun: !pdl.operation):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
- %h:2 = split_handles %muli in [2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+ %h:2 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
// expected-remark @below {{1}}
transform.test_print_number_of_associated_payload_ir_ops %h#0
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
// Silenceable failure and all handles are now empty.
- %h_2:3 = split_handles %muli_2 in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+ %h_2:3 = split_handle %muli_2 : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
// expected-remark @below {{0}}
transform.test_print_number_of_associated_payload_ir_ops %h_2#0
}
@@ -966,7 +966,7 @@ transform.with_pdl_patterns {
// -----
-func.func @split_handles(%a: index, %b: index, %c: index) {
+func.func @split_handle(%a: index, %b: index, %c: index) {
%0 = arith.muli %a, %b : index
%1 = arith.muli %a, %c : index
return
@@ -975,8 +975,8 @@ func.func @split_handles(%a: index, %b: index, %c: index) {
transform.sequence -> !pdl.operation failures(propagate) {
^bb1(%fun: !pdl.operation):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
- // expected-error @below {{expected to contain 3 operation handles but it contains 2 handles}}
- %h_2:3 = split_handles %muli in [3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+ // expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
+ %h_2:3 = split_handle %muli : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
/// Test that yield does not crash in the presence of silenceable error in
/// propagate mode.
yield %fun : !pdl.operation
@@ -988,7 +988,7 @@ transform.sequence -> !transform.any_op failures(suppress) {
^bb0(%arg0: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// Edge case propagating empty handles in splitting.
- %0:3 = split_handles %muli in [3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %0:3 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
// Test does not crash when accessing the empty handle.
yield %0#0 : !transform.any_op
}
More information about the Mlir-commits
mailing list