[Mlir-commits] [mlir] c80e6ed - Revert "[mlir][PDL] Add support for native constraints with results (#82760)"

Matthias Gehre llvmlistbot at llvm.org
Thu Feb 29 22:44:58 PST 2024


Author: Matthias Gehre
Date: 2024-03-01T07:44:30+01:00
New Revision: c80e6edba4a9593f0587e27fa0ac825ebe174afd

URL: https://github.com/llvm/llvm-project/commit/c80e6edba4a9593f0587e27fa0ac825ebe174afd
DIFF: https://github.com/llvm/llvm-project/commit/c80e6edba4a9593f0587e27fa0ac825ebe174afd.diff

LOG: Revert "[mlir][PDL] Add support for native constraints with results (#82760)"

Due to buildbot failure https://lab.llvm.org/buildbot/#/builders/88/builds/72130

This reverts commit dca32a3b594b3c91f9766a9312b5d82534910fa1.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
    mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
    mlir/include/mlir/IR/PDLPatternMatch.h.inc
    mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
    mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
    mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
    mlir/test/Dialect/PDL/ops.mlir
    mlir/test/Rewrite/pdl-bytecode.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Rewrite/TestPDLByteCode.cpp
    mlir/test/mlir-pdll/Parser/constraint-failure.pdll
    mlir/test/mlir-pdll/Parser/constraint.pdll
    mlir/test/python/dialects/pdl_ops.py

Removed: 
    mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 1e108c3d8ac77a..4e9ebccba77d88 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -35,25 +35,20 @@ def PDL_ApplyNativeConstraintOp
   let description = [{
     `pdl.apply_native_constraint` operations apply a native C++ constraint, that
     has been registered externally with the consumer of PDL, to a given set of
-    entities and optionally return a number of values.
+    entities.
 
     Example:
 
     ```mlir
     // Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
     pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
-    // Apply constraint `with_result` to `root`. This constraint returns an attribute.
-    %attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
     ```
   }];
 
   let arguments = (ins StrAttr:$name, 
                        Variadic<PDL_AnyType>:$args, 
                        DefaultValuedAttr<BoolAttr, "false">:$isNegated);
