[Mlir-commits] [mlir] [mlir][PDL] Add support for native constraints with results (PR #82760)

Matthias Gehre llvmlistbot at llvm.org
Fri Feb 23 05:05:17 PST 2024


https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/82760

This adds support for native PDL (and PDLL) C++ constraints to return results.

This is useful for situations where a pattern checks for certain constraints of multiple interdependent attributes and computes a new attribute value based on them. Currently, for such an example it is required to escape to C++ during matching to perform the check and after a successful match again escape to native C++ to perform the computation during the rewriting part of the pattern. With this work we can do the computation in C++ during matching and use the result in the rewriting part of the pattern. Effectively this enables a choice in the trade-off of memory consumption during matching vs recomputation of values.

This is an example of a situation where this is useful: We have two operations with certain attributes that have interdependent constraints. For instance attr_foo: one_of [0, 2, 4, 8], attr_bar: one_of [0, 2, 4, 8] and attr_foo == attr_bar. The pattern should only match if all conditions are true. The new operation should be created with a new attribute which is computed from the two matched attributes e.g. attr_baz = attr_foo * attr_bar. For the check we already escape to native C++ and have all values at hand so it makes sense to directly compute the new attribute value as well:

```
Constraint checkAndCompute(attr0: Attr, attr1: Attr) -> Attr;

Pattern example with benefit(1) {
    let foo = op<test.foo>() {attr = attr_foo : Attr};
    let bar = op<test.bar>(foo) {attr = attr_bar : Attr};
    let attr_baz = checkAndCompute(attr_foo, attr_bar);
    rewrite bar with {
        let baz = op<test.baz> {attr=attr_baz};
        replace bar with baz;
    };
}
```
To achieve this the following notable changes were necessary: PDLL
- Remove check in PDLL parser that prevented native constraints from returning results PDL
- Change PDL definition of pdl.apply_native_constraint to allow variadic results PDL_interp
- Change PDL_interp definition of pdl_interp.apply_constraint to allow variadic results

PDLToPDLInterp Pass:
The input to the pass is an arbitrary number of PDL patterns. The pass collects the predicates that are required to match all of the pdl patterns and establishes an ordering that allows creation of a single efficient matcher function to match all of them. Values that are matched and possibly used in the rewriting part of a pattern are represented as positions. This allows fusion and thus reusing a single position for multiple matching patterns. Accordingly, we introduce ConstraintPosition, which records the type and index of the result of the constraint. The problem is for the corresponding value to be used in the rewriting part of a pattern it has to be an input to the pdl_interp.record_match operation, which is generated early during the pass such that its surrounding block can be referred to by branching operations. In consequence the value has to be materialized after the original pdl.apply_native_constraint has been deleted but before we get the chance to generate the corresponding pdl_interp.apply_constraint operation. We solve this by emitting a placeholder value when a ConstraintPosition is evaluated. These placeholder values (due to fusion there may be multiple for one constraint result) are replaced later when the actual pdl_interp.apply_constraint operation is created.

Bytecode generator and interpreter:
Constraint functions which return results have a different type compared to existing constraint functions. They have the same type as native rewrite functions and hence are registered as rewrite functions.

>From ab7d42225958dab1c8745788359637e1a94a3ef1 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 23 Feb 2024 13:58:44 +0100
Subject: [PATCH] [mlir][PDL] Add support for native constraints with results
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This adds support for native PDL (and PDLL) C++ constraints to return results.

This is useful for situations where a pattern checks for certain constraints of multiple interdependent attributes and computes a new attribute value based on them. Currently, for such an example it is required to escape to C++ during matching to perform the check and after a successful match again escape to native C++ to perform the computation during the rewriting part of the pattern.
With this work we can do the computation in C++ during matching and use the result in the rewriting part of the pattern. Effectively this enables a choice in the trade-off of memory consumption during matching vs recomputation of values.

This is an example of a situation where this is useful: We have two operations with certain attributes that have interdependent constraints. For instance attr_foo: one_of [0, 2, 4, 8], attr_bar: one_of [0, 2, 4, 8] and attr_foo == attr_bar. The pattern should only match if all conditions are true. The new operation should be created with a new attribute which is computed from the two matched attributes e.g. attr_baz = attr_foo * attr_bar. For the check we already escape to native C++ and have all values at hand so it makes sense to directly compute the new attribute value as well:

```
Constraint checkAndCompute(attr0: Attr, attr1: Attr) -> Attr;

Pattern example with benefit(1) {
    let foo = op<test.foo>() {attr = attr_foo : Attr};
    let bar = op<test.bar>(foo) {attr = attr_bar : Attr};
    let attr_baz = checkAndCompute(attr_foo, attr_bar);
    rewrite bar with {
        let baz = op<test.baz> {attr=attr_baz};
        replace bar with baz;
    };
}
```
To achieve this the following notable changes were necessary:
PDLL
- Remove check in PDLL parser that prevented native constraints from returning results
PDL
- Change PDL definition of pdl.apply_native_constraint to allow variadic results
PDL_interp
- Change PDL_interp definition of pdl_interp.apply_constraint to allow variadic results

PDLToPDLInterp Pass:
The input to the pass is an arbitrary number of PDL patterns. The pass collects the predicates
that are required to match all of the pdl patterns and establishes an ordering that allows
creation of a single efficient matcher function to match all of them. Values that are matched and
possibly used in the rewriting part of a pattern are represented as positions. This allows fusion
and thus reusing a single position for multiple matching patterns.
Accordingly, we introduce ConstraintPosition, which records the type and index of the result of the constraint.
The problem is for the corresponding value to be used in the rewriting part of a pattern it has
to be an input to the pdl_interp.record_match operation, which is generated early during the pass
such that its surrounding block can be referred to by branching operations. In consequence the value has
to be materialized after the original pdl.apply_native_constraint has been deleted but before
we get the chance to generate the corresponding pdl_interp.apply_constraint operation.
We solve this by emitting a placeholder value when a ConstraintPosition is evaluated.
These placeholder values (due to fusion there may be multiple for one constraint result) are
replaced later when the actual pdl_interp.apply_constraint operation is created.

Bytecode generator and interpreter:
Constraint functions which return results have a different type compared to existing constraint functions.
They have the same type as native rewrite functions and hence are registered as rewrite functions.

Co-authored-by: Martin Lücke <martin.luecke at ed.ac.uk>
---
 mlir/include/mlir/Dialect/PDL/IR/PDLOps.td    |   9 +-
 .../mlir/Dialect/PDLInterp/IR/PDLInterpOps.td |   8 +-
 mlir/include/mlir/IR/PDLPatternMatch.h.inc    |  13 ++
 .../PDLToPDLInterp/PDLToPDLInterp.cpp         |  43 ++++-
 .../lib/Conversion/PDLToPDLInterp/Predicate.h |  58 +++++--
 .../PDLToPDLInterp/PredicateTree.cpp          |  53 +++++-
 mlir/lib/Dialect/PDL/IR/PDL.cpp               |   6 +
 mlir/lib/IR/PDL/PDLPatternMatch.cpp           |   9 +
 mlir/lib/Rewrite/ByteCode.cpp                 | 164 +++++++++++++-----
 mlir/lib/Tools/PDLL/Parser/Parser.cpp         |   6 -
 .../pdl-to-pdl-interp-matcher.mlir            |  51 ++++++
 .../PDLToPDLInterp/use-constraint-result.mlir |  21 +++
 mlir/test/Dialect/PDL/ops.mlir                |  18 ++
 mlir/test/Rewrite/pdl-bytecode.mlir           |  68 ++++++++
 mlir/test/lib/Rewrite/TestPDLByteCode.cpp     |  51 ++++++
 .../mlir-pdll/Parser/constraint-failure.pdll  |   5 -
 mlir/test/mlir-pdll/Parser/constraint.pdll    |   8 +
 17 files changed, 514 insertions(+), 77 deletions(-)
 create mode 100644 mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir

diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 4e9ebccba77d88..1e108c3d8ac77a 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
   let description = [{
     `pdl.apply_native_constraint` operations apply a native C++ constraint, that
     has been registered externally with the consumer of PDL, to a given set of
-    entities.
+    entities and optionally return a number of values.
 
     Example:
 
     ```mlir
     // Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
     pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+    // Apply constraint `with_result` to `root`. This constraint returns an attribute.
+    %attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
     ```
   }];
 
   let arguments = (ins StrAttr:$name, 
                        Variadic<PDL_AnyType>:$args, 
                        DefaultValuedAttr<BoolAttr, "false">:$isNegated);
-  let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
+  let results = (outs Variadic<PDL_AnyType>:$results);
+  let assemblyFormat = [{
+    $name `(` $args `:` type($args) `)` (`:`  type($results)^ )? attr-dict
+  }];
   let hasVerifier = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 48f625bd1fa3fd..901acc0e6733bb 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
   let description = [{
     `pdl_interp.apply_constraint` operations apply a generic constraint, that
     has been registered with the interpreter, with a given set of positional
-    values. On success, this operation branches to the true destination,
+    values.
+    The constraint function may return any number of results.
+    On success, this operation branches to the true destination,
     otherwise the false destination is taken. This behavior can be reversed
     by setting the attribute `isNegated` to true.
 
@@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
   let arguments = (ins StrAttr:$name, 
                        Variadic<PDL_AnyType>:$args,
                        DefaultValuedAttr<BoolAttr, "false">:$isNegated);
+  let results = (outs Variadic<PDL_AnyType>:$results);
   let assemblyFormat = [{
-    $name `(` $args `:` type($args) `)` attr-dict `->` successors
+    $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict 
+    `->` successors
   }];
 }
 
diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
index a215da8cb6431d..3516fda9e87e78 100644
--- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -868,6 +868,19 @@ public:
                                    std::forward<ConstraintFnT>(constraintFn)));
   }
 
