[Mlir-commits] [mlir] 930916c - [MLIR][PDL] Add PDLL support for negated native constraints

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 1 16:12:23 PDT 2023


Author: Mogball
Date: 2023-09-01T23:12:16Z
New Revision: 930916c7f3622870b40138dafcc5f94740404e8c

URL: https://github.com/llvm/llvm-project/commit/930916c7f3622870b40138dafcc5f94740404e8c
DIFF: https://github.com/llvm/llvm-project/commit/930916c7f3622870b40138dafcc5f94740404e8c.diff

LOG: [MLIR][PDL] Add PDLL support for negated native constraints

This commit enables the expression of negated native constraints in PDLL:

If a constraint is prefixed with "not" it is parsed as a negated constraint and hence the attribute `isNegated` of the emitted `pdl.apply_native_constraint` operation is set to `true`.
In first instance this is only supported for the calling of external native C++ constraints and generation of PDL patterns.

Previously, negating a native constraint would have  been handled by creating an additional native call, e.g.
```PDLL
Constraint checkA(input: Attr);
Constarint checkNotA(input: Attr);
```
or by including an explicit additional operand for negation, e.g.
`Constraint checkA(input: Attr, negated: Attr);`

With this a constraint can simply be negated by prefixing it with `not`. e.g.
```PDLL
Constraint simpleConstraint(op: Op);

Pattern example {
    let inputOp = op<test.bar>() ->(type: Type);
    let root = op<test.foo>(inputOp.0) -> ();
    not simpleConstraint(inputOp);
    simpleConstraint(root);
    erase root;
}
```

Depends on [[ https://reviews.llvm.org/D153871 | D153871 ]]

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D153959

Added: 
    

Modified: 
    mlir/include/mlir/Tools/PDLL/AST/Nodes.h
    mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
    mlir/lib/Tools/PDLL/AST/Nodes.cpp
    mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
    mlir/lib/Tools/PDLL/Parser/Lexer.cpp
    mlir/lib/Tools/PDLL/Parser/Lexer.h
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
    mlir/test/mlir-pdll/Parser/expr-failure.pdll
    mlir/test/mlir-pdll/Parser/expr.pdll

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index 3b29d48fb07bdf7..5515ee7548b5ab5 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -390,7 +390,8 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
                        private llvm::TrailingObjects<CallExpr, Expr *> {
 public:
   static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
-                          ArrayRef<Expr *> arguments, Type resultType);
+                          ArrayRef<Expr *> arguments, Type resultType,
+                          bool isNegated = false);
 
   /// Return the callable of this call.
   Expr *getCallableExpr() const { return callable; }
@@ -403,9 +404,14 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
     return const_cast<CallExpr *>(this)->getArguments();
   }
 
+  /// Returns whether the result of this call is to be negated.
+  bool getIsNegated() const { return isNegated; }
+
 private:
-  CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs)
-      : Base(loc, type), callable(callable), numArgs(numArgs) {}
+  CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs,
+           bool isNegated)
+      : Base(loc, type), callable(callable), numArgs(numArgs),
+        isNegated(isNegated) {}
 
   /// The callable of this call.
   Expr *callable;
@@ -415,6 +421,9 @@ class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
 
   /// TrailingObject utilities.
   friend llvm::TrailingObjects<CallExpr, Expr *>;
