[Mlir-commits] [mlir] [mlir][linalg] Fix `SemiFunctionType` custom parsing crash on missing `()` (PR #110365)

Felix Schneider llvmlistbot at llvm.org
Sat Sep 28 07:13:38 PDT 2024


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/110365

The `SemiFunctionType` allows printing/parsing a set of argument and result types, where there is always exactly one argument type and zero or more result types. If there are no result types, the argument type can be written without enclosing parens in the assembly. If there is at least one result type, the parens are mandatory.

This patch fixes a bug where omitting the parens around the argument types for a `SemiFunctionType` with non-optional result Types would crash the parser. It introduces a `bool` argument `resultOptional` to the parser and printer which, when `false`, correctly enforces the parens around argument types, otherwise printing an error.

Fix https://github.com/llvm/llvm-project/issues/109128

>From de1bc49cf714106d0f23bcd4d2aa483cc1ecd80b Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 28 Sep 2024 15:26:18 +0200
Subject: [PATCH] [mlir][linalg] Fix `SemiFunctionType` custom parsing crash

The `SemiFunctionType` allows printing/parsing a set of argument and
result types, where there is always exactly one argument type and zero
or more result types. If there are no result types, the argument type
can be written without enclosing parens in the assembly. If there is at
least one result type, the parens are mandatory.

This patch fixes a bug where omitting the parens around the argument types
for a `SemiFunctionType` with non-optional result Types would crash the
parser. It introduces a `bool` argument `resultOptional` to the parser
and printer which, when `false`, correctly enforces the parens around
argument types, otherwise printing an error.

Fix https://github.com/llvm/llvm-project/issues/109128
---
 .../Linalg/TransformOps/LinalgMatchOps.td     |  7 ++--
 .../Linalg/TransformOps/LinalgTransformOps.td | 37 +++++++++++--------
 .../mlir/Dialect/Linalg/TransformOps/Syntax.h |  5 ++-
 .../TransformOps/SparseTensorTransformOps.td  |  5 ++-
 .../Dialect/Linalg/TransformOps/Syntax.cpp    | 12 ++++--
 .../Dialect/Linalg/transform-ops-invalid.mlir |  8 ++++
 6 files changed, 49 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index cdc29d053e5a4b..2da52bbf861668 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
 
   let arguments = (ins TransformHandleTypeInterface:$operand_handle);
   let results = (outs TransformParamTypeInterface:$rank);
-  let assemblyFormat =
-      "$operand_handle attr-dict `:`"
-      "custom<SemiFunctionType>(type($operand_handle), type($rank))";
+  let assemblyFormat = [{
+      $operand_handle attr-dict `:`
+      custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
+  }];
 
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
 }
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..efbba1eb065dca 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -382,9 +382,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
 
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$transformed);
-  let assemblyFormat =
-      "$target attr-dict `:` "
-      "custom<SemiFunctionType>(type($target), type($transformed))";
+  let assemblyFormat = [{
+      $target attr-dict `:` 
+      custom<SemiFunctionType>(type($target), type($transformed), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -419,9 +420,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
 
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$transformed);
-  let assemblyFormat =
-      "$target attr-dict `:` "
-      "custom<SemiFunctionType>(type($target), type($transformed))";
+  let assemblyFormat = [{
+      $target attr-dict `:` 
+      custom<SemiFunctionType>(type($target), type($transformed), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -464,7 +466,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   let assemblyFormat = [{
     $target
     (`iterator_interchange` `=` $iterator_interchange^)? attr-dict
-    `:` custom<SemiFunctionType>(type($target), type($transformed))
+    `:` custom<SemiFunctionType>(type($target), type($transformed), "false")
   }];
   let hasVerifier = 1;
 
@@ -1197,9 +1199,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
                        OptionalAttr<I64Attr>:$alignment);
   let results = (outs TransformHandleTypeInterface:$transformed);
 
-  let assemblyFormat =
-    "$target attr-dict `:`"
-    "custom<SemiFunctionType>(type($target), type($transformed))";
+  let assemblyFormat = [{
+    $target attr-dict `:`
+    custom<SemiFunctionType>(type($target), type($transformed), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -1233,9 +1236,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$replacement);
   let regions = (region SizedRegion<1>:$bodyRegion);
-  let assemblyFormat =
-      "$target attr-dict-with-keyword regions `:` "
-      "custom<SemiFunctionType>(type($target), type($replacement))";
+  let assemblyFormat = [{
+      $target attr-dict-with-keyword regions `:`
+      custom<SemiFunctionType>(type($target), type($replacement), "false")
+  }];
   let hasVerifier = 1;
 }
 
@@ -1274,9 +1278,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$result);
 
-  let assemblyFormat =
-    "$target attr-dict `:`"
-    "custom<SemiFunctionType>(type($target), type($result))";
+  let assemblyFormat = [{
+    $target attr-dict `:`
+    custom<SemiFunctionType>(type($target), type($result), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
index 50e55e72226120..595e8aac1045fe 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
@@ -30,7 +30,7 @@ class Operation;
 /// the argument type in absence of result types, and does not accept the
 /// trailing `-> ()` construct, which makes the syntax nicer for operations.
 ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
-                                  Type &resultType);
+                                  Type &resultType, bool resultOptional = true);
 ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
                                   SmallVectorImpl<Type> &resultTypes);
 
@@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
 void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
                            Type argumentType, TypeRange resultType);
 void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
-                           Type argumentType, Type resultType);
+                           Type argumentType, Type resultType,
+                           bool resultOptional = true);
 } // namespace mlir
 
 #endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
index e340228795cdef..44eac878394b86 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
@@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$result);
 
-  let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
+  let assemblyFormat = [{
+    $target attr-dict `:`
+    custom<SemiFunctionType>(type($target), type($result), "false")
+  }];
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
     ::mlir::Value getOperandHandle() { return getTarget(); }
   }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
index 7ba0a6eb68f48c..266c9ad3314a32 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
@@ -12,9 +12,13 @@
 using namespace mlir;
 
 ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
-                                        Type &resultType) {
+                                        Type &resultType, bool resultOptional) {
   argumentType = resultType = nullptr;
-  bool hasLParen = parser.parseOptionalLParen().succeeded();
+
+  bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
+                                  : parser.parseLParen().succeeded();
+  if (!resultOptional && !hasLParen)
+    return failure();
   if (parser.parseType(argumentType).failed())
     return failure();
   if (!hasLParen)
@@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
 }
 
 void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
-                                 Type argumentType, Type resultType) {
+                                 Type argumentType, Type resultType,
+                                 bool resultOptional) {
+  assert(resultOptional || resultType != nullptr);
   return printSemiFunctionType(printer, op, argumentType,
                                resultType ? TypeRange(resultType)
                                           : TypeRange());
diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index e86d4962530a9a..4bbd9bfd1443f4 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
   transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
 
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error at below {{expected '('}}
+  %res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op 
+}



More information about the Mlir-commits mailing list