[Mlir-commits] [mlir] [mlir][linalg][conv] Flatten the channel dimension when vectorizing (PR #71918)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Nov 16 02:26:28 PST 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/71918

>From d0b00b2c494b44fd1da7964b9334633215765c7f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 8 Oct 2023 19:53:04 +0100
Subject: [PATCH 1/3] [mlir][linalg][conv] Flatten the channel dimension when
 vectorizing

The current vectorization of 1D depthwise convolutions in Linalg is
_sub-optimal_ for tensor with a low number of channel dimensions, e.g.:

```mlir
linalg.depthwise_conv_1d_nwc_wc
    {dilations = dense<1> : vector<1xi64>,
    strides = dense<1> : vector<1xi64>}
    ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
    outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
```

That's due to the fact that ultimately (i.e. at LLVM level),
vectorization happens along the trailing dimension (i.e. the channel
dimension). In this case it leads to vectors with 3 elements (or worse,
if there's e.g. only 1 channel dimension). For comparison, a 128 bit
wide vector registers can hold 16 x i8.

Instead, this patch adds an option to flatten/collapse the channel
dimension into the width dimension of the input/output:

```mlir
    %collapsed = tensor.collapse_shape %input [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
    %collapsed_0 = tensor.collapse_shape %outpu [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
```

(Note that for this to work, the filter is broadcast rather than
re-shaped. Please see the test cases for details).

The new vectorization strategy is implemented in `depthwiseConvFlatten`,
which was implemented based on `depthwiseConvGeneric` (i.e. the original
vectorization hook). The current vectorization is preserved and kept as
the default option. New vectorization can be selected through e.g. a
transform dialect attribute:

```mlir
  transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv}
```

A forthcoming patch will implement a strategy to automatically switch
between the two implementations, depending on the shape of the input
tensors.

Co-authored by: Bradley Smith <bradley.smith at arm.com>
---
 .../Linalg/TransformOps/LinalgTransformOps.td |   4 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |   3 +-
 .../TransformOps/LinalgTransformOps.cpp       |  21 +-
 .../Linalg/Transforms/Vectorization.cpp       | 228 +++++++++++++++++-
 .../Linalg/vectorize-convolution-flatten.mlir | 222 +++++++++++++++++
 .../Dialect/Linalg/vectorize-convolution.mlir |  12 +-
 6 files changed, 466 insertions(+), 24 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..310efe164f93950 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2034,6 +2034,7 @@ def VectorizeChildrenAndApplyPatternsOp :
   let arguments = (ins TransformHandleTypeInterface:$target,
                    UnitAttr:$vectorize_padding,
                    UnitAttr:$vectorize_nd_extract,
+                   UnitAttr:$flatten_1d_depthwise_conv,
                    UnitAttr:$disable_multi_reduction_to_contract_patterns,
                    UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
   let results = (outs TransformHandleTypeInterface:$transformed);
@@ -2045,7 +2046,8 @@ def VectorizeChildrenAndApplyPatternsOp :
   let builders = [
     OpBuilder<(ins "Value":$target,
                CArg<"bool", "false">:$vectorizePadding,
-               CArg<"bool", "false">:$vectorizeNDExtract)>,
+               CArg<"bool", "false">:$vectorizeNDExtract,
+               CArg<"bool", "false">:$flatten1DDepthwise)>
   ];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6547648f7495c31..a4aee1f45249c2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -753,7 +753,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
 LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
                         ArrayRef<int64_t> inputVectorSizes = {},
                         ArrayRef<bool> inputScalableVecDims = {},
-                        bool vectorizeNDExtract = false);
+                        bool vectorizeNDExtract = false,
+                        bool flatten1DDepthwiseConv = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..35e8be7806928e1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2937,7 +2937,7 @@ LogicalResult TileUsingForallOp::verify() {
 
 void transform::VectorizeChildrenAndApplyPatternsOp::build(
     OpBuilder &builder, OperationState &result, Value target,
-    bool vectorizePadding, bool vectorizeExtract) {
+    bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
   result.addOperands(target);
   if (vectorizePadding) {
     result.addAttribute(
@@ -2951,6 +2951,12 @@ void transform::VectorizeChildrenAndApplyPatternsOp::build(
             result.name),
         builder.getUnitAttr());
   }
+  if (flatten1DDepthwiseConv) {
+    result.addAttribute(
+        VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
+            result.name),
+        builder.getUnitAttr());
+  }
   result.addTypes(transform::AnyOpType::get(builder.getContext()));
 }
 
@@ -2959,22 +2965,26 @@ namespace {
 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
 struct VectorizationPattern : public RewritePattern {
   explicit VectorizationPattern(MLIRContext *context,
-                                bool vectorizeExtract = false)
+                                bool vectorizeExtract = false,
+                                bool flattenConv = false)
       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
-        vectorizeNDExtract(vectorizeExtract) {}
+        vectorizeNDExtract(vectorizeExtract),
+        flatten1DDepthwiseConv(flattenConv) {}
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
     if (!linalgOp)
       return rewriter.notifyMatchFailure(op, "expected Linalg Op");
     return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
-                     /*scalableVecDims=*/{}, vectorizeNDExtract);
+                     /*scalableVecDims=*/{}, vectorizeNDExtract,
+                     flatten1DDepthwiseConv);
   }
 
 private:
   /// Controls whether to vectorize `tensor.extract` when the input tensor is
   /// rank >= 2.
   bool vectorizeNDExtract = false;
+  bool flatten1DDepthwiseConv = false;
 };
 } // namespace
 
@@ -2991,7 +3001,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
 
   MLIRContext *ctx = getContext();
   RewritePatternSet patterns(ctx);
-  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
+  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
+                                     getFlatten_1dDepthwiseConv());
 
   if (!getDisableTransferPermutationMapLoweringPatterns())
     vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b8d82159856825f..f6f74b448edf9a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -44,8 +44,9 @@ using namespace mlir::linalg;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 /// Try to vectorize `convOp` as a convolution.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
-                                                   LinalgOp convOp);
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+                     bool flatten1DDepthwiseConv = false);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
@@ -1664,7 +1665,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                       ArrayRef<int64_t> inputVectorSizes,
                                       ArrayRef<bool> inputScalableVecDims,
-                                      bool vectorizeNDExtract) {
+                                      bool vectorizeNDExtract,
+                                      bool flatten1DDepthwiseConv) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -1696,8 +1698,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             // TODO: isaConvolutionOpInterface that can also infer from generic
             // features. Will require stride/dilation attributes inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