+
+  // Is the result of this call to be negated.
+  bool isNegated;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
index 4155fd69c589754..04f02f20dee9f46 100644
--- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
+++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp
@@ -225,7 +225,10 @@ void NodePrinter::printImpl(const AttributeExpr *expr) {
 void NodePrinter::printImpl(const CallExpr *expr) {
   os << "CallExpr " << expr << " Type<";
   print(expr->getType());
-  os << ">\n";
+  os << ">";
+  if (expr->getIsNegated())
+    os << " Negated";
+  os << "\n";
   printChildren(expr->getCallableExpr());
   printChildren("Arguments", expr->getArguments());
 }

diff  --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
index 47556295b7cbed6..654ff24454cb1fb 100644
--- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp
+++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp
@@ -266,12 +266,13 @@ AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
 //===----------------------------------------------------------------------===//
 
 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
-                           ArrayRef<Expr *> arguments, Type resultType) {
+                           ArrayRef<Expr *> arguments, Type resultType,
+                           bool isNegated) {
   unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
 
-  CallExpr *expr =
-      new (rawData) CallExpr(loc, resultType, callable, arguments.size());
+  CallExpr *expr = new (rawData)
+      CallExpr(loc, resultType, callable, arguments.size(), isNegated);
   std::uninitialized_copy(arguments.begin(), arguments.end(),
                           expr->getArguments().begin());
   return expr;

diff  --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
index 00b68162a1ed281..16c3ccf0de2690b 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
@@ -103,12 +103,14 @@ class CodeGen {
   Value genExprImpl(const ast::TypeExpr *expr);
 
   SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
-                                       Location loc, ValueRange inputs);
+                                       Location loc, ValueRange inputs,
+                                       bool isNegated = false);
   SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
                                     Location loc, ValueRange inputs);
   template <typename PDLOpT, typename T>
   SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
-                                                ValueRange inputs);
+                                                ValueRange inputs,
+                                                bool isNegated = false);
 
   //===--------------------------------------------------------------------===//
   // Fields
@@ -419,7 +421,7 @@ SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
   // Generate the PDL based on the type of callable.
   const ast::Decl *callable = callableExpr->getDecl();
   if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
-    return genConstraintCall(decl, loc, arguments);
+    return genConstraintCall(decl, loc, arguments, expr->getIsNegated());
   if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
     return genRewriteCall(decl, loc, arguments);
   llvm_unreachable("unhandled CallExpr callable");
@@ -547,15 +549,15 @@ Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
 
 SmallVector<Value>
 CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
-                           ValueRange inputs) {
+                           ValueRange inputs, bool isNegated) {
   // Apply any constraints defined on the arguments to the input values.
   for (auto it : llvm::zip(decl->getInputs(), inputs))
     applyVarConstraints(std::get<0>(it), std::get<1>(it));
 
   // Generate the constraint call.
   SmallVector<Value> results =
-      genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc,
-                                                               inputs);
+      genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
+          decl, loc, inputs, isNegated);
 
   // Apply any constraints defined on the results of the constraint.
   for (auto it : llvm::zip(decl->getResults(), results))
@@ -570,9 +572,9 @@ SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
 }
 
 template <typename PDLOpT, typename T>
-SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
-                                                       Location loc,
-                                                       ValueRange inputs) {
+SmallVector<Value>
+CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
+                                    ValueRange inputs, bool isNegated) {
   const ast::CompoundStmt *cstBody = decl->getBody();
 
   // If the decl doesn't have a statement body, it is a native decl.
@@ -585,8 +587,10 @@ SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
     } else {
       resultTypes.push_back(genType(declResultType));
     }
-    Operation *pdlOp = builder.create<PDLOpT>(
+    PDLOpT pdlOp = builder.create<PDLOpT>(
         loc, resultTypes, decl->getName().getName(), inputs);
+    if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
+      cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
     return pdlOp->getResults();
   }
 

diff  --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
index 74b02cc5209d966..4673a73b4efdc7a 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp
@@ -315,6 +315,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
                          .Case("erase", Token::kw_erase)
                          .Case("let", Token::kw_let)
                          .Case("Constraint", Token::kw_Constraint)
+                         .Case("not", Token::kw_not)
                          .Case("op", Token::kw_op)
                          .Case("Op", Token::kw_Op)
                          .Case("OpName", Token::kw_OpName)

