[Mlir-commits] [mlir] f67171a - [mlir][Linalg] Make depthwise convolution naming scheme consistent.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 15 00:01:40 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-15T07:54:29Z
New Revision: f67171ac5896d232b4b9826937ce97064909770b
URL: https://github.com/llvm/llvm-project/commit/f67171ac5896d232b4b9826937ce97064909770b
DIFF: https://github.com/llvm/llvm-project/commit/f67171ac5896d232b4b9826937ce97064909770b.diff
LOG: [mlir][Linalg] Make depthwise convolution naming scheme consistent.
Names should be consistent across all operations otherwise painful bugs will surface.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D113762
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
mlir/test/Dialect/Linalg/named-ops.mlir
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 80560b9f363b5..7cc84620665da 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -69,7 +69,7 @@ metadata: !LinalgOpMetadata
name: matmul_unsigned
cpp_class_name: MatmulUnsignedOp
doc: |-
- Performs a unsigned matrix multiplication of two 2D inputs.
+ Performs an unsigned matrix multiplication of two 2D inputs.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
@@ -1384,14 +1384,14 @@ structured_op: !LinalgStructuredOpConfig
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: depthwise_conv1D_nw
- cpp_class_name: DepthwiseConv1DNwOp
+ name: depthwise_conv_1d_nwc_wc
+ cpp_class_name: DepthwiseConv1DNwcWcOp
doc: |-
Performs depth-wise 1-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most dpethwise convolutions.
+ which is a special case for most depthwise convolutions.
implements:
- LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
@@ -1461,14 +1461,14 @@ structured_op: !LinalgStructuredOpConfig
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: depthwise_conv2D_nhw
- cpp_class_name: DepthwiseConv2DNhwOp
+ name: depthwise_conv_2d_nhwc_hwc
+ cpp_class_name: DepthwiseConv2DNhwcHwcOp
doc: |-
Performs depth-wise 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most dpethwise convolutions.
+ which is a special case for most depthwise convolutions.
implements:
- LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
@@ -1544,8 +1544,8 @@ structured_op: !LinalgStructuredOpConfig
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: depthwise_conv2D_nhw_q
- cpp_class_name: DepthwiseConv2DNhwQOp
+ name: depthwise_conv_2d_nhwc_hwc_q
+ cpp_class_name: DepthwiseConv2DNhwcHwcQOp
doc: |-
Performs depth-wise 2-D convolution.
@@ -1660,8 +1660,8 @@ structured_op: !LinalgStructuredOpConfig
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: depthwise_conv2D_nhwc
- cpp_class_name: DepthwiseConv2DNhwcOp
+ name: depthwise_conv_2d_nhwc_hwcm
+ cpp_class_name: DepthwiseConv2DNhwcHwcmOp
doc: |-
Performs depth-wise 2-D convolution.
@@ -1746,8 +1746,8 @@ structured_op: !LinalgStructuredOpConfig
is_unsigned_cast: false
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
- name: depthwise_conv2D_nhwc_q
- cpp_class_name: DepthwiseConv2DNhwcQOp
+ name: depthwise_conv_2d_nhwc_hwcm_q
+ cpp_class_name: DepthwiseConv2DNhwcHwcmQOp
doc: |-
Performs depth-wise 2-D convolution.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 54165266538c2..e90d1533b5c0e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1230,7 +1230,7 @@ class DepthwiseConvConverter
loc, resultTy.getShape(), resultETy);
if (!isQuantized) {
Value conv = rewriter
- .create<linalg::DepthwiseConv2DNhwcOp>(
+ .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, linalgConvTy, ValueRange{input, weight},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
@@ -1254,7 +1254,7 @@ class DepthwiseConvConverter
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
Value conv =
rewriter
- .create<linalg::DepthwiseConv2DNhwcQOp>(
+ .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d0db26a1bb3bb..3d2f42d174fe1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3037,16 +3037,16 @@ LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input,
loc, newInitTy, init, collapsedInitDims);
Value newConv;
- if (isa<DepthwiseConv2DNhwcOp>(operation)) {
+ if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
newConv = rewriter
- .create<DepthwiseConv2DNhwOp>(
+ .create<DepthwiseConv2DNhwcHwcOp>(
loc, newInitTy, ValueRange{input, collapsedKernel},
ValueRange{collapsedInit}, stride, dilation)
.getResult(0);
- } else if (isa<DepthwiseConv2DNhwcQOp>(operation)) {
+ } else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
newConv =
rewriter
- .create<DepthwiseConv2DNhwQOp>(
+ .create<DepthwiseConv2DNhwcHwcQOp>(
loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
ValueRange{collapsedInit}, stride, dilation)
.getResult(0);
@@ -3062,10 +3062,10 @@ LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input,
}
struct SimplifyDepthwiseConvOp
- : public OpRewritePattern<DepthwiseConv2DNhwcOp> {
- using OpRewritePattern<DepthwiseConv2DNhwcOp>::OpRewritePattern;
+ : public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
+ using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(DepthwiseConv2DNhwcOp op,
+ LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
Value input = op.getInputOperand(0)->get();
@@ -3082,10 +3082,10 @@ struct SimplifyDepthwiseConvOp
};
struct SimplifyDepthwiseConvQOp
- : public OpRewritePattern<DepthwiseConv2DNhwcQOp> {
- using OpRewritePattern<DepthwiseConv2DNhwcQOp>::OpRewritePattern;
+ : public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
+ using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(DepthwiseConv2DNhwcQOp op,
+ LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
Value input = op.getInputOperand(0)->get();
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 2e6b04118ef15..85bed25febdd2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -310,7 +310,7 @@ def conv_3d_ndhwc_dhwcf(
]) * cast(U, K[D.kd, D.kh, D.kw, D.c, D.f])
@linalg_structured_op
-def depthwise_conv1D_nw(
+def depthwise_conv_1d_nwc_wc(
I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
K=TensorDef(T2, S.KW, S.IC),
O=TensorDef(U, S.N, S.OW, S.IC, output=True),
@@ -320,7 +320,7 @@ def depthwise_conv1D_nw(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most dpethwise convolutions.
+ which is a special case for most depthwise convolutions.
"""
implements(ConvolutionOpInterface)
domain(D.n, D.ow, D.ic, D.kw)
@@ -329,7 +329,7 @@ def depthwise_conv1D_nw(
cast(U, K[D.kw, D.ic])
@linalg_structured_op
-def depthwise_conv2D_nhw(
+def depthwise_conv_2d_nhwc_hwc(
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
@@ -339,7 +339,7 @@ def depthwise_conv2D_nhw(
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. Multiplier is set to 1
- which is a special case for most dpethwise convolutions.
+ which is a special case for most depthwise convolutions.
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
@@ -348,7 +348,7 @@ def depthwise_conv2D_nhw(
D.ic]) * cast(U, K[D.kh, D.kw, D.ic])
@linalg_structured_op
-def depthwise_conv2D_nhw_q(
+def depthwise_conv_2d_nhwc_hwc_q(
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC),
IZp=ScalarDef(I32),
@@ -369,7 +369,7 @@ def depthwise_conv2D_nhw_q(
(cast(U, K[D.kh, D.kw, D.ic]) - cast(U, KZp)))
@linalg_structured_op
-def depthwise_conv2D_nhwc(
+def depthwise_conv_2d_nhwc_hwcm(
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
@@ -387,7 +387,7 @@ def depthwise_conv2D_nhwc(
D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm])
@linalg_structured_op
-def depthwise_conv2D_nhwc_q(
+def depthwise_conv_2d_nhwc_hwcm_q(
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
IZp=ScalarDef(I32),
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index c7ddb6bdae5c7..d072808b1b476 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1592,7 +1592,7 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
// CHECK: [[CST0:%.+]] = arith.constant 0
// CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
// CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33]
- // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>)
+ // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>)
// CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) {
// CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
@@ -1614,7 +1614,7 @@ func @depthwise_conv_strides(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x
// CHECK: [[CST0:%.+]] = arith.constant 0
// CHECK: [[FILL:%.+]] = linalg.fill([[CST0]], [[INIT]])
// CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 5, 5, 33]
- // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>)
+ // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>)
// CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) {
// CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
@@ -1642,7 +1642,7 @@ func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x12
// CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 12, 12, 512]
// CHECK: [[C128:%.+]] = arith.constant -128
// CHECK: [[C42:%.+]] = arith.constant 42
- // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x12x12x4x128xi32>)
+ // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x12x12x4x128xi32>)
// CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x12x12x512xi32>) outs([[OUT]] : tensor<1x12x12x512xi32>) {
// CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // no predecessors
@@ -1666,7 +1666,7 @@ func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : tenso
// CHECK: [[OUT:%.+]] = linalg.init_tensor [1, 10, 10, 512]
// CHECK: [[C128:%.+]] = arith.constant -128
// CHECK: [[C42:%.+]] = arith.constant 42
- // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nhwc_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x10x10x4x128xi32>)
+ // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, [[C128]], [[C42]] : tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[FILL]] : tensor<1x10x10x4x128xi32>)
// CHECK: [[COLLAPSED:%.+]] = linalg.tensor_collapse_shape [[DEPTH]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<512xi32>, tensor<1x10x10x512xi32>) outs([[OUT]] : tensor<1x10x10x512xi32>) {
// CHECK: ^bb0(%arg3: i32, %arg4: i32, %arg5: i32): // no predecessors
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index aa17f1c56a5dc..e938f8c232e44 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1095,9 +1095,9 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
// CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
// CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
- // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv2D_nhw {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
+ // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
// CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
- %0 = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
+ %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
return %0 : tensor<?x?x?x?x1xf32>
}
@@ -1108,8 +1108,8 @@ func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %ar
func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
// CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
// CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
- // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv2D_nhw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
+ // CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
// CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
- %0 = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
+ %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
return %0 : tensor<?x?x?x?x1xi32>
}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 44a3de3f2722a..961a31d293b9a 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -49,8 +49,8 @@ func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C:
// -----
-func @depthwise_conv2D_nhwc(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
- linalg.depthwise_conv2D_nhwc
+func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
+ linalg.depthwise_conv_2d_nhwc_hwcm
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x3x4x2x3xf32>)
@@ -61,7 +61,7 @@ func @depthwise_conv2D_nhwc(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
-// CHECK: func @depthwise_conv2D_nhwc
+// CHECK: func @depthwise_conv_2d_nhwc_hwcm
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
@@ -76,8 +76,8 @@ func @depthwise_conv2D_nhwc(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3
// -----
-func @depthwise_conv2D_nhwc(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x2x3x2x3xf32>) {
- linalg.depthwise_conv2D_nhwc
+func @depthwise_conv_2d_nhwc_hwcm(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x2x3x2x3xf32>) {
+ linalg.depthwise_conv_2d_nhwc_hwcm
{ dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x2x3x2x3xf32>)
@@ -88,7 +88,7 @@ func @depthwise_conv2D_nhwc(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
-// CHECK: func @depthwise_conv2D_nhwc
+// CHECK: func @depthwise_conv_2d_nhwc_hwcm
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
@@ -103,8 +103,8 @@ func @depthwise_conv2D_nhwc(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3
// -----
-func @depthwise_conv2D_nhw(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
- linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+func @depthwise_conv_2d_nhwc_hwc(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
+ linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@@ -114,7 +114,7 @@ func @depthwise_conv2D_nhw(%input: memref<1x113x113x96xf32>, %filter: memref<3x3
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-// CHECK: func @depthwise_conv2D_nhw
+// CHECK: func @depthwise_conv_2d_nhwc_hwc
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index cd01b6a0a920b..7d574980d63a6 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,94 +1,94 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
-// CHECK-LABEL: func @depthwise_conv2D_nhwc_tensor
-func @depthwise_conv2D_nhwc_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> {
+// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwcm_tensor
+func @depthwise_conv_2d_nhwc_hwcm_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> {
%zero = arith.constant 0.000000e+00 : f32
%init = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
%fill = linalg.fill(%zero, %init) : f32, tensor<2x3x4x2x3xf32> -> tensor<2x3x4x2x3xf32>
- // CHECK: %{{.+}} = linalg.depthwise_conv2D_nhwc
+ // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<2x3x4x2x3xf32>)
- %0 = linalg.depthwise_conv2D_nhwc
+ %0 = linalg.depthwise_conv_2d_nhwc_hwcm
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
outs(%fill : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
return %0 : tensor<2x3x4x2x3xf32>
}
-// CHECK-LABEL: func @depthwise_conv2D_nhwc_memref
-func @depthwise_conv2D_nhwc_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
- // CHECK: linalg.depthwise_conv2D_nhwc
+// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwcm_memref
+func @depthwise_conv_2d_nhwc_hwcm_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
+ // CHECK: linalg.depthwise_conv_2d_nhwc_hwcm
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<2x3x4x2x3xf32>)
- linalg.depthwise_conv2D_nhwc
+ linalg.depthwise_conv_2d_nhwc_hwcm
{ dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x3x4x2x3xf32>)
return
}
-// CHECK-LABEL: func @depthwise_conv1D_nw_tensor
-func @depthwise_conv1D_nw_tensor(%input: tensor<1x113x96xf32>, %filter: tensor<3x96xf32>) -> tensor<1x56x96xf32> {
+// CHECK-LABEL: func @depthwise_conv_1d_nw_tensor
+func @depthwise_conv_1d_nw_tensor(%input: tensor<1x113x96xf32>, %filter: tensor<3x96xf32>) -> tensor<1x56x96xf32> {
%init = linalg.init_tensor [1, 56, 96] : tensor<1x56x96xf32>
- // CHECK: %{{.+}} = linalg.depthwise_conv1D_nw
+ // CHECK: %{{.+}} = linalg.depthwise_conv_1d_nw
// CHECK-SAME: {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x113x96xf32>, tensor<3x96xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<1x56x96xf32>) -> tensor<1x56x96xf32>
- %0 = linalg.depthwise_conv1D_nw {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>}
+ %0 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : vector<1xi64>, strides = dense<2> : vector<1xi64>}
ins(%input, %filter: tensor<1x113x96xf32>, tensor<3x96xf32>)
outs(%init: tensor<1x56x96xf32>) -> tensor<1x56x96xf32>
return %0: tensor<1x56x96xf32>
}
-// CHECK-LABEL: func @depthwise_conv2D_nhw_tensor
-func @depthwise_conv2D_nhw_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> {
+// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor
+func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> {
%init = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
- // CHECK: %{{.+}} = linalg.depthwise_conv2D_nhw
+ // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwc
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
- %0 = linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+ %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
outs(%init: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
return %0: tensor<1x56x56x96xf32>
}
-// CHECK-LABEL: func @depthwise_conv2D_nhw_memref
-func @depthwise_conv2D_nhw_memref(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
- // CHECK: linalg.depthwise_conv2D_nhw
+// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_memref
+func @depthwise_conv_2d_nhwc_hwc_memref(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
+ // CHECK: linalg.depthwise_conv_2d_nhwc_hwc
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<1x56x56x96xf32>)
- linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+ linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
}
-func @depthwise_conv2D_nhwc_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> {
+func @depthwise_conv_2d_nhwc_hwcm_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> {
%zero = arith.constant 0.000000e+00 : f32
%init = linalg.init_tensor [2, 6, 7, 2, 3] : tensor<2x6x7x2x3xf32>
%fill = linalg.fill(%zero, %init) : f32, tensor<2x6x7x2x3xf32> -> tensor<2x6x7x2x3xf32>
- // CHECK: %{{.+}} = linalg.depthwise_conv2D_nhwc
+ // CHECK: %{{.+}} = linalg.depthwise_conv_2d_nhwc_hwcm
// CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<2x6x7x2x3xf32>)
- %0 = linalg.depthwise_conv2D_nhwc
+ %0 = linalg.depthwise_conv_2d_nhwc_hwcm
{ dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
outs(%fill : tensor<2x6x7x2x3xf32>) -> tensor<2x6x7x2x3xf32>
return %0 : tensor<2x6x7x2x3xf32>
}
-// CHECK-LABEL: func @depthwise_conv2D_nhwc_memref_dilated
-func @depthwise_conv2D_nhwc_memref_dilated(%input: memref<2x8x9x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x6x7x2x3xf32>) {
- // CHECK: linalg.depthwise_conv2D_nhwc
+// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwcm_memref_dilated
+func @depthwise_conv_2d_nhwc_hwcm_memref_dilated(%input: memref<2x8x9x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x6x7x2x3xf32>) {
+ // CHECK: linalg.depthwise_conv_2d_nhwc_hwcm
// CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<2x6x7x2x3xf32>)
- linalg.depthwise_conv2D_nhwc
+ linalg.depthwise_conv_2d_nhwc_hwcm
{ dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x6x7x2x3xf32>)
@@ -99,7 +99,7 @@ func @depthwise_conv2D_nhwc_memref_dilated(%input: memref<2x8x9x2xf32>, %filter:
func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{missing indexing map required attribute 'strides'}}
- linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>}
+ linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@@ -109,7 +109,7 @@ func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x11
func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{missing indexing map required attribute 'dilations'}}
- linalg.depthwise_conv2D_nhw {strides = dense<1> : vector<2xi64>}
+ linalg.depthwise_conv_2d_nhwc_hwc {strides = dense<1> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@@ -119,7 +119,7 @@ func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}}
- linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
+ linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@@ -129,7 +129,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}}
- linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
+ linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index aa3d3f55953c2..7afc46db38891 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -192,8 +192,8 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// -----
-func @depthwise_conv1d_nwc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
- linalg.depthwise_conv1D_nw
+func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+ linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
outs(%output : memref<3x2x4xf32>)
@@ -203,7 +203,7 @@ func @depthwise_conv1d_nwc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memr
// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)>
-// CHECK: func @depthwise_conv1d_nwc_3x5x4_memref
+// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref
// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
More information about the Mlir-commits
mailing list