[Mlir-commits] [mlir] [mlir][linalg] Enable masked vectorisation for depthwise convolutions (PR #81625)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Mar 14 03:35:46 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/81625

>From 2c7931e12d4172bf45d60d012f016a48d3e3de32 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 8 Feb 2024 12:45:40 +0000
Subject: [PATCH 1/9] [mlir][linalg] Add scalable vectorisation for depthwise
 convolutions

This patch adds support for scalable vectorisation of depthwise 1D HWC
convolutions,`linalg.depthwise_conv_1d_nwc_wc`. This is implemented by
adding support for masking.

Two major assumptions are made:
  * only the channel dimension can be scalable/dynamic (i.e. the
    trailing dim),
  * when specifying vector sizes to use in the vectoriser, only the size
    corresponding to the channel dim is effectively used (other dims are
    inferred from the context).

In terms of scalable vectorisation, this should be sufficient that cover
all practical cases (i.e. making arbitrary dim scalable wouldn't make
much sense). As for more generic cases with dynamic shapes (e.g. w or n
dims being dynamic), more work would be needed. In particular, one would
have to consider the filter and input/ouput tensors separately. However,
it's not clear whether that would be of any use in practice.
---
 .../Linalg/Transforms/Vectorization.cpp       | 131 +++++++++++---
 .../Linalg/vectorize-conv-scalable.mlir       | 161 ++++++++++++++++++
 2 files changed, 273 insertions(+), 19 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 1e703dacfd0c75..60f4f9deeb1cda 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -54,6 +54,7 @@ using namespace mlir::linalg;
 /// Try to vectorize `convOp` as a convolution.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+                     ArrayRef<int64_t> inputVecSizes = {},
                      bool flatten1DDepthwiseConv = false);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1713,6 +1714,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+  // Support dynamic shapes in 1D depthwise convolution, but only in the
+  // _channel_ dimension. That's exclusively to support scalable vectorisation.
+  if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+    auto lhsShaped = op.getDpsInputOperand(0)->get();
+    ArrayRef<int64_t> lhsShape =
+        dyn_cast<ShapedType>(lhsShaped.getType()).getShape();
+    auto shapeWithoutCh = lhsShape.drop_back(1);
+    if (ShapedType::isDynamicShape(shapeWithoutCh))
+      return failure();
+
+    return success();
+  }
+
   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
   // linalg.copy ops and ops that implement ContractionOpInterface for now.
   if (!isElementwise(op) &&
@@ -1913,7 +1927,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Only element-wise ops supported in the presence of scalable dims.
   auto linalgOp = dyn_cast<LinalgOp>(op);
-  return success(linalgOp && isElementwise(linalgOp));
+  return success(linalgOp && (isElementwise(linalgOp) ||
+                              isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -1999,7 +2014,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             // inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
-                  rewriter, linalgOp, flatten1DDepthwiseConv);
+                  rewriter, linalgOp, inputVectorSizes, flatten1DDepthwiseConv);
               if (succeeded(convOr)) {
                 llvm::append_range(results, (*convOr)->getResults());
                 return success();
@@ -2828,6 +2843,7 @@ struct Conv1DGenerator
         return;
       break;
     }
+    hasTensorSemantics = linalgOp.hasPureTensorSemantics();
     // The op is now known to be valid.
     valid = true;
   }
@@ -3131,13 +3147,21 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConv(bool flatten) {
+  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
+                                       bool flatten) {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
+    bool scalableChDim = false;
     int64_t nSize, wSize, cSize, kwSize;
     // kernel{kw, c}
     bindShapeDims(rhsShapedType, kwSize, cSize);
+    // Dynamic channel size implies scalable vectorisation
+    if (ShapedType::isDynamic(cSize)) {
+      assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
+      cSize = channelDimVecSize;
+      scalableChDim = true;
+    }
     // out{n, w, c}
     bindShapeDims(resShapedType, nSize, wSize);
 
@@ -3158,20 +3182,74 @@ struct Conv1DGenerator
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
-        lhsEltType);
-    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
-    VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+        lhsEltType, {false, false, scalableChDim});
+    VectorType rhsType =
+        VectorType::get({kwSize, cSize}, rhsEltType,
+                        /*scalableDims=*/{false, scalableChDim});
+    VectorType resType =
+        VectorType::get({nSize, wSize, cSize}, resEltType,
+                        /*scalableDims=*/{false, false, scalableChDim});
+
+    // Masks the input xfer Op along the channel dim, iff the corresponding
+    // scalable flag is set.
+    auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
+                               ArrayRef<bool> scalableDims,
+                               Operation *opToMask) {
+      bool scalableChDim = scalableDims.back();
+      if (!scalableChDim)
+        return opToMask;
+
+      auto maskType =
+          VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
+
+      SmallVector<OpFoldResult> mixedSourceDims =
+          hasTensorSemantics
+              ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+                    .Case<vector::TransferReadOp>([&](auto readOp) {
+                      return tensor::getMixedSizes(rewriter, loc,
+                                                   readOp.getSource());
+                    })
+                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
+                      return tensor::getMixedSizes(rewriter, loc,
+                                                   writeOp.getOperand(1));
+                    })
+              : TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+                    .Case<vector::TransferReadOp>([&](auto readOp) {
+                      return memref::getMixedSizes(rewriter, loc,
+                                                   readOp.getSource());
+                    })
+                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
+                      return memref::getMixedSizes(rewriter, loc,
+                                                   writeOp.getOperand(1));
+                    });
+
+      Value maskOp =
+          rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+
+      return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
+    };
 
     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
     // 0].
     Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedLHS = maybeMaskXferOp(
+        lhsType.getShape(),
+        /*scalableDims=*/{false, false, scalableChDim}, lhs.getDefiningOp());
+
     // Read rhs slice of size {kw, c} @ [0, 0].
     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
                                                         ValueRange{zero, zero});
+    auto maybeMaskedRHS = maybeMaskXferOp(
+        rhsType.getShape(),
+        /*scalableDims=*/{false, scalableChDim}, rhs.getDefiningOp());
+
     // Read res slice of size {n, w, c} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedRES = maybeMaskXferOp(
+        resType.getShape(),
+        /*scalableDims=*/{false, false, scalableChDim}, res.getDefiningOp());
 
     //===------------------------------------------------------------------===//
     // Begin vector-only rewrite part
@@ -3186,7 +3264,7 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, lhs,
+            loc, maybeMaskedLHS->getResult(0),
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             inOutSliceSizes, inOutStrides));
       }
@@ -3194,12 +3272,13 @@ struct Conv1DGenerator
     // Extract rhs slice of size {c} @ [kw].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+          loc, maybeMaskedRHS->getResult(0),
+          /*offsets=*/ArrayRef<int64_t>{kw}));
     }
     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, res,
+          loc, maybeMaskedRES->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
           inOutStrides));
     }
@@ -3238,6 +3317,7 @@ struct Conv1DGenerator
     // Its possible we failed to create the Fma.
     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
       // Manually revert (in reverse order) to avoid leaving a bad IR state.
+      // TODO: Replace with maybeMasked
       for (auto &collection :
            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
         for (Value v : collection)
@@ -3248,8 +3328,8 @@ struct Conv1DGenerator
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      res = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, resVals[w], res,
+      maybeMaskedRES = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], maybeMaskedRES->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
     }
