[Mlir-commits] [mlir] 293c21d - [mlir][tosa] Improve lowering of tosa.conv2d (#74143)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Dec 2 04:29:14 PST 2023


Author: Spenser Bauman
Date: 2023-12-02T12:29:10Z
New Revision: 293c21db9381fde27cda46e5c3ff8bf8578e5399

URL: https://github.com/llvm/llvm-project/commit/293c21db9381fde27cda46e5c3ff8bf8578e5399
DIFF: https://github.com/llvm/llvm-project/commit/293c21db9381fde27cda46e5c3ff8bf8578e5399.diff

LOG: [mlir][tosa] Improve lowering of tosa.conv2d (#74143)

The existing lowering of tosa.conv2d emits a separate linalg.generic
operator to add the bias after computing the computation.

This change eliminates that additional step by using the generated
linalg.conv_2d_* operator by using the bias value as the input to the
linalg.conv_2d operation.

Rather than:

    %init = tensor.empty()
    %conv = linalg.conv_2d ins(%A, %B) %outs(%init)

    %init = tensor.empty()
    %bias = linalg.generic ins(%conv, %bias) outs(%init2) {
      // perform add operation
    }

The lowering now produces:

    %init = tensor.empty()
    %bias_expanded = linalg.broadcast ins(%bias) outs(%init)
    %conv = linalg.conv_2d ins(%A, %B) %outs(%bias)

This is the same strategy as
https://github.com/llvm/llvm-project/pull/73049 applied to convolutions.

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 0accd9d1986a..b3fbc7dd0b22 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -344,15 +344,6 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
                                                   weightPermValue);
     }
 
-    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
-    Value emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, resultTy.getShape(), resultETy, filteredDims);
-    Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
-    Value zeroTensor = rewriter
-                           .create<linalg::FillOp>(loc, ValueRange{zero},
-                                                   ValueRange{emptyTensor})
-                           .result();
-
     // Extract the attributes for convolution.
     ArrayRef<int64_t> stride = strideTosaAttr;
     ArrayRef<int64_t> dilation = dilationTosaAttr;
@@ -361,18 +352,12 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     auto strideAttr = rewriter.getI64TensorAttr(stride);
     auto dilationAttr = rewriter.getI64TensorAttr(dilation);
 
-    // Create maps for the bias broadcasting
-    SmallVector<AffineMap, 4> indexingMaps;
-    indexingMaps.push_back(AffineMap::get(
-        /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
-        {rewriter.getAffineDimExpr(resultTy.getRank() - 1)},
-        rewriter.getContext()));
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
-
     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, resultTy.getShape(), resultETy, filteredDims);
 
+    Value broadcastBias =
+        linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+
     if (isQuantized) {
       auto quantizationInfo = *op.getQuantizationInfo();
       auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
@@ -380,38 +365,25 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
+
       Value conv =
           rewriter
               .create<LinalgConvQOp>(
                   loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
-                  ValueRange{zeroTensor}, strideAttr, dilationAttr)
+                  ValueRange{broadcastBias}, strideAttr, dilationAttr)
               ->getResult(0);
-      Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, conv,
-                                                biasEmptyTensor, indexingMaps);
-      rewriter.replaceOp(op, result);
+
+      rewriter.replaceOp(op, conv);
       return success();
     }
 
     Value conv = rewriter
                      .create<LinalgConvOp>(
                          loc, resultTy, ValueRange{input, weight},
-                         ValueRange{zeroTensor}, strideAttr, dilationAttr)
+                         ValueRange{broadcastBias}, strideAttr, dilationAttr)
                      ->getResult(0);
 
-    Value result =
-        rewriter
-            .create<linalg::GenericOp>(
-                loc, resultTy, ValueRange({bias, conv}), biasEmptyTensor,
-                indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
-                [&](OpBuilder &nestedBuilder, Location nestedLoc,
-                    ValueRange args) {
-                  Value added = nestedBuilder.create<arith::AddFOp>(
-                      loc, args[0], args[1]);
-                  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
-                })
-            .getResult(0);
-
-    rewriter.replaceOp(op, result);
+    rewriter.replaceOp(op, conv);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 230001f7633b..aa010e759a0f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -378,16 +378,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
   // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
-  // CHECK: %[[M_IN:.+]] = tensor.empty()
-  // CHECK: %[[CST:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill
-  // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
-  // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
   // CHECK:   arith.extsi
-  // CHECK:   arith.addi
   // CHECK:   linalg.yield
