[Mlir-commits] [mlir] [mlir][tosa][linalg] Apply direct tosa -> linalg Conv2D lowering (PR #68304)

Jack Frankland llvmlistbot at llvm.org
Thu Oct 5 05:34:27 PDT 2023


https://github.com/FranklandJack created https://github.com/llvm/llvm-project/pull/68304

TOSA defines the filter channel ordering for 2D convolution operation `tosa.conv2d` as `[OC, KH, KW, IC]`. The LinAlg dialect supports `[F, H, W, C]` and `[H, W, C, F]` orderings via the `linalg.conv_2d_nhwc_fhwc` and `linalg.conv_2d_nhwc_hwcf` operations respectively. Where `F == OC`, `KH == H`, `KW == W` and `C == IC`.

Currently `tosa.conv2d` is lowered to `linalg.conv2d_nhwc_hwcf` meaning we need to insert a transposition operation to permute the filter channels before they can be passed as weights to the linalg op, that is `[F, H, W, C]` -> `[H, W, C, F]`. An analogous transformation needs to be applied to the quantized operation that lowers to `linalg.conv_2d_nhwc_hwcf_q`.

This commit updates the TOSA->LinAlg lowering so that `tosa.conv2d` is lowered to `linalg.conv2d_nhwc_fhwc` removing the need for the introduction of a transposition operation and making the mapping 1-1. It also adds a `linalg.conv_2d_nhwc_fhwc_q` quantized operation to the LinAlg dialect so the same direct 1-1 mapping can be applied to the quantized variant.

This commit does not add any new lit tests but repurposes the current TosaToLinalgNamed tests by removing the checks for transpositions and updating the targeted LinAlg operations from `linalg.conv2d_nhwc_hwcf` to linalg.conv2d_nhwc_fhwc`.


Change-Id: I3b63860806dd9c289acdf013498b6c0e4f221325

>From 58be62d9290bfaad4d0f4f21b97d40f7e073dd26 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Fri, 22 Sep 2023 12:11:07 +0100
Subject: [PATCH] [mlir][tosa][linalg] Apply direct tosa -> linalg Conv2D
 lowering

TOSA defines the filter channel ordering for 2D convolution operation
`tosa.conv2d` as `[OC, KH, KW, IC]`. The LinAlg dialect supports `[F, H,
W, C]` and `[H, W, C, F]` orderings via the `linalg.conv_2d_nhwc_fhwc`
and `linalg.conv_2d_nhwc_hwcf` operations respectively. Where `F == OC`,
`KH == H`, `KW == W` and `C == IC`.

Currently `tosa.conv2d` is lowered to `linalg.conv2d_nhwc_hwcf` meaning
we need to insert a transposition operation to permute the filter
channels before they can be passed as weights to the linalg op, that is
`[F, H, W, C]` -> `[H, W, C, F]`. An analogous transformation needs to
be applied to the quantized operation that lowers to
`linalg.conv_2d_nhwc_hwcf_q`.

This commit updates the TOSA->LinAlg lowering so that `tosa.conv2d` is
lowered to `linalg.conv2d_nhwc_fhwc` removing the need for the
introduction of a transposition operation and making the mapping 1-1. It
also adds a `linalg.conv_2d_nhwc_fhwc_q` quantized operation to the
LinAlg dialect so the same direct 1-1 mapping can be applied to the
quantized variant.

This commit does not add any new lit tests but repurposes the current
TosaToLinalgNamed tests by removing the checks for transpositions and
updating the targeted LinAlg operations from `linalg.conv2d_nhwc_hwcf`
to linalg.conv2d_nhwc_fhwc`.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
Change-Id: I3b63860806dd9c289acdf013498b6c0e4f221325
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 137 ++++++++++++++++++
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  43 +++---
 .../linalg/opdsl/ops/core_named_ops.py        |  30 ++++
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |  20 +--
 4 files changed, 196 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 44bcbbab2bbe9de..cd64b813c11e532 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2575,6 +2575,143 @@ structured_op: !LinalgStructuredOpConfig
                     - !ScalarExpression
                       scalar_arg: KZp
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_nhwc_fhwc_q
+  cpp_class_name: Conv2DNhwcFhwcQOp
+  doc: |-
+    Performs 2-D convolution with zero point offsets.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: FHWC.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
+      s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10,
+      s3, s7, s9)>
+  - !LinalgOperandDefConfig
+    name: IZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: KZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
+      s1, s5, s10)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+      (s2, s6)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+      (s4, s8)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> (d3, d4, d5, d6)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10] -> (d0, d1, d2, d3)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: IZp
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: K
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: KZp
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_nchw_fchw
   cpp_class_name: Conv2DNchwFchwOp
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 62ec44bf9c1e1e1..4214bb57563285c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -248,25 +248,28 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     pad.resize(pad.size() + 2, 0);
     input = applyPad(loc, input, pad, zeroAttr, rewriter);
 
