[Mlir-commits] [mlir] f584633 - Added static verification for Linalg Ops.

Hanhan Wang llvmlistbot at llvm.org
Tue Mar 30 07:11:08 PDT 2021


Author: Inho Seo
Date: 2021-03-30T07:10:17-07:00
New Revision: f58463345415508b1fb5e3d35330ed288f1a0357

URL: https://github.com/llvm/llvm-project/commit/f58463345415508b1fb5e3d35330ed288f1a0357
DIFF: https://github.com/llvm/llvm-project/commit/f58463345415508b1fb5e3d35330ed288f1a0357.diff

LOG: Added static verification for Linalg Ops.

This verification is to check if the indices for static shaped operands
on linalgOps access out of bound memory or not. For dynamic shaped
operands, we would be able to check it on runtime stage.

Found several invalid Linalg ops testcases, and fixed them.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D98390

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/test/Dialect/Linalg/fusion-2-level.mlir
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir
    mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
    mlir/test/Dialect/Linalg/sparse_nd.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
    mlir/test/Dialect/Linalg/tile-indexed-generic.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index bb87f205b68b..f1bf22cf3d69 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -433,5 +433,54 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
     ++idx;
   }
 
+  // Check if given shapes match to inferred shapes.
+  Optional<SmallVector<int64_t, 4>> loopRanges = linalgOp.getStaticLoopRanges();
+  if (!loopRanges)
+    return linalgOp.emitError("unable to find loop range for operation");
+
+  // Verify only static cases since we can't get exact dimension sizes and loop
+  // ranges for dynamic cases in this stage.
+  if (llvm::none_of(*loopRanges, [](int64_t &range) {
+        return range == ShapedType::kDynamicSize;
+      })) {
+    for (int64_t &range : *loopRanges)
+      range -= 1;
+    for (const auto &en : llvm::enumerate(linalgOp.getShapedOperandTypes())) {
+      auto indices = indexingMaps[en.index()].compose(*loopRanges);
+      for (auto j : llvm::seq<unsigned>(0, en.value().getRank())) {
+
+        // Ignore dynamic dimension or the case that the inferred last index is
+        // zero. The index is increasing or decreasing in Linalg, for example,
+        // the last index should be `0` or `size-1`. We only check the cases
+        // that are non-zero because most of cases are increasing and it is too
+        // expensive to find the shape of decreasing cases.
+        if (en.value().isDynamicDim(j) || indices[j] == 0)
+          continue;
+
+        // The size of shaped operands and inferred dimension size should be
+        // same. But, for now we check if the inferred sizes are in boundary of
+        // shaped operands' size or not in case that Affine Expressions are
+        // complicated such as d0 * 3 + d1 since it is not easy to handle the
+        // issues.
+        auto inferredSize = indices[j] + 1;
+        auto shapedDimSize = en.value().getDimSize(j);
+        if (indexingMaps[en.index()].getResult(j).dyn_cast<AffineDimExpr>()) {
+          if (inferredSize != shapedDimSize) {
+            return linalgOp.emitOpError("inferred shaped operand #")
+                   << en.index() << " has shape's dimension #" << j << " to be "
+                   << inferredSize << ", but found " << shapedDimSize;
+          }
+        } else {
+          if (inferredSize > shapedDimSize) {
+            return linalgOp.emitOpError("inferred shaped operand #")
+                   << en.index() << " has shape's dimension #" << j
+                   << " to be greater than or equal to " << inferredSize
+                   << ", but found " << shapedDimSize;
+          }
+        }
+      }
+    }
+  }
+
   return success();
 }

