[Mlir-commits] [mlir] 8ec28af - Reapply "[mlir][PDL] Add support for native constraints with results (#82760)"

Matthias Gehre llvmlistbot at llvm.org
Sat Mar 2 11:57:43 PST 2024


Author: Matthias Gehre
Date: 2024-03-02T20:57:30+01:00
New Revision: 8ec28af8eaff5acd0df3e53340159c034f08533d

URL: https://github.com/llvm/llvm-project/commit/8ec28af8eaff5acd0df3e53340159c034f08533d
DIFF: https://github.com/llvm/llvm-project/commit/8ec28af8eaff5acd0df3e53340159c034f08533d.diff

LOG: Reapply "[mlir][PDL] Add support for native constraints with results (#82760)"

with a small stack-use-after-scope fix in getConstraintPredicates()

This reverts commit c80e6edba4a9593f0587e27fa0ac825ebe174afd.

Added: 
    mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir

Modified: 
    mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
    mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
    mlir/include/mlir/IR/PDLPatternMatch.h.inc
    mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
    mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
    mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
    mlir/test/Dialect/PDL/ops.mlir
    mlir/test/Rewrite/pdl-bytecode.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Rewrite/TestPDLByteCode.cpp
    mlir/test/mlir-pdll/Parser/constraint-failure.pdll
    mlir/test/mlir-pdll/Parser/constraint.pdll
    mlir/test/python/dialects/pdl_ops.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 4e9ebccba77d88..1e108c3d8ac77a 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
   let description = [{
     `pdl.apply_native_constraint` operations apply a native C++ constraint, that
     has been registered externally with the consumer of PDL, to a given set of
-    entities.
+    entities and optionally return a number of values.
 
     Example:
 
     ```mlir
     // Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
     pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+    // Apply constraint `with_result` to `root`. This constraint returns an attribute.
+    %attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
     ```
   }];
 
   let arguments = (ins StrAttr:$name, 
                        Variadic<PDL_AnyType>:$args, 
                        DefaultValuedAttr<BoolAttr, "false">:$isNegated);
-  let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
+  let results = (outs Variadic<PDL_AnyType>:$results);
+  let assemblyFormat = [{
+    $name `(` $args `:` type($args) `)` (`:`  type($results)^ )? attr-dict
+  }];
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 48f625bd1fa3fd..901acc0e6733bb 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
   let description = [{
     `pdl_interp.apply_constraint` operations apply a generic constraint, that
     has been registered with the interpreter, with a given set of positional
-    values. On success, this operation branches to the true destination,
+    values.
+    The constraint function may return any number of results.
+    On success, this operation branches to the true destination,
     otherwise the false destination is taken. This behavior can be reversed
     by setting the attribute `isNegated` to true.
 
@@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
   let arguments = (ins StrAttr:$name, 
                        Variadic<PDL_AnyType>:$args,
                        DefaultValuedAttr<BoolAttr, "false">:$isNegated);
+  let results = (outs Variadic<PDL_AnyType>:$results);
   let assemblyFormat = [{
-    $name `(` $args `:` type($args) `)` attr-dict `->` successors
+    $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict 
+    `->` successors
   }];
 }
 

diff  --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
index a215da8cb6431d..66286ed7a4c898 100644
--- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -318,8 +318,9 @@ protected:
 /// A generic PDL pattern constraint function. This function applies a
 /// constraint to a given set of opaque PDLValue entities. Returns success if
 /// the constraint successfully held, failure otherwise.
-using PDLConstraintFunction =
-    std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+using PDLConstraintFunction = std::function<LogicalResult(
+    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+
 /// A native PDL rewrite function. This function performs a rewrite on the
 /// given set of values. Any results from this rewrite that should be passed
 /// back to PDL should be added to the provided result list. This method is only
@@ -726,7 +727,7 @@ std::enable_if_t<
     PDLConstraintFunction>
 buildConstraintFn(ConstraintFnT &&constraintFn) {
   return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
-             PatternRewriter &rewriter,
+             PatternRewriter &rewriter, PDLResultList &,
              ArrayRef<PDLValue> values) -> LogicalResult {
     auto argIndices = std::make_index_sequence<
         llvm::function_traits<ConstraintFnT>::num_args - 1>();
@@ -842,10 +843,13 @@ public:
   /// Register a constraint function with PDL. A constraint function may be
   /// specified in one of two ways:
   ///
-  ///   * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
+  ///   * `LogicalResult (PatternRewriter &,
+  ///                     PDLResultList &,
+  ///                     ArrayRef<PDLValue>)`
   ///
   ///   In this overload the arguments of the constraint function are passed via
-  ///   the low-level PDLValue form.
+  ///   the low-level PDLValue form, and the results are manually appended to
+  ///   the given result list.
   ///
   ///   * `LogicalResult (PatternRewriter &, ValueTs... values)`
   ///
@@ -960,8 +964,8 @@ public:
   }
 };
 class PDLResultList {};
-using PDLConstraintFunction =
-    std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+using PDLConstraintFunction = std::function<LogicalResult(
+    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
 using PDLRewriteFunction = std::function<LogicalResult(
     PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
 

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index e911631a4bc52a..b00cd0dee3ae80 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -50,7 +50,8 @@ struct PatternLowering {
 
   /// Generate interpreter operations for the tree rooted at the given matcher
   /// node, in the specified region.
-  Block *generateMatcher(MatcherNode &node, Region &region);
+  Block *generateMatcher(MatcherNode &node, Region &region,
+                         Block *block = nullptr);
 
   /// Get or create an access to the provided positional value in the current
   /// block. This operation may mutate the provided block pointer if nested
@@ -148,6 +149,10 @@ struct PatternLowering {
   /// A mapping between pattern operations and the corresponding configuration
   /// set.
   DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
+
+  /// A mapping from a constraint question to the ApplyConstraintOp
+  /// that implements it.
+  DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
 };
 } // namespace
 
@@ -182,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) {
   firstMatcherBlock->erase();
 }
 
-Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
+Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
+                                        Block *block) {
   // Push a new scope for the values used by this matcher.
-  Block *block = &region.emplaceBlock();
+  if (!block)
+    block = &region.emplaceBlock();
   ValueMapScope scope(values);
 
   // If this is the return node, simply insert the corresponding interpreter
@@ -364,6 +371,15 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
           loc, cast<ArrayAttr>(rawTypeAttr));
     break;
   }
