[Mlir-commits] [mlir] ec92a12 - [mlir:PDLL] Don't require users to provide operands/results when all are variadic
River Riddle
llvmlistbot at llvm.org
Tue Nov 8 01:58:30 PST 2022
Author: River Riddle
Date: 2022-11-08T01:57:58-08:00
New Revision: ec92a125acc66c67e8d86dcf6e200fe34d204b3d
URL: https://github.com/llvm/llvm-project/commit/ec92a125acc66c67e8d86dcf6e200fe34d204b3d
DIFF: https://github.com/llvm/llvm-project/commit/ec92a125acc66c67e8d86dcf6e200fe34d204b3d.diff
LOG: [mlir:PDLL] Don't require users to provide operands/results when all are variadic
When all operands or results are variadic, zero values is a perfectly valid behavior
to expect, and we shouldn't force the user to provide values in this case. For example,
when creating a call or a return operation we often don't want/need to provide return
values.
Differential Revision: https://reviews.llvm.org/D133721
Added:
Modified:
mlir/lib/Tools/PDLL/Parser/Parser.cpp
mlir/test/lib/Transforms/TestDialectConversion.pdll
mlir/test/mlir-pdll/Parser/expr.pdll
mlir/test/mlir-pdll/Parser/include/ops.td
Removed:
################################################################################
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index ffa7f0cf52ff5..de19f577133e1 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -426,23 +426,23 @@ class Parser {
FailureOr<ast::OperationExpr *>
createOperationExpr(SMRange loc, const ast::OpNameDecl *name,
OpResultTypeContext resultTypeContext,
- MutableArrayRef<ast::Expr *> operands,
+ SmallVectorImpl<ast::Expr *> &operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
- MutableArrayRef<ast::Expr *> results);
+ SmallVectorImpl<ast::Expr *> &results);
LogicalResult
validateOperationOperands(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
- MutableArrayRef<ast::Expr *> operands);
+ SmallVectorImpl<ast::Expr *> &operands);
LogicalResult validateOperationResults(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
- MutableArrayRef<ast::Expr *> results);
+ SmallVectorImpl<ast::Expr *> &results);
void checkOperationResultTypeInferrence(SMRange loc, StringRef name,
const ods::Operation *odsOp);
LogicalResult validateOperationOperandsOrResults(
StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
- Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
+ Optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
- ast::Type rangeTy);
+ ast::RangeType rangeTy);
FailureOr<ast::TupleExpr *> createTupleExpr(SMRange loc,
ArrayRef<ast::Expr *> elements,
ArrayRef<StringRef> elementNames);
@@ -2851,9 +2851,9 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
SMRange loc, const ast::OpNameDecl *name,
OpResultTypeContext resultTypeContext,
- MutableArrayRef<ast::Expr *> operands,
+ SmallVectorImpl<ast::Expr *> &operands,
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
- MutableArrayRef<ast::Expr *> results) {
+ SmallVectorImpl<ast::Expr *> &results) {
Optional<StringRef> opNameRef = name->getName();
const ods::Operation *odsOp = lookupODSOperation(opNameRef);
@@ -2896,7 +2896,7 @@ FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
LogicalResult
Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
- MutableArrayRef<ast::Expr *> operands) {
+ SmallVectorImpl<ast::Expr *> &operands) {
return validateOperationOperandsOrResults(
"operand", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy,
@@ -2906,7 +2906,7 @@ Parser::validateOperationOperands(SMRange loc, Optional<StringRef> name,
LogicalResult
Parser::validateOperationResults(SMRange loc, Optional<StringRef> name,
const ods::Operation *odsOp,
- MutableArrayRef<ast::Expr *> results) {
+ SmallVectorImpl<ast::Expr *> &results) {
return validateOperationOperandsOrResults(
"result", loc, odsOp ? odsOp->getLoc() : Optional<SMRange>(), name,
results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy);
@@ -2956,9 +2956,9 @@ void Parser::checkOperationResultTypeInferrence(SMRange loc, StringRef opName,
LogicalResult Parser::validateOperationOperandsOrResults(
StringRef groupName, SMRange loc, Optional<SMRange> odsOpLoc,
- Optional<StringRef> name, MutableArrayRef<ast::Expr *> values,
+ Optional<StringRef> name, SmallVectorImpl<ast::Expr *> &values,
ArrayRef<ods::OperandOrResult> odsValues, ast::Type singleTy,
- ast::Type rangeTy) {
+ ast::RangeType rangeTy) {
// All operation types accept a single range parameter.
if (values.size() == 1) {
if (failed(convertExpressionTo(values[0], rangeTy)))
@@ -2969,14 +2969,56 @@ LogicalResult Parser::validateOperationOperandsOrResults(
/// If the operation has ODS information, we can more accurately verify the
/// values.
if (odsOpLoc) {
- if (odsValues.size() != values.size()) {
+ auto emitSizeMismatchError = [&] {
return emitErrorAndNote(
loc,
llvm::formatv("invalid number of {0} groups for `{1}`; expected "
"{2}, but got {3}",
groupName, *name, odsValues.size(), values.size()),
*odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name));
+ };
+
+ // Handle the case where no values were provided.
+ if (values.empty()) {
+ // If we don't expect any on the ODS side, we are done.
+ if (odsValues.empty())
+ return success();
+
+ // If we do, check if we actually need to provide values (i.e. if any of
+ // the values are actually required).
+ unsigned numVariadic = 0;
+ for (const auto &odsValue : odsValues) {
+ if (!odsValue.isVariableLength())
+ return emitSizeMismatchError();
+ ++numVariadic;
+ }
+
+ // If we are in a non-rewrite context, we don't need to do anything more.
+ // Zero-values is a valid constraint on the operation.
+ if (parserContext != ParserContext::Rewrite)
+ return success();
+
+ // Otherwise, when in a rewrite we may need to provide values to match the
+ // ODS signature of the operation to create.
+
+ // If we only have one variadic value, just use an empty list.
+ if (numVariadic == 1)
+ return success();
+
+ // Otherwise, create dummy values for each of the entries so that we
+ // adhere to the ODS signature.
+ for (unsigned i = 0, e = odsValues.size(); i < e; ++i) {
+ values.push_back(
+ ast::RangeExpr::create(ctx, loc, /*elements=*/llvm::None, rangeTy));
+ }
+ return success();
}
+
+ // Verify that the number of values provided matches the number of value
+ // groups ODS expects.
+ if (odsValues.size() != values.size())
+ return emitSizeMismatchError();
+
auto diagFn = [&](ast::Diagnostic &diag) {
diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name),
*odsOpLoc);
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.pdll b/mlir/test/lib/Transforms/TestDialectConversion.pdll
index c29e852feeff3..a6cd21159d7d6 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.pdll
+++ b/mlir/test/lib/Transforms/TestDialectConversion.pdll
@@ -10,9 +10,8 @@
#include "mlir/Transforms/DialectConversion.pdll"
/// Change the result type of a producer.
-// FIXME: We shouldn't need to specify arguments for the result cast.
-Pattern => replace op<test.cast>(args: ValueRange) -> (results: TypeRange)
- with op<test.cast>(args) -> (convertTypes(results));
+Pattern => replace op<test.cast> -> (results: TypeRange)
+ with op<test.cast> -> (convertTypes(results));
/// Pass through test.return conversion.
Pattern => replace op<test.return>(args: ValueRange)
diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll
index 6e68883f4edee..0736962dada78 100644
--- a/mlir/test/mlir-pdll/Parser/expr.pdll
+++ b/mlir/test/mlir-pdll/Parser/expr.pdll
@@ -213,6 +213,34 @@ Pattern {
// -----
+// Test that we don't need to provide values if all elements
+// are optional.
+
+#include "include/ops.td"
+
+// CHECK: Module
+// CHECK: -OperationExpr {{.*}} Type<Op<test.multi_variadic>>
+// CHECK-NOT: `Operands`
+// CHECK-NOT: `Result Types`
+// CHECK: -OperationExpr {{.*}} Type<Op<test.all_variadic>>
+// CHECK-NOT: `Operands`
+// CHECK-NOT: `Result Types`
+// CHECK: -OperationExpr {{.*}} Type<Op<test.multi_variadic>>
+// CHECK: `Operands`
+// CHECK: -RangeExpr {{.*}} Type<ValueRange>
+// CHECK: -RangeExpr {{.*}} Type<ValueRange>
+// CHECK: `Result Types`
+// CHECK: -RangeExpr {{.*}} Type<TypeRange>
+// CHECK: -RangeExpr {{.*}} Type<TypeRange>
+Pattern {
+ rewrite op<test.multi_variadic>() -> () with {
+ op<test.all_variadic> -> ();
+ op<test.multi_variadic> -> ();
+ };
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// TupleExpr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-pdll/Parser/include/ops.td b/mlir/test/mlir-pdll/Parser/include/ops.td
index 91c4122921109..575b4756e70f4 100644
--- a/mlir/test/mlir-pdll/Parser/include/ops.td
+++ b/mlir/test/mlir-pdll/Parser/include/ops.td
@@ -28,3 +28,8 @@ def OpAllVariadic : Op<Test_Dialect, "all_variadic"> {
def OpMultipleSingleResult : Op<Test_Dialect, "multiple_single_result"> {
let results = (outs I64:$result, I64:$result2);
}
+
+def OpMultiVariadic : Op<Test_Dialect, "multi_variadic"> {
+ let arguments = (ins Variadic<I64>:$operands, Variadic<I64>:$operand2);
+ let results = (outs Variadic<I64>:$results, Variadic<I64>:$results2);
+}
More information about the Mlir-commits
mailing list