@@ -3258,10 +3338,12 @@ struct Conv1DGenerator
     //===------------------------------------------------------------------===//
 
     // Write back res slice of size {n, w, c} @ [0, 0, 0].
-    return rewriter
-        .create<vector::TransferWriteOp>(loc, res, resShaped,
-                                         ValueRange{zero, zero, zero})
-        .getOperation();
+    Operation *resOut = rewriter.create<vector::TransferWriteOp>(
+        loc, maybeMaskedRES->getResult(0), resShaped,
+        ValueRange{zero, zero, zero});
+    return maybeMaskXferOp(resType.getShape(),
+                           /*scalableDims=*/{false, false, scalableChDim},
+                           resOut);
   }
 
   /// Lower:
@@ -3302,8 +3384,9 @@ struct Conv1DGenerator
     if (!lhs || !rhs)
       return nullptr;
 
-    if (isa<FloatType>(resTy.getElementType()))
+    if (isa<FloatType>(resTy.getElementType())) {
       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+    }
 
     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
     return rewriter.create<arith::AddIOp>(loc, mul, res);
@@ -3399,7 +3482,8 @@ struct Conv1DGenerator
 
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
-  FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
+  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
+                                             bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
@@ -3410,7 +3494,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConv(flatten);
+      return depthwiseConv(vecChDimSize, flatten);
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
@@ -3422,6 +3506,7 @@ struct Conv1DGenerator
   StringAttr redOp;
   StringAttr poolExtOp;
   bool isPoolExt = false;
+  bool hasTensorSemantics = false;
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
@@ -3477,6 +3562,7 @@ struct Conv1DGenerator
 // TODO: extend the generic vectorization to support windows and drop this.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+                     ArrayRef<int64_t> inputVecSizes,
                      bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can
@@ -3502,7 +3588,14 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
   res = e.generateNcwPooling();
   if (succeeded(res))
     return res;
-  return e.generateDilatedConv(flatten1DDepthwiseConv);
+
+  uint64_t vecChDimSize = ShapedType::kDynamic;
+  if (!inputVecSizes.empty()) {
+    // Only use the input vector size corresponding to the channel dim. Other
+    // vector dims will be inferred from the Ops.
+    vecChDimSize = inputVecSizes[2];
+  }
+  return e.generateDilatedConv(vecChDimSize, flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
new file mode 100644
index 00000000000000..d4b3574451c2bf
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
+                                                   %filter: tensor<1x?xi8>,
+                                                   %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+    outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+  return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK-DAG:       arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[FILTER]], %[[VAL_5]] : tensor<1x?xi8>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C0_I8:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_SIZE_INPUT:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_SIZE_INPUT]] : vector<1x8x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[C0_I8_1:.*]] = arith.constant 0 : i8
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[CH_DIM_SIZE_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK:           %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1_1]], %[[CH_DIM_SIZE_FLT]] : vector<1x[4]xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[C0_I8_1]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_23:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_SIZE_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_23]] : tensor<1x8x?xi8>
+// CHECK:           %[[VAL_25:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_26:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[CH_DIM_SIZE_OUT]] : vector<1x8x[4]xi1>
+/// Read the output tensor
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[VAL_22]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Convolution
+// CHECK:           %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
+// CHECK:           %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[FLT_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[VEC_IN_0]], %[[FLT_B]] : vector<1x8x[4]xi8>
+// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[VEC_OUT_0]] : vector<1x8x[4]xi8>
+// CHECK:           %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_36:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_37:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_36]] : tensor<1x8x?xi8>
+// CHECK:           %[[VAL_38:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_39:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_38]], %[[VAL_39]], %[[VAL_37]] : vector<1x8x[4]xi1>
+
+/// Write the output tensor
+// CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x?xf32>,
+                                                                %filter: memref<2x?xf32>,
+                                                                %output: memref<3x2x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
+    outs(%output : memref<3x2x?xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
+
+// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = memref.dim %[[FILTER]], %[[VAL_5]] : memref<2x?xf32>
+// CHECK:           %[[VAL_7:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+
+/// Create a mask for the input tensor
+// CHECK:           %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_11:.*]] = memref.dim %[[INPUT]], %[[VAL_10]] : memref<3x5x?xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] : vector<3x4x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_17:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_18:.*]] = memref.dim %[[FILTER]], %[[VAL_17]] : memref<2x?xf32>
+// CHECK:           %[[VAL_19:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_20:.*]] = vector.create_mask %[[VAL_19]], %[[VAL_18]] : vector<2x[4]xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[VAL_20]] { vector.transfer_read %[[FILTER]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_16]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_22:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_23:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_24:.*]] = memref.dim %[[OUTPUT]], %[[VAL_23]] : memref<3x2x?xf32>
+// CHECK:           %[[VAL_25:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_26:.*]] = arith.constant 2 : index
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[VAL_24]] : vector<3x2x[4]xi1>
+/// Read the output tensor
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_22]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
+
+/// Convolution
+// CHECK:           %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[VEC_FLT_1:.*]] = vector.extract %[[VEC_FLT]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_FLT_0_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FMA_1:.*]] = vector.fma %[[VEC_IN_0]], %[[VEC_FLT_0_B]], %[[VEC_OUT_0]] : vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_FLT_1_B:.*]] = vector.broadcast %[[VEC_FLT_1]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FMA_2:.*]] = vector.fma %[[VEC_IN_1]], %[[VEC_FLT_1_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[VAL_39:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_40:.*]] = memref.dim %[[OUTPUT]], %[[VAL_39]] : memref<3x2x?xf32>
+// CHECK:           %[[VAL_41:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_42:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_43:.*]] = vector.create_mask %[[VAL_41]], %[[VAL_42]], %[[VAL_40]] : vector<3x2x[4]xi1>
+/// Write the output tensor
+// CHECK:           vector.mask %[[VAL_43]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
+    transform.yield
+  }
+}

>From d5287b27cda33540bb5a17fb92021789565d4e86 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 14 Feb 2024 22:12:34 +0000
Subject: [PATCH 2/9] fixup! [clang][Driver] Small correction to
 print-runtime-dir

Addressing Cullen's comments
---
 .../Linalg/Transforms/Vectorization.cpp       | 46 +++++++++----------
 .../Linalg/vectorization-unsupported.mlir     | 19 ++++++++
 2 files changed, 40 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 60f4f9deeb1cda..2231aeb5dc135a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1714,12 +1714,13 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
-  // Support dynamic shapes in 1D depthwise convolution, but only in the
-  // _channel_ dimension. That's exclusively to support scalable vectorisation.
   if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+    // Support dynamic shapes in 1D depthwise convolution, but only in the
+    // _channel_ dimension. That's exclusively to support scalable
+    // vectorisation.
     auto lhsShaped = op.getDpsInputOperand(0)->get();
     ArrayRef<int64_t> lhsShape =
-        dyn_cast<ShapedType>(lhsShaped.getType()).getShape();
+        cast<ShapedType>(lhsShaped.getType()).getShape();
     auto shapeWithoutCh = lhsShape.drop_back(1);
     if (ShapedType::isDynamicShape(shapeWithoutCh))
       return failure();