+  case Predicates::ConstraintResultPos: {
+    // Due to the order of traversal, the ApplyConstraintOp has already been
+    // created and we can find it in constraintOpMap.
+    auto *constrResPos = cast<ConstraintPosition>(pos);
+    auto i = constraintOpMap.find(constrResPos->getQuestion());
+    assert(i != constraintOpMap.end());
+    value = i->second->getResult(constrResPos->getIndex());
+    break;
+  }
   default:
     llvm_unreachable("Generating unknown Position getter");
     break;
@@ -390,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
       args.push_back(getValueAt(currentBlock, position));
   }
 
-  // Generate the matcher in the current (potentially nested) region
-  // and get the failure successor.
-  Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
+  // Generate a new block as success successor and get the failure successor.
+  Block *success = &region->emplaceBlock();
   Block *failure = failureBlockStack.back();
 
-  // Finally, create the predicate.
+  // Create the predicate.
   builder.setInsertionPointToEnd(currentBlock);
   Predicates::Kind kind = question->getKind();
   switch (kind) {
@@ -447,14 +462,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
   }
   case Predicates::ConstraintQuestion: {
     auto *cstQuestion = cast<ConstraintQuestion>(question);
-    builder.create<pdl_interp::ApplyConstraintOp>(
-        loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
-        failure);
+    auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
+        loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
+        cstQuestion->getIsNegated(), success, failure);
+
+    constraintOpMap.insert({cstQuestion, applyConstraintOp});
     break;
   }
   default:
     llvm_unreachable("Generating unknown Predicate operation");
   }
+
+  // Generate the matcher in the current (potentially nested) region.
+  // This might use the results of the current predicate.
+  generateMatcher(*boolNode->getSuccessNode(), *region, success);
 }
 
 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 2c9b63f86d6efa..5ad2c477573a5b 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -47,6 +47,7 @@ enum Kind : unsigned {
   OperandPos,
   OperandGroupPos,
   AttributePos,
+  ConstraintResultPos,
   ResultPos,
   ResultGroupPos,
   TypePos,
@@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
   bool isOperandDefiningOp() const;
 };
 
