[Mlir-commits] [mlir] a07b422 - [mlir][linalg] Fix `SemiFunctionType` custom parsing crash on missing `()` (#110365)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 3 06:31:29 PST 2024
Author: Felix Schneider
Date: 2024-11-03T15:31:25+01:00
New Revision: a07b422e90174430213201d0b4b307f5ed089d3f
URL: https://github.com/llvm/llvm-project/commit/a07b422e90174430213201d0b4b307f5ed089d3f
DIFF: https://github.com/llvm/llvm-project/commit/a07b422e90174430213201d0b4b307f5ed089d3f.diff
LOG: [mlir][linalg] Fix `SemiFunctionType` custom parsing crash on missing `()` (#110365)
The `SemiFunctionType` allows printing/parsing a set of argument and
result types, where there is always exactly one argument type and zero
or more result types. If there are no result types, the argument type
can be written without enclosing parens in the assembly. If there is at
least one result type, the parens are mandatory.
This patch fixes a bug where omitting the parens around the argument
types for a `SemiFunctionType` with non-optional result Types would
crash the parser. It introduces a `bool` argument `resultOptional` to
the parser and printer which, when `false`, correctly enforces the
parens around argument types, otherwise printing an error.
Fix https://github.com/llvm/llvm-project/issues/109128
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index cdc29d053e5a4b..2da52bbf861668 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
let arguments = (ins TransformHandleTypeInterface:$operand_handle);
let results = (outs TransformParamTypeInterface:$rank);
- let assemblyFormat =
- "$operand_handle attr-dict `:`"
- "custom<SemiFunctionType>(type($operand_handle), type($rank))";
+ let assemblyFormat = [{
+ $operand_handle attr-dict `:`
+ custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
+ }];
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
}
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index abf446887c5442..25a98a16960f37 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -418,9 +418,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat =
- "$target attr-dict `:` "
- "custom<SemiFunctionType>(type($target), type($transformed))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($transformed), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -455,9 +456,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat =
- "$target attr-dict `:` "
- "custom<SemiFunctionType>(type($target), type($transformed))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($transformed), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -500,7 +502,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
let assemblyFormat = [{
$target
(`iterator_interchange` `=` $iterator_interchange^)? attr-dict
- `:` custom<SemiFunctionType>(type($target), type($transformed))
+ `:` custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
let hasVerifier = 1;
@@ -1233,9 +1235,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
OptionalAttr<I64Attr>:$alignment);
let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat =
- "$target attr-dict `:`"
- "custom<SemiFunctionType>(type($target), type($transformed))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($transformed), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -1269,9 +1272,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$replacement);
let regions = (region SizedRegion<1>:$bodyRegion);
- let assemblyFormat =
- "$target attr-dict-with-keyword regions `:` "
- "custom<SemiFunctionType>(type($target), type($replacement))";
+ let assemblyFormat = [{
+ $target attr-dict-with-keyword regions `:`
+ custom<SemiFunctionType>(type($target), type($replacement), "false")
+ }];
let hasVerifier = 1;
}
@@ -1310,9 +1314,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
- let assemblyFormat =
- "$target attr-dict `:`"
- "custom<SemiFunctionType>(type($target), type($result))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($result), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
index 50e55e72226120..595e8aac1045fe 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
@@ -30,7 +30,7 @@ class Operation;
/// the argument type in absence of result types, and does not accept the
/// trailing `-> ()` construct, which makes the syntax nicer for operations.
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
- Type &resultType);
+ Type &resultType, bool resultOptional = true);
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
SmallVectorImpl<Type> &resultTypes);
@@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, TypeRange resultType);
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
- Type argumentType, Type resultType);
+ Type argumentType, Type resultType,
+ bool resultOptional = true);
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
index e340228795cdef..44eac878394b86 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
@@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
- let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($result), "false")
+ }];
let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
::mlir::Value getOperandHandle() { return getTarget(); }
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
index 7ba0a6eb68f48c..266c9ad3314a32 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
@@ -12,9 +12,13 @@
using namespace mlir;
ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
- Type &resultType) {
+ Type &resultType, bool resultOptional) {
argumentType = resultType = nullptr;
- bool hasLParen = parser.parseOptionalLParen().succeeded();
+
+ bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
+ : parser.parseLParen().succeeded();
+ if (!resultOptional && !hasLParen)
+ return failure();
if (parser.parseType(argumentType).failed())
return failure();
if (!hasLParen)
@@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
}
void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
- Type argumentType, Type resultType) {
+ Type argumentType, Type resultType,
+ bool resultOptional) {
+ assert(resultOptional || resultType != nullptr);
return printSemiFunctionType(printer, op, argumentType,
resultType ? TypeRange(resultType)
: TypeRange());
diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index a30b56c7c58e8f..fbebb97a11983e 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error at below {{expected '('}}
+ %res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op
+}
More information about the Mlir-commits
mailing list