[Mlir-commits] [mlir] c56bd7a - [mlir][linalg] Enable masked vectorisation for depthwise convolutions (#81625)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 14 13:19:50 PDT 2024


Author: Andrzej WarzyƄski
Date: 2024-03-14T20:19:46Z
New Revision: c56bd7ab7934355ed19d4d3ec299b5784bf02379

URL: https://github.com/llvm/llvm-project/commit/c56bd7ab7934355ed19d4d3ec299b5784bf02379
DIFF: https://github.com/llvm/llvm-project/commit/c56bd7ab7934355ed19d4d3ec299b5784bf02379.diff

LOG: [mlir][linalg] Enable masked vectorisation for depthwise convolutions (#81625)

This patch adds support for masked vectorisation of depthwise 1D WC
convolutions,`linalg.depthwise_conv_1d_nwc_wc`. This is implemented by
adding support for masking.

Two major assumptions are made:
  * only the channel dimension can be dynamic/scalable (i.e. the
    trailing dim),
  * when specifying vector sizes to use in the vectoriser, only the size
    corresponding to the channel dim is effectively used (other dims are
    inferred from the context).

In terms of scalable vectorisation, this should be sufficient to cover
all practical cases (i.e. making arbitrary dim scalable wouldn't make
much sense). As for more generic cases with dynamic shapes (e.g. W or N
dims being dynamic), more work would be needed. In particular, one would
have to consider the filter and input/output tensors separately.

Added: 
    mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c64ecb79c5ca51..feb3b3f03cf538 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -460,7 +460,8 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
 LogicalResult vectorizeOpPrecondition(Operation *op,
                                       ArrayRef<int64_t> inputVectorSizes = {},
                                       ArrayRef<bool> inputScalableVecDims = {},
-                                      bool vectorizeNDExtract = false);
+                                      bool vectorizeNDExtract = false,
+                                      bool flatten1DDepthwiseConv = false);
 
 //===----------------------------------------------------------------------===//
 // Transformations exposed as functional-style API calls.

diff  --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index f6b03a0f2c8007..3ce16ef361f37d 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -9,12 +9,15 @@
 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Support/LLVM.h"
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 namespace mlir {
 
@@ -98,6 +101,17 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 std::optional<StaticTileOffsetRange>
 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
 
+/// A wrapper for getMixedSizes for vector.transfer_read and
+/// vector.transfer_write Ops (for source and destination, respectively).
+///
+/// Tensor and MemRef types implement their own, very similar version of
+/// getMixedSizes. This method will call the appropriate version (depending on
+/// `hasTensorSemantics`). It will also automatically extract the operand for
+/// which to call it on (source for "read" and destination for "write" ops).
+SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
+                                            Operation *xfer,
+                                            RewriterBase &rewriter);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 1e703dacfd0c75..c74ab1e6448bec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -54,6 +55,8 @@ using namespace mlir::linalg;
 /// Try to vectorize `convOp` as a convolution.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+                     ArrayRef<int64_t> inputVecSizes = {},
+                     ArrayRef<bool> inputVecScalableFlags = {},
                      bool flatten1DDepthwiseConv = false);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1712,7 +1715,40 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
   return success();
 }
 
-static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+static LogicalResult
+vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
+                                   bool flatten1DDepthwiseConv) {
+  if (flatten1DDepthwiseConv) {
+    LDBG("Vectorization of flattened convs with dynamic shapes is not "
+         "supported\n");
+    return failure();
+  }
+
+  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
+    LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
+    return failure();
+  }
+
+  // Support dynamic shapes in 1D depthwise convolution, but only in the
+  // _channel_ dimension.
+  Value lhs = conv.getDpsInputOperand(0)->get();
+  ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
+  auto shapeWithoutCh = lhsShape.drop_back(1);
+  if (ShapedType::isDynamicShape(shapeWithoutCh)) {
+    LDBG("Dynamically-shaped op vectorization precondition failed: only "
+         "channel dim can be dynamic\n");
+    return failure();
+  }
+
+  return success();
+}
+
+static LogicalResult
+vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
+                                     bool flatten1DDepthwiseConv) {
+  if (isa<ConvolutionOpInterface>(op.getOperation()))
+    return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
+
   // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
   // linalg.copy ops and ops that implement ContractionOpInterface for now.
   if (!isElementwise(op) &&
@@ -1778,10 +1814,9 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
   return success();
 }
 
-static LogicalResult
-vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
-                              ArrayRef<int64_t> inputVectorSizes,
-                              bool vectorizeNDExtract) {
+static LogicalResult vectorizeLinalgOpPrecondition(
+    LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
+    bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
   // tensor with dimension of 0 cannot be vectorized.
   if (llvm::is_contained(linalgOp.getStaticShape(), 0))
     return failure();
@@ -1791,8 +1826,8 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
                                       inputVectorSizes)))
     return failure();
 