+  /// Register a constraint function that produces results with PDL. A
+  /// constraint function with results uses the same registry as rewrite
+  /// functions. It may be specified as follows:
+  ///
+  ///   * `LogicalResult (PatternRewriter &, PDLResultList &,
+  ///   ArrayRef<PDLValue>)`
+  ///
+  ///   The arguments of the constraint function are passed via the low-level
+  ///   PDLValue form, and the results are manually appended to the given result
+  ///   list.
+  void registerConstraintFunctionWithResults(StringRef name,
+                                             PDLRewriteFunction constraintFn);
+
   /// Register a rewrite function with PDL. A rewrite function may be specified
   /// in one of two ways:
   ///
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index e911631a4bc52a..26bab62586ff2a 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -148,6 +148,12 @@ struct PatternLowering {
   /// A mapping between pattern operations and the corresponding configuration
   /// set.
   DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
+
+  /// A mapping from a constraint question and result index that together
+  /// refer to a value created by a constraint to the temporary placeholder
+  /// values created for them.
+  std::multimap<std::pair<ConstraintQuestion *, unsigned>, Value>
+      constraintResultMap;
 };
 } // namespace
 
@@ -364,6 +370,21 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
           loc, cast<ArrayAttr>(rawTypeAttr));
     break;
   }
