[Mlir-commits] [mlir] faf4226 - [PDLL] Add support for user defined constraint and rewrite functions

River Riddle llvmlistbot at llvm.org
Thu Feb 10 12:49:18 PST 2022


Author: River Riddle
Date: 2022-02-10T12:48:59-08:00
New Revision: faf42264e5401a1dfca95b701e5c2bf951d7f8a7

URL: https://github.com/llvm/llvm-project/commit/faf42264e5401a1dfca95b701e5c2bf951d7f8a7
DIFF: https://github.com/llvm/llvm-project/commit/faf42264e5401a1dfca95b701e5c2bf951d7f8a7.diff

LOG: [PDLL] Add support for user defined constraint and rewrite functions

These functions allow for defining pattern fragments usable within the `match` and `rewrite` sections of a pattern. The main structure of Constraints and Rewrites functions are the same, and are similar to functions in other languages; they contain a signature (i.e. name, argument list, result list) and a body:

```pdll
// Constraint that takes a value as an input, and produces a value:
Constraint Cst(arg: Value) -> Value { ... }

// Constraint that returns multiple values:
Constraint Cst() -> (result1: Value, result2: ValueRange);
```

When returning multiple results, each result can be optionally be named (the result of a Constraint/Rewrite in the case of multiple results is a tuple).

These body of a Constraint/Rewrite functions can be specified in several ways:

* Externally
In this case we are importing an external function (registered by the user outside of PDLL):

```pdll
Constraint Foo(op: Op);
Rewrite Bar();
```

* In PDLL (using PDLL constructs)
In this case, the body is defined using PDLL constructs:

```pdll
Rewrite BuildFooOp() {
  // The result type of the Rewrite is inferred from the return.
  return op<my_dialect.foo>;
}
// Constraints/Rewrites can also implement a lambda/expression
// body for simple one line bodies.
Rewrite BuildFooOp() => op<my_dialect.foo>;
```

* In PDLL (using a native/C++ code block)
In this case the body is specified using a C++(or potentially other language at some point) code block. When building PDLL in AOT mode this will generate a native constraint/rewrite and register it with the PDL bytecode.

```pdll
Rewrite BuildFooOp() -> Op<my_dialect.foo> [{
  return rewriter.create<my_dialect::FooOp>(...);
}];
```

Differential Revision: https://reviews.llvm.org/D115836

Added: 
    mlir/test/mlir-pdll/Parser/constraint-failure.pdll
    mlir/test/mlir-pdll/Parser/constraint.pdll
    mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
    mlir/test/mlir-pdll/Parser/rewrite.pdll

Modified: 
    mlir/include/mlir/Tools/PDLL/AST/Nodes.h
    mlir/include/mlir/Tools/PDLL/AST/Types.h
    mlir/lib/Tools/PDLL/AST/Context.cpp
    mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
    mlir/lib/Tools/PDLL/AST/Nodes.cpp
    mlir/lib/Tools/PDLL/AST/TypeDetail.h
    mlir/lib/Tools/PDLL/AST/Types.cpp
    mlir/lib/Tools/PDLL/Parser/Lexer.cpp
    mlir/lib/Tools/PDLL/Parser/Lexer.h
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/test/mlir-pdll/Parser/expr-failure.pdll
    mlir/test/mlir-pdll/Parser/expr.pdll
    mlir/test/mlir-pdll/Parser/pattern-failure.pdll
    mlir/test/mlir-pdll/Parser/stmt-failure.pdll

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index d46b9004b03ad..6824354a16edd 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -300,6 +300,31 @@ class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
   CompoundStmt *rewriteBody;
 };
 
+//===----------------------------------------------------------------------===//
+// ReturnStmt
+//===----------------------------------------------------------------------===//
+
+/// This statement represents a return from a "callable" like decl, e.g. a
+/// Constraint or a Rewrite.
+class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
+public:
+  static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
+
+  /// Return the result expression of this statement.
+  Expr *getResultExpr() { return resultExpr; }
+  const Expr *getResultExpr() const { return resultExpr; }
+
+  /// Set the result expression of this statement.
+  void setResultExpr(Expr *expr) { resultExpr = expr; }
+
+private:
+  ReturnStmt(SMRange loc, Expr *resultExpr)
+      : Base(loc), resultExpr(resultExpr) {}
+
+  // The result expression of this statement.
+  Expr *resultExpr;
+};
+
 //===----------------------------------------------------------------------===//
 // Expr
 //===----------------------------------------------------------------------===//
@@ -345,6 +370,43 @@ class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
   StringRef value;
 };
 
+//===----------------------------------------------------------------------===//
+// CallExpr
+//===----------------------------------------------------------------------===//
+
+/// This expression represents a call to a decl, such as a
+/// UserConstraintDecl/UserRewriteDecl.
+class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
+                       private llvm::TrailingObjects<CallExpr, Expr *> {
+public:
+  static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
+                          ArrayRef<Expr *> arguments, Type resultType);
+
+  /// Return the callable of this call.
+  Expr *getCallableExpr() const { return callable; }
+
+  /// Return the arguments of this call.
+  MutableArrayRef<Expr *> getArguments() {
+    return {getTrailingObjects<Expr *>(), numArgs};
+  }
+  ArrayRef<Expr *> getArguments() const {
+    return const_cast<CallExpr *>(this)->getArguments();
+  }
+
+private:
+  CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs)
+      : Base(loc, type), callable(callable), numArgs(numArgs) {}
+
+  /// The callable of this call.
+  Expr *callable;
+
+  /// The number of arguments of the call.
+  unsigned numArgs;
+
+  /// TrailingObject utilities.
+  friend llvm::TrailingObjects<CallExpr, Expr *>;
+};
+
 //===----------------------------------------------------------------------===//
 // DeclRefExpr
 //===----------------------------------------------------------------------===//
@@ -738,6 +800,114 @@ class ValueRangeConstraintDecl
   Expr *typeExpr;
 };
 
+//===----------------------------------------------------------------------===//
+// UserConstraintDecl
+//===----------------------------------------------------------------------===//
+
+/// This decl represents a user defined constraint. This is either:
+///   * an imported native constraint
+///     - Similar to an external function declaration. This is a native
+///       constraint defined externally, and imported into PDLL via a
+///       declaration.
+///   * a native constraint defined in PDLL
+///     - This is a native constraint, i.e. a constraint whose implementation is
+///       defined in C++(or potentially some other non-PDLL language). The
+///       implementation of this constraint is specified as a string code block
+///       in PDLL.
+///   * a PDLL constraint
+///     - This is a constraint which is defined using only PDLL constructs.
+class UserConstraintDecl final
+    : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
+      llvm::TrailingObjects<UserConstraintDecl, VariableDecl *> {
+public:
+  /// Create a native constraint with the given optional code block.
+  static UserConstraintDecl *createNative(Context &ctx, const Name &name,
+                                          ArrayRef<VariableDecl *> inputs,
+                                          ArrayRef<VariableDecl *> results,
+                                          Optional<StringRef> codeBlock,
+                                          Type resultType) {
+    return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
+                      resultType);
+  }
+
+  /// Create a PDLL constraint with the given body.
+  static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
+                                        ArrayRef<VariableDecl *> inputs,
+                                        ArrayRef<VariableDecl *> results,
+                                        const CompoundStmt *body,
+                                        Type resultType) {
+    return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
+                      body, resultType);
+  }
+
+  /// Return the name of the constraint.
+  const Name &getName() const { return *Decl::getName(); }
+
+  /// Return the input arguments of this constraint.
+  MutableArrayRef<VariableDecl *> getInputs() {
+    return {getTrailingObjects<VariableDecl *>(), numInputs};
+  }
+  ArrayRef<VariableDecl *> getInputs() const {
+    return const_cast<UserConstraintDecl *>(this)->getInputs();
+  }
+
+  /// Return the explicit results of the constraint declaration. May be empty,
+  /// even if the constraint has results (e.g. in the case of inferred results).
+  MutableArrayRef<VariableDecl *> getResults() {
+    return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
+  }
+  ArrayRef<VariableDecl *> getResults() const {
+    return const_cast<UserConstraintDecl *>(this)->getResults();
+  }
+
+  /// Return the optional code block of this constraint, if this is a native
+  /// constraint with a provided implementation.
+  Optional<StringRef> getCodeBlock() const { return codeBlock; }
+
+  /// Return the body of this constraint if this constraint is a PDLL
+  /// constraint, otherwise returns nullptr.
+  const CompoundStmt *getBody() const { return constraintBody; }
+
+  /// Return the result type of this constraint.
+  Type getResultType() const { return resultType; }
+
+  /// Returns true if this constraint is external.
+  bool isExternal() const { return !constraintBody && !codeBlock; }
+
+private:
+  /// Create either a PDLL constraint or a native constraint with the given
+  /// components.
+  static UserConstraintDecl *
+  createImpl(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
+             ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
+             const CompoundStmt *body, Type resultType);
+
+  UserConstraintDecl(const Name &name, unsigned numInputs, unsigned numResults,
+                     Optional<StringRef> codeBlock, const CompoundStmt *body,
+                     Type resultType)
+      : Base(name.getLoc(), &name), numInputs(numInputs),
+        numResults(numResults), codeBlock(codeBlock), constraintBody(body),
+        resultType(resultType) {}
+
+  /// The number of inputs to this constraint.
+  unsigned numInputs;
+
+  /// The number of explicit results to this constraint.
+  unsigned numResults;
+
+  /// The optional code block of this constraint.
+  Optional<StringRef> codeBlock;
+
+  /// The optional body of this constraint.
+  const CompoundStmt *constraintBody;
+
+  /// The result type of the constraint.
+  Type resultType;
+
+  /// Allow access to various internals.
+  friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *>;
+};
+
 //===----------------------------------------------------------------------===//
 // NamedAttributeDecl
 //===----------------------------------------------------------------------===//
@@ -826,6 +996,149 @@ class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
   const CompoundStmt *patternBody;
 };
 