diff  --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
index 3d0a09a22646..e19639479205 100644
--- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
@@ -28,7 +28,7 @@ func @f1(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, %B: memref<?x?xf32, of
             scf.for %arg10 = %c0 to %10 step %c4 {
               %14 = memref.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
               %16 = memref.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-              %17 = memref.subview %8[%arg8, %arg9][%c2, %c4][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+              %17 = memref.subview %8[%arg8, %arg9][%c2, %c3][%c1, %c1] : memref<?x?xf32, offset: ?, strides: [?, ?]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
               linalg.matmul ins(%14, %16: memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>)
                            outs(%17: memref<?x?xf32, offset: ?, strides: [?, ?]>)
             }

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 90b75abaca7d..7bd6dcb1404b 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
 
-func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) {
-  linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32>
+func @generalize_conv(%input : memref<1x449x562x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) {
+  linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x449x562x3xf32>, memref<1x112x112x32xf32>
   return
 }
 
@@ -10,7 +10,7 @@ func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32
 // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
 
 // CHECK: func @generalize_conv
-// CHECK-SAME:  %[[INPUT:.+]]: memref<1x225x225x3xf32>
+// CHECK-SAME:  %[[INPUT:.+]]: memref<1x449x562x3xf32>
 // CHECK-SAME: %[[FILTER:.+]]: memref<3x3x3x32xf32>
 // CHECK-SAME: %[[OUTPUT:.+]]: memref<1x112x112x32xf32>
 

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 28ff66abfe62..bdaf0ea351aa 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -703,3 +703,23 @@ func @illegal_fill_tensor_with_memref_return
   %0 = linalg.fill(%arg0, %arg1) : tensor<?x?xf32>, f32 -> memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }
+
+// -----
+
+func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+  // expected-error @+1 {{inferred shaped operand #1 has shape's dimension #0 to be 4, but found 3}}
+  linalg.matmul ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
+                      outs(%arg2 :memref<2x4xf32>)
+  return
+}
+
+// -----
+
+func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) {
+  // expected-error @+1 {{inferred shaped operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}}
+  linalg.conv_2d_input_nhwc_filter_hwcf
+    { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %filter : memref<1x3x4x2xf32>, memref<3x2x2x1xf32>)
+    outs(%output : memref<1x2x3x1xf32>)
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 499acc602b51..4e49afb891a7 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -282,15 +282,15 @@ func @conv_3d_input_ncdhw_filter_dhwcf(%input: memref<?x?x?x?x?xf32>, %filter: m
 // CHECK:         %{{.+}} = linalg.pooling_nhwc_sum
 // CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
 // CHECK-SAME:      strides = dense<1> : tensor<2xi64>
-// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>)
 // CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
-func @pooling_nhwc_sum_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> {
+func @pooling_nhwc_sum_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> {
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
   %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
   %cst = constant 0.000000e+00 : f32
   %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32>
   %res = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
-    ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>)
+    ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
     outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
   return %res : tensor<1x2x2x1xf32>
 }
@@ -301,11 +301,11 @@ func @pooling_nhwc_sum_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32
 // CHECK:         linalg.pooling_nhwc_sum
 // CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
 // CHECK-SAME:      strides = dense<1> : tensor<2xi64>
-// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>)
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>)
 // CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xf32>)
-func @pooling_nhwc_sum(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
+func @pooling_nhwc_sum(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
   linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
-    ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>)
+    ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>)
     outs(%output: memref<1x2x2x1xf32>)
   return
 }
@@ -316,15 +316,15 @@ func @pooling_nhwc_sum(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %out
 // CHECK:         %{{.+}} = linalg.pooling_nhwc_max
 // CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
 // CHECK-SAME:      strides = dense<1> : tensor<2xi64>
-// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>)
 // CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
-func @pooling_nhwc_max_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> {
+func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> {
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
   %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
   %cst = constant 0.000000e+00 : f32
   %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32>
   %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
-    ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>)
+    ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
     outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
   return %res : tensor<1x2x2x1xf32>
 }
@@ -335,11 +335,11 @@ func @pooling_nhwc_max_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32
 // CHECK:         linalg.pooling_nhwc_max
 // CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
 // CHECK-SAME:      strides = dense<1> : tensor<2xi64>
-// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>)
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>)
 // CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xf32>)
-func @pooling_nhwc_max(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
+func @pooling_nhwc_max(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
   linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
-    ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>)
+    ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>)
     outs(%output: memref<1x2x2x1xf32>)
   return
 }
@@ -350,15 +350,15 @@ func @pooling_nhwc_max(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %out
 // CHECK:         %{{.+}} = linalg.pooling_nhwc_min
 // CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
 // CHECK-SAME:      strides = dense<1> : tensor<2xi64>
-// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>)
 // CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
-func @pooling_nhwc_min_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> {
+func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> {
   %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
   %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32>
   %cst = constant 0.000000e+00 : f32
   %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32>
   %res = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
-    ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>)
+    ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>)
     outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32>
   return %res : tensor<1x2x2x1xf32>
 }
@@ -369,11 +369,11 @@ func @pooling_nhwc_min_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32
 // CHECK:         linalg.pooling_nhwc_min
 // CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
 // CHECK-SAME:      strides = dense<1> : tensor<2xi64>
-// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>)
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>)
 // CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xf32>)
