[Mlir-commits] [mlir] 026fac2 - [mlir][linalg] Vectorization for conv_1d_ncw_fcw

Hanhan Wang llvmlistbot at llvm.org
Wed Sep 14 11:08:15 PDT 2022

Author: Stanley Winata
Date: 2022-09-14T11:07:53-07:00
New Revision: 026fac2a14cdf0b904ec83044d1f271e1ba2c5f9

URL: https://github.com/llvm/llvm-project/commit/026fac2a14cdf0b904ec83044d1f271e1ba2c5f9
DIFF: https://github.com/llvm/llvm-project/commit/026fac2a14cdf0b904ec83044d1f271e1ba2c5f9.diff

LOG: [mlir][linalg] Vectorization for conv_1d_ncw_fcw

Most computer vision torch models uses nchw/ncw convolution. In a previous patch we added decomposition conv2dNchw to conv1dNcw. To enhance the performance on torch models we add this vectorization pattern for conv1dNcw which would consquently also improve the performance on conv2dNchw.

On IREE + Intel Xeon 8360 + Resnet50, we were able to get ~7x speed up ~880ms to 126ms.

Reviewed By: nicolasvasilache, hanchung

Differential Revision: https://reviews.llvm.org/D133675




diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index badff43101872..271d96219a0cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -91,6 +91,12 @@ static AffineMap reindexIndexingMap(AffineMap map) {
   return res;
+/// Helper enum to represent conv1d input traversal order.
+enum class Conv1DOpOrder {
+  Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
+  Nwc  // Corresponds to operation that traverses the input in (n, w, c) order.
 /// Helper data structure to represent the result of vectorization.
 /// In certain specific cases, like terminators, we do not want to propagate/
 enum VectorizationStatus {
@@ -1312,14 +1318,23 @@ namespace {
 /// or
 /// ```
+///   Op def: (     n,     c,     w,    f,    kw )
+///    Iters: ({Par(), Par(), Par(), Red(), Red()})
+///   Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
+/// ```
+/// kw is unrolled, w is unrolled iff dilationW > 1.
+/// or
+/// ```
 ///   Op def: (     n,     w,     c,    kw )
 ///    Iters: ({Par(), Par(), Par(), Red()})
 ///   Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
 /// ```
 /// kw is unrolled, w is unrolled iff dilationW > 1.
-struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
-  Conv1DNwcGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
-                     int dilationW)
+struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
+  Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
+                  int dilationW)
       : StructuredGenerator<LinalgOp>(builder, linalgOp), strideW(strideW),
         dilationW(dilationW) {
     // Determine whether `linalgOp` can be generated with this generator
@@ -1382,15 +1397,45 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> conv() {
+  FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
     if (!valid)
       return failure();
     int64_t nSize, wSize, cSize, kwSize, fSize;
-    // kernel{kw, c, f}
-    bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
-    // out{n, w, f}
-    bindShapeDims(resShapedType, nSize, wSize);
+    SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
+    switch (conv1DOpOrder) {
+    case Conv1DOpOrder::Nwc:
+      // kernel{kw, c, f}
+      bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
+      // out{n, w, f}
+      bindShapeDims(resShapedType, nSize, wSize);
+      lhsShape = {nSize,
+                  // iw = ow * sw + kw *  dw - 1
+                  //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+                  // Perform the proper inclusive -> exclusive -> inclusive.
+                  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
+                      1,
+                  cSize};
+      rhsShape = {kwSize, cSize, fSize};
+      resShape = {nSize, wSize, fSize};
+      break;
+    case Conv1DOpOrder::Ncw:
+      // kernel{f, c, kw}
+      bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
+      // out{n, f, w}
+      bindShapeDims(resShapedType, nSize, fSize, wSize);
+      lhsShape = {nSize, cSize,
+                  // iw = ow * sw + kw *  dw - 1
+                  //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+                  // Perform the proper inclusive -> exclusive -> inclusive.
+                  ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
+                      1};
+      rhsShape = {fSize, cSize, kwSize};
+      resShape = {nSize, fSize, wSize};
+      break;
+    default:
+      return failure();
+    }
     vector::TransferWriteOp write;
     Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
@@ -1403,17 +1448,9 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     Type lhsEltType = lhsShapedType.getElementType();
     Type rhsEltType = rhsShapedType.getElementType();
     Type resEltType = resShapedType.getElementType();
-    VectorType lhsType = VectorType::get(
-        {nSize,
-         // iw = ow * sw + kw *  dw - 1
-         //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
-         // Perform the proper inclusive -> exclusive -> inclusive.
-         ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
-         cSize},
-        lhsEltType);
-    VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType);
-    VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType);
+    auto lhsType = VectorType::get(lhsShape, lhsEltType);
+    auto rhsType = VectorType::get(rhsShape, rhsEltType);
+    auto resType = VectorType::get(resShape, resEltType);
     // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
     // 0].
     Value lhs = builder.create<vector::TransferReadOp>(
@@ -1425,6 +1462,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     Value res = builder.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
+    // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output:
+    // {n,w,f}. To reuse the base pattern vectorization case, we do pre
+    // transpose on input, weight, and output.
+    switch (conv1DOpOrder) {
+    case Conv1DOpOrder::Nwc:
+      // Base case, so no transposes necessary.
+      break;
+    case Conv1DOpOrder::Ncw:
+      // To match base vectorization case, we pre-transpose current case.
+      // ncw -> nwc
+      static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
+      lhs = builder.create<vector::TransposeOp>(loc, lhs, permLhs);
+      // fcw -> wcf
+      static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
+      rhs = builder.create<vector::TransposeOp>(loc, rhs, permRhs);
+      // nfw -> nwf
+      static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
+      res = builder.create<vector::TransposeOp>(loc, res, permRes);
+      break;
+    default:
+      return failure();
+    }
     // Begin vector-only rewrite part
@@ -1478,6 +1538,22 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     // End vector-only rewrite part
+    // The base vectorization case is output: {n,w,f}
+    // To reuse the result from base pattern vectorization case, we post
+    // transpose the base case result.
+    switch (conv1DOpOrder) {
+    case Conv1DOpOrder::Nwc:
+      // Base case, so no transposes necessary.
+      break;
+    case Conv1DOpOrder::Ncw:
+      // nwf -> nfw
+      static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
+      res = builder.create<vector::TransposeOp>(loc, res, perm);
+      break;
+    default:
+      return failure();
+    }
     // Write back res slice of size {n, w, f} @ [0, 0, 0].
     return builder
         .create<vector::TransferWriteOp>(loc, res, resShaped,
@@ -1619,7 +1695,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
-  FailureOr<Operation *> generateConv() {
+  FailureOr<Operation *> generateNwcConv() {
     AffineExpr n, w, f, kw, c;
     bindDims(ctx, n, w, f, kw, c);
     if (!iters({Par(), Par(), Par(), Red(), Red()}))
@@ -1629,7 +1705,23 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c, f},
                 /*resIndex*/ {n, w, f}}))
-      return conv();
+      return conv(Conv1DOpOrder::Nwc);
+    return failure();
+  }
+  /// Entry point that transposes into the common form:
+  ///   {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
+  FailureOr<Operation *> generateNcwConv() {
+    AffineExpr n, w, f, kw, c;
+    bindDims(ctx, n, f, w, c, kw);
+    if (!iters({Par(), Par(), Par(), Red(), Red()}))
+      return failure();
+    if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
+                /*rhsIndex*/ {f, c, kw},
+                /*resIndex*/ {n, f, w}}))
+      return conv(Conv1DOpOrder::Ncw);
     return failure();