+//===----------------------------------------------------------------------===//
+// ConstraintPosition
+
+struct ConstraintQuestion;
+
+/// A position describing the result of a native constraint. It saves the
+/// corresponding ConstraintQuestion and result index to enable referring
+/// back to them
+struct ConstraintPosition
+    : public PredicateBase<ConstraintPosition, Position,
+                           std::pair<ConstraintQuestion *, unsigned>,
+                           Predicates::ConstraintResultPos> {
+  using PredicateBase::PredicateBase;
+
+  /// Returns the ConstraintQuestion to enable keeping track of the native
+  /// constraint this position stems from.
+  ConstraintQuestion *getQuestion() const { return key.first; }
+
+  // Returns the result index of this position
+  unsigned getIndex() const { return key.second; }
+};
+
 //===----------------------------------------------------------------------===//
 // ResultPosition
 
@@ -447,11 +470,13 @@ struct AttributeQuestion
     : public PredicateBase<AttributeQuestion, Qualifier, void,
                            Predicates::AttributeQuestion> {};
 
-/// Apply a parameterized constraint to multiple position values.
+/// Apply a parameterized constraint to multiple position values and possibly
+/// produce results.
 struct ConstraintQuestion
-    : public PredicateBase<ConstraintQuestion, Qualifier,
-                           std::tuple<StringRef, ArrayRef<Position *>, bool>,
-                           Predicates::ConstraintQuestion> {
+    : public PredicateBase<
+          ConstraintQuestion, Qualifier,
+          std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
+          Predicates::ConstraintQuestion> {
   using Base::Base;
 
   /// Return the name of the constraint.
@@ -460,15 +485,19 @@ struct ConstraintQuestion
   /// Return the arguments of the constraint.
   ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
 
+  /// Return the result types of the constraint.
+  ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
+
   /// Return the negation status of the constraint.
-  bool getIsNegated() const { return std::get<2>(key); }
+  bool getIsNegated() const { return std::get<3>(key); }
 
   /// Construct an instance with the given storage allocator.
   static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
                                        KeyTy key) {
     return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
                                         alloc.copyInto(std::get<1>(key)),
-                                        std::get<2>(key)});
+                                        alloc.copyInto(std::get<2>(key)),
+                                        std::get<3>(key)});
   }
 
   /// Returns a hash suitable for the given keytype.
@@ -526,6 +555,7 @@ class PredicateUniquer : public StorageUniquer {
     // Register the types of Positions with the uniquer.
     registerParametricStorageType<AttributePosition>();
     registerParametricStorageType<AttributeLiteralPosition>();
+    registerParametricStorageType<ConstraintPosition>();
     registerParametricStorageType<ForEachPosition>();
     registerParametricStorageType<OperandPosition>();
     registerParametricStorageType<OperandGroupPosition>();
@@ -588,6 +618,12 @@ class PredicateBuilder {
     return OperationPosition::get(uniquer, p);
   }
 
+  // Returns a position for a new value created by a constraint.
+  ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
+                                            unsigned index) {
+    return ConstraintPosition::get(uniquer, std::make_pair(q, index));
+  }
+
   /// Returns an attribute position for an attribute of the given operation.
   Position *getAttribute(OperationPosition *p, StringRef name) {
     return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -673,11 +709,11 @@ class PredicateBuilder {
   }
 
   /// Create a predicate that applies a generic constraint.
-  Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
-                          bool isNegated) {
-    return {
-        ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
-        TrueAnswer::get(uniquer)};
+  Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
+                          ArrayRef<Type> resultTypes, bool isNegated) {
+    return {ConstraintQuestion::get(
+                uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
+            TrueAnswer::get(uniquer)};
   }
 
   /// Create a predicate comparing a value with null.

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index a9c3b0a71ef0d7..419ea863919786 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <queue>
@@ -49,14 +50,15 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
                               DenseMap<Value, Position *> &inputs,
                               AttributePosition *pos) {
   assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
-  pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
   predList.emplace_back(pos, builder.getIsNotNull());
 
-  // If the attribute has a type or value, add a constraint.
-  if (Value type = attr.getValueType())
-    getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
-  else if (Attribute value = attr.getValueAttr())
-    predList.emplace_back(pos, builder.getAttributeConstraint(value));
+  if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
+    // If the attribute has a type or value, add a constraint.
+    if (Value type = attr.getValueType())
+      getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
+    else if (Attribute value = attr.getValueAttr())
+      predList.emplace_back(pos, builder.getAttributeConstraint(value));
+  }
 }
 
 /// Collect all of the predicates for the given operand position.
@@ -272,8 +274,27 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
   // Push the constraint to the furthest position.
   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
                                     comparePosDepth);
