[Mlir-commits] [mlir] 709098f - [mlir][transform] SplitHandleOp: add additional distribution options
Matthias Springer
llvmlistbot at llvm.org
Tue May 9 02:43:49 PDT 2023
Author: Matthias Springer
Date: 2023-05-09T11:38:18+02:00
New Revision: 709098fb38b50ea376704f83b373bd41f0ed0cba
URL: https://github.com/llvm/llvm-project/commit/709098fb38b50ea376704f83b373bd41f0ed0cba
DIFF: https://github.com/llvm/llvm-project/commit/709098fb38b50ea376704f83b373bd41f0ed0cba.diff
LOG: [mlir][transform] SplitHandleOp: add additional distribution options
Add options to handle cases where there are not enough or too many payload ops mapped to the given handle.
Differential Revision: https://reviews.llvm.org/D149955
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/test-interpreter.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 8154835546bb6..0408b341b7c63 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -521,8 +521,8 @@ def SplitHandleOp : TransformDialectOp<"split_handle",
let description = [{
Splits `handle` into one or multiple handles, as specified by the number
of results of this operation. `handle` should be mapped to as many payload
- ops as there are results. Otherwise, this transform will fail silently.
- Each result handle is mapped to exactly one payload op. The order
+ ops as there are results. Otherwise, this transform will fail silently by
+ default. Each result handle is mapped to exactly one payload op. The order
of the payload ops is preserved, i.e., the i-th payload op is mapped to the
i-th result handle.
@@ -530,12 +530,23 @@ def SplitHandleOp : TransformDialectOp<"split_handle",
operations are tracked by the source `handle` and to extract them into
individual handles that can be further manipulated in isolation.
- If `handle` is empty, this transform will succeed and all result handles
- are empty.
+ If there are more payload ops than results, the remaining ops are mapped to
+ the result with index `overflow_result`. If no `overflow_result` is
+ specified, the transform fails silently.
+
+ If there are fewer payload ops than results, the transform fails silently
+ if `fail_on_payload_too_small` is set to "true". Otherwise, it succeeds and
+ the remaining result handles are not mapped to any op. It also succeeds if
+ `handle` is empty and `pass_through_empty_handle` is set to "true",
+ regardless of `fail_on_payload_too_small`.
}];
- let arguments = (ins TransformHandleTypeInterface:$handle);
+ let arguments = (ins TransformHandleTypeInterface:$handle,
+ DefaultValuedAttr<BoolAttr, "true">:$pass_through_empty_handle,
+ DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_too_small,
+ OptionalAttr<I64Attr>:$overflow_result);
let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+ let hasVerifier = 1;
let builders = [
OpBuilder<(ins "Value":$handle, "int64_t":$numResultHandles)>
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 5c39ccc3ea510..62ef94d54e477 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1502,24 +1502,40 @@ DiagnosedSilenceableFailure
transform::SplitHandleOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
int64_t numPayloadOps = state.getPayloadOps(getHandle()).size();
-
- // Empty handle corner case: all result handles are empty.
- if (numPayloadOps == 0) {
- for (OpResult result : getResults())
- results.set(result, {});
- return DiagnosedSilenceableFailure::success();
- }
-
- // If the input handle was not empty and the number of payload ops does not
- // match, this is a legit silenceable error.
- if (numPayloadOps != getNumResults())
+ auto produceNumOpsError = [&]() {
return emitSilenceableError()
- << getHandle() << " expected to contain " << getNumResults()
+ << getHandle() << " expected to contain " << this->getNumResults()
<< " payload ops but it contains " << numPayloadOps
<< " payload ops";
+ };
- for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle())))
- results.set(getResults()[en.index()].cast<OpResult>(), en.value());
+ // Fail if there are more payload ops than results and no overflow result was
+ // specified.
+ if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
+ return produceNumOpsError();
+
+ // Fail if there are more results than payload ops. Unless:
+ // - "fail_on_payload_too_small" is set to "false", or
+ // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
+ if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
+ !(numPayloadOps == 0 && getPassThroughEmptyHandle()))
+ return produceNumOpsError();
+
+ // Distribute payload ops.
+ SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
+ if (getOverflowResult())
+ resultHandles[*getOverflowResult()].reserve(numPayloadOps -
+ getNumResults());
+ for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
+ int64_t resultNum = en.index();
+ if (resultNum >= getNumResults())
+ resultNum = *getOverflowResult();
+ resultHandles[resultNum].push_back(en.value());
+ }
+
+ // Set transform op results.
+ for (auto &&it : llvm::enumerate(resultHandles))
+ results.set(getResult(it.index()).cast<OpResult>(), it.value());
return DiagnosedSilenceableFailure::success();
}
@@ -1532,6 +1548,13 @@ void transform::SplitHandleOp::getEffects(
// manipulation.
}
+LogicalResult transform::SplitHandleOp::verify() {
+ if (getOverflowResult().has_value() &&
+ !(*getOverflowResult() >= 0 && *getOverflowResult() < getNumResults()))
+ return emitOpError("overflow_result is not a valid result index");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// PDLMatchOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 9ed99577fbd0b..8ceb72d8d46b3 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -858,6 +858,47 @@ transform.sequence failures(suppress) {
// -----
+func.func @split_handle(%a: index, %b: index, %c: index) {
+ %0 = arith.muli %a, %b : index
+ %1 = arith.muli %a, %c : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%fun: !pdl.operation):
+ %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
+ // No error, last result handle is empty.
+ %h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+ // expected-remark @below {{1}}
+ transform.test_print_number_of_associated_payload_ir_ops %h#0
+ // expected-remark @below {{1}}
+ transform.test_print_number_of_associated_payload_ir_ops %h#1
+ // expected-remark @below {{0}}
+ transform.test_print_number_of_associated_payload_ir_ops %h#2
+}
+
+// -----
+
+func.func @split_handle(%a: index, %b: index, %c: index) {
+ %0 = arith.muli %a, %b : index
+ %1 = arith.muli %a, %c : index
+ %2 = arith.muli %a, %c : index
+ %3 = arith.muli %a, %c : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%fun: !pdl.operation):
+ %muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!pdl.operation) -> !pdl.operation
+ %h:2 = split_handle %muli_2 {overflow_result = 0} : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
+ // expected-remark @below {{3}}
+ transform.test_print_number_of_associated_payload_ir_ops %h#0
+ // expected-remark @below {{1}}
+ transform.test_print_number_of_associated_payload_ir_ops %h#1
+}
+
+// -----
+
"test.some_op"() : () -> ()
"other_dialect.other_op"() : () -> ()
More information about the Mlir-commits
mailing list