-  let results = (outs Variadic<PDL_AnyType>:$results);
-  let assemblyFormat = [{
-    $name `(` $args `:` type($args) `)` (`:`  type($results)^ )? attr-dict
-  }];
+  let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 901acc0e6733bb..48f625bd1fa3fd 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -88,9 +88,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
   let description = [{
     `pdl_interp.apply_constraint` operations apply a generic constraint, that
     has been registered with the interpreter, with a given set of positional
-    values.
-    The constraint function may return any number of results.
-    On success, this operation branches to the true destination,
+    values. On success, this operation branches to the true destination,
     otherwise the false destination is taken. This behavior can be reversed
     by setting the attribute `isNegated` to true.
 
@@ -106,10 +104,8 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
   let arguments = (ins StrAttr:$name, 
                        Variadic<PDL_AnyType>:$args,
                        DefaultValuedAttr<BoolAttr, "false">:$isNegated);
-  let results = (outs Variadic<PDL_AnyType>:$results);
   let assemblyFormat = [{
-    $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict 
-    `->` successors
+    $name `(` $args `:` type($args) `)` attr-dict `->` successors
   }];
 }
 

diff  --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
index 66286ed7a4c898..a215da8cb6431d 100644
--- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -318,9 +318,8 @@ protected:
 /// A generic PDL pattern constraint function. This function applies a
 /// constraint to a given set of opaque PDLValue entities. Returns success if
 /// the constraint successfully held, failure otherwise.
-using PDLConstraintFunction = std::function<LogicalResult(
-    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
-
+using PDLConstraintFunction =
+    std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
 /// A native PDL rewrite function. This function performs a rewrite on the
 /// given set of values. Any results from this rewrite that should be passed
 /// back to PDL should be added to the provided result list. This method is only
@@ -727,7 +726,7 @@ std::enable_if_t<
     PDLConstraintFunction>
 buildConstraintFn(ConstraintFnT &&constraintFn) {
   return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
-             PatternRewriter &rewriter, PDLResultList &,
+             PatternRewriter &rewriter,
              ArrayRef<PDLValue> values) -> LogicalResult {
     auto argIndices = std::make_index_sequence<
         llvm::function_traits<ConstraintFnT>::num_args - 1>();
@@ -843,13 +842,10 @@ public:
   /// Register a constraint function with PDL. A constraint function may be
   /// specified in one of two ways:
   ///
-  ///   * `LogicalResult (PatternRewriter &,
-  ///                     PDLResultList &,
-  ///                     ArrayRef<PDLValue>)`
+  ///   * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
   ///
   ///   In this overload the arguments of the constraint function are passed via
-  ///   the low-level PDLValue form, and the results are manually appended to
-  ///   the given result list.
+  ///   the low-level PDLValue form.
   ///
   ///   * `LogicalResult (PatternRewriter &, ValueTs... values)`
   ///
@@ -964,8 +960,8 @@ public:
   }
 };
 class PDLResultList {};
-using PDLConstraintFunction = std::function<LogicalResult(
-    PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+using PDLConstraintFunction =
+    std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
 using PDLRewriteFunction = std::function<LogicalResult(
     PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
 

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index b00cd0dee3ae80..e911631a4bc52a 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -50,8 +50,7 @@ struct PatternLowering {
 
   /// Generate interpreter operations for the tree rooted at the given matcher
   /// node, in the specified region.
-  Block *generateMatcher(MatcherNode &node, Region &region,
-                         Block *block = nullptr);
+  Block *generateMatcher(MatcherNode &node, Region &region);
 
   /// Get or create an access to the provided positional value in the current
   /// block. This operation may mutate the provided block pointer if nested
@@ -149,10 +148,6 @@ struct PatternLowering {
   /// A mapping between pattern operations and the corresponding configuration
   /// set.
   DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
-
-  /// A mapping from a constraint question to the ApplyConstraintOp
-  /// that implements it.
-  DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
 };
 } // namespace
 
@@ -187,11 +182,9 @@ void PatternLowering::lower(ModuleOp module) {
   firstMatcherBlock->erase();
 }
 
-Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
-                                        Block *block) {
+Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
   // Push a new scope for the values used by this matcher.
-  if (!block)
-    block = &region.emplaceBlock();
+  Block *block = &region.emplaceBlock();
   ValueMapScope scope(values);
 
   // If this is the return node, simply insert the corresponding interpreter
@@ -371,15 +364,6 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
           loc, cast<ArrayAttr>(rawTypeAttr));
     break;
   }
-  case Predicates::ConstraintResultPos: {
-    // Due to the order of traversal, the ApplyConstraintOp has already been
-    // created and we can find it in constraintOpMap.
-    auto *constrResPos = cast<ConstraintPosition>(pos);
-    auto i = constraintOpMap.find(constrResPos->getQuestion());
-    assert(i != constraintOpMap.end());
-    value = i->second->getResult(constrResPos->getIndex());
-    break;
-  }
   default:
     llvm_unreachable("Generating unknown Position getter");
     break;
@@ -406,11 +390,12 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
       args.push_back(getValueAt(currentBlock, position));
   }
 
-  // Generate a new block as success successor and get the failure successor.
-  Block *success = &region->emplaceBlock();
+  // Generate the matcher in the current (potentially nested) region
+  // and get the failure successor.
+  Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
   Block *failure = failureBlockStack.back();
 
-  // Create the predicate.
+  // Finally, create the predicate.
   builder.setInsertionPointToEnd(currentBlock);
   Predicates::Kind kind = question->getKind();
   switch (kind) {
@@ -462,20 +447,14 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
   }
   case Predicates::ConstraintQuestion: {
     auto *cstQuestion = cast<ConstraintQuestion>(question);
-    auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
-        loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
-        cstQuestion->getIsNegated(), success, failure);
-
-    constraintOpMap.insert({cstQuestion, applyConstraintOp});
+    builder.create<pdl_interp::ApplyConstraintOp>(
+        loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
+        failure);
     break;
   }
   default:
     llvm_unreachable("Generating unknown Predicate operation");
   }
-
-  // Generate the matcher in the current (potentially nested) region.
-  // This might use the results of the current predicate.
-  generateMatcher(*boolNode->getSuccessNode(), *region, success);
 }
 
 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 5ad2c477573a5b..2c9b63f86d6efa 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -47,7 +47,6 @@ enum Kind : unsigned {
   OperandPos,
   OperandGroupPos,
   AttributePos,
-  ConstraintResultPos,
   ResultPos,
   ResultGroupPos,
   TypePos,
@@ -280,28 +279,6 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
   bool isOperandDefiningOp() const;
 };
 