+//===----------------------------------------------------------------------===//
+// UserRewriteDecl
+//===----------------------------------------------------------------------===//
+
+/// This decl represents a user defined rewrite. This is either:
+///   * an imported native rewrite
+///     - Similar to an external function declaration. This is a native
+///       rewrite defined externally, and imported into PDLL via a declaration.
+///   * a native rewrite defined in PDLL
+///     - This is a native rewrite, i.e. a rewrite whose implementation is
+///       defined in C++(or potentially some other non-PDLL language). The
+///       implementation of this rewrite is specified as a string code block
+///       in PDLL.
+///   * a PDLL rewrite
+///     - This is a rewrite which is defined using only PDLL constructs.
+class UserRewriteDecl final
+    : public Node::NodeBase<UserRewriteDecl, Decl>,
+      llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
+public:
+  /// Create a native rewrite with the given optional code block.
+  static UserRewriteDecl *createNative(Context &ctx, const Name &name,
+                                       ArrayRef<VariableDecl *> inputs,
+                                       ArrayRef<VariableDecl *> results,
+                                       Optional<StringRef> codeBlock,
+                                       Type resultType) {
+    return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
+                      resultType);
+  }
+
+  /// Create a PDLL rewrite with the given body.
+  static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
+                                     ArrayRef<VariableDecl *> inputs,
+                                     ArrayRef<VariableDecl *> results,
+                                     const CompoundStmt *body,
+                                     Type resultType) {
+    return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
+                      body, resultType);
+  }
+
+  /// Return the name of the rewrite.
+  const Name &getName() const { return *Decl::getName(); }
+
+  /// Return the input arguments of this rewrite.
+  MutableArrayRef<VariableDecl *> getInputs() {
+    return {getTrailingObjects<VariableDecl *>(), numInputs};
+  }
+  ArrayRef<VariableDecl *> getInputs() const {
+    return const_cast<UserRewriteDecl *>(this)->getInputs();
+  }
+
+  /// Return the explicit results of the rewrite declaration. May be empty,
+  /// even if the rewrite has results (e.g. in the case of inferred results).
+  MutableArrayRef<VariableDecl *> getResults() {
+    return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
+  }
+  ArrayRef<VariableDecl *> getResults() const {
+    return const_cast<UserRewriteDecl *>(this)->getResults();
+  }
+
+  /// Return the optional code block of this rewrite, if this is a native
+  /// rewrite with a provided implementation.
+  Optional<StringRef> getCodeBlock() const { return codeBlock; }
+
+  /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
+  /// otherwise returns nullptr.
+  const CompoundStmt *getBody() const { return rewriteBody; }
+
+  /// Return the result type of this rewrite.
+  Type getResultType() const { return resultType; }
+
+  /// Returns true if this rewrite is external.
+  bool isExternal() const { return !rewriteBody && !codeBlock; }
+
+private:
+  /// Create either a PDLL rewrite or a native rewrite with the given
+  /// components.
+  static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
+                                     ArrayRef<VariableDecl *> inputs,
+                                     ArrayRef<VariableDecl *> results,
+                                     Optional<StringRef> codeBlock,
+                                     const CompoundStmt *body, Type resultType);
+
+  UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
+                  Optional<StringRef> codeBlock, const CompoundStmt *body,
+                  Type resultType)
+      : Base(name.getLoc(), &name), numInputs(numInputs),
+        numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
+        resultType(resultType) {}
+
+  /// The number of inputs to this rewrite.
+  unsigned numInputs;
+
+  /// The number of explicit results to this rewrite.
+  unsigned numResults;
+
+  /// The optional code block of this rewrite.
+  Optional<StringRef> codeBlock;
+
+  /// The optional body of this rewrite.
+  const CompoundStmt *rewriteBody;
+
+  /// The result type of the rewrite.
+  Type resultType;
+
+  /// Allow access to various internals.
+  friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
+};
+
+//===----------------------------------------------------------------------===//
+// CallableDecl
+//===----------------------------------------------------------------------===//
+
+/// This decl represents a shared interface for all callable decls.
+class CallableDecl : public Decl {
+public:
+  /// Return the callable type of this decl.
+  StringRef getCallableType() const {
+    if (isa<UserConstraintDecl>(this))
+      return "constraint";
+    assert(isa<UserRewriteDecl>(this) && "unknown callable type");
+    return "rewrite";
+  }
+
+  /// Return the inputs of this decl.
+  ArrayRef<VariableDecl *> getInputs() const {
+    if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
+      return cst->getInputs();
+    return cast<UserRewriteDecl>(this)->getInputs();
+  }
+
+  /// Return the result type of this decl.
+  Type getResultType() const {
+    if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
+      return cst->getResultType();
+    return cast<UserRewriteDecl>(this)->getResultType();
+  }
+
+  /// Support LLVM type casting facilities.
+  static bool classof(const Node *decl) {
+    return isa<UserConstraintDecl, UserRewriteDecl>(decl);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // VariableDecl
 //===----------------------------------------------------------------------===//
@@ -912,11 +1225,11 @@ class Module final : public Node::NodeBase<Module, Node>,
 
 inline bool Decl::classof(const Node *node) {
   return isa<ConstraintDecl, NamedAttributeDecl, OpNameDecl, PatternDecl,
-             VariableDecl>(node);
+             UserRewriteDecl, VariableDecl>(node);
 }
 
 inline bool ConstraintDecl::classof(const Node *node) {
-  return isa<CoreConstraintDecl>(node);
+  return isa<CoreConstraintDecl, UserConstraintDecl>(node);
 }
 
 inline bool CoreConstraintDecl::classof(const Node *node) {

diff  --git a/mlir/include/mlir/Tools/PDLL/AST/Types.h b/mlir/include/mlir/Tools/PDLL/AST/Types.h
index cac3cae962c2d..58a20801fb1b5 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Types.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Types.h
@@ -22,6 +22,7 @@ struct AttributeTypeStorage;
 struct ConstraintTypeStorage;
 struct OperationTypeStorage;
 struct RangeTypeStorage;
+struct RewriteTypeStorage;
 struct TupleTypeStorage;
 struct TypeTypeStorage;
 struct ValueTypeStorage;
@@ -203,6 +204,20 @@ class ValueRangeType : public RangeType {
   static ValueRangeType get(Context &context);
 };
 
+//===----------------------------------------------------------------------===//
+// RewriteType
+//===----------------------------------------------------------------------===//
+
+/// This class represents a PDLL type that corresponds to a rewrite reference.
+/// This type has no MLIR C++ API correspondance.
+class RewriteType : public Type::TypeBase<detail::RewriteTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Return an instance of the Rewrite type.
+  static RewriteType get(Context &context);
+};
+
 //===----------------------------------------------------------------------===//
 // TupleType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp
index 960870aabf436..09ae0e6ad6e07 100644
--- a/mlir/lib/Tools/PDLL/AST/Context.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Context.cpp
@@ -15,6 +15,7 @@ using namespace mlir::pdll::ast;
 Context::Context() {
   typeUniquer.registerSingletonStorageType<detail::AttributeTypeStorage>();
   typeUniquer.registerSingletonStorageType<detail::ConstraintTypeStorage>();
+  typeUniquer.registerSingletonStorageType<detail::RewriteTypeStorage>();
   typeUniquer.registerSingletonStorageType<detail::TypeTypeStorage>();
   typeUniquer.registerSingletonStorageType<detail::ValueTypeStorage>();
 

diff  --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
index cc7bdd47bf2c5..2be71b67f5292 100644
--- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
+++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
@@ -76,9 +76,11 @@ class NodePrinter {
   void printImpl(const EraseStmt *stmt);
   void printImpl(const LetStmt *stmt);
   void printImpl(const ReplaceStmt *stmt);
+  void printImpl(const ReturnStmt *stmt);
   void printImpl(const RewriteStmt *stmt);
 
   void printImpl(const AttributeExpr *expr);
+  void printImpl(const CallExpr *expr);
   void printImpl(const DeclRefExpr *expr);
   void printImpl(const MemberAccessExpr *expr);
   void printImpl(const OperationExpr *expr);
@@ -89,11 +91,13 @@ class NodePrinter {
   void printImpl(const OpConstraintDecl *decl);
   void printImpl(const TypeConstraintDecl *decl);
   void printImpl(const TypeRangeConstraintDecl *decl);
+  void printImpl(const UserConstraintDecl *decl);
   void printImpl(const ValueConstraintDecl *decl);
   void printImpl(const ValueRangeConstraintDecl *decl);
   void printImpl(const NamedAttributeDecl *decl);
   void printImpl(const OpNameDecl *decl);
   void printImpl(const PatternDecl *decl);
+  void printImpl(const UserRewriteDecl *decl);
   void printImpl(const VariableDecl *decl);
   void printImpl(const Module *module);
 
@@ -135,6 +139,7 @@ void NodePrinter::print(Type type) {
         print(type.getElementType());
         os << "Range";
       })
+      .Case([&](RewriteType) { os << "Rewrite"; })
       .Case([&](TupleType type) {
         os << "Tuple<";
         llvm::interleaveComma(
@@ -160,17 +165,19 @@ void NodePrinter::print(const Node *node) {
       .Case<
           // Statements.
           const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
-          const RewriteStmt,
+          const ReturnStmt, const RewriteStmt,
 
           // Expressions.
-          const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
-          const OperationExpr, const TupleExpr, const TypeExpr,
+          const AttributeExpr, const CallExpr, const DeclRefExpr,
+          const MemberAccessExpr, const OperationExpr, const TupleExpr,
+          const TypeExpr,
 
           // Decls.
           const AttrConstraintDecl, const OpConstraintDecl,
           const TypeConstraintDecl, const TypeRangeConstraintDecl,
-          const ValueConstraintDecl, const ValueRangeConstraintDecl,
-          const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
+          const UserConstraintDecl, const ValueConstraintDecl,
+          const ValueRangeConstraintDecl, const NamedAttributeDecl,
+          const OpNameDecl, const PatternDecl, const UserRewriteDecl,
           const VariableDecl,
 
           const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
@@ -199,6 +206,11 @@ void NodePrinter::printImpl(const ReplaceStmt *stmt) {
   printChildren("ReplValues", stmt->getReplExprs());
 }
 
+void NodePrinter::printImpl(const ReturnStmt *stmt) {
+  os << "ReturnStmt " << stmt << "\n";
+  printChildren(stmt->getResultExpr());
+}
+
 void NodePrinter::printImpl(const RewriteStmt *stmt) {
   os << "RewriteStmt " << stmt << "\n";
   printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
@@ -208,6 +220,14 @@ void NodePrinter::printImpl(const AttributeExpr *expr) {
   os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
 }
 
+void NodePrinter::printImpl(const CallExpr *expr) {
+  os << "CallExpr " << expr << " Type<";
+  print(expr->getType());
+  os << ">\n";
+  printChildren(expr->getCallableExpr());
+  printChildren("Arguments", expr->getArguments());
+}
+
 void NodePrinter::printImpl(const DeclRefExpr *expr) {
   os << "DeclRefExpr " << expr << " Type<";
   print(expr->getType());
@@ -265,6 +285,21 @@ void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
   os << "TypeRangeConstraintDecl " << decl << "\n";
 }
 
+void NodePrinter::printImpl(const UserConstraintDecl *decl) {
+  os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
+     << "> ResultType<" << decl->getResultType() << ">";
+  if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
+    os << " Code<";
+    llvm::printEscapedString(*codeBlock, os);
+    os << ">";
+  }
+  os << "\n";
+  printChildren("Inputs", decl->getInputs());
+  printChildren("Results", decl->getResults());
+  if (const CompoundStmt *body = decl->getBody())
+    printChildren(body);
+}
+
 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
   os << "ValueConstraintDecl " << decl << "\n";
   if (const auto *typeExpr = decl->getTypeExpr())
@@ -303,6 +338,21 @@ void NodePrinter::printImpl(const PatternDecl *decl) {
   printChildren(decl->getBody());
 }
 
+void NodePrinter::printImpl(const UserRewriteDecl *decl) {
+  os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
+     << "> ResultType<" << decl->getResultType() << ">";
+  if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
+    os << " Code<";
+    llvm::printEscapedString(*codeBlock, os);
+    os << ">";
+  }
+  os << "\n";
+  printChildren("Inputs", decl->getInputs());
+  printChildren("Results", decl->getResults());
+  if (const CompoundStmt *body = decl->getBody())
+    printChildren(body);
+}
+
 void NodePrinter::printImpl(const VariableDecl *decl) {
   os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
      << "> Type<";

diff  --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
index 76f42f5329fd4..2e7b252b35c86 100644
--- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
@@ -108,6 +108,15 @@ RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
       RewriteStmt(loc, rootOp, rewriteBody);
 }
 
+//===----------------------------------------------------------------------===//
+// ReturnStmt
+//===----------------------------------------------------------------------===//
+
+ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
+  return new (ctx.getAllocator().Allocate<ReturnStmt>())
+      ReturnStmt(loc, resultExpr);
+}
+
 //===----------------------------------------------------------------------===//
 // AttributeExpr
 //===----------------------------------------------------------------------===//
@@ -118,6 +127,22 @@ AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
       AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
 }
 
+//===----------------------------------------------------------------------===//
+// CallExpr
+//===----------------------------------------------------------------------===//
+
+CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
+                           ArrayRef<Expr *> arguments, Type resultType) {
+  unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
+  void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
+
+  CallExpr *expr =
+      new (rawData) CallExpr(loc, resultType, callable, arguments.size());
+  std::uninitialized_copy(arguments.begin(), arguments.end(),
+                          expr->getArguments().begin());
+  return expr;
+}
+
 //===----------------------------------------------------------------------===//
 // DeclRefExpr
 //===----------------------------------------------------------------------===//
@@ -267,6 +292,30 @@ ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
       ValueRangeConstraintDecl(loc, typeExpr);
 }
 
+//===----------------------------------------------------------------------===//
+// UserConstraintDecl
+//===----------------------------------------------------------------------===//
+
+UserConstraintDecl *UserConstraintDecl::createImpl(
+    Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
+    ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
+    const CompoundStmt *body, Type resultType) {
+  unsigned allocSize = UserConstraintDecl::totalSizeToAlloc<VariableDecl *>(
+      inputs.size() + results.size());
+  void *rawData =
+      ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
+  if (codeBlock)
+    codeBlock = codeBlock->copy(ctx.getAllocator());
+
+  UserConstraintDecl *decl = new (rawData) UserConstraintDecl(
+      name, inputs.size(), results.size(), codeBlock, body, resultType);
+  std::uninitialized_copy(inputs.begin(), inputs.end(),
+                          decl->getInputs().begin());
+  std::uninitialized_copy(results.begin(), results.end(),
+                          decl->getResults().begin());
+  return decl;
+}
+
 //===----------------------------------------------------------------------===//
 // NamedAttributeDecl
 //===----------------------------------------------------------------------===//
@@ -300,6 +349,32 @@ PatternDecl *PatternDecl::create(Context &ctx, SMRange loc,
       PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
 }
 
+//===----------------------------------------------------------------------===//
+// UserRewriteDecl
+//===----------------------------------------------------------------------===//
+
+UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
+                                             ArrayRef<VariableDecl *> inputs,
+                                             ArrayRef<VariableDecl *> results,
+                                             Optional<StringRef> codeBlock,
+                                             const CompoundStmt *body,
+                                             Type resultType) {
+  unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
+      inputs.size() + results.size());
+  void *rawData =
+      ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
+  if (codeBlock)
+    codeBlock = codeBlock->copy(ctx.getAllocator());
+
+  UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
+      name, inputs.size(), results.size(), codeBlock, body, resultType);
+  std::uninitialized_copy(inputs.begin(), inputs.end(),
+                          decl->getInputs().begin());
+  std::uninitialized_copy(results.begin(), results.end(),
+                          decl->getResults().begin());
+  return decl;
+}
+
 //===----------------------------------------------------------------------===//
 // VariableDecl
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h
index b8615175400ae..e6719fb961216 100644
--- a/mlir/lib/Tools/PDLL/AST/TypeDetail.h
+++ b/mlir/lib/Tools/PDLL/AST/TypeDetail.h
@@ -93,6 +93,12 @@ struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> {
   using Base::Base;
 };
 