-  PredicateBuilder::Predicate pred =
-      builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
+  ResultRange results = op.getResults();
+  PredicateBuilder::Predicate pred = builder.getConstraint(
+      op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
+      op.getIsNegated());
+
+  // For each result register a position so it can be used later
+  for (auto [i, result] : llvm::enumerate(results)) {
+    ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
+    ConstraintPosition *pos = builder.getConstraintPosition(q, i);
+    auto [it, inserted] = inputs.try_emplace(result, pos);
+    // If this is an input value that has been visited in the tree, add a
+    // constraint to ensure that both instances refer to the same value.
+    if (!inserted) {
+      Position *first = pos;
+      Position *second = it->second;
+      if (comparePosDepth(second, first))
+        std::tie(second, first) = std::make_pair(first, second);
+
+      predList.emplace_back(second, builder.getEqualTo(first));
+    }
+  }
   predList.emplace_back(pos, pred);
 }
 
@@ -875,6 +896,49 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
   *root = std::make_unique<ExitNode>();
 }
 
+/// Sorts the range begin/end with the partial order given by cmp.
+template <typename Iterator, typename Compare>
+static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
+  while (begin != end) {
+    // Cannot compute sortBeforeOthers in the predicate of stable_partition
+    // because stable_partition will not keep the [begin, end) range intact
+    // while it runs.
+    llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
+    for (auto i = begin; i != end; ++i) {
+      if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
+        sortBeforeOthers.insert(*i);
+    }
+
+    auto const next = std::stable_partition(begin, end, [&](auto const &a) {
+      return sortBeforeOthers.contains(a);
+    });
+    assert(next != begin && "not a partial ordering");
+    begin = next;
+  }
+}
+
+/// Returns true if 'b' depends on a result of 'a'.
+static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
+  auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
+  if (!cqa)
+    return false;
+
+  auto positionDependsOnA = [&](Position *p) {
+    auto *cp = dyn_cast<ConstraintPosition>(p);
+    return cp && cp->getQuestion() == cqa;
+  };
+
+  if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
+    // Does any argument of b use a?
+    return llvm::any_of(cqb->getArgs(), positionDependsOnA);
+  }
+  if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
+    return positionDependsOnA(b->position) ||
+           positionDependsOnA(equalTo->getValue());
+  }
+  return positionDependsOnA(b->position);
+}
+
 /// Given a module containing PDL pattern operations, generate a matcher tree
 /// using the patterns within the given module and return the root matcher node.
 std::unique_ptr<MatcherNode>
@@ -955,6 +1019,10 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
     return *lhs < *rhs;
   });
 
+  // Mostly keep the now established order, but also ensure that
+  // ConstraintQuestions come after the results they use.
+  stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
+
   // Build the matchers for each of the pattern predicate lists.
   std::unique_ptr<MatcherNode> root;
   for (OrderedPredicateList &list : lists)

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index d5f34679f06c60..428b19f4c74af8 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -94,6 +94,12 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
 LogicalResult ApplyNativeConstraintOp::verify() {
   if (getNumOperands() == 0)
     return emitOpError("expected at least one argument");
+  if (llvm::any_of(getResults(), [](OpResult result) {
+        return isa<OperationType>(result.getType());
+      })) {
+    return emitOpError(
+        "returning an operation from a constraint is not supported");
+  }
   return success();
 }
 

diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 6e6992dcdeea78..559ce7f466a83d 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -769,11 +769,25 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
 
 void Generator::generate(pdl_interp::ApplyConstraintOp op,
                          ByteCodeWriter &writer) {
-  assert(constraintToMemIndex.count(op.getName()) &&
-         "expected index for constraint function");
+  // Constraints that should return a value have to be registered as rewrites.
+  // If a constraint and a rewrite of similar name are registered the
+  // constraint takes precedence
   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
   writer.appendPDLValueList(op.getArgs());
   writer.append(ByteCodeField(op.getIsNegated()));
+  ResultRange results = op.getResults();
+  writer.append(ByteCodeField(results.size()));
+  for (Value result : results) {
+    // We record the expected kind of the result, so that we can provide extra
+    // verification of the native rewrite function and handle the failure case
+    // of constraints accordingly.
+    writer.appendPDLValueKind(result);
+
+    // Range results also need to append the range storage index.
+    if (isa<pdl::RangeType>(result.getType()))
+      writer.append(getRangeStorageIndex(result));
+    writer.append(result);
+  }
   writer.append(op.getSuccessors());
 }
 void Generator::generate(pdl_interp::ApplyRewriteOp op,
@@ -786,11 +800,9 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
   ResultRange results = op.getResults();
   writer.append(ByteCodeField(results.size()));
   for (Value result : results) {
-    // In debug mode we also record the expected kind of the result, so that we
+    // We record the expected kind of the result, so that we
     // can provide extra verification of the native rewrite function.
-#ifndef NDEBUG
     writer.appendPDLValueKind(result);
-#endif
 
     // Range results also need to append the range storage index.
     if (isa<pdl::RangeType>(result.getType()))
@@ -1076,6 +1088,28 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
 // ByteCode Execution
 
 namespace {
+/// This class is an instantiation of the PDLResultList that provides access to
+/// the returned results. This API is not on `PDLResultList` to avoid
+/// overexposing access to information specific solely to the ByteCode.
+class ByteCodeRewriteResultList : public PDLResultList {
+public:
+  ByteCodeRewriteResultList(unsigned maxNumResults)
+      : PDLResultList(maxNumResults) {}
+
+  /// Return the list of PDL results.
+  MutableArrayRef<PDLValue> getResults() { return results; }
+
+  /// Return the type ranges allocated by this list.
+  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+    return allocatedTypeRanges;
+  }
+
+  /// Return the value ranges allocated by this list.
+  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+    return allocatedValueRanges;
+  }
+};
+
 /// This class provides support for executing a bytecode stream.
 class ByteCodeExecutor {
 public:
@@ -1152,6 +1186,9 @@ class ByteCodeExecutor {
   void executeSwitchResultCount();
   void executeSwitchType();
   void executeSwitchTypes();
+  void processNativeFunResults(ByteCodeRewriteResultList &results,
+                               unsigned numResults,
+                               LogicalResult &rewriteResult);
 
   /// Pushes a code iterator to the stack.
   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
@@ -1225,6 +1262,8 @@ class ByteCodeExecutor {
     return T::getFromOpaquePointer(pointer);
   }
 
+  void skip(size_t skipN) { curCodeIt += skipN; }
+
   /// Jump to a specific successor based on a predicate value.
   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
   /// Jump to a specific successor based on a destination index.
@@ -1381,33 +1420,11 @@ class ByteCodeExecutor {
   ArrayRef<PDLConstraintFunction> constraintFunctions;
   ArrayRef<PDLRewriteFunction> rewriteFunctions;
 };
-
-/// This class is an instantiation of the PDLResultList that provides access to
-/// the returned results. This API is not on `PDLResultList` to avoid
-/// overexposing access to information specific solely to the ByteCode.
-class ByteCodeRewriteResultList : public PDLResultList {
-public:
-  ByteCodeRewriteResultList(unsigned maxNumResults)
-      : PDLResultList(maxNumResults) {}
-
-  /// Return the list of PDL results.
-  MutableArrayRef<PDLValue> getResults() { return results; }
-
-  /// Return the type ranges allocated by this list.
-  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
-    return allocatedTypeRanges;
-  }
-
-  /// Return the value ranges allocated by this list.
-  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
-    return allocatedValueRanges;
-  }
-};
 } // namespace
 
 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
-  const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
+  ByteCodeField fun_idx = read();
   SmallVector<PDLValue, 16> args;
   readList<PDLValue>(args);
 
@@ -1422,8 +1439,29 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
     llvm::dbgs() << "  * isNegated: " << isNegated << "\n";
     llvm::interleaveComma(args, llvm::dbgs());
   });
-  // Invoke the constraint and jump to the proper destination.
-  selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
+
+  ByteCodeField numResults = read();
+  const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
+  ByteCodeRewriteResultList results(numResults);
+  LogicalResult rewriteResult = constraintFn(rewriter, results, args);
+  ArrayRef<PDLValue> constraintResults = results.getResults();
+  LLVM_DEBUG({
+    if (succeeded(rewriteResult)) {
+      llvm::dbgs() << "  * Constraint succeeded\n";
+      llvm::dbgs() << "  * Results: ";
+      llvm::interleaveComma(constraintResults, llvm::dbgs());
+      llvm::dbgs() << "\n";
+    } else {
+      llvm::dbgs() << "  * Constraint failed\n";
+    }
+  });
+  assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
+         "native PDL rewrite function succeeded but returned "
+         "unexpected number of results");
+  processNativeFunResults(results, numResults, rewriteResult);
+
+  // Depending on the constraint jump to the proper destination.
+  selectJump(isNegated != succeeded(rewriteResult));
 }
 
 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@@ -1445,16 +1483,39 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
   assert(results.getResults().size() == numResults &&
          "native PDL rewrite function returned unexpected number of results");
 