-//===----------------------------------------------------------------------===//
-// ConstraintPosition
-
-struct ConstraintQuestion;
-
-/// A position describing the result of a native constraint. It saves the
-/// corresponding ConstraintQuestion and result index to enable referring
-/// back to them
-struct ConstraintPosition
-    : public PredicateBase<ConstraintPosition, Position,
-                           std::pair<ConstraintQuestion *, unsigned>,
-                           Predicates::ConstraintResultPos> {
-  using PredicateBase::PredicateBase;
-
-  /// Returns the ConstraintQuestion to enable keeping track of the native
-  /// constraint this position stems from.
-  ConstraintQuestion *getQuestion() const { return key.first; }
-
-  // Returns the result index of this position
-  unsigned getIndex() const { return key.second; }
-};
-
 //===----------------------------------------------------------------------===//
 // ResultPosition
 
@@ -470,13 +447,11 @@ struct AttributeQuestion
     : public PredicateBase<AttributeQuestion, Qualifier, void,
                            Predicates::AttributeQuestion> {};
 
-/// Apply a parameterized constraint to multiple position values and possibly
-/// produce results.
+/// Apply a parameterized constraint to multiple position values.
 struct ConstraintQuestion
-    : public PredicateBase<
-          ConstraintQuestion, Qualifier,
-          std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
-          Predicates::ConstraintQuestion> {
+    : public PredicateBase<ConstraintQuestion, Qualifier,
+                           std::tuple<StringRef, ArrayRef<Position *>, bool>,
+                           Predicates::ConstraintQuestion> {
   using Base::Base;
 
   /// Return the name of the constraint.
@@ -485,19 +460,15 @@ struct ConstraintQuestion
   /// Return the arguments of the constraint.
   ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
 
-  /// Return the result types of the constraint.
-  ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
-
   /// Return the negation status of the constraint.
-  bool getIsNegated() const { return std::get<3>(key); }
+  bool getIsNegated() const { return std::get<2>(key); }
 
   /// Construct an instance with the given storage allocator.
   static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
                                        KeyTy key) {
     return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
                                         alloc.copyInto(std::get<1>(key)),
-                                        alloc.copyInto(std::get<2>(key)),
-                                        std::get<3>(key)});
+                                        std::get<2>(key)});
   }
 
   /// Returns a hash suitable for the given keytype.
@@ -555,7 +526,6 @@ class PredicateUniquer : public StorageUniquer {
     // Register the types of Positions with the uniquer.
     registerParametricStorageType<AttributePosition>();
     registerParametricStorageType<AttributeLiteralPosition>();
-    registerParametricStorageType<ConstraintPosition>();
     registerParametricStorageType<ForEachPosition>();
     registerParametricStorageType<OperandPosition>();
     registerParametricStorageType<OperandGroupPosition>();
@@ -618,12 +588,6 @@ class PredicateBuilder {
     return OperationPosition::get(uniquer, p);
   }
 
-  // Returns a position for a new value created by a constraint.
-  ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
-                                            unsigned index) {
-    return ConstraintPosition::get(uniquer, std::make_pair(q, index));
-  }
-
   /// Returns an attribute position for an attribute of the given operation.
   Position *getAttribute(OperationPosition *p, StringRef name) {
     return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -709,11 +673,11 @@ class PredicateBuilder {
   }
 
   /// Create a predicate that applies a generic constraint.
-  Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
-                          ArrayRef<Type> resultTypes, bool isNegated) {
-    return {ConstraintQuestion::get(
-                uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
-            TrueAnswer::get(uniquer)};
+  Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
+                          bool isNegated) {
+    return {
+        ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
+        TrueAnswer::get(uniquer)};
   }
 
   /// Create a predicate comparing a value with null.

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index f3d0e08ef49b97..a9c3b0a71ef0d7 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -15,7 +15,6 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <queue>
@@ -50,15 +49,14 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
                               DenseMap<Value, Position *> &inputs,
                               AttributePosition *pos) {
   assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
+  pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
   predList.emplace_back(pos, builder.getIsNotNull());
 
-  if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
-    // If the attribute has a type or value, add a constraint.
-    if (Value type = attr.getValueType())
-      getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
-    else if (Attribute value = attr.getValueAttr())
-      predList.emplace_back(pos, builder.getAttributeConstraint(value));
-  }
+  // If the attribute has a type or value, add a constraint.
+  if (Value type = attr.getValueType())
+    getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
+  else if (Attribute value = attr.getValueAttr())
+    predList.emplace_back(pos, builder.getAttributeConstraint(value));
 }
 
 /// Collect all of the predicates for the given operand position.
@@ -274,25 +272,8 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
   // Push the constraint to the furthest position.
   Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
                                     comparePosDepth);