+//===----------------------------------------------------------------------===//
+// RewriteType
+//===----------------------------------------------------------------------===//
+
+struct RewriteTypeStorage : public TypeStorageBase<RewriteTypeStorage> {};
+
 //===----------------------------------------------------------------------===//
 // TupleType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp
index ba9b31f85adf3..cf0f0e918870b 100644
--- a/mlir/lib/Tools/PDLL/AST/Types.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Types.cpp
@@ -107,6 +107,14 @@ ValueRangeType ValueRangeType::get(Context &context) {
       .cast<ValueRangeType>();
 }
 
+//===----------------------------------------------------------------------===//
+// RewriteType
+//===----------------------------------------------------------------------===//
+
+RewriteType RewriteType::get(Context &context) {
+  return context.getTypeUniquer().get<ImplTy>();
+}
+
 //===----------------------------------------------------------------------===//
 // TupleType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
index 3db5cfbdfd645..61c5783e24545 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
@@ -298,7 +298,9 @@ Token Lexer::lexIdentifier(const char *tokStart) {
                          .Case("OpName", Token::kw_OpName)
                          .Case("Pattern", Token::kw_Pattern)
                          .Case("replace", Token::kw_replace)
+                         .Case("return", Token::kw_return)
                          .Case("rewrite", Token::kw_rewrite)
+                         .Case("Rewrite", Token::kw_Rewrite)
                          .Case("type", Token::kw_type)
                          .Case("Type", Token::kw_Type)
                          .Case("TypeRange", Token::kw_TypeRange)

diff  --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h
index 4692f28ba877c..0109b0da36c79 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.h
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h
@@ -55,7 +55,9 @@ class Token {
     kw_OpName,
     kw_Pattern,
     kw_replace,
+    kw_return,
     kw_rewrite,
+    kw_Rewrite,
     kw_Type,
     kw_TypeRange,
     kw_Value,

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 5264b953aaabc..5ed2481027800 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -50,6 +50,9 @@ class Parser {
   enum class ParserContext {
     /// The parser is in the global context.
     Global,
+    /// The parser is currently within a Constraint, which disallows all types
+    /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
+    Constraint,
     /// The parser is currently within the matcher portion of a Pattern, which
     /// is allows a terminal operation rewrite statement but no other rewrite
     /// transformations.
@@ -106,6 +109,77 @@ class Parser {
 
   FailureOr<ast::Decl *> parseTopLevelDecl();
   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
+
+  /// Parse an argument variable as part of the signature of a
+  /// UserConstraintDecl or UserRewriteDecl.
+  FailureOr<ast::VariableDecl *> parseArgumentDecl();
+
+  /// Parse a result variable as part of the signature of a UserConstraintDecl
+  /// or UserRewriteDecl.
+  FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
+
+  /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
+  /// defined in a non-global context.
+  FailureOr<ast::UserConstraintDecl *>
+  parseUserConstraintDecl(bool isInline = false);
+
+  /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
+  /// non-global context, such as within a Pattern/Constraint/etc.
+  FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
+
+  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
+  /// PDLL constructs.
+  FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
+      const ast::Name &name, bool isInline,
+      ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+      ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
+
+  /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
+  /// defined in a non-global context.
+  FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
+
+  /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
+  /// non-global context, such as within a Pattern/Rewrite/etc.
+  FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
+
+  /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
+  /// PDLL constructs.
+  FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
+      const ast::Name &name, bool isInline,
+      ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+      ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
+
+  /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
+  /// effectively the same syntax, and only 
diff er on slight semantics (given
+  /// the 
diff erent parsing contexts).
+  template <typename T, typename ParseUserPDLLDeclFnT>
+  FailureOr<T *> parseUserConstraintOrRewriteDecl(
+      ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
+      StringRef anonymousNamePrefix, bool isInline);
+
+  /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
+  /// These decls have effectively the same syntax.
+  template <typename T>
+  FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
+      const ast::Name &name, bool isInline,
+      ArrayRef<ast::VariableDecl *> arguments,
+      ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
+
+  /// Parse the functional signature (i.e. the arguments and results) of a
+  /// UserConstraintDecl or UserRewriteDecl.
+  LogicalResult parseUserConstraintOrRewriteSignature(
+      SmallVectorImpl<ast::VariableDecl *> &arguments,
+      SmallVectorImpl<ast::VariableDecl *> &results,
+      ast::DeclScope *&argumentScope, ast::Type &resultType);
+
+  /// Validate the return (which if present is specified by bodyIt) of a
+  /// UserConstraintDecl or UserRewriteDecl.
+  LogicalResult validateUserConstraintOrRewriteReturn(
+      StringRef declType, ast::CompoundStmt *body,
+      ArrayRef<ast::Stmt *>::iterator bodyIt,
+      ArrayRef<ast::Stmt *>::iterator bodyE,
+      ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
+
   FailureOr<ast::CompoundStmt *>
   parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
                   bool expectTerminalSemicolon = true);
@@ -138,10 +212,17 @@ class Parser {
   /// location of a previously parsed type constraint for the entity that will
   /// be constrained by the parsed constraint. `existingConstraints` are any
   /// existing constraints that have already been parsed for the same entity
-  /// that will be constrained by this constraint.
+  /// that will be constrained by this constraint. `allowInlineTypeConstraints`
+  /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
   FailureOr<ast::ConstraintRef>
   parseConstraint(Optional<SMRange> &typeConstraint,
-                  ArrayRef<ast::ConstraintRef> existingConstraints);
+                  ArrayRef<ast::ConstraintRef> existingConstraints,
+                  bool allowInlineTypeConstraints);
+
+  /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
+  /// argument or result variable. The constraints for these variables do not
+  /// allow inline type constraints, and only permit a single constraint.
+  FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
 
   //===--------------------------------------------------------------------===//
   // Exprs
@@ -150,8 +231,11 @@ class Parser {
 
   /// Identifier expressions.
   FailureOr<ast::Expr *> parseAttributeExpr();
+  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
   FailureOr<ast::Expr *> parseIdentifierExpr();
+  FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
+  FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
@@ -168,6 +252,7 @@ class Parser {
   FailureOr<ast::EraseStmt *> parseEraseStmt();
   FailureOr<ast::LetStmt *> parseLetStmt();
   FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
+  FailureOr<ast::ReturnStmt *> parseReturnStmt();
   FailureOr<ast::RewriteStmt *> parseRewriteStmt();
 
   //===--------------------------------------------------------------------===//
@@ -177,6 +262,10 @@ class Parser {
   //===--------------------------------------------------------------------===//
   // Decls
 
+  /// Try to extract a callable from the given AST node. Returns nullptr on
+  /// failure.
+  ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
+
   /// Try to create a pattern decl with the given components, returning the
   /// Pattern on success.
   FailureOr<ast::PatternDecl *>
@@ -184,12 +273,30 @@ class Parser {
                     const ParsedPatternMetadata &metadata,
                     ast::CompoundStmt *body);
 
+  /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
+  /// of results, defined as part of the signature.
+  ast::Type
+  createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
+
+  /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
+  template <typename T>
+  FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
+      const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
+      ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
+      ast::CompoundStmt *body);
+
   /// Try to create a variable decl with the given components, returning the
   /// Variable on success.
   FailureOr<ast::VariableDecl *>
   createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
                      ArrayRef<ast::ConstraintRef> constraints);
 
+  /// Create a variable for an argument or result defined as part of the
+  /// signature of a UserConstraintDecl/UserRewriteDecl.
+  FailureOr<ast::VariableDecl *>
+  createArgOrResultVariableDecl(StringRef name, SMRange loc,
+                                const ast::ConstraintRef &constraint);
+
   /// Validate the constraints used to constraint a variable decl.
   /// `inferredType` is the type of the variable inferred by the constraints
   /// within the list, and is updated to the most refined type as determined by
@@ -201,23 +308,26 @@ class Parser {
   /// Validate a single reference to a constraint. `inferredType` contains the
   /// currently inferred variabled type and is refined within the type defined
   /// by the constraint. Returns success if the constraint is valid, failure
-  /// otherwise.
+  /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
+  /// defined constraints) may be used with the variable.
   LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
-                                           ast::Type &inferredType);
+                                           ast::Type &inferredType,
+                                           bool allowNonCoreConstraints = true);
   LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
   LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
 
   //===--------------------------------------------------------------------===//
   // Exprs
 
-  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc,
-                                                  ast::Decl *decl);
+  FailureOr<ast::CallExpr *>
+  createCallExpr(SMRange loc, ast::Expr *parentExpr,
+                 MutableArrayRef<ast::Expr *> arguments);
+  FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
   FailureOr<ast::DeclRefExpr *>
   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
                            ArrayRef<ast::ConstraintRef> constraints);
   FailureOr<ast::MemberAccessExpr *>
-  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
-                         SMRange loc);
+  createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
 
   /// Validate the member access `name` into the given parent expression. On
   /// success, this also returns the type of the member accessed.
@@ -231,12 +341,10 @@ class Parser {
   LogicalResult
   validateOperationOperands(SMRange loc, Optional<StringRef> name,
                             MutableArrayRef<ast::Expr *> operands);
-  LogicalResult validateOperationResults(SMRange loc,
-                                         Optional<StringRef> name,
+  LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
                                          MutableArrayRef<ast::Expr *> results);
   LogicalResult
-  validateOperationOperandsOrResults(SMRange loc,
-                                     Optional<StringRef> name,
+  validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name,
                                      MutableArrayRef<ast::Expr *> values,
                                      ast::Type singleTy, ast::Type rangeTy);
   FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
@@ -246,8 +354,7 @@ class Parser {
   //===--------------------------------------------------------------------===//
   // Stmts
 
-  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc,
-                                              ast::Expr *rootOp);
+  FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
   FailureOr<ast::ReplaceStmt *>
   createReplaceStmt(SMRange loc, ast::Expr *rootOp,
                     MutableArrayRef<ast::Expr *> replValues);
@@ -304,8 +411,8 @@ class Parser {
   LogicalResult emitError(const Twine &msg) {
     return emitError(curToken.getLoc(), msg);
   }
-  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg,
-                                 SMRange noteLoc, const Twine &note) {
+  LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
+                                 const Twine &note) {
     lexer.emitErrorAndNote(loc, msg, noteLoc, note);
     return failure();
   }
@@ -333,6 +440,9 @@ class Parser {
   /// Cached types to simplify verification and expression creation.
   ast::Type valueTy, valueRangeTy;
   ast::Type typeTy, typeRangeTy;
+
+  /// A counter used when naming anonymous constraints and rewrites.
+  unsigned anonymousDeclNameCounter = 0;
 };
 } // namespace
 
@@ -506,9 +616,15 @@ LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
 FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
   FailureOr<ast::Decl *> decl;
   switch (curToken.getKind()) {
+  case Token::kw_Constraint:
+    decl = parseUserConstraintDecl();
+    break;
   case Token::kw_Pattern:
     decl = parsePatternDecl();
     break;
+  case Token::kw_Rewrite:
+    decl = parseUserRewriteDecl();
+    break;
   default:
     return emitError("expected top-level declaration, such as a `Pattern`");
   }
@@ -570,6 +686,363 @@ FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
   return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
 }
 
+FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
+  // Ensure that the argument is named.
+  if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
+    return emitError("expected identifier argument name");
+
+  // Parse the argument similarly to a normal variable.
+  StringRef name = curToken.getSpelling();
+  SMRange nameLoc = curToken.getLoc();
+  consumeToken();
+
+  if (failed(
+          parseToken(Token::colon, "expected `:` before argument constraint")))
+    return failure();
+
+  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
+  if (failed(cst))
+    return failure();
+
+  return createArgOrResultVariableDecl(name, nameLoc, *cst);
+}
+
+FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
+  // Check to see if this result is named.
+  if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
+    // Check to see if this name actually refers to a Constraint.
+    ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
+    if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
+      // If yes, and this is a Rewrite, give a nice error message as non-Core
+      // constraints are not supported on Rewrite results.
+      if (parserContext == ParserContext::Rewrite) {
+        return emitError(
+            "`Rewrite` results are only permitted to use core constraints, "
+            "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
+      }
+
+      // Otherwise, parse this as an unnamed result variable.
+    } else {
+      // If it wasn't a constraint, parse the result similarly to a variable. If
+      // there is already an existing decl, we will emit an error when defining
+      // this variable later.
+      StringRef name = curToken.getSpelling();
+      SMRange nameLoc = curToken.getLoc();
+      consumeToken();
+
+      if (failed(parseToken(Token::colon,
+                            "expected `:` before result constraint")))
+        return failure();
+
+      FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
+      if (failed(cst))
+        return failure();
+
+      return createArgOrResultVariableDecl(name, nameLoc, *cst);
+    }
+  }
+
+  // If it isn't named, we parse the constraint directly and create an unnamed
+  // result variable.
+  FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
+  if (failed(cst))
+    return failure();
+
+  return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
+}
+
+FailureOr<ast::UserConstraintDecl *>
+Parser::parseUserConstraintDecl(bool isInline) {
+  // Constraints and rewrites have very similar formats, dispatch to a shared
+  // interface for parsing.
+  return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
+      [&](auto &&...args) { return parseUserPDLLConstraintDecl(args...); },
+      ParserContext::Constraint, "constraint", isInline);
+}
+
+FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
+  FailureOr<ast::UserConstraintDecl *> decl =
+      parseUserConstraintDecl(/*isInline=*/true);
+  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
+    return failure();
+
+  curDeclScope->add(*decl);
+  return decl;
+}
+
+FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
+    const ast::Name &name, bool isInline,
+    ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+    ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
+  // Push the argument scope back onto the list, so that the body can
+  // reference arguments.
+  pushDeclScope(argumentScope);
+
+  // Parse the body of the constraint. The body is either defined as a compound
+  // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
+  ast::CompoundStmt *body;
+  if (curToken.is(Token::equal_arrow)) {
+    FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
+        [&](ast::Stmt *&stmt) -> LogicalResult {
+          ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
+          if (!stmtExpr) {
+            return emitError(stmt->getLoc(),
+                             "expected `Constraint` lambda body to contain a "
+                             "single expression");
+          }
+          stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
+          return success();
+        },
+        /*expectTerminalSemicolon=*/!isInline);
+    if (failed(bodyResult))
+      return failure();
+    body = *bodyResult;
+  } else {
+    FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
+    if (failed(bodyResult))
+      return failure();
+    body = *bodyResult;
+
+    // Verify the structure of the body.
+    auto bodyIt = body->begin(), bodyE = body->end();
+    for (; bodyIt != bodyE; ++bodyIt)
+      if (isa<ast::ReturnStmt>(*bodyIt))
+        break;
+    if (failed(validateUserConstraintOrRewriteReturn(
+            "Constraint", body, bodyIt, bodyE, results, resultType)))
+      return failure();
+  }
+  popDeclScope();
+
+  return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
+      name, arguments, results, resultType, body);
+}
+
+FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
+  // Constraints and rewrites have very similar formats, dispatch to a shared
+  // interface for parsing.
+  return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
+      [&](auto &&...args) { return parseUserPDLLRewriteDecl(args...); },
+      ParserContext::Rewrite, "rewrite", isInline);
+}
+
+FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
+  FailureOr<ast::UserRewriteDecl *> decl =
+      parseUserRewriteDecl(/*isInline=*/true);
+  if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
+    return failure();
+
+  curDeclScope->add(*decl);
+  return decl;
+}
+
+FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
+    const ast::Name &name, bool isInline,
+    ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+    ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
+  // Push the argument scope back onto the list, so that the body can
+  // reference arguments.
+  curDeclScope = argumentScope;
+  ast::CompoundStmt *body;
+  if (curToken.is(Token::equal_arrow)) {
+    FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
+        [&](ast::Stmt *&statement) -> LogicalResult {
+          if (isa<ast::OpRewriteStmt>(statement))
+            return success();
+
+          ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
+          if (!statementExpr) {
+            return emitError(
+                statement->getLoc(),
+                "expected `Rewrite` lambda body to contain a single expression "
+                "or an operation rewrite statement; such as `erase`, "
+                "`replace`, or `rewrite`");
+          }
+          statement =
+              ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
+          return success();
+        },
+        /*expectTerminalSemicolon=*/!isInline);
+    if (failed(bodyResult))
+      return failure();
+    body = *bodyResult;
+  } else {
+    FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
+    if (failed(bodyResult))
+      return failure();
+    body = *bodyResult;
+  }
+  popDeclScope();
+
+  // Verify the structure of the body.
+  auto bodyIt = body->begin(), bodyE = body->end();
+  for (; bodyIt != bodyE; ++bodyIt)
+    if (isa<ast::ReturnStmt>(*bodyIt))
+      break;
+  if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
+                                                   bodyE, results, resultType)))
+    return failure();
+  return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
+      name, arguments, results, resultType, body);
+}
+
+template <typename T, typename ParseUserPDLLDeclFnT>
+FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
+    ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
+    StringRef anonymousNamePrefix, bool isInline) {
+  SMRange loc = curToken.getLoc();
+  consumeToken();
+  llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
+
+  // Parse the name of the decl.
+  const ast::Name *name = nullptr;
+  if (curToken.isNot(Token::identifier)) {
+    // Only inline decls can be un-named. Inline decls are similar to "lambdas"
+    // in C++, so being unnamed is fine.
+    if (!isInline)
+      return emitError("expected identifier name");
+
+    // Create a unique anonymous name to use, as the name for this decl is not
+    // important.
+    std::string anonName =
+        llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
+                      anonymousDeclNameCounter++)
+            .str();
+    name = &ast::Name::create(ctx, anonName, loc);
+  } else {
+    // If a name was provided, we can use it directly.
+    name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
+    consumeToken(Token::identifier);
+  }
+
+  // Parse the functional signature of the decl.
+  SmallVector<ast::VariableDecl *> arguments, results;
+  ast::DeclScope *argumentScope;
+  ast::Type resultType;
+  if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
+                                                   argumentScope, resultType)))
+    return failure();
+
+  // Check to see which type of constraint this is. If the constraint contains a
+  // compound body, this is a PDLL decl.
+  if (curToken.isAny(Token::l_brace, Token::equal_arrow))
+    return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
+                           resultType);
+
+  // Otherwise, this is a native decl.
+  return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
+                                                   results, resultType);
+}
+
+template <typename T>
+FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
+    const ast::Name &name, bool isInline,
+    ArrayRef<ast::VariableDecl *> arguments,
+    ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
+  // If followed by a string, the native code body has also been specified.
+  std::string codeStrStorage;
+  Optional<StringRef> optCodeStr;
+  if (curToken.isString()) {
+    codeStrStorage = curToken.getStringValue();
+    optCodeStr = codeStrStorage;
+    consumeToken();
+  } else if (isInline) {
+    return emitError(name.getLoc(),
+                     "external declarations must be declared in global scope");
+  }
+  if (failed(parseToken(Token::semicolon,
+                        "expected `;` after native declaration")))
+    return failure();
+  return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
+}
+
+LogicalResult Parser::parseUserConstraintOrRewriteSignature(
+    SmallVectorImpl<ast::VariableDecl *> &arguments,
+    SmallVectorImpl<ast::VariableDecl *> &results,
+    ast::DeclScope *&argumentScope, ast::Type &resultType) {
+  // Parse the argument list of the decl.
+  if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
+    return failure();
+
+  argumentScope = pushDeclScope();
+  if (curToken.isNot(Token::r_paren)) {
+    do {
+      FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
+      if (failed(argument))
+        return failure();
+      arguments.emplace_back(*argument);
+    } while (consumeIf(Token::comma));
+  }
+  popDeclScope();
+  if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
+    return failure();
+
+  // Parse the results of the decl.
+  pushDeclScope();
+  if (consumeIf(Token::arrow)) {
+    auto parseResultFn = [&]() -> LogicalResult {
+      FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
+      if (failed(result))
+        return failure();
+      results.emplace_back(*result);
+      return success();
+    };
+
+    // Check for a list of results.
+    if (consumeIf(Token::l_paren)) {
+      do {
+        if (failed(parseResultFn()))
+          return failure();
+      } while (consumeIf(Token::comma));
+      if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
+        return failure();
+
+      // Otherwise, there is only one result.
+    } else if (failed(parseResultFn())) {
+      return failure();
+    }
+  }
+  popDeclScope();
+
+  // Compute the result type of the decl.
+  resultType = createUserConstraintRewriteResultType(results);
+
+  // Verify that results are only named if there are more than one.
+  if (results.size() == 1 && !results.front()->getName().getName().empty()) {
+    return emitError(
+        results.front()->getLoc(),
+        "cannot create a single-element tuple with an element label");
+  }
+  return success();
+}
+
+LogicalResult Parser::validateUserConstraintOrRewriteReturn(
+    StringRef declType, ast::CompoundStmt *body,
+    ArrayRef<ast::Stmt *>::iterator bodyIt,
+    ArrayRef<ast::Stmt *>::iterator bodyE,
+    ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
+  // Handle if a `return` was provided.
+  if (bodyIt != bodyE) {
+    // Emit an error if we have trailing statements after the return.
+    if (std::next(bodyIt) != bodyE) {
+      return emitError(
+          (*std::next(bodyIt))->getLoc(),
+          llvm::formatv("`return` terminated the `{0}` body, but found "
+                        "trailing statements afterwards",
+                        declType));
+    }
+
+    // Otherwise if a return wasn't provided, check that no results are
+    // expected.
+  } else if (!results.empty()) {
+    return emitError(
+        {body->getLoc().End, body->getLoc().End},
+        llvm::formatv("missing return in a `{0}` expected to return `{1}`",
+                      declType, resultType));
+  }
+  return success();
+}
+
 FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
   return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
     if (isa<ast::OpRewriteStmt>(statement))