+  case Predicates::ConstraintResultPos: {
+    // The corresponding pdl.ApplyNativeConstraint op has already been deleted
+    // and the new pdl_interp.ApplyConstraint has not been created yet. To
+    // enable referring to results created by this operation we build a
+    // placeholder value that will be replaced when the actual
+    // pdl_interp.ApplyConstraint operation is created.
+    auto *constrResPos = cast<ConstraintPosition>(pos);
+    Value placeholderValue = builder.create<pdl_interp::CreateAttributeOp>(
+        loc, StringAttr::get(builder.getContext(), "placeholder"));
+    constraintResultMap.insert(
+        {{constrResPos->getQuestion(), constrResPos->getIndex()},
+         placeholderValue});
+    value = placeholderValue;
+    break;
+  }
   default:
     llvm_unreachable("Generating unknown Position getter");
     break;
@@ -447,9 +468,25 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
   }
   case Predicates::ConstraintQuestion: {
     auto *cstQuestion = cast<ConstraintQuestion>(question);
-    builder.create<pdl_interp::ApplyConstraintOp>(
-        loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
-        failure);
+    auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
+        loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
+        cstQuestion->getIsNegated(), success, failure);
+    // Replace the generated placeholders with the results of the constraint and
+    // erase them
+    for (auto result : llvm::enumerate(applyConstraintOp.getResults())) {
+      std::pair<ConstraintQuestion *, unsigned> substitutionKey = {
+          cstQuestion, result.index()};
+      // Check if there are substitutions to perform. If the result is never
+      // used or multiple calls to the same constraint have been merged,
+      // no substitutions will have been generated for this specific op.
+      auto range = constraintResultMap.equal_range(substitutionKey);
+      std::for_each(range.first, range.second, [&](const auto &elem) {
+        Value placeholder = elem.second;
+        placeholder.replaceAllUsesWith(result.value());
+        placeholder.getDefiningOp()->erase();
+      });
+      constraintResultMap.erase(substitutionKey);
+    }
     break;
   }
   default:
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 2c9b63f86d6efa..5ad2c477573a5b 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -47,6 +47,7 @@ enum Kind : unsigned {
   OperandPos,
   OperandGroupPos,
   AttributePos,
+  ConstraintResultPos,
   ResultPos,
   ResultGroupPos,
   TypePos,
@@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
   bool isOperandDefiningOp() const;
 };
 
+//===----------------------------------------------------------------------===//
+// ConstraintPosition
+
+struct ConstraintQuestion;
+
+/// A position describing the result of a native constraint. It saves the
+/// corresponding ConstraintQuestion and result index to enable referring
+/// back to them
+struct ConstraintPosition
+    : public PredicateBase<ConstraintPosition, Position,
+                           std::pair<ConstraintQuestion *, unsigned>,
+                           Predicates::ConstraintResultPos> {
+  using PredicateBase::PredicateBase;
+
+  /// Returns the ConstraintQuestion to enable keeping track of the native
+  /// constraint this position stems from.
+  ConstraintQuestion *getQuestion() const { return key.first; }
+
+  // Returns the result index of this position
+  unsigned getIndex() const { return key.second; }
+};
+
 //===----------------------------------------------------------------------===//
 // ResultPosition
 