@@ -1925,7 +1926,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
   if (!isScalable)
     return success();
 
-  // Only element-wise ops supported in the presence of scalable dims.
+  // Only element-wise and 1d depthwise conv ops supported in the presence of
+  // scalable dims.
   auto linalgOp = dyn_cast<LinalgOp>(op);
   return success(linalgOp && (isElementwise(linalgOp) ||
                               isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
@@ -3182,7 +3184,7 @@ struct Conv1DGenerator
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
-        lhsEltType, {false, false, scalableChDim});
+        lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
     VectorType rhsType =
         VectorType::get({kwSize, cSize}, rhsEltType,
                         /*scalableDims=*/{false, scalableChDim});
@@ -3233,23 +3235,20 @@ struct Conv1DGenerator
     // 0].
     Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
-    auto maybeMaskedLHS = maybeMaskXferOp(
-        lhsType.getShape(),
-        /*scalableDims=*/{false, false, scalableChDim}, lhs.getDefiningOp());
+    auto maybeMaskedLhs = maybeMaskXferOp(
+        lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
 
     // Read rhs slice of size {kw, c} @ [0, 0].
     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
                                                         ValueRange{zero, zero});
-    auto maybeMaskedRHS = maybeMaskXferOp(
-        rhsType.getShape(),
-        /*scalableDims=*/{false, scalableChDim}, rhs.getDefiningOp());
+    auto maybeMaskedRhs = maybeMaskXferOp(
+        rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
 
     // Read res slice of size {n, w, c} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
-    auto maybeMaskedRES = maybeMaskXferOp(
-        resType.getShape(),
-        /*scalableDims=*/{false, false, scalableChDim}, res.getDefiningOp());
+    auto maybeMaskedRes = maybeMaskXferOp(
+        resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
 
     //===------------------------------------------------------------------===//
     // Begin vector-only rewrite part
@@ -3264,7 +3263,7 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, maybeMaskedLHS->getResult(0),
+            loc, maybeMaskedLhs->getResult(0),
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             inOutSliceSizes, inOutStrides));
       }
@@ -3272,13 +3271,13 @@ struct Conv1DGenerator
     // Extract rhs slice of size {c} @ [kw].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-          loc, maybeMaskedRHS->getResult(0),
+          loc, maybeMaskedRhs->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{kw}));
     }
     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, maybeMaskedRES->getResult(0),
+          loc, maybeMaskedRes->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
           inOutStrides));
     }
@@ -3317,7 +3316,6 @@ struct Conv1DGenerator
     // Its possible we failed to create the Fma.
     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
       // Manually revert (in reverse order) to avoid leaving a bad IR state.
-      // TODO: Replace with maybeMasked
       for (auto &collection :
            {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
         for (Value v : collection)
@@ -3328,8 +3326,8 @@ struct Conv1DGenerator
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      maybeMaskedRES = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, resVals[w], maybeMaskedRES->getResult(0),
+      maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], maybeMaskedRes->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
     }
@@ -3339,10 +3337,9 @@ struct Conv1DGenerator
 
     // Write back res slice of size {n, w, c} @ [0, 0, 0].
     Operation *resOut = rewriter.create<vector::TransferWriteOp>(
-        loc, maybeMaskedRES->getResult(0), resShaped,
+        loc, maybeMaskedRes->getResult(0), resShaped,
         ValueRange{zero, zero, zero});
-    return maybeMaskXferOp(resType.getShape(),
-                           /*scalableDims=*/{false, false, scalableChDim},
+    return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
                            resOut);
   }
 
@@ -3384,9 +3381,8 @@ struct Conv1DGenerator
     if (!lhs || !rhs)
       return nullptr;
 
-    if (isa<FloatType>(resTy.getElementType())) {
+    if (isa<FloatType>(resTy.getElementType()))
       return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
-    }
 
     auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
     return rewriter.create<arith::AddIOp>(loc, mul, res);
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index a1a52397386c97..4553dc4fbe414b 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -19,6 +19,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @depthwise_conv1d_nwc_wc_dyn_w_dim(%input: memref<3x?x4xf32>, %filter: memref<?x4xf32>, %output: memref<3x?x4xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x?x4xf32>, memref<?x4xf32>)
+    outs(%output : memref<3x?x4xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 2, 4, 2] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @depthwise_conv1d_nwc_wc_dyn_ch_dim(%input: memref<3x5x?xf32>, %filter: memref<2x?xf32>, %output: memref<3x2x?xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.depthwise_conv_1d_nwc_wc

>From 0a06be8f14797fd83b1e437540cdcac09c02ea2a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 15 Feb 2024 20:58:11 +0000
Subject: [PATCH 3/9] fixup! [mlir][linalg] Add scalable vectorisation for
 depthwise convolutions

Addressing PR comments:
- add CSE in tests, update check-lines accordingly
- add support for plain (non-scalable) masked vectorisation
- moved pre-conditions for vectorisation to a dedicated hook
---
 .../Linalg/Transforms/Vectorization.cpp       |  64 +++---
 .../Linalg/vectorize-conv-scalable.mlir       | 202 ++++++++++--------
 2 files changed, 153 insertions(+), 113 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2231aeb5dc135a..bd411244a00450 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -55,6 +55,7 @@ using namespace mlir::linalg;
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
                      ArrayRef<int64_t> inputVecSizes = {},
+                     ArrayRef<bool> inputVecScalableFlags = {},
                      bool flatten1DDepthwiseConv = false);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1713,21 +1714,31 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
   return success();
 }
 
-static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
-  if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
-    // Support dynamic shapes in 1D depthwise convolution, but only in the
-    // _channel_ dimension. That's exclusively to support scalable
-    // vectorisation.
-    auto lhsShaped = op.getDpsInputOperand(0)->get();
-    ArrayRef<int64_t> lhsShape =
-        cast<ShapedType>(lhsShaped.getType()).getShape();
-    auto shapeWithoutCh = lhsShape.drop_back(1);
-    if (ShapedType::isDynamicShape(shapeWithoutCh))
-      return failure();
+static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
+  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv.getOperation())) {
+    LDBG("Not a depth-wise 1D conv, dynamic shapes are not supported\n");
+    return failure();
+  }
 
-    return success();
+  // Support dynamic shapes in 1D depthwise convolution, but only in the
+  // _channel_ dimension. That's exclusively to support scalable
+  // vectorisation.
+  auto lhs = conv.getDpsInputOperand(0)->get();
+  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
+  auto shapeWithoutCh = lhsShape.drop_back(1);
+  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
+    LDBG("Dynamically-shaped op vectorization precondition failed: only "
+         "channel dim can be dynamic\n");
+    return failure();
   }
 