-func @pooling_nhwc_min(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
+func @pooling_nhwc_min(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) {
   linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
-    ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>)
+    ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>)
     outs(%output: memref<1x2x2x1xf32>)
   return
 }

diff  --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
index 382f6016ee2d..b1a04db826de 100644
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
@@ -168,9 +168,9 @@ func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
 
 #map0 = affine_map<(d0, d1, d2) -> (d0)>
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
-#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
+#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
   %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
   %1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32>
   %2 = linalg.generic
@@ -183,9 +183,9 @@ func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf3
     return %2 : tensor<5x7x3xf32>
 }
 
-//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)>
-//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//       CHECK: func @generic_op_120_permultation_reshape_producer_fusion
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+//       CHECK: func @generic_op_120_permutation_reshape_producer_fusion
 //   CHECK-NOT:   linalg.tensor_reshape
 //       CHECK:   linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]

diff  --git a/mlir/test/Dialect/Linalg/sparse_nd.mlir b/mlir/test/Dialect/Linalg/sparse_nd.mlir
index 56ab7de4f0f0..bc282ca5f8ba 100644
--- a/mlir/test/Dialect/Linalg/sparse_nd.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_nd.mlir
@@ -21,7 +21,7 @@
 
 // CHECK-LABEL:   func @mul(
 // CHECK-SAME:              %[[VAL_0:.*0]]: tensor<10x20x30x40x50x60x70x80xf32>,
-// CHECK-SAME:              %[[VAL_1:.*1]]: tensor<10x20x30x40x50x60x70x80xf32>,
+// CHECK-SAME:              %[[VAL_1:.*1]]: tensor<80x70x60x50x40x30x20x10xf32>,
 // CHECK-SAME:              %[[VAL_2:.*2]]: tensor<10x20x30x40x50x60x70x80xf32>) -> tensor<10x20x30x40x50x60x70x80xf32> {
 // CHECK:           %[[VAL_3:.*]] = constant 3 : index
 // CHECK:           %[[VAL_4:.*]] = constant 4 : index
@@ -34,11 +34,11 @@
 // CHECK:           %[[VAL_11:.*]] = constant 0 : index
 // CHECK:           %[[VAL_12:.*]] = constant 1 : index
 // CHECK:           %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_0]] : memref<10x20x30x40x50x60x70x80xf32>
-// CHECK:           %[[VAL_14:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_3]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xindex>
-// CHECK:           %[[VAL_15:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_3]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xindex>
-// CHECK:           %[[VAL_16:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_4]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xindex>
-// CHECK:           %[[VAL_17:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_4]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xindex>
-// CHECK:           %[[VAL_18:.*]] = linalg.sparse_values %[[VAL_1]] : tensor<10x20x30x40x50x60x70x80xf32> to memref<?xf32>
+// CHECK:           %[[VAL_14:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = linalg.sparse_values %[[VAL_1]] : tensor<80x70x60x50x40x30x20x10xf32> to memref<?xf32>
 // CHECK:           %[[VAL_19:.*]] = memref.buffer_cast %[[VAL_2]] : memref<10x20x30x40x50x60x70x80xf32>
 // CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<10x20x30x40x50x60x70x80xf32>
 // CHECK:           linalg.copy(%[[VAL_19]], %[[VAL_20]]) : memref<10x20x30x40x50x60x70x80xf32>, memref<10x20x30x40x50x60x70x80xf32>