-  ResultRange results = op.getResults();
-  PredicateBuilder::Predicate pred = builder.getConstraint(
-      op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
-      op.getIsNegated());
-
-  // For each result register a position so it can be used later
-  for (auto [i, result] : llvm::enumerate(results)) {
-    ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
-    ConstraintPosition *pos = builder.getConstraintPosition(q, i);
-    auto [it, inserted] = inputs.insert({result, pos});
-    // If this is an input value that has been visited in the tree, add a
-    // constraint to ensure that both instances refer to the same value.
-    if (!inserted) {
-      auto minMaxPositions =
-          std::minmax<Position *>(pos, it->second, comparePosDepth);
-      predList.emplace_back(minMaxPositions.second,
-                            builder.getEqualTo(minMaxPositions.first));
-    }
-  }
+  PredicateBuilder::Predicate pred =
+      builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
   predList.emplace_back(pos, pred);
 }
 
@@ -894,49 +875,6 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
   *root = std::make_unique<ExitNode>();
 }
 
-/// Sorts the range begin/end with the partial order given by cmp.
-template <typename Iterator, typename Compare>
-static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
-  while (begin != end) {
-    // Cannot compute sortBeforeOthers in the predicate of stable_partition
-    // because stable_partition will not keep the [begin, end) range intact
-    // while it runs.
-    llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
-    for (auto i = begin; i != end; ++i) {
-      if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
-        sortBeforeOthers.insert(*i);
-    }
-
-    auto const next = std::stable_partition(begin, end, [&](auto const &a) {
-      return sortBeforeOthers.contains(a);
-    });
-    assert(next != begin && "not a partial ordering");
-    begin = next;
-  }
-}
-
-/// Returns true if 'b' depends on a result of 'a'.
-static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
-  auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
-  if (!cqa)
-    return false;
-
-  auto positionDependsOnA = [&](Position *p) {
-    auto *cp = dyn_cast<ConstraintPosition>(p);
-    return cp && cp->getQuestion() == cqa;
-  };
-
-  if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
-    // Does any argument of b use a?
-    return llvm::any_of(cqb->getArgs(), positionDependsOnA);
-  }
-  if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
-    return positionDependsOnA(b->position) ||
-           positionDependsOnA(equalTo->getValue());
-  }
-  return positionDependsOnA(b->position);
-}
-
 /// Given a module containing PDL pattern operations, generate a matcher tree
 /// using the patterns within the given module and return the root matcher node.
 std::unique_ptr<MatcherNode>
@@ -1017,10 +955,6 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
     return *lhs < *rhs;
   });
 
-  // Mostly keep the now established order, but also ensure that
-  // ConstraintQuestions come after the results they use.
-  stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
-
   // Build the matchers for each of the pattern predicate lists.
   std::unique_ptr<MatcherNode> root;
   for (OrderedPredicateList &list : lists)

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 428b19f4c74af8..d5f34679f06c60 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -94,12 +94,6 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
 LogicalResult ApplyNativeConstraintOp::verify() {
   if (getNumOperands() == 0)
     return emitOpError("expected at least one argument");
-  if (llvm::any_of(getResults(), [](OpResult result) {
-        return isa<OperationType>(result.getType());
-      })) {
-    return emitOpError(
-        "returning an operation from a constraint is not supported");
-  }
   return success();
 }
 

diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 559ce7f466a83d..6e6992dcdeea78 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -769,25 +769,11 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
 
 void Generator::generate(pdl_interp::ApplyConstraintOp op,
                          ByteCodeWriter &writer) {
-  // Constraints that should return a value have to be registered as rewrites.
-  // If a constraint and a rewrite of similar name are registered the
-  // constraint takes precedence
+  assert(constraintToMemIndex.count(op.getName()) &&
+         "expected index for constraint function");
   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
   writer.appendPDLValueList(op.getArgs());
   writer.append(ByteCodeField(op.getIsNegated()));
-  ResultRange results = op.getResults();
-  writer.append(ByteCodeField(results.size()));
-  for (Value result : results) {
-    // We record the expected kind of the result, so that we can provide extra
-    // verification of the native rewrite function and handle the failure case
-    // of constraints accordingly.
-    writer.appendPDLValueKind(result);
-
-    // Range results also need to append the range storage index.
-    if (isa<pdl::RangeType>(result.getType()))
-      writer.append(getRangeStorageIndex(result));
-    writer.append(result);
-  }
   writer.append(op.getSuccessors());
 }
 void Generator::generate(pdl_interp::ApplyRewriteOp op,
@@ -800,9 +786,11 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
   ResultRange results = op.getResults();
   writer.append(ByteCodeField(results.size()));
   for (Value result : results) {
-    // We record the expected kind of the result, so that we
+    // In debug mode we also record the expected kind of the result, so that we
     // can provide extra verification of the native rewrite function.
+#ifndef NDEBUG
     writer.appendPDLValueKind(result);
+#endif
 
     // Range results also need to append the range storage index.
     if (isa<pdl::RangeType>(result.getType()))
@@ -1088,28 +1076,6 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
 // ByteCode Execution
 
 namespace {
-/// This class is an instantiation of the PDLResultList that provides access to
-/// the returned results. This API is not on `PDLResultList` to avoid
-/// overexposing access to information specific solely to the ByteCode.
-class ByteCodeRewriteResultList : public PDLResultList {
-public:
-  ByteCodeRewriteResultList(unsigned maxNumResults)
-      : PDLResultList(maxNumResults) {}
-
-  /// Return the list of PDL results.
-  MutableArrayRef<PDLValue> getResults() { return results; }
-
-  /// Return the type ranges allocated by this list.
-  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
-    return allocatedTypeRanges;
-  }
-
-  /// Return the value ranges allocated by this list.
-  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
-    return allocatedValueRanges;
-  }
-};
-
 /// This class provides support for executing a bytecode stream.
 class ByteCodeExecutor {
 public:
@@ -1186,9 +1152,6 @@ class ByteCodeExecutor {
   void executeSwitchResultCount();
   void executeSwitchType();
   void executeSwitchTypes();
-  void processNativeFunResults(ByteCodeRewriteResultList &results,
-                               unsigned numResults,
-                               LogicalResult &rewriteResult);
 
   /// Pushes a code iterator to the stack.
   void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
@@ -1262,8 +1225,6 @@ class ByteCodeExecutor {
     return T::getFromOpaquePointer(pointer);
   }
 
-  void skip(size_t skipN) { curCodeIt += skipN; }
-
   /// Jump to a specific successor based on a predicate value.
   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
   /// Jump to a specific successor based on a destination index.
@@ -1420,11 +1381,33 @@ class ByteCodeExecutor {
   ArrayRef<PDLConstraintFunction> constraintFunctions;
   ArrayRef<PDLRewriteFunction> rewriteFunctions;
 };
+
+/// This class is an instantiation of the PDLResultList that provides access to
+/// the returned results. This API is not on `PDLResultList` to avoid
+/// overexposing access to information specific solely to the ByteCode.
+class ByteCodeRewriteResultList : public PDLResultList {
+public:
+  ByteCodeRewriteResultList(unsigned maxNumResults)
+      : PDLResultList(maxNumResults) {}
+
+  /// Return the list of PDL results.
+  MutableArrayRef<PDLValue> getResults() { return results; }
+
+  /// Return the type ranges allocated by this list.
+  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+    return allocatedTypeRanges;
+  }
+
+  /// Return the value ranges allocated by this list.
+  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+    return allocatedValueRanges;
+  }
+};
 } // namespace
 
 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
-  ByteCodeField fun_idx = read();
+  const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
   SmallVector<PDLValue, 16> args;
   readList<PDLValue>(args);
 
@@ -1439,29 +1422,8 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
     llvm::dbgs() << "  * isNegated: " << isNegated << "\n";
     llvm::interleaveComma(args, llvm::dbgs());
   });
-
-  ByteCodeField numResults = read();
-  const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
-  ByteCodeRewriteResultList results(numResults);
-  LogicalResult rewriteResult = constraintFn(rewriter, results, args);
-  ArrayRef<PDLValue> constraintResults = results.getResults();
-  LLVM_DEBUG({
-    if (succeeded(rewriteResult)) {
-      llvm::dbgs() << "  * Constraint succeeded\n";
-      llvm::dbgs() << "  * Results: ";
-      llvm::interleaveComma(constraintResults, llvm::dbgs());
-      llvm::dbgs() << "\n";
-    } else {
-      llvm::dbgs() << "  * Constraint failed\n";
-    }
-  });
-  assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
-         "native PDL rewrite function succeeded but returned "
-         "unexpected number of results");
-  processNativeFunResults(results, numResults, rewriteResult);
-
-  // Depending on the constraint jump to the proper destination.
-  selectJump(isNegated != succeeded(rewriteResult));
+  // Invoke the constraint and jump to the proper destination.
+  selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
 }
 
 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@@ -1483,39 +1445,16 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
   assert(results.getResults().size() == numResults &&
          "native PDL rewrite function returned unexpected number of results");
 
