[Mlir-commits] [mlir] 03c2f5d - [mlir][linalg][conv] Flatten the channel dimension when vectorizing (#71918)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 6 13:35:07 PST 2023


Author: Andrzej WarzyƄski
Date: 2023-12-06T21:35:03Z
New Revision: 03c2f5d8bbcf31239a631d9343ac7f4b6b3094c1

URL: https://github.com/llvm/llvm-project/commit/03c2f5d8bbcf31239a631d9343ac7f4b6b3094c1
DIFF: https://github.com/llvm/llvm-project/commit/03c2f5d8bbcf31239a631d9343ac7f4b6b3094c1.diff

LOG: [mlir][linalg][conv] Flatten the channel dimension when vectorizing (#71918)

The current vectorization of 1D depthwise convolutions in Linalg is
_sub-optimal_ for tensor with a low number of channel dimensions, e.g.:

```mlir
linalg.depthwise_conv_1d_nwc_wc
    {dilations = dense<1> : vector<1xi64>,
    strides = dense<1> : vector<1xi64>}
    ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
    outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
```

That's due to the fact that ultimately (i.e. at LLVM level),
vectorization happens along the trailing dimension (i.e. the channel
dimension). In this case it leads to vectors with 3 elements (or worse,
if there's e.g. only 1 channel dimension). For comparison, a 128 bit
wide vector registers can hold 16 x i8.

Instead, this patch adds an option to flatten/collapse the channel
dimension into the width dimension of the input/filter/output using
`vector.shape_cast` operation:

```mlir
    %sc_input = vector.shape_cast %input : vector<1x8x3xi8> to vector<1x24xi8>
    %sc_output = vector.shape_cast %output : vector<1x8x3xi8> to vector<1x24xi8>
    %b_filter = vector.broadcast %filter : vector<3xi8> to vector<1x8x3xi8>
    %sc_filter = vector.shape_cast %b_filter : vector<1x8x3xi8> to vector<1x24xi8>
```

This new vectorization mode is implemented in `depthwiseConv` by
inserting `vector.shape_cast` Ops before and after 
`depthwiseConv1dSliceAsMulAcc` is invoked. It can be selected through
e.g. a transform dialect attribute:

```mlir
  transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv}
```

A forthcoming patch will implement a strategy to automatically switch
between the two implementations, depending on the shape of the input
tensors.

Co-authored by: Bradley Smith <bradley.smith at arm.com>

Added: 
    mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 002926ff965fd..de65f3176c46a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2038,6 +2038,7 @@ def VectorizeChildrenAndApplyPatternsOp :
   let arguments = (ins TransformHandleTypeInterface:$target,
                    UnitAttr:$vectorize_padding,
                    UnitAttr:$vectorize_nd_extract,
+                   UnitAttr:$flatten_1d_depthwise_conv,
                    UnitAttr:$disable_multi_reduction_to_contract_patterns,
                    UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
   let results = (outs TransformHandleTypeInterface:$transformed);
@@ -2049,7 +2050,8 @@ def VectorizeChildrenAndApplyPatternsOp :
   let builders = [
     OpBuilder<(ins "Value":$target,
                CArg<"bool", "false">:$vectorizePadding,
-               CArg<"bool", "false">:$vectorizeNDExtract)>,
+               CArg<"bool", "false">:$vectorizeNDExtract,
+               CArg<"bool", "false">:$flatten1DDepthwise)>
   ];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6c4e16bd94f47..3f4dfe42b71fd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -753,7 +753,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
 LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
                         ArrayRef<int64_t> inputVectorSizes = {},
                         ArrayRef<bool> inputScalableVecDims = {},
-                        bool vectorizeNDExtract = false);
+                        bool vectorizeNDExtract = false,
+                        bool flatten1DDepthwiseConv = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 18ee36efab9d8..e3713457e8412 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2946,7 +2946,7 @@ LogicalResult TileUsingForallOp::verify() {
 
 void transform::VectorizeChildrenAndApplyPatternsOp::build(
     OpBuilder &builder, OperationState &result, Value target,
-    bool vectorizePadding, bool vectorizeExtract) {
+    bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
   result.addOperands(target);
   if (vectorizePadding) {
     result.addAttribute(
@@ -2960,6 +2960,12 @@ void transform::VectorizeChildrenAndApplyPatternsOp::build(
             result.name),
         builder.getUnitAttr());
   }
+  if (flatten1DDepthwiseConv) {
+    result.addAttribute(
+        VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
+            result.name),
+        builder.getUnitAttr());
+  }
   result.addTypes(transform::AnyOpType::get(builder.getContext()));
 }
 
@@ -2968,22 +2974,29 @@ namespace {
 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
 struct VectorizationPattern : public RewritePattern {
   explicit VectorizationPattern(MLIRContext *context,
-                                bool vectorizeExtract = false)
+                                bool vectorizeExtract = false,
+                                bool flattenConv = false)
       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
-        vectorizeNDExtract(vectorizeExtract) {}
+        vectorizeNDExtract(vectorizeExtract),
+        flatten1DDepthwiseConv(flattenConv) {}
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
     if (!linalgOp)
       return rewriter.notifyMatchFailure(op, "expected Linalg Op");
     return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
-                     /*scalableVecDims=*/{}, vectorizeNDExtract);
+                     /*scalableVecDims=*/{}, vectorizeNDExtract,
+                     flatten1DDepthwiseConv);
   }
 
 private:
   /// Controls whether to vectorize `tensor.extract` when the input tensor is
   /// rank >= 2.
   bool vectorizeNDExtract = false;
