[Mlir-commits] [mlir] 9325b8d - [mlir][Linalg] Add conv ops with TF definition.

Hanhan Wang llvmlistbot at llvm.org
Wed Feb 10 23:00:13 PST 2021


Author: Hanhan Wang
Date: 2021-02-10T22:59:38-08:00
New Revision: 9325b8da1702238f15f837b6e07f099baf4dcd94

URL: https://github.com/llvm/llvm-project/commit/9325b8da1702238f15f837b6e07f099baf4dcd94
DIFF: https://github.com/llvm/llvm-project/commit/9325b8da1702238f15f837b6e07f099baf4dcd94.diff

LOG: [mlir][Linalg] Add conv ops with TF definition.

The dimension order of a filter in tensorflow is
[filter_height, filter_width, in_channels, out_channels], which is different
from current definition. The current definition follows TOSA spec. Add TF
version conv ops to .tc, so we do not have to insert a transpose op around a
conv op.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D96038

Added: 
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-input-ncw-filter-wcf-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-input-nwc-filter-wcf-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-input-nchw-filter-hwcf-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-input-nhwc-filter-hwcf-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-input-ncdhw-filter-dhwcf-call.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-input-ndhwc-filter-dhwcf-call.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 6c6c92332f8f..6692f7d5831e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -124,3 +124,114 @@ Note: this op only supports channel multiplier == 1.
   O(n, oh, ow, c) = std_addf<kh, kw>(std_mulf(
     I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c)));
 }
