[Mlir-commits] [mlir] faf4226 - [PDLL] Add support for user defined constraint and rewrite functions
River Riddle
llvmlistbot at llvm.org
Thu Feb 10 12:49:18 PST 2022
Author: River Riddle
Date: 2022-02-10T12:48:59-08:00
New Revision: faf42264e5401a1dfca95b701e5c2bf951d7f8a7
URL: https://github.com/llvm/llvm-project/commit/faf42264e5401a1dfca95b701e5c2bf951d7f8a7
DIFF: https://github.com/llvm/llvm-project/commit/faf42264e5401a1dfca95b701e5c2bf951d7f8a7.diff
LOG: [PDLL] Add support for user defined constraint and rewrite functions
These functions allow for defining pattern fragments usable within the `match` and `rewrite` sections of a pattern. The main structure of Constraints and Rewrites functions are the same, and are similar to functions in other languages; they contain a signature (i.e. name, argument list, result list) and a body:
```pdll
// Constraint that takes a value as an input, and produces a value:
Constraint Cst(arg: Value) -> Value { ... }
// Constraint that returns multiple values:
Constraint Cst() -> (result1: Value, result2: ValueRange);
```
When returning multiple results, each result can be optionally be named (the result of a Constraint/Rewrite in the case of multiple results is a tuple).
These body of a Constraint/Rewrite functions can be specified in several ways:
* Externally
In this case we are importing an external function (registered by the user outside of PDLL):
```pdll
Constraint Foo(op: Op);
Rewrite Bar();
```
* In PDLL (using PDLL constructs)
In this case, the body is defined using PDLL constructs:
```pdll
Rewrite BuildFooOp() {
// The result type of the Rewrite is inferred from the return.
return op<my_dialect.foo>;
}
// Constraints/Rewrites can also implement a lambda/expression
// body for simple one line bodies.
Rewrite BuildFooOp() => op<my_dialect.foo>;
```
* In PDLL (using a native/C++ code block)
In this case the body is specified using a C++(or potentially other language at some point) code block. When building PDLL in AOT mode this will generate a native constraint/rewrite and register it with the PDL bytecode.
```pdll
Rewrite BuildFooOp() -> Op<my_dialect.foo> [{
return rewriter.create<my_dialect::FooOp>(...);
}];
```
Differential Revision: https://reviews.llvm.org/D115836
Added:
mlir/test/mlir-pdll/Parser/constraint-failure.pdll
mlir/test/mlir-pdll/Parser/constraint.pdll
mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
mlir/test/mlir-pdll/Parser/rewrite.pdll
Modified:
mlir/include/mlir/Tools/PDLL/AST/Nodes.h
mlir/include/mlir/Tools/PDLL/AST/Types.h
mlir/lib/Tools/PDLL/AST/Context.cpp
mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
mlir/lib/Tools/PDLL/AST/Nodes.cpp
mlir/lib/Tools/PDLL/AST/TypeDetail.h
mlir/lib/Tools/PDLL/AST/Types.cpp
mlir/lib/Tools/PDLL/Parser/Lexer.cpp
mlir/lib/Tools/PDLL/Parser/Lexer.h
mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/test/mlir-pdll/Parser/expr-failure.pdll
mlir/test/mlir-pdll/Parser/expr.pdll
mlir/test/mlir-pdll/Parser/pattern-failure.pdll
mlir/test/mlir-pdll/Parser/stmt-failure.pdll
Removed:
################################################################################
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index d46b9004b03ad..6824354a16edd 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -300,6 +300,31 @@ class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
CompoundStmt *rewriteBody;
};
+//===----------------------------------------------------------------------===//
+// ReturnStmt
+//===----------------------------------------------------------------------===//
+
+/// This statement represents a return from a "callable" like decl, e.g. a
+/// Constraint or a Rewrite.
+class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
+public:
+ static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
+
+ /// Return the result expression of this statement.
+ Expr *getResultExpr() { return resultExpr; }
+ const Expr *getResultExpr() const { return resultExpr; }
+
+ /// Set the result expression of this statement.
+ void setResultExpr(Expr *expr) { resultExpr = expr; }
+
+private:
+ ReturnStmt(SMRange loc, Expr *resultExpr)
+ : Base(loc), resultExpr(resultExpr) {}
+
+ // The result expression of this statement.
+ Expr *resultExpr;
+};
+
//===----------------------------------------------------------------------===//
// Expr
//===----------------------------------------------------------------------===//
@@ -345,6 +370,43 @@ class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
StringRef value;
};
+//===----------------------------------------------------------------------===//
+// CallExpr
+//===----------------------------------------------------------------------===//
+
+/// This expression represents a call to a decl, such as a
+/// UserConstraintDecl/UserRewriteDecl.
+class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
+ private llvm::TrailingObjects<CallExpr, Expr *> {
+public:
+ static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
+ ArrayRef<Expr *> arguments, Type resultType);
+
+ /// Return the callable of this call.
+ Expr *getCallableExpr() const { return callable; }
+
+ /// Return the arguments of this call.
+ MutableArrayRef<Expr *> getArguments() {
+ return {getTrailingObjects<Expr *>(), numArgs};
+ }
+ ArrayRef<Expr *> getArguments() const {
+ return const_cast<CallExpr *>(this)->getArguments();
+ }
+
+private:
+ CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs)
+ : Base(loc, type), callable(callable), numArgs(numArgs) {}
+
+ /// The callable of this call.
+ Expr *callable;
+
+ /// The number of arguments of the call.
+ unsigned numArgs;
+
+ /// TrailingObject utilities.
+ friend llvm::TrailingObjects<CallExpr, Expr *>;
+};
+
//===----------------------------------------------------------------------===//
// DeclRefExpr
//===----------------------------------------------------------------------===//
@@ -738,6 +800,114 @@ class ValueRangeConstraintDecl
Expr *typeExpr;
};
+//===----------------------------------------------------------------------===//
+// UserConstraintDecl
+//===----------------------------------------------------------------------===//
+
+/// This decl represents a user defined constraint. This is either:
+/// * an imported native constraint
+/// - Similar to an external function declaration. This is a native
+/// constraint defined externally, and imported into PDLL via a
+/// declaration.
+/// * a native constraint defined in PDLL
+/// - This is a native constraint, i.e. a constraint whose implementation is
+/// defined in C++(or potentially some other non-PDLL language). The
+/// implementation of this constraint is specified as a string code block
+/// in PDLL.
+/// * a PDLL constraint
+/// - This is a constraint which is defined using only PDLL constructs.
+class UserConstraintDecl final
+ : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
+ llvm::TrailingObjects<UserConstraintDecl, VariableDecl *> {
+public:
+ /// Create a native constraint with the given optional code block.
+ static UserConstraintDecl *createNative(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ Optional<StringRef> codeBlock,
+ Type resultType) {
+ return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
+ resultType);
+ }
+
+ /// Create a PDLL constraint with the given body.
+ static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ const CompoundStmt *body,
+ Type resultType) {
+ return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
+ body, resultType);
+ }
+
+ /// Return the name of the constraint.
+ const Name &getName() const { return *Decl::getName(); }
+
+ /// Return the input arguments of this constraint.
+ MutableArrayRef<VariableDecl *> getInputs() {
+ return {getTrailingObjects<VariableDecl *>(), numInputs};
+ }
+ ArrayRef<VariableDecl *> getInputs() const {
+ return const_cast<UserConstraintDecl *>(this)->getInputs();
+ }
+
+ /// Return the explicit results of the constraint declaration. May be empty,
+ /// even if the constraint has results (e.g. in the case of inferred results).
+ MutableArrayRef<VariableDecl *> getResults() {
+ return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
+ }
+ ArrayRef<VariableDecl *> getResults() const {
+ return const_cast<UserConstraintDecl *>(this)->getResults();
+ }
+
+ /// Return the optional code block of this constraint, if this is a native
+ /// constraint with a provided implementation.
+ Optional<StringRef> getCodeBlock() const { return codeBlock; }
+
+ /// Return the body of this constraint if this constraint is a PDLL
+ /// constraint, otherwise returns nullptr.
+ const CompoundStmt *getBody() const { return constraintBody; }
+
+ /// Return the result type of this constraint.
+ Type getResultType() const { return resultType; }
+
+ /// Returns true if this constraint is external.
+ bool isExternal() const { return !constraintBody && !codeBlock; }
+
+private:
+ /// Create either a PDLL constraint or a native constraint with the given
+ /// components.
+ static UserConstraintDecl *
+ createImpl(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
+ const CompoundStmt *body, Type resultType);
+
+ UserConstraintDecl(const Name &name, unsigned numInputs, unsigned numResults,
+ Optional<StringRef> codeBlock, const CompoundStmt *body,
+ Type resultType)
+ : Base(name.getLoc(), &name), numInputs(numInputs),
+ numResults(numResults), codeBlock(codeBlock), constraintBody(body),
+ resultType(resultType) {}
+
+ /// The number of inputs to this constraint.
+ unsigned numInputs;
+
+ /// The number of explicit results to this constraint.
+ unsigned numResults;
+
+ /// The optional code block of this constraint.
+ Optional<StringRef> codeBlock;
+
+ /// The optional body of this constraint.
+ const CompoundStmt *constraintBody;
+
+ /// The result type of the constraint.
+ Type resultType;
+
+ /// Allow access to various internals.
+ friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *>;
+};
+
//===----------------------------------------------------------------------===//
// NamedAttributeDecl
//===----------------------------------------------------------------------===//
@@ -826,6 +996,149 @@ class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
const CompoundStmt *patternBody;
};
+//===----------------------------------------------------------------------===//
+// UserRewriteDecl
+//===----------------------------------------------------------------------===//
+
+/// This decl represents a user defined rewrite. This is either:
+/// * an imported native rewrite
+/// - Similar to an external function declaration. This is a native
+/// rewrite defined externally, and imported into PDLL via a declaration.
+/// * a native rewrite defined in PDLL
+/// - This is a native rewrite, i.e. a rewrite whose implementation is
+/// defined in C++(or potentially some other non-PDLL language). The
+/// implementation of this rewrite is specified as a string code block
+/// in PDLL.
+/// * a PDLL rewrite
+/// - This is a rewrite which is defined using only PDLL constructs.
+class UserRewriteDecl final
+ : public Node::NodeBase<UserRewriteDecl, Decl>,
+ llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
+public:
+ /// Create a native rewrite with the given optional code block.
+ static UserRewriteDecl *createNative(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ Optional<StringRef> codeBlock,
+ Type resultType) {
+ return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
+ resultType);
+ }
+
+ /// Create a PDLL rewrite with the given body.
+ static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ const CompoundStmt *body,
+ Type resultType) {
+ return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
+ body, resultType);
+ }
+
+ /// Return the name of the rewrite.
+ const Name &getName() const { return *Decl::getName(); }
+
+ /// Return the input arguments of this rewrite.
+ MutableArrayRef<VariableDecl *> getInputs() {
+ return {getTrailingObjects<VariableDecl *>(), numInputs};
+ }
+ ArrayRef<VariableDecl *> getInputs() const {
+ return const_cast<UserRewriteDecl *>(this)->getInputs();
+ }
+
+ /// Return the explicit results of the rewrite declaration. May be empty,
+ /// even if the rewrite has results (e.g. in the case of inferred results).
+ MutableArrayRef<VariableDecl *> getResults() {
+ return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
+ }
+ ArrayRef<VariableDecl *> getResults() const {
+ return const_cast<UserRewriteDecl *>(this)->getResults();
+ }
+
+ /// Return the optional code block of this rewrite, if this is a native
+ /// rewrite with a provided implementation.
+ Optional<StringRef> getCodeBlock() const { return codeBlock; }
+
+ /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
+ /// otherwise returns nullptr.
+ const CompoundStmt *getBody() const { return rewriteBody; }
+
+ /// Return the result type of this rewrite.
+ Type getResultType() const { return resultType; }
+
+ /// Returns true if this rewrite is external.
+ bool isExternal() const { return !rewriteBody && !codeBlock; }
+
+private:
+ /// Create either a PDLL rewrite or a native rewrite with the given
+ /// components.
+ static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ Optional<StringRef> codeBlock,
+ const CompoundStmt *body, Type resultType);
+
+ UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
+ Optional<StringRef> codeBlock, const CompoundStmt *body,
+ Type resultType)
+ : Base(name.getLoc(), &name), numInputs(numInputs),
+ numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
+ resultType(resultType) {}
+
+ /// The number of inputs to this rewrite.
+ unsigned numInputs;
+
+ /// The number of explicit results to this rewrite.
+ unsigned numResults;
+
+ /// The optional code block of this rewrite.
+ Optional<StringRef> codeBlock;
+
+ /// The optional body of this rewrite.
+ const CompoundStmt *rewriteBody;
+
+ /// The result type of the rewrite.
+ Type resultType;
+
+ /// Allow access to various internals.
+ friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
+};
+
+//===----------------------------------------------------------------------===//
+// CallableDecl
+//===----------------------------------------------------------------------===//
+
+/// This decl represents a shared interface for all callable decls.
+class CallableDecl : public Decl {
+public:
+ /// Return the callable type of this decl.
+ StringRef getCallableType() const {
+ if (isa<UserConstraintDecl>(this))
+ return "constraint";
+ assert(isa<UserRewriteDecl>(this) && "unknown callable type");
+ return "rewrite";
+ }
+
+ /// Return the inputs of this decl.
+ ArrayRef<VariableDecl *> getInputs() const {
+ if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
+ return cst->getInputs();
+ return cast<UserRewriteDecl>(this)->getInputs();
+ }
+
+ /// Return the result type of this decl.
+ Type getResultType() const {
+ if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
+ return cst->getResultType();
+ return cast<UserRewriteDecl>(this)->getResultType();
+ }
+
+ /// Support LLVM type casting facilities.
+ static bool classof(const Node *decl) {
+ return isa<UserConstraintDecl, UserRewriteDecl>(decl);
+ }
+};
+
//===----------------------------------------------------------------------===//
// VariableDecl
//===----------------------------------------------------------------------===//
@@ -912,11 +1225,11 @@ class Module final : public Node::NodeBase<Module, Node>,
inline bool Decl::classof(const Node *node) {
return isa<ConstraintDecl, NamedAttributeDecl, OpNameDecl, PatternDecl,
- VariableDecl>(node);
+ UserRewriteDecl, VariableDecl>(node);
}
inline bool ConstraintDecl::classof(const Node *node) {
- return isa<CoreConstraintDecl>(node);
+ return isa<CoreConstraintDecl, UserConstraintDecl>(node);
}
inline bool CoreConstraintDecl::classof(const Node *node) {
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Types.h b/mlir/include/mlir/Tools/PDLL/AST/Types.h
index cac3cae962c2d..58a20801fb1b5 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Types.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Types.h
@@ -22,6 +22,7 @@ struct AttributeTypeStorage;
struct ConstraintTypeStorage;
struct OperationTypeStorage;
struct RangeTypeStorage;
+struct RewriteTypeStorage;
struct TupleTypeStorage;
struct TypeTypeStorage;
struct ValueTypeStorage;
@@ -203,6 +204,20 @@ class ValueRangeType : public RangeType {
static ValueRangeType get(Context &context);
};
+//===----------------------------------------------------------------------===//
+// RewriteType
+//===----------------------------------------------------------------------===//
+
+/// This class represents a PDLL type that corresponds to a rewrite reference.
+/// This type has no MLIR C++ API correspondance.
+class RewriteType : public Type::TypeBase<detail::RewriteTypeStorage> {
+public:
+ using Base::Base;
+
+ /// Return an instance of the Rewrite type.
+ static RewriteType get(Context &context);
+};
+
//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp
index 960870aabf436..09ae0e6ad6e07 100644
--- a/mlir/lib/Tools/PDLL/AST/Context.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Context.cpp
@@ -15,6 +15,7 @@ using namespace mlir::pdll::ast;
Context::Context() {
typeUniquer.registerSingletonStorageType<detail::AttributeTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::ConstraintTypeStorage>();
+ typeUniquer.registerSingletonStorageType<detail::RewriteTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::TypeTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::ValueTypeStorage>();
diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
index cc7bdd47bf2c5..2be71b67f5292 100644
--- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
+++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
@@ -76,9 +76,11 @@ class NodePrinter {
void printImpl(const EraseStmt *stmt);
void printImpl(const LetStmt *stmt);
void printImpl(const ReplaceStmt *stmt);
+ void printImpl(const ReturnStmt *stmt);
void printImpl(const RewriteStmt *stmt);
void printImpl(const AttributeExpr *expr);
+ void printImpl(const CallExpr *expr);
void printImpl(const DeclRefExpr *expr);
void printImpl(const MemberAccessExpr *expr);
void printImpl(const OperationExpr *expr);
@@ -89,11 +91,13 @@ class NodePrinter {
void printImpl(const OpConstraintDecl *decl);
void printImpl(const TypeConstraintDecl *decl);
void printImpl(const TypeRangeConstraintDecl *decl);
+ void printImpl(const UserConstraintDecl *decl);
void printImpl(const ValueConstraintDecl *decl);
void printImpl(const ValueRangeConstraintDecl *decl);
void printImpl(const NamedAttributeDecl *decl);
void printImpl(const OpNameDecl *decl);
void printImpl(const PatternDecl *decl);
+ void printImpl(const UserRewriteDecl *decl);
void printImpl(const VariableDecl *decl);
void printImpl(const Module *module);
@@ -135,6 +139,7 @@ void NodePrinter::print(Type type) {
print(type.getElementType());
os << "Range";
})
+ .Case([&](RewriteType) { os << "Rewrite"; })
.Case([&](TupleType type) {
os << "Tuple<";
llvm::interleaveComma(
@@ -160,17 +165,19 @@ void NodePrinter::print(const Node *node) {
.Case<
// Statements.
const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
- const RewriteStmt,
+ const ReturnStmt, const RewriteStmt,
// Expressions.
- const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
- const OperationExpr, const TupleExpr, const TypeExpr,
+ const AttributeExpr, const CallExpr, const DeclRefExpr,
+ const MemberAccessExpr, const OperationExpr, const TupleExpr,
+ const TypeExpr,
// Decls.
const AttrConstraintDecl, const OpConstraintDecl,
const TypeConstraintDecl, const TypeRangeConstraintDecl,
- const ValueConstraintDecl, const ValueRangeConstraintDecl,
- const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
+ const UserConstraintDecl, const ValueConstraintDecl,
+ const ValueRangeConstraintDecl, const NamedAttributeDecl,
+ const OpNameDecl, const PatternDecl, const UserRewriteDecl,
const VariableDecl,
const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
@@ -199,6 +206,11 @@ void NodePrinter::printImpl(const ReplaceStmt *stmt) {
printChildren("ReplValues", stmt->getReplExprs());
}
+void NodePrinter::printImpl(const ReturnStmt *stmt) {
+ os << "ReturnStmt " << stmt << "\n";
+ printChildren(stmt->getResultExpr());
+}
+
void NodePrinter::printImpl(const RewriteStmt *stmt) {
os << "RewriteStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
@@ -208,6 +220,14 @@ void NodePrinter::printImpl(const AttributeExpr *expr) {
os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
}
+void NodePrinter::printImpl(const CallExpr *expr) {
+ os << "CallExpr " << expr << " Type<";
+ print(expr->getType());
+ os << ">\n";
+ printChildren(expr->getCallableExpr());
+ printChildren("Arguments", expr->getArguments());
+}
+
void NodePrinter::printImpl(const DeclRefExpr *expr) {
os << "DeclRefExpr " << expr << " Type<";
print(expr->getType());
@@ -265,6 +285,21 @@ void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
os << "TypeRangeConstraintDecl " << decl << "\n";
}
+void NodePrinter::printImpl(const UserConstraintDecl *decl) {
+ os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
+ << "> ResultType<" << decl->getResultType() << ">";
+ if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
+ os << " Code<";
+ llvm::printEscapedString(*codeBlock, os);
+ os << ">";
+ }
+ os << "\n";
+ printChildren("Inputs", decl->getInputs());
+ printChildren("Results", decl->getResults());
+ if (const CompoundStmt *body = decl->getBody())
+ printChildren(body);
+}
+
void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
os << "ValueConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
@@ -303,6 +338,21 @@ void NodePrinter::printImpl(const PatternDecl *decl) {
printChildren(decl->getBody());
}
+void NodePrinter::printImpl(const UserRewriteDecl *decl) {
+ os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
+ << "> ResultType<" << decl->getResultType() << ">";
+ if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
+ os << " Code<";
+ llvm::printEscapedString(*codeBlock, os);
+ os << ">";
+ }
+ os << "\n";
+ printChildren("Inputs", decl->getInputs());
+ printChildren("Results", decl->getResults());
+ if (const CompoundStmt *body = decl->getBody())
+ printChildren(body);
+}
+
void NodePrinter::printImpl(const VariableDecl *decl) {
os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
<< "> Type<";
diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
index 76f42f5329fd4..2e7b252b35c86 100644
--- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
@@ -108,6 +108,15 @@ RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
RewriteStmt(loc, rootOp, rewriteBody);
}
+//===----------------------------------------------------------------------===//
+// ReturnStmt
+//===----------------------------------------------------------------------===//
+
+ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
+ return new (ctx.getAllocator().Allocate<ReturnStmt>())
+ ReturnStmt(loc, resultExpr);
+}
+
//===----------------------------------------------------------------------===//
// AttributeExpr
//===----------------------------------------------------------------------===//
@@ -118,6 +127,22 @@ AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
}
+//===----------------------------------------------------------------------===//
+// CallExpr
+//===----------------------------------------------------------------------===//
+
+CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
+ ArrayRef<Expr *> arguments, Type resultType) {
+ unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
+ void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
+
+ CallExpr *expr =
+ new (rawData) CallExpr(loc, resultType, callable, arguments.size());
+ std::uninitialized_copy(arguments.begin(), arguments.end(),
+ expr->getArguments().begin());
+ return expr;
+}
+
//===----------------------------------------------------------------------===//
// DeclRefExpr
//===----------------------------------------------------------------------===//
@@ -267,6 +292,30 @@ ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
ValueRangeConstraintDecl(loc, typeExpr);
}
+//===----------------------------------------------------------------------===//
+// UserConstraintDecl
+//===----------------------------------------------------------------------===//
+
+UserConstraintDecl *UserConstraintDecl::createImpl(
+ Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
+ const CompoundStmt *body, Type resultType) {
+ unsigned allocSize = UserConstraintDecl::totalSizeToAlloc<VariableDecl *>(
+ inputs.size() + results.size());
+ void *rawData =
+ ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
+ if (codeBlock)
+ codeBlock = codeBlock->copy(ctx.getAllocator());
+
+ UserConstraintDecl *decl = new (rawData) UserConstraintDecl(
+ name, inputs.size(), results.size(), codeBlock, body, resultType);
+ std::uninitialized_copy(inputs.begin(), inputs.end(),
+ decl->getInputs().begin());
+ std::uninitialized_copy(results.begin(), results.end(),
+ decl->getResults().begin());
+ return decl;
+}
+
//===----------------------------------------------------------------------===//
// NamedAttributeDecl
//===----------------------------------------------------------------------===//
@@ -300,6 +349,32 @@ PatternDecl *PatternDecl::create(Context &ctx, SMRange loc,
PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
}
+//===----------------------------------------------------------------------===//
+// UserRewriteDecl
+//===----------------------------------------------------------------------===//
+
+UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
+ ArrayRef<VariableDecl *> inputs,
+ ArrayRef<VariableDecl *> results,
+ Optional<StringRef> codeBlock,
+ const CompoundStmt *body,
+ Type resultType) {
+ unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
+ inputs.size() + results.size());
+ void *rawData =
+ ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
+ if (codeBlock)
+ codeBlock = codeBlock->copy(ctx.getAllocator());
+
+ UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
+ name, inputs.size(), results.size(), codeBlock, body, resultType);
+ std::uninitialized_copy(inputs.begin(), inputs.end(),
+ decl->getInputs().begin());
+ std::uninitialized_copy(results.begin(), results.end(),
+ decl->getResults().begin());
+ return decl;
+}
+
//===----------------------------------------------------------------------===//
// VariableDecl
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h
index b8615175400ae..e6719fb961216 100644
--- a/mlir/lib/Tools/PDLL/AST/TypeDetail.h
+++ b/mlir/lib/Tools/PDLL/AST/TypeDetail.h
@@ -93,6 +93,12 @@ struct RangeTypeStorage : public TypeStorageBase<RangeTypeStorage, Type> {
using Base::Base;
};
+//===----------------------------------------------------------------------===//
+// RewriteType
+//===----------------------------------------------------------------------===//
+
+struct RewriteTypeStorage : public TypeStorageBase<RewriteTypeStorage> {};
+
//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp
index ba9b31f85adf3..cf0f0e918870b 100644
--- a/mlir/lib/Tools/PDLL/AST/Types.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Types.cpp
@@ -107,6 +107,14 @@ ValueRangeType ValueRangeType::get(Context &context) {
.cast<ValueRangeType>();
}
+//===----------------------------------------------------------------------===//
+// RewriteType
+//===----------------------------------------------------------------------===//
+
+RewriteType RewriteType::get(Context &context) {
+ return context.getTypeUniquer().get<ImplTy>();
+}
+
//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
index 3db5cfbdfd645..61c5783e24545 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
@@ -298,7 +298,9 @@ Token Lexer::lexIdentifier(const char *tokStart) {
.Case("OpName", Token::kw_OpName)
.Case("Pattern", Token::kw_Pattern)
.Case("replace", Token::kw_replace)
+ .Case("return", Token::kw_return)
.Case("rewrite", Token::kw_rewrite)
+ .Case("Rewrite", Token::kw_Rewrite)
.Case("type", Token::kw_type)
.Case("Type", Token::kw_Type)
.Case("TypeRange", Token::kw_TypeRange)
diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h
index 4692f28ba877c..0109b0da36c79 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.h
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h
@@ -55,7 +55,9 @@ class Token {
kw_OpName,
kw_Pattern,
kw_replace,
+ kw_return,
kw_rewrite,
+ kw_Rewrite,
kw_Type,
kw_TypeRange,
kw_Value,
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 5264b953aaabc..5ed2481027800 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -50,6 +50,9 @@ class Parser {
enum class ParserContext {
/// The parser is in the global context.
Global,
+ /// The parser is currently within a Constraint, which disallows all types
+ /// of rewrites (e.g. `erase`, `replace`, calls to Rewrites, etc.).
+ Constraint,
/// The parser is currently within the matcher portion of a Pattern, which
/// is allows a terminal operation rewrite statement but no other rewrite
/// transformations.
@@ -106,6 +109,77 @@ class Parser {
FailureOr<ast::Decl *> parseTopLevelDecl();
FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
+
+ /// Parse an argument variable as part of the signature of a
+ /// UserConstraintDecl or UserRewriteDecl.
+ FailureOr<ast::VariableDecl *> parseArgumentDecl();
+
+ /// Parse a result variable as part of the signature of a UserConstraintDecl
+ /// or UserRewriteDecl.
+ FailureOr<ast::VariableDecl *> parseResultDecl(unsigned resultNum);
+
+ /// Parse a UserConstraintDecl. `isInline` signals if the constraint is being
+ /// defined in a non-global context.
+ FailureOr<ast::UserConstraintDecl *>
+ parseUserConstraintDecl(bool isInline = false);
+
+ /// Parse an inline UserConstraintDecl. An inline decl is one defined in a
+ /// non-global context, such as within a Pattern/Constraint/etc.
+ FailureOr<ast::UserConstraintDecl *> parseInlineUserConstraintDecl();
+
+ /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
+ /// PDLL constructs.
+ FailureOr<ast::UserConstraintDecl *> parseUserPDLLConstraintDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
+
+ /// Parse a parseUserRewriteDecl. `isInline` signals if the rewrite is being
+ /// defined in a non-global context.
+ FailureOr<ast::UserRewriteDecl *> parseUserRewriteDecl(bool isInline = false);
+
+ /// Parse an inline UserRewriteDecl. An inline decl is one defined in a
+ /// non-global context, such as within a Pattern/Rewrite/etc.
+ FailureOr<ast::UserRewriteDecl *> parseInlineUserRewriteDecl();
+
+ /// Parse a PDLL (i.e. non-native) UserRewriteDecl whose body is defined using
+ /// PDLL constructs.
+ FailureOr<ast::UserRewriteDecl *> parseUserPDLLRewriteDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
+
+ /// Parse either a UserConstraintDecl or UserRewriteDecl. These decls have
+ /// effectively the same syntax, and only
diff er on slight semantics (given
+ /// the
diff erent parsing contexts).
+ template <typename T, typename ParseUserPDLLDeclFnT>
+ FailureOr<T *> parseUserConstraintOrRewriteDecl(
+ ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
+ StringRef anonymousNamePrefix, bool isInline);
+
+ /// Parse a native (i.e. non-PDLL) UserConstraintDecl or UserRewriteDecl.
+ /// These decls have effectively the same syntax.
+ template <typename T>
+ FailureOr<T *> parseUserNativeConstraintOrRewriteDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType);
+
+ /// Parse the functional signature (i.e. the arguments and results) of a
+ /// UserConstraintDecl or UserRewriteDecl.
+ LogicalResult parseUserConstraintOrRewriteSignature(
+ SmallVectorImpl<ast::VariableDecl *> &arguments,
+ SmallVectorImpl<ast::VariableDecl *> &results,
+ ast::DeclScope *&argumentScope, ast::Type &resultType);
+
+ /// Validate the return (which if present is specified by bodyIt) of a
+ /// UserConstraintDecl or UserRewriteDecl.
+ LogicalResult validateUserConstraintOrRewriteReturn(
+ StringRef declType, ast::CompoundStmt *body,
+ ArrayRef<ast::Stmt *>::iterator bodyIt,
+ ArrayRef<ast::Stmt *>::iterator bodyE,
+ ArrayRef<ast::VariableDecl *> results, ast::Type &resultType);
+
FailureOr<ast::CompoundStmt *>
parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
bool expectTerminalSemicolon = true);
@@ -138,10 +212,17 @@ class Parser {
/// location of a previously parsed type constraint for the entity that will
/// be constrained by the parsed constraint. `existingConstraints` are any
/// existing constraints that have already been parsed for the same entity
- /// that will be constrained by this constraint.
+ /// that will be constrained by this constraint. `allowInlineTypeConstraints`
+ /// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
FailureOr<ast::ConstraintRef>
parseConstraint(Optional<SMRange> &typeConstraint,
- ArrayRef<ast::ConstraintRef> existingConstraints);
+ ArrayRef<ast::ConstraintRef> existingConstraints,
+ bool allowInlineTypeConstraints);
+
+ /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
+ /// argument or result variable. The constraints for these variables do not
+ /// allow inline type constraints, and only permit a single constraint.
+ FailureOr<ast::ConstraintRef> parseArgOrResultConstraint();
//===--------------------------------------------------------------------===//
// Exprs
@@ -150,8 +231,11 @@ class Parser {
/// Identifier expressions.
FailureOr<ast::Expr *> parseAttributeExpr();
+ FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
FailureOr<ast::Expr *> parseIdentifierExpr();
+ FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
+ FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
@@ -168,6 +252,7 @@ class Parser {
FailureOr<ast::EraseStmt *> parseEraseStmt();
FailureOr<ast::LetStmt *> parseLetStmt();
FailureOr<ast::ReplaceStmt *> parseReplaceStmt();
+ FailureOr<ast::ReturnStmt *> parseReturnStmt();
FailureOr<ast::RewriteStmt *> parseRewriteStmt();
//===--------------------------------------------------------------------===//
@@ -177,6 +262,10 @@ class Parser {
//===--------------------------------------------------------------------===//
// Decls
+ /// Try to extract a callable from the given AST node. Returns nullptr on
+ /// failure.
+ ast::CallableDecl *tryExtractCallableDecl(ast::Node *node);
+
/// Try to create a pattern decl with the given components, returning the
/// Pattern on success.
FailureOr<ast::PatternDecl *>
@@ -184,12 +273,30 @@ class Parser {
const ParsedPatternMetadata &metadata,
ast::CompoundStmt *body);
+ /// Build the result type for a UserConstraintDecl/UserRewriteDecl given a set
+ /// of results, defined as part of the signature.
+ ast::Type
+ createUserConstraintRewriteResultType(ArrayRef<ast::VariableDecl *> results);
+
+ /// Create a PDLL (i.e. non-native) UserConstraintDecl or UserRewriteDecl.
+ template <typename T>
+ FailureOr<T *> createUserPDLLConstraintOrRewriteDecl(
+ const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
+ ast::CompoundStmt *body);
+
/// Try to create a variable decl with the given components, returning the
/// Variable on success.
FailureOr<ast::VariableDecl *>
createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
ArrayRef<ast::ConstraintRef> constraints);
+ /// Create a variable for an argument or result defined as part of the
+ /// signature of a UserConstraintDecl/UserRewriteDecl.
+ FailureOr<ast::VariableDecl *>
+ createArgOrResultVariableDecl(StringRef name, SMRange loc,
+ const ast::ConstraintRef &constraint);
+
/// Validate the constraints used to constraint a variable decl.
/// `inferredType` is the type of the variable inferred by the constraints
/// within the list, and is updated to the most refined type as determined by
@@ -201,23 +308,26 @@ class Parser {
/// Validate a single reference to a constraint. `inferredType` contains the
/// currently inferred variabled type and is refined within the type defined
/// by the constraint. Returns success if the constraint is valid, failure
- /// otherwise.
+ /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
+ /// defined constraints) may be used with the variable.
LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
- ast::Type &inferredType);
+ ast::Type &inferredType,
+ bool allowNonCoreConstraints = true);
LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
//===--------------------------------------------------------------------===//
// Exprs
- FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc,
- ast::Decl *decl);
+ FailureOr<ast::CallExpr *>
+ createCallExpr(SMRange loc, ast::Expr *parentExpr,
+ MutableArrayRef<ast::Expr *> arguments);
+ FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
FailureOr<ast::DeclRefExpr *>
createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
ArrayRef<ast::ConstraintRef> constraints);
FailureOr<ast::MemberAccessExpr *>
- createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
- SMRange loc);
+ createMemberAccessExpr(ast::Expr *parentExpr, StringRef name, SMRange loc);
/// Validate the member access `name` into the given parent expression. On
/// success, this also returns the type of the member accessed.
@@ -231,12 +341,10 @@ class Parser {
LogicalResult
validateOperationOperands(SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> operands);
- LogicalResult validateOperationResults(SMRange loc,
- Optional<StringRef> name,
+ LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> results);
LogicalResult
- validateOperationOperandsOrResults(SMRange loc,
- Optional<StringRef> name,
+ validateOperationOperandsOrResults(SMRange loc, Optional<StringRef> name,
MutableArrayRef<ast::Expr *> values,
ast::Type singleTy, ast::Type rangeTy);
FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
@@ -246,8 +354,7 @@ class Parser {
//===--------------------------------------------------------------------===//
// Stmts
- FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc,
- ast::Expr *rootOp);
+ FailureOr<ast::EraseStmt *> createEraseStmt(SMRange loc, ast::Expr *rootOp);
FailureOr<ast::ReplaceStmt *>
createReplaceStmt(SMRange loc, ast::Expr *rootOp,
MutableArrayRef<ast::Expr *> replValues);
@@ -304,8 +411,8 @@ class Parser {
LogicalResult emitError(const Twine &msg) {
return emitError(curToken.getLoc(), msg);
}
- LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg,
- SMRange noteLoc, const Twine ¬e) {
+ LogicalResult emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
+ const Twine ¬e) {
lexer.emitErrorAndNote(loc, msg, noteLoc, note);
return failure();
}
@@ -333,6 +440,9 @@ class Parser {
/// Cached types to simplify verification and expression creation.
ast::Type valueTy, valueRangeTy;
ast::Type typeTy, typeRangeTy;
+
+ /// A counter used when naming anonymous constraints and rewrites.
+ unsigned anonymousDeclNameCounter = 0;
};
} // namespace
@@ -506,9 +616,15 @@ LogicalResult Parser::parseInclude(SmallVector<ast::Decl *> &decls) {
FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
FailureOr<ast::Decl *> decl;
switch (curToken.getKind()) {
+ case Token::kw_Constraint:
+ decl = parseUserConstraintDecl();
+ break;
case Token::kw_Pattern:
decl = parsePatternDecl();
break;
+ case Token::kw_Rewrite:
+ decl = parseUserRewriteDecl();
+ break;
default:
return emitError("expected top-level declaration, such as a `Pattern`");
}
@@ -570,6 +686,363 @@ FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
}
+FailureOr<ast::VariableDecl *> Parser::parseArgumentDecl() {
+ // Ensure that the argument is named.
+ if (curToken.isNot(Token::identifier) && !curToken.isDependentKeyword())
+ return emitError("expected identifier argument name");
+
+ // Parse the argument similarly to a normal variable.
+ StringRef name = curToken.getSpelling();
+ SMRange nameLoc = curToken.getLoc();
+ consumeToken();
+
+ if (failed(
+ parseToken(Token::colon, "expected `:` before argument constraint")))
+ return failure();
+
+ FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
+ if (failed(cst))
+ return failure();
+
+ return createArgOrResultVariableDecl(name, nameLoc, *cst);
+}
+
+FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
+ // Check to see if this result is named.
+ if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
+ // Check to see if this name actually refers to a Constraint.
+ ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
+ if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
+ // If yes, and this is a Rewrite, give a nice error message as non-Core
+ // constraints are not supported on Rewrite results.
+ if (parserContext == ParserContext::Rewrite) {
+ return emitError(
+ "`Rewrite` results are only permitted to use core constraints, "
+ "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
+ }
+
+ // Otherwise, parse this as an unnamed result variable.
+ } else {
+ // If it wasn't a constraint, parse the result similarly to a variable. If
+ // there is already an existing decl, we will emit an error when defining
+ // this variable later.
+ StringRef name = curToken.getSpelling();
+ SMRange nameLoc = curToken.getLoc();
+ consumeToken();
+
+ if (failed(parseToken(Token::colon,
+ "expected `:` before result constraint")))
+ return failure();
+
+ FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
+ if (failed(cst))
+ return failure();
+
+ return createArgOrResultVariableDecl(name, nameLoc, *cst);
+ }
+ }
+
+ // If it isn't named, we parse the constraint directly and create an unnamed
+ // result variable.
+ FailureOr<ast::ConstraintRef> cst = parseArgOrResultConstraint();
+ if (failed(cst))
+ return failure();
+
+ return createArgOrResultVariableDecl("", cst->referenceLoc, *cst);
+}
+
+FailureOr<ast::UserConstraintDecl *>
+Parser::parseUserConstraintDecl(bool isInline) {
+ // Constraints and rewrites have very similar formats, dispatch to a shared
+ // interface for parsing.
+ return parseUserConstraintOrRewriteDecl<ast::UserConstraintDecl>(
+ [&](auto &&...args) { return parseUserPDLLConstraintDecl(args...); },
+ ParserContext::Constraint, "constraint", isInline);
+}
+
+FailureOr<ast::UserConstraintDecl *> Parser::parseInlineUserConstraintDecl() {
+ FailureOr<ast::UserConstraintDecl *> decl =
+ parseUserConstraintDecl(/*isInline=*/true);
+ if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
+ return failure();
+
+ curDeclScope->add(*decl);
+ return decl;
+}
+
+FailureOr<ast::UserConstraintDecl *> Parser::parseUserPDLLConstraintDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
+ // Push the argument scope back onto the list, so that the body can
+ // reference arguments.
+ pushDeclScope(argumentScope);
+
+ // Parse the body of the constraint. The body is either defined as a compound
+ // block, i.e. `{ ... }`, or a lambda body, i.e. `=> <expr>`.
+ ast::CompoundStmt *body;
+ if (curToken.is(Token::equal_arrow)) {
+ FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
+ [&](ast::Stmt *&stmt) -> LogicalResult {
+ ast::Expr *stmtExpr = dyn_cast<ast::Expr>(stmt);
+ if (!stmtExpr) {
+ return emitError(stmt->getLoc(),
+ "expected `Constraint` lambda body to contain a "
+ "single expression");
+ }
+ stmt = ast::ReturnStmt::create(ctx, stmt->getLoc(), stmtExpr);
+ return success();
+ },
+ /*expectTerminalSemicolon=*/!isInline);
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+ } else {
+ FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+
+ // Verify the structure of the body.
+ auto bodyIt = body->begin(), bodyE = body->end();
+ for (; bodyIt != bodyE; ++bodyIt)
+ if (isa<ast::ReturnStmt>(*bodyIt))
+ break;
+ if (failed(validateUserConstraintOrRewriteReturn(
+ "Constraint", body, bodyIt, bodyE, results, resultType)))
+ return failure();
+ }
+ popDeclScope();
+
+ return createUserPDLLConstraintOrRewriteDecl<ast::UserConstraintDecl>(
+ name, arguments, results, resultType, body);
+}
+
+FailureOr<ast::UserRewriteDecl *> Parser::parseUserRewriteDecl(bool isInline) {
+ // Constraints and rewrites have very similar formats, dispatch to a shared
+ // interface for parsing.
+ return parseUserConstraintOrRewriteDecl<ast::UserRewriteDecl>(
+ [&](auto &&...args) { return parseUserPDLLRewriteDecl(args...); },
+ ParserContext::Rewrite, "rewrite", isInline);
+}
+
+FailureOr<ast::UserRewriteDecl *> Parser::parseInlineUserRewriteDecl() {
+ FailureOr<ast::UserRewriteDecl *> decl =
+ parseUserRewriteDecl(/*isInline=*/true);
+ if (failed(decl) || failed(checkDefineNamedDecl((*decl)->getName())))
+ return failure();
+
+ curDeclScope->add(*decl);
+ return decl;
+}
+
+FailureOr<ast::UserRewriteDecl *> Parser::parseUserPDLLRewriteDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments, ast::DeclScope *argumentScope,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
+ // Push the argument scope back onto the list, so that the body can
+ // reference arguments.
+ curDeclScope = argumentScope;
+ ast::CompoundStmt *body;
+ if (curToken.is(Token::equal_arrow)) {
+ FailureOr<ast::CompoundStmt *> bodyResult = parseLambdaBody(
+ [&](ast::Stmt *&statement) -> LogicalResult {
+ if (isa<ast::OpRewriteStmt>(statement))
+ return success();
+
+ ast::Expr *statementExpr = dyn_cast<ast::Expr>(statement);
+ if (!statementExpr) {
+ return emitError(
+ statement->getLoc(),
+ "expected `Rewrite` lambda body to contain a single expression "
+ "or an operation rewrite statement; such as `erase`, "
+ "`replace`, or `rewrite`");
+ }
+ statement =
+ ast::ReturnStmt::create(ctx, statement->getLoc(), statementExpr);
+ return success();
+ },
+ /*expectTerminalSemicolon=*/!isInline);
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+ } else {
+ FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+ }
+ popDeclScope();
+
+ // Verify the structure of the body.
+ auto bodyIt = body->begin(), bodyE = body->end();
+ for (; bodyIt != bodyE; ++bodyIt)
+ if (isa<ast::ReturnStmt>(*bodyIt))
+ break;
+ if (failed(validateUserConstraintOrRewriteReturn("Rewrite", body, bodyIt,
+ bodyE, results, resultType)))
+ return failure();
+ return createUserPDLLConstraintOrRewriteDecl<ast::UserRewriteDecl>(
+ name, arguments, results, resultType, body);
+}
+
+template <typename T, typename ParseUserPDLLDeclFnT>
+FailureOr<T *> Parser::parseUserConstraintOrRewriteDecl(
+ ParseUserPDLLDeclFnT &&parseUserPDLLFn, ParserContext declContext,
+ StringRef anonymousNamePrefix, bool isInline) {
+ SMRange loc = curToken.getLoc();
+ consumeToken();
+ llvm::SaveAndRestore<ParserContext> saveCtx(parserContext, declContext);
+
+ // Parse the name of the decl.
+ const ast::Name *name = nullptr;
+ if (curToken.isNot(Token::identifier)) {
+ // Only inline decls can be un-named. Inline decls are similar to "lambdas"
+ // in C++, so being unnamed is fine.
+ if (!isInline)
+ return emitError("expected identifier name");
+
+ // Create a unique anonymous name to use, as the name for this decl is not
+ // important.
+ std::string anonName =
+ llvm::formatv("<anonymous_{0}_{1}>", anonymousNamePrefix,
+ anonymousDeclNameCounter++)
+ .str();
+ name = &ast::Name::create(ctx, anonName, loc);
+ } else {
+ // If a name was provided, we can use it directly.
+ name = &ast::Name::create(ctx, curToken.getSpelling(), curToken.getLoc());
+ consumeToken(Token::identifier);
+ }
+
+ // Parse the functional signature of the decl.
+ SmallVector<ast::VariableDecl *> arguments, results;
+ ast::DeclScope *argumentScope;
+ ast::Type resultType;
+ if (failed(parseUserConstraintOrRewriteSignature(arguments, results,
+ argumentScope, resultType)))
+ return failure();
+
+ // Check to see which type of constraint this is. If the constraint contains a
+ // compound body, this is a PDLL decl.
+ if (curToken.isAny(Token::l_brace, Token::equal_arrow))
+ return parseUserPDLLFn(*name, isInline, arguments, argumentScope, results,
+ resultType);
+
+ // Otherwise, this is a native decl.
+ return parseUserNativeConstraintOrRewriteDecl<T>(*name, isInline, arguments,
+ results, resultType);
+}
+
+template <typename T>
+FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
+ const ast::Name &name, bool isInline,
+ ArrayRef<ast::VariableDecl *> arguments,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType) {
+ // If followed by a string, the native code body has also been specified.
+ std::string codeStrStorage;
+ Optional<StringRef> optCodeStr;
+ if (curToken.isString()) {
+ codeStrStorage = curToken.getStringValue();
+ optCodeStr = codeStrStorage;
+ consumeToken();
+ } else if (isInline) {
+ return emitError(name.getLoc(),
+ "external declarations must be declared in global scope");
+ }
+ if (failed(parseToken(Token::semicolon,
+ "expected `;` after native declaration")))
+ return failure();
+ return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
+}
+
+LogicalResult Parser::parseUserConstraintOrRewriteSignature(
+ SmallVectorImpl<ast::VariableDecl *> &arguments,
+ SmallVectorImpl<ast::VariableDecl *> &results,
+ ast::DeclScope *&argumentScope, ast::Type &resultType) {
+ // Parse the argument list of the decl.
+ if (failed(parseToken(Token::l_paren, "expected `(` to start argument list")))
+ return failure();
+
+ argumentScope = pushDeclScope();
+ if (curToken.isNot(Token::r_paren)) {
+ do {
+ FailureOr<ast::VariableDecl *> argument = parseArgumentDecl();
+ if (failed(argument))
+ return failure();
+ arguments.emplace_back(*argument);
+ } while (consumeIf(Token::comma));
+ }
+ popDeclScope();
+ if (failed(parseToken(Token::r_paren, "expected `)` to end argument list")))
+ return failure();
+
+ // Parse the results of the decl.
+ pushDeclScope();
+ if (consumeIf(Token::arrow)) {
+ auto parseResultFn = [&]() -> LogicalResult {
+ FailureOr<ast::VariableDecl *> result = parseResultDecl(results.size());
+ if (failed(result))
+ return failure();
+ results.emplace_back(*result);
+ return success();
+ };
+
+ // Check for a list of results.
+ if (consumeIf(Token::l_paren)) {
+ do {
+ if (failed(parseResultFn()))
+ return failure();
+ } while (consumeIf(Token::comma));
+ if (failed(parseToken(Token::r_paren, "expected `)` to end result list")))
+ return failure();
+
+ // Otherwise, there is only one result.
+ } else if (failed(parseResultFn())) {
+ return failure();
+ }
+ }
+ popDeclScope();
+
+ // Compute the result type of the decl.
+ resultType = createUserConstraintRewriteResultType(results);
+
+ // Verify that results are only named if there are more than one.
+ if (results.size() == 1 && !results.front()->getName().getName().empty()) {
+ return emitError(
+ results.front()->getLoc(),
+ "cannot create a single-element tuple with an element label");
+ }
+ return success();
+}
+
+LogicalResult Parser::validateUserConstraintOrRewriteReturn(
+ StringRef declType, ast::CompoundStmt *body,
+ ArrayRef<ast::Stmt *>::iterator bodyIt,
+ ArrayRef<ast::Stmt *>::iterator bodyE,
+ ArrayRef<ast::VariableDecl *> results, ast::Type &resultType) {
+ // Handle if a `return` was provided.
+ if (bodyIt != bodyE) {
+ // Emit an error if we have trailing statements after the return.
+ if (std::next(bodyIt) != bodyE) {
+ return emitError(
+ (*std::next(bodyIt))->getLoc(),
+ llvm::formatv("`return` terminated the `{0}` body, but found "
+ "trailing statements afterwards",
+ declType));
+ }
+
+ // Otherwise if a return wasn't provided, check that no results are
+ // expected.
+ } else if (!results.empty()) {
+ return emitError(
+ {body->getLoc().End, body->getLoc().End},
+ llvm::formatv("missing return in a `{0}` expected to return `{1}`",
+ declType, resultType));
+ }
+ return success();
+}
+
FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
if (isa<ast::OpRewriteStmt>(statement))
@@ -619,6 +1092,11 @@ FailureOr<ast::Decl *> Parser::parsePatternDecl() {
// Verify the body of the pattern.
auto bodyIt = body->begin(), bodyE = body->end();
for (; bodyIt != bodyE; ++bodyIt) {
+ if (isa<ast::ReturnStmt>(*bodyIt)) {
+ return emitError((*bodyIt)->getLoc(),
+ "`return` statements are only permitted within a "
+ "`Constraint` or `Rewrite` body");
+ }
// Break when we've found the rewrite statement.
if (isa<ast::OpRewriteStmt>(*bodyIt))
break;
@@ -719,8 +1197,8 @@ LogicalResult Parser::checkDefineNamedDecl(const ast::Name &name) {
}
FailureOr<ast::VariableDecl *>
-Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
- ast::Type type, ast::Expr *initExpr,
+Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
+ ast::Expr *initExpr,
ArrayRef<ast::ConstraintRef> constraints) {
assert(curDeclScope && "defining variable outside of decl scope");
const ast::Name &nameDecl = ast::Name::create(ctx, name, nameLoc);
@@ -741,8 +1219,7 @@ Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
}
FailureOr<ast::VariableDecl *>
-Parser::defineVariableDecl(StringRef name, SMRange nameLoc,
- ast::Type type,
+Parser::defineVariableDecl(StringRef name, SMRange nameLoc, ast::Type type,
ArrayRef<ast::ConstraintRef> constraints) {
return defineVariableDecl(name, nameLoc, type, /*initExpr=*/nullptr,
constraints);
@@ -752,8 +1229,8 @@ LogicalResult Parser::parseVariableDeclConstraintList(
SmallVectorImpl<ast::ConstraintRef> &constraints) {
Optional<SMRange> typeConstraint;
auto parseSingleConstraint = [&] {
- FailureOr<ast::ConstraintRef> constraint =
- parseConstraint(typeConstraint, constraints);
+ FailureOr<ast::ConstraintRef> constraint = parseConstraint(
+ typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
if (failed(constraint))
return failure();
constraints.push_back(*constraint);
@@ -773,8 +1250,15 @@ LogicalResult Parser::parseVariableDeclConstraintList(
FailureOr<ast::ConstraintRef>
Parser::parseConstraint(Optional<SMRange> &typeConstraint,
- ArrayRef<ast::ConstraintRef> existingConstraints) {
+ ArrayRef<ast::ConstraintRef> existingConstraints,
+ bool allowInlineTypeConstraints) {
auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
+ if (!allowInlineTypeConstraints) {
+ return emitError(
+ curToken.getLoc(),
+ "inline `Attr`, `Value`, and `ValueRange` type constraints are not "
+ "permitted on arguments or results");
+ }
if (typeConstraint)
return emitErrorAndNote(
curToken.getLoc(),
@@ -842,6 +1326,14 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
return ast::ConstraintRef(
ast::ValueRangeConstraintDecl::create(ctx, loc, typeExpr), loc);
}
+
+ case Token::kw_Constraint: {
+ // Handle an inline constraint.
+ FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
+ if (failed(decl))
+ return failure();
+ return ast::ConstraintRef(*decl, loc);
+ }
case Token::identifier: {
StringRef constraintName = curToken.getSpelling();
consumeToken(Token::identifier);
@@ -867,6 +1359,12 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
return emitError(loc, "expected identifier constraint");
}
+FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
+ Optional<SMRange> typeConstraint;
+ return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
+ /*allowInlineTypeConstraints=*/false);
+}
+
//===----------------------------------------------------------------------===//
// Exprs
@@ -880,12 +1378,18 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
case Token::kw_attr:
lhsExpr = parseAttributeExpr();
break;
+ case Token::kw_Constraint:
+ lhsExpr = parseInlineConstraintLambdaExpr();
+ break;
case Token::identifier:
lhsExpr = parseIdentifierExpr();
break;
case Token::kw_op:
lhsExpr = parseOperationExpr();
break;
+ case Token::kw_Rewrite:
+ lhsExpr = parseInlineRewriteLambdaExpr();
+ break;
case Token::kw_type:
lhsExpr = parseTypeExpr();
break;
@@ -904,6 +1408,9 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
case Token::dot:
lhsExpr = parseMemberAccessExpr(*lhsExpr);
break;
+ case Token::l_paren:
+ lhsExpr = parseCallExpr(*lhsExpr);
+ break;
default:
return lhsExpr;
}
@@ -934,8 +1441,28 @@ FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
return ast::AttributeExpr::create(ctx, loc, attrExpr);
}
-FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name,
- SMRange loc) {
+FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
+ SMRange loc = curToken.getLoc();
+ consumeToken(Token::l_paren);
+
+ // Parse the arguments of the call.
+ SmallVector<ast::Expr *> arguments;
+ if (curToken.isNot(Token::r_paren)) {
+ do {
+ FailureOr<ast::Expr *> argument = parseExpr();
+ if (failed(argument))
+ return failure();
+ arguments.push_back(*argument);
+ } while (consumeIf(Token::comma));
+ }
+ loc.End = curToken.getEndLoc();
+ if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
+ return failure();
+
+ return createCallExpr(loc, parentExpr, arguments);
+}
+
+FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
ast::Decl *decl = curDeclScope->lookup(name);
if (!decl)
return emitError(loc, "undefined reference to `" + name + "`");
@@ -963,6 +1490,24 @@ FailureOr<ast::Expr *> Parser::parseIdentifierExpr() {
return parseDeclRefExpr(name, nameLoc);
}
+FailureOr<ast::Expr *> Parser::parseInlineConstraintLambdaExpr() {
+ FailureOr<ast::UserConstraintDecl *> decl = parseInlineUserConstraintDecl();
+ if (failed(decl))
+ return failure();
+
+ return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
+ ast::ConstraintType::get(ctx));
+}
+
+FailureOr<ast::Expr *> Parser::parseInlineRewriteLambdaExpr() {
+ FailureOr<ast::UserRewriteDecl *> decl = parseInlineUserRewriteDecl();
+ if (failed(decl))
+ return failure();
+
+ return ast::DeclRefExpr::create(ctx, (*decl)->getLoc(), *decl,
+ ast::RewriteType::get(ctx));
+}
+
FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
SMRange loc = curToken.getLoc();
consumeToken(Token::dot);
@@ -1202,6 +1747,9 @@ FailureOr<ast::Stmt *> Parser::parseStmt(bool expectTerminalSemicolon) {
case Token::kw_replace:
stmt = parseReplaceStmt();
break;
+ case Token::kw_return:
+ stmt = parseReturnStmt();
+ break;
case Token::kw_rewrite:
stmt = parseRewriteStmt();
break;
@@ -1239,6 +1787,8 @@ FailureOr<ast::CompoundStmt *> Parser::parseCompoundStmt() {
}
FailureOr<ast::EraseStmt *> Parser::parseEraseStmt() {
+ if (parserContext == ParserContext::Constraint)
+ return emitError("`erase` cannot be used within a Constraint");
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_erase);
@@ -1311,6 +1861,8 @@ FailureOr<ast::LetStmt *> Parser::parseLetStmt() {
}
FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
+ if (parserContext == ParserContext::Constraint)
+ return emitError("`replace` cannot be used within a Constraint");
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_replace);
@@ -1356,7 +1908,21 @@ FailureOr<ast::ReplaceStmt *> Parser::parseReplaceStmt() {
return createReplaceStmt(loc, *rootOp, replValues);
}
+FailureOr<ast::ReturnStmt *> Parser::parseReturnStmt() {
+ SMRange loc = curToken.getLoc();
+ consumeToken(Token::kw_return);
+
+ // Parse the result value.
+ FailureOr<ast::Expr *> resultExpr = parseExpr();
+ if (failed(resultExpr))
+ return failure();
+
+ return ast::ReturnStmt::create(ctx, loc, *resultExpr);
+}
+
FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
+ if (parserContext == ParserContext::Constraint)
+ return emitError("`rewrite` cannot be used within a Constraint");
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_rewrite);
@@ -1379,6 +1945,15 @@ FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
if (failed(rewriteBody))
return failure();
+ // Verify the rewrite body.
+ for (const ast::Stmt *stmt : (*rewriteBody)->getChildren()) {
+ if (isa<ast::ReturnStmt>(stmt)) {
+ return emitError(stmt->getLoc(),
+ "`return` statements are only permitted within a "
+ "`Constraint` or `Rewrite` body");
+ }
+ }
+
return createRewriteStmt(loc, *rootOp, *rewriteBody);
}
@@ -1389,6 +1964,13 @@ FailureOr<ast::RewriteStmt *> Parser::parseRewriteStmt() {
//===----------------------------------------------------------------------===//
// Decls
+ast::CallableDecl *Parser::tryExtractCallableDecl(ast::Node *node) {
+ // Unwrap reference expressions.
+ if (auto *init = dyn_cast<ast::DeclRefExpr>(node))
+ node = init->getDecl();
+ return dyn_cast<ast::CallableDecl>(node);
+}
+
FailureOr<ast::PatternDecl *>
Parser::createPatternDecl(SMRange loc, const ast::Name *name,
const ParsedPatternMetadata &metadata,
@@ -1397,9 +1979,47 @@ Parser::createPatternDecl(SMRange loc, const ast::Name *name,
metadata.hasBoundedRecursion, body);
}
+ast::Type Parser::createUserConstraintRewriteResultType(
+ ArrayRef<ast::VariableDecl *> results) {
+ // Single result decls use the type of the single result.
+ if (results.size() == 1)
+ return results[0]->getType();
+
+ // Multiple results use a tuple type, with the types and names grabbed from
+ // the result variable decls.
+ auto resultTypes = llvm::map_range(
+ results, [&](const auto *result) { return result->getType(); });
+ auto resultNames = llvm::map_range(
+ results, [&](const auto *result) { return result->getName().getName(); });
+ return ast::TupleType::get(ctx, llvm::to_vector(resultTypes),
+ llvm::to_vector(resultNames));
+}
+
+template <typename T>
+FailureOr<T *> Parser::createUserPDLLConstraintOrRewriteDecl(
+ const ast::Name &name, ArrayRef<ast::VariableDecl *> arguments,
+ ArrayRef<ast::VariableDecl *> results, ast::Type resultType,
+ ast::CompoundStmt *body) {
+ if (!body->getChildren().empty()) {
+ if (auto *retStmt = dyn_cast<ast::ReturnStmt>(body->getChildren().back())) {
+ ast::Expr *resultExpr = retStmt->getResultExpr();
+
+ // Process the result of the decl. If no explicit signature results
+ // were provided, check for return type inference. Otherwise, check that
+ // the return expression can be converted to the expected type.
+ if (results.empty())
+ resultType = resultExpr->getType();
+ else if (failed(convertExpressionTo(resultExpr, resultType)))
+ return failure();
+ else
+ retStmt->setResultExpr(resultExpr);
+ }
+ }
+ return T::createPDLL(ctx, name, arguments, results, body, resultType);
+}
+
FailureOr<ast::VariableDecl *>
-Parser::createVariableDecl(StringRef name, SMRange loc,
- ast::Expr *initializer,
+Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
ArrayRef<ast::ConstraintRef> constraints) {
// The type of the variable, which is expected to be inferred by either a
// constraint or an initializer expression.
@@ -1426,6 +2046,12 @@ Parser::createVariableDecl(StringRef name, SMRange loc,
"list or the initializer");
}
+ // Constraint types cannot be used when defining variables.
+ if (type.isa<ast::ConstraintType, ast::RewriteType>()) {
+ return emitError(
+ loc, llvm::formatv("unable to define variable of `{0}` type", type));
+ }
+
// Try to define a variable with the given name.
FailureOr<ast::VariableDecl *> varDecl =
defineVariableDecl(name, loc, type, initializer, constraints);
@@ -1435,6 +2061,18 @@ Parser::createVariableDecl(StringRef name, SMRange loc,
return *varDecl;
}
+FailureOr<ast::VariableDecl *>
+Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
+ const ast::ConstraintRef &constraint) {
+ // Constraint arguments may apply more complex constraints via the arguments.
+ bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
+ ast::Type argType;
+ if (failed(validateVariableConstraint(constraint, argType,
+ allowNonCoreConstraints)))
+ return failure();
+ return defineVariableDecl(name, loc, argType, constraint);
+}
+
LogicalResult
Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
ast::Type &inferredType) {
@@ -1445,7 +2083,8 @@ Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
}
LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
- ast::Type &inferredType) {
+ ast::Type &inferredType,
+ bool allowNonCoreConstraints) {
ast::Type constraintType;
if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
@@ -1474,6 +2113,25 @@ LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
return failure();
}
constraintType = valueRangeTy;
+ } else if (const auto *cst =
+ dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
+ if (!allowNonCoreConstraints) {
+ return emitError(ref.referenceLoc,
+ "`Rewrite` arguments and results are only permitted to "
+ "use core constraints, such as `Attr`, `Op`, `Type`, "
+ "`TypeRange`, `Value`, `ValueRange`");
+ }
+
+ ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
+ if (inputs.size() != 1) {
+ return emitErrorAndNote(ref.referenceLoc,
+ "`Constraint`s applied via a variable constraint "
+ "list must take a single input, but got " +
+ Twine(inputs.size()),
+ cst->getLoc(),
+ "see definition of constraint here");
+ }
+ constraintType = inputs.front()->getType();
} else {
llvm_unreachable("unknown constraint type");
}
@@ -1515,11 +2173,66 @@ Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
//===----------------------------------------------------------------------===//
// Exprs
+FailureOr<ast::CallExpr *>
+Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
+ MutableArrayRef<ast::Expr *> arguments) {
+ ast::Type parentType = parentExpr->getType();
+
+ ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
+ if (!callableDecl) {
+ return emitError(loc,
+ llvm::formatv("expected a reference to a callable "
+ "`Constraint` or `Rewrite`, but got: `{0}`",
+ parentType));
+ }
+ if (parserContext == ParserContext::Rewrite) {
+ if (isa<ast::UserConstraintDecl>(callableDecl))
+ return emitError(
+ loc, "unable to invoke `Constraint` within a rewrite section");
+ } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
+ return emitError(loc, "unable to invoke `Rewrite` within a match section");
+ }
+
+ // Verify the arguments of the call.
+ /// Handle size mismatch.
+ ArrayRef<ast::VariableDecl *> callArgs = callableDecl->getInputs();
+ if (callArgs.size() != arguments.size()) {
+ return emitErrorAndNote(
+ loc,
+ llvm::formatv("invalid number of arguments for {0} call; expected "
+ "{1}, but got {2}",
+ callableDecl->getCallableType(), callArgs.size(),
+ arguments.size()),
+ callableDecl->getLoc(),
+ llvm::formatv("see the definition of {0} here",
+ callableDecl->getName()->getName()));
+ }
+
+ /// Handle argument type mismatch.
+ auto attachDiagFn = [&](ast::Diagnostic &diag) {
+ diag.attachNote(llvm::formatv("see the definition of `{0}` here",
+ callableDecl->getName()->getName()),
+ callableDecl->getLoc());
+ };
+ for (auto it : llvm::zip(callArgs, arguments)) {
+ if (failed(convertExpressionTo(std::get<1>(it), std::get<0>(it)->getType(),
+ attachDiagFn)))
+ return failure();
+ }
+
+ return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
+ callableDecl->getResultType());
+}
+
FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
ast::Decl *decl) {
// Check the type of decl being referenced.
ast::Type declType;
- if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
+ if (isa<ast::ConstraintDecl>(decl))
+ declType = ast::ConstraintType::get(ctx);
+ else if (isa<ast::UserRewriteDecl>(decl))
+ declType = ast::RewriteType::get(ctx);
+ else if (auto *varDecl = dyn_cast<ast::VariableDecl>(decl))
declType = varDecl->getType();
else
return emitError(loc, "invalid reference to `" +
@@ -1529,8 +2242,7 @@ FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
}
FailureOr<ast::DeclRefExpr *>
-Parser::createInlineVariableExpr(ast::Type type, StringRef name,
- SMRange loc,
+Parser::createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
ArrayRef<ast::ConstraintRef> constraints) {
FailureOr<ast::VariableDecl *> decl =
defineVariableDecl(name, loc, type, constraints);
@@ -1551,8 +2263,7 @@ Parser::createMemberAccessExpr(ast::Expr *parentExpr, StringRef name,
}
FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
- StringRef name,
- SMRange loc) {
+ StringRef name, SMRange loc) {
ast::Type parentType = parentExpr->getType();
if (parentType.isa<ast::OperationType>()) {
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
@@ -1622,9 +2333,8 @@ Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
}
LogicalResult Parser::validateOperationOperandsOrResults(
- SMRange loc, Optional<StringRef> name,
- MutableArrayRef<ast::Expr *> values, ast::Type singleTy,
- ast::Type rangeTy) {
+ SMRange loc, Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
+ ast::Type singleTy, ast::Type rangeTy) {
// All operation types accept a single range parameter.
if (values.size() == 1) {
if (failed(convertExpressionTo(values[0], rangeTy)))
@@ -1665,7 +2375,7 @@ Parser::createTupleExpr(SMRange loc, ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames) {
for (const ast::Expr *element : elements) {
ast::Type eleTy = element->getType();
- if (eleTy.isa<ast::ConstraintType, ast::TupleType>()) {
+ if (eleTy.isa<ast::ConstraintType, ast::RewriteType, ast::TupleType>()) {
return emitError(
element->getLoc(),
llvm::formatv("unable to build a tuple with `{0}` element", eleTy));
diff --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
new file mode 100644
index 0000000000000..8291913623943
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
@@ -0,0 +1,160 @@
+// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Constraint Structure
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected identifier name
+Constraint {}
+
+// -----
+
+// CHECK: :6:12: error: `Foo` has already been defined
+// CHECK: :5:12: note: see previous definition here
+Constraint Foo() { op<>; }
+Constraint Foo() { op<>; }
+
+// -----
+
+Constraint Foo() {
+ // CHECK: `erase` cannot be used within a Constraint
+ erase op<>;
+}
+
+// -----
+
+Constraint Foo() {
+ // CHECK: `replace` cannot be used within a Constraint
+ replace;
+}
+
+// -----
+
+Constraint Foo() {
+ // CHECK: `rewrite` cannot be used within a Constraint
+ rewrite;
+}
+
+// -----
+
+Constraint Foo() -> Value {
+ // CHECK: `return` terminated the `Constraint` body, but found trailing statements afterwards
+ return _: Value;
+ return _: Value;
+}
+
+// -----
+
+// CHECK: missing return in a `Constraint` expected to return `Value`
+Constraint Foo() -> Value {
+ let value: Value;
+}
+
+// -----
+
+// CHECK: expected `Constraint` lambda body to contain a single expression
+Constraint Foo() -> Value => let foo: Value;
+
+// -----
+
+// CHECK: unable to convert expression of type `Op` to the expected type of `Attr`
+Constraint Foo() -> Attr => op<>;
+
+// -----
+
+Rewrite SomeRewrite();
+
+// CHECK: unable to invoke `Rewrite` within a match section
+Constraint Foo() {
+ SomeRewrite();
+}
+
+// -----
+
+Constraint Foo() {
+ Constraint Foo() {};
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Arguments
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected `(` to start argument list
+Constraint Foo {}
+
+// -----
+
+// CHECK: expected identifier argument name
+Constraint Foo(10{}
+
+// -----
+
+// CHECK: expected `:` before argument constraint
+Constraint Foo(arg{}
+
+// -----
+
+// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
+Constraint Foo(arg: Value<type>){}
+
+// -----
+
+// CHECK: expected `)` to end argument list
+Constraint Foo(arg: Value{}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Results
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected identifier constraint
+Constraint Foo() -> {}
+
+// -----
+
+// CHECK: cannot create a single-element tuple with an element label
+Constraint Foo() -> result: Value;
+
+// -----
+
+// CHECK: cannot create a single-element tuple with an element label
+Constraint Foo() -> (result: Value);
+
+// -----
+
+// CHECK: expected identifier constraint
+Constraint Foo() -> ();
+
+// -----
+
+// CHECK: expected `:` before result constraint
+Constraint Foo() -> (result{};
+
+// -----
+
+// CHECK: expected `)` to end result list
+Constraint Foo() -> (Op{};
+
+// -----
+
+// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
+Constraint Foo() -> Value<type>){}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Native Constraints
+//===----------------------------------------------------------------------===//
+
+Pattern {
+ // CHECK: external declarations must be declared in global scope
+ Constraint ExternalConstraint();
+}
+
+// -----
+
+// CHECK: expected `;` after native declaration
+Constraint Foo() [{}]
diff --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll
new file mode 100644
index 0000000000000..1c0a015ab4a7b
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/constraint.pdll
@@ -0,0 +1,74 @@
+// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
+
+// CHECK: Module
+// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Tuple<>>
+Constraint Foo();
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Tuple<>> Code< /* Native Code */ >
+Constraint Foo() [{ /* Native Code */ }];
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
+// CHECK: `Inputs`
+// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Value>
+// CHECK: `Results`
+// CHECK: `-VariableDecl {{.*}} Name<> Type<Value>
+// CHECK: `-CompoundStmt {{.*}}
+// CHECK: `-ReturnStmt {{.*}}
+// CHECK: `-DeclRefExpr {{.*}} Type<Value>
+// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Value>
+Constraint Foo(arg: Value) -> Value => arg;
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Tuple<result1: Value, result2: Attr>>
+// CHECK: `Results`
+// CHECK: |-VariableDecl {{.*}} Name<result1> Type<Value>
+// CHECK: | `Constraints`
+// CHECK: | `-ValueConstraintDecl {{.*}}
+// CHECK: `-VariableDecl {{.*}} Name<result2> Type<Attr>
+// CHECK: `Constraints`
+// CHECK: `-AttrConstraintDecl {{.*}}
+// CHECK: `-CompoundStmt {{.*}}
+// CHECK: `-ReturnStmt {{.*}}
+// CHECK: `-TupleExpr {{.*}} Type<Tuple<result1: Value, result2: Attr>>
+// CHECK: |-MemberAccessExpr {{.*}} Member<0> Type<Value>
+// CHECK: | `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
+// CHECK: `-MemberAccessExpr {{.*}} Member<1> Type<Attr>
+// CHECK: `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
+Constraint Foo() -> (result1: Value, result2: Attr) => (_: Value, attr<"10">);
+
+// -----
+
+// CHECK: Module
+// CHECK: |-UserConstraintDecl {{.*}} Name<Bar> ResultType<Tuple<>>
+// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
+// CHECK: `Inputs`
+// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Value>
+// CHECK: `Constraints`
+// CHECK: `-UserConstraintDecl {{.*}} Name<Bar> ResultType<Tuple<>>
+// CHECK: `Results`
+// CHECK: `-VariableDecl {{.*}} Name<> Type<Value>
+// CHECK: `Constraints`
+// CHECK: `-UserConstraintDecl {{.*}} Name<Bar> ResultType<Tuple<>>
+Constraint Bar(input: Value);
+
+Constraint Foo(arg: Bar) -> Bar => arg;
+
+// -----
+
+// Test that anonymous constraints are uniquely named.
+
+// CHECK: Module
+// CHECK: UserConstraintDecl {{.*}} Name<<anonymous_constraint_0>> ResultType<Tuple<>>
+// CHECK: UserConstraintDecl {{.*}} Name<<anonymous_constraint_1>> ResultType<Attr>
+Constraint Outer() {
+ Constraint() {};
+ Constraint() => attr<"10">;
+}
diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
index 1bf73a4ad60fa..7ed3ba8057bd5 100644
--- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
@@ -43,6 +43,45 @@ Pattern {
// -----
+//===----------------------------------------------------------------------===//
+// Call Expr
+//===----------------------------------------------------------------------===//
+
+Constraint foo(value: Value);
+
+Pattern {
+ // CHECK: expected `)` after argument list
+ foo(_: Value{};
+}
+
+// -----
+
+Pattern {
+ // CHECK: expected a reference to a callable `Constraint` or `Rewrite`, but got: `Op`
+ let foo: Op;
+ foo();
+}
+
+// -----
+
+Constraint Foo();
+
+Pattern {
+ // CHECK: invalid number of arguments for constraint call; expected 0, but got 1
+ Foo(_: Value);
+}
+
+// -----
+
+Constraint Foo(arg: Value);
+
+Pattern {
+ // CHECK: unable to convert expression of type `Attr` to the expected type of `Value`
+ Foo(attr<"i32">);
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Member Access Expr
//===----------------------------------------------------------------------===//
@@ -105,6 +144,26 @@ Pattern {
// -----
+Constraint Foo();
+
+Pattern {
+ // CHECK: unable to build a tuple with `Constraint` element
+ let tuple = (Foo);
+ erase op<>;
+}
+
+// -----
+
+Rewrite Foo();
+
+Pattern {
+ // CHECK: unable to build a tuple with `Rewrite` element
+ let tuple = (Foo);
+ erase op<>;
+}
+
+// -----
+
Pattern {
// CHECK: expected expression
let tuple = (10 = _: Value);
diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll
index f457d653dbd7c..d645feaf118fa 100644
--- a/mlir/test/mlir-pdll/Parser/expr.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr.pdll
@@ -14,6 +14,42 @@ Pattern {
// -----
+//===----------------------------------------------------------------------===//
+// CallExpr
+//===----------------------------------------------------------------------===//
+
+// CHECK: Module
+// CHECK: |-UserConstraintDecl {{.*}} Name<MakeRootOp> ResultType<Op<my_dialect.foo>>
+// CHECK: `-CallExpr {{.*}} Type<Op<my_dialect.foo>>
+// CHECK: `-DeclRefExpr {{.*}} Type<Constraint>
+// CHECK: `-UserConstraintDecl {{.*}} Name<MakeRootOp> ResultType<Op<my_dialect.foo>>
+Constraint MakeRootOp() => op<my_dialect.foo>;
+
+Pattern {
+ erase MakeRootOp();
+}
+
+// -----
+
+// CHECK: Module
+// CHECK: |-UserRewriteDecl {{.*}} Name<CreateNewOp> ResultType<Op<my_dialect.foo>>
+// CHECK: `-PatternDecl {{.*}}
+// CHECK: `-CallExpr {{.*}} Type<Op<my_dialect.foo>>
+// CHECK: `-DeclRefExpr {{.*}} Type<Rewrite>
+// CHECK: `-UserRewriteDecl {{.*}} Name<CreateNewOp> ResultType<Op<my_dialect.foo>>
+// CHECK: `Arguments`
+// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type<ValueRange>
+// CHECK: `-DeclRefExpr {{.*}} Type<Op<my_dialect.bar>>
+// CHECK: `-VariableDecl {{.*}} Name<inputOp> Type<Op<my_dialect.bar>>
+Rewrite CreateNewOp(inputs: ValueRange) => op<my_dialect.foo>(inputs);
+
+Pattern {
+ let inputOp = op<my_dialect.bar>;
+ replace op<my_dialect.bar>(inputOp) with CreateNewOp(inputOp);
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// MemberAccessExpr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
index 42ea657821120..8b104128f5204 100644
--- a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
@@ -1,4 +1,4 @@
-// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
+// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
// CHECK: expected `{` or `=>` to start pattern body
Pattern }
@@ -12,6 +12,13 @@ Pattern Foo { erase root: Op; }
// -----
+// CHECK: `return` statements are only permitted within a `Constraint` or `Rewrite` body
+Pattern {
+ return _: Value;
+}
+
+// -----
+
// CHECK: expected Pattern body to terminate with an operation rewrite statement
Pattern {
let value: Value;
@@ -32,6 +39,15 @@ Pattern => op<>;
// -----
+Rewrite SomeRewrite();
+
+// CHECK: unable to invoke `Rewrite` within a match section
+Pattern {
+ SomeRewrite();
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// Metadata
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
new file mode 100644
index 0000000000000..1cdb32b5d6b0e
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
@@ -0,0 +1,161 @@
+// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Rewrite Structure
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected identifier name
+Rewrite {}
+
+// -----
+
+// CHECK: :6:9: error: `Foo` has already been defined
+// CHECK: :5:9: note: see previous definition here
+Rewrite Foo();
+Rewrite Foo();
+
+// -----
+
+Rewrite Foo() -> Value {
+ // CHECK: `return` terminated the `Rewrite` body, but found trailing statements afterwards
+ return _: Value;
+ return _: Value;
+}
+
+// -----
+
+// CHECK: missing return in a `Rewrite` expected to return `Value`
+Rewrite Foo() -> Value {
+ let value: Value;
+}
+
+// -----
+
+// CHECK: missing return in a `Rewrite` expected to return `Value`
+Rewrite Foo() -> Value => erase op<my_dialect.foo>;
+
+// -----
+
+// CHECK: unable to convert expression of type `Op<my_dialect.foo>` to the expected type of `Attr`
+Rewrite Foo() -> Attr => op<my_dialect.foo>;
+
+// -----
+
+// CHECK: expected `Rewrite` lambda body to contain a single expression or an operation rewrite statement; such as `erase`, `replace`, or `rewrite`
+Rewrite Foo() => let foo = op<my_dialect.foo>;
+
+// -----
+
+Constraint ValueConstraint(value: Value);
+
+// CHECK: unable to invoke `Constraint` within a rewrite section
+Rewrite Foo(value: Value) {
+ ValueConstraint(value);
+}
+
+// -----
+
+Rewrite Bar();
+
+// CHECK: `Bar` has already been defined
+Rewrite Foo() {
+ Rewrite Bar() {};
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Arguments
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected `(` to start argument list
+Rewrite Foo {}
+
+// -----
+
+// CHECK: expected identifier argument name
+Rewrite Foo(10{}
+
+// -----
+
+// CHECK: expected `:` before argument constraint
+Rewrite Foo(arg{}
+
+// -----
+
+// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
+Rewrite Foo(arg: Value<type>){}
+
+// -----
+
+Constraint ValueConstraint(value: Value);
+
+// CHECK: arguments and results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
+Rewrite Foo(arg: ValueConstraint);
+
+// -----
+
+// CHECK: expected `)` to end argument list
+Rewrite Foo(arg: Value{}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Results
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected identifier constraint
+Rewrite Foo() -> {}
+
+// -----
+
+// CHECK: cannot create a single-element tuple with an element label
+Rewrite Foo() -> result: Value;
+
+// -----
+
+// CHECK: cannot create a single-element tuple with an element label
+Rewrite Foo() -> (result: Value);
+
+// -----
+
+// CHECK: expected identifier constraint
+Rewrite Foo() -> ();
+
+// -----
+
+// CHECK: expected `:` before result constraint
+Rewrite Foo() -> (result{};
+
+// -----
+
+// CHECK: expected `)` to end result list
+Rewrite Foo() -> (Op{};
+
+// -----
+
+// CHECK: inline `Attr`, `Value`, and `ValueRange` type constraints are not permitted on arguments or results
+Rewrite Foo() -> Value<type>){}
+
+// -----
+
+Constraint ValueConstraint(value: Value);
+
+// CHECK: results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
+Rewrite Foo() -> ValueConstraint;
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Native Rewrites
+//===----------------------------------------------------------------------===//
+
+Pattern {
+ // CHECK: external declarations must be declared in global scope
+ Rewrite ExternalConstraint();
+}
+
+// -----
+
+// CHECK: expected `;` after native declaration
+Rewrite Foo() [{}]
diff --git a/mlir/test/mlir-pdll/Parser/rewrite.pdll b/mlir/test/mlir-pdll/Parser/rewrite.pdll
new file mode 100644
index 0000000000000..7ef7478319d4b
--- /dev/null
+++ b/mlir/test/mlir-pdll/Parser/rewrite.pdll
@@ -0,0 +1,58 @@
+// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
+
+// CHECK: Module
+// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Tuple<>>
+Rewrite Foo();
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Tuple<>> Code< /* Native Code */ >
+Rewrite Foo() [{ /* Native Code */ }];
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Value>
+// CHECK: `Inputs`
+// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Op>
+// CHECK: `Results`
+// CHECK: `-VariableDecl {{.*}} Name<> Type<Value>
+// CHECK: `-CompoundStmt {{.*}}
+// CHECK: `-ReturnStmt {{.*}}
+// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type<Value>
+// CHECK: `-DeclRefExpr {{.*}} Type<Op>
+// CHECK: `-VariableDecl {{.*}} Name<arg> Type<Op>
+Rewrite Foo(arg: Op) -> Value => arg;
+
+// -----
+
+// CHECK: Module
+// CHECK: `-UserRewriteDecl {{.*}} Name<Foo> ResultType<Tuple<result1: Value, result2: Attr>>
+// CHECK: `Results`
+// CHECK: |-VariableDecl {{.*}} Name<result1> Type<Value>
+// CHECK: | `Constraints`
+// CHECK: | `-ValueConstraintDecl {{.*}}
+// CHECK: `-VariableDecl {{.*}} Name<result2> Type<Attr>
+// CHECK: `Constraints`
+// CHECK: `-AttrConstraintDecl {{.*}}
+// CHECK: `-CompoundStmt {{.*}}
+// CHECK: `-ReturnStmt {{.*}}
+// CHECK: `-TupleExpr {{.*}} Type<Tuple<result1: Value, result2: Attr>>
+// CHECK: |-MemberAccessExpr {{.*}} Member<0> Type<Value>
+// CHECK: | `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
+// CHECK: `-MemberAccessExpr {{.*}} Member<1> Type<Attr>
+// CHECK: `-TupleExpr {{.*}} Type<Tuple<Value, Attr>>
+Rewrite Foo() -> (result1: Value, result2: Attr) => (_: Value, attr<"10">);
+
+// -----
+
+// Test that anonymous Rewrites are uniquely named.
+
+// CHECK: Module
+// CHECK: UserRewriteDecl {{.*}} Name<<anonymous_rewrite_0>> ResultType<Tuple<>>
+// CHECK: UserRewriteDecl {{.*}} Name<<anonymous_rewrite_1>> ResultType<Attr>
+Rewrite Outer() {
+ Rewrite() {};
+ Rewrite() => attr<"10">;
+}
diff --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll
index 802353615d29c..d52a2d28b6d05 100644
--- a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll
@@ -223,6 +223,33 @@ Pattern {
// -----
+Constraint Foo();
+
+Pattern {
+ // CHECK: unable to define variable of `Constraint` type
+ let foo = Foo;
+}
+
+// -----
+
+Rewrite Foo();
+
+Pattern {
+ // CHECK: unable to define variable of `Rewrite` type
+ let foo = Foo;
+}
+
+// -----
+
+Constraint MultiConstraint(arg1: Value, arg2: Value);
+
+Pattern {
+ // CHECK: `Constraint`s applied via a variable constraint list must take a single input, but got 2
+ let foo: MultiConstraint;
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// `replace`
//===----------------------------------------------------------------------===//
@@ -276,6 +303,17 @@ Pattern {
// -----
+//===----------------------------------------------------------------------===//
+// `return`
+//===----------------------------------------------------------------------===//
+
+// CHECK: expected `;` after statement
+Constraint Foo(arg: Value) -> Value {
+ return arg
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// `rewrite`
//===----------------------------------------------------------------------===//
@@ -307,3 +345,12 @@ Pattern {
op<>;
};
}
+
+// -----
+
+Pattern {
+ // CHECK: `return` statements are only permitted within a `Constraint` or `Rewrite` body
+ rewrite root: Op with {
+ return root;
+ };
+}
More information about the Mlir-commits
mailing list