-  processNativeFunResults(results, numResults, rewriteResult);
+  // Store the results in the bytecode memory.
+  for (PDLValue &result : results.getResults()) {
+    LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
 
-  if (failed(rewriteResult)) {
-    LLVM_DEBUG(llvm::dbgs() << "  - Failed");
-    return failure();
-  }
-  return success();
-}
+// In debug mode we also verify the expected kind of the result.
+#ifndef NDEBUG
+    assert(result.getKind() == read<PDLValue::Kind>() &&
+           "native PDL rewrite function returned an unexpected type of result");
+#endif
 
-void ByteCodeExecutor::processNativeFunResults(
-    ByteCodeRewriteResultList &results, unsigned numResults,
-    LogicalResult &rewriteResult) {
-  // Store the results in the bytecode memory or handle missing results on
-  // failure.
-  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
-    PDLValue::Kind resultKind = read<PDLValue::Kind>();
-
-    // Skip the according number of values on the buffer on failure and exit
-    // early as there are no results to process.
-    if (failed(rewriteResult)) {
-      if (resultKind == PDLValue::Kind::TypeRange ||
-          resultKind == PDLValue::Kind::ValueRange) {
-        skip(2);
-      } else {
-        skip(1);
-      }
-      return;
-    }
-    PDLValue result = results.getResults()[resultIdx];
-    LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
-    assert(result.getKind() == resultKind &&
-           "native PDL rewrite function returned an unexpected type of "
-           "result");
     // If the result is a range, we need to copy it over to the bytecodes
     // range memory.
     if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
@@ -1537,6 +1476,13 @@ void ByteCodeExecutor::processNativeFunResults(
     allocatedTypeRangeMemory.push_back(std::move(it));
   for (auto &it : results.getAllocatedValueRanges())
     allocatedValueRangeMemory.push_back(std::move(it));
+
+  // Process the result of the rewrite.
+  if (failed(rewriteResult)) {
+    LLVM_DEBUG(llvm::dbgs() << "  - Failed");
+    return failure();
+  }
+  return success();
 }
 
 void ByteCodeExecutor::executeAreEqual() {

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 206781ed146692..97ff8bd0d8584d 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -1362,6 +1362,12 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
   if (failed(parseToken(Token::semicolon,
                         "expected `;` after native declaration")))
     return failure();
+  // TODO: PDL should be able to support constraint results in certain
+  // situations, we should revise this.
+  if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
+    return emitError(
+        "native Constraints currently do not support returning results");
+  }
   return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
 }
 

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 92afb765b5ab4e..02bb8316c02db0 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -79,57 +79,6 @@ module @constraints {
 
 // -----
 
-// CHECK-LABEL: module @constraint_with_result
-module @constraint_with_result {
-  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
-  // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
-  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
-  pdl.pattern : benefit(1) {
-    %root = operation
-    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
-    rewrite %root with "rewriter"(%attr : !pdl.attribute)
-  }
-}
-
-// -----
-
-// CHECK-LABEL: module @constraint_with_unused_result
-module @constraint_with_unused_result {
-  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
-  // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
-  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
-  pdl.pattern : benefit(1) {
-    %root = operation
-    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
-    rewrite %root with "rewriter"
-  }
-}
-
-// -----
-
-// CHECK-LABEL: module @constraint_with_result_multiple
-module @constraint_with_result_multiple {
-  // check that native constraints work as expected even when multiple identical constraints are fused
-
-  // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
-  // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
-  // CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
-  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]]  : !pdl.operation, !pdl.attribute)
-  // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
-  pdl.pattern : benefit(1) {
-    %root = operation
-    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
-    rewrite %root with "rewriter"(%attr : !pdl.attribute)
-  }
-  pdl.pattern : benefit(1) {
-    %root = operation
-    %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
-    rewrite %root with "rewriter"(%attr : !pdl.attribute)
-  }
-}
-
-// -----
-
 // CHECK-LABEL: module @negated_constraint
 module @negated_constraint {
   // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
deleted file mode 100644
index 4511baff2015a4..00000000000000
--- a/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
+++ /dev/null
@@ -1,77 +0,0 @@
-// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
-
-// Ensuse that the dependency between add & less
-// causes them to be in the correct order.
-// CHECK-LABEL: matcher
-// CHECK: apply_constraint "return_attr_constraint"
-// CHECK: apply_constraint "use_attr_constraint"
-
-module {
-  pdl.pattern : benefit(1) {
-    %0 = attribute
-    %1 = types
-    %2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>)
-    %3 = attribute = 0 : i32
-    %4 = attribute = 1 : i32
-    %5 = apply_native_constraint "return_attr_constraint"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
-    apply_native_constraint "use_attr_constraint"(%0, %5 : !pdl.attribute, !pdl.attribute)
-    rewrite %2 with "rewriter"
-  }
-}
-
-// -----
-
-// CHECK-LABEL: matcher
-// CHECK: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of
-// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_attr_constraint"
-// CHECK: pdl_interp.are_equal %[[ATTR:.*]], %[[CONSTRAINT:.*]]
-
-pdl.pattern : benefit(1) {
-    %inputOp = operation
-    %result = result 0 of %inputOp
-    %attr = pdl.apply_native_constraint "return_attr_constraint"(%inputOp : !pdl.operation) : !pdl.attribute
-    %root = operation(%result : !pdl.value) {"attr" = %attr}
-    rewrite %root with "rewriter"(%attr : !pdl.attribute)
-}
-
-// -----
-
-// CHECK-LABEL: matcher
-// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_value_constr"
-// CHECK: %[[VALUE:.*]] = pdl_interp.get_operand 0
-// CHECK: pdl_interp.are_equal %[[VALUE:.*]], %[[CONSTRAINT:.*]]
-
-pdl.pattern : benefit(1) {
-    %attr = attribute = 10
-    %value = pdl.apply_native_constraint "return_value_constr"(%attr: !pdl.attribute) : !pdl.value
-    %root = operation(%value : !pdl.value)
-    rewrite %root with "rewriter"
-}
-
-// -----
-
-// CHECK-LABEL: matcher
-// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_constr"
-// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
-// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
-
-pdl.pattern : benefit(1) {
-    %attr = attribute = 10
-    %type = pdl.apply_native_constraint "return_type_constr"(%attr: !pdl.attribute) : !pdl.type
-    %root = operation -> (%type : !pdl.type)
-    rewrite %root with "rewriter"
-}
-
-// -----
-
-// CHECK-LABEL: matcher
-// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_range_constr"
-// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
-// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
-
-pdl.pattern : benefit(1) {
-    %attr = attribute = 10
-    %types = pdl.apply_native_constraint "return_type_range_constr"(%attr: !pdl.attribute) : !pdl.range<type>
-    %root = operation -> (%types : !pdl.range<type>)
-    rewrite %root with "rewriter"
-}

