[Mlir-commits] [mlir] d9db5a5 - [mlir] relax value handle updates when operation is replaced

Alex Zinenko llvmlistbot at llvm.org
Tue Mar 14 08:57:39 PDT 2023


Author: Alex Zinenko
Date: 2023-03-14T15:57:31Z
New Revision: d9db5a5904fd328130a5dd100e3cf36eb7d9e6d0

URL: https://github.com/llvm/llvm-project/commit/d9db5a5904fd328130a5dd100e3cf36eb7d9e6d0
DIFF: https://github.com/llvm/llvm-project/commit/d9db5a5904fd328130a5dd100e3cf36eb7d9e6d0.diff

LOG: [mlir] relax value handle updates when operation is replaced

The initial implementaiton of value handle update when the payload
operation defining the values associated with value handles was being
replaced required the replacement operation to have the same number of
results. This is not strictly necessary. The replacement operation may
have more results, or less results provided that there are no handles to
the results that have no equivalent in the replacement op.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D145254

Added: 
    

Modified: 
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Transform/transform-state-extension.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index ef5bfd2a85aeb..83f2a7c7c5932 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -339,14 +339,7 @@ transform::TransformState::replacePayloadOp(Operation *op,
   }
 
   // Otherwise, replace the pointed-to object of all handles while preserving
-  // their relative order.
-  if (op->getNumResults() != replacement->getNumResults()) {
-    return emitError(op->getLoc())
-           << "cannot replace an op with another op producing a 
diff erent "
-              "number of results while tracking handles";
-  }
-
-  // Replace the mapped operation if present.
+  // their relative order. First, replace the mapped operation if present.
   for (Value handle : opHandles) {
     Mappings &mappings = getMapping(handle);
     auto it = mappings.direct.find(handle);
@@ -362,9 +355,21 @@ transform::TransformState::replacePayloadOp(Operation *op,
     mappings.reverse[replacement].push_back(handle);
   }
 
-  // Replace the mapped results of the operation.
-  for (auto [origResult, replacementResult, handleList] : llvm::zip(
-           op->getResults(), replacement->getResults(), resultValueHandles)) {
+  // Second, replace the mapped results of the operation.
+  for (auto [origResult, handleList] :
+       llvm::zip(op->getResults(), resultValueHandles)) {
+    // No handles to the value, skip even if there is no replacement.
+    if (handleList.empty())
+      continue;
+
+    unsigned resultNumber = origResult.getResultNumber();
+    if (resultNumber >= replacement->getNumResults()) {
+      return emitError(op->getLoc())
+             << "cannot replace an op with another op producing less results "
+                "while tracking handles";
+    }
+
+    Value replacementResult = replacement->getResult(resultNumber);
     for (Value resultHandle : handleList) {
       Mappings &mappings = getMapping(resultHandle);
       auto it = mappings.values.find(resultHandle);

diff  --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir
index 054ee077496c5..7885e08291716 100644
--- a/mlir/test/Dialect/Transform/transform-state-extension.mlir
+++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir
@@ -47,15 +47,36 @@ module {
 
 // -----
 
-// expected-error @below {{cannot replace an op with another op producing a 
diff erent number of results while tracking handles}}
-module {
-  transform.sequence failures(propagate) {
-  ^bb0(%arg0: !pdl.operation):
-    test_add_test_extension "A"
-    %dummy = test_remap_operand_to_self %arg0 : !transform.any_op
-  }
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  test_add_test_extension "A"
+   // This is okay because we are replacing the top-level module opeation
+   // (0 results) with this operation that has _more_ (1) results.
+  %dummy = test_remap_operand_to_self %arg0 : !pdl.operation
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  test_add_test_extension "A"
+  %dummy = test_remap_operand_to_self %arg0 : !pdl.operation
+  // This is still okay. Even though we are replacing the previous
+  // operation with (1 result) with this operation that has less (0) results,
+  // there is no handle to the result, hence no issue with value handle update.
+  test_remap_operand_to_self %dummy
 }
 
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  test_add_test_extension "A"
+  // expected-error @below {{cannot replace an op with another op producing less results while tracking handles}}
+  %dummy = test_remap_operand_to_self %arg0 : !pdl.operation
+  %valuehandle = transform.get_result %dummy[0] : (!pdl.operation) -> !transform.any_value
+  test_remap_operand_to_self %dummy
+}
 
 // -----
 

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index e32cb3ec891ab..79b5256ef1e23 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -297,6 +297,8 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
   if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(),
                                       getOperation())))
     return DiagnosedSilenceableFailure::definiteFailure();
+  if (getNumResults() > 0)
+    results.set(getResult(0).cast<OpResult>(), getOperation());
   return DiagnosedSilenceableFailure::success();
 }
 


        


More information about the Mlir-commits mailing list