[Mlir-commits] [mlir] [mlir][linalg] Fix crashes in parser on linalg ops without operands (PR #97944)

Felix Schneider llvmlistbot at llvm.org
Sun Jul 7 04:19:45 PDT 2024


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/97944

>From bd56160def5f6ac4c2f25a63eb4bcee54f2630ef Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sun, 7 Jul 2024 11:59:22 +0200
Subject: [PATCH 1/3] [mlir][linalg] Fix crashes in parser on linalg ops
 without operands

`parseDstStyleOp` parses both `ins()` and `outs()` optionally. The parsers
for `linalg.transpose`, `linalg.broadcast` and `linalg.map` however
assume that at least one operand is present in the state, leading to crashes
otherwise.

This patch adds checks to the parsers which stop them from crashing if
no operands were parsed. After the Ops are parsed successfuly, the verifier
takes it from there.

Fix https://github.com/llvm/llvm-project/issues/97857
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp |  9 ++++---
 mlir/test/Dialect/Linalg/invalid.mlir    | 31 ++++++++++++++++++++++++
 2 files changed, 37 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd72..dec84ed5790969 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1356,8 +1356,10 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   if (payloadOpName.has_value()) {
-    addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
-                         ArrayRef(result.operands).drop_back());
+    if (!result.operands.empty())
+      addBodyWithPayloadOp(parser, result, payloadOpName.value(),
+                           payloadOpAttrs,
+                           ArrayRef(result.operands).drop_back());
   } else {
     SmallVector<OpAsmParser::Argument> regionArgs;
     if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1739,7 +1741,8 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc,
                                 ValueRange outputs) {
   buildGenericRegion(builder, loc, region, inputs, outputs,
                      [](OpBuilder &b, Location loc, ValueRange args) {
-                       b.create<linalg::YieldOp>(loc, args[0]);
+                       if (!args.empty())
+                         b.create<linalg::YieldOp>(loc, args[0]);
                      });
 }
 
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 44c81c31ace0f9..cfd269e3e6e3a8 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -455,6 +455,18 @@ func.func @map_input_output_shape_mismatch(
 
 // -----
 
+func.func @map_no_operands(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+  // This must not crash the parser.
+  linalg.map { arith.addf }
+  // expected-error @+1 {{cannot name an operation with no results}}
+  %add = linalg.map { arith.addf }
+  func.return %add : tensor<64xf32>
+}
+
+// -----
+
 func.func @reduce_input_vs_init_dimension_mismatch(
     %input: tensor<16x32x64xf32>,
     %init: tensor<16x64xf32>)  -> tensor<16x64xf32> {
@@ -676,6 +688,16 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
 
 // -----
 
+func.func @transpose_no_operands() -> tensor<32x64x16xf32> {
+  // This must not crash the parser.
+  linalg.transpose permutation = [1, 0, 2]
+  // expected-error @+1 {{cannot name an operation with no results}}
+  %transpose = linalg.transpose permutation = [1, 0, 2]
+  func.return %transpose : tensor<32x64x16xf32>
+}
+
+// -----
+
 func.func @broadcast_input_dims_rank_mismatch(
     %input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
     -> tensor<4x8x16xf32> {
@@ -725,6 +747,15 @@ func.func @broadcast_size_1_extension_not_supported(
       dimensions = [1]
   func.return %bcast : tensor<4x?x16xf32>
 }
+// -----
+
+func.func @broadcast_no_operands()
+    -> tensor<4x?x16xf32> {
+  linalg.broadcast dimensions = [1]
+  // expected-error @+1 {{cannot name an operation with no results}}
+  %broadcast = linalg.broadcast dimensions = [1]
+  func.return %broadcast : tensor<32x64x16xf32>
+}
 
 // -----
 

>From 148e1ec43b9aa6acfece7d3917b22de89cc010cd Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sun, 7 Jul 2024 12:51:35 +0200
Subject: [PATCH 2/3] Add more tests

---
 mlir/test/Dialect/Linalg/invalid.mlir | 56 +++++++++++++++++++++++----
 1 file changed, 49 insertions(+), 7 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index cfd269e3e6e3a8..941851464a0447 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -455,11 +455,23 @@ func.func @map_input_output_shape_mismatch(
 
 // -----
 
-func.func @map_no_operands(
+func.func @map_no_operands1() {
+  // expected-error @+1 {{'linalg.map' op requires one region}}
+  linalg.map { arith.addf }
+}
+
+// -----
+
+func.func @map_no_operands2() {
+  // expected-error @+1 {{'linalg.map' op requires one region}}
+  "linalg.map"() : () -> ()
+}
+
+// -----
+
+func.func @map_no_operands3(
     %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
     -> tensor<64xf32> {
-  // This must not crash the parser.
-  linalg.map { arith.addf }
   // expected-error @+1 {{cannot name an operation with no results}}
   %add = linalg.map { arith.addf }
   func.return %add : tensor<64xf32>
@@ -688,9 +700,23 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
 
 // -----
 
-func.func @transpose_no_operands() -> tensor<32x64x16xf32> {
-  // This must not crash the parser.
+func.func @transpose_no_operands1() {
+  // expected-error @+1 {{'linalg.transpose' op expected 2 operands, but found 0}}
   linalg.transpose permutation = [1, 0, 2]
+}
+
+// -----
+
+func.func @transpose_no_operands2() {
+  // expected-error @+1 {{'linalg.transpose' op expected 2 operands, but found 0}}
+  "linalg.transpose"() <{permutation = array<i64: 1, 0, 2>}> ({
+    ^bb0:
+  }) : () -> ()
+}
+
+// -----
+
+func.func @transpose_no_operands3() -> tensor<32x64x16xf32> {
   // expected-error @+1 {{cannot name an operation with no results}}
   %transpose = linalg.transpose permutation = [1, 0, 2]
   func.return %transpose : tensor<32x64x16xf32>
@@ -747,11 +773,27 @@ func.func @broadcast_size_1_extension_not_supported(
       dimensions = [1]
   func.return %bcast : tensor<4x?x16xf32>
 }
+
 // -----
 
-func.func @broadcast_no_operands()
-    -> tensor<4x?x16xf32> {
+func.func @broadcast_no_operands1() {
+  // expected-error @+1 {{'linalg.broadcast' op expected 2 operands, but found 0}}
   linalg.broadcast dimensions = [1]
+}
+
+// -----
+
+func.func @broadcast_no_operands2() {
+  // expected-error @+1 {{'linalg.broadcast' op expected 2 operands, but found 0}}
+  "linalg.broadcast"() <{dimensions = array<i64: 1>}> ({
+    ^bb0:
+  }) : () -> ()
+}
+
+// -----
+
+func.func @broadcast_no_operands3()
+    -> tensor<4x?x16xf32> {
   // expected-error @+1 {{cannot name an operation with no results}}
   %broadcast = linalg.broadcast dimensions = [1]
   func.return %broadcast : tensor<32x64x16xf32>

>From a295f8b9c2a7bc63359bedec34a1ba71c368f93a Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sun, 7 Jul 2024 13:19:28 +0200
Subject: [PATCH 3/3] Improve error message for linalg.map

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 ++
 mlir/test/Dialect/Linalg/invalid.mlir    | 8 +++++---
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index dec84ed5790969..0754bd95a90f73 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1360,6 +1360,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
       addBodyWithPayloadOp(parser, result, payloadOpName.value(),
                            payloadOpAttrs,
                            ArrayRef(result.operands).drop_back());
+    else
+      result.addRegion();
   } else {
     SmallVector<OpAsmParser::Argument> regionArgs;
     if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 941851464a0447..213ef6c7b2616d 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -456,15 +456,17 @@ func.func @map_input_output_shape_mismatch(
 // -----
 
 func.func @map_no_operands1() {
-  // expected-error @+1 {{'linalg.map' op requires one region}}
+  // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}}
   linalg.map { arith.addf }
 }
 
 // -----
 
 func.func @map_no_operands2() {
-  // expected-error @+1 {{'linalg.map' op requires one region}}
-  "linalg.map"() : () -> ()
+  // expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}}
+  "linalg.map"() ({
+    ^bb0:
+  }) : () -> ()
 }
 
 // -----



More information about the Mlir-commits mailing list