diff  --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index 20e40deea5f863..6e6da5cce446ae 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -134,24 +134,6 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
 
 // -----
 
-pdl.pattern @apply_constraint_with_no_results : benefit(1) {
-  %root = operation
-  apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
-  rewrite %root with "rewriter"
-}
-
-// -----
-
-pdl.pattern @apply_constraint_with_results : benefit(1) {
-  %root = operation
-  %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
-  rewrite %root {
-    apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
-  }
-}
-
-// -----
-
 pdl.pattern @attribute_with_dict : benefit(1) {
   %root = operation
   rewrite %root {

diff  --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index f8e4f2e83b296a..ae61c1a079548a 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -109,74 +109,6 @@ module @ir attributes { test.apply_constraint_3 } {
 
 // -----
 
-// Test returning a type from a native constraint.
-module @patterns {
-  pdl_interp.func @matcher(%root : !pdl.operation) {
-    pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
-
-  ^pat:
-    %new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end
-
-  ^pat2:
-    pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
-
-  ^end:
-    pdl_interp.finalize
-  }
-
-  module @rewriters {
-    pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
-      %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
-      pdl_interp.erase %root
-      pdl_interp.finalize
-    }
-  }
-}
-
-// CHECK-LABEL: test.apply_constraint_4
-// CHECK-NOT: "test.replaced_by_pattern"
-// CHECK: "test.replaced_by_pattern"() : () -> f32
-module @ir attributes { test.apply_constraint_4 } {
-  "test.failure_op"() : () -> ()
-  "test.success_op"() : () -> ()
-}
-
-// -----
-
-// Test success and failure cases of native constraints with pdl.range results.
-module @patterns {
-  pdl_interp.func @matcher(%root : !pdl.operation) {
-    pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
-  
-  ^pat:
-    %num_results = pdl_interp.create_attribute 2 : i32
-    %types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range<type> -> ^pat1, ^end
-
-  ^pat1:
-    pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range<type>) : benefit(1), loc([%root]) -> ^end
-
-  ^end:
-    pdl_interp.finalize
-  }
-
-  module @rewriters {
-    pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range<type>) {
-      %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range<type>)
-      pdl_interp.erase %root
-      pdl_interp.finalize
-    }
-  }
-}
-
-// CHECK-LABEL: test.apply_constraint_5
-// CHECK-NOT: "test.replaced_by_pattern"
-// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32)
-module @ir attributes { test.apply_constraint_5 } {
-  "test.failure_op"() : () -> ()
-  "test.success_op"() : () -> ()
-}
-
-// -----
 
 //===----------------------------------------------------------------------===//
 // pdl_interp::ApplyRewriteOp

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index b9424e06bf0318..50caf8f9cfc709 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -887,7 +887,7 @@ class TestTransformDialectExtension
 #include "TestTransformDialectExtensionTypes.cpp.inc"
         >();
 