@@ -447,11 +470,13 @@ struct AttributeQuestion
     : public PredicateBase<AttributeQuestion, Qualifier, void,
                            Predicates::AttributeQuestion> {};
 
-/// Apply a parameterized constraint to multiple position values.
+/// Apply a parameterized constraint to multiple position values and possibly
+/// produce results.
 struct ConstraintQuestion
-    : public PredicateBase<ConstraintQuestion, Qualifier,
-                           std::tuple<StringRef, ArrayRef<Position *>, bool>,
-                           Predicates::ConstraintQuestion> {
+    : public PredicateBase<
+          ConstraintQuestion, Qualifier,
+          std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
+          Predicates::ConstraintQuestion> {
   using Base::Base;
 
   /// Return the name of the constraint.
@@ -460,15 +485,19 @@ struct ConstraintQuestion
   /// Return the arguments of the constraint.
   ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
 
+  /// Return the result types of the constraint.
+  ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
+
   /// Return the negation status of the constraint.
-  bool getIsNegated() const { return std::get<2>(key); }
+  bool getIsNegated() const { return std::get<3>(key); }
 
   /// Construct an instance with the given storage allocator.
   static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
                                        KeyTy key) {
     return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
                                         alloc.copyInto(std::get<1>(key)),
-                                        std::get<2>(key)});
+                                        alloc.copyInto(std::get<2>(key)),
+                                        std::get<3>(key)});
   }
 
   /// Returns a hash suitable for the given keytype.
@@ -526,6 +555,7 @@ class PredicateUniquer : public StorageUniquer {
     // Register the types of Positions with the uniquer.
     registerParametricStorageType<AttributePosition>();
     registerParametricStorageType<AttributeLiteralPosition>();
+    registerParametricStorageType<ConstraintPosition>();
     registerParametricStorageType<ForEachPosition>();
     registerParametricStorageType<OperandPosition>();
     registerParametricStorageType<OperandGroupPosition>();
@@ -588,6 +618,12 @@ class PredicateBuilder {
     return OperationPosition::get(uniquer, p);
   }
 
+  // Returns a position for a new value created by a constraint.
+  ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
+                                            unsigned index) {
+    return ConstraintPosition::get(uniquer, std::make_pair(q, index));
+  }
+
   /// Returns an attribute position for an attribute of the given operation.
   Position *getAttribute(OperationPosition *p, StringRef name) {
     return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -673,11 +709,11 @@ class PredicateBuilder {
   }
 
   /// Create a predicate that applies a generic constraint.
-  Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
-                          bool isNegated) {
-    return {
-        ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
-        TrueAnswer::get(uniquer)};
+  Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
+                          ArrayRef<Type> resultTypes, bool isNegated) {
+    return {ConstraintQuestion::get(
+                uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
+            TrueAnswer::get(uniquer)};
   }
 
   /// Create a predicate comparing a value with null.
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index a9c3b0a71ef0d7..d5d397eef8b01d 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <queue>
+#include "llvm/ADT/SmallPtrSet.h"
 
 #define DEBUG_TYPE "pdl-predicate-tree"
 
@@ -272,8 +273,17 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
   // Push the constraint to the furthest position.
   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
                                     comparePosDepth);
-  PredicateBuilder::Predicate pred =
-      builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
+  ResultRange results = op.getResults();
+  PredicateBuilder::Predicate pred = builder.getConstraint(
+      op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
+      op.getIsNegated());
+
+  // for each result register a position so it can be used later
+  for (auto result : llvm::enumerate(results)) {
+    ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
+    ConstraintPosition *pos = builder.getConstraintPosition(q, result.index());
+    inputs[result.value()] = pos;
+  }
   predList.emplace_back(pos, pred);
 }
 
@@ -875,6 +885,27 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
   *root = std::make_unique<ExitNode>();
 }
 
+/// Sorts the range begin/end with the partial order given by cmp.
+template <typename Iterator, typename Compare>
+void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
+  while (begin != end) {
+    // Cannot compute sortBeforeOthers in the predicate of stable_partition
+    // because stable_partition will not keep the [begin, end) range intact
+    // while it runs.
+    llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
+    for(auto i = begin; i != end; ++i) {
+      if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
+        sortBeforeOthers.insert(*i);
+    }
+
+    auto const next = std::stable_partition(begin, end, [&](auto const &a) {
+      return sortBeforeOthers.contains(a);
+    });
+    assert(next != begin && "not a partial ordering");
+    begin = next;
+  }
+}
+
 /// Given a module containing PDL pattern operations, generate a matcher tree
 /// using the patterns within the given module and return the root matcher node.
 std::unique_ptr<MatcherNode>