-  // Store the results in the bytecode memory.
-  for (PDLValue &result : results.getResults()) {
-    LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
+  processNativeFunResults(results, numResults, rewriteResult);
 
-// In debug mode we also verify the expected kind of the result.
-#ifndef NDEBUG
-    assert(result.getKind() == read<PDLValue::Kind>() &&
-           "native PDL rewrite function returned an unexpected type of result");
-#endif
+  if (failed(rewriteResult)) {
+    LLVM_DEBUG(llvm::dbgs() << "  - Failed");
+    return failure();
+  }
+  return success();
+}
 
+void ByteCodeExecutor::processNativeFunResults(
+    ByteCodeRewriteResultList &results, unsigned numResults,
+    LogicalResult &rewriteResult) {
+  // Store the results in the bytecode memory or handle missing results on
+  // failure.
+  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
+    PDLValue::Kind resultKind = read<PDLValue::Kind>();
+
+    // Skip the according number of values on the buffer on failure and exit
+    // early as there are no results to process.
+    if (failed(rewriteResult)) {
+      if (resultKind == PDLValue::Kind::TypeRange ||
+          resultKind == PDLValue::Kind::ValueRange) {
+        skip(2);
+      } else {
+        skip(1);
+      }
+      return;
+    }
+    PDLValue result = results.getResults()[resultIdx];
+    LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
+    assert(result.getKind() == resultKind &&
+           "native PDL rewrite function returned an unexpected type of "
+           "result");
     // If the result is a range, we need to copy it over to the bytecodes
     // range memory.
     if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
@@ -1476,13 +1537,6 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
     allocatedTypeRangeMemory.push_back(std::move(it));
   for (auto &it : results.getAllocatedValueRanges())
     allocatedValueRangeMemory.push_back(std::move(it));
-
-  // Process the result of the rewrite.
-  if (failed(rewriteResult)) {
-    LLVM_DEBUG(llvm::dbgs() << "  - Failed");
-    return failure();
-  }
-  return success();
 }
 
 void ByteCodeExecutor::executeAreEqual() {

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 97ff8bd0d8584d..206781ed146692 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -1362,12 +1362,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
   if (failed(parseToken(Token::semicolon,
                         "expected `;` after native declaration")))
     return failure();
-  // TODO: PDL should be able to support constraint results in certain
-  // situations, we should revise this.
-  if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
-    return emitError(
-        "native Constraints currently do not support returning results");
-  }
   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
 }
 

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 02bb8316c02db0..92afb765b5ab4e 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -79,6 +79,57 @@ module @constraints {
 
 // -----
 
+// CHECK-LABEL: module @constraint_with_result
+module @constraint_with_result {
+  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+  // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
+  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
+  pdl.pattern : benefit(1) {
+    %root = operation
+    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+    rewrite %root with "rewriter"(%attr : !pdl.attribute)
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @constraint_with_unused_result
+module @constraint_with_unused_result {
+  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+  // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
+  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
+  pdl.pattern : benefit(1) {
+    %root = operation
+    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+    rewrite %root with "rewriter"
+  }
+}
+
+// -----
+
+// CHECK-LABEL: module @constraint_with_result_multiple
+module @constraint_with_result_multiple {
+  // check that native constraints work as expected even when multiple identical constraints are fused
+
+  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+  // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
+  // CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
+  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]]  : !pdl.operation, !pdl.attribute)
+  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
+  pdl.pattern : benefit(1) {
+    %root = operation
+    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+    rewrite %root with "rewriter"(%attr : !pdl.attribute)
+  }
+  pdl.pattern : benefit(1) {
+    %root = operation
+    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+    rewrite %root with "rewriter"(%attr : !pdl.attribute)
+  }
+}
+
+// -----
+
 // CHECK-LABEL: module @negated_constraint
 module @negated_constraint {
   // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
new file mode 100644
index 00000000000000..4511baff2015a4
--- /dev/null
+++ b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
+
+// Ensuse that the dependency between add & less
+// causes them to be in the correct order.
+// CHECK-LABEL: matcher
+// CHECK: apply_constraint "return_attr_constraint"
+// CHECK: apply_constraint "use_attr_constraint"
+
+module {
+  pdl.pattern : benefit(1) {
+    %0 = attribute
+    %1 = types
+    %2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>)
+    %3 = attribute = 0 : i32
+    %4 = attribute = 1 : i32
+    %5 = apply_native_constraint "return_attr_constraint"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
+    apply_native_constraint "use_attr_constraint"(%0, %5 : !pdl.attribute, !pdl.attribute)
+    rewrite %2 with "rewriter"
+  }
+}
+
+// -----
+
+// CHECK-LABEL: matcher
+// CHECK: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of
+// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_attr_constraint"
+// CHECK: pdl_interp.are_equal %[[ATTR:.*]], %[[CONSTRAINT:.*]]
+
+pdl.pattern : benefit(1) {
+    %inputOp = operation
+    %result = result 0 of %inputOp
+    %attr = pdl.apply_native_constraint "return_attr_constraint"(%inputOp : !pdl.operation) : !pdl.attribute
+    %root = operation(%result : !pdl.value) {"attr" = %attr}
+    rewrite %root with "rewriter"(%attr : !pdl.attribute)
+}
+
+// -----
+
+// CHECK-LABEL: matcher
+// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_value_constr"
+// CHECK: %[[VALUE:.*]] = pdl_interp.get_operand 0
+// CHECK: pdl_interp.are_equal %[[VALUE:.*]], %[[CONSTRAINT:.*]]
+
+pdl.pattern : benefit(1) {
+    %attr = attribute = 10
+    %value = pdl.apply_native_constraint "return_value_constr"(%attr: !pdl.attribute) : !pdl.value
+    %root = operation(%value : !pdl.value)
+    rewrite %root with "rewriter"
+}
+
+// -----
+
+// CHECK-LABEL: matcher
+// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_constr"
+// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
+// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
+
+pdl.pattern : benefit(1) {
+    %attr = attribute = 10
+    %type = pdl.apply_native_constraint "return_type_constr"(%attr: !pdl.attribute) : !pdl.type
+    %root = operation -> (%type : !pdl.type)
+    rewrite %root with "rewriter"
+}
+
+// -----
+
+// CHECK-LABEL: matcher
+// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_range_constr"
+// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
+// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
+
+pdl.pattern : benefit(1) {
+    %attr = attribute = 10
+    %types = pdl.apply_native_constraint "return_type_range_constr"(%attr: !pdl.attribute) : !pdl.range<type>
+    %root = operation -> (%types : !pdl.range<type>)
+    rewrite %root with "rewriter"
+}