+  return success();
+}
+
+static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+  if (isa<ConvolutionOpInterface>(op.getOperation()))
+    return vectorizeDynamicConvOpPrecondition(op);
+
   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
   // linalg.copy ops and ops that implement ContractionOpInterface for now.
   if (!isElementwise(op) &&
@@ -2016,7 +2027,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             // inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
-                  rewriter, linalgOp, inputVectorSizes, flatten1DDepthwiseConv);
+                  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
+                  flatten1DDepthwiseConv);
               if (succeeded(convOr)) {
                 llvm::append_range(results, (*convOr)->getResults());
                 return success();
@@ -3150,19 +3162,21 @@ struct Conv1DGenerator
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
   FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
+                                       bool channelDimScalableFlag,
                                        bool flatten) {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
     bool scalableChDim = false;
+    bool useMasking = false;
     int64_t nSize, wSize, cSize, kwSize;
     // kernel{kw, c}
     bindShapeDims(rhsShapedType, kwSize, cSize);
-    // Dynamic channel size implies scalable vectorisation
     if (ShapedType::isDynamic(cSize)) {
       assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
       cSize = channelDimVecSize;
-      scalableChDim = true;
+      scalableChDim = channelDimScalableFlag;
+      useMasking = true;
     }
     // out{n, w, c}
     bindShapeDims(resShapedType, nSize, wSize);
@@ -3197,13 +3211,10 @@ struct Conv1DGenerator
     auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
                                ArrayRef<bool> scalableDims,
                                Operation *opToMask) {
-      bool scalableChDim = scalableDims.back();
-      if (!scalableChDim)
+      if (!useMasking)
         return opToMask;
-
       auto maskType =
           VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
-
       SmallVector<OpFoldResult> mixedSourceDims =
           hasTensorSemantics
               ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
@@ -3479,6 +3490,7 @@ struct Conv1DGenerator
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
   FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
+                                             bool vecChDimScalableFlag = true,
                                              bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
@@ -3490,7 +3502,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConv(vecChDimSize, flatten);
+      return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
@@ -3556,10 +3568,9 @@ struct Conv1DGenerator
 
 /// Helper function to vectorize a LinalgOp with convolution semantics.
 // TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *>
-vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
-                     ArrayRef<int64_t> inputVecSizes,
-                     bool flatten1DDepthwiseConv) {
+static FailureOr<Operation *> vectorizeConvolution(
+    RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
+    ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can
   // simply use them if present, otherwise use the default and let the generic
@@ -3586,12 +3597,15 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
     return res;
 
   uint64_t vecChDimSize = ShapedType::kDynamic;
+  bool vecChDimScalableFlag = false;
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
     vecChDimSize = inputVecSizes[2];
+    vecChDimScalableFlag = inputScalableVecDims[2];
   }
-  return e.generateDilatedConv(vecChDimSize, flatten1DDepthwiseConv);
+  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
+                               flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
index d4b3574451c2bf..45b4d53d2a03a9 100644
--- a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
 
 func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
                                                    %filter: tensor<1x?xi8>,
@@ -14,7 +14,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, 4, 1] : !transform.any_op
     transform.yield
   }
 }
@@ -24,58 +24,102 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
 // CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
 
-// CHECK-DAG:       arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[FILTER]], %[[VAL_5]] : tensor<1x?xi8>
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[C0_I8:.*]] = arith.constant 0 : i8
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0 : i8
 
 /// Create a mask for the input tensor
 // CHECK:           %[[C2:.*]] = arith.constant 2 : index
-// CHECK:           %[[CH_DIM_SIZE_INPUT:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[CH_DIM_IN:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
 // CHECK:           %[[C8:.*]] = arith.constant 8 : index
-// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_SIZE_INPUT]] : vector<1x8x[4]xi1>
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x4xi1>
 /// Read the input tensor
-// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
 
 /// Create a mask for the filter tensor
-// CHECK:           %[[C0_I8_1:.*]] = arith.constant 0 : i8
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
-// CHECK:           %[[CH_DIM_SIZE_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
-// CHECK:           %[[C1_1:.*]] = arith.constant 1 : index
-// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1_1]], %[[CH_DIM_SIZE_FLT]] : vector<1x[4]xi1>
+// CHECK:           %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x4xi1>
 /// Read the filter tensor
-// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[C0_I8_1]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x4xi8> } : vector<1x4xi1> -> vector<1x4xi8>
 
 /// Create a mask for the output tensor
-// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : i8
-// CHECK:           %[[VAL_23:.*]] = arith.constant 2 : index
-// CHECK:           %[[CH_DIM_SIZE_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_23]] : tensor<1x8x?xi8>
-// CHECK:           %[[VAL_25:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_26:.*]] = arith.constant 8 : index
-// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[CH_DIM_SIZE_OUT]] : vector<1x8x[4]xi1>
-/// Read the output tensor
-// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[VAL_22]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+// CHECK:           %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x4xi1>
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
 
 /// Convolution
-// CHECK:           %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
-// CHECK:           %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
-// CHECK:           %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
-// CHECK:           %[[FLT_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xi8> to vector<1x8x[4]xi8>
-// CHECK:           %[[MULI:.*]] = arith.muli %[[VEC_IN_0]], %[[FLT_B]] : vector<1x8x[4]xi8>
-// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[VEC_OUT_0]] : vector<1x8x[4]xi8>
-// CHECK:           %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+// CHECK:           %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
+// CHECK:           %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<4xi8> from vector<1x4xi8>
+// CHECK:           %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
+// CHECK:           %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<4xi8> to vector<1x8x4xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x4xi8>
+// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x4xi8>
+// CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x4xi8> into vector<1x8x4xi8>
+// CHECK:           %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x4xi8>, tensor<1x8x?xi8> } : vector<1x8x4xi1> -> tensor<1x8x?xi8>
+// CHECK:           return %[[OUT]] : tensor<1x8x?xi8>
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
+      %input: tensor<1x8x?xi8>,
+      %filter: tensor<1x?xi8>,
+      %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+    outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+  return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_IN:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x[4]xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
 
 /// Create a mask for the output tensor
-// CHECK:           %[[VAL_36:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_37:.*]] = tensor.dim %[[OUTPUT]], %[[VAL_36]] : tensor<1x8x?xi8>
-// CHECK:           %[[VAL_38:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_39:.*]] = arith.constant 8 : index
-// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_38]], %[[VAL_39]], %[[VAL_37]] : vector<1x8x[4]xi1>
+// CHECK:           %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x[4]xi1>
+/// Read the output tensor
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Convolution
+// CHECK:           %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
+// CHECK:           %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x[4]xi8>
+// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x[4]xi8>
+// CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+// CHECK:           %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+// CHECK:           return %[[OUT]] : tensor<1x8x?xi8>
 
-/// Write the output tensor
-// CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
 
 
 // -----
@@ -90,67 +134,49 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
   return
 }
 
+// TODO - nice variable names
 // CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(
-// CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
-// CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
-// CHECK-SAME:      %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
-
-// CHECK:           %[[VAL_3:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_6:.*]] = memref.dim %[[FILTER]], %[[VAL_5]] : memref<2x?xf32>
-// CHECK:           %[[VAL_7:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-SAME:     %[[VAL_0:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME:     %[[VAL_1:.*]]: memref<2x?xf32>,
+// CHECK-SAME:     %[[VAL_2:.*]]: memref<3x2x?xf32>) {
+
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_6:.*]] = arith.constant 2 : index
 
 /// Create a mask for the input tensor