+  /// Controls whether to "flatten" the channel dimension when vectorising 1D
+  /// depthwise convolutions. This should lead to bette vectorization for
+  /// tensors with a low number of channel dimensions.
+  bool flatten1DDepthwiseConv = false;
 };
 } // namespace
 
@@ -3000,7 +3013,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
 
   MLIRContext *ctx = getContext();
   RewritePatternSet patterns(ctx);
-  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
+  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
+                                     getFlatten_1dDepthwiseConv());
 
   if (!getDisableTransferPermutationMapLoweringPatterns())
     vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f9a53a8451a60..c21d007c931b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -44,8 +44,9 @@ using namespace mlir::linalg;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 /// Try to vectorize `convOp` as a convolution.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
-                                                   LinalgOp convOp);
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+                     bool flatten1DDepthwiseConv = false);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
@@ -1664,7 +1665,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                       ArrayRef<int64_t> inputVectorSizes,
                                       ArrayRef<bool> inputScalableVecDims,
-                                      bool vectorizeNDExtract) {
+                                      bool vectorizeNDExtract,
+                                      bool flatten1DDepthwiseConv) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -1696,8 +1698,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             // TODO: isaConvolutionOpInterface that can also infer from generic
             // features. Will require stride/dilation attributes inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
-              FailureOr<Operation *> convOr =
-                  vectorizeConvolution(rewriter, linalgOp);
+              FailureOr<Operation *> convOr = vectorizeConvolution(
+                  rewriter, linalgOp, flatten1DDepthwiseConv);
               if (succeeded(convOr)) {
                 llvm::append_range(results, (*convOr)->getResults());
                 return success();
@@ -2822,7 +2824,7 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConv() {
+  FailureOr<Operation *> depthwiseConv(bool flatten) {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
@@ -2869,6 +2871,9 @@ struct Conv1DGenerator
     //===------------------------------------------------------------------===//
     // Unroll along kw and read slices of lhs and rhs.
     SmallVector<Value> lhsVals, rhsVals, resVals;
+    auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
+    auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
+
     // Extract lhs slice of size {n, wSizeStep, c}
     //   @ [0, sw * w + dw * kw, 0].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
@@ -2876,8 +2881,7 @@ struct Conv1DGenerator
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
             loc, lhs,
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
-            /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
-            /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+            inOutSliceSizes, inOutStrides));
       }
     }
     // Extract rhs slice of size {c} @ [kw].
@@ -2889,21 +2893,39 @@ struct Conv1DGenerator
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
           loc, res,
-          /*offsets=*/ArrayRef<int64_t>{0, w, 0},
-          /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
-          /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+          /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
+          inOutStrides));
     }
 
     auto linearIndex = [&](int64_t kw, int64_t w) {
       return kw * (wSize / wSizeStep) + w;
     };
 
