[Mlir-commits] [mlir] 458542f - [mlir][linalg] Relax structured op region filler check (#123741)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 28 00:56:03 PST 2025


Author: Adam Siemieniuk
Date: 2025-01-28T09:55:59+01:00
New Revision: 458542f454cdb769801f0b6459405b429503e00a

URL: https://github.com/llvm/llvm-project/commit/458542f454cdb769801f0b6459405b429503e00a
DIFF: https://github.com/llvm/llvm-project/commit/458542f454cdb769801f0b6459405b429503e00a.diff

LOG: [mlir][linalg] Relax structured op region filler check (#123741)

Removes assert on output type from structure op region filler to allow
more graceful error handling.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..7c45f8805c205e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -121,14 +121,11 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
 /// `regionBuilder`. The method is used by both named structured ops created by
 /// ods-gen and by manually defined C++ ops. It is called by both builders and
 /// parsers and creates a block with arguments corresponding to the elemental
-/// types of `inputTypes` and `outputTypes`. All output types are asserted to be
-/// ShapedType.
+/// types of `inputTypes` and `outputTypes`.
 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                                    TypeRange inputTypes, TypeRange outputTypes,
                                    ArrayRef<NamedAttribute> attrs,
                                    RegionBuilderFn regionBuilder) {
-  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
-
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
   for (auto containers : {inputTypes, outputTypes}) {

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..0853856d933035 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -370,6 +370,24 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
 
 // -----
 
+func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+  // expected-error @+1 {{'linalg.matmul' op expected operand rank (0) to match the result rank of indexing_map #0 (2)}}
+  linalg.matmul ins(%arg0, %arg1 : f32, memref<3x4xf32>)
+                outs(%arg2 : memref<2x4xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_scalar_output_matmul(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: f32) {
+  // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any type values, but got 'f32'}}
+  linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>)
+                outs(%arg2 : f32)
+  return
+}
+
+// -----
+
 func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
   // expected-error @+1 {{expected attribute value}}
   linalg.matmul indexing_maps = [


        


More information about the Mlir-commits mailing list