[Mlir-commits] [mlir] [mlir][PDL] Support running `pdl_interp.foreach` on ranges of values and types (PR #173161)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 30 02:15:47 PST 2026
https://github.com/jumerckx updated https://github.com/llvm/llvm-project/pull/173161
>From d7f3b4bd424e0eadceceb89337568af5252e9865 Mon Sep 17 00:00:00 2001
From: jumerckx <31353884+jumerckx at users.noreply.github.com>
Date: Sat, 20 Dec 2025 09:10:22 -0600
Subject: [PATCH 1/2] support pdl_interp.foreach on values and types
---
mlir/lib/Rewrite/ByteCode.cpp | 30 ++++++++++++++++++++++++++++++
1 file changed, 30 insertions(+)
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 159aa54686034..b089966f714cc 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1741,6 +1741,36 @@ void ByteCodeExecutor::executeForEach() {
selectJump(size_t(0));
return;
}
+ case PDLValue::Kind::Value: {
+ unsigned &index = loopIndex[read()];
+ ValueRange range = valueRangeMemory[rangeIndex];
+ assert(index <= range.size() && "iterated past the end");
+ if (index < range.size()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
+ value = range[index].getAsOpaquePointer();
+ break;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " * Done\n");
+ index = 0;
+ selectJump(size_t(0));
+ return;
+ }
+ case PDLValue::Kind::Type: {
+ unsigned &index = loopIndex[read()];
+ TypeRange range = typeRangeMemory[rangeIndex];
+ assert(index <= range.size() && "iterated past the end");
+ if (index < range.size()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
+ value = range[index].getAsOpaquePointer();
+ break;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " * Done\n");
+ index = 0;
+ selectJump(size_t(0));
+ return;
+ }
default:
llvm_unreachable("unexpected `ForEach` value kind");
}
>From bdf870671d6b920d659fe78daeed47e7320b882d Mon Sep 17 00:00:00 2001
From: jumerckx <31353884+jumerckx at users.noreply.github.com>
Date: Sat, 20 Dec 2025 11:46:06 -0600
Subject: [PATCH 2/2] lit tests
---
mlir/test/Rewrite/pdl-bytecode.mlir | 86 +++++++++++++++++++++++
mlir/test/lib/Rewrite/TestPDLByteCode.cpp | 15 ++++
2 files changed, 101 insertions(+)
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index 8221f009a659f..844f832cd22c6 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -956,6 +956,92 @@ module @ir attributes { test.foreach } {
// -----
+// Test pdl_interp.foreach over a range of types.
+module @patterns {
+ pdl_interp.func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+
+ ^pat:
+ %results = pdl_interp.get_results of %root : !pdl.range<value>
+ %types = pdl_interp.get_value_type of %results : !pdl.range<type>
+ // Iterate over the types of the results of the root op
+ pdl_interp.foreach %type : !pdl.type in %types {
+ // Only match if the type is i64, verifying we introspect all types
+ // but only trigger one rewrite
+ pdl_interp.check_type %type is i64 -> ^record, ^cont
+ ^record:
+ pdl_interp.record_match @rewriters::@success(%root, %type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ pdl_interp.func @success(%root : !pdl.operation, %type : !pdl.type) {
+ // Create an op for the matched i64 type
+ pdl_interp.create_operation "test.type_found" -> (%type : !pdl.type)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+// CHECK-LABEL: test.foreach_type
+// CHECK: "test.type_found"() : () -> i64
+// CHECK-NOT: "test.type_found"
+module @ir attributes { test.foreach_type } {
+ "test.success_op"() : () -> (i32, i64)
+}
+// -----
+
+// Test pdl_interp.foreach over a range of values returned by native constraint.
+module @patterns {
+ pdl_interp.func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+
+ ^pat:
+ %values = pdl_interp.apply_constraint "op_constr_return_value_range"(%root : !pdl.operation) : !pdl.range<value> -> ^loop, ^end
+
+ ^loop:
+ pdl_interp.foreach %val : !pdl.value in %values {
+ %type = pdl_interp.get_value_type of %val : !pdl.type
+ // Only match if the type is f16, verifying we introspect all values
+ // but only trigger one rewrite
+ pdl_interp.check_type %type is f16 -> ^record, ^cont
+ ^record:
+ pdl_interp.record_match @rewriters::@success(%root, %val : !pdl.operation, !pdl.value) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ pdl_interp.func @success(%root : !pdl.operation, %val : !pdl.value) {
+ %type = pdl_interp.get_value_type of %val : !pdl.type
+ pdl_interp.create_operation "test.value_found"(%val : !pdl.value) -> (%type : !pdl.type)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.foreach_value
+// CHECK: %[[VAL1:.*]] = "test.input1"
+// CHECK: "test.value_found"(%[[VAL1]]) : (f16) -> f16
+// CHECK-NOT: "test.value_found"
+module @ir attributes { test.foreach_value } {
+ %0 = "test.input0"() : () -> f32
+ %1 = "test.input1"() : () -> f16
+ "test.success_op"(%0, %1) : (f32, f16) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::GetUsersOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index e5783c96f44e4..715a589e26a37 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -81,6 +81,19 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
return failure();
}
+// Custom constraint that returns a value range if the op is named test.success_op
+static LogicalResult customValueRangeResultConstraint(PatternRewriter &rewriter,
+ PDLResultList &results,
+ ArrayRef<PDLValue> args) {
+ auto *op = args[0].cast<Operation *>();
+ if (op->getName().getStringRef() == "test.success_op") {
+ results.push_back(op->getOperands()); // Returns ValueRange
+ return success();
+ }
+ return failure();
+}
+
+
// Custom creator invoked from PDL.
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
return rewriter.create(OperationState(op->getLoc(), "test.success"));
@@ -161,6 +174,8 @@ struct TestPDLByteCodePass
customConstraintFailure);
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
customTypeRangeResultConstraint);
+ pdlPattern.registerConstraintFunction("op_constr_return_value_range",
+ customValueRangeResultConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);
More information about the Mlir-commits
mailing list