[Mlir-commits] [mlir] f3df4b9 - [mlir][PDL] Support running `pdl_interp.foreach` on ranges of values and types (#173161)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 2 08:58:53 PST 2026


Author: jumerckx
Date: 2026-02-02T08:58:48-08:00
New Revision: f3df4b9292ce6fbd90de53f0124e55db0a9ee714

URL: https://github.com/llvm/llvm-project/commit/f3df4b9292ce6fbd90de53f0124e55db0a9ee714
DIFF: https://github.com/llvm/llvm-project/commit/f3df4b9292ce6fbd90de53f0124e55db0a9ee714.diff

LOG: [mlir][PDL] Support running `pdl_interp.foreach` on ranges of values and types (#173161)

The foreach execution only works for operation ranges, typically
stemming from pdl_interp.get_users.
Custom rewrites/constraints can return ranges of types and values as
well, however.
This pr adds support for executing `pdl_interp.foreach` in those cases.

Added: 
    

Modified: 
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/test/Rewrite/pdl-bytecode.mlir
    mlir/test/lib/Rewrite/TestPDLByteCode.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index ede7d8a4006fc..cf00216288115 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1741,6 +1741,36 @@ void ByteCodeExecutor::executeForEach() {
     selectJump(size_t(0));
     return;
   }
+  case PDLValue::Kind::Value: {
+    unsigned &index = loopIndex[read()];
+    ValueRange range = valueRangeMemory[rangeIndex];
+    assert(index <= range.size() && "iterated past the end");
+    if (index < range.size()) {
+      LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range[index] << "\n");
+      value = range[index].getAsOpaquePointer();
+      break;
+    }
+
+    LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
+    index = 0;
+    selectJump(size_t(0));
+    return;
+  }
+  case PDLValue::Kind::Type: {
+    unsigned &index = loopIndex[read()];
+    TypeRange range = typeRangeMemory[rangeIndex];
+    assert(index <= range.size() && "iterated past the end");
+    if (index < range.size()) {
+      LLVM_DEBUG(llvm::dbgs() << "  * Result: " << range[index] << "\n");
+      value = range[index].getAsOpaquePointer();
+      break;
+    }
+
+    LLVM_DEBUG(llvm::dbgs() << "  * Done\n");
+    index = 0;
+    selectJump(size_t(0));
+    return;
+  }
   default:
     llvm_unreachable("unexpected `ForEach` value kind");
   }

diff  --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index 8221f009a659f..844f832cd22c6 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -956,6 +956,92 @@ module @ir attributes { test.foreach } {
 
 // -----
 
+// Test pdl_interp.foreach over a range of types.
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+  
+  ^pat:
+    %results = pdl_interp.get_results of %root : !pdl.range<value>
+    %types = pdl_interp.get_value_type of %results : !pdl.range<type>
+    // Iterate over the types of the results of the root op
+    pdl_interp.foreach %type : !pdl.type in %types {
+      // Only match if the type is i64, verifying we introspect all types
+      // but only trigger one rewrite
+      pdl_interp.check_type %type is i64 -> ^record, ^cont
+    ^record:
+      pdl_interp.record_match @rewriters::@success(%root, %type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^cont
+    ^cont:
+      pdl_interp.continue
+    } -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root : !pdl.operation, %type : !pdl.type) {
+      // Create an op for the matched i64 type
+      pdl_interp.create_operation "test.type_found" -> (%type : !pdl.type)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+// CHECK-LABEL: test.foreach_type
+// CHECK: "test.type_found"() : () -> i64
+// CHECK-NOT: "test.type_found"
+module @ir attributes { test.foreach_type } {
+  "test.success_op"() : () -> (i32, i64)
+}
+// -----
+
+// Test pdl_interp.foreach over a range of values returned by native constraint.
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+  
+  ^pat:
+    %values = pdl_interp.apply_constraint "op_constr_return_value_range"(%root : !pdl.operation) : !pdl.range<value> -> ^loop, ^end
+
+  ^loop:
+    pdl_interp.foreach %val : !pdl.value in %values {
+      %type = pdl_interp.get_value_type of %val : !pdl.type
+      // Only match if the type is f16, verifying we introspect all values
+      // but only trigger one rewrite
+      pdl_interp.check_type %type is f16 -> ^record, ^cont
+    ^record:
+      pdl_interp.record_match @rewriters::@success(%root, %val : !pdl.operation, !pdl.value) : benefit(1), loc([%root]) -> ^cont
+    ^cont:
+      pdl_interp.continue
+    } -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root : !pdl.operation, %val : !pdl.value) {
+      %type = pdl_interp.get_value_type of %val : !pdl.type
+      pdl_interp.create_operation "test.value_found"(%val : !pdl.value) -> (%type : !pdl.type)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.foreach_value
+// CHECK: %[[VAL1:.*]] = "test.input1"
+// CHECK: "test.value_found"(%[[VAL1]]) : (f16) -> f16
+// CHECK-NOT: "test.value_found"
+module @ir attributes { test.foreach_value } {
+  %0 = "test.input0"() : () -> f32
+  %1 = "test.input1"() : () -> f16
+  "test.success_op"(%0, %1) : (f32, f16) -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::GetUsersOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index e5783c96f44e4..1e3d0186eb1c9 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -81,6 +81,19 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
   return failure();
 }
 
+// Custom constraint that returns a value range if the op is named
+// test.success_op
+static LogicalResult customValueRangeResultConstraint(PatternRewriter &rewriter,
+                                                      PDLResultList &results,
+                                                      ArrayRef<PDLValue> args) {
+  auto *op = args[0].cast<Operation *>();
+  if (op->getName().getStringRef() == "test.success_op") {
+    results.push_back(op->getOperands()); // Returns ValueRange
+    return success();
+  }
+  return failure();
+}
+
 // Custom creator invoked from PDL.
 static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
   return rewriter.create(OperationState(op->getLoc(), "test.success"));
@@ -161,6 +174,8 @@ struct TestPDLByteCodePass
                                           customConstraintFailure);
     pdlPattern.registerConstraintFunction("op_constr_return_type_range",
                                           customTypeRangeResultConstraint);
+    pdlPattern.registerConstraintFunction("op_constr_return_value_range",
+                                          customValueRangeResultConstraint);
     pdlPattern.registerRewriteFunction("creator", customCreate);
     pdlPattern.registerRewriteFunction("var_creator",
                                        customVariadicResultCreate);


        


More information about the Mlir-commits mailing list