-  if (linalgOp.hasDynamicShape() &&
-      failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) {
+  if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
+                                        linalgOp, flatten1DDepthwiseConv))) {
     LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
     return failure();
   }
@@ -1911,14 +1946,17 @@ vectorizeScalableVectorPrecondition(Operation *op,
   if (!isScalable)
     return success();
 
-  // Only element-wise ops supported in the presence of scalable dims.
+  // Only element-wise and 1d depthwise conv ops supported in the presence of
+  // scalable dims.
   auto linalgOp = dyn_cast<LinalgOp>(op);
-  return success(linalgOp && isElementwise(linalgOp));
+  return success(linalgOp && (isElementwise(linalgOp) ||
+                              isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
     Operation *op, ArrayRef<int64_t> inputVectorSizes,
-    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract) {
+    ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
+    bool flatten1DDepthwiseConv) {
   if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
                                                  inputScalableVecDims)))
     return failure();
@@ -1926,7 +1964,8 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
   return TypeSwitch<Operation *, LogicalResult>(op)
       .Case<linalg::LinalgOp>([&](auto linalgOp) {
         return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
-                                             vectorizeNDExtract);
+                                             vectorizeNDExtract,
+                                             flatten1DDepthwiseConv);
       })
       .Case<tensor::PadOp>([&](auto padOp) {
         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
@@ -1975,7 +2014,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
   LLVM_DEBUG(llvm::dbgs() << "\n");
 
   if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
-                                     vectorizeNDExtract))) {
+                                     vectorizeNDExtract,
+                                     flatten1DDepthwiseConv))) {
     LDBG("Vectorization pre-conditions failed\n");
     return failure();
   }
@@ -1999,7 +2039,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             // inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
-                  rewriter, linalgOp, flatten1DDepthwiseConv);
+                  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
+                  flatten1DDepthwiseConv);
               if (succeeded(convOr)) {
                 llvm::append_range(results, (*convOr)->getResults());
                 return success();
@@ -3131,13 +3172,30 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConv(bool flatten) {
+  FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
+                                       bool channelDimScalableFlag,
+                                       bool flatten) {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
+    bool scalableChDim = false;
+    bool useMasking = false;
     int64_t nSize, wSize, cSize, kwSize;
     // kernel{kw, c}
     bindShapeDims(rhsShapedType, kwSize, cSize);
+    if (ShapedType::isDynamic(cSize)) {
+      assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
+      cSize = channelDimVecSize;
+      // Scalable vectors are only used when both conditions are met:
+      //  1. channel dim is dynamic
+      //  2. channelDimScalableFlag is set
+      scalableChDim = channelDimScalableFlag;
+      useMasking = true;
+    }
+
+    assert(!(useMasking && flatten) &&
+           "Unsupported flattened conv with dynamic shapes");
+
     // out{n, w, c}
     bindShapeDims(resShapedType, nSize, wSize);
 
@@ -3158,20 +3216,51 @@ struct Conv1DGenerator
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
-        lhsEltType);
-    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
-    VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+        lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
+    VectorType rhsType =
+        VectorType::get({kwSize, cSize}, rhsEltType,
+                        /*scalableDims=*/{false, scalableChDim});
+    VectorType resType =
+        VectorType::get({nSize, wSize, cSize}, resEltType,
+                        /*scalableDims=*/{false, false, scalableChDim});
+
+    // Masks the input xfer Op along the channel dim, iff the corresponding
+    // scalable flag is set.
+    auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
+                               ArrayRef<bool> scalableDims,
+                               Operation *opToMask) {
+      if (!useMasking)
+        return opToMask;
+      auto maskType =
+          VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
+
+      SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
+          cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
+
+      Value maskOp =
+          rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
+
+      return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
+    };
 
     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
     // 0].
     Value lhs = rewriter.create<vector::TransferReadOp>(
         loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedLhs = maybeMaskXferOp(
+        lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
+
     // Read rhs slice of size {kw, c} @ [0, 0].
     Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
                                                         ValueRange{zero, zero});
+    auto maybeMaskedRhs = maybeMaskXferOp(
+        rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
+
     // Read res slice of size {n, w, c} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
         loc, resType, resShaped, ValueRange{zero, zero, zero});
+    auto maybeMaskedRes = maybeMaskXferOp(
+        resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
 
     //===------------------------------------------------------------------===//
     // Begin vector-only rewrite part
@@ -3186,7 +3275,7 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, lhs,
+            loc, maybeMaskedLhs->getResult(0),
             /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
             inOutSliceSizes, inOutStrides));
       }
