[Mlir-commits] [mlir] f3df4b9 - [mlir][PDL] Support running `pdl_interp.foreach` on ranges of values and types (#173161)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 2 08:58:53 PST 2026
Author: jumerckx
Date: 2026-02-02T08:58:48-08:00
New Revision: f3df4b9292ce6fbd90de53f0124e55db0a9ee714
URL: https://github.com/llvm/llvm-project/commit/f3df4b9292ce6fbd90de53f0124e55db0a9ee714
DIFF: https://github.com/llvm/llvm-project/commit/f3df4b9292ce6fbd90de53f0124e55db0a9ee714.diff
LOG: [mlir][PDL] Support running `pdl_interp.foreach` on ranges of values and types (#173161)
The foreach execution only works for operation ranges, typically
stemming from pdl_interp.get_users.
Custom rewrites/constraints can return ranges of types and values as
well, however.
This pr adds support for executing `pdl_interp.foreach` in those cases.
Added:
Modified:
mlir/lib/Rewrite/ByteCode.cpp
mlir/test/Rewrite/pdl-bytecode.mlir
mlir/test/lib/Rewrite/TestPDLByteCode.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index ede7d8a4006fc..cf00216288115 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1741,6 +1741,36 @@ void ByteCodeExecutor::executeForEach() {
selectJump(size_t(0));
return;
}
+ case PDLValue::Kind::Value: {
+ unsigned &index = loopIndex[read()];
+ ValueRange range = valueRangeMemory[rangeIndex];
+ assert(index <= range.size() && "iterated past the end");
+ if (index < range.size()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
+ value = range[index].getAsOpaquePointer();
+ break;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " * Done\n");
+ index = 0;
+ selectJump(size_t(0));
+ return;
+ }
+ case PDLValue::Kind::Type: {
+ unsigned &index = loopIndex[read()];
+ TypeRange range = typeRangeMemory[rangeIndex];
+ assert(index <= range.size() && "iterated past the end");
+ if (index < range.size()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
+ value = range[index].getAsOpaquePointer();
+ break;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " * Done\n");
+ index = 0;
+ selectJump(size_t(0));
+ return;
+ }
default:
llvm_unreachable("unexpected `ForEach` value kind");
}
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index 8221f009a659f..844f832cd22c6 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -956,6 +956,92 @@ module @ir attributes { test.foreach } {
// -----
+// Test pdl_interp.foreach over a range of types.
+module @patterns {
+ pdl_interp.func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+
+ ^pat:
+ %results = pdl_interp.get_results of %root : !pdl.range<value>
+ %types = pdl_interp.get_value_type of %results : !pdl.range<type>
+ // Iterate over the types of the results of the root op
+ pdl_interp.foreach %type : !pdl.type in %types {
+ // Only match if the type is i64, verifying we introspect all types
+ // but only trigger one rewrite
+ pdl_interp.check_type %type is i64 -> ^record, ^cont
+ ^record:
+ pdl_interp.record_match @rewriters::@success(%root, %type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ pdl_interp.func @success(%root : !pdl.operation, %type : !pdl.type) {
+ // Create an op for the matched i64 type
+ pdl_interp.create_operation "test.type_found" -> (%type : !pdl.type)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+// CHECK-LABEL: test.foreach_type
+// CHECK: "test.type_found"() : () -> i64
+// CHECK-NOT: "test.type_found"
+module @ir attributes { test.foreach_type } {
+ "test.success_op"() : () -> (i32, i64)
+}
+// -----
+
+// Test pdl_interp.foreach over a range of values returned by native constraint.
+module @patterns {
+ pdl_interp.func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+
+ ^pat:
+ %values = pdl_interp.apply_constraint "op_constr_return_value_range"(%root : !pdl.operation) : !pdl.range<value> -> ^loop, ^end
+
+ ^loop:
+ pdl_interp.foreach %val : !pdl.value in %values {
+ %type = pdl_interp.get_value_type of %val : !pdl.type
+ // Only match if the type is f16, verifying we introspect all values
+ // but only trigger one rewrite
+ pdl_interp.check_type %type is f16 -> ^record, ^cont
+ ^record:
+ pdl_interp.record_match @rewriters::@success(%root, %val : !pdl.operation, !pdl.value) : benefit(1), loc([%root]) -> ^cont
+ ^cont:
+ pdl_interp.continue
+ } -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ pdl_interp.func @success(%root : !pdl.operation, %val : !pdl.value) {
+ %type = pdl_interp.get_value_type of %val : !pdl.type
+ pdl_interp.create_operation "test.value_found"(%val : !pdl.value) -> (%type : !pdl.type)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.foreach_value
+// CHECK: %[[VAL1:.*]] = "test.input1"
+// CHECK: "test.value_found"(%[[VAL1]]) : (f16) -> f16
+// CHECK-NOT: "test.value_found"
+module @ir attributes { test.foreach_value } {
+ %0 = "test.input0"() : () -> f32
+ %1 = "test.input1"() : () -> f16
+ "test.success_op"(%0, %1) : (f32, f16) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl_interp::GetUsersOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index e5783c96f44e4..1e3d0186eb1c9 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -81,6 +81,19 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
return failure();
}
+// Custom constraint that returns a value range if the op is named
+// test.success_op
+static LogicalResult customValueRangeResultConstraint(PatternRewriter &rewriter,
+ PDLResultList &results,
+ ArrayRef<PDLValue> args) {
+ auto *op = args[0].cast<Operation *>();
+ if (op->getName().getStringRef() == "test.success_op") {
+ results.push_back(op->getOperands()); // Returns ValueRange
+ return success();
+ }
+ return failure();
+}
+
// Custom creator invoked from PDL.
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
return rewriter.create(OperationState(op->getLoc(), "test.success"));
@@ -161,6 +174,8 @@ struct TestPDLByteCodePass
customConstraintFailure);
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
customTypeRangeResultConstraint);
+ pdlPattern.registerConstraintFunction("op_constr_return_value_range",
+ customValueRangeResultConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);
More information about the Mlir-commits
mailing list