[Mlir-commits] [mlir] a21986b - [MLIR][PDL] Skip over all results in the PDL Bytecode if a Constraint/Rewrite failed (#139255)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 20 23:37:31 PDT 2025


Author: Jonas Rickert
Date: 2025-05-20T23:37:27-07:00
New Revision: a21986b152927b368eb9c7516ebeaa0b5fbd3167

URL: https://github.com/llvm/llvm-project/commit/a21986b152927b368eb9c7516ebeaa0b5fbd3167
DIFF: https://github.com/llvm/llvm-project/commit/a21986b152927b368eb9c7516ebeaa0b5fbd3167.diff

LOG: [MLIR][PDL] Skip over all results in the PDL Bytecode if a Constraint/Rewrite failed (#139255)

Skipping only over the first results leads to the curCodeIt pointing to
the wrong location in the bytecode, causing the execution to continue
with a wrong instruction after the Constraint/Rewrite.

Signed-off-by: Rickert, Jonas <Jonas.Rickert at amd.com>

Added: 
    

Modified: 
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/test/Rewrite/pdl-bytecode.mlir
    mlir/test/lib/Rewrite/TestPDLByteCode.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 17cb3a74184f1..83940edffa4c2 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1496,22 +1496,24 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
 void ByteCodeExecutor::processNativeFunResults(
     ByteCodeRewriteResultList &results, unsigned numResults,
     LogicalResult &rewriteResult) {
-  // Store the results in the bytecode memory or handle missing results on
-  // failure.
-  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
-    PDLValue::Kind resultKind = read<PDLValue::Kind>();
-
+  if (failed(rewriteResult)) {
     // Skip the according number of values on the buffer on failure and exit
     // early as there are no results to process.
-    if (failed(rewriteResult)) {
+    for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
+      const PDLValue::Kind resultKind = read<PDLValue::Kind>();
       if (resultKind == PDLValue::Kind::TypeRange ||
           resultKind == PDLValue::Kind::ValueRange) {
         skip(2);
       } else {
         skip(1);
       }
-      return;
     }
+    return;
+  }
+
+  // Store the results in the bytecode memory
+  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
+    PDLValue::Kind resultKind = read<PDLValue::Kind>();
     PDLValue result = results.getResults()[resultIdx];
     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
     assert(result.getKind() == resultKind &&

diff  --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index f8e4f2e83b296..8221f009a659f 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -143,6 +143,36 @@ module @ir attributes { test.apply_constraint_4 } {
 
 // -----
 
+// Test returning a type from a native constraint.
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    %new_type:2 = pdl_interp.apply_constraint "op_multiple_returns_failure"(%root : !pdl.operation) : !pdl.type, !pdl.type -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root, %new_type#0 : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
+      %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_constraint_multi_result_failure
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.success_op"
+module @ir attributes { test.apply_constraint_multi_result_failure } {
+  "test.success_op"() : () -> ()
+}
+
+// -----
+
 // Test success and failure cases of native constraints with pdl.range results.
 module @patterns {
   pdl_interp.func @matcher(%root : !pdl.operation) {

diff  --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 7b96bf5e28d32..e5783c96f44e4 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -55,6 +55,13 @@ static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
   return failure();
 }
 
+// Custom constraint that always returns failure
+static LogicalResult customConstraintFailure(PatternRewriter & /*rewriter*/,
+                                             PDLResultList & /*results*/,
+                                             ArrayRef<PDLValue> /*args*/) {
+  return failure();
+}
+
 // Custom constraint that returns a type range of variable length if the op is
 // named test.success_op
 static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
@@ -150,6 +157,8 @@ struct TestPDLByteCodePass
                                           customValueResultConstraint);
     pdlPattern.registerConstraintFunction("op_constr_return_type",
                                           customTypeResultConstraint);
+    pdlPattern.registerConstraintFunction("op_multiple_returns_failure",
+                                          customConstraintFailure);
     pdlPattern.registerConstraintFunction("op_constr_return_type_range",
                                           customTypeRangeResultConstraint);
     pdlPattern.registerRewriteFunction("creator", customCreate);


        


More information about the Mlir-commits mailing list