-    // Transpose the kernel to match dimension ordering of the linalg
-    // convolution operation.
-    // TODO(suderman): See if this can be efficiently folded - check whether
-    // the input is used anywhere else, if not fold the constant.
-    SmallVector<int64_t> weightPerm;
-    for (int i = 1; i < resultTy.getRank(); i++)
-      weightPerm.push_back(i);
-    weightPerm.push_back(0);
-
-    SmallVector<int64_t> newWeightShape;
-    for (auto dim : weightPerm)
-      newWeightShape.push_back(weightShape[dim]);
-    auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
-    Value weightPermValue =
-        rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
-    Type newWeightTy =
-        RankedTensorType::get(newWeightShape, weightTy.getElementType());
-    weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                weightPermValue);
+    // For Conv3D transpose the kernel to match dimension ordering of the linalg
+    // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
+    // map directly and then transpose later if desired.
+    if (5 == inputTy.getRank()) {
+      // TODO(suderman): See if this can be efficiently folded - check whether
+      // the input is used anywhere else, if not fold the constant.
+      SmallVector<int64_t> weightPerm;
+      for (int i = 1; i < resultTy.getRank(); i++)
+        weightPerm.push_back(i);
+      weightPerm.push_back(0);
+
+      SmallVector<int64_t> newWeightShape;
+      for (auto dim : weightPerm)
+        newWeightShape.push_back(weightShape[dim]);
+      auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+      Value weightPermValue =
+          rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+      Type newWeightTy =
+          RankedTensorType::get(newWeightShape, weightTy.getElementType());
+      weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+                                                  weightPermValue);
+    }
 
     auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
@@ -977,7 +980,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
     RewritePatternSet *patterns) {
   patterns->add<
       // clang-format off
-      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>,
+      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
       DepthwiseConvConverter,
       MatMulConverter,
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 6eae3d916c92882..a8f8f8e0fbd68b4 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -693,6 +693,36 @@ def conv_2d_nhwc_hwcf_q(
     ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))
 
 
+ at linalg_structured_op
+def conv_2d_nhwc_fhwc_q(
+    I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
+    K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D convolution with zero point offsets.
+
+    Layout:
+      * Input: NHWC.
+      * Kernel: FHWC.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
+    O[D.n, D.oh, D.ow, D.f] += (
+        TypeFn.cast_signed(
+            U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
+        )
+        - TypeFn.cast_signed(U, IZp)
+    ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp))
+
+
 @linalg_structured_op
 def conv_2d_nchw_fchw(
     I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index bf970c84832e9e5..b601bfb28a4f280 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -363,13 +363,11 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 
 // CHECK-LABEL: @conv2d_i8
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
-  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
-  // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
   // CHECK:   arith.extsi
   // CHECK:   arith.addi
@@ -385,13 +383,11 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 
 // CHECK-LABEL: @conv2d_f32
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
-  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
-  // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty()
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
   // CHECK:   arith.addf
   // CHECK:   linalg.yield
@@ -408,13 +404,11 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
 func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
-  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
-  // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
   // CHECK:   %[[ADD:.+]] = arith.addf
   // CHECK:   linalg.yield %[[ADD]] : f32
@@ -468,13 +462,11 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
   // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
 
   // Running convolution
-  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
-  // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]]
   // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
   // CHECK: %[[CST:.+]] = arith.constant 0
   // CHECK: %[[FILL:.+]] = linalg.fill
   // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
+  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
   // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
   // CHECK:   %[[ADD:.+]] = arith.addf
   // CHECK:   linalg.yield %[[ADD]] : f32
@@ -489,7 +481,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK:   tensor.yield %[[C0]]
-  // CHECK: linalg.conv_2d_nhwc_hwcf
+  // CHECK: linalg.conv_2d_nhwc_fhwc
   %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
   return
 }
@@ -501,7 +493,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
   // CHECK:   %[[C22:.+]] = arith.constant -22
   // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK:   tensor.yield %[[C22]]
-  // CHECK: linalg.conv_2d_nhwc_hwcf_q
+  // CHECK: linalg.conv_2d_nhwc_fhwc_q
   %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
   return
 }



More information about the Mlir-commits mailing list