@@ -619,6 +1092,11 @@ FailureOr<ast::Decl *> Parser::parsePatternDecl() {
     // Verify the body of the pattern.
     auto bodyIt = body->begin(), bodyE = body->end();
     for (; bodyIt != bodyE; ++bodyIt) {
+      if (isa<ast::ReturnStmt>(*bodyIt)) {
+        return emitError((*bodyIt)->getLoc(),
+                         "`return` statements are only permitted within a "
+                         "`Constraint` or `Rewrite` body");
+      }
       // Break when we've found the rewrite statement.
       if (isa<ast::OpRewriteStmt>(*bodyIt))
         break;
@@ -719,8 +1197,8 @@ LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
 }
 
 FailureOr<ast::VariableDecl *>
-Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
-                           ast::Type type, ast::Expr *initExpr,
+Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
+                           ast::Expr *initExpr,
                            ArrayRef<ast::ConstraintRef> constraints) {
   assert(curDeclScope && "defining variable outside of decl scope");
   const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
@@ -741,8 +1219,7 @@ Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
 }
 
 FailureOr<ast::VariableDecl *>
-Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
-                           ast::Type type,
+Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
                            ArrayRef<ast::ConstraintRef> constraints) {
   return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
                             constraints);
@@ -752,8 +1229,8 @@ LogicalResult Parser::parseVariableDeclConstraintList(
     SmallVectorImpl<ast::ConstraintRef> &constraints) {
   Optional<SMRange> typeConstraint;
   auto parseSingleConstraint = [&] {
-    FailureOr<ast::ConstraintRef> constraint =
-        parseConstraint(typeConstraint, constraints);
+    FailureOr<ast::ConstraintRef> constraint = parseConstraint(
+        typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
     if (failed(constraint))
       return failure();
     constraints.push_back(*constraint);
@@ -773,8 +1250,15 @@ LogicalResult Parser::parseVariableDeclConstraintList(
 
 FailureOr<ast::ConstraintRef>
 Parser::parseConstraint(Optional<SMRange> &typeConstraint,
-                        ArrayRef<ast::ConstraintRef> existingConstraints) {
+                        ArrayRef<ast::ConstraintRef> existingConstraints,
+                        bool allowInlineTypeConstraints) {
   auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
+    if (!allowInlineTypeConstraints) {
+      return emitError(
+          curToken.getLoc(),
+          "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
+          "permitted on arguments or results");
+    }
     if (typeConstraint)
       return emitErrorAndNote(
           curToken.getLoc(),
@@ -842,6 +1326,14 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
     return ast::ConstraintRef(
         ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
   }
+
+  case Token::kw_Constraint: {
+    // Handle an inline constraint.
+    FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
+    if (failed(decl))
+      return failure();
+    return ast::ConstraintRef(*decl, loc);
+  }
   case Token::identifier: {
     StringRef constraintName = curToken.getSpelling();
     consumeToken(Token::identifier);
@@ -867,6 +1359,12 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
   return emitError(loc, "expected identifier constraint");
 }
 
+FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
+  Optional<SMRange> typeConstraint;
+  return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
+                         /*allowInlineTypeConstraints=*/false);
+}
+
 //===----------------------------------------------------------------------===//
 // Exprs
 
@@ -880,12 +1378,18 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
   case Token::kw_attr:
     lhsExpr = parseAttributeExpr();
     break;
+  case Token::kw_Constraint:
+    lhsExpr = parseInlineConstraintLambdaExpr();
+    break;
   case Token::identifier:
     lhsExpr = parseIdentifierExpr();
     break;
   case Token::kw_op:
     lhsExpr = parseOperationExpr();
     break;
+  case Token::kw_Rewrite:
+    lhsExpr = parseInlineRewriteLambdaExpr();
+    break;
   case Token::kw_type:
     lhsExpr = parseTypeExpr();
     break;
@@ -904,6 +1408,9 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
     case Token::dot:
       lhsExpr = parseMemberAccessExpr(*lhsExpr);
       break;
+    case Token::l_paren:
+      lhsExpr = parseCallExpr(*lhsExpr);
+      break;
     default:
       return lhsExpr;
     }
@@ -934,8 +1441,28 @@ FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
   return ast::AttributeExpr::create(ctx, loc, attrExpr);
 }
 
-FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name,
-                                                SMRange loc) {
+FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
+  SMRange loc = curToken.getLoc();
+  consumeToken(Token::l_paren);
+
+  // Parse the arguments of the call.
+  SmallVector<ast::Expr *> arguments;
+  if (curToken.isNot(Token::r_paren)) {
+    do {
+      FailureOr<ast::Expr *> argument = parseExpr();
+      if (failed(argument))
+        return failure();
+      arguments.push_back(*argument);
+    } while (consumeIf(Token::comma));
+  }
+  loc.End = curToken.getEndLoc();
+  if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
+    return failure();
+
+  return createCallExpr(loc, parentExpr, arguments);
+}
+
+FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
   ast::Decl *decl = curDeclScope->lookup(name);
   if (!decl)
     return emitError(loc, "undefined reference to `" + name + "`");
@@ -963,6 +1490,24 @@ FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
   return parseDeclRefExpr(name, nameLoc);
 }
 
+FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
+  FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
+  if (failed(decl))
+    return failure();
+
+  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
+                                  ast::ConstraintType::get(ctx));
+}
+
+FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
+  FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
+  if (failed(decl))
+    return failure();
+
+  return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
+                                  ast::RewriteType::get(ctx));
+}
+
 FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
   SMRange loc = curToken.getLoc();
   consumeToken(Token::dot);
@@ -1202,6 +1747,9 @@ FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
   case Token::kw_replace:
     stmt = parseReplaceStmt();
     break;
+  case Token::kw_return:
+    stmt = parseReturnStmt();
+    break;
   case Token::kw_rewrite:
     stmt = parseRewriteStmt();
     break;
@@ -1239,6 +1787,8 @@ FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
 }
 
 FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
+  if (parserContext == ParserContext::Constraint)
+    return emitError("`erase` cannot be used within a Constraint");
   SMRange loc = curToken.getLoc();
   consumeToken(Token::kw_erase);
 
@@ -1311,6 +1861,8 @@ FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
 }
 
 FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
+  if (parserContext == ParserContext::Constraint)
+    return emitError("`replace` cannot be used within a Constraint");
   SMRange loc = curToken.getLoc();
   consumeToken(Token::kw_replace);
 
@@ -1356,7 +1908,21 @@ FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
   return createReplaceStmt(loc, *rootOp, replValues);
 }
 
+FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
+  SMRange loc = curToken.getLoc();
+  consumeToken(Token::kw_return);
+
+  // Parse the result value.
+  FailureOr<ast::Expr *> resultExpr = parseExpr();
+  if (failed(resultExpr))
+    return failure();
+
+  return ast::ReturnStmt::create(ctx, loc, *resultExpr);
+}
+
 FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
+  if (parserContext == ParserContext::Constraint)
+    return emitError("`rewrite` cannot be used within a Constraint");
   SMRange loc = curToken.getLoc();
   consumeToken(Token::kw_rewrite);
 
@@ -1379,6 +1945,15 @@ FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
   if (failed(rewriteBody))
     return failure();
 
+  // Verify the rewrite body.
+  for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
+    if (isa<ast::ReturnStmt>(stmt)) {
+      return emitError(stmt->getLoc(),
+                       "`return` statements are only permitted within a "
+                       "`Constraint` or `Rewrite` body");
+    }
+  }
+
   return createRewriteStmt(loc, *rootOp, *rewriteBody);
 }
 
@@ -1389,6 +1964,13 @@ FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
 //===----------------------------------------------------------------------===//
 // Decls
 
+ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
+  // Unwrap reference expressions.
+  if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
+    node = init->getDecl();
+  return dyn_cast<ast::CallableDecl>(node);
+}
+
 FailureOr<ast::PatternDecl *>
 Parser::createPatternDecl(SMRange loc, const ast::Name *name,
                           const ParsedPatternMetadata &metadata,
@@ -1397,9 +1979,47 @@ Parser::createPatternDecl(SMRange loc, const ast::Name *name,
                                   metadata.hasBoundedRecursion, body);
 }
 
+ast::Type Parser::createUserConstraintRewriteResultType(
+    ArrayRef<ast::VariableDecl *> results) {
+  // Single result decls use the type of the single result.
+  if (results.size() == 1)
+    return results[0]->getType();
+
+  // Multiple results use a tuple type, with the types and names grabbed from
+  // the result variable decls.
+  auto resultTypes = llvm::map_range(
+      results, [&](const auto *result) { return result->getType(); });
+  auto resultNames = llvm::map_range(
+      results, [&](const auto *result) { return result->getName().getName(); });
+  return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
+                             llvm::to_vector(resultNames));
+}
+
+template <typename T>
+FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
+    const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
+    ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
+    ast::CompoundStmt *body) {
+  if (!body->getChildren().empty()) {
+    if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
+      ast::Expr *resultExpr = retStmt->getResultExpr();
+
+      // Process the result of the decl. If no explicit signature results
+      // were provided, check for return type inference. Otherwise, check that
+      // the return expression can be converted to the expected type.
+      if (results.empty())
+        resultType = resultExpr->getType();
+      else if (failed(convertExpressionTo(resultExpr, resultType)))
+        return failure();
+      else
+        retStmt->setResultExpr(resultExpr);
+    }
+  }
+  return T::createPDLL(ctx, name, arguments, results, body, resultType);
+}
+
 FailureOr<ast::VariableDecl *>
-Parser::createVariableDecl(StringRef name, SMRange loc,
-                           ast::Expr *initializer,
+Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
                            ArrayRef<ast::ConstraintRef> constraints) {
   // The type of the variable, which is expected to be inferred by either a
   // constraint or an initializer expression.
@@ -1426,6 +2046,12 @@ Parser::createVariableDecl(StringRef name, SMRange loc,
         "list or the initializer");
   }
 
+  // Constraint types cannot be used when defining variables.
+  if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
+    return emitError(
+        loc, llvm::formatv("unable to define variable of `{0}` type", type));
+  }
+
   // Try to define a variable with the given name.
   FailureOr<ast::VariableDecl *> varDecl =
       defineVariableDecl(name, loc, type, initializer, constraints);
@@ -1435,6 +2061,18 @@ Parser::createVariableDecl(StringRef name, SMRange loc,
   return *varDecl;
 }
 
+FailureOr<ast::VariableDecl *>
+Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
+                                      const ast::ConstraintRef &constraint) {
+  // Constraint arguments may apply more complex constraints via the arguments.
+  bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
+  ast::Type argType;
+  if (failed(validateVariableConstraint(constraint, argType,
+                                        allowNonCoreConstraints)))
+    return failure();
+  return defineVariableDecl(name, loc, argType, constraint);
+}
+
 LogicalResult
 Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
                                     ast::Type &inferredType) {
@@ -1445,7 +2083,8 @@ Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
 }
 
 LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
-                                                 ast::Type &inferredType) {
+                                                 ast::Type &inferredType,
+                                                 bool allowNonCoreConstraints) {
   ast::Type constraintType;
   if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
     if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
@@ -1474,6 +2113,25 @@ LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
         return failure();
     }
     constraintType = valueRangeTy;
+  } else if (const auto *cst =
+                 dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
+    if (!allowNonCoreConstraints) {
+      return emitError(ref.referenceLoc,
+                       "`Rewrite` arguments and results are only permitted to "
+                       "use core constraints, such as `Attr`, `Op`, `Type`, "
+                       "`TypeRange`, `Value`, `ValueRange`");
+    }
+
+    ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
+    if (inputs.size() != 1) {
+      return emitErrorAndNote(ref.referenceLoc,
+                              "`Constraint`s applied via a variable constraint "
+                              "list must take a single input, but got " +
+                                  Twine(inputs.size()),
+                              cst->getLoc(),
+                              "see definition of constraint here");
+    }
+    constraintType = inputs.front()->getType();
   } else {
     llvm_unreachable("unknown constraint type");
   }
@@ -1515,11 +2173,66 @@ Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
 //===----------------------------------------------------------------------===//
 // Exprs
 