@@ -3194,12 +3283,13 @@ struct Conv1DGenerator
     // Extract rhs slice of size {c} @ [kw].
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       rhsVals.push_back(rewriter.create<vector::ExtractOp>(
-          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+          loc, maybeMaskedRhs->getResult(0),
+          /*offsets=*/ArrayRef<int64_t>{kw}));
     }
     // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
       resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, res,
+          loc, maybeMaskedRes->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
           inOutStrides));
     }
@@ -3208,10 +3298,15 @@ struct Conv1DGenerator
       return kw * (wSize / wSizeStep) + w;
     };
 
+    // Note - the scalable flags are ignored as flattening combined with
+    // scalable vectorization is not supported.
     auto inOutFlattenSliceSizes =
         SmallVector<int64_t>{nSize, wSizeStep * cSize};
-    auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
-    auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
+    auto lhsTypeAfterFlattening =
+        VectorType::get(inOutFlattenSliceSizes, lhsEltType);
+    auto resTypeAfterFlattening =
+        VectorType::get(inOutFlattenSliceSizes, resEltType);
+
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
@@ -3221,9 +3316,9 @@ struct Conv1DGenerator
           // Flatten the input and output vectors (collapse the channel
           // dimension)
           lhsVal = rewriter.create<vector::ShapeCastOp>(
-              loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
-          resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
-                                                        resVals[w]);
+              loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
+          resVal = rewriter.create<vector::ShapeCastOp>(
+              loc, resTypeAfterFlattening, resVals[w]);
         }
         resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
                                                   rhsVals[kw], resVal, flatten);
@@ -3248,8 +3343,8 @@ struct Conv1DGenerator
     // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
     // This does not depend on kw.
     for (int64_t w = 0; w < wSize; w += wSizeStep) {
-      res = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, resVals[w], res,
+      maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], maybeMaskedRes->getResult(0),
           /*offsets=*/ArrayRef<int64_t>{0, w, 0},
           /*strides=*/ArrayRef<int64_t>{1, 1, 1});
     }
@@ -3258,10 +3353,11 @@ struct Conv1DGenerator
     //===------------------------------------------------------------------===//
 
     // Write back res slice of size {n, w, c} @ [0, 0, 0].
-    return rewriter
-        .create<vector::TransferWriteOp>(loc, res, resShaped,
-                                         ValueRange{zero, zero, zero})
-        .getOperation();
+    Operation *resOut = rewriter.create<vector::TransferWriteOp>(
+        loc, maybeMaskedRes->getResult(0), resShaped,
+        ValueRange{zero, zero, zero});
+    return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
+                           resOut);
   }
 
   /// Lower:
@@ -3278,6 +3374,10 @@ struct Conv1DGenerator
     lhs = promote(rewriter, loc, lhs, resTy);
 
     if (flatten) {
+      // NOTE: This following logic won't work for scalable vectors. For this
+      // reason, "flattening" is not supported when shapes are dynamic (this
+      // should be captured by one of the pre-conditions).
+
       // There are two options for handling the filter:
       //  * shape_cast(broadcast(filter))
       //  * broadcast(shuffle(filter))
@@ -3399,7 +3499,9 @@ struct Conv1DGenerator
 
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
-  FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
+  FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
+                                             bool vecChDimScalableFlag = false,
+                                             bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
@@ -3410,7 +3512,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConv(flatten);
+      return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
@@ -3475,9 +3577,9 @@ struct Conv1DGenerator
 
 /// Helper function to vectorize a LinalgOp with convolution semantics.
 // TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *>
-vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
-                     bool flatten1DDepthwiseConv) {
+static FailureOr<Operation *> vectorizeConvolution(
+    RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
+    ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can
   // simply use them if present, otherwise use the default and let the generic
@@ -3502,7 +3604,28 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
   res = e.generateNcwPooling();
   if (succeeded(res))
     return res;
-  return e.generateDilatedConv(flatten1DDepthwiseConv);
+
+  // Only depthwise 1D NWC convs are left - these can be vectorized using masks
+  // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
+  // masked/scalable) is the channel dim (i.e. the trailing dim).
+  uint64_t vecChDimSize = ShapedType::kDynamic;
+  bool vecChDimScalableFlag = false;
+  if (!inputVecSizes.empty()) {
+    // Only use the input vector size corresponding to the channel dim. Other
+    // vector dims will be inferred from the Ops.
+    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
+            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+           "Not a 1D depthwise conv!");
+    size_t chDimIdx =
+        TypeSwitch<Operation *, size_t>(op)
+            .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
+            .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
+
+    vecChDimSize = inputVecSizes[chDimIdx];
+    vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
+  }
+  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
+                               flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d613672608c3ad..63ed0947cf6ce2 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -300,3 +300,20 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
   shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
   return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
 }
+
+SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
+                                                    Operation *xfer,
+                                                    RewriterBase &rewriter) {
+  auto loc = xfer->getLoc();
+
+  Value base = TypeSwitch<Operation *, Value>(xfer)
+                   .Case<vector::TransferReadOp>(
+                       [&](auto readOp) { return readOp.getSource(); })
+                   .Case<vector::TransferWriteOp>(
+                       [&](auto writeOp) { return writeOp.getOperand(1); });
+
+  SmallVector<OpFoldResult> mixedSourceDims =
+      hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, base)
+                         : memref::getMixedSizes(rewriter, loc, base);
+  return mixedSourceDims;
+}