diff  --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h
index 6a78669f854d50c..c80cb3693c4b18c 100644
--- a/mlir/lib/Tools/PDLL/Parser/Lexer.h
+++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h
@@ -57,6 +57,7 @@ class Token {
     kw_erase,
     kw_let,
     kw_Constraint,
+    kw_not,
     kw_Op,
     kw_OpName,
     kw_Pattern,

diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index fa8bc082af86f97..93eb5e892e228e0 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -315,12 +315,14 @@ class Parser {
 
   /// Identifier expressions.
   FailureOr<ast::Expr *> parseAttributeExpr();
-  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr);
+  FailureOr<ast::Expr *> parseCallExpr(ast::Expr *parentExpr,
+                                       bool isNegated = false);
   FailureOr<ast::Expr *> parseDeclRefExpr(StringRef name, SMRange loc);
   FailureOr<ast::Expr *> parseIdentifierExpr();
   FailureOr<ast::Expr *> parseInlineConstraintLambdaExpr();
   FailureOr<ast::Expr *> parseInlineRewriteLambdaExpr();
   FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
+  FailureOr<ast::Expr *> parseNegatedExpr();
   FailureOr<ast::OpNameDecl *> parseOperationName(bool allowEmptyName = false);
   FailureOr<ast::OpNameDecl *> parseWrappedOperationName(bool allowEmptyName);
   FailureOr<ast::Expr *>
@@ -405,7 +407,8 @@ class Parser {
 
   FailureOr<ast::CallExpr *>
   createCallExpr(SMRange loc, ast::Expr *parentExpr,
-                 MutableArrayRef<ast::Expr *> arguments);
+                 MutableArrayRef<ast::Expr *> arguments,
+                 bool isNegated = false);
   FailureOr<ast::DeclRefExpr *> createDeclRefExpr(SMRange loc, ast::Decl *decl);
   FailureOr<ast::DeclRefExpr *>
   createInlineVariableExpr(ast::Type type, StringRef name, SMRange loc,
@@ -1805,6 +1808,9 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
   case Token::kw_Constraint:
     lhsExpr = parseInlineConstraintLambdaExpr();
     break;
+  case Token::kw_not:
+    lhsExpr = parseNegatedExpr();
+    break;
   case Token::identifier:
     lhsExpr = parseIdentifierExpr();
     break;
@@ -1866,7 +1872,8 @@ FailureOr<ast::Expr *> Parser::parseAttributeExpr() {
   return ast::AttributeExpr::create(ctx, loc, attrExpr);
 }
 
-FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
+FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr,
+                                             bool isNegated) {
   consumeToken(Token::l_paren);
 
   // Parse the arguments of the call.
@@ -1890,7 +1897,7 @@ FailureOr<ast::Expr *> Parser::parseCallExpr(ast::Expr *parentExpr) {
   if (failed(parseToken(Token::r_paren, "expected `)` after argument list")))
     return failure();
 
-  return createCallExpr(loc, parentExpr, arguments);
+  return createCallExpr(loc, parentExpr, arguments, isNegated);
 }
 
 FailureOr<ast::Expr *> Parser::parseDeclRefExpr(StringRef name, SMRange loc) {
@@ -1959,6 +1966,17 @@ FailureOr<ast::Expr *> Parser::parseMemberAccessExpr(ast::Expr *parentExpr) {
   return createMemberAccessExpr(parentExpr, memberName, loc);
 }
 
+FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
+  consumeToken(Token::kw_not);
+  // Only native constraints are supported after negation
+  if (!curToken.is(Token::identifier))
+    return emitError("expected native constraint");
+  FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
+  if (failed(identifierExpr))
+    return failure();
+  return parseCallExpr(*identifierExpr, /*isNegated = */ true);
+}
+
 FailureOr<ast::OpNameDecl *> Parser::parseOperationName(bool allowEmptyName) {
   SMRange loc = curToken.getLoc();
 
@@ -2672,7 +2690,7 @@ Parser::validateTypeRangeConstraintExpr(const ast::Expr *typeExpr) {
 
 FailureOr<ast::CallExpr *>
 Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
-                       MutableArrayRef<ast::Expr *> arguments) {
+                       MutableArrayRef<ast::Expr *> arguments, bool isNegated) {
   ast::Type parentType = parentExpr->getType();
 
   ast::CallableDecl *callableDecl = tryExtractCallableDecl(parentExpr);
@@ -2686,8 +2704,14 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
     if (isa<ast::UserConstraintDecl>(callableDecl))
       return emitError(
           loc, "unable to invoke `Constraint` within a rewrite section");
-  } else if (isa<ast::UserRewriteDecl>(callableDecl)) {
-    return emitError(loc, "unable to invoke `Rewrite` within a match section");
+    if (isNegated)
+      return emitError(loc, "unable to negate a Rewrite");
+  } else {
+    if (isa<ast::UserRewriteDecl>(callableDecl))
+      return emitError(loc,
+                       "unable to invoke `Rewrite` within a match section");
+    if (isNegated && cast<ast::UserConstraintDecl>(callableDecl)->getBody())
+      return emitError(loc, "unable to negate non native constraints");
   }
 
   // Verify the arguments of the call.
@@ -2718,7 +2742,7 @@ Parser::createCallExpr(SMRange loc, ast::Expr *parentExpr,
   }
 
   return ast::CallExpr::create(ctx, loc, parentExpr, arguments,
-                               callableDecl->getResultType());
+                               callableDecl->getResultType(), isNegated);
 }
 
 FailureOr<ast::DeclRefExpr *> Parser::createDeclRefExpr(SMRange loc,

diff  --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
index 950b90d75d6a4a0..948571cd7e5c810 100644
--- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
@@ -36,6 +36,20 @@ Pattern TestExternalCall => replace root: Op with TestRewrite(root);
 
 // -----
 
+// CHECK: pdl.pattern @TestExternalNegatedCall
+// CHECK: %[[ROOT:.*]] = operation
+// CHECK: apply_native_constraint  "TestConstraint"(%[[ROOT]] : !pdl.operation) {isNegated = true}
+// CHECK: rewrite %[[ROOT]]
+// CHECK: erase %[[ROOT]]
+Constraint TestConstraint(op: Op);
+Pattern TestExternalNegatedCall {
+    let root = op : Op;
+    not TestConstraint(root);
+    erase root;
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // MemberAccessExpr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
index 31258cb99ebc499..253d770c83e58a6 100644
--- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll
@@ -173,6 +173,16 @@ Pattern {
 
 // -----
 
+Constraint Foo(op: Op) {}
+
+Pattern {
+  // CHECK: unable to negate non native constraints
+  let root = op<>;
+  not Foo(root);
+}
+
+// -----
+
 Rewrite Foo();
 
 Pattern {
@@ -183,6 +193,18 @@ Pattern {
 
 // -----
 
+Rewrite Foo(op: Op);
+
+Pattern {
+  // CHECK: unable to negate a Rewrite
+  let root = op<>;
+  rewrite root with {
+     not Foo(root);
+  }
+}
+
+// -----
+
 Pattern {
   // CHECK: expected expression
   let tuple = (10 = _: Value);

diff  --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll
index 0736962dada78f3..712e8835477aa17 100644
--- a/mlir/test/mlir-pdll/Parser/expr.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr.pdll
@@ -50,6 +50,22 @@ Pattern {
 
 // -----
 
+// CHECK: Module {{.*}}
+// CHECK:  -UserConstraintDecl {{.*}} Name<TestConstraint> ResultType<Tuple<>>
+// CHECK: `-PatternDecl {{.*}}
+// CHECK:   -CallExpr {{.*}} Type<Tuple<>> Negated
+// CHECK:     `-DeclRefExpr {{.*}} Type<Constraint>
+// CHECK:       `-UserConstraintDecl {{.*}} Name<TestConstraint> ResultType<Tuple<>>
+Constraint TestConstraint(op: Op);
+
+Pattern {
+  let inputOp = op<my_dialect.bar>;
+  not TestConstraint(inputOp);
+  erase inputOp;
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // MemberAccessExpr
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list