-              FailureOr<Operation *> convOr =
-                  vectorizeConvolution(rewriter, linalgOp);
+              FailureOr<Operation *> convOr = vectorizeConvolution(
+                  rewriter, linalgOp, flatten1DDepthwiseConv);
               if (succeeded(convOr)) {
                 llvm::append_range(results, (*convOr)->getResults());
                 return success();
@@ -2822,7 +2824,7 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConv() {
+  FailureOr<Operation *> depthwiseConvGeneric() {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
@@ -2936,6 +2938,176 @@ struct Conv1DGenerator
         .getOperation();
   }
 
+  /// Generate a vector implementation for ("flatten channel dim"):
+  /// ```
+  ///   Op def: (     n,     w,     c,    kw)
+  ///    Iters: ({Par(), Par(), Par(), Red()})
+  ///   Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+  /// ```
+  /// c of the input/output is collapsed with w. kw is always unrolled and
+  /// broadcast to match w.
+  ///
+  /// TODO: Add support for non-unit stride/dilation
+  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+  /// > 1.
+  FailureOr<Operation *> depthwiseConvFlatten() {
+    if (!valid)
+      return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+
+    int64_t nSize, iSize, wSize, cSize, kwSize;
+    // kernel{kw, c}
+    bindShapeDims(rhsShapedType, kwSize, cSize);
+    // out{n, w, c}
+    bindShapeDims(resShapedType, nSize, wSize);
+    // in{n, w, c}
+    bindShapeDims(lhsShapedType, nSize, iSize);
+
+    vector::TransferWriteOp write;
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+    if (strideW == 1)
+      return rewriter.notifyMatchFailure(
+          op, "Non-unit strides are not supported yet");
+    if (dilationW == 1)
+      return rewriter.notifyMatchFailure(
+          op, "Non-unit dilations are not supported yet");
+
+    Type lhsEltType = lhsShapedType.getElementType();
+    Type rhsEltType = rhsShapedType.getElementType();
+    Type resEltType = resShapedType.getElementType();
+    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+    VectorType lhsType = VectorType::get(
+        {nSize,
+         // iw = (ow * sw + kw *  dw - 1) * c
+         //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+         (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) *
+             cSize},
+        lhsEltType);
+
+    VectorType resType = VectorType::get({nSize, wSize * cSize}, resEltType);
+
+    Value res, lhs, lhsFlat, resFlat;
+    // Read rhs slice of size {kw, c} @ [0, 0].
+    Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+                                                        ValueRange{zero, zero});
+
+    SmallVector<ReassociationIndices> reassociation = {{0}, {1, 2}};
+
+    // Flatten w and c dimensions
+    auto lhsTypeCollapsed = VectorType::get({nSize, iSize * cSize}, lhsEltType);
+    auto linalgOp = dyn_cast<LinalgOp>(op);
+    lhsFlat =
+        linalgOp.hasTensorSemantics()
+            ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+                  loc,
+                  RankedTensorType::get(lhsTypeCollapsed.getShape(),
+                                        lhsEltType),
+                  lhsShaped, reassociation)
+            : (Value)rewriter.create<memref::CollapseShapeOp>(
+                  loc, MemRefType::get(lhsTypeCollapsed.getShape(), lhsEltType),
+                  lhsShaped, reassociation);
+    resFlat =
+        linalgOp.hasTensorSemantics()
+            ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+                  loc, RankedTensorType::get(resType.getShape(), resEltType),
+                  resShaped, reassociation)
+            : (Value)rewriter.create<memref::CollapseShapeOp>(
+                  loc, MemRefType::get(resType.getShape(), resEltType),
+                  resShaped, reassociation);
+
+    // Read lhs slice of size {n, (w * wSize + kw * dilationW) * c} @ [0,
+    // 0].
+    lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsFlat,
+                                                  ValueRange{zero, zero});
+    // Read res slice of size {n, w * c} @ [0, 0].
+    res = rewriter.create<vector::TransferReadOp>(loc, resType, resFlat,
+                                                  ValueRange{zero, zero});
+
+    //===------------------------------------------------------------------===//
+    // Begin vector-only rewrite part
+    //===------------------------------------------------------------------===//
+    // Unroll along kw and read slices of lhs and rhs.
+    SmallVector<Value> lhsVals, rhsVals, resVals;
+    // Extract lhs slice of size {n, wSizeStep * c}
+    //   @ [0, (sw * w + dw * kw) * cSize].
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w = 0; w < wSize; w += wSize) {
+        lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, lhs,
+            /*offsets=*/
+            ArrayRef<int64_t>{0, (w * wSize + kw * dilationW) * cSize},
+            /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+            /*strides=*/ArrayRef<int64_t>{1, 1}));
+      }
+    }
+    // Extract rhs slice of size {c} @ [kw].
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      rhsVals.push_back(rewriter.create<vector::ExtractOp>(
+          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+    }
+
+    // Extract res slice
+    // Flattened case:  {n, wSizeStep * c} @ [0, w].
+    for (int64_t w = 0; w < wSize; w += wSize) {
+      resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, res,
+          /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+          /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+          /*strides=*/ArrayRef<int64_t>{1, 1}));
+    }
+
+    auto linearIndex = [&](int64_t kw, int64_t w) {
+      return kw * (wSize / wSize) + w;
+    };
+
+    // Compute contraction:
+    //    O{n, w * c} += I{n, (sw * w + dw * kw) * c} * F{c}
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w = 0; w < wSize; w += wSize) {
+        resVals[w] = depthwiseConv1dFlatSliceAsMulAcc(
+            rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+            resVals[w]);
+      }
+    }
+
+    // Its possible we failed to create the Fma.
+    if (!llvm::all_of(resVals, [](Value v) { return v; })) {
+      // Manually revert (in reverse order) to avoid leaving a bad IR state.
+      for (auto &collection :
+           {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
+        for (Value v : collection)
+          rewriter.eraseOp(v.getDefiningOp());
+      return rewriter.notifyMatchFailure(op, "failed to create FMA");
+    }
+
+    // Write back res slice. This does not depend on kw.
+    // Flattened case: {n, wSizeStep * c} @ [0, w].
+    for (int64_t w = 0; w < wSize; w += wSize) {
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], res,
+          /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+          /*strides=*/ArrayRef<int64_t>{1, 1});
+    }
+    //===------------------------------------------------------------------===//
+    // End vector-only rewrite part
+    //===------------------------------------------------------------------===//
+    // Write back res slice of size {n, w * c} @ [0, 0].
+    mlir::vector::TransferWriteOp tWrite =
+        rewriter.create<vector::TransferWriteOp>(loc, res, resFlat,
+                                                 ValueRange{zero, zero});
+
+    // A tensor has to be re-shaped back to it's original shape ...
+    if (linalgOp.hasTensorSemantics())
+      // Re-expand shape
+      return rewriter
+          .create<tensor::ExpandShapeOp>(loc, resShapedType, tWrite.getResult(),
+                                         reassociation)
+          .getOperation();
+    /// ... memrefs don't requie reshaping (re-shape is just a different view
+    /// into the same memref)
+    return tWrite.getOperation();
+  }
+
   /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
                                      Value lhs, Value rhs, Value res) {
@@ -2959,6 +3131,39 @@ struct Conv1DGenerator
     return rewriter.create<arith::AddIOp>(loc, mul, res);
   }
 