diff  --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index 6e6da5cce446ae..20e40deea5f863 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
 
 // -----
 
+pdl.pattern @apply_constraint_with_no_results : benefit(1) {
+  %root = operation
+  apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
+  rewrite %root with "rewriter"
+}
+
+// -----
+
+pdl.pattern @apply_constraint_with_results : benefit(1) {
+  %root = operation
+  %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
+  rewrite %root {
+    apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
+  }
+}
+
+// -----
+
 pdl.pattern @attribute_with_dict : benefit(1) {
   %root = operation
   rewrite %root {

diff  --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index ae61c1a079548a..f8e4f2e83b296a 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -109,6 +109,74 @@ module @ir attributes { test.apply_constraint_3 } {
 
 // -----
 
+// Test returning a type from a native constraint.
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+
+  ^pat:
+    %new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end
+
+  ^pat2:
+    pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
+      %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_constraint_4
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.replaced_by_pattern"() : () -> f32
+module @ir attributes { test.apply_constraint_4 } {
+  "test.failure_op"() : () -> ()
+  "test.success_op"() : () -> ()
+}
+
+// -----
+
+// Test success and failure cases of native constraints with pdl.range results.
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+  
+  ^pat:
+    %num_results = pdl_interp.create_attribute 2 : i32
+    %types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range<type> -> ^pat1, ^end
+
+  ^pat1:
+    pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range<type>) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range<type>) {
+      %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range<type>)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_constraint_5
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32)
+module @ir attributes { test.apply_constraint_5 } {
+  "test.failure_op"() : () -> ()
+  "test.success_op"() : () -> ()
+}
+
+// -----
 
 //===----------------------------------------------------------------------===//
 // pdl_interp::ApplyRewriteOp

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 50caf8f9cfc709..b9424e06bf0318 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -887,7 +887,7 @@ class TestTransformDialectExtension
 #include "TestTransformDialectExtensionTypes.cpp.inc"
         >();
 
