[Mlir-commits] [mlir] 3d8b906 - [PDLL] Add support for single line lambda-like patterns
River Riddle
llvmlistbot at llvm.org
Thu Feb 10 12:49:16 PST 2022
Author: River Riddle
Date: 2022-02-10T12:48:58-08:00
New Revision: 3d8b90601211914b0d4690fa603e4b5c43e5c9ac
URL: https://github.com/llvm/llvm-project/commit/3d8b90601211914b0d4690fa603e4b5c43e5c9ac
DIFF: https://github.com/llvm/llvm-project/commit/3d8b90601211914b0d4690fa603e4b5c43e5c9ac.diff
LOG: [PDLL] Add support for single line lambda-like patterns
This allows for defining simple patterns in a single line. The lambda
body of a Pattern expects a single operation rewrite statement:
```
Pattern => replace op<my_dialect.foo>(operands: ValueRange) with operands;
```
Differential Revision: https://reviews.llvm.org/D115835
Added:
Modified:
mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/test/mlir-pdll/Parser/pattern-failure.pdll
mlir/test/mlir-pdll/Parser/pattern.pdll
Removed:
################################################################################
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index a6c8ebbe76705..5264b953aaabc 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -106,6 +106,10 @@ class Parser {
FailureOr<ast::Decl *> parseTopLevelDecl();
FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
+ FailureOr<ast::CompoundStmt *>
+ parseLambdaBody(function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
+ bool expectTerminalSemicolon = true);
+ FailureOr<ast::CompoundStmt *> parsePatternLambdaBody();
FailureOr<ast::Decl *> parsePatternDecl();
LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
@@ -547,6 +551,36 @@ FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
return ast::NamedAttributeDecl::create(ctx, name, attrValue);
}
+FailureOr<ast::CompoundStmt *> Parser::parseLambdaBody(
+ function_ref<LogicalResult(ast::Stmt *&)> processStatementFn,
+ bool expectTerminalSemicolon) {
+ consumeToken(Token::equal_arrow);
+
+ // Parse the single statement of the lambda body.
+ SMLoc bodyStartLoc = curToken.getStartLoc();
+ pushDeclScope();
+ FailureOr<ast::Stmt *> singleStatement = parseStmt(expectTerminalSemicolon);
+ bool failedToParse =
+ failed(singleStatement) || failed(processStatementFn(*singleStatement));
+ popDeclScope();
+ if (failedToParse)
+ return failure();
+
+ SMRange bodyLoc(bodyStartLoc, curToken.getStartLoc());
+ return ast::CompoundStmt::create(ctx, bodyLoc, *singleStatement);
+}
+
+FailureOr<ast::CompoundStmt *> Parser::parsePatternLambdaBody() {
+ return parseLambdaBody([&](ast::Stmt *&statement) -> LogicalResult {
+ if (isa<ast::OpRewriteStmt>(statement))
+ return success();
+ return emitError(
+ statement->getLoc(),
+ "expected Pattern lambda body to contain a single operation "
+ "rewrite statement, such as `erase`, `replace`, or `rewrite`");
+ });
+}
+
FailureOr<ast::Decl *> Parser::parsePatternDecl() {
SMRange loc = curToken.getLoc();
consumeToken(Token::kw_Pattern);
@@ -568,29 +602,37 @@ FailureOr<ast::Decl *> Parser::parsePatternDecl() {
// Parse the pattern body.
ast::CompoundStmt *body;
- if (curToken.isNot(Token::l_brace))
- return emitError("expected `{` to start pattern body");
- FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
- if (failed(bodyResult))
- return failure();
- body = *bodyResult;
-
- // Verify the body of the pattern.
- auto bodyIt = body->begin(), bodyE = body->end();
- for (; bodyIt != bodyE; ++bodyIt) {
- // Break when we've found the rewrite statement.
- if (isa<ast::OpRewriteStmt>(*bodyIt))
- break;
- }
- if (bodyIt == bodyE) {
- return emitError(loc,
- "expected Pattern body to terminate with an operation "
- "rewrite statement, such as `erase`");
- }
- if (std::next(bodyIt) != bodyE) {
- return emitError((*std::next(bodyIt))->getLoc(),
- "Pattern body was terminated by an operation "
- "rewrite statement, but found trailing statements");
+ // Handle a lambda body.
+ if (curToken.is(Token::equal_arrow)) {
+ FailureOr<ast::CompoundStmt *> bodyResult = parsePatternLambdaBody();
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+ } else {
+ if (curToken.isNot(Token::l_brace))
+ return emitError("expected `{` or `=>` to start pattern body");
+ FailureOr<ast::CompoundStmt *> bodyResult = parseCompoundStmt();
+ if (failed(bodyResult))
+ return failure();
+ body = *bodyResult;
+
+ // Verify the body of the pattern.
+ auto bodyIt = body->begin(), bodyE = body->end();
+ for (; bodyIt != bodyE; ++bodyIt) {
+ // Break when we've found the rewrite statement.
+ if (isa<ast::OpRewriteStmt>(*bodyIt))
+ break;
+ }
+ if (bodyIt == bodyE) {
+ return emitError(loc,
+ "expected Pattern body to terminate with an operation "
+ "rewrite statement, such as `erase`");
+ }
+ if (std::next(bodyIt) != bodyE) {
+ return emitError((*std::next(bodyIt))->getLoc(),
+ "Pattern body was terminated by an operation "
+ "rewrite statement, but found trailing statements");
+ }
}
return createPatternDecl(loc, name, metadata, body);
diff --git a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
index caa084cda0b68..42ea657821120 100644
--- a/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/pattern-failure.pdll
@@ -1,6 +1,6 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
-// CHECK: expected `{` to start pattern body
+// CHECK: expected `{` or `=>` to start pattern body
Pattern }
// -----
@@ -27,6 +27,11 @@ Pattern {
// -----
+// CHECK: expected Pattern lambda body to contain a single operation rewrite statement, such as `erase`, `replace`, or `rewrite`
+Pattern => op<>;
+
+// -----
+
//===----------------------------------------------------------------------===//
// Metadata
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-pdll/Parser/pattern.pdll b/mlir/test/mlir-pdll/Parser/pattern.pdll
index 1a7851606213e..f0b2046e4b1b8 100644
--- a/mlir/test/mlir-pdll/Parser/pattern.pdll
+++ b/mlir/test/mlir-pdll/Parser/pattern.pdll
@@ -23,3 +23,11 @@ Pattern NamedPattern {
Pattern NamedPattern with benefit(10), recursion {
erase _: Op;
}
+
+// -----
+
+// CHECK: Module
+// CHECK: `-PatternDecl {{.*}} Name<NamedPattern>
+// CHECK: `-CompoundStmt
+// CHECK: `-EraseStmt
+Pattern NamedPattern => erase _: Op;
More information about the Mlir-commits
mailing list