-// CHECK:           %[[VAL_10:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_11:.*]] = memref.dim %[[INPUT]], %[[VAL_10]] : memref<3x5x?xf32>
-// CHECK:           %[[VAL_12:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
-// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[VAL_12]], %[[VAL_13]], %[[VAL_11]] : vector<3x4x[4]xi1>
+// CHECK:           %[[VAL_7:.*]] = memref.dim %[[VAL_0]], %[[VAL_6]] : memref<3x5x?xf32>
+// CHECK:           %[[VAL_8:.*]] = arith.constant 3 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]], %[[VAL_7]] : vector<3x4x[4]xi1>
 /// Read the input tensor
-// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
+// CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_4]], %[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
 
 /// Create a mask for the filter tensor
-// CHECK:           %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_17:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_18:.*]] = memref.dim %[[FILTER]], %[[VAL_17]] : memref<2x?xf32>
-// CHECK:           %[[VAL_19:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_20:.*]] = vector.create_mask %[[VAL_19]], %[[VAL_18]] : vector<2x[4]xi1>
+// CHECK:           %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_3]] : memref<2x?xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.create_mask %[[VAL_6]], %[[VAL_12]] : vector<2x[4]xi1>
 /// Read the filter tensor
-// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[VAL_20]] { vector.transfer_read %[[FILTER]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_16]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
+// CHECK:           %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
 
 /// Create a mask for the output tensor
-// CHECK:           %[[VAL_22:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_23:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_24:.*]] = memref.dim %[[OUTPUT]], %[[VAL_23]] : memref<3x2x?xf32>
-// CHECK:           %[[VAL_25:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_26:.*]] = arith.constant 2 : index
-// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[VAL_25]], %[[VAL_26]], %[[VAL_24]] : vector<3x2x[4]xi1>
+// CHECK:           %[[VAL_15:.*]] = memref.dim %[[VAL_2]], %[[VAL_6]] : memref<3x2x?xf32>
+// CHECK:           %[[VAL_16:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]], %[[VAL_15]] : vector<3x2x[4]xi1>
 /// Read the output tensor
-// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]], %[[VAL_22]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
+// CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_4]], %[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
 
 /// Convolution
-// CHECK:           %[[VEC_IN_0:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[VEC_IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[VEC_FLT_0:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xf32> from vector<2x[4]xf32>
-// CHECK:           %[[VEC_FLT_1:.*]] = vector.extract %[[VEC_FLT]][1] : vector<[4]xf32> from vector<2x[4]xf32>
-// CHECK:           %[[VEC_OUT_0:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[VEC_FLT_0_B:.*]] = vector.broadcast %[[VEC_FLT_0]] : vector<[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[FMA_1:.*]] = vector.fma %[[VEC_IN_0]], %[[VEC_FLT_0_B]], %[[VEC_OUT_0]] : vector<3x2x[4]xf32>
-// CHECK:           %[[VEC_FLT_1_B:.*]] = vector.broadcast %[[VEC_FLT_1]] : vector<[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[FMA_2:.*]] = vector.fma %[[VEC_IN_1]], %[[VEC_FLT_1_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
-// CHECK:           %[[VEC_OUT_1:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
-
-/// Create a mask for the output tensor
-// CHECK:           %[[VAL_39:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_40:.*]] = memref.dim %[[OUTPUT]], %[[VAL_39]] : memref<3x2x?xf32>
-// CHECK:           %[[VAL_41:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_42:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_43:.*]] = vector.create_mask %[[VAL_41]], %[[VAL_42]], %[[VAL_40]] : vector<3x2x[4]xi1>
-/// Write the output tensor
-// CHECK:           vector.mask %[[VAL_43]] { vector.transfer_write %[[VEC_OUT_1]], %[[OUTPUT]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
+// CHECK:           %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_11]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VAL_19:.*]] = vector.extract_strided_slice %[[VAL_11]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VAL_20:.*]] = vector.extract %[[VAL_14]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[VAL_21:.*]] = vector.extract %[[VAL_14]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_17]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VAL_23:.*]] = vector.broadcast %[[VAL_20]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VAL_24:.*]] = vector.fma %[[VAL_18]], %[[VAL_23]], %[[VAL_22]] : vector<3x2x[4]xf32>
+// CHECK:           %[[VAL_25:.*]] = vector.broadcast %[[VAL_21]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[VAL_26:.*]] = vector.fma %[[VAL_19]], %[[VAL_25]], %[[VAL_24]] : vector<3x2x[4]xf32>
+// CHECK:           %[[VAL_27:.*]] = vector.insert_strided_slice %[[VAL_26]], %[[VAL_17]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
+// CHECK:           vector.mask %[[VAL_16]] { vector.transfer_write %[[VAL_27]], %[[VAL_2]]{{\[}}%[[VAL_4]], %[[VAL_4]], %[[VAL_4]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {

>From f9962706b8e3d6bf44bf33286eb45ed92a8d0f58 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 5 Mar 2024 18:02:20 +0000
Subject: [PATCH 4/9] fixup! fixup! [mlir][linalg] Add masked vectorisation for
 depthwise convolutions

Address comments from Ben and Crefeda
---
 mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp  | 9 +++++----
 mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir | 4 ++--
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bd411244a00450..a782c70d58732c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2857,7 +2857,6 @@ struct Conv1DGenerator
         return;
       break;
     }
-    hasTensorSemantics = linalgOp.hasPureTensorSemantics();
     // The op is now known to be valid.
     valid = true;
   }
@@ -3175,6 +3174,9 @@ struct Conv1DGenerator
     if (ShapedType::isDynamic(cSize)) {
       assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
       cSize = channelDimVecSize;
+      // Scalable vectors are only used when both conditions are met:
+      //  1. channel dim is dynamic
+      //  2. channelDimScalableFlag is set
       scalableChDim = channelDimScalableFlag;
       useMasking = true;
     }
@@ -3216,7 +3218,7 @@ struct Conv1DGenerator
       auto maskType =
           VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
       SmallVector<OpFoldResult> mixedSourceDims =