+    auto inOutFlattenSliceSizes =
+        SmallVector<int64_t>{nSize, wSizeStep * cSize};
+    auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
+    auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
-                                                  lhsVals[linearIndex(kw, w)],
-                                                  rhsVals[kw], resVals[w]);
+        Value lhsVal = lhsVals[linearIndex(kw, w)];
+        Value resVal = resVals[w];
+        ShapedType filterBCastTy = cast<ShapedType>(resVal.getType());
+        if (flatten) {
+          // Flatten the input and filter vectors (collapse the channel
+          // dimension)
+          lhsVal = rewriter.create<vector::ShapeCastOp>(
+              loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
+          resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
+                                                        resVals[w]);
+        }
+        resVals[w] = depthwiseConv1dSliceAsMulAcc(
+            rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
+        if (flatten) {
+          // Un-flatten the output vector (restore the channel dimension)
+          resVals[w] = rewriter.create<vector::ShapeCastOp>(
+              loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
+        }
       }
     }
 
@@ -2936,9 +2958,13 @@ struct Conv1DGenerator
         .getOperation();
   }
 
-  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
+  /// Lower:
+  ///   *  lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
+  ///   *  lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
+  /// to MulAcc.
   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
-                                     Value lhs, Value rhs, Value res) {
+                                     Value lhs, Value rhs, Value res,
+                                     ShapedType bcastTy, bool flatten) {
     auto rhsTy = cast<ShapedType>(rhs.getType());
     auto resTy = cast<ShapedType>(res.getType());
 
@@ -2946,7 +2972,13 @@ struct Conv1DGenerator
     lhs = promote(rewriter, loc, lhs, resTy);
 
     rhs = rewriter.create<vector::BroadcastOp>(
-        loc, resTy.clone(rhsTy.getElementType()), rhs);
+        loc, bcastTy.clone(rhsTy.getElementType()), rhs);
+    if (flatten) {
+      // Flatten the channel dimension
+      rhs = rewriter.create<vector::ShapeCastOp>(
+          loc, resTy.clone(rhsTy.getElementType()), rhs);
+    }
+
     rhs = promote(rewriter, loc, rhs, resTy);
 
     if (!lhs || !rhs)