+  /// Lower lhs{n, w * c} * rhs{c} -> res{n, w * c} to MulAcc
+  Value depthwiseConv1dFlatSliceAsMulAcc(RewriterBase &rewriter, Location loc,
+                                         Value lhs, Value rhs, Value res) {
+    auto rhsTy = rhs.getType().cast<ShapedType>();
+    auto resTy = res.getType().cast<ShapedType>();
+
+    lhs = promote(rewriter, loc, lhs, resTy);
+
+    auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
+    auto resSize = res.getType().cast<VectorType>().getShape()[1];
+
+    SmallVector<int64_t, 16> indicies;
+    for (int i = 0; i < resSize / rhsSize; ++i) {
+      for (int j = 0; j < rhsSize; ++j)
+        indicies.push_back(j);
+    }
+
+    rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
+
+    rhs = rewriter.create<vector::BroadcastOp>(
+        loc, resTy.clone(rhsTy.getElementType()), rhs);
+    rhs = promote(rewriter, loc, rhs, resTy);
+
+    if (!lhs || !rhs)
+      return nullptr;
+
+    if (resTy.getElementType().isa<FloatType>())
+      return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+
+    auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
+    return rewriter.create<arith::AddIOp>(loc, mul, res);
+  }
+
   /// Entry point for non-channeled convolution:
   ///   {{w + kw}, {kw}, {w}}
   FailureOr<Operation *> generateNonChanneledConv() {
@@ -3049,7 +3254,7 @@ struct Conv1DGenerator
 
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
-  FailureOr<Operation *> generateDilatedConv() {
+  FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
@@ -3060,7 +3265,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConv();
+      return flatten ? depthwiseConvFlatten() : depthwiseConvGeneric();
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
@@ -3125,8 +3330,9 @@ struct Conv1DGenerator
 
 /// Helper function to vectorize a LinalgOp with convolution semantics.
 // TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
-                                                   LinalgOp op) {
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+                     bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can simply
   // use them if present, otherwise use the default and let the generic conv.
@@ -3151,7 +3357,7 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
   res = e.generateNcwPooling();
   if (succeeded(res))
     return res;
-  return e.generateDilatedConv();
+  return e.generateDilatedConv(flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
new file mode 100644
index 000000000000000..6b0f920bfa42e7f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -0,0 +1,222 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @flatten_tensor(%input: tensor<1x8x3xi8>, %filter: tensor<1x3xi8>, %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
+    outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
+  return %res : tensor<1x8x3xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @flatten_tensor(
+// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<1x8x3xi8>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x3xi8>,
+// CHECK-SAME:                              %[[VAL_2:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x3xi8>, vector<1x3xi8>
+// CHECK:           %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
+// CHECK:           %[[VAL_7:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8>
+// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<1x3xi8>
+// CHECK:           %[[VAL_11:.*]] = vector.shuffle %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<24xi8> to vector<1x24xi8>
+// CHECK:           %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_12]] : vector<1x24xi8>
+// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : vector<1x24xi8>
+// CHECK:           %[[VAL_15:.*]] = vector.transfer_write %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x24xi8>, tensor<1x24xi8>
+// CHECK:           %[[VAL_16:.*]] = tensor.expand_shape %[[VAL_15]] {{\[\[}}0], [1, 2]] : tensor<1x24xi8> into tensor<1x8x3xi8>
+// CHECK:           return %[[VAL_16]] : tensor<1x8x3xi8>
+// CHECK:         }
+
+//------
+
+func.func @flatten_memref(%input: memref<1x8x3xi8>, %filter: memref<1x3xi8>, %output: memref<1x8x3xi8>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : memref<1x8x3xi8>, memref<1x3xi8>)
+    outs(%output : memref<1x8x3xi8>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @flatten_memref(
+// CHECK-SAME:                              %[[VAL_0:.*]]: memref<1x8x3xi8>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: memref<1x3xi8>,
+// CHECK-SAME:                              %[[VAL_2:.*]]: memref<1x8x3xi8>) {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
+// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
+// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
+// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<1x3xi8>
+// CHECK:           %[[VAL_11:.*]] = vector.shuffle %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<24xi8> to vector<1x24xi8>
+// CHECK:           %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_12]] : vector<1x24xi8>
+// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : vector<1x24xi8>
+// CHECK:           vector.transfer_write %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x24xi8>, memref<1x24xi8>
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @flatten_memref_wider_filter(%input: memref<1x8x3xi8>, %filter: memref<2x3xi8>, %output: memref<1x7x3xi8>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : memref<1x8x3xi8>, memref<2x3xi8>)
+    outs(%output : memref<1x7x3xi8>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @flatten_memref_wider_filter(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<1x8x3xi8>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: memref<2x3xi8>,
+// CHECK-SAME:                                           %[[VAL_2:.*]]: memref<1x7x3xi8>) {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x3xi8>, vector<2x3xi8>
+// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
+// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<1x7x3xi8> into memref<1x21xi8>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x21xi8>, vector<1x21xi8>
+// CHECK:           %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 0], sizes = [1, 21], strides = [1, 1]} : vector<1x24xi8> to vector<1x21xi8>
+// CHECK:           %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 3], sizes = [1, 21], strides = [1, 1]} : vector<1x24xi8> to vector<1x21xi8>
+// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<2x3xi8>
+// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_5]][1] : vector<3xi8> from vector<2x3xi8>
+// CHECK:           %[[VAL_14:.*]] = vector.shuffle %[[VAL_12]], %[[VAL_12]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<21xi8> to vector<1x21xi8>
+// CHECK:           %[[VAL_16:.*]] = arith.muli %[[VAL_10]], %[[VAL_15]] : vector<1x21xi8>
+// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_9]] : vector<1x21xi8>
+// CHECK:           %[[VAL_18:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
+// CHECK:           %[[VAL_19:.*]] = vector.broadcast %[[VAL_18]] : vector<21xi8> to vector<1x21xi8>
+// CHECK:           %[[VAL_20:.*]] = arith.muli %[[VAL_11]], %[[VAL_19]] : vector<1x21xi8>
+// CHECK:           %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_17]] : vector<1x21xi8>
+// CHECK:           vector.transfer_write %[[VAL_21]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x21xi8>, memref<1x21xi8>
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
+    outs(%output : memref<3x2x4xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(
+// CHECK-SAME:                                                        %[[VAL_0:.*]]: memref<3x5x4xf32>,
+// CHECK-SAME:                                                        %[[VAL_1:.*]]: memref<2x4xf32>,
+// CHECK-SAME:                                                        %[[VAL_2:.*]]: memref<3x2x4xf32>) {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
+// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<3x5x4xf32> into memref<3x20xf32>
+// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<3x2x4xf32> into memref<3x8xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x20xf32>, vector<3x16xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x8xf32>, vector<3x8xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 0], sizes = [3, 8], strides = [1, 1]} : vector<3x16xf32> to vector<3x8xf32>
+// CHECK:           %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 8], sizes = [3, 8], strides = [1, 1]} : vector<3x16xf32> to vector<3x8xf32>
+// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_5]][0] : vector<4xf32> from vector<2x4xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_5]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:           %[[VAL_14:.*]] = vector.shuffle %[[VAL_12]], %[[VAL_12]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<8xf32> to vector<3x8xf32>
+// CHECK:           %[[VAL_16:.*]] = vector.fma %[[VAL_10]], %[[VAL_15]], %[[VAL_9]] : vector<3x8xf32>
+// CHECK:           %[[VAL_17:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : vector<8xf32> to vector<3x8xf32>
+// CHECK:           %[[VAL_19:.*]] = vector.fma %[[VAL_11]], %[[VAL_18]], %[[VAL_16]] : vector<3x8xf32>
+// CHECK:           vector.transfer_write %[[VAL_19]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x8xf32>, memref<3x8xf32>
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %filter: memref<2x4xi8>, %output: memref<3x2x4xi32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>)
+    outs(%output : memref<3x2x4xi32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(
+// CHECK-SAME:                                                       %[[VAL_0:.*]]: memref<3x5x4xi8>,
+// CHECK-SAME:                                                       %[[VAL_1:.*]]: memref<2x4xi8>,
+// CHECK-SAME:                                                       %[[VAL_2:.*]]: memref<3x2x4xi32>) {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xi8>, vector<2x4xi8>
+// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<3x5x4xi8> into memref<3x20xi8>
+// CHECK:           %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<3x2x4xi32> into memref<3x8xi32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x20xi8>, vector<3x16xi8>
+// CHECK:           %[[VAL_10:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_5]] {in_bounds = [true, true]} : memref<3x8xi32>, vector<3x8xi32>
+// CHECK:           %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_9]] {offsets = [0, 0], sizes = [3, 8], strides = [1, 1]} : vector<3x16xi8> to vector<3x8xi8>
+// CHECK:           %[[VAL_12:.*]] = vector.extract_strided_slice %[[VAL_9]] {offsets = [0, 8], sizes = [3, 8], strides = [1, 1]} : vector<3x16xi8> to vector<3x8xi8>
+// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_6]][0] : vector<4xi8> from vector<2x4xi8>
+// CHECK:           %[[VAL_14:.*]] = vector.extract %[[VAL_6]][1] : vector<4xi8> from vector<2x4xi8>
+// CHECK:           %[[VAL_15:.*]] = arith.extsi %[[VAL_11]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK:           %[[VAL_16:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+// CHECK:           %[[VAL_17:.*]] = vector.broadcast %[[VAL_16]] : vector<8xi8> to vector<3x8xi8>
+// CHECK:           %[[VAL_18:.*]] = arith.extsi %[[VAL_17]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK:           %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x8xi32>
+// CHECK:           %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_10]] : vector<3x8xi32>
+// CHECK:           %[[VAL_21:.*]] = arith.extsi %[[VAL_12]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK:           %[[VAL_22:.*]] = vector.shuffle %[[VAL_14]], %[[VAL_14]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
+// CHECK:           %[[VAL_23:.*]] = vector.broadcast %[[VAL_22]] : vector<8xi8> to vector<3x8xi8>
+// CHECK:           %[[VAL_24:.*]] = arith.extsi %[[VAL_23]] : vector<3x8xi8> to vector<3x8xi32>
+// CHECK:           %[[VAL_25:.*]] = arith.muli %[[VAL_21]], %[[VAL_24]] : vector<3x8xi32>
+// CHECK:           %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_20]] : vector<3x8xi32>
+// CHECK:           vector.transfer_write %[[VAL_26]], %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x8xi32>, memref<3x8xi32>
+// CHECK:           return
+// CHECK:         }
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index 93e36a69567bd59..59b242d530442b3 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -535,9 +535,9 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %
 //   CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
 
 /// Read the whole data in one shot.
-//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
-//      CHECK:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
-//      CHECK:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+//      CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 //      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
@@ -575,9 +575,9 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %fi
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 
 /// Read the whole data in one shot.
-//      CHECK:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
-//      CHECK:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
-//      CHECK:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK-DAG:   %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
+//      CHECK-DAG:  %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
+//      CHECK-DAG:  %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
 
 //      CHECK:   %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
 // CHECK-SAME:     {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>

>From e0f7a0b058d0a1cc35f15ae75c1d317f7710c50e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 15 Nov 2023 18:38:07 +0000
Subject: [PATCH 2/3] Change the implementation

Following on from Nicolas' observation, this commit refactors the
implementation to simply replace:
```
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
                                          lhsVals[linearIndex(kw, w)],
                                          rhsVals[kw], resVals[w]);
```

with shape_cast + depthwiseConv1dSliceAsMulAcc + shape_cast.
---
 .../Linalg/Transforms/Vectorization.cpp       | 252 +++---------------
 .../Linalg/vectorize-convolution-flatten.mlir | 162 +++++------
 2 files changed, 118 insertions(+), 296 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f6f74b448edf9a8..acbf29a9bbd616d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2824,7 +2824,7 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConvGeneric() {
+  FailureOr<Operation *> depthwiseConvGeneric(bool flatten) {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
@@ -2871,6 +2871,9 @@ struct Conv1DGenerator
     //===------------------------------------------------------------------===//
     // Unroll along kw and read slices of lhs and rhs.
     SmallVector<Value> lhsVals, rhsVals, resVals;
+    auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
+    auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
+
     // Extract lhs slice of size {n, wSizeStep, c}
     //   @ [0, sw * w + dw * kw, 0].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
@@ -2878,8 +2881,7 @@ struct Conv1DGenerator
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
             loc, lhs,
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
-            /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
-            /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+            inOutSliceSizes, inOutStrides));
       }
     }
     // Extract rhs slice of size {c} @ [kw].
@@ -2891,21 +2893,35 @@ struct Conv1DGenerator
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
           loc, res,
-          /*offsets=*/ArrayRef<int64_t>{0, w, 0},
-          /*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
-          /*strides=*/ArrayRef<int64_t>{1, 1, 1}));
+          /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
+          inOutStrides));
     }
 
     auto linearIndex = [&](int64_t kw, int64_t w) {
       return kw * (wSize / wSizeStep) + w;
     };
 
+    auto inOutFlattenSliceSizes =
+        SmallVector<int64_t>{nSize, wSizeStep * cSize};
+    auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
+    auto resCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
-                                                  lhsVals[linearIndex(kw, w)],
-                                                  rhsVals[kw], resVals[w]);
+        Value lhsVal = lhsVals[linearIndex(kw, w)];
+        Value resVal = resVals[w];
+        ShapedType filterBCastTy = cast<ShapedType>(resVal.getType());
+        if (flatten) {
+          lhsVal = rewriter.create<vector::ShapeCastOp>(
+              loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
+          resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
+                                                        resVals[w]);
+        }
+        resVals[w] = depthwiseConv1dSliceAsMulAcc(
+            rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
+        if (flatten)
+          resVals[w] = rewriter.create<vector::ShapeCastOp>(
+              loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
       }
     }
 
@@ -2938,179 +2954,13 @@ struct Conv1DGenerator
         .getOperation();
   }
 
-  /// Generate a vector implementation for ("flatten channel dim"):
-  /// ```
-  ///   Op def: (     n,     w,     c,    kw)
-  ///    Iters: ({Par(), Par(), Par(), Red()})
-  ///   Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
-  /// ```
-  /// c of the input/output is collapsed with w. kw is always unrolled and
-  /// broadcast to match w.
-  ///
-  /// TODO: Add support for non-unit stride/dilation
-  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
-  /// > 1.
-  FailureOr<Operation *> depthwiseConvFlatten() {
-    if (!valid)
-      return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
-
-    int64_t nSize, iSize, wSize, cSize, kwSize;
-    // kernel{kw, c}
-    bindShapeDims(rhsShapedType, kwSize, cSize);
-    // out{n, w, c}
-    bindShapeDims(resShapedType, nSize, wSize);
-    // in{n, w, c}
-    bindShapeDims(lhsShapedType, nSize, iSize);
-
-    vector::TransferWriteOp write;
-    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-
-    if (strideW == 1)
-      return rewriter.notifyMatchFailure(
-          op, "Non-unit strides are not supported yet");
-    if (dilationW == 1)
-      return rewriter.notifyMatchFailure(
-          op, "Non-unit dilations are not supported yet");
-
-    Type lhsEltType = lhsShapedType.getElementType();
-    Type rhsEltType = rhsShapedType.getElementType();
-    Type resEltType = resShapedType.getElementType();
-    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
-    VectorType lhsType = VectorType::get(
-        {nSize,
-         // iw = (ow * sw + kw *  dw - 1) * c
-         //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
-         (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) *
-             cSize},
-        lhsEltType);
-
-    VectorType resType = VectorType::get({nSize, wSize * cSize}, resEltType);
-
-    Value res, lhs, lhsFlat, resFlat;
-    // Read rhs slice of size {kw, c} @ [0, 0].
-    Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
-                                                        ValueRange{zero, zero});
-
-    SmallVector<ReassociationIndices> reassociation = {{0}, {1, 2}};
-
-    // Flatten w and c dimensions
-    auto lhsTypeCollapsed = VectorType::get({nSize, iSize * cSize}, lhsEltType);
-    auto linalgOp = dyn_cast<LinalgOp>(op);
-    lhsFlat =
-        linalgOp.hasTensorSemantics()
-            ? (Value)rewriter.create<tensor::CollapseShapeOp>(
-                  loc,
-                  RankedTensorType::get(lhsTypeCollapsed.getShape(),
-                                        lhsEltType),
-                  lhsShaped, reassociation)
-            : (Value)rewriter.create<memref::CollapseShapeOp>(
-                  loc, MemRefType::get(lhsTypeCollapsed.getShape(), lhsEltType),
-                  lhsShaped, reassociation);
-    resFlat =
-        linalgOp.hasTensorSemantics()
-            ? (Value)rewriter.create<tensor::CollapseShapeOp>(
-                  loc, RankedTensorType::get(resType.getShape(), resEltType),
-                  resShaped, reassociation)
-            : (Value)rewriter.create<memref::CollapseShapeOp>(
-                  loc, MemRefType::get(resType.getShape(), resEltType),
-                  resShaped, reassociation);
-
-    // Read lhs slice of size {n, (w * wSize + kw * dilationW) * c} @ [0,
-    // 0].
-    lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsFlat,
-                                                  ValueRange{zero, zero});
-    // Read res slice of size {n, w * c} @ [0, 0].
-    res = rewriter.create<vector::TransferReadOp>(loc, resType, resFlat,
-                                                  ValueRange{zero, zero});
-
-    //===------------------------------------------------------------------===//
-    // Begin vector-only rewrite part
-    //===------------------------------------------------------------------===//
-    // Unroll along kw and read slices of lhs and rhs.
-    SmallVector<Value> lhsVals, rhsVals, resVals;
-    // Extract lhs slice of size {n, wSizeStep * c}
-    //   @ [0, (sw * w + dw * kw) * cSize].
-    for (int64_t kw = 0; kw < kwSize; ++kw) {
-      for (int64_t w = 0; w < wSize; w += wSize) {
-        lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, lhs,
-            /*offsets=*/
-            ArrayRef<int64_t>{0, (w * wSize + kw * dilationW) * cSize},
-            /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
-            /*strides=*/ArrayRef<int64_t>{1, 1}));
-      }
-    }
-    // Extract rhs slice of size {c} @ [kw].
-    for (int64_t kw = 0; kw < kwSize; ++kw) {
-      rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
-    }
-
-    // Extract res slice
-    // Flattened case:  {n, wSizeStep * c} @ [0, w].
-    for (int64_t w = 0; w < wSize; w += wSize) {
-      resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, res,
-          /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
-          /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
-          /*strides=*/ArrayRef<int64_t>{1, 1}));
-    }
-
-    auto linearIndex = [&](int64_t kw, int64_t w) {
-      return kw * (wSize / wSize) + w;
-    };
-
-    // Compute contraction:
-    //    O{n, w * c} += I{n, (sw * w + dw * kw) * c} * F{c}
-    for (int64_t kw = 0; kw < kwSize; ++kw) {
-      for (int64_t w = 0; w < wSize; w += wSize) {
-        resVals[w] = depthwiseConv1dFlatSliceAsMulAcc(
-            rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
-            resVals[w]);
-      }
-    }
-
-    // Its possible we failed to create the Fma.
-    if (!llvm::all_of(resVals, [](Value v) { return v; })) {
-      // Manually revert (in reverse order) to avoid leaving a bad IR state.
-      for (auto &collection :
-           {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
-        for (Value v : collection)
-          rewriter.eraseOp(v.getDefiningOp());
-      return rewriter.notifyMatchFailure(op, "failed to create FMA");
-    }
-
-    // Write back res slice. This does not depend on kw.
-    // Flattened case: {n, wSizeStep * c} @ [0, w].
-    for (int64_t w = 0; w < wSize; w += wSize) {
-      res = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, resVals[w], res,
-          /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
-          /*strides=*/ArrayRef<int64_t>{1, 1});
-    }
-    //===------------------------------------------------------------------===//
-    // End vector-only rewrite part
-    //===------------------------------------------------------------------===//
-    // Write back res slice of size {n, w * c} @ [0, 0].
-    mlir::vector::TransferWriteOp tWrite =
-        rewriter.create<vector::TransferWriteOp>(loc, res, resFlat,
-                                                 ValueRange{zero, zero});
-
-    // A tensor has to be re-shaped back to it's original shape ...
-    if (linalgOp.hasTensorSemantics())
-      // Re-expand shape
-      return rewriter
-          .create<tensor::ExpandShapeOp>(loc, resShapedType, tWrite.getResult(),
-                                         reassociation)
-          .getOperation();
-    /// ... memrefs don't requie reshaping (re-shape is just a different view
-    /// into the same memref)
-    return tWrite.getOperation();
-  }
-
-  /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
+  /// Lower:
+  ///   *  lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
+  ///   *  lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
+  /// to MulAcc.
   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