@@ -955,6 +986,24 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
     return *lhs < *rhs;
   });
 
+  // Mostly keep the now established order, but also ensure that
+  // ConstraintQuestions come after the results they use.
+  stableTopologicalSort(ordered.begin(), ordered.end(),
+                        [](OrderedPredicate *a, OrderedPredicate *b) {
+                          auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
+                          auto *cqb = dyn_cast<ConstraintQuestion>(b->question);
+                          if (cqa && cqb) {
+                            // Does any argument of b use a? Then b must be
+                            // sorted after a.
+                            return llvm::any_of(
+                                cqb->getArgs(), [&](Position *p) {
+                                  auto *cp = dyn_cast<ConstraintPosition>(p);
+                                  return cp && cp->getQuestion() == cqa;
+                                });
+                          }
+                          return false;
+                        });
+
   // Build the matchers for each of the pattern predicate lists.
   std::unique_ptr<MatcherNode> root;
   for (OrderedPredicateList &list : lists)
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index d5f34679f06c60..2d19a9f39cebd0 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -94,6 +94,12 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
 LogicalResult ApplyNativeConstraintOp::verify() {
   if (getNumOperands() == 0)
     return emitOpError("expected at least one argument");
+  if (llvm::any_of(getResults(), [](OpResult result) {
+        return result.getType().isa<OperationType>();
+      })) {
+    return emitOpError(
+        "returning an operation from a constraint is not supported");
+  }
   return success();
 }
 
diff --git a/mlir/lib/IR/PDL/PDLPatternMatch.cpp b/mlir/lib/IR/PDL/PDLPatternMatch.cpp
index da07cc462a5a13..e37bcc40d03904 100644
--- a/mlir/lib/IR/PDL/PDLPatternMatch.cpp
+++ b/mlir/lib/IR/PDL/PDLPatternMatch.cpp
@@ -123,6 +123,15 @@ void PDLPatternModule::registerConstraintFunction(
   constraintFunctions.try_emplace(name, std::move(constraintFn));
 }
 
