[Mlir-commits] [mlir] 930916c - [MLIR][PDL] Add PDLL support for negated native constraints
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 1 16:12:23 PDT 2023
Author: Mogball
Date: 2023-09-01T23:12:16Z
New Revision: 930916c7f3622870b40138dafcc5f94740404e8c
URL: https://github.com/llvm/llvm-project/commit/930916c7f3622870b40138dafcc5f94740404e8c
DIFF: https://github.com/llvm/llvm-project/commit/930916c7f3622870b40138dafcc5f94740404e8c.diff
LOG: [MLIR][PDL] Add PDLL support for negated native constraints
This commit enables the expression of negated native constraints in PDLL:
If a constraint is prefixed with "not" it is parsed as a negated constraint and hence the attribute `isNegated` of the emitted `pdl.apply_native_constraint` operation is set to `true`.
In first instance this is only supported for the calling of external native C++ constraints and generation of PDL patterns.
Previously, negating a native constraint would have been handled by creating an additional native call, e.g.
```PDLL
Constraint checkA(input: Attr);
Constarint checkNotA(input: Attr);
```
or by including an explicit additional operand for negation, e.g.
`Constraint checkA(input: Attr, negated: Attr);`
With this a constraint can simply be negated by prefixing it with `not`. e.g.
```PDLL
Constraint simpleConstraint(op: Op);
Pattern example {
let inputOp = op<test.bar>() ->(type: Type);
let root = op<test.foo>(inputOp.0) -> ();
not simpleConstraint(inputOp);
simpleConstraint(root);
erase root;
}
```
Depends on [[ https://reviews.llvm.org/D153871 | D153871 ]]
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D153959
Added:
Modified:
mlir/include/mlir/Tools/PDLL/AST/Nodes.h
mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
mlir/lib/Tools/PDLL/AST/Nodes.cpp
mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
mlir/lib/Tools/PDLL/Parser/Lexer.cpp
mlir/lib/Tools/PDLL/Parser/Lexer.h
mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
mlir/test/mlir-pdll/Parser/expr-failure.pdll
mlir/test/mlir-pdll/Parser/expr.pdll
Removed:
################################################################################
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index 3b29d48fb07bdf7..5515ee7548b5ab5 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -390,7 +390,8 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
private llvm::TrailingObjects<CallExpr, Expr *> {
public:
static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
- ArrayRef<Expr *> arguments, Type resultType);
+ ArrayRef<Expr *> arguments, Type resultType,
+ bool isNegated = false);
/// Return the callable of this call.
Expr *getCallableExpr() const { return callable; }
@@ -403,9 +404,14 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
return const_cast<CallExpr *>(this)->getArguments();
}
+ /// Returns whether the result of this call is to be negated.
+ bool getIsNegated() const { return isNegated; }
+
private:
- CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs)
- : Base(loc, type), callable(callable), numArgs(numArgs) {}
+ CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs,
+ bool isNegated)
+ : Base(loc, type), callable(callable), numArgs(numArgs),
+ isNegated(isNegated) {}
/// The callable of this call.
Expr *callable;
@@ -415,6 +421,9 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
/// TrailingObject utilities.
friend llvm::TrailingObjects<CallExpr, Expr *>;
+
+ // Is the result of this call to be negated.
+ bool isNegated;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
index 4155fd69c589754..04f02f20dee9f46 100644
--- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
+++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
@@ -225,7 +225,10 @@ void NodePrinter::printImpl(const AttributeExpr *expr) {
void NodePrinter::printImpl(const CallExpr *expr) {
os << "CallExpr " << expr << " Type<";
print(expr->getType());
- os << ">\n";
+ os << ">";
+ if (expr->getIsNegated())
+ os << " Negated";
+ os << "\n";
printChildren(expr->getCallableExpr());
printChildren("Arguments", expr->getArguments());
}
diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
index 47556295b7cbed6..654ff24454cb1fb 100644
--- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
@@ -266,12 +266,13 @@ AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
//===----------------------------------------------------------------------===//
CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
- ArrayRef<Expr *> arguments, Type resultType) {
+ ArrayRef<Expr *> arguments, Type resultType,
+ bool isNegated) {
unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
- CallExpr *expr =
- new (rawData) CallExpr(loc, resultType, callable, arguments.size());
+ CallExpr *expr = new (rawData)
+ CallExpr(loc, resultType, callable, arguments.size(), isNegated);
std::uninitialized_copy(arguments.begin(), arguments.end(),
expr->getArguments().begin());
return expr;
diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
index 00b68162a1ed281..16c3ccf0de2690b 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
@@ -103,12 +103,14 @@ class CodeGen {
Value genExprImpl(const ast::TypeExpr *expr);
SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
- Location loc, ValueRange inputs);
+ Location loc, ValueRange inputs,
+ bool isNegated = false);
SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
Location loc, ValueRange inputs);
template <typename PDLOpT, typename T>
SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
- ValueRange inputs);
+ ValueRange inputs,
+ bool isNegated = false);
//===--------------------------------------------------------------------===//
// Fields
@@ -419,7 +421,7 @@ SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
// Generate the PDL based on the type of callable.
const ast::Decl *callable = callableExpr->getDecl();
if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
- return genConstraintCall(decl, loc, arguments);
+ return genConstraintCall(decl, loc, arguments, expr->getIsNegated());
if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
return genRewriteCall(decl, loc, arguments);
llvm_unreachable("unhandled CallExpr callable");
@@ -547,15 +549,15 @@ Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
SmallVector<Value>
CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
- ValueRange inputs) {
+ ValueRange inputs, bool isNegated) {
// Apply any constraints defined on the arguments to the input values.
for (auto it : llvm::zip(decl->getInputs(), inputs))
applyVarConstraints(std::get<0>(it), std::get<1>(it));
// Generate the constraint call.
SmallVector<Value> results =
- genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc,
- inputs);
+ genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
+ decl, loc, inputs, isNegated);
// Apply any constraints defined on the results of the constraint.
for (auto it : llvm::zip(decl->getResults(), results))
@@ -570,9 +572,9 @@ SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
}
template <typename PDLOpT, typename T>
-SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
- Location loc,
- ValueRange inputs) {
+SmallVector<Value>
+CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
+ ValueRange inputs, bool isNegated) {
const ast::CompoundStmt *cstBody = decl->getBody();
// If the decl doesn't have a statement body, it is a native decl.
@@ -585,8 +587,10 @@ SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
} else {
resultTypes.push_back(genType(declResultType));
}
- Operation *pdlOp = builder.create<PDLOpT>(
+ PDLOpT pdlOp = builder.create<PDLOpT>(
loc, resultTypes, decl->getName().getName(), inputs);
+ if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
+ cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
return pdlOp->getResults();
}
diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
index 74b02cc5209d966..4673a73b4efdc7a 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
@@ -315,6 +315,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
.Case("erase", Token::kw_erase)
.Case("let", Token::kw_let)
.Case("Constraint", Token::kw_Constraint)
+ .Case("not", Token::kw_not)
.Case("op", Token::kw_op)
.Case("Op", Token::kw_Op)
.Case("OpName", Token::kw_OpName)
diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h
index 6a78669f854d50c..c80cb3693c4b18c 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.h
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h
@@ -57,6 +57,7 @@ class Token {
kw_erase,
kw_let,
kw_Constraint,
+ kw_not,
kw_Op,
kw_OpName,
kw_Pattern,
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index fa8bc082af86f97..93eb5e892e228e0 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -315,12 +315,14 @@ class Parser {
/// Identifier expressions.
FailureOr<ast::Expr *> parseAttributeExpr();
- FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
+ FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
+ bool isNegated = false);
FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
FailureOr<ast::Expr *> parseIdentifierExpr();
FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
+ FailureOr<ast::Expr *> parseNegatedExpr();
FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
FailureOr<ast::Expr *>
@@ -405,7 +407,8 @@ class Parser {
FailureOr<ast::CallExpr *>
createCallExpr(SMRange loc, ast::Expr *parentExpr,
- MutableArrayRef<ast::Expr *> arguments);
+ MutableArrayRef<ast::Expr *> arguments,
+ bool isNegated = false);
FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
FailureOr<ast::DeclRefExpr *>
createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
@@ -1805,6 +1808,9 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
case Token::kw_Constraint:
lhsExpr = parseInlineConstraintLambdaExpr();
break;
+ case Token::kw_not:
+ lhsExpr = parseNegatedExpr();
+ break;
case Token::identifier:
lhsExpr = parseIdentifierExpr();
break;
@@ -1866,7 +1872,8 @@ FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
return ast::AttributeExpr::create(ctx, loc, attrExpr);
}
-FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
+FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
+ bool isNegated) {
consumeToken(Token::l_paren);
// Parse the arguments of the call.
@@ -1890,7 +1897,7 @@ FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
return failure();
- return createCallExpr(loc, parentExpr, arguments);
+ return createCallExpr(loc, parentExpr, arguments, isNegated);
}
FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
@@ -1959,6 +1966,17 @@ FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
return createMemberAccessExpr(parentExpr, memberName, loc);
}
+FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
+ consumeToken(Token::kw_not);
+ // Only native constraints are supported after negation
+ if (!curToken.is(Token::identifier))
+ return emitError("expected native constraint");
+ FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
+ if (failed(identifierExpr))
+ return failure();
+ return parseCallExpr(*identifierExpr, /*isNegated = */ true);
+}
+
FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
SMRange loc = curToken.getLoc();
@@ -2672,7 +2690,7 @@ Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
FailureOr<ast::CallExpr *>
Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
- MutableArrayRef<ast::Expr *> arguments) {
+ MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
ast::Type parentType = parentExpr->getType();
ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
@@ -2686,8 +2704,14 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
if (isa<ast::UserConstraintDecl>(callableDecl))
return emitError(
loc, "unable to invoke `Constraint` within a rewrite section");
- } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
- return emitError(loc, "unable to invoke `Rewrite` within a match section");
+ if (isNegated)
+ return emitError(loc, "unable to negate a Rewrite");
+ } else {
+ if (isa<ast::UserRewriteDecl>(callableDecl))
+ return emitError(loc,
+ "unable to invoke `Rewrite` within a match section");
+ if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
+ return emitError(loc, "unable to negate non native constraints");
}
// Verify the arguments of the call.
@@ -2718,7 +2742,7 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
}
return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
- callableDecl->getResultType());
+ callableDecl->getResultType(), isNegated);
}
FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,
diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
index 950b90d75d6a4a0..948571cd7e5c810 100644
--- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
@@ -36,6 +36,20 @@ Pattern TestExternalCall => replace root: Op with TestRewrite(root);
// -----
+// CHECK: pdl.pattern @TestExternalNegatedCall
+// CHECK: %[[ROOT:.*]] = operation
+// CHECK: apply_native_constraint "TestConstraint"(%[[ROOT]] : !pdl.operation) {isNegated = true}
+// CHECK: rewrite %[[ROOT]]
+// CHECK: erase %[[ROOT]]
+Constraint TestConstraint(op: Op);
+Pattern TestExternalNegatedCall {
+ let root = op : Op;
+ not TestConstraint(root);
+ erase root;
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// MemberAccessExpr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
index 31258cb99ebc499..253d770c83e58a6 100644
--- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
@@ -173,6 +173,16 @@ Pattern {
// -----
+Constraint Foo(op: Op) {}
+
+Pattern {
+ // CHECK: unable to negate non native constraints
+ let root = op<>;
+ not Foo(root);
+}
+
+// -----
+
Rewrite Foo();
Pattern {
@@ -183,6 +193,18 @@ Pattern {
// -----
+Rewrite Foo(op: Op);
+
+Pattern {
+ // CHECK: unable to negate a Rewrite
+ let root = op<>;
+ rewrite root with {
+ not Foo(root);
+ }
+}
+
+// -----
+
Pattern {
// CHECK: expected expression
let tuple = (10 = _: Value);
diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll
index 0736962dada78f3..712e8835477aa17 100644
--- a/mlir/test/mlir-pdll/Parser/expr.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr.pdll
@@ -50,6 +50,22 @@ Pattern {
// -----
+// CHECK: Module {{.*}}
+// CHECK: -UserConstraintDecl {{.*}} Name<TestConstraint> ResultType<Tuple<>>
+// CHECK: `-PatternDecl {{.*}}
+// CHECK: -CallExpr {{.*}} Type<Tuple<>> Negated
+// CHECK: `-DeclRefExpr {{.*}} Type<Constraint>
+// CHECK: `-UserConstraintDecl {{.*}} Name<TestConstraint> ResultType<Tuple<>>
+Constraint TestConstraint(op: Op);
+
+Pattern {
+ let inputOp = op<my_dialect.bar>;
+ not TestConstraint(inputOp);
+ erase inputOp;
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// MemberAccessExpr
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list