+  // CHECK: } -> tensor<1x45x40x28xi32>
+  // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+  // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+
   %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
   return
 }
@@ -401,15 +399,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
   // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
-  // CHECK: %[[M_IN:.+]] = tensor.empty()
-  // CHECK: %[[CST:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill
-  // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
-  // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
-  // CHECK:   arith.addf
+
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
   // CHECK:   linalg.yield
+  // CHECK: } -> tensor<1x45x40x28xf32>
+  // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%1 : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+
+  // HWCF: linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
   %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
   return
 }
@@ -421,16 +418,14 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
 
 // CHECK-LABEL: @conv2d_dyn
 func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
-  // CHECK: %[[C0:.+]] = arith.constant 0
-  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
-  // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
-  // CHECK: %[[CST:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill
-  // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
-  // CHECK:   %[[ADD:.+]] = arith.addf
-  // CHECK:   linalg.yield %[[ADD]] : f32
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x49x42x27xf32>
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x45x40x28xf32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<?x45x40x28xf32>) {
+  // CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f32):
+  // CHECK:   linalg.yield %[[IN]] : f32
+  // CHECK: } -> tensor<?x45x40x28xf32>
+  // CHECK: %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<?x45x40x28xf32>) -> tensor<?x45x40x28xf32>
   %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
   return
 }
@@ -481,14 +476,12 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
   // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
 
   // Running convolution
-  // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
-  // CHECK: %[[CST:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill
-  // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
-  // CHECK:   %[[ADD:.+]] = arith.addf
-  // CHECK:   linalg.yield %[[ADD]] : f32
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) : tensor<1x?x?x28xf32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x?x?x28xf32>) {
+  // CHECK:   linalg.yield
+  // CHECK: } -> tensor<1x?x?x28xf32>
+  // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%17 : tensor<1x?x?x28xf32>) -> tensor<1x?x?x28xf32>
+
   %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
   return
 }
@@ -678,52 +671,52 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
 
 // -----
 
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+
 // CHECK-LABEL: @conv3d_f32
 func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
-  // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
-  // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
-  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
-  // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
-  // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>)
-  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
-  // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf
+  // CHECK-DAG:  %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+  // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
+  // CHECK:      %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2 : tensor<28xf32>) outs(%1 : tensor<1x47x45x43x28xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        linalg.yield %[[IN]] : f32
+  // CHECK:      } -> tensor<1x47x45x43x28xf32>
+  // CHECK:      linalg.conv_3d_ndhwc_dhwcf
   // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
   // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
-  // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xf32>, tensor<1x47x45x43x28xf32>)
-  // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>) {
-  // CHECK: ^bb0(%[[A1:.+]]: f32, %[[A2:.+]]: f32, %{{.+}}: f32):
-  // CHECK: %[[ADD:.+]] = arith.addf %[[A1]], %[[A2]] : f32
-  // CHECK: linalg.yield %[[ADD]]
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
   %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>)  -> tensor<1x47x45x43x28xf32>
   return
 }
 
 // -----
 
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+
 // CHECK-LABEL: @conv3d_i8
 func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
-    // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
-  // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
-  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
-  // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
-  // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : i32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>)
-  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
-  // CHECK-DAG: %[[IZP:.+]] = arith.constant -128 : i32
-  // CHECK-DAG: %[[FZP:.+]] = arith.constant 42 : i32
-  // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf_q
+  // CHECK-DAG:  %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+  // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
+  // CHECK:      %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2 : tensor<28xi32>)
+  // CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x28xi32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
+  // CHECK:        linalg.yield %[[IN]] : i32
+  // CHECK:      } -> tensor<1x47x45x43x28xi32>
+  // CHECK:      %[[IZP:.+]] = arith.constant -128 : i32
+  // CHECK:      %[[FZP:.+]] = arith.constant 42 : i32
+  // CHECK:      linalg.conv_3d_ndhwc_dhwcf_q
   // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
   // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
-  // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xi32>, tensor<1x47x45x43x28xi32>)
-  // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>) {
-  // CHECK: ^bb0(%[[A1:.+]]: i32, %[[A2:.+]]: i32, %{{.+}}: i32):
-  // CHECK: %[[ADD:.+]] = arith.addi %[[A1]], %[[A2]] : i32
-  // CHECK: linalg.yield %[[ADD]]
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
+
   %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>)  -> tensor<1x47x45x43x28xi32>
   return
 }


        


More information about the Mlir-commits mailing list