-    auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
+    auto verboseConstraint = [](PatternRewriter &rewriter,
                                 ArrayRef<PDLValue> pdlValues) {
       for (const PDLValue &pdlValue : pdlValues) {
         if (Operation *op = pdlValue.dyn_cast<Operation *>()) {

diff  --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 56af3c15b905f1..daa1c371f27c92 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -30,50 +30,6 @@ static LogicalResult customMultiEntityVariadicConstraint(
   return success();
 }
 
-// Custom constraint that returns a value if the op is named test.success_op
-static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
-                                                 PDLResultList &results,
-                                                 ArrayRef<PDLValue> args) {
-  auto *op = args[0].cast<Operation *>();
-  if (op->getName().getStringRef() == "test.success_op") {
-    StringAttr customAttr = rewriter.getStringAttr("test.success");
-    results.push_back(customAttr);
-    return success();
-  }
-  return failure();
-}
-
-// Custom constraint that returns a type if the op is named test.success_op
-static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
-                                                PDLResultList &results,
-                                                ArrayRef<PDLValue> args) {
-  auto *op = args[0].cast<Operation *>();
-  if (op->getName().getStringRef() == "test.success_op") {
-    results.push_back(rewriter.getF32Type());
-    return success();
-  }
-  return failure();
-}
-
-// Custom constraint that returns a type range of variable length if the op is
-// named test.success_op
-static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
-                                                     PDLResultList &results,
-                                                     ArrayRef<PDLValue> args) {
-  auto *op = args[0].cast<Operation *>();
-  int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
-
-  if (op->getName().getStringRef() == "test.success_op") {
-    SmallVector<Type> types;
-    for (int i = 0; i < numTypes; i++) {
-      types.push_back(rewriter.getF32Type());
-    }
-    results.push_back(TypeRange(types));
-    return success();
-  }
-  return failure();
-}
-
 // Custom creator invoked from PDL.
 static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
   return rewriter.create(OperationState(op->getLoc(), "test.success"));
@@ -146,12 +102,6 @@ struct TestPDLByteCodePass
                                           customMultiEntityConstraint);
     pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
                                           customMultiEntityVariadicConstraint);
-    pdlPattern.registerConstraintFunction("op_constr_return_attr",
-                                          customValueResultConstraint);
-    pdlPattern.registerConstraintFunction("op_constr_return_type",
-                                          customTypeResultConstraint);
-    pdlPattern.registerConstraintFunction("op_constr_return_type_range",
-                                          customTypeRangeResultConstraint);
     pdlPattern.registerRewriteFunction("creator", customCreate);
     pdlPattern.registerRewriteFunction("var_creator",
                                        customVariadicResultCreate);

diff  --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
index 48747d3fa2e681..18877b4bcc50ec 100644
--- a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
@@ -158,3 +158,8 @@ Pattern {
 
 // CHECK: expected `;` after native declaration
 Constraint Foo() [{}]
+
+// -----
+
+// CHECK: native Constraints currently do not support returning results
+Constraint Foo() -> Op;

diff  --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll
index e2a52ff130cc84..1c0a015ab4a7b4 100644
--- a/mlir/test/mlir-pdll/Parser/constraint.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint.pdll
@@ -12,14 +12,6 @@ Constraint Foo() [{ /* Native Code */ }];
 
 // -----
 
-// Test that native constraints support returning results.
-
-// CHECK:  Module
-// CHECK:  `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Attr>
-Constraint Foo() -> Attr;
-
-// -----
-
 // CHECK: Module
 // CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
 // CHECK:   `Inputs`

diff  --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py
index 95cb25c14873db..0d364f9222a657 100644
--- a/mlir/test/python/dialects/pdl_ops.py
+++ b/mlir/test/python/dialects/pdl_ops.py
@@ -298,6 +298,6 @@ def test_apply_native_constraint():
     pattern = PatternOp(1)
     with InsertionPoint(pattern.body):
         resultType = TypeOp()
-        ApplyNativeConstraintOp([], "typeConstraint", args=[resultType])
+        ApplyNativeConstraintOp("typeConstraint", args=[resultType])
         root = OperationOp(types=[resultType])
         RewriteOp(root, name="rewrite")


        


More information about the Mlir-commits mailing list