-          hasTensorSemantics
+          cast<LinalgOp>(op).hasPureTensorSemantics()
               ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
                     .Case<vector::TransferReadOp>([&](auto readOp) {
                       return tensor::getMixedSizes(rewriter, loc,
@@ -3490,7 +3492,7 @@ struct Conv1DGenerator
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
   FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
-                                             bool vecChDimScalableFlag = true,
+                                             bool vecChDimScalableFlag = false,
                                              bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
@@ -3514,7 +3516,6 @@ struct Conv1DGenerator
   StringAttr redOp;
   StringAttr poolExtOp;
   bool isPoolExt = false;
-  bool hasTensorSemantics = false;
   int strideW, dilationW;
   Value lhsShaped, rhsShaped, resShaped;
   ShapedType lhsShapedType, rhsShapedType, resShapedType;
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
index 45b4d53d2a03a9..95d7886c9a22ea 100644
--- a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
@@ -124,7 +124,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x?xf32>,
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(%input: memref<3x5x?xf32>,
                                                                 %filter: memref<2x?xf32>,
                                                                 %output: memref<3x2x?xf32>) {
   linalg.depthwise_conv_1d_nwc_wc
@@ -135,7 +135,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
 }
 
 // TODO - nice variable names
-// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
 // CHECK-SAME:     %[[VAL_0:.*]]: memref<3x5x?xf32>,
 // CHECK-SAME:     %[[VAL_1:.*]]: memref<2x?xf32>,
 // CHECK-SAME:     %[[VAL_2:.*]]: memref<3x2x?xf32>) {

>From 7fe623b87ab33bdfca0acd8e0ecdf3da41b44faf Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Mar 2024 07:55:10 +0000
Subject: [PATCH 5/9] fixup! [mlir][linalg] Add masked vectorisation for
 depthwise convolutions

Better LIT var names in tests
---
 .../Linalg/vectorize-conv-scalable.mlir       | 66 +++++++++----------
 1 file changed, 33 insertions(+), 33 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
index 95d7886c9a22ea..bb337b13689fa7 100644
--- a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
@@ -124,9 +124,10 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(%input: memref<3x5x?xf32>,
-                                                                %filter: memref<2x?xf32>,
-                                                                %output: memref<3x2x?xf32>) {
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
+      %input: memref<3x5x?xf32>,
+      %filter: memref<2x?xf32>,
+      %output: memref<3x2x?xf32>) {
   linalg.depthwise_conv_1d_nwc_wc
     {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
     ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
@@ -134,49 +135,48 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(%input: memref<3x
   return
 }
 
-// TODO - nice variable names
 // CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
-// CHECK-SAME:     %[[VAL_0:.*]]: memref<3x5x?xf32>,
-// CHECK-SAME:     %[[VAL_1:.*]]: memref<2x?xf32>,
-// CHECK-SAME:     %[[VAL_2:.*]]: memref<3x2x?xf32>) {
+// CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
 
-// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
 
 /// Create a mask for the input tensor
-// CHECK:           %[[VAL_7:.*]] = memref.dim %[[VAL_0]], %[[VAL_6]] : memref<3x5x?xf32>
-// CHECK:           %[[VAL_8:.*]] = arith.constant 3 : index
-// CHECK:           %[[VAL_9:.*]] = arith.constant 5 : index
-// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]], %[[VAL_7]] : vector<3x4x[4]xi1>
+// CHECK:           %[[CH_DIM_IN:.*]] = memref.dim %[[INPUT]], %[[C2]] : memref<3x5x?xf32>
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C5:.*]] = arith.constant 5 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C3]], %[[C5]], %[[CH_DIM_IN]] : vector<3x4x[4]xi1>
 /// Read the input tensor
-// CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_4]], %[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
 
 /// Create a mask for the filter tensor
-// CHECK:           %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_3]] : memref<2x?xf32>
-// CHECK:           %[[VAL_13:.*]] = vector.create_mask %[[VAL_6]], %[[VAL_12]] : vector<2x[4]xi1>
+// CHECK:           %[[CH_DIM_FLT:.*]] = memref.dim %[[FILTER]], %[[C1]] : memref<2x?xf32>
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C2]], %[[CH_DIM_FLT]] : vector<2x[4]xi1>
 /// Read the filter tensor
-// CHECK:           %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
 
 /// Create a mask for the output tensor
-// CHECK:           %[[VAL_15:.*]] = memref.dim %[[VAL_2]], %[[VAL_6]] : memref<3x2x?xf32>
-// CHECK:           %[[VAL_16:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]], %[[VAL_15]] : vector<3x2x[4]xi1>
+// CHECK:           %[[CH_DIM_OUT:.*]] = memref.dim %[[OUTPUT]], %[[C2]] : memref<3x2x?xf32>
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[C3]], %[[C2]], %[[CH_DIM_OUT]] : vector<3x2x[4]xi1>
 /// Read the output tensor
-// CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_4]], %[[VAL_4]], %[[VAL_4]]], %[[VAL_5]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
 
 /// Convolution
-// CHECK:           %[[VAL_18:.*]] = vector.extract_strided_slice %[[VAL_11]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[VAL_19:.*]] = vector.extract_strided_slice %[[VAL_11]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[VAL_20:.*]] = vector.extract %[[VAL_14]][0] : vector<[4]xf32> from vector<2x[4]xf32>
-// CHECK:           %[[VAL_21:.*]] = vector.extract %[[VAL_14]][1] : vector<[4]xf32> from vector<2x[4]xf32>
-// CHECK:           %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_17]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[VAL_23:.*]] = vector.broadcast %[[VAL_20]] : vector<[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[VAL_24:.*]] = vector.fma %[[VAL_18]], %[[VAL_23]], %[[VAL_22]] : vector<3x2x[4]xf32>
-// CHECK:           %[[VAL_25:.*]] = vector.broadcast %[[VAL_21]] : vector<[4]xf32> to vector<3x2x[4]xf32>
-// CHECK:           %[[VAL_26:.*]] = vector.fma %[[VAL_19]], %[[VAL_25]], %[[VAL_24]] : vector<3x2x[4]xf32>
-// CHECK:           %[[VAL_27:.*]] = vector.insert_strided_slice %[[VAL_26]], %[[VAL_17]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
-// CHECK:           vector.mask %[[VAL_16]] { vector.transfer_write %[[VAL_27]], %[[VAL_2]]{{\[}}%[[VAL_4]], %[[VAL_4]], %[[VAL_4]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
+// CHECK:           %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[IN_2:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[FLT_2:.*]] = vector.extract %[[VEC_FLT]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FMA_1:.*]] = vector.fma %[[IN_1]], %[[FLT_1_B]], %[[OUT_1]] : vector<3x2x[4]xf32>
+// CHECK:           %[[FLT_2_B:.*]] = vector.broadcast %[[FLT_2]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FMA_2:.*]] = vector.fma %[[IN_2]], %[[FLT_2_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
+// CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
+// CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {

>From 5d87b8972dd947674ef565b65c4e2e12913f6c49 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 6 Mar 2024 15:50:03 +0000
Subject: [PATCH 6/9] fixup! [mlir][linalg] Add masked vectorisation for
 depthwise convolutions

More documentaiton, some simplification (as per Cullen's comments)
---
 .../Dialect/Linalg/Transforms/Vectorization.cpp  |  5 ++++-
 .../Linalg/vectorization-unsupported.mlir        |  2 --
 ...r => vectorize-conv-masked-and-scalable.mlir} | 16 ++++++++--------
 3 files changed, 12 insertions(+), 11 deletions(-)
 rename mlir/test/Dialect/Linalg/{vectorize-conv-scalable.mlir => vectorize-conv-masked-and-scalable.mlir} (100%)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a782c70d58732c..bcb8dd2f92e594 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1715,7 +1715,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
-  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv.getOperation())) {
+  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
     LDBG("Not a depth-wise 1D conv, dynamic shapes are not supported\n");
     return failure();
   }
@@ -3597,6 +3597,9 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (succeeded(res))
     return res;
 