+
+ods_def<ConvInputNWCFilterWCFOp>:
+def conv_1d_input_nwc_filter_wcf(I: f32(N, W, C), K: f32(KW, C, F)) -> (O: f32(N, W, F))
+  attr(strides: 1xi64, dilations: 1xi64)
+""" A 1-D convolution given NWC layout input and WCF layout filter.
+
+Computes a 1-D convolution given 3-D input and filter. The data layout
+of input is NWC and the data layout of filter is WCF.
+
+The indexing maps for these three tensors contain 5 dimensions, following the
+order of (`N`, `W`, `F`, `KW`, `C`).
+"""
+{
+  O(n, w, f) = std_addf<kw>(
+      std_mulf(I(n, w * strides[0] + kw * dilations[0], c), K(kw, c, f)));
+}
+
+ods_def<ConvInputNCWFilterWCFOp>:
+def conv_1d_input_ncw_filter_wcf(I: f32(N, C, W), K: f32(KW, C, F)) -> (O: f32(N, F, W))
+  attr(strides: 1xi64, dilations: 1xi64)
+""" A 1-D convolution given NCW layout input and WCF layout filter.
+
+Computes a 1-D convolution given 3-D input and filter. The data layout
+of input is NCW and the data layout of filter is WCF.
+
+The indexing maps for these three tensors contain 5 dimensions, following the
+order of (`N`, `F`, `W`, `KW`, `C`).
+"""
+{
+  O(n, f, w) = std_addf<kw>(
+      std_mulf(I(n, c, w * strides[0] + kw * dilations[0]), K(kw, c, f)));
+}
+
+ods_def<ConvInputNHWCFilterHWCFOp>:
+def conv_2d_input_nhwc_filter_hwcf(I: f32(N, H, W, C), K: f32(KH, KW, C, F)) -> (O: f32(N, H, W, F))
+  attr(strides: 2xi64, dilations: 2xi64)
+""" A 2-D convolution given NHWC layout input and HWCF layout filter.
+
+Computes a 2-D convolution given 4-D input and filter. The data layout
+of input is NHWC and the data layout of filter is HWCF.
+
+The indexing maps for these three tensors contain 7 dimensions, following the
+order of (`N`, `H`, `W`, `F`, `KH`, `KW`, `C`).
+"""
+{
+  O(n, h, w, f) =
+      std_addf<kh, kw>(std_mulf(I(n, h * strides[0] + kh * dilations[0],
+                                  w * strides[1] + kw * dilations[1], c),
+                                K(kh, kw, c, f)));
+}
+
+ods_def<ConvInputNCHWFilterHWCFOp>:
+def conv_2d_input_nchw_filter_hwcf
+    (I: f32(N, C, H, W), K: f32(KH, KW, C, F))
+  -> (O: f32(N, F, H, W))
+  attr(strides: 2xi64, dilations: 2xi64)
+""" A 2-D convolution given NCHW layout input and HWCF layout filter.
+
+Computes a 2-D convolution given 4-D input and filter. The data layout
+of input is NCHW and the data layout of filter is HWCF.
+
+The indexing maps for these three tensors contain 7 dimensions, following the
+order of (`N`, `F`, `H`, `W`, `KH`, `KW`, `C`).
+"""
+{
+  O(n, f, h, w) =
+      std_addf<kh, kw>(std_mulf(I(n, c, h * strides[0] + kh * dilations[0],
+                                  w * strides[1] + kw * dilations[1]),
+                                K(kh, kw, c, f)));
+}
+
+ods_def<ConvInputNDHWCFilterDHWCFOp>:
+def conv_3d_input_ndhwc_filter_dhwcf
+    (I: f32(N, D, H, W, C), K: f32(KD, KH, KW, C, F))
+  -> (O: f32(N, D, H, W, F))
+  attr(strides: 3xi64, dilations: 3xi64)
+""" A 3-D convolution given NDHWC layout input and DHWCF layout filter.
+
+Computes a 3-D convolution given 5-D input and filter. The data layout
+of input is NDHWC and the data layout of filter is DHWCF.
+
+The indexing maps for these three tensors contain 9 dimensions, following the
+order of (`N`, `D`, `H`, `W`, `F`, `KD`, `KH`, `KW`, `C`).
+"""
+{
+  O(n, d, h, w, f) =
+      std_addf<kd, kh, kw>(std_mulf(I(n, d * strides[0] + kd * dilations[0],
+                                      h * strides[1] + kh * dilations[1],
+                                      w * strides[2] + kw * dilations[2], c),
+                                    K(kd, kh, kw, c, f)));
+}
+
+ods_def<ConvInputNCDHWFilterDHWCFOp>:
+def conv_3d_input_ncdhw_filter_dhwcf
+    (I: f32(N, C, D, H, W), K: f32(KD, KH, KW, C, F))
+  -> (O: f32(N, F, D, H, W))
+  attr(strides: 3xi64, dilations: 3xi64)
+""" A 3-D convolution given NCDHW layout input and DHWCF layout filter.
+
+Computes a 3-D convolution given 5-D input and filter. The data layout
+of input is NCDHW and the data layout of filter is DHWCF.
+
+The indexing maps for these three tensors contain 9 dimensions, following the
+order of (`N`, `F`, `D`, `H`, `W`, `KD`, `KH`, `KW`, `C`).
+"""
+{
+  O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
+      I(n, c, d * strides[0] + kd * dilations[0],
+        h * strides[1] + kh * dilations[1], w * strides[2] + kw * dilations[2]),
+      K(kd, kh, kw, c, f)));
+}

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-input-ncw-filter-wcf-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-input-ncw-filter-wcf-call.mlir
new file mode 100644
index 000000000000..3b99a112f8f2
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-input-ncw-filter-wcf-call.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,4" -convert-linalg-to-loops -convert-scf-to-std \
+// RUN:   -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,4" \
+// RUN:   -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func private @print_memref_f32(memref<*xf32>)
+
+// Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f
+func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> memref<?x?x?xf32> {
+  %buf = alloc(%s1, %s2, %s3) : memref<?x?x?xf32>
+  linalg.fill(%buf, %f) : memref<?x?x?xf32>, f32
+  return %buf : memref<?x?x?xf32>
+}
+
+func @conv_1d_input_ncw_filter_wcf(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  linalg.conv_1d_input_ncw_filter_wcf {dilations = dense<1> : tensor<1xi64>,
+                         strides = dense<1> : tensor<1xi64>}
+     ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%arg2: memref<?x?x?xf32>)
+  return
+}
+
+func @main() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %c6 = constant 6 : index
+  %c8 = constant 8 : index
+  %f10 = constant 10.00000e+00 : f32
+  %val = constant 2.00000e+00 : f32
+  %zero = constant 0.00000e+00 : f32
+
+  %filter1D_ncw = call @alloc_3d_filled_f32(%c3, %c1, %c1, %val) : (index, index, index, f32) -> (memref<?x?x?xf32>)
+  %in1D_ncw = call @alloc_3d_filled_f32(%c1, %c1, %c8, %val) : (index, index, index, f32) -> (memref<?x?x?xf32>)
+  %out1D_ncw = call @alloc_3d_filled_f32(%c1, %c1, %c6, %zero) : (index, index, index, f32) -> (memref<?x?x?xf32>)
+
+  store %f10, %in1D_ncw[%c0, %c0, %c3] : memref<?x?x?xf32>
+  call @conv_1d_input_ncw_filter_wcf(%in1D_ncw, %filter1D_ncw, %out1D_ncw) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  %out1D_ncw_ = memref_cast %out1D_ncw : memref<?x?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%out1D_ncw_): (memref<*xf32>) -> ()
+
+  dealloc %filter1D_ncw : memref<?x?x?xf32>
+  dealloc %in1D_ncw : memref<?x?x?xf32>
+  dealloc %out1D_ncw : memref<?x?x?xf32>
+  return
+}
+
+// CHECK:       Unranked Memref {{.*}}
+// CHECK-NEXT:  [
+// CHECK-SAME:   [
+// CHECK-SAME:    [12, 28, 28, 28, 12, 12]
+// CHECK-SAME:   ]
+// CHECK-SAME:  ]

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-input-nwc-filter-wcf-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-input-nwc-filter-wcf-call.mlir
new file mode 100644
index 000000000000..4936e025202b
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-input-nwc-filter-wcf-call.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,4" -convert-linalg-to-loops -convert-scf-to-std \
+// RUN:   -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,4" \
+// RUN:   -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func private @print_memref_f32(memref<*xf32>)
+
+// Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f
+func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> memref<?x?x?xf32> {
+  %buf = alloc(%s1, %s2, %s3) : memref<?x?x?xf32>
+  linalg.fill(%buf, %f) : memref<?x?x?xf32>, f32
+  return %buf : memref<?x?x?xf32>
+}
+
+func @conv_1d_input_nwc_filter_wcf(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+  linalg.conv_1d_input_nwc_filter_wcf {dilations = dense<1> : tensor<1xi64>,
+                         strides = dense<1> : tensor<1xi64>}
+     ins (%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%arg2: memref<?x?x?xf32>)
+  return
+}
+
+func @main() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %c6 = constant 6 : index
+  %c8 = constant 8 : index
+  %f10 = constant 10.00000e+00 : f32
+  %val = constant 2.00000e+00 : f32
+  %zero = constant 0.00000e+00 : f32
+
+  %filter1D_nwc = call @alloc_3d_filled_f32(%c3, %c1, %c1, %val) : (index, index, index, f32) -> (memref<?x?x?xf32>)
+  %in1D_nwc = call @alloc_3d_filled_f32(%c3, %c8, %c1, %val) : (index, index, index, f32) -> (memref<?x?x?xf32>)
+  %out1D_nwc = call @alloc_3d_filled_f32(%c3, %c6, %c1, %zero) : (index, index, index, f32) -> (memref<?x?x?xf32>)
+
+  store %f10, %in1D_nwc[%c0, %c3, %c0] : memref<?x?x?xf32>
+  call @conv_1d_input_nwc_filter_wcf(%in1D_nwc, %filter1D_nwc, %out1D_nwc) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()
+  %out1D_nwc_ = memref_cast %out1D_nwc : memref<?x?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%out1D_nwc_): (memref<*xf32>) -> ()
+
+  dealloc %filter1D_nwc : memref<?x?x?xf32>
+  dealloc %in1D_nwc : memref<?x?x?xf32>
+  dealloc %out1D_nwc : memref<?x?x?xf32>
+  return
+}
+
+// CHECK:       Unranked Memref {{.*}}
+// CHECK-NEXT:  [
+// CHECK-SAME:   [
+// CHECK-SAME:    [12],
+// CHECK-COUNT-3: [28],
+// CHECK-NEXT:    [12],
+// CHECK-NEXT:    [12]
+// CHECK-SAME:   ],
+// CHECK-NEXT:   [
+// CHECK-COUNT-5: [12],
+// CHECK-NEXT:    [12]
+// CHECK-SAME:   ],
+// CHECK-NEXT:   [
+// CHECK-COUNT-5: [12],
+// CHECK-NEXT:    [12]
+// CHECK-SAME:   ]
+// CHECK-SAME:  ]

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-input-nchw-filter-hwcf-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-input-nchw-filter-hwcf-call.mlir
new file mode 100644
index 000000000000..58bc537a0194
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-input-nchw-filter-hwcf-call.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,0,4,4" -convert-linalg-to-loops -convert-scf-to-std \
+// RUN:   -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3"  -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,0,4,4" \
+// RUN:   -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func private @print_memref_f32(memref<*xf32>)
+
+// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
+func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f : f32) -> memref<?x?x?x?xf32> {
+  %buf = alloc(%s1, %s2, %s3, %s4) : memref<?x?x?x?xf32>
+  linalg.fill(%buf, %f) : memref<?x?x?x?xf32>, f32
+  return %buf : memref<?x?x?x?xf32>
+}
+
+func @conv_2d_input_nchw_filter_hwcf(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
+  linalg.conv_2d_input_nchw_filter_hwcf {dilations = dense<1> : tensor<2xi64>,
+                          strides = dense<1> : tensor<2xi64>}
+     ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+    outs (%arg2: memref<?x?x?x?xf32>)
+  return
+}
+
+func @main() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %c6 = constant 6 : index
+  %c8 = constant 8 : index
+  %f10 = constant 10.00000e+00 : f32
+  %val = constant 2.00000e+00 : f32
+  %zero = constant 0.00000e+00 : f32
+
+  %filter2D_nchw = call @alloc_4d_filled_f32(%c3, %c3, %c1, %c1, %val) : (index, index, index, index, f32) -> (memref<?x?x?x?xf32>)
+  %in2D_nchw = call @alloc_4d_filled_f32(%c3, %c1, %c8, %c8, %val) : (index, index, index, index, f32) -> (memref<?x?x?x?xf32>)
+  %out2D_nchw = call @alloc_4d_filled_f32(%c3, %c1, %c6, %c6, %zero) : (index, index, index, index, f32) -> (memref<?x?x?x?xf32>)
+
+  store %f10, %in2D_nchw[%c0, %c0, %c0, %c3] : memref<?x?x?x?xf32>
+  call @conv_2d_input_nchw_filter_hwcf(%in2D_nchw, %filter2D_nchw, %out2D_nchw) : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>) -> ()
+  %out2D_nchw_ = memref_cast %out2D_nchw : memref<?x?x?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%out2D_nchw_): (memref<*xf32>) -> ()
+
+  dealloc %filter2D_nchw : memref<?x?x?x?xf32>
+  dealloc %in2D_nchw : memref<?x?x?x?xf32>
+  dealloc %out2D_nchw : memref<?x?x?x?xf32>
+  return
+}
+
+// CHECK:       Unranked Memref {{.*}}
+// CHECK-NEXT:  [
+// CHECK-SAME:   [
+// CHECK-SAME:    [
+// CHECK-SAME:     [36,     52,     52,     52,     36,     36],
+// CHECK-COUNT-5:  [36,     36,     36,     36,     36,     36]
+// CHECK-SAME:    ]
+// CHECK-SAME:   ],
+// CHECK-NEXT:   [
+// CHECK-SAME:    [
+// CHECK-COUNT-6:  [36,     36,     36,     36,     36,     36]
+// CHECK-SAME:    ]
+// CHECK-SAME:   ],
+// CHECK-NEXT:   [
+// CHECK-SAME:    [
+// CHECK-COUNT-6:  [36,     36,     36,     36,     36,     36]
+// CHECK-SAME:    ]
+// CHECK-SAME:   ]
+// CHECK-SAME:  ]

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-input-nhwc-filter-hwcf-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-input-nhwc-filter-hwcf-call.mlir
new file mode 100644
index 000000000000..c50c84c63a01
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-input-nhwc-filter-hwcf-call.mlir
@@ -0,0 +1,129 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,3,2" -convert-linalg-to-loops -convert-scf-to-std \
+// RUN:   -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,3,2" \
+// RUN:   -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func private @print_memref_f32(memref<*xf32>)
+
+// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
+func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f : f32) -> memref<?x?x?x?xf32> {
+  %buf = alloc(%s1, %s2, %s3, %s4) : memref<?x?x?x?xf32>
+  linalg.fill(%buf, %f) : memref<?x?x?x?xf32>, f32
+  return %buf : memref<?x?x?x?xf32>
+}
+
+func @conv_2d_input_nhwc_filter_hwcf(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
+  linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>,
+                          strides = dense<1> : tensor<2xi64>}
+     ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+    outs (%arg2: memref<?x?x?x?xf32>)
+  return
+}
+
+func @main() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %c6 = constant 6 : index
+  %c8 = constant 8 : index
+  %f10 = constant 10.00000e+00 : f32
+  %val = constant 2.00000e+00 : f32
+  %zero = constant 0.00000e+00 : f32
+
+  %filter2D_nhwc = call @alloc_4d_filled_f32(%c3, %c3, %c3, %c1, %val) :(index, index, index, index, f32) -> (memref<?x?x?x?xf32>)
+  %in2D_nhwc = call @alloc_4d_filled_f32(%c3, %c8, %c8, %c3, %val) : (index, index, index, index, f32) -> (memref<?x?x?x?xf32>)
+  %out2D_nhwc = call @alloc_4d_filled_f32(%c3, %c6, %c6, %c1, %zero) : (index, index, index, index, f32) -> (memref<?x?x?x?xf32>)
+
+  store %f10, %in2D_nhwc[%c0, %c0, %c3, %c0] : memref<?x?x?x?xf32>
+  call @conv_2d_input_nhwc_filter_hwcf(%in2D_nhwc, %filter2D_nhwc, %out2D_nhwc) : (memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>) -> ()
+  %out2D_nhwc_ = memref_cast %out2D_nhwc : memref<?x?x?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%out2D_nhwc_): (memref<*xf32>) -> ()
+
+  dealloc %filter2D_nhwc : memref<?x?x?x?xf32>
+  dealloc %in2D_nhwc : memref<?x?x?x?xf32>
+  dealloc %out2D_nhwc : memref<?x?x?x?xf32>
+  return
+}
+
+// CHECK:       Unranked Memref {{.*}}
+// CHECK-NEXT:  [
+// CHECK-SAME:   [
+// CHECK-SAME:    [
+// CHECK-SAME:     [108],
+// CHECK-COUNT-3:  [124],
+// CHECK-COUNT-2:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ]
+// CHECK-SAME:   ],
+// CHECK-NEXT:   [
+// CHECK-SAME:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ]
+// CHECK-SAME:   ],
+// CHECK-NEXT:   [
+// CHECK-SAME:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-COUNT-6:  [108]
+// CHECK-SAME:    ]
+// CHECK-SAME:   ]
+// CHECK-SAME:  ]

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-input-ncdhw-filter-dhwcf-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-input-ncdhw-filter-dhwcf-call.mlir
new file mode 100644
index 000000000000..9cf45f1bd18c
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-input-ncdhw-filter-dhwcf-call.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,5,5,5" -convert-linalg-to-loops -convert-scf-to-std \
+// RUN:   -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,5,5,5" \
+// RUN:   -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func private @print_memref_f32(memref<*xf32>)
+
+// Creates and returns 5-D buffer of size (%s1, %s2, %s3, %s4, %s5) filled with the value %f
+func @alloc_5d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index, %f : f32) -> memref<?x?x?x?x?xf32> {
+  %buf = alloc(%s1, %s2, %s3, %s4, %s5) : memref<?x?x?x?x?xf32>
+  linalg.fill(%buf, %f) : memref<?x?x?x?x?xf32>, f32
+  return %buf : memref<?x?x?x?x?xf32>
+}
+
+func @conv_3d_input_ncdhw_filter_dhwcf(%arg0: memref<?x?x?x?x?xf32>, %arg1: memref<?x?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
+  linalg.conv_3d_input_ncdhw_filter_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                           strides = dense<1> : tensor<3xi64>}
+     ins (%arg0, %arg1: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%arg2: memref<?x?x?x?x?xf32>)
+  return
+}
+
+func @main() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %c6 = constant 6 : index
+  %c8 = constant 8 : index
+  %f10 = constant 10.00000e+00 : f32
+  %val = constant 2.00000e+00 : f32
+  %zero = constant 0.00000e+00 : f32
+
+  %filter3D_ncdhw = call @alloc_5d_filled_f32(%c3, %c3, %c3, %c1, %c1, %val) : (index, index, index, index, index, f32) -> (memref<?x?x?x?x?xf32>)
+  %in3D_ncdhw = call @alloc_5d_filled_f32(%c1, %c1, %c8, %c8, %c8, %val) : (index, index, index, index, index, f32) -> (memref<?x?x?x?x?xf32>)
+  %out3D_ncdhw = call @alloc_5d_filled_f32(%c1, %c1, %c6, %c6, %c6, %zero) : (index, index, index, index, index, f32) -> (memref<?x?x?x?x?xf32>)
+
+  store %f10, %in3D_ncdhw[%c0, %c0, %c0, %c0, %c3] : memref<?x?x?x?x?xf32>
+  call @conv_3d_input_ncdhw_filter_dhwcf(%in3D_ncdhw, %filter3D_ncdhw, %out3D_ncdhw) : (memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
+  %out3D_ncdhw_ = memref_cast %out3D_ncdhw : memref<?x?x?x?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%out3D_ncdhw_): (memref<*xf32>) -> ()
+
+  dealloc %filter3D_ncdhw : memref<?x?x?x?x?xf32>
+  dealloc %in3D_ncdhw : memref<?x?x?x?x?xf32>
+  dealloc %out3D_ncdhw : memref<?x?x?x?x?xf32>
+  return
+}
+
+// CHECK:       Unranked Memref {{.*}}
+// CHECK-NEXT:  [
+// CHECK-SAME:   [
+// CHECK-SAME:    [
+// CHECK-SAME:     [
+// CHECK-SAME:      [108,      124,      124,      124,      108,      108],
+// CHECK-COUNT-5:   [108,      108,      108,      108,      108,      108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108,      108,      108,      108,      108,      108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108,      108,      108,      108,      108,      108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108,      108,      108,      108,      108,      108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108,      108,      108,      108,      108,      108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108,      108,      108,      108,      108,      108]
+// CHECK-SAME:     ]
+// CHECK-SAME:    ]
+// CHECK-SAME:   ]
+// CHECK-SAME:  ]

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-input-ndhwc-filter-dhwcf-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-input-ndhwc-filter-dhwcf-call.mlir
new file mode 100644
index 000000000000..e425b9a132f4
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-input-ndhwc-filter-dhwcf-call.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,5,5,5" -convert-linalg-to-loops -convert-scf-to-std \
+// RUN:   -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,5,5,5" \
+// RUN:   -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func private @print_memref_f32(memref<*xf32>)
+
+// Creates and returns 5-D buffer of size (%s1, %s2, %s3, %s4, %s5) filled with the value %f
+func @alloc_5d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index, %f : f32) -> memref<?x?x?x?x?xf32> {
+  %buf = alloc(%s1, %s2, %s3, %s4, %s5) : memref<?x?x?x?x?xf32>
+  linalg.fill(%buf, %f) : memref<?x?x?x?x?xf32>, f32
+  return %buf : memref<?x?x?x?x?xf32>
+}
+
+func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: memref<?x?x?x?x?xf32>, %arg1: memref<?x?x?x?x?xf32>, %arg2: memref<?x?x?x?x?xf32>) {
+  linalg.conv_3d_input_ndhwc_filter_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                           strides = dense<1> : tensor<3xi64>}
+     ins (%arg0, %arg1: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%arg2: memref<?x?x?x?x?xf32>)
+  return
+}
+
+
+func @main() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %c6 = constant 6 : index
+  %c8 = constant 8 : index
+  %f10 = constant 10.00000e+00 : f32
+  %val = constant 2.00000e+00 : f32
+  %zero = constant 0.00000e+00 : f32
+
+  %filter3D_ndhwc = call @alloc_5d_filled_f32(%c3, %c3, %c3, %c1, %c1, %val) : (index, index, index, index, index, f32) -> (memref<?x?x?x?x?xf32>)
+  %in3D_ndhwc = call @alloc_5d_filled_f32(%c1, %c8, %c8, %c8, %c1, %val) : (index, index, index, index, index, f32) -> (memref<?x?x?x?x?xf32>)
+  %out3D_ndhwc = call @alloc_5d_filled_f32(%c1, %c6, %c6, %c6, %c1, %zero) : (index, index, index, index, index, f32) -> (memref<?x?x?x?x?xf32>)
+
+  store %f10, %in3D_ndhwc[%c0, %c0, %c0, %c3, %c0] : memref<?x?x?x?x?xf32>
+  call @conv_3d_input_ndhwc_filter_dhwcf(%in3D_ndhwc, %filter3D_ndhwc, %out3D_ndhwc) : (memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
+  %out3D_ndhwc_ = memref_cast %out3D_ndhwc : memref<?x?x?x?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%out3D_ndhwc_): (memref<*xf32>) -> ()
+
+  dealloc %filter3D_ndhwc : memref<?x?x?x?x?xf32>
+  dealloc %in3D_ndhwc : memref<?x?x?x?x?xf32>
+  dealloc %out3D_ndhwc : memref<?x?x?x?x?xf32>
+  return
+}
+
+// CHECK:       Unranked Memref {{.*}}
+// CHECK-NEXT:  [
+// CHECK-SAME:   [
+// CHECK-SAME:    [
+// CHECK-SAME:     [
+// CHECK-SAME:      [108],
+// CHECK-COUNT-3:   [124],
+// CHECK-COUNT-2:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-SAME:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-SAME:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-SAME:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-SAME:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ]
+// CHECK-SAME:    ],
+// CHECK-NEXT:    [
+// CHECK-SAME:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ],
+// CHECK-NEXT:     [
+// CHECK-COUNT-6:   [108]
+// CHECK-SAME:     ]
+// CHECK-SAME:    ]
+// CHECK-SAME:   ]
+// CHECK-SAME:  ]

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 86f05c38ed89..49d323aebe92 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -622,27 +622,39 @@ void mlir::linalg::populateConvVectorizationPatterns(
 
   populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
                                               tileSizes, context);
