[Mlir-commits] [mlir] [MLIR][ByteCode] Skip over all results in the Bytecode if a Constraint/Rewrite failed, instead of just skipping over the first result. (PR #139255)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 9 05:38:15 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jonas Rickert (jorickert)
<details>
<summary>Changes</summary>
Skipping only over the first results leads to the curCodeIt pointing to the wrong location in the bytecode, causing the execution to continue with a wrong instruction after the Constraint/Rewrite.
---
Full diff: https://github.com/llvm/llvm-project/pull/139255.diff
3 Files Affected:
- (modified) mlir/lib/Rewrite/ByteCode.cpp (+9-7)
- (modified) mlir/test/Rewrite/pdl-bytecode.mlir (+30)
- (modified) mlir/test/lib/Rewrite/TestPDLByteCode.cpp (+9)
``````````diff
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 17cb3a74184f1..83940edffa4c2 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1496,22 +1496,24 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
void ByteCodeExecutor::processNativeFunResults(
ByteCodeRewriteResultList &results, unsigned numResults,
LogicalResult &rewriteResult) {
- // Store the results in the bytecode memory or handle missing results on
- // failure.
- for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
- PDLValue::Kind resultKind = read<PDLValue::Kind>();
-
+ if (failed(rewriteResult)) {
// Skip the according number of values on the buffer on failure and exit
// early as there are no results to process.
- if (failed(rewriteResult)) {
+ for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
+ const PDLValue::Kind resultKind = read<PDLValue::Kind>();
if (resultKind == PDLValue::Kind::TypeRange ||
resultKind == PDLValue::Kind::ValueRange) {
skip(2);
} else {
skip(1);
}
- return;
}
+ return;
+ }
+
+ // Store the results in the bytecode memory
+ for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
+ PDLValue::Kind resultKind = read<PDLValue::Kind>();
PDLValue result = results.getResults()[resultIdx];
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
assert(result.getKind() == resultKind &&
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index f8e4f2e83b296..8221f009a659f 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -143,6 +143,36 @@ module @ir attributes { test.apply_constraint_4 } {
// -----
+// Test returning a type from a native constraint.
+module @patterns {
+ pdl_interp.func @matcher(%root : !pdl.operation) {
+ %new_type:2 = pdl_interp.apply_constraint "op_multiple_returns_failure"(%root : !pdl.operation) : !pdl.type, !pdl.type -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root, %new_type#0 : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
+ %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.apply_constraint_multi_result_failure
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.success_op"
+module @ir attributes { test.apply_constraint_multi_result_failure } {
+ "test.success_op"() : () -> ()
+}
+
+// -----
+
// Test success and failure cases of native constraints with pdl.range results.
module @patterns {
pdl_interp.func @matcher(%root : !pdl.operation) {
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 7b96bf5e28d32..e5783c96f44e4 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -55,6 +55,13 @@ static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
return failure();
}
+// Custom constraint that always returns failure
+static LogicalResult customConstraintFailure(PatternRewriter & /*rewriter*/,
+ PDLResultList & /*results*/,
+ ArrayRef<PDLValue> /*args*/) {
+ return failure();
+}
+
// Custom constraint that returns a type range of variable length if the op is
// named test.success_op
static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
@@ -150,6 +157,8 @@ struct TestPDLByteCodePass
customValueResultConstraint);
pdlPattern.registerConstraintFunction("op_constr_return_type",
customTypeResultConstraint);
+ pdlPattern.registerConstraintFunction("op_multiple_returns_failure",
+ customConstraintFailure);
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
customTypeRangeResultConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
``````````
</details>
https://github.com/llvm/llvm-project/pull/139255
More information about the Mlir-commits
mailing list