@@ -84,12 +84,12 @@
 // CHECK:           return %[[VAL_50]] : tensor<10x20x30x40x50x60x70x80xf32>
 // CHECK:         }
 func @mul(%arga: tensor<10x20x30x40x50x60x70x80xf32>,
-          %argb: tensor<10x20x30x40x50x60x70x80xf32>,
+          %argb: tensor<80x70x60x50x40x30x20x10xf32>,
           %argx: tensor<10x20x30x40x50x60x70x80xf32>)
 	      -> tensor<10x20x30x40x50x60x70x80xf32> {
   %0 = linalg.generic #trait_mul
     ins(%arga, %argb: tensor<10x20x30x40x50x60x70x80xf32>,
-                      tensor<10x20x30x40x50x60x70x80xf32>)
+                      tensor<80x70x60x50x40x30x20x10xf32>)
     outs(%argx: tensor<10x20x30x40x50x60x70x80xf32>) {
       ^bb(%a: f32, %b: f32, %x: f32):
         %0 = mulf %a, %b : f32

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 200a237cfdc8..a4def4cd4747 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -67,7 +67,7 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
 
 // -----
 
-func @conv_tensors_static(%input: tensor<1x225x225x32xf32>, %filter: tensor<3x3x3x32xf32>, %elementwise: tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> {
+func @conv_tensors_static(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>, %elementwise: tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> {
   %c112 = constant 112 : index
   %c32 = constant 32 : index
   %c16 = constant 16 : index
@@ -81,7 +81,7 @@ func @conv_tensors_static(%input: tensor<1x225x225x32xf32>, %filter: tensor<3x3x
 
   %conv = linalg.conv_2d_input_nhwc_filter_hwcf
     {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
-    ins(%input, %filter : tensor<1x225x225x32xf32>, tensor<3x3x3x32xf32>)
+    ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>)
     outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
 
   %for0 = scf.for %iv0 = %c0 to %c112 step %c8 iter_args(%arg0 = %fill) -> tensor<1x112x112x32xf32> {
@@ -118,7 +118,7 @@ func @conv_tensors_static(%input: tensor<1x225x225x32xf32>, %filter: tensor<3x3x
 //      CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
 //      CHECK: func @conv_tensors_static
-// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x225x225x32xf32>, %[[FILTER:.+]]: tensor<3x3x3x32xf32>, %[[ELEM:.+]]: tensor<1x112x112x32xf32>)
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x225x225x3xf32>, %[[FILTER:.+]]: tensor<3x3x3x32xf32>, %[[ELEM:.+]]: tensor<1x112x112x32xf32>)
 
 //      CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
 // CHECK-NEXT: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32>
@@ -127,14 +127,14 @@ func @conv_tensors_static(%input: tensor<1x225x225x32xf32>, %filter: tensor<3x3x
 // CHECK-NEXT:   %[[OFFSET_H:.+]] = affine.apply #[[MAP0]](%[[IV0]])
 // CHECK-NEXT:   scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG1:.+]] = %[[ARG0]])
 // CHECK-NEXT:     %[[OFFSET_W:.+]] = affine.apply #[[MAP0]](%[[IV1]])
-// CHECK-NEXT:     %[[ST_INPUT:.+]] = subtensor %arg0[0, %[[OFFSET_H]], %[[OFFSET_W]], 0] [1, 17, 33, 32] [1, 1, 1, 1] : tensor<1x225x225x32xf32> to tensor<1x17x33x32xf32>
+// CHECK-NEXT:     %[[ST_INPUT:.+]] = subtensor %arg0[0, %[[OFFSET_H]], %[[OFFSET_W]], 0] [1, 17, 33, 3] [1, 1, 1, 1] : tensor<1x225x225x3xf32> to tensor<1x17x33x3xf32>
 // CHECK-NEXT:     scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG2:.+]] = %[[ARG1]])
 // CHECK-NEXT:       %[[ST_ELEM:.+]] = subtensor %[[ELEM]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
 // CHECK-NEXT:       %[[ST_ARG2:.+]] = subtensor %[[ARG2]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
 // CHECK-NEXT:       %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV2]]] [3, 3, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x4xf32>
 // CHECK-NEXT:       %[[ST_FILL:.+]] = subtensor %[[FILL]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
 // CHECK-NEXT:       %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
-// CHECK-SAME:         ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<1x17x33x32xf32>, tensor<3x3x3x4xf32>)
+// CHECK-SAME:         ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<1x17x33x3xf32>, tensor<3x3x3x4xf32>)
 // CHECK-SAME:         outs(%[[ST_FILL]] : tensor<1x8x16x4xf32>)
 // CHECK-NEXT:       %[[ADD:.+]] = linalg.generic
 // CHECK-SAME:         ins(%[[ST_CONV]], %[[ST_ELEM]] : tensor<1x8x16x4xf32>, tensor<1x8x16x4xf32>)

diff  --git a/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir b/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir
index 0d38d4eede19..117c82f2550b 100644
--- a/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir
+++ b/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir
@@ -54,10 +54,10 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>)
   ],
   iterator_types = ["parallel", "parallel"]
 }
-func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) {
+func @indexed_generic_matrix(%operand: memref<50x99xf32>, %result: memref<50x50xf32>) {
   linalg.indexed_generic #combined_indices_trait
-     ins(%operand : memref<50x100xf32>)
-    outs(%result : memref<50x100xf32>) {
+     ins(%operand : memref<50x99xf32>)
+    outs(%result : memref<50x50xf32>) {
     ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
       %i_int = index_cast %i: index to i32
       %i_float = sitofp %i_int : i32 to f32


        


More information about the Mlir-commits mailing list