+  populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
+      tiling, promotion, vectorization, tileSizes, context);
 
   populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
                                               tileSizes, context);
+  populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
+      tiling, promotion, vectorization, tileSizes, context);
 
   populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
                                              tileSizes, context);
 
   populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
                                                tileSizes, context);
+  populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
+      tiling, promotion, vectorization, tileSizes, context);
 
   populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
                                                tileSizes, context);
+  populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
+      tiling, promotion, vectorization, tileSizes, context);
 
   populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
                                               tileSizes, context);
 
   populateVectorizationPatterns<ConvNDHWCOp, 5>(
       tiling, promotion, vectorization, tileSizes, context);
+  populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
+      tiling, promotion, vectorization, tileSizes, context);
 
   populateVectorizationPatterns<ConvNCDHWOp, 5>(
       tiling, promotion, vectorization, tileSizes, context);
+  populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
+      tiling, promotion, vectorization, tileSizes, context);
 
   patterns.push_back(std::move(tiling));
   patterns.push_back(std::move(promotion));

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 44207308ab00..74daf57e2584 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -99,3 +99,163 @@ func @depthwise_conv_2d_input_nhwc_filter_hwc(%input: memref<1x113x113x96xf32>,
 // CHECK-NEXT:      %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
 // CHECK-NEXT:      %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
 // CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+
+// -----
+
+func @conv_1d_input_nwc_filter_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.conv_1d_input_nwc_filter_wcf {dilations = dense<1> : tensor<1xi64>,
+                                       strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+
+// CHECK: func @conv_1d_input_nwc_filter_wcf
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+
+// -----
+
+func @conv_1d_input_ncw_filter_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+  linalg.conv_1d_input_ncw_filter_wcf {dilations = dense<1> : tensor<1xi64>,
+                                       strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2 + d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+
+// CHECK: func @conv_1d_input_ncw_filter_wcf
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+
+// -----
+
+func @conv_2d_input_nhwc_filter_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
+  linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<2> : tensor<2xi64>,
+                                         strides = dense<3> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?xf32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 3 + d4 * 2, d2 * 3 + d5 * 2, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+
+// CHECK: func @conv_2d_input_nhwc_filter_hwcf
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+
+// -----
+
+func @conv_2d_input_nchw_filter_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
+  linalg.conv_2d_input_nchw_filter_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?xf32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d6, d2 + d4, d3 + d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+
+// CHECK: func @conv_2d_input_nchw_filter_hwcf
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+
+// -----
+
+func @conv_3d_input_ndhwc_filter_dhwcf(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+  linalg.conv_3d_input_ndhwc_filter_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                                           strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?x?xf32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+
+// CHECK: func @conv_3d_input_ndhwc_filter_dhwcf
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?x?xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+
+// -----
+
+func @conv_3d_input_ncdhw_filter_dhwcf(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+  linalg.conv_3d_input_ncdhw_filter_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                                           strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?x?xf32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d8, d2 + d5, d3 + d6, d4 + d7)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+
+// CHECK: func @conv_3d_input_ncdhw_filter_dhwcf
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "parallel"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?x?xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index f4a8fa94a7e0..b68c0ad2591d 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -54,3 +54,195 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113
     outs(%output: memref<1x56x56x96xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @conv_1d_input_nwc_filter_wcf
+func @conv_1d_input_nwc_filter_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  // CHECK:      %{{.+}} = linalg.conv_1d_input_nwc_filter_wcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<1xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<1xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %0 = linalg.conv_1d_input_nwc_filter_wcf {dilations = dense<1> : tensor<1xi64>,
+                                            strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_1d_input_nwc_filter_wcf
+func @conv_1d_input_nwc_filter_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+  // CHECK:      linalg.conv_1d_input_nwc_filter_wcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<1xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<1xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.conv_1d_input_nwc_filter_wcf {dilations = dense<1> : tensor<1xi64>,
+                                       strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_1d_input_ncw_filter_wcf
+func @conv_1d_input_ncw_filter_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  // CHECK:      %{{.+}} = linalg.conv_1d_input_ncw_filter_wcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<1xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<1xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %0 = linalg.conv_1d_input_ncw_filter_wcf {dilations = dense<1> : tensor<1xi64>,
+                                            strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+    outs (%init: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_1d_input_ncw_filter_wcf
+func @conv_1d_input_ncw_filter_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+  // CHECK:      linalg.conv_1d_input_ncw_filter_wcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<1xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<1xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?xf32>)
+  linalg.conv_1d_input_ncw_filter_wcf {dilations = dense<1> : tensor<1xi64>,
+                                       strides = dense<1> : tensor<1xi64>}
+     ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+    outs (%output: memref<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
+func @conv_2d_input_nhwc_filter_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  // CHECK:      %{{.+}} = linalg.conv_2d_input_nhwc_filter_hwcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  %0 = linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
+func @conv_2d_input_nhwc_filter_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
+  // CHECK:      linalg.conv_2d_input_nhwc_filter_hwcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?xf32>)
+  linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_input_nchw_filter_hwcf
+func @conv_2d_input_nchw_filter_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  // CHECK:      %{{.+}} = linalg.conv_2d_input_nchw_filter_hwcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  %0 = linalg.conv_2d_input_nchw_filter_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_input_nchw_filter_hwcf
+func @conv_2d_input_nchw_filter_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
+  // CHECK:      linalg.conv_2d_input_nchw_filter_hwcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?xf32>)
+  linalg.conv_2d_input_nchw_filter_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
+func @conv_3d_input_ndhwc_filter_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  // CHECK:      %{{.+}} = linalg.conv_3d_input_ndhwc_filter_dhwcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<3xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<3xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  %0 = linalg.conv_3d_input_ndhwc_filter_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
+func @conv_3d_input_ndhwc_filter_dhwcf(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+  // CHECK:      linalg.conv_3d_input_ndhwc_filter_dhwcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<3xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<3xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?x?xf32>)
+  linalg.conv_3d_input_ndhwc_filter_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                                           strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_3d_input_ncdhw_filter_dhwcf
+func @conv_3d_input_ncdhw_filter_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  // CHECK:      %{{.+}} = linalg.conv_3d_input_ncdhw_filter_dhwcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<3xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<3xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  %0 = linalg.conv_3d_input_ncdhw_filter_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                                                strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_3d_input_ncdhw_filter_dhwcf
+func @conv_3d_input_ncdhw_filter_dhwcf(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+  // CHECK:      linalg.conv_3d_input_ncdhw_filter_dhwcf
+  // CHECK-SAME:   dilations = dense<1> : tensor<3xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<3xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?x?xf32>)
+  linalg.conv_3d_input_ncdhw_filter_dhwcf {dilations = dense<1> : tensor<3xi64>,
+                                           strides = dense<1> : tensor<3xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?x?xf32>)
+  return
+}


        


More information about the Mlir-commits mailing list