@@ -3049,7 +3081,7 @@ struct Conv1DGenerator
 
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
-  FailureOr<Operation *> generateDilatedConv() {
+  FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
@@ -3060,7 +3092,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConv();
+      return depthwiseConv(flatten);
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
@@ -3125,8 +3157,9 @@ struct Conv1DGenerator
 
 /// Helper function to vectorize a LinalgOp with convolution semantics.
 // TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
-                                                   LinalgOp op) {
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+                     bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can simply
   // use them if present, otherwise use the default and let the generic conv.
@@ -3151,7 +3184,7 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
   res = e.generateNcwPooling();
   if (succeeded(res))
     return res;
-  return e.generateDilatedConv();
+  return e.generateDilatedConv(flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
new file mode 100644
index 0000000000000..a242d09671825
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -0,0 +1,309 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x3xi8>,
+                                                   %filter: tensor<1x3xi8>,
+                                                   %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
+    outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
+  return %res : tensor<1x8x3xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x3xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x3xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> {
+
+// CHECK-DAG:       %[[C0_IDX:.*]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+// CHECK:           %[[V_INPUT_R:.*]] = vector.transfer_read %[[INPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
+// CHECK:           %[[V_FILTER_R:.*]] = vector.transfer_read %[[FILTER]][%[[C0_IDX]], %[[C0_IDX]]]
+// CHECK:           %[[V_OUTPUT_R:.*]] = vector.transfer_read %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
+
+// CHECK:           %[[V_FILTER_0:.*]] = vector.extract %[[V_FILTER_R]][0] : vector<3xi8> from vector<1x3xi8>
+
+/// w == 0, kw = 0
+// CHECK:           %[[SC_INPUT:.*]] = vector.shape_cast %[[V_INPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           %[[SC_OUTPUT:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           %[[B_FILTER:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<3xi8> to vector<1x8x3xi8>
+// CHECK:           %[[SC_FILTER:.*]] = vector.shape_cast %[[B_FILTER]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[SC_INPUT]], %[[SC_FILTER]] : vector<1x24xi8>
+// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[SC_OUTPUT]] : vector<1x24xi8>
+
+// Write the result back in one shot.
+// CHECK:           %[[SC_ADDI:.*]] = vector.shape_cast %[[ADDI]] : vector<1x24xi8> to vector<1x8x3xi8>
+// CHECK:           vector.transfer_write %[[SC_ADDI]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
+
+//------
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>,
+                                                                %filter: memref<2x4xf32>,
+                                                                %output: memref<3x2x4xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
+    outs(%output : memref<3x2x4xf32>)
+  return
+}
+
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2
+//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
+
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+
+/// Read the whole data in one shot.
+//      CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+//      CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
+
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xf32> from vector<2x4xf32>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xf32> from vector<2x4xf32>
+
+
+/// w == 0, kw = 0
+// CHECK:           %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
+// CHECK:           %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[FMA_0:.*]] = vector.fma %[[SC_V_INPUT_0]], %[[SC_B_FILTER_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xf32>
+
+/// w == 0, kw = 1
+// CHECK:           %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[B_V_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32>
+// CHECK:           %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_V_FILTER_1]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[FMA_1:.*]] = vector.fma %[[SC_V_INPUT_1]], %[[SC_B_FILTER_1]], %[[FMA_0]] : vector<3x8xf32>
+
+// Write the result back in one shot.
+//      CHECK:   %[[SC_FMA_1:.*]] = vector.shape_cast %[[FMA_1]] : vector<3x8xf32> to vector<3x2x4xf32>
+//      CHECK:   vector.transfer_write %[[SC_FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5x4xi8>,
+                                                              %filter: memref<2x4xi8>,
+                                                              %output: memref<3x2x4xi32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
+    outs(%output : memref<3x2x4xi32>)
+  return
+}
+
+//       CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2
+//  CHECK-SAME:   (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>)
+
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+
+/// Read the whole data in one shot.
+//      CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+//      CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+//      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+//      CHECK:   %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:     {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+
+//      CHECK:  %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<4xi8> from vector<2x4xi8>
+//      CHECK:  %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<4xi8> from vector<2x4xi8>
+
+/// w == 0, kw = 0
+//      CHECK:  %[[SC_V_INPUT_0:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x8xi8>
+//      CHECK:  %[[SC_V_OUTPUT_R:.*]] = vector.shape_cast %[[V_OUTPUT_R]] : vector<3x2x4xi32> to vector<3x8xi32>
+//      CHECK:  %[[EXT_INPUT_0:.*]] = arith.extsi %[[SC_V_INPUT_0]] : vector<3x8xi8> to vector<3x8xi32>
+//      CHECK:  %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8>
+//      CHECK:  %[[SC_B_FILTER_0:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x8xi8>
+//      CHECK:  %[[EXT_FILTER_0:.*]] = arith.extsi %[[SC_B_FILTER_0]] : vector<3x8xi8> to vector<3x8xi32>
+//      CHECK:  %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x8xi32>
+//      CHECK:  %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[SC_V_OUTPUT_R]] : vector<3x8xi32>
+
+/// w == 0, kw = 1
+//      CHECK:  %[[SC_V_INPUT_1:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x8xi8>
+//      CHECK:  %[[EXT_INPUT_1:.*]] = arith.extsi %[[SC_V_INPUT_1]] : vector<3x8xi8> to vector<3x8xi32>
+//      CHECK:  %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8>
+//      CHECK:  %[[SC_B_FILTER_1:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x8xi8>
+//      CHECK:  %[[EXT_FILTER_1:.*]] = arith.extsi %[[SC_B_FILTER_1]] : vector<3x8xi8> to vector<3x8xi32>
+//      CHECK:  %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x8xi32>
+//      CHECK:  %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x8xi32>
+
+// Write the result back in one shot.
+//      CHECK:   %[[SC_ADD_1:.*]] = vector.shape_cast %[[ADD_1]] : vector<3x8xi32> to vector<3x2x4xi32>
+//      CHECK:   vector.transfer_write %[[SC_ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4xi8>,
+                                                            %filter: tensor<3x4xi8>,
+                                                            %output: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+    ins(%input, %filter : tensor<3x9x4xi8>, tensor<3x4xi8>)
+    outs(%output : tensor<3x3x4xi8>) -> tensor<3x3x4xi8>
+  return %res : tensor<3x3x4xi8>
+}
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<3x9x4xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<3x4xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<3x3x4xi8>) -> tensor<3x3x4xi8> {
+
+// CHECK-DAG:           %[[C0_IDX:.*]] = arith.constant 0 : index
+// CHECK-DAG:           %[[C0_I8:.*]] = arith.constant 0 : i8
+
+/// Read the whole data in one shot.
+// CHECK:           %[[V_INPUT_R:.*]] = vector.transfer_read %[[INPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]]
+// CHECK:           %[[V_FILTER_R:.*]] = vector.transfer_read %[[FILTER]][%[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]]
+// CHECK:           %[[V_OUTPUT_R:.*]] = vector.transfer_read %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]], %[[C0_I8]]
+
+// CHECK:           %[[V_INPUT_0:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:        {offsets = [0, 0, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_INPUT_1:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:        {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_INPUT_2:.*]] = vector.extract_strided_slice %[[V_INPUT_R]] 
+// CHECK-SAME:        {offsets = [0, 4, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_INPUT_3:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:        {offsets = [0, 1, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_INPUT_4:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:        {offsets = [0, 3, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_INPUT_5:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:        {offsets = [0, 5, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_INPUT_6:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:        {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_INPUT_7:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:        {offsets = [0, 4, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_INPUT_8:.*]] = vector.extract_strided_slice %[[V_INPUT_R]]
+// CHECK-SAME:        {offsets = [0, 6, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x7x4xi8> to vector<3x1x4xi8>
+
+// CHECK:           %[[V_FILTER_0:.*]] = vector.extract %[[V_FILTER_R]][0] : vector<4xi8> from vector<3x4xi8>
+// CHECK:           %[[V_FILTER_1:.*]] = vector.extract %[[V_FILTER_R]][1] : vector<4xi8> from vector<3x4xi8>
+// CHECK:           %[[V_FILTER_2:.*]] = vector.extract %[[V_FILTER_R]][2] : vector<4xi8> from vector<3x4xi8>
+
+// CHECK:           %[[V_OUTPUT_0:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:        {offsets = [0, 0, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_OUTPUT_1:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:       {offsets = [0, 1, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[V_OUTPUT_2:.*]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
+// CHECK-SAME:        {offsets = [0, 2, 0], sizes = [3, 1, 4], strides = [1, 1, 1]} : vector<3x3x4xi8> to vector<3x1x4xi8>
+
+/// w == 0, kw == 0
+// CHECK:           %[[VAL_23:.*]] = vector.shape_cast %[[V_INPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_24:.*]] = vector.shape_cast %[[V_OUTPUT_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_26:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_27:.*]] = arith.muli %[[VAL_23]], %[[VAL_26]] : vector<3x4xi8>
+// CHECK:           %[[VAL_28:.*]] = arith.addi %[[VAL_27]], %[[VAL_24]] : vector<3x4xi8>
+
+/// w == 1, kw == 0
+// CHECK:           %[[VAL_29:.*]] = vector.shape_cast %[[V_INPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_30:.*]] = vector.shape_cast %[[V_OUTPUT_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_32:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_33:.*]] = arith.muli %[[VAL_29]], %[[VAL_32]] : vector<3x4xi8>
+// CHECK:           %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_30]] : vector<3x4xi8>
+
+/// w == 2, kw == 0
+// CHECK:           %[[VAL_35:.*]] = vector.shape_cast %[[V_INPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_36:.*]] = vector.shape_cast %[[V_OUTPUT_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_38:.*]] = vector.shape_cast %[[B_FILTER_0]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_39:.*]] = arith.muli %[[VAL_35]], %[[VAL_38]] : vector<3x4xi8>
+// CHECK:           %[[VAL_40:.*]] = arith.addi %[[VAL_39]], %[[VAL_36]] : vector<3x4xi8>
+
+/// w == 3, kw == 1
+// CHECK:           %[[VAL_41:.*]] = vector.shape_cast %[[V_INPUT_3]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_43:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_44:.*]] = arith.muli %[[VAL_41]], %[[VAL_43]] : vector<3x4xi8>
+// CHECK:           %[[VAL_45:.*]] = arith.addi %[[VAL_44]], %[[VAL_28]] : vector<3x4xi8>
+
+/// w == 4, kw == 1
+// CHECK:           %[[VAL_46:.*]] = vector.shape_cast %[[V_INPUT_4]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_48:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_49:.*]] = arith.muli %[[VAL_46]], %[[VAL_48]] : vector<3x4xi8>
+// CHECK:           %[[VAL_50:.*]] = arith.addi %[[VAL_49]], %[[VAL_34]] : vector<3x4xi8>
+
+/// w == 5, kw == 1
+// CHECK:           %[[VAL_51:.*]] = vector.shape_cast %[[V_INPUT_5]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_53:.*]] = vector.shape_cast %[[B_FILTER_1]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_54:.*]] = arith.muli %[[VAL_51]], %[[VAL_53]] : vector<3x4xi8>
+// CHECK:           %[[VAL_55:.*]] = arith.addi %[[VAL_54]], %[[VAL_40]] : vector<3x4xi8>
+
+/// w == 6, kw == 2
+// CHECK:           %[[VAL_56:.*]] = vector.shape_cast %[[V_INPUT_6]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_58:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_59:.*]] = arith.muli %[[VAL_56]], %[[VAL_58]] : vector<3x4xi8>
+// CHECK:           %[[VAL_60:.*]] = arith.addi %[[VAL_59]], %[[VAL_45]] : vector<3x4xi8>
+
+/// w == 7, kw == 2
+// CHECK:           %[[VAL_61:.*]] = vector.shape_cast %[[VAL_60]] : vector<3x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_62:.*]] = vector.shape_cast %[[V_INPUT_7]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_64:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_65:.*]] = arith.muli %[[VAL_62]], %[[VAL_64]] : vector<3x4xi8>
+// CHECK:           %[[VAL_66:.*]] = arith.addi %[[VAL_65]], %[[VAL_50]] : vector<3x4xi8>
+
+/// w == 8, kw == 2
+// CHECK:           %[[VAL_67:.*]] = vector.shape_cast %[[VAL_66]] : vector<3x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_68:.*]] = vector.shape_cast %[[V_INPUT_8]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[B_FILTER_2:.*]] = vector.broadcast %[[V_FILTER_2]] : vector<4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_70:.*]] = vector.shape_cast %[[B_FILTER_2]] : vector<3x1x4xi8> to vector<3x4xi8>
+// CHECK:           %[[VAL_71:.*]] = arith.muli %[[VAL_68]], %[[VAL_70]] : vector<3x4xi8>
+// CHECK:           %[[VAL_72:.*]] = arith.addi %[[VAL_71]], %[[VAL_55]] : vector<3x4xi8>
+
+// Write the result back.
+// CHECK:           %[[VAL_73:.*]] = vector.shape_cast %[[VAL_72]] : vector<3x4xi8> to vector<3x1x4xi8>
+// CHECK:           %[[VAL_74:.*]] = vector.insert_strided_slice %[[VAL_61]], %[[V_OUTPUT_R]]
+// CHECK-SAME:        {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8>
+// CHECK:           %[[VAL_75:.*]] = vector.insert_strided_slice %[[VAL_67]], %[[VAL_74]]
+// CHECK-SAME:        {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8>
+// CHECK:           %[[VAL_76:.*]] = vector.insert_strided_slice %[[VAL_73]], %[[VAL_75]]
+// CHECK-SAME:        {offsets = [0, 2, 0], strides = [1, 1, 1]} : vector<3x1x4xi8> into vector<3x3x4xi8>
+// CHECK:           %[[VAL_77:.*]] = vector.transfer_write %[[VAL_76]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+


        


More information about the Mlir-commits mailing list