-                                     Value lhs, Value rhs, Value res) {
+                                     Value lhs, Value rhs, Value res,
+                                     ShapedType bcastTy, bool flatten) {
     auto rhsTy = cast<ShapedType>(rhs.getType());
     auto resTy = cast<ShapedType>(res.getType());
 
@@ -3118,46 +2968,16 @@ struct Conv1DGenerator
     lhs = promote(rewriter, loc, lhs, resTy);
 
     rhs = rewriter.create<vector::BroadcastOp>(
-        loc, resTy.clone(rhsTy.getElementType()), rhs);
-    rhs = promote(rewriter, loc, rhs, resTy);
-
-    if (!lhs || !rhs)
-      return nullptr;
-
-    if (isa<FloatType>(resTy.getElementType()))
-      return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+        loc, bcastTy.clone(rhsTy.getElementType()), rhs);
+    if (flatten)
+      rhs = rewriter.create<vector::ShapeCastOp>(loc, resTy, rhs);
 
-    auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
-    return rewriter.create<arith::AddIOp>(loc, mul, res);
-  }
-
-  /// Lower lhs{n, w * c} * rhs{c} -> res{n, w * c} to MulAcc
-  Value depthwiseConv1dFlatSliceAsMulAcc(RewriterBase &rewriter, Location loc,
-                                         Value lhs, Value rhs, Value res) {
-    auto rhsTy = rhs.getType().cast<ShapedType>();
-    auto resTy = res.getType().cast<ShapedType>();
-
-    lhs = promote(rewriter, loc, lhs, resTy);
-
-    auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
-    auto resSize = res.getType().cast<VectorType>().getShape()[1];
-
-    SmallVector<int64_t, 16> indicies;
-    for (int i = 0; i < resSize / rhsSize; ++i) {
-      for (int j = 0; j < rhsSize; ++j)
-        indicies.push_back(j);
-    }
-
-    rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
-
-    rhs = rewriter.create<vector::BroadcastOp>(
-        loc, resTy.clone(rhsTy.getElementType()), rhs);
     rhs = promote(rewriter, loc, rhs, resTy);
 
     if (!lhs || !rhs)
       return nullptr;
 
-    if (resTy.getElementType().isa<FloatType>())
+    if (isa<FloatType>(resTy.getElementType()))
       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
 
     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
@@ -3265,7 +3085,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return flatten ? depthwiseConvFlatten() : depthwiseConvGeneric();
+      return depthwiseConvGeneric(flatten);
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
index 6b0f920bfa42e7f..17f93ed46e77900 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -17,25 +17,24 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-
 // CHECK-LABEL:   func.func @flatten_tensor(
 // CHECK-SAME:                              %[[VAL_0:.*]]: tensor<1x8x3xi8>,
 // CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x3xi8>,
 // CHECK-SAME:                              %[[VAL_2:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> {
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x3xi8>, vector<1x3xi8>
-// CHECK:           %[[VAL_6:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
-// CHECK:           %[[VAL_7:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
-// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x24xi8>, vector<1x24xi8>
-// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<1x3xi8>
-// CHECK:           %[[VAL_11:.*]] = vector.shuffle %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
-// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<24xi8> to vector<1x24xi8>
-// CHECK:           %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_12]] : vector<1x24xi8>
-// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : vector<1x24xi8>
-// CHECK:           %[[VAL_15:.*]] = vector.transfer_write %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x24xi8>, tensor<1x24xi8>
-// CHECK:           %[[VAL_16:.*]] = tensor.expand_shape %[[VAL_15]] {{\[\[}}0], [1, 2]] : tensor<1x24xi8> into tensor<1x8x3xi8>
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : tensor<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<1x3xi8>, vector<1x3xi8>
+// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : tensor<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK:           %[[VAL_8:.*]] = vector.extract %[[VAL_6]][0] : vector<3xi8> from vector<1x3xi8>
+// CHECK:           %[[VAL_9:.*]] = vector.shape_cast %[[VAL_5]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           %[[VAL_11:.*]] = vector.broadcast %[[VAL_8]] : vector<3xi8> to vector<1x8x3xi8>
+// CHECK:           %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           %[[VAL_13:.*]] = arith.muli %[[VAL_9]], %[[VAL_12]] : vector<1x24xi8>
+// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : vector<1x24xi8>
+// CHECK:           %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x24xi8> to vector<1x8x3xi8>
+// CHECK:           %[[VAL_16:.*]] = vector.transfer_write %[[VAL_15]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, tensor<1x8x3xi8>
 // CHECK:           return %[[VAL_16]] : tensor<1x8x3xi8>
 // CHECK:         }
 
@@ -65,17 +64,18 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:                              %[[VAL_2:.*]]: memref<1x8x3xi8>) {
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
-// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
-// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
-// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
-// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<1x3xi8>
-// CHECK:           %[[VAL_11:.*]] = vector.shuffle %[[VAL_10]], %[[VAL_10]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
-// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<24xi8> to vector<1x24xi8>
-// CHECK:           %[[VAL_13:.*]] = arith.muli %[[VAL_8]], %[[VAL_12]] : vector<1x24xi8>
-// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : vector<1x24xi8>
-// CHECK:           vector.transfer_write %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x24xi8>, memref<1x24xi8>
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x3xi8>, vector<1x3xi8>
+// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK:           %[[VAL_8:.*]] = vector.extract %[[VAL_6]][0] : vector<3xi8> from vector<1x3xi8>
+// CHECK:           %[[VAL_9:.*]] = vector.shape_cast %[[VAL_5]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           %[[VAL_11:.*]] = vector.broadcast %[[VAL_8]] : vector<3xi8> to vector<1x8x3xi8>
+// CHECK:           %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<1x8x3xi8> to vector<1x24xi8>
+// CHECK:           %[[VAL_13:.*]] = arith.muli %[[VAL_9]], %[[VAL_12]] : vector<1x24xi8>
+// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : vector<1x24xi8>
+// CHECK:           %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x24xi8> to vector<1x8x3xi8>
+// CHECK:           vector.transfer_write %[[VAL_15]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<1x8x3xi8>, memref<1x8x3xi8>
 // CHECK:           return
 // CHECK:         }
 
@@ -105,27 +105,30 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:                                           %[[VAL_2:.*]]: memref<1x7x3xi8>) {
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i8
-// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x3xi8>, vector<2x3xi8>
-// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<1x8x3xi8> into memref<1x24xi8>
-// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<1x7x3xi8> into memref<1x21xi8>
-// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x24xi8>, vector<1x24xi8>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<1x21xi8>, vector<1x21xi8>
-// CHECK:           %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 0], sizes = [1, 21], strides = [1, 1]} : vector<1x24xi8> to vector<1x21xi8>
-// CHECK:           %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 3], sizes = [1, 21], strides = [1, 1]} : vector<1x24xi8> to vector<1x21xi8>
-// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_5]][0] : vector<3xi8> from vector<2x3xi8>
-// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_5]][1] : vector<3xi8> from vector<2x3xi8>
-// CHECK:           %[[VAL_14:.*]] = vector.shuffle %[[VAL_12]], %[[VAL_12]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
-// CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<21xi8> to vector<1x21xi8>
-// CHECK:           %[[VAL_16:.*]] = arith.muli %[[VAL_10]], %[[VAL_15]] : vector<1x21xi8>
-// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_9]] : vector<1x21xi8>
-// CHECK:           %[[VAL_18:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] : vector<3xi8>, vector<3xi8>
-// CHECK:           %[[VAL_19:.*]] = vector.broadcast %[[VAL_18]] : vector<21xi8> to vector<1x21xi8>
-// CHECK:           %[[VAL_20:.*]] = arith.muli %[[VAL_11]], %[[VAL_19]] : vector<1x21xi8>
-// CHECK:           %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_17]] : vector<1x21xi8>
-// CHECK:           vector.transfer_write %[[VAL_21]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<1x21xi8>, memref<1x21xi8>
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x8x3xi8>, vector<1x8x3xi8>
+// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x3xi8>, vector<2x3xi8>
+// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<1x7x3xi8>, vector<1x7x3xi8>
+// CHECK:           %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 0, 0], sizes = [1, 7, 3], strides = [1, 1, 1]} : vector<1x8x3xi8> to vector<1x7x3xi8>
+// CHECK:           %[[VAL_9:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 1, 0], sizes = [1, 7, 3], strides = [1, 1, 1]} : vector<1x8x3xi8> to vector<1x7x3xi8>
+// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<3xi8> from vector<2x3xi8>
+// CHECK:           %[[VAL_11:.*]] = vector.extract %[[VAL_6]][1] : vector<3xi8> from vector<2x3xi8>
+// CHECK:           %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK:           %[[VAL_13:.*]] = vector.shape_cast %[[VAL_7]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK:           %[[VAL_14:.*]] = vector.broadcast %[[VAL_10]] : vector<3xi8> to vector<1x7x3xi8>
+// CHECK:           %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK:           %[[VAL_16:.*]] = arith.muli %[[VAL_12]], %[[VAL_15]] : vector<1x21xi8>
+// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_13]] : vector<1x21xi8>
+// CHECK:           %[[VAL_18:.*]] = vector.shape_cast %[[VAL_9]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK:           %[[VAL_19:.*]] = vector.broadcast %[[VAL_11]] : vector<3xi8> to vector<1x7x3xi8>
+// CHECK:           %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<1x7x3xi8> to vector<1x21xi8>
+// CHECK:           %[[VAL_21:.*]] = arith.muli %[[VAL_18]], %[[VAL_20]] : vector<1x21xi8>
+// CHECK:           %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_17]] : vector<1x21xi8>
+// CHECK:           %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x21xi8> to vector<1x7x3xi8>
+// CHECK:           vector.transfer_write %[[VAL_23]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<1x7x3xi8>, memref<1x7x3xi8>
 // CHECK:           return
 // CHECK:         }
 
+
 // -----
 
 func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
@@ -151,22 +154,24 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:                                                        %[[VAL_2:.*]]: memref<3x2x4xf32>) {
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
-// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<3x5x4xf32> into memref<3x20xf32>
-// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<3x2x4xf32> into memref<3x8xf32>
-// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x20xf32>, vector<3x16xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x8xf32>, vector<3x8xf32>
-// CHECK:           %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 0], sizes = [3, 8], strides = [1, 1]} : vector<3x16xf32> to vector<3x8xf32>
-// CHECK:           %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_8]] {offsets = [0, 8], sizes = [3, 8], strides = [1, 1]} : vector<3x16xf32> to vector<3x8xf32>
-// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_5]][0] : vector<4xf32> from vector<2x4xf32>
-// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_5]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK:           %[[VAL_14:.*]] = vector.shuffle %[[VAL_12]], %[[VAL_12]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
-// CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<8xf32> to vector<3x8xf32>
-// CHECK:           %[[VAL_16:.*]] = vector.fma %[[VAL_10]], %[[VAL_15]], %[[VAL_9]] : vector<3x8xf32>
-// CHECK:           %[[VAL_17:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
-// CHECK:           %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : vector<8xf32> to vector<3x8xf32>
-// CHECK:           %[[VAL_19:.*]] = vector.fma %[[VAL_11]], %[[VAL_18]], %[[VAL_16]] : vector<3x8xf32>
-// CHECK:           vector.transfer_write %[[VAL_19]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x8xf32>, memref<3x8xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
+// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.extract_strided_slice %[[VAL_5]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_6]][0] : vector<4xf32> from vector<2x4xf32>
+// CHECK:           %[[VAL_11:.*]] = vector.extract %[[VAL_6]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:           %[[VAL_12:.*]] = vector.shape_cast %[[VAL_8]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.shape_cast %[[VAL_7]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[VAL_14:.*]] = vector.broadcast %[[VAL_10]] : vector<4xf32> to vector<3x2x4xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.shape_cast %[[VAL_14]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[VAL_16:.*]] = vector.fma %[[VAL_12]], %[[VAL_15]], %[[VAL_13]] : vector<3x8xf32>
+// CHECK:           %[[VAL_17:.*]] = vector.shape_cast %[[VAL_9]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[VAL_18:.*]] = vector.broadcast %[[VAL_11]] : vector<4xf32> to vector<3x2x4xf32>
+// CHECK:           %[[VAL_19:.*]] = vector.shape_cast %[[VAL_18]] : vector<3x2x4xf32> to vector<3x8xf32>
+// CHECK:           %[[VAL_20:.*]] = vector.fma %[[VAL_17]], %[[VAL_19]], %[[VAL_16]] : vector<3x8xf32>
+// CHECK:           %[[VAL_21:.*]] = vector.shape_cast %[[VAL_20]] : vector<3x8xf32> to vector<3x2x4xf32>
+// CHECK:           vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
 // CHECK:           return
 // CHECK:         }
 
@@ -196,27 +201,24 @@ module attributes {transform.with_named_sequence} {
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i8
 // CHECK:           %[[VAL_5:.*]] = arith.constant 0 : i32
-// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xi8>, vector<2x4xi8>
-// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1, 2]] : memref<3x5x4xi8> into memref<3x20xi8>
-// CHECK:           %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2]] : memref<3x2x4xi32> into memref<3x8xi32>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<3x20xi8>, vector<3x16xi8>
-// CHECK:           %[[VAL_10:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_5]] {in_bounds = [true, true]} : memref<3x8xi32>, vector<3x8xi32>
-// CHECK:           %[[VAL_11:.*]] = vector.extract_strided_slice %[[VAL_9]] {offsets = [0, 0], sizes = [3, 8], strides = [1, 1]} : vector<3x16xi8> to vector<3x8xi8>
-// CHECK:           %[[VAL_12:.*]] = vector.extract_strided_slice %[[VAL_9]] {offsets = [0, 8], sizes = [3, 8], strides = [1, 1]} : vector<3x16xi8> to vector<3x8xi8>
-// CHECK:           %[[VAL_13:.*]] = vector.extract %[[VAL_6]][0] : vector<4xi8> from vector<2x4xi8>
-// CHECK:           %[[VAL_14:.*]] = vector.extract %[[VAL_6]][1] : vector<4xi8> from vector<2x4xi8>
-// CHECK:           %[[VAL_15:.*]] = arith.extsi %[[VAL_11]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK:           %[[VAL_16:.*]] = vector.shuffle %[[VAL_13]], %[[VAL_13]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
-// CHECK:           %[[VAL_17:.*]] = vector.broadcast %[[VAL_16]] : vector<8xi8> to vector<3x8xi8>
-// CHECK:           %[[VAL_18:.*]] = arith.extsi %[[VAL_17]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK:           %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x8xi32>
-// CHECK:           %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_10]] : vector<3x8xi32>
-// CHECK:           %[[VAL_21:.*]] = arith.extsi %[[VAL_12]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK:           %[[VAL_22:.*]] = vector.shuffle %[[VAL_14]], %[[VAL_14]] [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xi8>, vector<4xi8>
-// CHECK:           %[[VAL_23:.*]] = vector.broadcast %[[VAL_22]] : vector<8xi8> to vector<3x8xi8>
-// CHECK:           %[[VAL_24:.*]] = arith.extsi %[[VAL_23]] : vector<3x8xi8> to vector<3x8xi32>
-// CHECK:           %[[VAL_25:.*]] = arith.muli %[[VAL_21]], %[[VAL_24]] : vector<3x8xi32>
-// CHECK:           %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_20]] : vector<3x8xi32>
-// CHECK:           vector.transfer_write %[[VAL_26]], %[[VAL_8]]{{\[}}%[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true]} : vector<3x8xi32>, memref<3x8xi32>
+// CHECK:           %[[VAL_6:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true, true]} : memref<3x5x4xi8>, vector<3x4x4xi8>
+// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : memref<2x4xi8>, vector<2x4xi8>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_5]] {in_bounds = [true, true, true]} : memref<3x2x4xi32>, vector<3x2x4xi32>
+// CHECK:           %[[VAL_9:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+// CHECK:           %[[VAL_10:.*]] = vector.extract_strided_slice %[[VAL_6]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8>
+// CHECK:           %[[VAL_11:.*]] = vector.extract %[[VAL_7]][0] : vector<4xi8> from vector<2x4xi8>
+// CHECK:           %[[VAL_12:.*]] = vector.extract %[[VAL_7]][1] : vector<4xi8> from vector<2x4xi8>
+// CHECK:           %[[VAL_13:.*]] = arith.extsi %[[VAL_9]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+// CHECK:           %[[VAL_14:.*]] = arith.extsi %[[VAL_11]] : vector<4xi8> to vector<4xi32>
+// CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_14]] : vector<4xi32> to vector<3x2x4xi32>
+// CHECK:           %[[VAL_16:.*]] = arith.muli %[[VAL_13]], %[[VAL_15]] : vector<3x2x4xi32>
+// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_8]] : vector<3x2x4xi32>
+// CHECK:           %[[VAL_18:.*]] = arith.extsi %[[VAL_10]] : vector<3x2x4xi8> to vector<3x2x4xi32>
+// CHECK:           %[[VAL_19:.*]] = arith.extsi %[[VAL_12]] : vector<4xi8> to vector<4xi32>
+// CHECK:           %[[VAL_20:.*]] = vector.broadcast %[[VAL_19]] : vector<4xi32> to vector<3x2x4xi32>
+// CHECK:           %[[VAL_21:.*]] = arith.muli %[[VAL_18]], %[[VAL_20]] : vector<3x2x4xi32>
+// CHECK:           %[[VAL_22:.*]] = arith.addi %[[VAL_21]], %[[VAL_17]] : vector<3x2x4xi32>
+// CHECK:           vector.transfer_write %[[VAL_22]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true]} : vector<3x2x4xi32>, memref<3x2x4xi32>
 // CHECK:           return
 // CHECK:         }
+

>From 2cb7081a4d36fef77598d540483714f3551d9dac Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 16 Nov 2023 10:24:45 +0000
Subject: [PATCH 3/3] Revert renaming + fix formatting

---
 .../lib/Dialect/Linalg/Transforms/Vectorization.cpp | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index acbf29a9bbd616d..2940178db146829 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2824,7 +2824,7 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConvGeneric(bool flatten) {
+  FailureOr<Operation *> depthwiseConv(bool flatten) {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
@@ -2881,7 +2881,8 @@ struct Conv1DGenerator
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
             loc, lhs,
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
-            inOutSliceSizes, inOutStrides));
+            inOutSliceSizes,
+            inOutStrides));
       }
     }
     // Extract rhs slice of size {c} @ [kw].
@@ -2893,7 +2894,8 @@ struct Conv1DGenerator
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
           loc, res,
-          /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
+          /*offsets=*/ArrayRef<int64_t>{0, w, 0},
+          inOutSliceSizes,
           inOutStrides));
     }
 
@@ -2901,8 +2903,7 @@ struct Conv1DGenerator
       return kw * (wSize / wSizeStep) + w;
     };
 
-    auto inOutFlattenSliceSizes =
-        SmallVector<int64_t>{nSize, wSizeStep * cSize};
+    auto inOutFlattenSliceSizes = SmallVector<int64_t>{nSize, wSizeStep * cSize};
     auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
     auto resCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
@@ -3085,7 +3086,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConvGeneric(flatten);
+      return depthwiseConv(flatten);
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }



More information about the Mlir-commits mailing list