[Mlir-commits] [mlir] [MLIR][ByteCode] Skip over all results in the Bytecode if a Constraint/Rewrite failed, instead of just skipping over the first result. (PR #139255)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 9 05:38:15 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jonas Rickert (jorickert)

<details>
<summary>Changes</summary>

Skipping only over the first results leads to the curCodeIt pointing to the wrong location in the bytecode, causing the execution to continue with a wrong instruction after the Constraint/Rewrite.

---
Full diff: https://github.com/llvm/llvm-project/pull/139255.diff


3 Files Affected:

- (modified) mlir/lib/Rewrite/ByteCode.cpp (+9-7) 
- (modified) mlir/test/Rewrite/pdl-bytecode.mlir (+30) 
- (modified) mlir/test/lib/Rewrite/TestPDLByteCode.cpp (+9) 


``````````diff
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 17cb3a74184f1..83940edffa4c2 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1496,22 +1496,24 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
 void ByteCodeExecutor::processNativeFunResults(
     ByteCodeRewriteResultList &results, unsigned numResults,
     LogicalResult &rewriteResult) {
-  // Store the results in the bytecode memory or handle missing results on
-  // failure.
-  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
-    PDLValue::Kind resultKind = read<PDLValue::Kind>();
-
+  if (failed(rewriteResult)) {
     // Skip the according number of values on the buffer on failure and exit
     // early as there are no results to process.
-    if (failed(rewriteResult)) {
+    for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
+      const PDLValue::Kind resultKind = read<PDLValue::Kind>();
       if (resultKind == PDLValue::Kind::TypeRange ||
           resultKind == PDLValue::Kind::ValueRange) {
         skip(2);
       } else {
         skip(1);
       }
-      return;
     }
+    return;
+  }
+
+  // Store the results in the bytecode memory
+  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
+    PDLValue::Kind resultKind = read<PDLValue::Kind>();
     PDLValue result = results.getResults()[resultIdx];
     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
     assert(result.getKind() == resultKind &&
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index f8e4f2e83b296..8221f009a659f 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -143,6 +143,36 @@ module @ir attributes { test.apply_constraint_4 } {
 
 // -----
 
+// Test returning a type from a native constraint.
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    %new_type:2 = pdl_interp.apply_constraint "op_multiple_returns_failure"(%root : !pdl.operation) : !pdl.type, !pdl.type -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root, %new_type#0 : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
+      %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_constraint_multi_result_failure
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.success_op"
+module @ir attributes { test.apply_constraint_multi_result_failure } {
+  "test.success_op"() : () -> ()
+}
+
+// -----
+
 // Test success and failure cases of native constraints with pdl.range results.
 module @patterns {
   pdl_interp.func @matcher(%root : !pdl.operation) {
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 7b96bf5e28d32..e5783c96f44e4 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -55,6 +55,13 @@ static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
   return failure();
 }
 
+// Custom constraint that always returns failure
+static LogicalResult customConstraintFailure(PatternRewriter & /*rewriter*/,
+                                             PDLResultList & /*results*/,
+                                             ArrayRef<PDLValue> /*args*/) {
+  return failure();
+}
+
 // Custom constraint that returns a type range of variable length if the op is
 // named test.success_op
 static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
@@ -150,6 +157,8 @@ struct TestPDLByteCodePass
                                           customValueResultConstraint);
     pdlPattern.registerConstraintFunction("op_constr_return_type",
                                           customTypeResultConstraint);
+    pdlPattern.registerConstraintFunction("op_multiple_returns_failure",
+                                          customConstraintFailure);
     pdlPattern.registerConstraintFunction("op_constr_return_type_range",
                                           customTypeRangeResultConstraint);
     pdlPattern.registerRewriteFunction("creator", customCreate);

``````````

</details>


https://github.com/llvm/llvm-project/pull/139255


More information about the Mlir-commits mailing list