[Mlir-commits] [mlir] ff2f6c3 - [mlir][transform] Extract getConsumedHandleOpOperands helper function
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 9 02:32:36 PDT 2023
Author: Matthias Springer
Date: 2023-06-09T11:32:27+02:00
New Revision: ff2f6c34be71cecea54483ad0a98385bff000b29
URL: https://github.com/llvm/llvm-project/commit/ff2f6c34be71cecea54483ad0a98385bff000b29
DIFF: https://github.com/llvm/llvm-project/commit/ff2f6c34be71cecea54483ad0a98385bff000b29.diff
LOG: [mlir][transform] Extract getConsumedHandleOpOperands helper function
This function is extracted from `TransformState::applyTransform`.
Differential Revision: https://reviews.llvm.org/D152374
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index a283f974a68f9..5605c4804c680 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -64,6 +64,10 @@ void forwardTerminatorOperands(Block *block, transform::TransformState &state,
/// outside of test cases.
TransformState makeTransformStateForTesting(Region *region,
Operation *payloadRoot);
+
+/// Returns all operands that are handles and being consumed by the given op.
+SmallVector<OpOperand *>
+getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
} // namespace detail
/// Options controlling the application of transform operations by the
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index c3184c1e35bb8..07f8ebcf8d6d9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -100,6 +100,11 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
diag.attachNote(target->getLoc()) << "when applied to this op";
return diag;
}
+
+ /// Returns all operands that are handles and being consumed by this op.
+ ::llvm::SmallVector<OpOperand *> getConsumedHandleOpOperands() {
+ return ::mlir::transform::detail::getConsumedHandleOpOperands($_op);
+ }
}];
let verify = [{
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 57b3226c7b3c1..e69e56436322c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -843,21 +843,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
// Find which operands are consumed.
- DenseSet<unsigned> consumedOperands;
- auto memEffectInterface =
- cast<MemoryEffectOpInterface>(transform.getOperation());
- SmallVector<MemoryEffects::EffectInstance, 2> effects;
- for (OpOperand &target : transform->getOpOperands()) {
- effects.clear();
- memEffectInterface.getEffectsOnValue(target.get(), effects);
- if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
- return isa<transform::TransformMappingResource>(
- effect.getResource()) &&
- isa<MemoryEffects::Free>(effect.getEffect());
- })) {
- consumedOperands.insert(target.getOperandNumber());
- }
- }
+ SmallVector<OpOperand *> consumedOperands =
+ transform.getConsumedHandleOpOperands();
// Remember the results of the payload ops associated with the consumed
// op handles or the ops defining the value handles so we can drop the
@@ -869,8 +856,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
DenseSet<Operation *> consumedPayloadOps;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
- for (unsigned index : consumedOperands) {
- Value operand = transform->getOperand(index);
+ for (OpOperand *opOperand : consumedOperands) {
+ Value operand = opOperand->get();
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
for (Operation *payloadOp : getPayloadOps(operand)) {
llvm::append_range(origOpFlatResults, payloadOp->getResults());
@@ -901,7 +888,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
DiagnosedDefiniteFailure diag =
emitDefiniteFailure(transform->getLoc())
<< "unexpectedly consumed a value that is not a handle as operand #"
- << index;
+ << opOperand->getOperandNumber();
diag.attachNote(operand.getLoc())
<< "value defined here with type " << operand.getType();
return diag;
@@ -923,8 +910,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// Remove the mapping for the operand if it is consumed by the operation. This
// allows us to catch use-after-free with assertions later on.
- for (unsigned index : consumedOperands) {
- Value operand = transform->getOperand(index);
+ for (OpOperand *opOperand : consumedOperands) {
+ Value operand = opOperand->get();
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
forgetMapping(operand, origOpFlatResults);
} else if (llvm::isa<TransformValueHandleTypeInterface>(
@@ -1593,6 +1580,27 @@ void transform::getConsumedBlockArguments(
// Utilities for TransformOpInterface.
//===----------------------------------------------------------------------===//
+SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
+ TransformOpInterface transformOp) {
+ SmallVector<OpOperand *> consumedOperands;
+ consumedOperands.reserve(transformOp->getNumOperands());
+ auto memEffectInterface =
+ cast<MemoryEffectOpInterface>(transformOp.getOperation());
+ SmallVector<MemoryEffects::EffectInstance, 2> effects;
+ for (OpOperand &target : transformOp->getOpOperands()) {
+ effects.clear();
+ memEffectInterface.getEffectsOnValue(target.get(), effects);
+ if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
+ return isa<transform::TransformMappingResource>(
+ effect.getResource()) &&
+ isa<MemoryEffects::Free>(effect.getEffect());
+ })) {
+ consumedOperands.push_back(&target);
+ }
+ }
+ return consumedOperands;
+}
+
LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
auto iface = cast<MemoryEffectOpInterface>(op);
SmallVector<MemoryEffects::EffectInstance> effects;
More information about the Mlir-commits
mailing list