[Mlir-commits] [mlir] 3d8b906 - [PDLL] Add support for single line lambda-like patterns

River Riddle llvmlistbot at llvm.org
Thu Feb 10 12:49:16 PST 2022


Author: River Riddle
Date: 2022-02-10T12:48:58-08:00
New Revision: 3d8b90601211914b0d4690fa603e4b5c43e5c9ac

URL: https://github.com/llvm/llvm-project/commit/3d8b90601211914b0d4690fa603e4b5c43e5c9ac
DIFF: https://github.com/llvm/llvm-project/commit/3d8b90601211914b0d4690fa603e4b5c43e5c9ac.diff

LOG: [PDLL] Add support for single line lambda-like patterns

This allows for defining simple patterns in a single line. The lambda
body of a Pattern expects a single operation rewrite statement:

```
Pattern => replace op<my_dialect.foo>(operands: ValueRange) with operands;
```

Differential Revision: https://reviews.llvm.org/D115835

Added: 
    

Modified: 
    mlir/lib/Tools/PDLL/Parser/Parser.cpp
    mlir/test/mlir-pdll/Parser/pattern-failure.pdll
    mlir/test/mlir-pdll/Parser/pattern.pdll

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index a6c8ebbe76705..5264b953aaabc 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -106,6 +106,10 @@ class Parser {
 
   FailureOr<ast::Decl *> parseTopLevelDecl();
   FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
+  FailureOr<ast::CompoundStmt *>
+  parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
+                  bool expectTerminalSemicolon = true);
+  FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
   FailureOr<ast::Decl *> parsePatternDecl();
   LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
 
@@ -547,6 +551,36 @@ FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
   return ast::NamedAttributeDecl::create(ctx, name, attrValue);
 }
 
+FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
+    function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
+    bool expectTerminalSemicolon) {
+  consumeToken(Token::equal_arrow);
+
+  // Parse the single statement of the lambda body.
+  SMLoc bodyStartLoc = curToken.getStartLoc();
+  pushDeclScope();
+  FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
+  bool failedToParse =
+      failed(singleStatement) || failed(processStatementFn(*singleStatement));
+  popDeclScope();
+  if (failedToParse)
+    return failure();
+
+  SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
+  return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
+}
+
+FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
+  return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
+    if (isa<ast::OpRewriteStmt>(statement))
+      return success();
+    return emitError(
+        statement->getLoc(),
+        "expected Pattern lambda body to contain a single operation "
+        "rewrite statement, such as `erase`, `replace`, or `rewrite`");
+  });
+}
+
 FailureOr<ast::Decl *> Parser::parsePatternDecl() {
   SMRange loc = curToken.getLoc();
   consumeToken(Token::kw_Pattern);
@@ -568,29 +602,37 @@ FailureOr<ast::Decl *> Parser::parsePatternDecl() {
   // Parse the pattern body.
   ast::CompoundStmt *body;
 
-  if (curToken.isNot(Token::l_brace))
-    return emitError("expected `{` to start pattern body");
-  FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
-  if (failed(bodyResult))
-    return failure();
-  body = *bodyResult;
-
-  // Verify the body of the pattern.
-  auto bodyIt = body->begin(), bodyE = body->end();
-  for (; bodyIt != bodyE; ++bodyIt) {
-    // Break when we've found the rewrite statement.
-    if (isa<ast::OpRewriteStmt>(*bodyIt))
-      break;
-  }
-  if (bodyIt == bodyE) {
-    return emitError(loc,
-                     "expected Pattern body to terminate with an operation "
-                     "rewrite statement, such as `erase`");
-  }
-  if (std::next(bodyIt) != bodyE) {
-    return emitError((*std::next(bodyIt))->getLoc(),
-                     "Pattern body was terminated by an operation "
-                     "rewrite statement, but found trailing statements");
+  // Handle a lambda body.
+  if (curToken.is(Token::equal_arrow)) {
+    FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
+    if (failed(bodyResult))
+      return failure();
+    body = *bodyResult;
+  } else {
+    if (curToken.isNot(Token::l_brace))
+      return emitError("expected `{` or `=>` to start pattern body");
+    FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
+    if (failed(bodyResult))
+      return failure();
+    body = *bodyResult;
+
+    // Verify the body of the pattern.
+    auto bodyIt = body->begin(), bodyE = body->end();
+    for (; bodyIt != bodyE; ++bodyIt) {
+      // Break when we've found the rewrite statement.
+      if (isa<ast::OpRewriteStmt>(*bodyIt))
+        break;
+    }
+    if (bodyIt == bodyE) {
+      return emitError(loc,
+                       "expected Pattern body to terminate with an operation "
+                       "rewrite statement, such as `erase`");
+    }
+    if (std::next(bodyIt) != bodyE) {
+      return emitError((*std::next(bodyIt))->getLoc(),
+                       "Pattern body was terminated by an operation "
+                       "rewrite statement, but found trailing statements");
+    }
   }
 
   return createPatternDecl(loc, name, metadata, body);

diff  --git a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
index caa084cda0b68..42ea657821120 100644
--- a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
@@ -1,6 +1,6 @@
 // RUN: not mlir-pdll %s -split-input-file 2>&1  | FileCheck %s
 
-// CHECK: expected `{` to start pattern body
+// CHECK: expected `{` or `=>` to start pattern body
 Pattern }
 
 // -----
@@ -27,6 +27,11 @@ Pattern {
 
 // -----
 
+// CHECK: expected Pattern lambda body to contain a single operation rewrite statement, such as `erase`, `replace`, or `rewrite`
+Pattern => op<>;
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Metadata
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-pdll/Parser/pattern.pdll b/mlir/test/mlir-pdll/Parser/pattern.pdll
index 1a7851606213e..f0b2046e4b1b8 100644
--- a/mlir/test/mlir-pdll/Parser/pattern.pdll
+++ b/mlir/test/mlir-pdll/Parser/pattern.pdll
@@ -23,3 +23,11 @@ Pattern NamedPattern {
 Pattern NamedPattern with benefit(10), recursion {
   erase _: Op;
 }
+
+// -----
+
+// CHECK: Module
+// CHECK: `-PatternDecl {{.*}} Name<NamedPattern>
+// CHECK:   `-CompoundStmt
+// CHECK:     `-EraseStmt
+Pattern NamedPattern => erase _: Op;


        


More information about the Mlir-commits mailing list