+void PDLPatternModule::registerConstraintFunctionWithResults(
+    StringRef name, PDLRewriteFunction constraintFn) {
+  // TODO: Is it possible to diagnose when `name` is already registered to
+  // a function that is not equivalent to `rewriteFn`?
+  // Allow existing mappings in the case multiple patterns depend on the same
+  // rewrite.
+  registerRewriteFunction(name, std::move(constraintFn));
+}
+
 void PDLPatternModule::registerRewriteFunction(StringRef name,
                                                PDLRewriteFunction rewriteFn) {
   // TODO: Is it possible to diagnose when `name` is already registered to
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 6e6992dcdeea78..6090255f30fe01 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -769,11 +769,36 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
 
 void Generator::generate(pdl_interp::ApplyConstraintOp op,
                          ByteCodeWriter &writer) {
-  assert(constraintToMemIndex.count(op.getName()) &&
-         "expected index for constraint function");
-  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
+  // Constraints that should return a value have to be registered as rewrites.
+  // If a constraint and a rewrite of similar name are registered the
+  // constraint takes precedence
+  ResultRange results = op.getResults();
+  if (results.empty() && constraintToMemIndex.count(op.getName()) != 0) {
+    writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
+  } else if (!results.empty() &&
+             externalRewriterToMemIndex.count(op.getName()) != 0) {
+    writer.append(OpCode::ApplyConstraint,
+                  externalRewriterToMemIndex[op.getName()]);
+  } else {
+    assert(true && "expected index for constraint function, make sure it is "
+                   "registered properly. Note that native constraints with "
+                   "results have to be registered using "
+                   "PDLPatternModule::registerConstraintFunctionWithResults.");
+  }
   writer.appendPDLValueList(op.getArgs());
   writer.append(ByteCodeField(op.getIsNegated()));
+  writer.append(ByteCodeField(results.size()));
+  for (Value result : results) {
+    // We record the expected kind of the result, so that we can provide extra
+    // verification of the native rewrite function and handle the failure case
+    // of constraints accordingly.
+    writer.appendPDLValueKind(result);
+
+    // Range results also need to append the range storage index.
+    if (result.getType().isa<pdl::RangeType>())
+      writer.append(getRangeStorageIndex(result));
+    writer.append(result);
+  }
   writer.append(op.getSuccessors());
 }
 void Generator::generate(pdl_interp::ApplyRewriteOp op,
@@ -786,11 +811,9 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
   ResultRange results = op.getResults();
   writer.append(ByteCodeField(results.size()));
   for (Value result : results) {
-    // In debug mode we also record the expected kind of the result, so that we
+    // We record the expected kind of the result, so that we
     // can provide extra verification of the native rewrite function.
-#ifndef NDEBUG
     writer.appendPDLValueKind(result);
-#endif
 
     // Range results also need to append the range storage index.
     if (isa<pdl::RangeType>(result.getType()))
@@ -1076,6 +1099,28 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
 // ByteCode Execution
 
 namespace {
+/// This class is an instantiation of the PDLResultList that provides access to
+/// the returned results. This API is not on `PDLResultList` to avoid
+/// overexposing access to information specific solely to the ByteCode.
+class ByteCodeRewriteResultList : public PDLResultList {
+public:
+  ByteCodeRewriteResultList(unsigned maxNumResults)
+      : PDLResultList(maxNumResults) {}
+
+  /// Return the list of PDL results.
+  MutableArrayRef<PDLValue> getResults() { return results; }
+
+  /// Return the type ranges allocated by this list.
+  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+    return allocatedTypeRanges;
+  }
+
+  /// Return the value ranges allocated by this list.
+  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+    return allocatedValueRanges;
+  }
+};
+
 /// This class provides support for executing a bytecode stream.
 class ByteCodeExecutor {
 public:
@@ -1152,6 +1197,9 @@ class ByteCodeExecutor {
   void executeSwitchResultCount();
   void executeSwitchType();
   void executeSwitchTypes();
+  void processNativeFunResults(ByteCodeRewriteResultList &results,
+                               unsigned numResults,
+                               LogicalResult &rewriteResult);
 
   /// Pushes a code iterator to the stack.
   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
@@ -1225,6 +1273,8 @@ class ByteCodeExecutor {
     return T::getFromOpaquePointer(pointer);
   }
 
+  void skip(size_t skipN) { curCodeIt += skipN; }
+
   /// Jump to a specific successor based on a predicate value.
   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
   /// Jump to a specific successor based on a destination index.
@@ -1381,33 +1431,11 @@ class ByteCodeExecutor {
   ArrayRef<PDLConstraintFunction> constraintFunctions;
   ArrayRef<PDLRewriteFunction> rewriteFunctions;
 };
-
-/// This class is an instantiation of the PDLResultList that provides access to
-/// the returned results. This API is not on `PDLResultList` to avoid
-/// overexposing access to information specific solely to the ByteCode.
-class ByteCodeRewriteResultList : public PDLResultList {
-public:
-  ByteCodeRewriteResultList(unsigned maxNumResults)
-      : PDLResultList(maxNumResults) {}
-
-  /// Return the list of PDL results.
-  MutableArrayRef<PDLValue> getResults() { return results; }
-
-  /// Return the type ranges allocated by this list.
-  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
-    return allocatedTypeRanges;
-  }
-
-  /// Return the value ranges allocated by this list.
-  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
-    return allocatedValueRanges;
-  }
-};
 } // namespace
 
 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
-  const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
+  ByteCodeField fun_idx = read();
   SmallVector<PDLValue, 16> args;
   readList<PDLValue>(args);
 
@@ -1422,8 +1450,36 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
     llvm::dbgs() << "  * isNegated: " << isNegated << "\n";
     llvm::interleaveComma(args, llvm::dbgs());
   });
-  // Invoke the constraint and jump to the proper destination.
-  selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
+
+  ByteCodeField numResults = read();
+  if (numResults == 0) {
+    const PDLConstraintFunction &constraintFn = constraintFunctions[fun_idx];
+    LogicalResult rewriteResult = constraintFn(rewriter, args);
+    // Depending on the constraint jump to the proper destination.
+    selectJump(isNegated != succeeded(rewriteResult));
+    return;
+  }
+  const PDLRewriteFunction &constraintFn = rewriteFunctions[fun_idx];
+  ByteCodeRewriteResultList results(numResults);
+  LogicalResult rewriteResult = constraintFn(rewriter, results, args);
+  ArrayRef<PDLValue> constraintResults = results.getResults();
+  LLVM_DEBUG({
+    if (succeeded(rewriteResult)) {
+      llvm::dbgs() << "  * Constraint succeeded\n";
+      llvm::dbgs() << "  * Results: ";
+      llvm::interleaveComma(constraintResults, llvm::dbgs());
+      llvm::dbgs() << "\n";
+    } else {
+      llvm::dbgs() << "  * Constraint failed\n";
+    }
+  });
+  assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
+         "native PDL rewrite function succeeded but returned "
+         "unexpected number of results");
+  processNativeFunResults(results, numResults, rewriteResult);
+
+  // Depending on the constraint jump to the proper destination.
+  selectJump(isNegated != succeeded(rewriteResult));
 }
 
 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@@ -1445,16 +1501,39 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
   assert(results.getResults().size() == numResults &&
          "native PDL rewrite function returned unexpected number of results");
 