+  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
+  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
+  // masked/scalable) is the channel dim (i.e. the trailing dim).
   uint64_t vecChDimSize = ShapedType::kDynamic;
   bool vecChDimScalableFlag = false;
   if (!inputVecSizes.empty()) {
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 4553dc4fbe414b..212cba2569e9ce 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -41,7 +41,6 @@ module attributes {transform.with_named_sequence} {
 func.func @depthwise_conv1d_nwc_wc_dyn_ch_dim(%input: memref<3x5x?xf32>, %filter: memref<2x?xf32>, %output: memref<3x2x?xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.depthwise_conv_1d_nwc_wc
-    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
     ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
     outs(%output : memref<3x2x?xf32>)
   return
@@ -60,7 +59,6 @@ module attributes {transform.with_named_sequence} {
 func.func @depthwise_conv1d_nwc_wc_dyn_w_dim(%input: memref<3x?x3xf32>, %filter: memref<2x3xf32>, %output: memref<3x?x3xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.depthwise_conv_1d_nwc_wc
-    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
     ins(%input, %filter : memref<3x?x3xf32>, memref<2x3xf32>)
     outs(%output : memref<3x?x3xf32>)
   return
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
similarity index 100%
rename from mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
rename to mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
index bb337b13689fa7..27ca7785780db3 100644
--- a/mlir/test/Dialect/Linalg/vectorize-conv-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
@@ -135,6 +135,14 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
   return
 }
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
+    transform.yield
+  }
+}
+
 // CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
 // CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
 // CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
@@ -177,11 +185,3 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
 // CHECK:           %[[FMA_2:.*]] = vector.fma %[[IN_2]], %[[FLT_2_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
 // CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
 // CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
-    transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
-    transform.yield
-  }
-}

>From 11f54be26210f3469598d0451b7845873d9ea685 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 7 Mar 2024 18:43:55 +0000
Subject: [PATCH 7/9] fixup! [mlir][linalg] Add masked vectorisation for
 depthwise convolutions

Address Diego's comments, move  to vector utils
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   | 14 +++++++++
 .../Linalg/Transforms/Vectorization.cpp       | 31 +++++--------------
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 17 ++++++++++
 3 files changed, 38 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index f6b03a0f2c8007..c5e35a831334fa 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -10,11 +10,14 @@
 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Support/LLVM.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 namespace mlir {
 
@@ -98,6 +101,17 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 std::optional<StaticTileOffsetRange>
 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
 
+/// A wrapper for getMixedSizes for vector.transfer_read and
+/// vector.transfer_write Ops (for source and destination, respectively).
+///
+/// Tensor and MemRef types implement their own, very similar version of
+/// getMixedSizes. This method will call the appropriate version (depending on
+/// `hasTensorSemantics`). It will also automatically extract the operand for
+/// which to call it on (source for "read" and destination for "write" ops).
+SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
+                                            Operation *xfer,
+                                            RewriterBase &rewriter);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bcb8dd2f92e594..d44b2ea95aaa76 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -1721,9 +1722,8 @@ static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
   }
 
   // Support dynamic shapes in 1D depthwise convolution, but only in the
-  // _channel_ dimension. That's exclusively to support scalable
-  // vectorisation.
-  auto lhs = conv.getDpsInputOperand(0)->get();
+  // _channel_ dimension.
+  Value lhs = conv.getDpsInputOperand(0)->get();
   ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
   auto shapeWithoutCh = lhsShape.drop_back(1);
   if (ShapedType::isDynamicShape(shapeWithoutCh)) {
@@ -3217,29 +3217,12 @@ struct Conv1DGenerator
         return opToMask;
       auto maskType =
           VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
-      SmallVector<OpFoldResult> mixedSourceDims =
-          cast<LinalgOp>(op).hasPureTensorSemantics()
-              ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
-                    .Case<vector::TransferReadOp>([&](auto readOp) {
-                      return tensor::getMixedSizes(rewriter, loc,
-                                                   readOp.getSource());
-                    })
-                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
-                      return tensor::getMixedSizes(rewriter, loc,
-                                                   writeOp.getOperand(1));
-                    })
-              : TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
-                    .Case<vector::TransferReadOp>([&](auto readOp) {
-                      return memref::getMixedSizes(rewriter, loc,
-                                                   readOp.getSource());
-                    })
-                    .Case<vector::TransferWriteOp>([&](auto writeOp) {
-                      return memref::getMixedSizes(rewriter, loc,
-                                                   writeOp.getOperand(1));
-                    });
+
+      SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
+          cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
 
       Value maskOp =
-          rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+          rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
 
       return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
     };
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d613672608c3ad..a474cc609f593e 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -300,3 +300,20 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
   shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
   return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
 }
+
+SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
+                                                    Operation *xfer,
+                                                    RewriterBase &rewriter) {
+  auto loc = xfer->getLoc();
+
+  Value blah = TypeSwitch<Operation *, Value>(xfer)
+                   .Case<vector::TransferReadOp>(
+                       [&](auto readOp) { return readOp.getSource(); })
+                   .Case<vector::TransferWriteOp>(
+                       [&](auto writeOp) { return writeOp.getOperand(1); });
+
+  SmallVector<OpFoldResult> mixedSourceDims =
+      hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, blah)
+                         : memref::getMixedSizes(rewriter, loc, blah);
+  return mixedSourceDims;
+}

>From 2528f8ec341e36bd420b68f3455d0fe5a906a205 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Mar 2024 15:49:23 +0000
Subject: [PATCH 8/9] fixup! [mlir][linalg] Add scalable vectorisation for
 depthwise convolutions

Generalise the code a tiny bit to cover for 1D depthwise NCW convs (once
supported by the vectoriser).
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |  2 +-
 .../Linalg/Transforms/Vectorization.cpp       | 12 ++++++++++--
 .../Linalg/vectorization-unsupported.mlir     | 19 +++++++++++++++++++
 3 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index c5e35a831334fa..3ce16ef361f37d 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -9,9 +9,9 @@
 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 
-#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Support/LLVM.h"
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d44b2ea95aaa76..d5914959349774 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3588,8 +3588,16 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
-    vecChDimSize = inputVecSizes[2];
-    vecChDimScalableFlag = inputScalableVecDims[2];
+    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
+            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+           "Not a 1D depthwise conv!");
+    size_t chDimIdx =
+        TypeSwitch<Operation *, size_t>(op)
+            .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
+            .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
+
+    vecChDimSize = inputVecSizes[chDimIdx];
+    vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
   }
   return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
                                flatten1DDepthwiseConv);
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 212cba2569e9ce..71ccc698318560 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -19,6 +19,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @depthwise_conv1d_ncw_cw(%input: memref<3x5x4xf32>, %filter: memref<5x1xf32>, %output: memref<3x5x4xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.depthwise_conv_1d_ncw_cw
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x4xf32>, memref<5x1xf32>)
+    outs(%output : memref<3x5x4xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_ncw_cw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 4, 5, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @depthwise_conv1d_nwc_wc_dyn_w_dim(%input: memref<3x?x4xf32>, %filter: memref<?x4xf32>, %output: memref<3x?x4xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.depthwise_conv_1d_nwc_wc

>From 2a8ce8acfd74c2b343df7f5e5583d0e6a6963e0b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 14 Mar 2024 10:34:20 +0000
Subject: [PATCH 9/9] fixup! [mlir][linalg] Add scalable vectorisation for
 depthwise convolutions

* Add missing dyn dimension in a test
* Make sure "flattening" + "masked vectorisation" are not allowed
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  3 +-
 .../Linalg/Transforms/Vectorization.cpp       | 54 +++++++++++++------
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  6 +--
 .../Linalg/vectorization-unsupported.mlir     |  8 +--
 .../vectorize-conv-masked-and-scalable.mlir   |  2 -
 5 files changed, 49 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c64ecb79c5ca51..5f73c7d1bb7348 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -460,7 +460,8 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
 LogicalResult vectorizeOpPrecondition(Operation *op,
                                       ArrayRef<int64_t> inputVectorSizes = {},
                                       ArrayRef<bool> inputScalableVecDims = {},
-                                      bool vectorizeNDExtract = false);
+                                      bool vectorizeNDExtract = false, 
+                                      bool flatten1DDepthwiseConv = false);
 
 //===----------------------------------------------------------------------===//
 // Transformations exposed as functional-style API calls.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d5914959349774..00d5bb33584808 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1715,9 +1715,17 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
   return success();
 }
 