+FailureOr<ast::CallExpr *>
+Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
+                       MutableArrayRef<ast::Expr *> arguments) {
+  ast::Type parentType = parentExpr->getType();
+
+  ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
+  if (!callableDecl) {
+    return emitError(loc,
+                     llvm::formatv("expected a reference to a callable "
+                                   "`Constraint` or `Rewrite`, but got: `{0}`",
+                                   parentType));
+  }
+  if (parserContext == ParserContext::Rewrite) {
+    if (isa<ast::UserConstraintDecl>(callableDecl))
+      return emitError(
+          loc, "unable to invoke `Constraint` within a rewrite section");
+  } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
+    return emitError(loc, "unable to invoke `Rewrite` within a match section");
+  }
+
+  // Verify the arguments of the call.
+  /// Handle size mismatch.
+  ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
+  if (callArgs.size() != arguments.size()) {
+    return emitErrorAndNote(
+        loc,
+        llvm::formatv("invalid number of arguments for {0} call; expected "
+                      "{1}, but got {2}",
+                      callableDecl->getCallableType(), callArgs.size(),
+                      arguments.size()),
+        callableDecl->getLoc(),
+        llvm::formatv("see the definition of {0} here",
+                      callableDecl->getName()->getName()));
+  }
+
+  /// Handle argument type mismatch.
+  auto attachDiagFn = [&](ast::Diagnostic &diag) {
+    diag.attachNote(llvm::formatv("see the definition of `{0}` here",
+                                  callableDecl->getName()->getName()),
+                    callableDecl->getLoc());
+  };
+  for (auto it : llvm::zip(callArgs, arguments)) {
+    if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
+                                   attachDiagFn)))
+      return failure();
+  }
+
+  return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
+                               callableDecl->getResultType());
+}
+
 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
                                                         ast::Decl *decl) {
   // Check the type of decl being referenced.
   ast::Type declType;
-  if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
+  if (isa<ast::ConstraintDecl>(decl))
+    declType = ast::ConstraintType::get(ctx);
+  else if (isa<ast::UserRewriteDecl>(decl))
+    declType = ast::RewriteType::get(ctx);
+  else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
     declType = varDecl->getType();
   else
     return emitError(loc, "invalid reference to `" +
@@ -1529,8 +2242,7 @@ FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
 }
 
 FailureOr<ast::DeclRefExpr *>
-Parser::createInlineVariableExpr(ast::Type type, StringRef name,
-                                 SMRange loc,
+Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
                                  ArrayRef<ast::ConstraintRef> constraints) {
   FailureOr<ast::VariableDecl *> decl =
       defineVariableDecl(name, loc, type, constraints);
@@ -1551,8 +2263,7 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
 }
 
 FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
-                                                  StringRef name,
-                                                  SMRange loc) {
+                                                  StringRef name, SMRange loc) {
   ast::Type parentType = parentExpr->getType();
   if (parentType.isa<ast::OperationType>()) {
     if (name == ast::AllResultsMemberAccessExpr::getMemberName())
@@ -1622,9 +2333,8 @@ Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
 }
 
 LogicalResult Parser::validateOperationOperandsOrResults(
-    SMRange loc, Optional<StringRef> name,
-    MutableArrayRef<ast::Expr *> values, ast::Type singleTy,
-    ast::Type rangeTy) {
+    SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
+    ast::Type singleTy, ast::Type rangeTy) {
   // All operation types accept a single range parameter.
   if (values.size() == 1) {
     if (failed(convertExpressionTo(values[0], rangeTy)))
@@ -1665,7 +2375,7 @@ Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
                         ArrayRef<StringRef> elementNames) {
   for (const ast::Expr *element : elements) {
     ast::Type eleTy = element->getType();
-    if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) {
+    if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
       return emitError(
           element->getLoc(),
           llvm::formatv("unable to build a tuple with `{0}` element", eleTy));

diff  --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
new file mode 100644
index 0000000000000..8291913623943
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
@@ -0,0 +1,160 @@
+// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Constraint Structure
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected identifier name
+Constraint {}
+
+// -----
+
+// CHECK: :6:12: error: `Foo` has already been defined
+// CHECK: :5:12: note: see previous definition here
+Constraint Foo() { op<>; }
+Constraint Foo() { op<>; }
+
+// -----
+
+Constraint Foo() {
+  // CHECK: `erase` cannot be used within a Constraint
+  erase op<>;
+}
+
+// -----
+
+Constraint Foo() {
+  // CHECK: `replace` cannot be used within a Constraint
+  replace;
+}
+
+// -----
+
+Constraint Foo() {
+  // CHECK: `rewrite` cannot be used within a Constraint
+  rewrite;
+}
+
+// -----
+
+Constraint Foo() -> Value {
+  // CHECK: `return` terminated the `Constraint` body, but found trailing statements afterwards
+  return _: Value;
+  return _: Value;
+}
+
+// -----
+
+// CHECK: missing return in a `Constraint` expected to return `Value`
+Constraint Foo() -> Value {
+  let value: Value;
+}
+
+// -----
+
+// CHECK: expected `Constraint` lambda body to contain a single expression
+Constraint Foo() -> Value => let foo: Value;
+
+// -----
+
+// CHECK: unable to convert expression of type `Op` to the expected type of `Attr`
+Constraint Foo() -> Attr => op<>;
+
+// -----
+
+Rewrite SomeRewrite();
+
+// CHECK: unable to invoke `Rewrite` within a match section
+Constraint Foo() {
+  SomeRewrite();
+}
+
+// -----
+
+Constraint Foo() {
+  Constraint Foo() {};
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Arguments
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected `(` to start argument list
+Constraint Foo {}
+
+// -----
+
+// CHECK: expected identifier argument name
+Constraint Foo(10{}
+
+// -----
+
+// CHECK: expected `:` before argument constraint
+Constraint Foo(arg{}
+
+// -----
+
+// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
+Constraint Foo(arg: Value<type>){}
+
+// -----
+
+// CHECK: expected `)` to end argument list
+Constraint Foo(arg: Value{}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Results
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected identifier constraint
+Constraint Foo() -> {}
+
+// -----
+
+// CHECK: cannot create a single-element tuple with an element label
+Constraint Foo() -> result: Value;
+
+// -----
+
+// CHECK: cannot create a single-element tuple with an element label
+Constraint Foo() -> (result: Value);
+
+// -----
+
+// CHECK: expected identifier constraint
+Constraint Foo() -> ();
+
+// -----
+
+// CHECK: expected `:` before result constraint
+Constraint Foo() -> (result{};
+
+// -----
+
+// CHECK: expected `)` to end result list
+Constraint Foo() -> (Op{};
+
+// -----
+
+// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
+Constraint Foo() -> Value<type>){}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Native Constraints
+//===----------------------------------------------------------------------===//
+
+Pattern {
+  // CHECK: external declarations must be declared in global scope
+  Constraint ExternalConstraint();
+}
+
+// -----
+
+// CHECK: expected `;` after native declaration
+Constraint Foo() [{}]

diff  --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll
new file mode 100644
index 0000000000000..1c0a015ab4a7b
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/constraint.pdll
@@ -0,0 +1,74 @@
+// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
+
+// CHECK:  Module
+// CHECK:  `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Tuple<>>
+Constraint Foo();
+
+// -----
+
+// CHECK:  Module
+// CHECK:  `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Tuple<>> Code< /* Native Code */ >
+Constraint Foo() [{ /* Native Code */ }];
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
+// CHECK:   `Inputs`
+// CHECK:     `-VariableDecl {{.*}} Name<arg> Type<Value>
+// CHECK:   `Results`
+// CHECK:     `-VariableDecl {{.*}} Name<> Type<Value>
+// CHECK:   `-CompoundStmt {{.*}}
+// CHECK:     `-ReturnStmt {{.*}}
+// CHECK:       `-DeclRefExpr {{.*}} Type<Value>
+// CHECK:         `-VariableDecl {{.*}} Name<arg> Type<Value>
+Constraint Foo(arg: Value) -> Value => arg;
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Tuple<result1: Value, result2: Attr>>
+// CHECK:   `Results`
+// CHECK:     |-VariableDecl {{.*}} Name<result1> Type<Value>
+// CHECK:     | `Constraints`
+// CHECK:     |   `-ValueConstraintDecl {{.*}}
+// CHECK:     `-VariableDecl {{.*}} Name<result2> Type<Attr>
+// CHECK:       `Constraints`
+// CHECK:         `-AttrConstraintDecl {{.*}}
+// CHECK:   `-CompoundStmt {{.*}}
+// CHECK:     `-ReturnStmt {{.*}}
+// CHECK:       `-TupleExpr {{.*}} Type<Tuple<result1: Value, result2: Attr>>
+// CHECK:         |-MemberAccessExpr {{.*}} Member<0> Type<Value>
+// CHECK:         | `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
+// CHECK:         `-MemberAccessExpr {{.*}} Member<1> Type<Attr>
+// CHECK:           `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
+Constraint Foo() -> (result1: Value, result2: Attr) => (_: Value, attr<"10">);
+
+// -----
+
+// CHECK: Module
+// CHECK: |-UserConstraintDecl {{.*}} Name<Bar> ResultType<Tuple<>>
+// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
+// CHECK:   `Inputs`
+// CHECK:     `-VariableDecl {{.*}} Name<arg> Type<Value>
+// CHECK:       `Constraints`
+// CHECK:         `-UserConstraintDecl {{.*}} Name<Bar> ResultType<Tuple<>>
+// CHECK:   `Results`
+// CHECK:     `-VariableDecl {{.*}} Name<> Type<Value>
+// CHECK:       `Constraints`
+// CHECK:         `-UserConstraintDecl {{.*}} Name<Bar> ResultType<Tuple<>>
+Constraint Bar(input: Value);
+
+Constraint Foo(arg: Bar) -> Bar => arg;
+
+// -----
+
+// Test that anonymous constraints are uniquely named.
+
+// CHECK: Module
+// CHECK: UserConstraintDecl {{.*}} Name<<anonymous_constraint_0>> ResultType<Tuple<>>
+// CHECK: UserConstraintDecl {{.*}} Name<<anonymous_constraint_1>> ResultType<Attr>
+Constraint Outer() {
+  Constraint() {};
+  Constraint() => attr<"10">;
+}

diff  --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
index 1bf73a4ad60fa..7ed3ba8057bd5 100644
--- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
@@ -43,6 +43,45 @@ Pattern {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// Call Expr
+//===----------------------------------------------------------------------===//
+
+Constraint foo(value: Value);
+
+Pattern {
+  // CHECK: expected `)` after argument list
+  foo(_: Value{};
+}
+
+// -----
+
+Pattern {
+  // CHECK: expected a reference to a callable `Constraint` or `Rewrite`, but got: `Op`
+  let foo: Op;
+  foo();
+}
+
+// -----
+
+Constraint Foo();
+
+Pattern {
+  // CHECK: invalid number of arguments for constraint call; expected 0, but got 1
+  Foo(_: Value);
+}
+
+// -----
+
+Constraint Foo(arg: Value);
+
+Pattern {
+  // CHECK: unable to convert expression of type `Attr` to the expected type of `Value`
+  Foo(attr<"i32">);
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Member Access Expr
 //===----------------------------------------------------------------------===//
@@ -105,6 +144,26 @@ Pattern {
 
 // -----
 
+Constraint Foo();
+
+Pattern {
+  // CHECK: unable to build a tuple with `Constraint` element
+  let tuple = (Foo);
+  erase op<>;
+}
+
+// -----
+
+Rewrite Foo();
+
+Pattern {
+  // CHECK: unable to build a tuple with `Rewrite` element
+  let tuple = (Foo);
+  erase op<>;
+}
+
+// -----
+
 Pattern {
   // CHECK: expected expression
   let tuple = (10 = _: Value);

diff  --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll
index f457d653dbd7c..d645feaf118fa 100644
--- a/mlir/test/mlir-pdll/Parser/expr.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr.pdll
@@ -14,6 +14,42 @@ Pattern {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// CallExpr
+//===----------------------------------------------------------------------===//
+
+// CHECK: Module
+// CHECK: |-UserConstraintDecl {{.*}} Name<MakeRootOp> ResultType<Op<my_dialect.foo>>
+// CHECK:   `-CallExpr {{.*}} Type<Op<my_dialect.foo>>
+// CHECK:     `-DeclRefExpr {{.*}} Type<Constraint>
+// CHECK:       `-UserConstraintDecl {{.*}} Name<MakeRootOp> ResultType<Op<my_dialect.foo>>
+Constraint MakeRootOp() => op<my_dialect.foo>;
+
+Pattern {
+  erase MakeRootOp();
+}
+
+// -----
+
+// CHECK: Module
+// CHECK: |-UserRewriteDecl {{.*}} Name<CreateNewOp> ResultType<Op<my_dialect.foo>>
+// CHECK: `-PatternDecl {{.*}}
+// CHECK:   `-CallExpr {{.*}} Type<Op<my_dialect.foo>>
+// CHECK:     `-DeclRefExpr {{.*}} Type<Rewrite>
+// CHECK:       `-UserRewriteDecl {{.*}} Name<CreateNewOp> ResultType<Op<my_dialect.foo>>
+// CHECK:     `Arguments`
+// CHECK:       `-MemberAccessExpr {{.*}} Member<$results> Type<ValueRange>
+// CHECK:         `-DeclRefExpr {{.*}} Type<Op<my_dialect.bar>>
+// CHECK:           `-VariableDecl {{.*}} Name<inputOp> Type<Op<my_dialect.bar>>
+Rewrite CreateNewOp(inputs: ValueRange) => op<my_dialect.foo>(inputs);
+
+Pattern {
+  let inputOp = op<my_dialect.bar>;
+  replace op<my_dialect.bar>(inputOp) with CreateNewOp(inputOp);
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // MemberAccessExpr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
index 42ea657821120..8b104128f5204 100644
--- a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
@@ -1,4 +1,4 @@
-// RUN: not mlir-pdll %s -split-input-file 2>&1  | FileCheck %s
+// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
 
 // CHECK: expected `{` or `=>` to start pattern body
 Pattern }
@@ -12,6 +12,13 @@ Pattern Foo { erase root: Op; }
 
 // -----
 
+// CHECK: `return` statements are only permitted within a `Constraint` or `Rewrite` body
+Pattern {
+  return _: Value;
+}
+
+// -----
+
 // CHECK: expected Pattern body to terminate with an operation rewrite statement
 Pattern {
   let value: Value;
@@ -32,6 +39,15 @@ Pattern => op<>;
 
 // -----
 
+Rewrite SomeRewrite();
+
+// CHECK: unable to invoke `Rewrite` within a match section
+Pattern {
+  SomeRewrite();
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Metadata
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
new file mode 100644
index 0000000000000..1cdb32b5d6b0e
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
@@ -0,0 +1,161 @@
+// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Rewrite Structure
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected identifier name
+Rewrite {}
+
+// -----
+
+// CHECK: :6:9: error: `Foo` has already been defined
+// CHECK: :5:9: note: see previous definition here
+Rewrite Foo();
+Rewrite Foo();
+
+// -----
+
+Rewrite Foo() -> Value {
+  // CHECK: `return` terminated the `Rewrite` body, but found trailing statements afterwards
+  return _: Value;
+  return _: Value;
+}
+
+// -----
+
+// CHECK: missing return in a `Rewrite` expected to return `Value`
+Rewrite Foo() -> Value {
+  let value: Value;
+}
+
+// -----
+
+// CHECK: missing return in a `Rewrite` expected to return `Value`
+Rewrite Foo() -> Value => erase op<my_dialect.foo>;
+
+// -----
+
+// CHECK: unable to convert expression of type `Op<my_dialect.foo>` to the expected type of `Attr`
+Rewrite Foo() -> Attr => op<my_dialect.foo>;
+
+// -----
+
+// CHECK: expected `Rewrite` lambda body to contain a single expression or an operation rewrite statement; such as `erase`, `replace`, or `rewrite`
+Rewrite Foo() => let foo = op<my_dialect.foo>;
+
+// -----
+
+Constraint ValueConstraint(value: Value);
+
+// CHECK: unable to invoke `Constraint` within a rewrite section
+Rewrite Foo(value: Value) {
+  ValueConstraint(value);
+}
+
+// -----
+
+Rewrite Bar();
+
+// CHECK: `Bar` has already been defined
+Rewrite Foo() {
+  Rewrite Bar() {};
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Arguments
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected `(` to start argument list
+Rewrite Foo {}
+
+// -----
+
+// CHECK: expected identifier argument name
+Rewrite Foo(10{}
+
+// -----
+
+// CHECK: expected `:` before argument constraint
+Rewrite Foo(arg{}
+
+// -----
+
+// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
+Rewrite Foo(arg: Value<type>){}
+
+// -----
+
+Constraint ValueConstraint(value: Value);
+
+// CHECK: arguments and results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
+Rewrite Foo(arg: ValueConstraint);
+
+// -----
+
+// CHECK: expected `)` to end argument list
+Rewrite Foo(arg: Value{}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Results
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected identifier constraint
+Rewrite Foo() -> {}
+
+// -----
+
+// CHECK: cannot create a single-element tuple with an element label
+Rewrite Foo() -> result: Value;
+
+// -----
+
+// CHECK: cannot create a single-element tuple with an element label
+Rewrite Foo() -> (result: Value);
+
+// -----
+
+// CHECK: expected identifier constraint
+Rewrite Foo() -> ();
+
+// -----
+
+// CHECK: expected `:` before result constraint
+Rewrite Foo() -> (result{};
+
+// -----
+
+// CHECK: expected `)` to end result list
+Rewrite Foo() -> (Op{};
+
+// -----
+
+// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
+Rewrite Foo() -> Value<type>){}
+
+// -----
+
+Constraint ValueConstraint(value: Value);
+
+// CHECK: results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
+Rewrite Foo() -> ValueConstraint;
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Native Rewrites
+//===----------------------------------------------------------------------===//
+
+Pattern {
+  // CHECK: external declarations must be declared in global scope
+  Rewrite ExternalConstraint();
+}
+
+// -----
+
+// CHECK: expected `;` after native declaration
+Rewrite Foo() [{}]

diff  --git a/mlir/test/mlir-pdll/Parser/rewrite.pdll b/mlir/test/mlir-pdll/Parser/rewrite.pdll
new file mode 100644
index 0000000000000..7ef7478319d4b
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/rewrite.pdll
@@ -0,0 +1,58 @@
+// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
+
+// CHECK:  Module
+// CHECK:  `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Tuple<>>
+Rewrite Foo();
+
+// -----
+
+// CHECK:  Module
+// CHECK:  `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Tuple<>> Code< /* Native Code */ >
+Rewrite Foo() [{ /* Native Code */ }];
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Value>
+// CHECK:   `Inputs`
+// CHECK:     `-VariableDecl {{.*}} Name<arg> Type<Op>
+// CHECK:   `Results`
+// CHECK:     `-VariableDecl {{.*}} Name<> Type<Value>
+// CHECK:   `-CompoundStmt {{.*}}
+// CHECK:     `-ReturnStmt {{.*}}
+// CHECK:       `-MemberAccessExpr {{.*}} Member<$results> Type<Value>
+// CHECK:         `-DeclRefExpr {{.*}} Type<Op>
+// CHECK:           `-VariableDecl {{.*}} Name<arg> Type<Op>
+Rewrite Foo(arg: Op) -> Value => arg;
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Tuple<result1: Value, result2: Attr>>
+// CHECK:   `Results`
+// CHECK:     |-VariableDecl {{.*}} Name<result1> Type<Value>
+// CHECK:     | `Constraints`
+// CHECK:     |   `-ValueConstraintDecl {{.*}}
+// CHECK:     `-VariableDecl {{.*}} Name<result2> Type<Attr>
+// CHECK:       `Constraints`
+// CHECK:         `-AttrConstraintDecl {{.*}}
+// CHECK:   `-CompoundStmt {{.*}}
+// CHECK:     `-ReturnStmt {{.*}}
+// CHECK:       `-TupleExpr {{.*}} Type<Tuple<result1: Value, result2: Attr>>
+// CHECK:         |-MemberAccessExpr {{.*}} Member<0> Type<Value>
+// CHECK:         | `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
+// CHECK:         `-MemberAccessExpr {{.*}} Member<1> Type<Attr>
+// CHECK:           `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
+Rewrite Foo() -> (result1: Value, result2: Attr) => (_: Value, attr<"10">);
+
+// -----
+
+// Test that anonymous Rewrites are uniquely named.
+
+// CHECK: Module
+// CHECK: UserRewriteDecl {{.*}} Name<<anonymous_rewrite_0>> ResultType<Tuple<>>
+// CHECK: UserRewriteDecl {{.*}} Name<<anonymous_rewrite_1>> ResultType<Attr>
+Rewrite Outer() {
+  Rewrite() {};
+  Rewrite() => attr<"10">;
+}

diff  --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll
index 802353615d29c..d52a2d28b6d05 100644
--- a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll
@@ -223,6 +223,33 @@ Pattern {
 
 // -----
 
+Constraint Foo();
+
+Pattern {
+  // CHECK: unable to define variable of `Constraint` type
+  let foo = Foo;
+}
+
+// -----
+
+Rewrite Foo();
+
+Pattern {
+  // CHECK: unable to define variable of `Rewrite` type
+  let foo = Foo;
+}
+
+// -----
+
+Constraint MultiConstraint(arg1: Value, arg2: Value);
+
+Pattern {
+  // CHECK: `Constraint`s applied via a variable constraint list must take a single input, but got 2
+  let foo: MultiConstraint;
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // `replace`
 //===----------------------------------------------------------------------===//
@@ -276,6 +303,17 @@ Pattern {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// `return`
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected `;` after statement
+Constraint Foo(arg: Value) -> Value {
+  return arg
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // `rewrite`
 //===----------------------------------------------------------------------===//
@@ -307,3 +345,12 @@ Pattern {
       op<>;
   };
 }
+
+// -----
+
+Pattern {
+  // CHECK: `return` statements are only permitted within a `Constraint` or `Rewrite` body
+  rewrite root: Op with {
+      return root;
+  };
+}


        


More information about the Mlir-commits mailing list