diff  --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index a1a52397386c97..9127eac5da9510 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -19,10 +19,49 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @depthwise_conv1d_nwc_wc_dyn_ch_dim(%input: memref<3x5x?xf32>, %filter: memref<2x?xf32>, %output: memref<3x2x?xf32>) {
+// Masked vectorisation of 1D depthwise CW convs is not yet supported
+
+func.func @depthwise_conv1d_ncw_cw(%input: memref<3x?x4xf32>, %filter: memref<?x1xf32>, %output: memref<3x?x4xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.depthwise_conv_1d_ncw_cw
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x?x4xf32>, memref<?x1xf32>)
+    outs(%output : memref<3x?x4xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_ncw_cw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 4, 5, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_dyn_w_dim(%input: memref<3x?x4xf32>, %filter: memref<?x4xf32>, %output: memref<3x?x4xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.depthwise_conv_1d_nwc_wc
     {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x?x4xf32>, memref<?x4xf32>)
+    outs(%output : memref<3x?x4xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 2, 4, 2] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_dyn_ch_dim(%input: memref<3x5x?xf32>, %filter: memref<2x?xf32>, %output: memref<3x2x?xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.depthwise_conv_1d_nwc_wc
     ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
     outs(%output : memref<3x2x?xf32>)
   return
@@ -41,7 +80,6 @@ module attributes {transform.with_named_sequence} {
 func.func @depthwise_conv1d_nwc_wc_dyn_w_dim(%input: memref<3x?x3xf32>, %filter: memref<2x3xf32>, %output: memref<3x?x3xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.depthwise_conv_1d_nwc_wc
-    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
     ins(%input, %filter : memref<3x?x3xf32>, memref<2x3xf32>)
     outs(%output : memref<3x?x3xf32>)
   return

diff  --git a/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir b/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
new file mode 100644
index 00000000000000..84b556d80887c0
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
@@ -0,0 +1,185 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
+                                                   %filter: tensor<1x?xi8>,
+                                                   %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+    outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+  return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, 4, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_IN:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x4xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x4xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x4xi8> } : vector<1x4xi1> -> vector<1x4xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x4xi1>
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
+
+/// Convolution
+// CHECK:           %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
+// CHECK:           %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<4xi8> from vector<1x4xi8>
+// CHECK:           %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
+// CHECK:           %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<4xi8> to vector<1x8x4xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x4xi8>
+// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x4xi8>
+// CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x4xi8> into vector<1x8x4xi8>
+// CHECK:           %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x4xi8>, tensor<1x8x?xi8> } : vector<1x8x4xi1> -> tensor<1x8x?xi8>
+// CHECK:           return %[[OUT]] : tensor<1x8x?xi8>
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
+      %input: tensor<1x8x?xi8>,
+      %filter: tensor<1x?xi8>,
+      %output: tensor<1x8x?xi8>) -> (tensor<1x8x?xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x?xi8>, tensor<1x?xi8>)
+    outs(%output : tensor<1x8x?xi8>) -> tensor<1x8x?xi8>
+  return %res : tensor<1x8x?xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<1x8x?xi8>,
+// CHECK-SAME:      %[[FILTER:.*]]: tensor<1x?xi8>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: tensor<1x8x?xi8>) -> tensor<1x8x?xi8> {
+
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0 : i8
+
+/// Create a mask for the input tensor
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[CH_DIM_IN:.*]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x[4]xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x[4]xi1>
+/// Read the output tensor
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
+
+/// Convolution
+// CHECK:           %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xi8> from vector<1x[4]xi8>
+// CHECK:           %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<[4]xi8> to vector<1x8x[4]xi8>
+// CHECK:           %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x[4]xi8>
+// CHECK:           %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x[4]xi8>
+// CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
+// CHECK:           %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
+// CHECK:           return %[[OUT]] : tensor<1x8x?xi8>
+
+// -----
+
+func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
+      %input: memref<3x5x?xf32>,
+      %filter: memref<2x?xf32>,
+      %output: memref<3x2x?xf32>) {
+  linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+    ins(%input, %filter : memref<3x5x?xf32>, memref<2x?xf32>)
+    outs(%output : memref<3x2x?xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<3x5x?xf32>,
+// CHECK-SAME:      %[[FILTER:.*]]: memref<2x?xf32>,
+// CHECK-SAME:      %[[OUTPUT:.*]]: memref<3x2x?xf32>) {
+
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+
+/// Create a mask for the input tensor
+// CHECK:           %[[CH_DIM_IN:.*]] = memref.dim %[[INPUT]], %[[C2]] : memref<3x5x?xf32>
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C5:.*]] = arith.constant 5 : index
+// CHECK:           %[[MASK_IN:.*]] = vector.create_mask %[[C3]], %[[C5]], %[[CH_DIM_IN]] : vector<3x4x[4]xi1>
+/// Read the input tensor
+// CHECK:           %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
+
+/// Create a mask for the filter tensor
+// CHECK:           %[[CH_DIM_FLT:.*]] = memref.dim %[[FILTER]], %[[C1]] : memref<2x?xf32>
+// CHECK:           %[[MASK_FLT:.*]] = vector.create_mask %[[C2]], %[[CH_DIM_FLT]] : vector<2x[4]xi1>
+/// Read the filter tensor
+// CHECK:           %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
+
+/// Create a mask for the output tensor
+// CHECK:           %[[CH_DIM_OUT:.*]] = memref.dim %[[OUTPUT]], %[[C2]] : memref<3x2x?xf32>
+// CHECK:           %[[MASK_OUT:.*]] = vector.create_mask %[[C3]], %[[C2]], %[[CH_DIM_OUT]] : vector<3x2x[4]xi1>
+/// Read the output tensor
+// CHECK:           %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
+
+/// Convolution
+// CHECK:           %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[IN_2:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FLT_1:.*]] = vector.extract %[[VEC_FLT]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[FLT_2:.*]] = vector.extract %[[VEC_FLT]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:           %[[OUT_1:.*]] = vector.extract_strided_slice %[[VEC_OUT]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x2x[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FLT_1_B:.*]] = vector.broadcast %[[FLT_1]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FMA_1:.*]] = vector.fma %[[IN_1]], %[[FLT_1_B]], %[[OUT_1]] : vector<3x2x[4]xf32>
+// CHECK:           %[[FLT_2_B:.*]] = vector.broadcast %[[FLT_2]] : vector<[4]xf32> to vector<3x2x[4]xf32>
+// CHECK:           %[[FMA_2:.*]] = vector.fma %[[IN_2]], %[[FLT_2_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
+// CHECK:           %[[OUT_INS:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
+// CHECK:           vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>


        


More information about the Mlir-commits mailing list