[Mlir-commits] [mlir] b6ec1de - [mlir:PDLL] Allow complex constraints on Rewrite arguments/results
River Riddle
llvmlistbot at llvm.org
Tue Nov 8 01:58:32 PST 2022
Author: River Riddle
Date: 2022-11-08T01:57:58-08:00
New Revision: b6ec1de7cbc4dda73248c6636d0747fd445598a4
URL: https://github.com/llvm/llvm-project/commit/b6ec1de7cbc4dda73248c6636d0747fd445598a4
DIFF: https://github.com/llvm/llvm-project/commit/b6ec1de7cbc4dda73248c6636d0747fd445598a4.diff
LOG: [mlir:PDLL] Allow complex constraints on Rewrite arguments/results
The documentation already has examples of this, and it allows for
using nicer C++ types when defining native Rewrites.
Differential Revision: https://reviews.llvm.org/D133989
Added:
mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td
Modified:
mlir/include/mlir/Tools/PDLL/AST/Nodes.h
mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
Removed:
################################################################################
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index 5f282b0c8884f..2281115dddddb 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -943,7 +943,7 @@ class UserConstraintDecl final
Type resultType)
: Base(name.getLoc(), &name), numInputs(numInputs),
numResults(numResults), codeBlock(codeBlock), constraintBody(body),
- resultType(resultType) {}
+ resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {}
/// The number of inputs to this constraint.
unsigned numInputs;
diff --git a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
index d0ccbe97e4eac..90ceda96e97e2 100644
--- a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
+++ b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
@@ -48,12 +48,9 @@ class CodeCompleteContext {
/// Signal code completion for a constraint name with an optional decl scope.
/// `currentType` is the current type of the variable that will use the
- /// constraint, or nullptr if a type is unknown. `allowNonCoreConstraints`
- /// indicates if user defined constraints are allowed in the completion
- /// results. `allowInlineTypeConstraints` enables inline type constraints for
- /// Attr/Value/ValueRange.
+ /// constraint, or nullptr if a type is unknown. `allowInlineTypeConstraints`
+ /// enables inline type constraints for Attr/Value/ValueRange.
virtual void codeCompleteConstraintName(ast::Type currentType,
- bool allowNonCoreConstraints,
bool allowInlineTypeConstraints,
const ast::DeclScope *scope);
diff --git a/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp b/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
index acc2ca84037dc..9d241ea279641 100644
--- a/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
@@ -24,5 +24,5 @@ void CodeCompleteContext::codeCompleteOperationMemberAccess(
ast::OperationType opType) {}
void CodeCompleteContext::codeCompleteConstraintName(
- ast::Type currentType, bool allowNonCoreConstraints,
- bool allowInlineTypeConstraints, const ast::DeclScope *scope) {}
+ ast::Type currentType, bool allowInlineTypeConstraints,
+ const ast::DeclScope *scope) {}
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index de19f577133e1..3af285e5152bf 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -297,13 +297,10 @@ class Parser {
/// existing constraints that have already been parsed for the same entity
/// that will be constrained by this constraint. `allowInlineTypeConstraints`
/// allows the use of inline Type constraints, e.g. `Value<valueType: Type>`.
- /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined
- /// constraints) may be used with the variable.
FailureOr<ast::ConstraintRef>
parseConstraint(Optional<SMRange> &typeConstraint,
ArrayRef<ast::ConstraintRef> existingConstraints,
- bool allowInlineTypeConstraints,
- bool allowNonCoreConstraints);
+ bool allowInlineTypeConstraints);
/// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl
/// argument or result variable. The constraints for these variables do not
@@ -389,20 +386,16 @@ class Parser {
/// `inferredType` is the type of the variable inferred by the constraints
/// within the list, and is updated to the most refined type as determined by
/// the constraints. Returns success if the constraint list is valid, failure
- /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
- /// defined constraints) may be used with the variable.
+ /// otherwise.
LogicalResult
validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
- ast::Type &inferredType,
- bool allowNonCoreConstraints = true);
+ ast::Type &inferredType);
/// Validate a single reference to a constraint. `inferredType` contains the
/// currently inferred variabled type and is refined within the type defined
/// by the constraint. Returns success if the constraint is valid, failure
- /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user
- /// defined constraints) may be used with the variable.
+ /// otherwise.
LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref,
- ast::Type &inferredType,
- bool allowNonCoreConstraints = true);
+ ast::Type &inferredType);
LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr);
LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr);
@@ -469,7 +462,6 @@ class Parser {
LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr);
LogicalResult codeCompleteAttributeName(Optional<StringRef> opName);
LogicalResult codeCompleteConstraintName(ast::Type inferredType,
- bool allowNonCoreConstraints,
bool allowInlineTypeConstraints);
LogicalResult codeCompleteDialectName();
LogicalResult codeCompleteOperationName(StringRef dialectName);
@@ -1129,18 +1121,7 @@ FailureOr<ast::VariableDecl *> Parser::parseResultDecl(unsigned resultNum) {
// Check to see if this result is named.
if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) {
// Check to see if this name actually refers to a Constraint.
- ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling());
- if (isa_and_nonnull<ast::ConstraintDecl>(existingDecl)) {
- // If yes, and this is a Rewrite, give a nice error message as non-Core
- // constraints are not supported on Rewrite results.
- if (parserContext == ParserContext::Rewrite) {
- return emitError(
- "`Rewrite` results are only permitted to use core constraints, "
- "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`");
- }
-
- // Otherwise, parse this as an unnamed result variable.
- } else {
+ if (!curDeclScope->lookup<ast::ConstraintDecl>(curToken.getSpelling())) {
// If it wasn't a constraint, parse the result similarly to a variable. If
// there is already an existing decl, we will emit an error when defining
// this variable later.
@@ -1662,8 +1643,7 @@ LogicalResult Parser::parseVariableDeclConstraintList(
Optional<SMRange> typeConstraint;
auto parseSingleConstraint = [&] {
FailureOr<ast::ConstraintRef> constraint = parseConstraint(
- typeConstraint, constraints, /*allowInlineTypeConstraints=*/true,
- /*allowNonCoreConstraints=*/true);
+ typeConstraint, constraints, /*allowInlineTypeConstraints=*/true);
if (failed(constraint))
return failure();
constraints.push_back(*constraint);
@@ -1684,8 +1664,7 @@ LogicalResult Parser::parseVariableDeclConstraintList(
FailureOr<ast::ConstraintRef>
Parser::parseConstraint(Optional<SMRange> &typeConstraint,
ArrayRef<ast::ConstraintRef> existingConstraints,
- bool allowInlineTypeConstraints,
- bool allowNonCoreConstraints) {
+ bool allowInlineTypeConstraints) {
auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult {
if (!allowInlineTypeConstraints) {
return emitError(
@@ -1791,12 +1770,10 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
case Token::code_complete: {
// Try to infer the current type for use by code completion.
ast::Type inferredType;
- if (failed(validateVariableConstraints(existingConstraints, inferredType,
- allowNonCoreConstraints)))
+ if (failed(validateVariableConstraints(existingConstraints, inferredType)))
return failure();
- return codeCompleteConstraintName(inferredType, allowNonCoreConstraints,
- allowInlineTypeConstraints);
+ return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints);
}
default:
break;
@@ -1805,13 +1782,9 @@ Parser::parseConstraint(Optional<SMRange> &typeConstraint,
}
FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
- // Constraint arguments may apply more complex constraints via the arguments.
- bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
-
Optional<SMRange> typeConstraint;
return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None,
- /*allowInlineTypeConstraints=*/false,
- allowNonCoreConstraints);
+ /*allowInlineTypeConstraints=*/false);
}
//===----------------------------------------------------------------------===//
@@ -2598,29 +2571,23 @@ Parser::createVariableDecl(StringRef name, SMRange loc, ast::Expr *initializer,
FailureOr<ast::VariableDecl *>
Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc,
const ast::ConstraintRef &constraint) {
- // Constraint arguments may apply more complex constraints via the arguments.
- bool allowNonCoreConstraints = parserContext == ParserContext::Constraint;
ast::Type argType;
- if (failed(validateVariableConstraint(constraint, argType,
- allowNonCoreConstraints)))
+ if (failed(validateVariableConstraint(constraint, argType)))
return failure();
return defineVariableDecl(name, loc, argType, constraint);
}
LogicalResult
Parser::validateVariableConstraints(ArrayRef<ast::ConstraintRef> constraints,
- ast::Type &inferredType,
- bool allowNonCoreConstraints) {
+ ast::Type &inferredType) {
for (const ast::ConstraintRef &ref : constraints)
- if (failed(validateVariableConstraint(ref, inferredType,
- allowNonCoreConstraints)))
+ if (failed(validateVariableConstraint(ref, inferredType)))
return failure();
return success();
}
LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
- ast::Type &inferredType,
- bool allowNonCoreConstraints) {
+ ast::Type &inferredType) {
ast::Type constraintType;
if (const auto *cst = dyn_cast<ast::AttrConstraintDecl>(ref.constraint)) {
if (const ast::Expr *typeExpr = cst->getTypeExpr()) {
@@ -2652,13 +2619,6 @@ LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref,
constraintType = valueRangeTy;
} else if (const auto *cst =
dyn_cast<ast::UserConstraintDecl>(ref.constraint)) {
- if (!allowNonCoreConstraints) {
- return emitError(ref.referenceLoc,
- "`Rewrite` arguments and results are only permitted to "
- "use core constraints, such as `Attr`, `Op`, `Type`, "
- "`TypeRange`, `Value`, `ValueRange`");
- }
-
ArrayRef<ast::VariableDecl *> inputs = cst->getInputs();
if (inputs.size() != 1) {
return emitErrorAndNote(ref.referenceLoc,
@@ -3160,11 +3120,9 @@ LogicalResult Parser::codeCompleteAttributeName(Optional<StringRef> opName) {
LogicalResult
Parser::codeCompleteConstraintName(ast::Type inferredType,
- bool allowNonCoreConstraints,
bool allowInlineTypeConstraints) {
codeCompleteContext->codeCompleteConstraintName(
- inferredType, allowNonCoreConstraints, allowInlineTypeConstraints,
- curDeclScope);
+ inferredType, allowInlineTypeConstraints, curDeclScope);
return failure();
}
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index 476bf16efd7d5..7846103ff9dad 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -760,7 +760,6 @@ class LSPCodeCompleteContext : public CodeCompleteContext {
}
void codeCompleteConstraintName(ast::Type currentType,
- bool allowNonCoreConstraints,
bool allowInlineTypeConstraints,
const ast::DeclScope *scope) final {
auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
@@ -808,9 +807,6 @@ class LSPCodeCompleteContext : public CodeCompleteContext {
while (scope) {
for (const ast::Decl *decl : scope->getDecls()) {
if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(decl)) {
- if (!allowNonCoreConstraints)
- continue;
-
lsp::CompletionItem item;
item.label = cst->getName().getName().str();
item.kind = lsp::CompletionItemKind::Interface;
diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
index f97530700b1d4..21a89661708f7 100644
--- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
@@ -1,4 +1,4 @@
-// RUN: mlir-pdll %s -I %S -split-input-file -x cpp | FileCheck %s
+// RUN: mlir-pdll %s -I %S -I %S/../../../../include -split-input-file -x cpp | FileCheck %s
// Check that we generate a wrapper pattern for each PDL pattern. Also
// add in a pattern awkwardly named the same as our generated patterns to
@@ -44,6 +44,8 @@ Pattern => erase op<test.op3>;
// Check the generation of native constraints and rewrites.
+#include "include/ods.td"
+
// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::mlir::PatternRewriter &rewriter,
// CHECK-SAME: ::mlir::Attribute attr, ::mlir::Operation * op, ::mlir::Type type,
// CHECK-SAME: ::mlir::Value value, ::mlir::TypeRange typeRange, ::mlir::ValueRange valueRange) {
@@ -58,6 +60,7 @@ Pattern => erase op<test.op3>;
// CHECK: foo;
// CHECK: }
+// CHECK: TestAttrInterface TestRewriteODSPDLFn(::mlir::PatternRewriter &rewriter, TestAttrInterface attr) {
// CHECK: static ::mlir::Attribute TestRewriteSinglePDLFn(::mlir::PatternRewriter &rewriter) {
// CHECK: std::tuple<::mlir::Attribute, ::mlir::Type> TestRewriteTuplePDLFn(::mlir::PatternRewriter &rewriter) {
@@ -73,6 +76,7 @@ Constraint TestCst(attr: Attr, op: Op, type: Type, value: Value, typeRange: Type
Constraint TestUnusedCst() [{ return success(); }];
Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) [{ foo; }];
+Rewrite TestRewriteODS(attr: TestAttrInterface) -> TestAttrInterface [{}];
Rewrite TestRewriteSingle() -> Attr [{}];
Rewrite TestRewriteTuple() -> (Attr, Type) [{}];
Rewrite TestUnusedRewrite(op: Op) [{}];
@@ -82,6 +86,7 @@ Pattern TestCstAndRewrite {
TestCst(attr<"true">, root, type, operand, types, operands);
rewrite root with {
TestRewrite(attr<"true">, root, type, operand, types, operands);
+ TestRewriteODS(attr<"true">);
TestRewriteSingle();
TestRewriteTuple();
erase root;
diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td b/mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td
new file mode 100644
index 0000000000000..3eb57a43849cf
--- /dev/null
+++ b/mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td
@@ -0,0 +1,3 @@
+include "mlir/IR/OpBase.td"
+
+def TestAttrInterface : AttrInterface<"TestAttrInterface">;
diff --git a/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
index 1cdb32b5d6b0e..dd8843d8f85d3 100644
--- a/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll
@@ -88,13 +88,6 @@ Rewrite Foo(arg: Value<type>){}
// -----
-Constraint ValueConstraint(value: Value);
-
-// CHECK: arguments and results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
-Rewrite Foo(arg: ValueConstraint);
-
-// -----
-
// CHECK: expected `)` to end argument list
Rewrite Foo(arg: Value{}
@@ -139,13 +132,6 @@ Rewrite Foo() -> Value<type>){}
// -----
-Constraint ValueConstraint(value: Value);
-
-// CHECK: results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`
-Rewrite Foo() -> ValueConstraint;
-
-// -----
-
//===----------------------------------------------------------------------===//
// Native Rewrites
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list