-    auto verboseConstraint = [](PatternRewriter &rewriter,
+    auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
                                 ArrayRef<PDLValue> pdlValues) {
       for (const PDLValue &pdlValue : pdlValues) {
         if (Operation *op = pdlValue.dyn_cast<Operation *>()) {

diff  --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index daa1c371f27c92..56af3c15b905f1 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -30,6 +30,50 @@ static LogicalResult customMultiEntityVariadicConstraint(
   return success();
 }
 
+// Custom constraint that returns a value if the op is named test.success_op
+static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
+                                                 PDLResultList &results,
+                                                 ArrayRef<PDLValue> args) {
+  auto *op = args[0].cast<Operation *>();
+  if (op->getName().getStringRef() == "test.success_op") {
+    StringAttr customAttr = rewriter.getStringAttr("test.success");
+    results.push_back(customAttr);
+    return success();
+  }
+  return failure();
+}
+
+// Custom constraint that returns a type if the op is named test.success_op
+static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
+                                                PDLResultList &results,
+                                                ArrayRef<PDLValue> args) {
+  auto *op = args[0].cast<Operation *>();
+  if (op->getName().getStringRef() == "test.success_op") {
+    results.push_back(rewriter.getF32Type());
+    return success();
+  }
+  return failure();
+}
+
+// Custom constraint that returns a type range of variable length if the op is
+// named test.success_op
+static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
+                                                     PDLResultList &results,
+                                                     ArrayRef<PDLValue> args) {
+  auto *op = args[0].cast<Operation *>();
+  int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
+
+  if (op->getName().getStringRef() == "test.success_op") {
+    SmallVector<Type> types;
+    for (int i = 0; i < numTypes; i++) {
+      types.push_back(rewriter.getF32Type());
+    }
+    results.push_back(TypeRange(types));
+    return success();
+  }
+  return failure();
+}
+
 // Custom creator invoked from PDL.
 static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
   return rewriter.create(OperationState(op->getLoc(), "test.success"));
@@ -102,6 +146,12 @@ struct TestPDLByteCodePass
                                           customMultiEntityConstraint);
     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
                                           customMultiEntityVariadicConstraint);
+    pdlPattern.registerConstraintFunction("op_constr_return_attr",
+                                          customValueResultConstraint);
+    pdlPattern.registerConstraintFunction("op_constr_return_type",
+                                          customTypeResultConstraint);
+    pdlPattern.registerConstraintFunction("op_constr_return_type_range",
+                                          customTypeRangeResultConstraint);
     pdlPattern.registerRewriteFunction("creator", customCreate);
     pdlPattern.registerRewriteFunction("var_creator",
                                        customVariadicResultCreate);

diff  --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
index 18877b4bcc50ec..48747d3fa2e681 100644
--- a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
@@ -158,8 +158,3 @@ Pattern {
 
 // CHECK: expected `;` after native declaration
 Constraint Foo() [{}]
-
-// -----
-
-// CHECK: native Constraints currently do not support returning results
-Constraint Foo() -> Op;

diff  --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll
index 1c0a015ab4a7b4..e2a52ff130cc84 100644
--- a/mlir/test/mlir-pdll/Parser/constraint.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint.pdll
@@ -12,6 +12,14 @@ Constraint Foo() [{ /* Native Code */ }];
 
 // -----
 
+// Test that native constraints support returning results.
+
+// CHECK:  Module
+// CHECK:  `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Attr>
+Constraint Foo() -> Attr;
+
+// -----
+
 // CHECK: Module
 // CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
 // CHECK:   `Inputs`

diff  --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py
index 0d364f9222a657..95cb25c14873db 100644
--- a/mlir/test/python/dialects/pdl_ops.py
+++ b/mlir/test/python/dialects/pdl_ops.py
@@ -298,6 +298,6 @@ def test_apply_native_constraint():
     pattern = PatternOp(1)
     with InsertionPoint(pattern.body):
         resultType = TypeOp()
-        ApplyNativeConstraintOp("typeConstraint", args=[resultType])
+        ApplyNativeConstraintOp([], "typeConstraint", args=[resultType])
         root = OperationOp(types=[resultType])
         RewriteOp(root, name="rewrite")


        


More information about the Mlir-commits mailing list