-  // Store the results in the bytecode memory.
-  for (PDLValue &result : results.getResults()) {
-    LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
+  processNativeFunResults(results, numResults, rewriteResult);
 
-// In debug mode we also verify the expected kind of the result.
-#ifndef NDEBUG
-    assert(result.getKind() == read<PDLValue::Kind>() &&
-           "native PDL rewrite function returned an unexpected type of result");
-#endif
+  if (failed(rewriteResult)) {
+    LLVM_DEBUG(llvm::dbgs() << "  - Failed");
+    return failure();
+  }
+  return success();
+}
 
+void ByteCodeExecutor::processNativeFunResults(
+    ByteCodeRewriteResultList &results, unsigned numResults,
+    LogicalResult &rewriteResult) {
+  // Store the results in the bytecode memory or handle missing results on
+  // failure.
+  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
+    PDLValue::Kind resultKind = read<PDLValue::Kind>();
+
+    // Skip the according number of values on the buffer on failure and exit
+    // early as there are no results to process.
+    if (failed(rewriteResult)) {
+      if (resultKind == PDLValue::Kind::TypeRange ||
+          resultKind == PDLValue::Kind::ValueRange) {
+        skip(2);
+      } else {
+        skip(1);
+      }
+      return;
+    }
+    PDLValue result = results.getResults()[resultIdx];
+    LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
+    assert(result.getKind() == resultKind &&
+           "native PDL rewrite function returned an unexpected type of "
+           "result");
     // If the result is a range, we need to copy it over to the bytecodes
     // range memory.
     if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
@@ -1476,13 +1555,6 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
     allocatedTypeRangeMemory.push_back(std::move(it));
   for (auto &it : results.getAllocatedValueRanges())
     allocatedValueRangeMemory.push_back(std::move(it));
-
-  // Process the result of the rewrite.
-  if (failed(rewriteResult)) {
-    LLVM_DEBUG(llvm::dbgs() << "  - Failed");
-    return failure();
-  }
-  return success();
 }
 
 void ByteCodeExecutor::executeAreEqual() {
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 97ff8bd0d8584d..206781ed146692 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -1362,12 +1362,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
   if (failed(parseToken(Token::semicolon,
                         "expected `;` after native declaration")))
     return failure();
-  // TODO: PDL should be able to support constraint results in certain
-  // situations, we should revise this.
-  if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
-    return emitError(
-        "native Constraints currently do not support returning results");
-  }
   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
 }
 
diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 02bb8316c02db0..92afb765b5ab4e 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -79,6 +79,57 @@ module @constraints {
 
 // -----
 
+// CHECK-LABEL: module @constraint_with_result
+module @constraint_with_result {
+  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+  // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
+  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
+  pdl.pattern : benefit(1) {
+    %root = operation
+    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+    rewrite %root with "rewriter"(%attr : !pdl.attribute)
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @constraint_with_unused_result
+module @constraint_with_unused_result {
+  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+  // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
+  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
+  pdl.pattern : benefit(1) {
+    %root = operation
+    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+    rewrite %root with "rewriter"
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @constraint_with_result_multiple
+module @constraint_with_result_multiple {
+  // check that native constraints work as expected even when multiple identical constraints are fused
+
+  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+  // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
+  // CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
+  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]]  : !pdl.operation, !pdl.attribute)
+  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
+  pdl.pattern : benefit(1) {
+    %root = operation
+    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+    rewrite %root with "rewriter"(%attr : !pdl.attribute)
+  }
+  pdl.pattern : benefit(1) {
+    %root = operation
+    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+    rewrite %root with "rewriter"(%attr : !pdl.attribute)
+  }
+}
+
+// -----
+
 // CHECK-LABEL: module @negated_constraint
 module @negated_constraint {
   // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
diff --git a/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
new file mode 100644
index 00000000000000..cdd51ff0ad6379
--- /dev/null
+++ b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
+
+// Ensuse that the dependency between add & less
+// causes them to be in the correct order.
+// CHECK: apply_constraint "__builtin_add"
+// CHECK: apply_constraint "__builtin_less"
+
+module {
+  pdl.pattern @test : benefit(1) {
+    %0 = attribute
+    %1 = types
+    %2 = operation "tosa.mul"  {"shift" = %0} -> (%1 : !pdl.range<type>)
+    %3 = attribute = 0 : i32
+    %4 = attribute = 1 : i32
+    %5 = apply_native_constraint "__builtin_add"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
+    apply_native_constraint "__builtin_less"(%0, %5 : !pdl.attribute, !pdl.attribute)
+    rewrite %2 {
+      replace %2 with %2
+    }
+  }
+}
diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index 6e6da5cce446ae..20e40deea5f863 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
 
 // -----
 
+pdl.pattern @apply_constraint_with_no_results : benefit(1) {
+  %root = operation
+  apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
+  rewrite %root with "rewriter"
+}
+
+// -----
+
+pdl.pattern @apply_constraint_with_results : benefit(1) {
+  %root = operation
+  %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
+  rewrite %root {
+    apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
+  }
+}
+
+// -----
+
 pdl.pattern @attribute_with_dict : benefit(1) {
   %root = operation
   rewrite %root {
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index ae61c1a079548a..f8e4f2e83b296a 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -109,6 +109,74 @@ module @ir attributes { test.apply_constraint_3 } {
 
 // -----
 
+// Test returning a type from a native constraint.
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+
+  ^pat:
+    %new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
+      %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_constraint_4
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.replaced_by_pattern"() : () -> f32
+module @ir attributes { test.apply_constraint_4 } {
+  "test.failure_op"() : () -> ()
+  "test.success_op"() : () -> ()
+}
+
+// -----
+
+// Test success and failure cases of native constraints with pdl.range results.
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+  
+  ^pat:
+    %num_results = pdl_interp.create_attribute 2 : i32
+    %types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range<type> -> ^pat1, ^end
+
+  ^pat1:
+    pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range<type>) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range<type>) {
+      %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range<type>)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_constraint_5
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32)
+module @ir attributes { test.apply_constraint_5 } {
+  "test.failure_op"() : () -> ()
+  "test.success_op"() : () -> ()
+}
+
+// -----
 
 //===----------------------------------------------------------------------===//
 // pdl_interp::ApplyRewriteOp
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index daa1c371f27c92..73c8844ece6a0b 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -30,6 +30,51 @@ static LogicalResult customMultiEntityVariadicConstraint(
   return success();
 }
 
+// Custom constraint that returns a value if the op is named test.success_op
+static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
+                                                 PDLResultList &results,
+                                                 ArrayRef<PDLValue> args) {
+  auto *op = args[0].cast<Operation *>();
+  if (op->getName().getStringRef() == "test.success_op") {
+    StringAttr customAttr = rewriter.getStringAttr("test.success");
+    results.push_back(customAttr);
+    return success();
+  }
+  return failure();
+}
+
+// Custom constraint that returns a type if the op is named test.success_op
+static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
+                                                PDLResultList &results,
+                                                ArrayRef<PDLValue> args) {
+  auto *op = args[0].cast<Operation *>();
+  if (op->getName().getStringRef() == "test.success_op") {
+    results.push_back(rewriter.getF32Type());
+    return success();
+  }
+  return failure();
+}
+
+// Custom constraint that returns a type range of variable length if the op is
+// named test.success_op
+static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
+                                                     PDLResultList &results,
+                                                     ArrayRef<PDLValue> args) {
+  auto *op = args[0].cast<Operation *>();
+  int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
+
+  if (op->getName().getStringRef() == "test.success_op") {
+    SmallVector<Type> types;
+    for (int i = 0; i < numTypes; i++) {
+      types.push_back(rewriter.getF32Type());
+    }
+    results.push_back(TypeRange(types));
+    return success();
+  }
+  return failure();
+}
+
+
 // Custom creator invoked from PDL.
 static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
   return rewriter.create(OperationState(op->getLoc(), "test.success"));
@@ -102,6 +147,12 @@ struct TestPDLByteCodePass
                                           customMultiEntityConstraint);
     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
                                           customMultiEntityVariadicConstraint);
+    pdlPattern.registerConstraintFunctionWithResults(
+        "op_constr_return_attr", customValueResultConstraint);
+    pdlPattern.registerConstraintFunctionWithResults(
+        "op_constr_return_type", customTypeResultConstraint);
+    pdlPattern.registerConstraintFunctionWithResults(
+        "op_constr_return_type_range", customTypeRangeResultConstraint);
     pdlPattern.registerRewriteFunction("creator", customCreate);
     pdlPattern.registerRewriteFunction("var_creator",
                                        customVariadicResultCreate);
diff --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
index 18877b4bcc50ec..48747d3fa2e681 100644
--- a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
@@ -158,8 +158,3 @@ Pattern {
 
 // CHECK: expected `;` after native declaration
 Constraint Foo() [{}]
-
-// -----
-
-// CHECK: native Constraints currently do not support returning results
-Constraint Foo() -> Op;
diff --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll
index 1c0a015ab4a7b4..e2a52ff130cc84 100644
--- a/mlir/test/mlir-pdll/Parser/constraint.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint.pdll
@@ -12,6 +12,14 @@ Constraint Foo() [{ /* Native Code */ }];
 
 // -----
 
+// Test that native constraints support returning results.
+
+// CHECK:  Module
+// CHECK:  `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Attr>
+Constraint Foo() -> Attr;
+
+// -----
+
 // CHECK: Module
 // CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
 // CHECK:   `Inputs`



More information about the Mlir-commits mailing list