[Mlir-commits] [mlir] 8c6da76 - [mlir][Transform] Fix applyToOne corner case when no op is matched.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 23 12:19:11 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-23T12:18:21-07:00
New Revision: 8c6da76483935d172c34e04e6c0106e33d803c61

URL: https://github.com/llvm/llvm-project/commit/8c6da76483935d172c34e04e6c0106e33d803c61
DIFF: https://github.com/llvm/llvm-project/commit/8c6da76483935d172c34e04e6c0106e33d803c61.diff

LOG: [mlir][Transform] Fix applyToOne corner case when no op is matched.

Such situations manifest themselves with an empty payload which ends up producing empty results.
In such cases, we still want to match the transform op contract and return as many empty SmallVector<Operation*>
as the op requires.

Differential Revision: https://reviews.llvm.org/D128456

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/test/Dialect/Transform/test-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index e6fbfc88e31ee..ef891dd2ddc51 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -824,6 +824,17 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
       decltype(&OpTy::applyToOne)>::template arg_t<0>;
   ArrayRef<Operation *> targets =
       state.getPayloadOps(this->getOperation()->getOperand(0));
+  // Handle the corner case where no target is specified.
+  // This is typically the case when the matcher fails to apply and we need to
+  // propagate gracefully.
+  // In this case, we fill all results with an empty vector.
+  if (targets.empty()) {
+    SmallVector<Operation *> emptyResult;
+    for (auto r : this->getOperation()->getResults())
+      transformResults.set(r.template cast<OpResult>(), emptyResult);
+    return DiagnosedSilenceableFailure::success();
+  }
+
   SmallVector<SmallVector<Operation *>, 1> results;
   // In the multi-result case, collect the number of results each transform
   // produced.
@@ -831,14 +842,17 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
       targets, results, [&](TransformOpType specificOp) {
         return static_cast<OpTy *>(this)->applyToOne(specificOp, state);
       });
+  // Propagate the failure (definite or silencable) if any.
   if (!result.succeeded())
     return result;
-  if (results.empty())
+
+  // Legitimately no results, bail early.
+  if (results.empty() && OpTy::template hasTrait<OpTrait::ZeroResults>())
     return DiagnosedSilenceableFailure::success();
 
   // Ensure all applications return the same number of results.
   // Variadic cases are much trickier to handle in a generic fashion.
-  int64_t nRes = results[0].size();
+  int64_t nRes = results.empty() ? 0 : results[0].size();
   if (llvm::any_of(results, [&](const auto &r) {
         return static_cast<int64_t>(r.size()) != nRes;
       })) {
@@ -849,6 +863,8 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
               "generic `apply` instead of the specialized `applyToOne`";
   }
   // Ensure the number of results agrees with what the transform op expects.
+  // Unless we see empty results, in which case we just want to propagate the
+  // emptiness.
   if (this->getOperation()->getNumResults() != nRes) {
     InFlightDiagnostic diag = static_cast<OpTy *>(this)->emitError()
                               << "unexpected number of results (got " << nRes
@@ -857,10 +873,6 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
     return DiagnosedSilenceableFailure::definiteFailure();
   }
 
-  // If no results, bail early.
-  if (OpTy::template hasTrait<OpTrait::ZeroResults>())
-    return DiagnosedSilenceableFailure::success();
-
   // Perform transposition of M applications producing N results each into N
   // results for each of the M applications.
   SmallVector<SmallVector<Operation *, 1>> transposedResults =

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index a487d1dbef19c..34d1fc8a2b174 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -436,3 +436,27 @@ transform.with_pdl_patterns {
     %1:2 = transform.test_correct_number_of_multi_results %0
   }
 }
+
+// -----
+
+func.func @foo() {
+  "wrong_op_name" () : () -> ()
+  return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @some : benefit(1) {
+    %0 = pdl.operands
+    %1 = pdl.types
+    %2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+    pdl.rewrite %2 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = pdl_match @some in %arg1
+    // Transform fails to match any but still produces 2 results.
+    %1:2 = transform.test_correct_number_of_multi_results %0
+  }
+}


        


More information about the Mlir-commits mailing list