[Mlir-commits] [mlir] [Linalg] Add matchers to infer which Convolution Op a given linalg.generic is (PR #163374)

Abhishek Varma llvmlistbot at llvm.org
Wed Oct 15 01:40:21 PDT 2025


================
@@ -0,0 +1,615 @@
+// The following test examples of linalg convolution named ops lowered to linalg.generic and then
+// lifted back up to named op.
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @conv_1d_nwc_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.conv_1d_nwc_wcf {dilations = dense<3> : tensor<1xi64>,
+                                       strides = dense<2> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_1d_nwc_wcf
+//      CHECK:   linalg.conv_1d_nwc_wcf
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_1d_ncw_fcw(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.conv_1d_ncw_fcw {dilations = dense<3> : tensor<1xi64>,
+                                       strides = dense<2> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_1d_ncw_fcw
+//      CHECK:   linalg.conv_1d_ncw_fcw
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
+  linalg.conv_1d ins(%in, %filter : memref<?xf32>, memref<?xf32>)
+                outs(%out : memref<?xf32>)
+  return
+}
+//      CHECK: @conv_1d
+//      CHECK:   linalg.conv_1d
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_1d_ncw_cw(%input: memref<?x?x?xf32>, %filter: memref<?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.depthwise_conv_1d_ncw_cw {dilations = dense<3> : tensor<1xi64>,
+                                       strides = dense<2> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @depthwise_conv_1d_ncw_cw
+//      CHECK:   linalg.depthwise_conv_1d_ncw_cw
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_1d_nwc_wc(%input: memref<?x?x?xf32>, %filter: memref<?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>,
+                                       strides = dense<2> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @depthwise_conv_1d_nwc_wc
+//      CHECK:   linalg.depthwise_conv_1d_nwc_wc
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_1d_nwc_wcm(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>,
+                                       strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?x?xf32>)
+  return
+}
+//      CHECK: @depthwise_conv_1d_nwc_wcm
+//      CHECK:   linalg.depthwise_conv_1d_nwc_wcm
+// CHECK-SAME:      dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
+  linalg.conv_2d ins(%in, %filter : memref<?x?xf32>, memref<?x?xf32>)
+                outs(%out: memref<?x?xf32>)
+  return
+}
+//      CHECK: @conv_2d
+//      CHECK:   linalg.conv_2d
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nchw_fchw(%arg0 : tensor<?x?x?x?xf32>,
+    %arg1 : tensor<?x?x?x?xf32>,  %arg2 : tensor<?x?x?x?xf32>) ->
+    (tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>) {
+  %0 = linalg.conv_2d_nchw_fchw {dilations = dense<[2,4]> : tensor<2xi64>, strides = dense<[3,5]> : tensor<2xi64>}
+  ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+      outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  %1 = tensor.cast %0 : tensor<?x?x?x?xf32> to tensor<4x8x12x16xf32>
+  return %1, %0 : tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nchw_fchw
+//      CHECK:   linalg.conv_2d_nchw_fchw
+// CHECK-SAME:      dilations = dense<[2, 4]> : tensor<2xi64>, strides = dense<[3, 5]> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nchw_fchw_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, i32, i32)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nchw_fchw_q
+//      CHECK:   linalg.conv_2d_nchw_fchw_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_ngchw_fgchw
+//      CHECK:   linalg.conv_2d_ngchw_fgchw
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw(%input: memref<?x?x?x?x?xi8>, %filter: memref<?x?x?x?x?xi8>, %output: memref<?x?x?x?x?xi32>) {
+  linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+                                       strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>)
+    outs (%output: memref<?x?x?x?x?xi32>)
+  return
+}
+//      CHECK: @conv_2d_ngchw_gfchw
+//      CHECK:   linalg.conv_2d_ngchw_gfchw
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw_q(%input: memref<?x?x?x?x?xi8>, %filter: memref<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?x?xi32>) {
+  linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>,
+                                       strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter, %inputzp, %filterzp: memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>, i32, i32)
+    outs (%output: memref<?x?x?x?x?xi32>)
+  return
+}
+//      CHECK: @conv_2d_ngchw_gfchw_q
+//      CHECK:   linalg.conv_2d_ngchw_gfchw_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_hwcf_q(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?xf32>) {
+  linalg.conv_2d_nhwc_hwcf_q {
+    dilations = dense<1> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } ins(%input, %filter, %inputzp, %filterzp : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, i32, i32) outs(%output : memref<?x?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_2d_nhwc_hwcf_q
+//      CHECK:   linalg.conv_2d_nhwc_hwcf_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc_q(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?x?xf32>) {
+  linalg.conv_2d_nhwgc_gfhwc_q {
+    dilations = dense<1> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } ins(%input, %filter, %inputzp, %filterzp : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>, i32, i32) outs(%output : memref<?x?x?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_2d_nhwgc_gfhwc_q
+//      CHECK:   linalg.conv_2d_nhwgc_gfhwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>{
+  %res = linalg.depthwise_conv_2d_nhwc_hwc_q {
+    dilations = dense<1> : tensor<2xi64>,
+    strides = dense<1> : tensor<2xi64>
+  } ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%output : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %res : tensor<?x?x?x?xi32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwc_q
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                 strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_fhwc
+//      CHECK:   linalg.conv_2d_nhwc_fhwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, i32, i32)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_fhwc_q
+//      CHECK:   linalg.conv_2d_nhwc_fhwc_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @conv_2d_nhwc_hwcf
+//      CHECK:   linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+  linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_2d_nhwgc_gfhwc
+//      CHECK:   linalg.conv_2d_nhwgc_gfhwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nchw_chw
+//      CHECK:   linalg.depthwise_conv_2d_nchw_chw
+// CHECK-SAME:      dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>,
+                                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwc
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwc
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>,
+                                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwcm
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwcm
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x?xi8>, %arg2: tensor<?x?x?x?x?xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @depthwise_conv_2d_nhwc_hwcm_q
+//      CHECK:   linalg.depthwise_conv_2d_nhwc_hwcm_q
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
+  linalg.conv_3d ins(%in, %filter : memref<?x?x?xf32>, memref<?x?x?xf32>)
+                outs(%out : memref<?x?x?xf32>)
+  return
+}
+//      CHECK: @conv_3d
+//      CHECK:   linalg.conv_3d
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_3d_ncdhw_fcdhw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_3d_ncdhw_fcdhw
+//      CHECK:   linalg.conv_3d_ncdhw_fcdhw
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @conv_3d_ndhwc_dhwcf
+//      CHECK:   linalg.conv_3d_ndhwc_dhwcf
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @conv_3d_ndhwc_dhwcf_q(%input: tensor<?x?x?x?x?xi8>, %filter: tensor<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32> {
+  %0 = linalg.conv_3d_ndhwc_dhwcf_q {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+    ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?x?xi8>, tensor<?x?x?x?x?xi8>, i32, i32)
+    outs (%init: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?x?xi32>
+}
+//      CHECK: @conv_3d_ndhwc_dhwcf_q
+//      CHECK:   linalg.conv_3d_ndhwc_dhwcf_q
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_3d_ncdhw_cdhw {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_3d_ncdhw_cdhw
+//      CHECK:   linalg.depthwise_conv_3d_ncdhw_cdhw
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_3d_ndhwc_dhwc
+//      CHECK:   linalg.depthwise_conv_3d_ndhwc_dhwc
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> {
+  %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?x?xf32>
+}
+//      CHECK: @depthwise_conv_3d_ndhwc_dhwcm
+//      CHECK:   linalg.depthwise_conv_3d_ndhwc_dhwcm
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nchw_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nchw_max
+//      CHECK:   linalg.pooling_nchw_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nchw_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nchw_sum
+//      CHECK:   linalg.pooling_nchw_sum
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ncw_max(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_ncw_max {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+    ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+    outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_ncw_max
+//      CHECK:   linalg.pooling_ncw_max
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ncw_sum(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.pooling_ncw_sum {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+    ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+    outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+//      CHECK: @pooling_ncw_sum
+//      CHECK:   linalg.pooling_ncw_sum
+// CHECK-SAME:      dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ndhwc_max(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @pooling_ndhwc_max
+//      CHECK:   linalg.pooling_ndhwc_max
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ndhwc_min(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @pooling_ndhwc_min
+//      CHECK:   linalg.pooling_ndhwc_min
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_ndhwc_sum(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  %0 = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>,
+                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+//      CHECK: @pooling_ndhwc_sum
+//      CHECK:   linalg.pooling_ndhwc_sum
+// CHECK-SAME:      dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_max
+//      CHECK:   linalg.pooling_nhwc_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_min
+//      CHECK:   linalg.pooling_nhwc_min
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_sum
+//      CHECK:   linalg.pooling_nhwc_sum
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+  %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
+    outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+  return %0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: @pooling_nhwc_max_unsigned
+//      CHECK:   linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+//  CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+//      CHECK: @pooling_nhwc_min_unsigned
+//      CHECK:   linalg.pooling_nhwc_min
----------------
Abhishek-Varma wrote:

So I wanted to flag this for `max/min_unsigned` ops - glad Copilot reminded me!

Step 1: Use the following input IR :-
```
func.func @pooling_nhwc_max_unsigned_float(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
  %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
                                strides = dense<1> : tensor<2xi64>}
     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
  return %0 : tensor<?x?x?x?xf32>
}
func.func @pooling_nhwc_max_signed_float(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
  %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
                                strides = dense<1> : tensor<2xi64>}
     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
  return %0 : tensor<?x?x?x?xf32>
}
```

Step 2: Run `mlir-opt --linalg-generalize-named-ops input.mlir`

We'll observe the following :-
```
  func.func @pooling_nhwc_max_unsigned_float(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
    %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %1 = arith.maximumf %out, %in : f32
      linalg.yield %1 : f32
    } -> tensor<?x?x?x?xf32>
    return %0 : tensor<?x?x?x?xf32>
  }
  func.func @pooling_nhwc_max_signed_float(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
    %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?x?x?xf32>) {
    ^bb0(%in: f32, %in_0: f32, %out: f32):
      %1 = arith.maximumf %out, %in : f32
      linalg.yield %1 : f32
    } -> tensor<?x?x?x?xf32>
    return %0 : tensor<?x?x?x?xf32>
  }
```

Since both the `linalg.generic` ops are IDENTICAL - there's no way to distinguish between these two and therefore we get `linalg.pooling_nhwc_max` (a less constrained version) for both the cases when using `--linalg-specialize-generic-ops`. Basically the matcher, currently, will match both - just the matter of which one is preferred more (as is demonstrated via `Specialize.cpp`).

NOTE: This only happens when the input is a float. In case of integer, the body of the `linalg.generic` has either `arith.maxui` or `arith.maxsi` - therefore let's us distinguish between `linalg.pooling_nhwc_max_unsigned` and `linalg.pooling_nhwc_max` respectively.

CC: @hanhanW @rengolin - let me know how do we want to proceed in this ops' case.

https://github.com/llvm/llvm-project/pull/163374


More information about the Mlir-commits mailing list