[Mlir-commits] [mlir] 4299be1 - [mlir] optionally allow repeated handles in transform dialect
Alex Zinenko
llvmlistbot at llvm.org
Mon Dec 19 01:02:11 PST 2022
Author: Alex Zinenko
Date: 2022-12-19T09:02:03Z
New Revision: 4299be1a087edc434fd0111a3316931d7ebc0638
URL: https://github.com/llvm/llvm-project/commit/4299be1a087edc434fd0111a3316931d7ebc0638
DIFF: https://github.com/llvm/llvm-project/commit/4299be1a087edc434fd0111a3316931d7ebc0638.diff
LOG: [mlir] optionally allow repeated handles in transform dialect
Some operations may be able to deal with handles pointing to the same
operation when the handle is consumed. For example, merge handles with
deduplication doesn't actually destroy payload operations and is
specifically intended to remove the situation with duplicates. Add a
method to the transform interface to allow ops to declare they can
support repeated handles.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D140124
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/expensive-checks.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index 33781536239e9..c760cb12598bc 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -48,6 +48,19 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
"::mlir::transform::TransformResults &":$transformResults,
"::mlir::transform::TransformState &":$state
)>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Indicates whether the op instance allows its handle operands to be
+ associated with the same payload operations.
+ }],
+ /*returnType=*/"bool",
+ /*name=*/"allowsRepeatedHandleOperands",
+ /*arguments=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return false;
+ }]
+ >,
];
let extraSharedClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index f6788dece6ca6..c813a64bd7e2c 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -210,7 +210,7 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
}
def MergeHandlesOp : TransformDialectOp<"merge_handles",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
+ [DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SameOperandsAndResultType]> {
let summary = "Merges handles into one pointing to the union of payload ops";
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 9b136cccbe6f1..10a4381e1d75a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -189,7 +189,8 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
for (OpOperand &target : transform->getOpOperands()) {
// If the operand uses an invalidated handle, report it.
auto it = invalidatedHandles.find(target.get());
- if (it != invalidatedHandles.end())
+ if (!transform.allowsRepeatedHandleOperands() &&
+ it != invalidatedHandles.end())
return it->getSecond()(transform->getLoc()), failure();
// Invalidate handles pointing to the operations nested in the operation
@@ -201,6 +202,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
if (llvm::any_of(effects, consumesTarget))
recordHandleInvalidation(target);
}
+
return success();
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index f7a8ee1a979a0..629b1f6a073b1 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -449,6 +449,11 @@ transform::MergeHandlesOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::success();
}
+bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
+ // Handles may be the same if deduplicating is enabled.
+ return getDeduplicate();
+}
+
void transform::MergeHandlesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getHandles(), effects);
diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir
index 550ede6353121..abc09d799ee23 100644
--- a/mlir/test/Dialect/Transform/expensive-checks.mlir
+++ b/mlir/test/Dialect/Transform/expensive-checks.mlir
@@ -99,3 +99,17 @@ module {
transform.test_consume_operand %1, %2
}
}
+
+// -----
+
+// Deduplication attribute allows "merge_handles" to take repeated operands.
+
+module {
+
+ transform.sequence failures(propagate) {
+ ^bb0(%0: !pdl.operation):
+ %1 = transform.test_copy_payload %0
+ %2 = transform.test_copy_payload %0
+ transform.merge_handles %1, %2 { deduplicate } : !pdl.operation
+ }
+}
More information about the Mlir-commits
mailing list