[Mlir-commits] [mlir] ff2f6c3 - [mlir][transform] Extract getConsumedHandleOpOperands helper function

Matthias Springer llvmlistbot at llvm.org
Fri Jun 9 02:32:36 PDT 2023


Author: Matthias Springer
Date: 2023-06-09T11:32:27+02:00
New Revision: ff2f6c34be71cecea54483ad0a98385bff000b29

URL: https://github.com/llvm/llvm-project/commit/ff2f6c34be71cecea54483ad0a98385bff000b29
DIFF: https://github.com/llvm/llvm-project/commit/ff2f6c34be71cecea54483ad0a98385bff000b29.diff

LOG: [mlir][transform] Extract getConsumedHandleOpOperands helper function

This function is extracted from `TransformState::applyTransform`.

Differential Revision: https://reviews.llvm.org/D152374

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index a283f974a68f9..5605c4804c680 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -64,6 +64,10 @@ void forwardTerminatorOperands(Block *block, transform::TransformState &state,
 /// outside of test cases.
 TransformState makeTransformStateForTesting(Region *region,
                                             Operation *payloadRoot);
+
+/// Returns all operands that are handles and being consumed by the given op.
+SmallVector<OpOperand *>
+getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
 } // namespace detail
 
 /// Options controlling the application of transform operations by the

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index c3184c1e35bb8..07f8ebcf8d6d9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -100,6 +100,11 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
       diag.attachNote(target->getLoc()) << "when applied to this op";
       return diag;
     }
+
+    /// Returns all operands that are handles and being consumed by this op.
+    ::llvm::SmallVector<OpOperand *> getConsumedHandleOpOperands() {
+      return ::mlir::transform::detail::getConsumedHandleOpOperands($_op);
+    }
   }];
 
   let verify = [{

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 57b3226c7b3c1..e69e56436322c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -843,21 +843,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   }
 
   // Find which operands are consumed.
-  DenseSet<unsigned> consumedOperands;
-  auto memEffectInterface =
-      cast<MemoryEffectOpInterface>(transform.getOperation());
-  SmallVector<MemoryEffects::EffectInstance, 2> effects;
-  for (OpOperand &target : transform->getOpOperands()) {
-    effects.clear();
-    memEffectInterface.getEffectsOnValue(target.get(), effects);
-    if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
-          return isa<transform::TransformMappingResource>(
-                     effect.getResource()) &&
-                 isa<MemoryEffects::Free>(effect.getEffect());
-        })) {
-      consumedOperands.insert(target.getOperandNumber());
-    }
-  }
+  SmallVector<OpOperand *> consumedOperands =
+      transform.getConsumedHandleOpOperands();
 
   // Remember the results of the payload ops associated with the consumed
   // op handles or the ops defining the value handles so we can drop the
@@ -869,8 +856,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   DenseSet<Operation *> consumedPayloadOps;
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-  for (unsigned index : consumedOperands) {
-    Value operand = transform->getOperand(index);
+  for (OpOperand *opOperand : consumedOperands) {
+    Value operand = opOperand->get();
     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
       for (Operation *payloadOp : getPayloadOps(operand)) {
         llvm::append_range(origOpFlatResults, payloadOp->getResults());
@@ -901,7 +888,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     DiagnosedDefiniteFailure diag =
         emitDefiniteFailure(transform->getLoc())
         << "unexpectedly consumed a value that is not a handle as operand #"
-        << index;
+        << opOperand->getOperandNumber();
     diag.attachNote(operand.getLoc())
         << "value defined here with type " << operand.getType();
     return diag;
@@ -923,8 +910,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
 
   // Remove the mapping for the operand if it is consumed by the operation. This
   // allows us to catch use-after-free with assertions later on.
-  for (unsigned index : consumedOperands) {
-    Value operand = transform->getOperand(index);
+  for (OpOperand *opOperand : consumedOperands) {
+    Value operand = opOperand->get();
     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
       forgetMapping(operand, origOpFlatResults);
     } else if (llvm::isa<TransformValueHandleTypeInterface>(
@@ -1593,6 +1580,27 @@ void transform::getConsumedBlockArguments(
 // Utilities for TransformOpInterface.
 //===----------------------------------------------------------------------===//
 
+SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
+    TransformOpInterface transformOp) {
+  SmallVector<OpOperand *> consumedOperands;
+  consumedOperands.reserve(transformOp->getNumOperands());
+  auto memEffectInterface =
+      cast<MemoryEffectOpInterface>(transformOp.getOperation());
+  SmallVector<MemoryEffects::EffectInstance, 2> effects;
+  for (OpOperand &target : transformOp->getOpOperands()) {
+    effects.clear();
+    memEffectInterface.getEffectsOnValue(target.get(), effects);
+    if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
+          return isa<transform::TransformMappingResource>(
+                     effect.getResource()) &&
+                 isa<MemoryEffects::Free>(effect.getEffect());
+        })) {
+      consumedOperands.push_back(&target);
+    }
+  }
+  return consumedOperands;
+}
+
 LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
   auto iface = cast<MemoryEffectOpInterface>(op);
   SmallVector<MemoryEffects::EffectInstance> effects;


        


More information about the Mlir-commits mailing list