-static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
+static LogicalResult
+vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
+                                   bool flatten1DDepthwiseConv) {
+  if (flatten1DDepthwiseConv) {
+    LDBG("Vectorization of flattened convs with dynamic shapes is not "
+         "supported\n");
+    return failure();
+  }
+
   if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
-    LDBG("Not a depth-wise 1D conv, dynamic shapes are not supported\n");
+    LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
     return failure();
   }
 
@@ -1735,9 +1743,10 @@ static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
   return success();
 }
 
-static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+static LogicalResult
+vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv) {
   if (isa<ConvolutionOpInterface>(op.getOperation()))
-    return vectorizeDynamicConvOpPrecondition(op);
+    return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
 
   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
   // linalg.copy ops and ops that implement ContractionOpInterface for now.
@@ -1807,7 +1816,8 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
 static LogicalResult
 vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
                               ArrayRef<int64_t> inputVectorSizes,
-                              bool vectorizeNDExtract) {
+                              bool vectorizeNDExtract,
+                              bool flatten1DDepthwiseConv) {
   // tensor with dimension of 0 cannot be vectorized.
   if (llvm::is_contained(linalgOp.getStaticShape(), 0))
     return failure();
@@ -1817,8 +1827,8 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
                                       inputVectorSizes)))
     return failure();
 
-  if (linalgOp.hasDynamicShape() &&
-      failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) {
+  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
+                                        linalgOp, flatten1DDepthwiseConv))) {
     LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
     return failure();
   }
@@ -1946,7 +1956,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
     Operation *op, ArrayRef<int64_t> inputVectorSizes,
-    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract) {
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv) {
   if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
                                                  inputScalableVecDims)))
     return failure();
@@ -1954,7 +1965,8 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
   return TypeSwitch<Operation *, LogicalResult>(op)
       .Case<linalg::LinalgOp>([&](auto linalgOp) {
         return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
-                                             vectorizeNDExtract);
+                                             vectorizeNDExtract,
+                                             flatten1DDepthwiseConv);
       })
       .Case<tensor::PadOp>([&](auto padOp) {
         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
@@ -2003,7 +2015,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
   LLVM_DEBUG(llvm::dbgs() << "\n");
 
   if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
-                                     vectorizeNDExtract))) {
+                                     vectorizeNDExtract, flatten1DDepthwiseConv))) {
     LDBG("Vectorization pre-conditions failed\n");
     return failure();
   }
@@ -3180,6 +3192,9 @@ struct Conv1DGenerator
       scalableChDim = channelDimScalableFlag;
       useMasking = true;
     }
+
+    assert(!(useMasking && flatten) && "Unsupported flattened conv with dynamic shapes");
+
     // out{n, w, c}
     bindShapeDims(resShapedType, nSize, wSize);
 
@@ -3282,10 +3297,15 @@ struct Conv1DGenerator
       return kw * (wSize / wSizeStep) + w;
     };
 
+    // Note - the scalable flags are ignored as flattening combined with
+    // scalable vectorization is not supported.
     auto inOutFlattenSliceSizes =
         SmallVector<int64_t>{nSize, wSizeStep * cSize};
-    auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
-    auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
+    auto lhsTypeAfterFlattening =
+        VectorType::get(inOutFlattenSliceSizes, lhsEltType);
+    auto resTypeAfterFlattening =
+        VectorType::get(inOutFlattenSliceSizes, resEltType);
+
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
@@ -3295,9 +3315,9 @@ struct Conv1DGenerator
           // Flatten the input and output vectors (collapse the channel
           // dimension)
           lhsVal = rewriter.create<vector::ShapeCastOp>(
-              loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
-          resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
-                                                        resVals[w]);
+              loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
+          resVal = rewriter.create<vector::ShapeCastOp>(
+              loc, resTypeAfterFlattening, resVals[w]);
         }
         resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
                                                   rhsVals[kw], resVal, flatten);
@@ -3353,6 +3373,10 @@ struct Conv1DGenerator
     lhs = promote(rewriter, loc, lhs, resTy);
 
     if (flatten) {
+      // NOTE: This following logic won't work for scalable vectors. For this
+      // reason, "flattening" is not supported when shapes are dynamic (this
+      // should be captured by one of the pre-conditions).
+
       // There are two options for handling the filter:
       //  * shape_cast(broadcast(filter))
       //  * broadcast(shuffle(filter))
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index a474cc609f593e..63ed0947cf6ce2 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -306,14 +306,14 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
                                                     RewriterBase &rewriter) {
   auto loc = xfer->getLoc();
 
-  Value blah = TypeSwitch<Operation *, Value>(xfer)
+  Value base = TypeSwitch<Operation *, Value>(xfer)
                    .Case<vector::TransferReadOp>(
                        [&](auto readOp) { return readOp.getSource(); })
                    .Case<vector::TransferWriteOp>(
                        [&](auto writeOp) { return writeOp.getOperand(1); });
 
   SmallVector<OpFoldResult> mixedSourceDims =
-      hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, blah)
-                         : memref::getMixedSizes(rewriter, loc, blah);
+      hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, base)
+                         : memref::getMixedSizes(rewriter, loc, base);
   return mixedSourceDims;
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 71ccc698318560..9127eac5da9510 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -19,12 +19,14 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @depthwise_conv1d_ncw_cw(%input: memref<3x5x4xf32>, %filter: memref<5x1xf32>, %output: memref<3x5x4xf32>) {
+// Masked vectorisation of 1D depthwise CW convs is not yet supported
+
+func.func @depthwise_conv1d_ncw_cw(%input: memref<3x?x4xf32>, %filter: memref<?x1xf32>, %output: memref<3x?x4xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.depthwise_conv_1d_ncw_cw
     {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
-    ins(%input, %filter : memref<3x5x4xf32>, memref<5x1xf32>)
-    outs(%output : memref<3x5x4xf32>)
+    ins(%input, %filter : memref<3x?x4xf32>, memref<?x1xf32>)
+    outs(%output : memref<3x?x4xf32>)
   return
 }
 
diff --git a/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
index 27ca7785780db3..84b556d80887c0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
@@ -120,8 +120,6 @@ module attributes {transform.with_named_sequence} {
 // CHECK:           %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
 // CHECK:           return %[[OUT]] : tensor<1x8x?xi8>
 
-
-
 // -----
 
 func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(



More information about the Mlir-commits mailing list