@@ -1668,8 +1760,11 @@ static FailureOr<Operation *> vectorizeConvolution(OpBuilder &b, LinalgOp op) {
   auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
   auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
   auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-  Conv1DNwcGenerator e(b, op, stride, dilation);
-  auto res = e.generateConv();
+  Conv1DGenerator e(b, op, stride, dilation);
+  auto res = e.generateNwcConv();
+  if (succeeded(res))
+    return res;
+  res = e.generateNcwConv();
   if (succeeded(res))
     return res;
   return e.generateDilatedConv();

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index e3c4d7d1c1de0..e7495765b3ec7 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -185,6 +185,218 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x
 // Write the result back in one shot.
 //      CHECK:   vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// -----
+func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x1xf32>, %output: memref<4x8x2xf32>) {
+  linalg.conv_1d_ncw_fcw
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x3x6xf32>, memref<8x3x1xf32>)
+    outs(%output : memref<4x8x2xf32>)
+  return
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//      CHECK: func @conv1d_ncw_4x8x2_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<8x3x1xf32>, %[[OUTPUT:.+]]: memref<4x8x2xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+/// Read the whole data in one shot.
+//  CHECK-DAG:   %[[V_NWC_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_NWC_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_NWC_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+/// Transpose result to nwc format.
+//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transpose %[[V_NWC_INPUT_R]], [0, 2, 1]
+//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transpose %[[V_NWC_FILTER_R]], [2, 1, 0]
+//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transpose %[[V_NWC_OUTPUT_R]], [0, 2, 1]
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
+//      CHECK:    %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
+//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+/// w == 0, kw == 0
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+/// w == 1, kw == 0
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+/// w == 0, kw == 0
+//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_0]], %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
+/// w == 1, kw == 0
+//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_1]], %[[RES_0]]
+// CHECK-SAME:     {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
+/// Transpose result to ncw format.
+//  CHECK:  %[[RES_2:.+]] = vector.transpose %[[RES_1]], [0, 2, 1]
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[RES_2]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// -----
+func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) {
+  linalg.conv_1d_ncw_fcw
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x3x6xf32>, memref<8x3x2xf32>)
+    outs(%output : memref<4x8x2xf32>)
+  return
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//      CHECK: func @conv1d_ncw_4x8x2_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<8x3x2xf32>, %[[OUTPUT:.+]]: memref<4x8x2xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+/// Read the whole data in one shot.
+//  CHECK-DAG:   %[[V_NWC_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   %[[V_NWC_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:   %[[V_NWC_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+/// Transpose result to nwc format.
+//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transpose %[[V_NWC_INPUT_R]], [0, 2, 1]
+//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transpose %[[V_NWC_FILTER_R]], [2, 1, 0]
+//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transpose %[[V_NWC_OUTPUT_R]], [0, 2, 1]
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+//      CHECK:   %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+//      CHECK:   %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+//      CHECK:  %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+//      CHECK:  %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
+/// w == 0, kw == 0
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_0]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+/// w == 1, kw == 0
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_0]], %[[V_OUTPUT_1]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+/// w == 1, kw == 1
+//      CHECK:   %[[CONTRACT_2:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     %[[V_INPUT_2]], %[[V_FILTER_1]], %[[CONTRACT_0]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+/// w == 1, kw == 1
+//      CHECK:   %[[CONTRACT_3:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     %[[V_INPUT_3]], %[[V_FILTER_1]], %[[CONTRACT_1]]
+// CHECK-SAME:     : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
+/// w == 0, kw == 0
+//      CHECK:   %[[RES_0:.+]] = vector.insert_strided_slice %[[CONTRACT_2]], %[[V_OUTPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
+/// w == 1, kw == 0
+//      CHECK:   %[[RES_1:.+]] = vector.insert_strided_slice %[[CONTRACT_3]], %[[RES_0]]
+// CHECK-SAME:     {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x8xf32> into vector<4x2x8xf32>
+/// Transpose result to ncw format.
+//  CHECK:  %[[RES_2:.+]] = vector.transpose %[[RES_1]], [0, 2, 1]
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[RES_2]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+// -----
+func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x3x2xf32>, %output: memref<4x8x2xf32>) {
+  linalg.conv_1d_ncw_fcw
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<4x3x6xf32>, memref<8x3x2xf32>)
+    outs(%output : memref<4x8x2xf32>)
+  return
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//      CHECK: func @conv1d_ncw_4x8x2_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<4x3x6xf32>, %[[FILTER:.+]]: memref<8x3x2xf32>, %[[OUTPUT:.+]]: memref<4x8x2xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+/// Read the whole data in one shot.
+//  CHECK-DAG:   %[[V_NWC_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_NWC_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+//  CHECK-DAG:  %[[V_NWC_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
+/// Transpose result to nwc format.
+//  CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transpose %[[V_NWC_INPUT_R]], [0, 2, 1]
+//  CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transpose %[[V_NWC_FILTER_R]], [2, 1, 0]
+//  CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transpose %[[V_NWC_OUTPUT_R]], [0, 2, 1]
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
+/// w == 0, kw == 0
+//      CHECK:   %[[CONTRACT_0:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
+// CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
+/// w == 0, kw == 1
+//      CHECK:   %[[CONTRACT_1:.+]] = vector.contract {
+// CHECK-SAME:       indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
+// CHECK-SAME:     : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
+/// Transpose result to ncw format.
+//  CHECK:  %[[RES:.+]] = vector.transpose %[[CONTRACT_1]], [0, 2, 1]
+// Write the result back in one shot.
+//      CHECK:   vector.transfer_write %[[RES]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 // -----
 func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {


More information about the Mlir-commits mailing list